27 March 2026

Small Transformers Donโ€™t Need LayerNorm at Inference Time: Scaling LayerNorm Removal to GPT-2 XL and Implications for Mechanistic Interpretability

๐Ÿ’กLayer normalization์€ training stability์—๋Š” ์ค‘์š”ํ•˜์ง€๋งŒ, inference ๋‹จ๊ณ„์—์„œ๋Š” ๊ผญ ํ•„์š”ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ๋‹ค! GPT-2 ์˜ ๋ชจ๋“  LayerNorm์„ ์ œ๊ฑฐํ•˜์—ฌ ๋ณด์—ฌ์คŒ

Small Transformers Donโ€™t Need LayerNorm at Inference Time: Scaling LayerNorm Removal to GPT-2 XL and Implications for Mechanistic Interpretability

Review

๋‹‰๋„ค์ž„ Strength & Weakness & Sugguestions ๋ณ„์  (0/5)
๋ˆˆ๋ฌผ โ€ข ๊ฐ•์  : ์•ˆ์ •ํ™”๋ฅผ ์œ„ํ•ด ๋‹น์—ฐ์‹œ(?) ์‚ฌ์šฉ๋˜์—ˆ๋˜ LN์„ ๋ถ€๋ถ„์ ์œผ๋กœ ์ œ๊ฑฐํ•จ์œผ๋กœ์จ LLM์˜ ์ถ”๋ก  ์—ฐ์‚ฐ์„ ํšจ์œจ์ ์œผ๋กœ ํ•  ์ˆ˜ ์žˆ์Œ์„ ์ œ์‹œ. LN์„ ์ œ๊ฑฐํ•˜๋ฉด ์•ˆ์ •์„ฑ์ด ๋–จ์–ด์ง€๋ฏ€๋กœ, loss spike๊ฐ€ ๋ฐœ์ƒํ•˜๋Š”๋ฐ, ์ด๊ฒƒ์„ ๋‚ฎ์ถ”๋ ค๋ฉด fine-tuning์ด ๋” ํ•„์š”ํ•  ๊ฑฐ๊ณ , ๋น„ํšจ์œจ์ ์ด์ง€ ์•Š๋‚˜? ์ƒ๊ฐํ–ˆ์ง€๋งŒ, inference์‹œ์—๋งŒ '๋ถ€๋ถ„์ '์œผ๋กœ ํ•œ๋‹ค๊ณ  ํ•˜๋‹ˆ ํšจ์œจ์ ์ธ ์ธก๋ฉด์—์„œ๋„ ์ข‹์€ ๋ฐฉ๋ฒ•์ธ ๊ฒƒ ๊ฐ™์Œ. ๋ฌด์—‡๋ณด๋‹ค interpretability ์ธก๋ฉด์—์„œ๋Š” ํฐ ์˜๋ฏธ๊ฐ€ ์žˆ๋‹ค๊ณ  ์ƒ๊ฐํ•จ.
โ€ข ์•ฝ์  : ๋ฌผ๋ก  loss spike๋Š” ์ž‘์€ ๋ชจ๋ธ์— ํ•œํ•ด์„œ ์‹ฌํ•˜๊ฒŒ ํŠ„๋‹ค๊ณ  ํ•˜์ง€๋งŒ, GPT-2 ์™ธ์— ๋” ํฐ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ฐ–๋Š” ๋ชจ๋ธ์—์„œ๋„ ์‹คํ—˜์„ ํ–ˆ์œผ๋ฉด ๋” ์ผ๋ฐ˜์„ฑ์ด ์ข‹์•˜์„ ๊ฒƒ ๊ฐ™์Œ.
โ€ข ๋ณด์™„์  : ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋” ๋งŽ์€ ๋ชจ๋ธ์— ๋Œ€ํ•œ ์‹คํ—˜์„ ํ†ตํ•ด ์ผ๋ฐ˜์„ฑ ํ™•๋ณด.
3.8
๋ฉ์ฟ ๋ฆผ๋ณดLayer norm์„ ๋Œ€์ฒดํ•˜๊ณ  ๊ทธ๊ฒƒ์ด ์ž˜ ์ž‘๋™ํ•˜๋Š”์ง€ ์‹คํ—˜ํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ noveltyํ•˜๊ณ  ์ค‘์š”ํ•จ! ์ „์ฒด์ ์œผ๋กœ soundํ•˜๊ณ , ์ œ๊ฑฐํ–ˆ์„ ๋•Œ, interpretability๊ฐ€ ์ฆ๊ฐ€ํ•œ ๊ฒƒ์€ ์ข‹์€ ๊ฒฐ๊ณผ์ž„. ๋Œ€์‹  layer norm์„ ์ œ๊ฑฐํ–ˆ์„ ๋•Œ์˜ ์ทจ์•ฝ์ ๋„ ๊ฐ™์ด ์„œ์ˆ ํ–ˆ๋‹ค๋ฉด ์ข‹์•˜์„ ๊ฒƒ ๊ฐ™์Œ. ๋‚ด ์ƒ๊ฐ์—๋Š” ์ข€ ์•ˆ์ •์„ฑ์ด ๋–จ์–ด์ ธ์„œ ๋‹ค์–‘ํ•œ task์˜ alignment๊ฐ€ ์ œ๋Œ€๋กœ ์ˆ˜ํ–‰๋˜์ง€ ์•Š์„ ์ˆ˜๋„ ์žˆ๊ฒ ๋‹ค๋Š” ์ƒ๊ฐ์ด ์กฐ๊ธˆ ๋“ฆ. 4
thumps-up โ€ข ์žฅ: ํ•™์Šต ์•ˆ์ •์„ฑ์„ ์œ„ํ•ด ์ œ์•ˆ๋œ LN์ด ์˜คํžˆ๋ ค interpretability๋ฅผ ํ•ด์นœ๋‹ค๋Š”๊ฒŒ ๋‚˜ํ•œํ… ์‹ ๊ธฐํ•œ ๊ด€์ ์ด์—ˆ๋‹ค! transformer decoder์˜ ๊ฐ€์žฅ ํด๋ž˜์‹ํ•œ ๋ชจ๋“ˆ์ธ GPT2์— ๋Œ€ํ•ด ๋ถ„์„ํ•˜๋Š” ๊ฒƒ๋„ reasonableํ•จ
โ€ข ๋‹จ: ํฅ๋ฏธ๋กญ๊ธด ํ•œ๋ฐ ๊ทธ๋ž˜์„œ ์ด๊ฒŒ ๋ญ?
โ€ข ๋ณด์™„: ๋‹ค๋ฅธ family์— ๋Œ€ํ•ด์„œ๋„ ๋ถ„์„ํ–ˆ์œผ๋ฉด ์™„์„ฑ๋„๊ฐ€ ํ›จ์”ฌ ๋†’์•˜์„๋“ฏ!
2.8
ํ”ผ๋•€ โ€ข ๊ฐ•์ : ๋‹น์—ฐํžˆ ํ•„์š”ํ•˜๋‹ค๊ณ  ์ƒ๊ฐํ–ˆ๋˜ LN์— ๋Œ€ํ•ด ์ถ”๋ก  ์‹œ์ ์—” ์—†์–ด๋„ ๋œ๋‹ค๋Š” ์ ์„ ์‹คํ—˜์„ ํ†ตํ•ด์„œ ์ •๋Ÿ‰์ ์œผ๋กœ ์ฆ๋ช…ํ•จ
โ€ข ์•ฝ์ : GPT2์€ ๋„ˆ๋ฌด ๊ตฌํ˜• ๋ชจ๋ธ์ž„.. Llama, Qwen์ด๋‚˜ Mistral ๊ฐ™์€ ํ˜„๋Œ€ ์•„ํ‚คํ…์ฒ˜์—์„œ๋„ ์ผ๋ฐ˜ํ™”๊ฐ€ ๊ฐ€๋Šฅํ•œ์ง€ ์‹คํ—˜ํ•ด๋ดค์œผ๋ฉด ์ข‹์•˜์„ ๋“ฏ (์•ˆํ•œ ์ด์œ ๊ฐ€ ์žˆ์ง€ ์•Š์„๊นŒ?)
โ€ข ๋ณด์™„์ : ํ˜„๋Œ€ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์—์„œ๋„ ๊ผญ ๋ถ„์„ํ•ด๋ดค์œผ๋ฉด ํ•จ
2.2
์›ƒ์œผ๋ฉด์„œ ๋ณด์ž์žฅ์ : ๊ตฌ์กฐ์  ๊ธฐ๋Šฅ ๋ถ„์„, ํšจ์œจ์  ๊ด€์ , ํฅ๋ฏธ๋กœ์šด ๊ด€์ ๊ณผ ๊ธฐ์—ฌ, ์‹คํ—˜.
๋‹จ์ : ๋ฌธ์ œ์ ์„ ๋ณด์ด๊ณ  ํ•ด๊ฒฐํ•œ ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ๊ทธ๋ƒฅ ํ˜„์ƒ์„ ๋ณด๊ณ ํ•œ ๋…ผ๋ฌธ์ด๋ผ ์กฐ๊ธˆ ์•„์‰ฌ์›€. ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ์˜์‹ฌ์ด ์ƒˆ๋ก์ƒˆ๋ก
๋ณด์™„์ : ์—†์• ๋Š” ๊ฒƒ์— ๋Œ€ํ•ด ๊ธฐ๋Šฅ์ ์ด๋‚˜ ์„ฑ๋Šฅ์ ์œผ๋กœ ๋‹ค์–‘ํ•œ ํƒœ์Šคํฌ ๋ฐ ๋ชจ๋ธ์— ๋Œ€ํ•ด ํ‰๊ฐ€ํ–ˆ์œผ๋ฉด ํ•จ. ๋ถ„์„์ ์œผ๋กœ.
3.3
๋…์ˆ˜๋ฆฌ์˜คํ˜•์ œ โ€ข ๊ฐ•์ : LayerNorm์ด inference์— ํ•„์ˆ˜์ ์ด์ง€ ์•Š์„ ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฑธ ์‹คํ—˜์ ์œผ๋กœ ์ž˜ ๋ณด์—ฌ์คŒ
โ€ข ์•ฝ์ : ๋” ๋‹ค์–‘ํ•œ ๋ชจ๋ธ๊ตฐ์—์„œ์˜ ์‹คํ—˜์ด ์žˆ์œผ๋ฉด ์ข‹๊ฒ ์Œ. ๋˜ํ•œ DLA error์€ ๊ฐ์†Œํ–ˆ์ง€๋งŒ attribution patching์€ ์ข‹์•„์ง€์ง€ ์•Š์•˜์Œ. Interpretability ๊ด€๋ จํ•œ ์ด์ ์ด ์ „๋ฉด์ ์œผ๋กœ ๊ฐœ์„ ๋˜๋Š”๊ฑฐ๋Š” ์•„๋‹Œ๋“ฏ
โ€ข ๋ณด์™„์ : inference์‹œ latency ๊ด€๋ จํ•ด์„œ ์ธก์ •์ด ์žˆ์œผ๋ฉด ์ข‹์„๊ฒƒ ๊ฐ™์Œ. ๊ทธ๋ฆฌ๊ณ  trainning ์‹œ์—์„œ๋„ LN์˜ ์ œ๊ฑฐ ์˜ํ–ฅ์ด ๊ถ๊ธˆํ•จ
4.0
์‚์งˆ โ€ข ๊ฐ•์ : Layer norm์ด ๋น„์„ ํ˜•์  ํŠน์„ฑ์ด ๊ฐ•ํ•œ๋ฐ, ์ด๋กœ ์ธํ•ด LLM component๋“ค์ด ์„œ๋กœ ์–ฝํ˜€์žˆ๋˜๊ฑธ ์„ ํ˜•์œผ๋กœ ๋ฐ”๊ฟˆ์œผ๋กœ์จ ํŠน์ • component์˜ ์—ญํ• ์„ ๋ถ„๋ฆฌํ•˜์—ฌ ์„ค๋ช…ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š”๊ฒŒ ๊ฐ€์žฅ ํผ
โ€ข ์•ฝ์ : Layer Norm ์ œ๊ฑฐํ•˜๋ ค๋ฉด ์ถ”๊ฐ€์ ์ธ ํŒŒ์ธํŠœ๋‹์ด ํ•„์š”ํ•œ๋ฐ, ์ด์— ๋Œ€ํ•œ ๋น„์šฉ ๋ถ„์„ + ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์–ผ๋งˆ๋‚˜ ๊ฐ์†Œํ•˜๋Š” ์ง€ ๋ถ„์„ ์‹คํ—˜์ด ๋ถ€์žฌํ•จ
โ€ข ๋ณด์™„์ : Reasoning task์— ์ ์šฉํ•ด๋ด์„œ ์ถ”๋ก  ์„ฑ๋Šฅ or ์‹œ๊ฐ„์— ๋Œ€ํ•œ ์˜ํ–ฅ ๋ถ„์„ ์ˆ˜ํ–‰
3.8
ํŒ์ฝ˜โ€ข ์žฅ์ : layer-norm์„ ๋ชจ๋ธ ์„ฑ๋Šฅ์˜ ํฐ ์†์ƒ ์—†์ด ๋ณด๋‹ค ํšจ์œจ์ ์ธ ๋ชจ๋“ˆ๋กœ ๋Œ€์ฒดํ•  ์ˆ˜ ์žˆ์Œ์„ ๋ฐํž˜
โ€ข ๋‹จ์ : ์ด๋ฏธ ํ•™์Šต๋œ LLM์—์„œ ๊ฐ layer-norm ๋ ˆ์ด์–ด๋ฅผ ๋ถ€๋ถ„์ ์œผ๋กœ ์ œ๊ฑฐํ•˜๋ฉด์„œ ๋ณต๊ตฌํ•˜๊ณ  ์žˆ๋Š”๋ฐ, ์ฒ˜์Œ๋ถ€ํ„ฐ layer-norm ๋ ˆ์ด์–ด๊ฐ€ ์—†์—ˆ๋”๋ผ๋„ ๊ธฐ์กด์ฒ˜๋Ÿผ ํ•™์Šต์ด ์ž˜ ๋์„๊นŒ?
โ€ข ๋ณด์™„์ : ์ž‘์€ ๋ชจ๋ธ๋กœ๋ผ๋„ layer-norm ์—†์ด ์ฒ˜์Œ๋ถ€ํ„ฐ ํ›ˆ๋ จํ•œ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ ๋น„๊ต
3.5
ํŒŒ์ด์–ด โ€ข ์žฅ์ : Layer-Norm์„ ์ œ๊ฑฐํ•˜๊ณ ๋„ ๊ฐ™์€ ํšจ๊ณผ๋ฅผ ์–ป์œผ๋ฉด์„œ ์„ค๋ช… ๊ฐ€๋Šฅ์„ฑ์ด ํ–ฅ์ƒ๋  ์ˆ˜ ์žˆ๋‹ค๋Š” ๋ถ€๋ถ„์„ ๋ฐํ˜€๋‚ธ ๊ฒƒ์— ๊ธฐ์—ฌ๊ฐ€ ํผ.
โ€ข ๋‹จ์ : Layer-Norm์„ ์ œ๊ฑฐํ•˜๋”๋ผ๋„ ๊ณผ์—ฐ ํ•™์Šต์ด ์ž˜๋ ๊นŒ?
โ€ข ๋ณด์™„: Layer-Norm์„ ์ œ๊ฑฐํ•˜๊ณ  ์‹คํ—˜ํ•˜๊ฑฐ๋‚˜, ํŒŒ์ธ ํŠœ๋‹ํ•œ ๊ฒฐ๊ณผ์™€ ๋น„๊ตํ•œ ์‹คํ—˜์ด ์ถ”๊ฐ€๋˜์—ˆ์œผ๋ฉด ํ•จ.
3.9
์ดˆ์ฝœ๋ฆฟ โ€ข ์žฅ์ : LN์„ ์ œ๊ฑฐํ•œ ํ›„ ๋ชจ๋ธ์ด overconfidentํ•ด์ง„๋‹ค๋Š” ๊ฒƒ์„ ๋…ผ๋ฌธ์—์„œ ์–ธ๊ธ‰ํ•˜๊ณ  ๋ถ„์„ํ•œ ์ ์ด ์ข‹์•˜์Œ
โ€ข ์•ฝ์ : DLA error๋Š” ์ค„์—ˆ๋Š”๋ฐ, attribution patching error๋Š” ๋Œ€๋กœ ๋‚จ์•„์žˆ์–ด์„œ, LN์„ ์ œ๊ฑฐํ•˜๋ฉด interpretability๊ฐ€ ์ข‹์•„์ง„๋‹ค๋Š” ์ฃผ์žฅ์ด ์™„์ „ํ•˜์ง€ ์•Š์€๊ฒƒ ๊ฐ™์•˜์Œ.
โ€ข ๋ณด์™„์ : Attribution patching error๊ฐ€ ์™œ ๊ฐœ์„ ๋˜์ง€ ์•Š๋Š”์ง€ ์›์ธ ๋ถ„์„์ด ์žˆ์—ˆ์œผ๋ฉด ์ข‹์•˜์„๊ฒƒ ๊ฐ™๋‹ค
3.5

TL; DR

๐Ÿ’ก

Layer normalization์€ training stability์—๋Š” ์ค‘์š”ํ•˜์ง€๋งŒ, inference ๋‹จ๊ณ„์—์„œ๋Š” ๊ผญ ํ•„์š”ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ๋‹ค! GPT-2 ์˜ ๋ชจ๋“  LayerNorm์„ ์ œ๊ฑฐํ•˜์—ฌ ๋ณด์—ฌ์คŒ

Summary

  • ์—ฐ๊ตฌ์ง„
  • ์ธ์šฉ์ˆ˜: 5
  • openreview forum : https://openreview.net/forum?id=VPtHqcafIY
    • ๋ฆฌ๋ทฐ์–ด๋“ค์—๊ฒŒ๋Š” 8,8,8,6์ ์„ ๋ฐ›์•˜์ง€๋งŒ, AC๋Š” meta reiew ์—์„œ โ€œrejection-likeโ€ meta review๋ฅผ ์ฃผ์—ˆ์Œ! ํ•œ๋ฒˆ ์ฝ์–ด๋ณด์‹œ๋Š” ๊ฑธ ์ถ”์ฒœ ~~

Background & Motivation

  • Layer Normalization Layer
    • โŠ™ : Hadamard (element-wise) product
    • ํ•™์Šต ๋‹จ๊ณ„์—์„œ์˜ stabilize๋ฅผ ์œ„ํ•œ Layer ์ž„
      • ๋ฏธ๋ฆฌ ํ•™์Šตํ•ด๋‘” ๊ฐ’์„ ์‚ฌ์‹ค์ƒ linear transformation์ฒ˜๋Ÿผ ์“ธ ์ˆ˜ ์žˆ๋Š” Batch normalization ๊ณผ ๋‹ฌ๋ฆฌ, LN์€ inference๋•Œ ๋งค๋ฒˆ ์ˆ˜ํ–‰ํ•ด์•ผ ํ•˜๋Š” ์—ฐ์‚ฐ์ž„. (=non-linear function)
      • LN์˜ non-linearity๋กœ ์ธํ•ด, model์˜ mechanistic interpretability๊ฐ€ ๋ฐฉํ•ด๋จ
        • mechanistic interpretability ๋ž€? ๋ชจ๋ธ์„ ๋” ์ž‘์€ component์˜๋กœ ๋ถ„ํ•ดํ•˜๊ณ , ๊ฐ component์˜์˜ ๊ฐœ๋ณ„์ ์ธ ํšจ๊ณผ์™€ ์ƒํ˜ธ ์ž‘์šฉ์„ ์ดํ•ดํ•˜๋Š” ๊ฒƒ
        • why? LM์ด residual stream activation์— ๋”ฐ๋ผ ๊ฐ component(=sublayer, head, โ€ฆ)๊ฐ€ ์˜ํ–ฅ์„ ๋ฐ›๊ธฐ ๋•Œ๋ฌธ
          • Eq (1) ์—์„œ, ํ˜„์žฌ ์‹œ์ ์—์„œ์˜ ํ‰๊ท ์„ ๋นผ๊ณ  std๋ฅผ ๋‚˜๋ˆ ์„œ ์ •๊ทœํ™”ํ•˜๋Š”๋ฐ, ๊ฐ ๋‹จ๊ณ„์—์„œ์˜ ํ‰๊ท , std๊ฐ€ ๋‹ค ๋‹ค๋ฅด๊ธฐ ๋•Œ๋ฌธ์— ๋’ค ๋‹จ๊ณ„๊นŒ์ง€ ์˜ํ–ฅ์„ ๋ผ์นจ
        • ๊ทธ๋ž˜์„œ ๊ธฐ์กด interpretability ์—ฐ๊ตฌ์—์„œ๋Š” NL์„ linear transformation์œผ๋กœ ๊ทผ์‚ฌํ™”ํ•ด์„œ ์ˆ˜ํ–‰ํ–ˆ์Œ
          • but, ์ •ํ™•ํ•˜์ง€ ์•Š๊ณ , ์ด๋ ‡๊ฒŒ ํ•™์Šต๋œ ๋ชจ๋ธ์€ ์‹ค์ œ LLM๊ณผ ๋‹ค๋ฆ„
        • ๋˜๋Š” ์•„์˜ˆ LN ์„ ์—†์• ๊ณ , element-wise tanh function์„ ์‚ฌ์šฉํ•จ
          • but, ์—ฌ์ „ํžˆ non-linear function์ด๊ธฐ ๋•Œ๋ฌธ์— interpretability ์—ฐ๊ตฌ์— ๋ถ€์ ํ•ฉ

        โ‡’ ์ด๋ฏธ ํ•™์Šต๋œ ์‹ค์ œ transformer์—์„œ LN์„ ์ œ๊ฑฐํ•œ ๋ฒ„์ „์„ ๋ถ„์„ํ•ด๋ณด์ž !!

Contributions (What theyโ€™ve revealed)

  • transformer๋กœ๋ถ€ํ„ฐ LN์„ ์ œ๊ฑฐํ•˜์˜€์Œ & LN layer ์—†์ด๋„ ์ž‘๋™ํ•˜๋ฉฐ, original model๊ณผ ์œ ์‚ฌํ•œ cross-entropy loss๋ฅผ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ์Œ์„ ๋ณด์ž„
    • FakeLN block ์ •์˜
      • standard deviation ๋Œ€์‹ ์—, fixed scalar ฯƒห‰avg\bar{\sigma}_{avg}๏ปฟ ์‚ฌ์šฉ
      • BS: batch size
    • ๊ตฌ์ฒด์ ์ธ removal process
      1. 0๋ฒˆ์งธ MLP layer์˜ LN ( LNMLP0LN^0_{MLP}๏ปฟ) ๋ฅผ ์ œ๊ฑฐํ•˜๊ณ  fine-tuning
      1. 1๋ฒˆ์งธ MLP layer์˜ LN ( LNMLP1LN^1_{MLP}๏ปฟ) ๋ฅผ ์ œ๊ฑฐํ•˜๊ณ  fine-tuning
      1. โ€ฆ
      1. 0๋ฒˆ์งธ query/key LN ( NLqk0NL^0_{qk}๏ปฟ) ์„ ์ œ๊ฑฐ ํ•˜๊ณ  fine-tuning
      1. 1๋ฒˆ์งธ query/key LN ( NLqk1NL^1_{qk}๏ปฟ) ์„ ์ œ๊ฑฐ ํ•˜๊ณ  fine-tuning
      1. โ€ฆ
      1. 0๋ฒˆ์งธ value LN ( NLv0NL^0_{v}๏ปฟ) ์„ ์ œ๊ฑฐ ํ•˜๊ณ  fine-tuning
      1. 1๋ฒˆ์งธ value LN ( NLv1NL^1_{v}๏ปฟ) ์„ ์ œ๊ฑฐ ํ•˜๊ณ  fine-tuning
      1. โ€ฆ
      1. Final NL NLfNL^f๏ปฟ์ œ๊ฑฐ
      • ํ•œ๋ฒˆ์— ์—†์• ์ง€ ์•Š๋Š” ์ด์œ ?

        : ๋ชจ๋“  LN์„ ํ•œ ๋ฒˆ์— ์—†์• ๋ฉด ๋ชจ๋ธ ์„ฑ๋Šฅ์ด ํšŒ๋ณต ๋ถˆ๊ฐ€๋Šฅํ•˜๊ฒŒ ๋ถ•๊ดดํ•˜๊ธฐ ๋•Œ๋ฌธ์—, LN block์„ ํ•˜๋‚˜์”ฉ ์ œ๊ฑฐํ•˜๊ณ  loss spike๊ฐ€ ๊ฐ€๋ผ์•‰์„๋•Œ๊นŒ์ง€ (์ผ์ • step๋งŒ) fine-tuningํ•จ

      • LN layer ์ œ๊ฑฐ๋ฅผ ๋ณด๋‹ค ์•ˆ์ •์ ์œผ๋กœ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•œ Auxiliary loss ํ™œ์šฉ
        • ฮป: hyperparameter
        • Limitations์—์„œ ์–ธ๊ธ‰ํ•˜๊ธธ, GPT2-large/XL model์„ ์•ˆ์ •ํ™”ํ•˜๊ธฐ ์œ„ํ•ด ์ถ”๊ฐ€์ ์œผ๋กœ ๋„์ž…ํ•œ loss๋ผ๊ณ  ํ•จ !!
    • ํ•™์Šต ๊ฒฐ๊ณผ
      • training loss
      • ํ•™์Šต ์ดํ›„์—๋„ ๋น„์Šทํ•œ cross entropy loss๋ฅผ ๋‹ฌ์„ฑํ•จ

  • LN layer๊ฐ€ ์—†์„ ๋•Œ model์˜ interpretability๊ฐ€ ํ–ฅ์ƒ๋˜์—ˆ์Œ์„ ๊ฒ€์ฆ

    : ๊ธฐ์กด interpretability ๋ถ„์„ ์—ฐ๊ตฌ์—์„œ ๋งŽ์ด ํ™œ์šฉ๋˜๋Š” direct logit attribution (DLA), attribution patching ์‚ฌ์šฉ

    1. LN layer๊ฐ€ ์—†์„ ๋•Œ DLA error๊ฐ€ 50% โ†’ 0%๋กœ ๊ฐ์†Œํ•จ
      • direct logit attribution(DLA)๋ž€?
        • ์–ด๋–ค component์˜ direct effect(=ํŠน์ • component๊ฐ€ ์ค‘๊ฐ„ component๋ฅผ ๊ฑฐ์น˜์ง€ ์•Š๊ณ  output์— ์ฃผ๋Š” ํšจ๊ณผ)๋ฅผ ์„ ํ˜• ๊ทผ์‚ฌ๋กœ ์ถ”์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•
        • ๊ธฐ์กด LN์€ nonlinearity๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์—, DLA๊ฐ€ ๊ทผ์‚ฌํ™”๋จ

          โ‡’ ์ด์˜ˆ ๋”ฐ๋ผ DLA์™€ DE๊ฐ„์˜ ์ฐจ์ด(=error)๊ฐ€ ์žˆ์„ ์ˆ˜ ๋ฐ–์— ์—†์Œ!

      • DLA๊ฐ€ DE์™€ ํ‰๊ท ์ ์œผ๋กœ ์–ผ๋งˆ๋‚˜ ์–ด๊ธ‹๋‚˜๋Š”๊ฐ€๋ฅผ ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•ด Normalized Mean Absolute Error (NMAE) ํ™œ์šฉํ•˜์—ฌ ์ธก์ •
        • w/LN 49.07% โ†’ w/FakeLN 0.00%๋กœ ๊ฐ์†Œํ•จ
    1. attribution patching์€ LN layer ์œ ๋ฌด์™€ ๋ฌด๊ด€ํ•จ
      • activation patching์ด๋ž€?

        : ์–ด๋–ค ๋‚ด๋ถ€ activation์ด ์ •๋ง๋กœ ์›์ธ์ธ์ง€ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด, ์ •๋‹ต์„ ์ž˜ ๋งžํžˆ๋Š” clean prompt์—์„œ ๋‚˜์˜จ activation์„ corrupted prompt์—์„œ ๋‚˜์˜จ activation์— ๊ฐ™์€ ์œ„์น˜๋กœ ๋Œ€์ฒดํ•˜์—ฌ, ๊ฐ component์˜ ๊ฒฐ๊ณผ๋ฅผ ๊ด€์ฐฐํ•˜๋Š” ๋ฐฉ๋ฒ•

        โ‡’ ์ •ํ™•ํ•˜์ง€๋งŒ, ์‹ค์ œ๋กœ activation์„ ํ•˜๋‚˜ํ•˜๋‚˜ ๋ฐ”๊ฟ” ๋„ฃ์œผ๋ฉด์„œ ํ™•์ธํ•ด์•ผ ํ•˜๊ธฐ์— ์—ฐ์‚ฐ๋Ÿ‰ ์ปค์ง

      • attribution patching์ด๋ž€?

        : activation patching์˜ first-order Taylor approximation

        โ‡’ ๊ทผ์‚ฌํ•˜๋Š” ๊ณผ์ •์—์„œ attribution patching errors ๋ฐœ์ƒ

      • ๊ฒฐ๊ณผ
        • LN free model์ด๋ผ๊ณ  ํ•ด๋„, attribution patching errors ๋Š” ์—ฌ์ „ํ•จ
  • ์ถ”๊ฐ€ ๋ถ„์„ ์ˆ˜ํ–‰
    1. LN layer๊ฐ€ residual stream geometry๋ฅผ ์–ด๋–ป๊ฒŒ ๋ฐ”๊พธ๋Š”๊ฐ€
      • NL-free model์—์„œ๋„ ๊ฐ™์€ ํ˜„์ƒ์ด ๋‚˜ํƒ€๋‚˜๋Š”์ง€๋ฅผ ๊ด€์ฐฐํ•จ
        • LN-free model์€ first token์„ ๋‹ค๋ฅธ token๊ณผ ๊ฑฐ์˜ ๋™์ผํ•˜๊ฒŒ ์ทจ๊ธ‰ํ•จ
        • attention sink rate๋„ 55.3% โ†’ 45.3%๋กœ ๊ฐ์†Œํ•จ
    1. LN-free model์€ ์กฐ๊ธˆ ๋” overconfidentํ•จ
      • LN-free model์„ ๊ฐœ๋ฐœํ•˜๋Š” ์™€์ค‘์—, GPT-2 Medium์˜ ๊ฒฝ์šฐ, ์ถœ๋ ฅ ๋ถ„ํฌ์˜ ํ‰๊ท  ์—”ํŠธ๋กœํ”ผ๋Š” 2.86 โ†’ 2.53 ์œผ๋กœ ๊ฐ์†Œํ•จ์„ ๋ฐœ๊ฒฌ
      • ์ด์— ๋”ฐ๋ผ, entropy neuron(=confidence neuron)์„ ๋ถ„์„ํ•ด๋ด„
        • ๊ธฐ์กด ์—ฐ๊ตฌ์™€ ๋™์ผํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ ๋ฐœ๊ฒฌ ๋ฐ ๋ถ„์„ ์ˆ˜ํ–‰ํ•จ

          : weight norm์ด ํฌ๊ณ , ๋ชจ๋“  output logit์— ๊ฑฐ์˜ ๊ฐ™์€ ์˜ํ–ฅ์„ ์ฃผ๋Š” (token ranking ์ž์ฒด๋ฅผ ๋ฐ”๊พธ์ง€ ์•Š๋Š”) neuron

        • ๊ฒฐ๊ณผ ๋ถ„์„
          • entropy neuron์ธ 1083, 1108, 3144์˜ cross entropy loss ๊ฐ’์ด NL-free์—์„œ ํฌ๊ฒŒ ๊ฐ์†Œํ•จ

            โ‡’ ์ฆ‰, confidence neuron์ด ๊ธฐ๋Šฅ์„ ํ•˜์ง€ ์•Š์Œ

Categories

PROBINGresearch