Small Transformers Donโt Need LayerNorm at Inference Time: Scaling LayerNorm Removal to GPT-2 XL and Implications for Mechanistic Interpretability
Review
| ๋๋ค์ | Strength & Weakness & Sugguestions | ๋ณ์ (0/5) |
|---|---|---|
| ๋๋ฌผ | โข ๊ฐ์ : ์์ ํ๋ฅผ ์ํด ๋น์ฐ์(?) ์ฌ์ฉ๋์๋ LN์ ๋ถ๋ถ์ ์ผ๋ก ์ ๊ฑฐํจ์ผ๋ก์จ LLM์ ์ถ๋ก ์ฐ์ฐ์ ํจ์จ์ ์ผ๋ก ํ ์ ์์์ ์ ์. LN์ ์ ๊ฑฐํ๋ฉด ์์ ์ฑ์ด ๋จ์ด์ง๋ฏ๋ก, loss spike๊ฐ ๋ฐ์ํ๋๋ฐ, ์ด๊ฒ์ ๋ฎ์ถ๋ ค๋ฉด fine-tuning์ด ๋ ํ์ํ ๊ฑฐ๊ณ , ๋นํจ์จ์ ์ด์ง ์๋? ์๊ฐํ์ง๋ง, inference์์๋ง '๋ถ๋ถ์ '์ผ๋ก ํ๋ค๊ณ ํ๋ ํจ์จ์ ์ธ ์ธก๋ฉด์์๋ ์ข์ ๋ฐฉ๋ฒ์ธ ๊ฒ ๊ฐ์. ๋ฌด์๋ณด๋ค interpretability ์ธก๋ฉด์์๋ ํฐ ์๋ฏธ๊ฐ ์๋ค๊ณ ์๊ฐํจ. โข ์ฝ์ : ๋ฌผ๋ก loss spike๋ ์์ ๋ชจ๋ธ์ ํํด์ ์ฌํ๊ฒ ํ๋ค๊ณ ํ์ง๋ง, GPT-2 ์ธ์ ๋ ํฐ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ๋ ๋ชจ๋ธ์์๋ ์คํ์ ํ์ผ๋ฉด ๋ ์ผ๋ฐ์ฑ์ด ์ข์์ ๊ฒ ๊ฐ์. โข ๋ณด์์ : ํ๋ผ๋ฏธํฐ๊ฐ ๋ ๋ง์ ๋ชจ๋ธ์ ๋ํ ์คํ์ ํตํด ์ผ๋ฐ์ฑ ํ๋ณด. | 3.8 |
| ๋ฉ์ฟ ๋ฆผ๋ณด | Layer norm์ ๋์ฒดํ๊ณ ๊ทธ๊ฒ์ด ์ ์๋ํ๋์ง ์คํํ๋ ๊ฒ์ ๋งค์ฐ noveltyํ๊ณ ์ค์ํจ! ์ ์ฒด์ ์ผ๋ก soundํ๊ณ , ์ ๊ฑฐํ์ ๋, interpretability๊ฐ ์ฆ๊ฐํ ๊ฒ์ ์ข์ ๊ฒฐ๊ณผ์. ๋์ layer norm์ ์ ๊ฑฐํ์ ๋์ ์ทจ์ฝ์ ๋ ๊ฐ์ด ์์ ํ๋ค๋ฉด ์ข์์ ๊ฒ ๊ฐ์. ๋ด ์๊ฐ์๋ ์ข ์์ ์ฑ์ด ๋จ์ด์ ธ์ ๋ค์ํ task์ alignment๊ฐ ์ ๋๋ก ์ํ๋์ง ์์ ์๋ ์๊ฒ ๋ค๋ ์๊ฐ์ด ์กฐ๊ธ ๋ฆ. | 4 |
| thumps-up | โข ์ฅ: ํ์ต ์์ ์ฑ์ ์ํด ์ ์๋ LN์ด ์คํ๋ ค interpretability๋ฅผ ํด์น๋ค๋๊ฒ ๋ํํ
์ ๊ธฐํ ๊ด์ ์ด์๋ค! transformer decoder์ ๊ฐ์ฅ ํด๋์ํ ๋ชจ๋์ธ GPT2์ ๋ํด ๋ถ์ํ๋ ๊ฒ๋ reasonableํจ โข ๋จ: ํฅ๋ฏธ๋กญ๊ธด ํ๋ฐ ๊ทธ๋์ ์ด๊ฒ ๋ญ? โข ๋ณด์: ๋ค๋ฅธ family์ ๋ํด์๋ ๋ถ์ํ์ผ๋ฉด ์์ฑ๋๊ฐ ํจ์ฌ ๋์์๋ฏ! | 2.8 |
| ํผ๋ | โข ๊ฐ์ : ๋น์ฐํ ํ์ํ๋ค๊ณ ์๊ฐํ๋ LN์ ๋ํด ์ถ๋ก ์์ ์ ์์ด๋ ๋๋ค๋ ์ ์ ์คํ์ ํตํด์ ์ ๋์ ์ผ๋ก ์ฆ๋ช
ํจ โข ์ฝ์ : GPT2์ ๋๋ฌด ๊ตฌํ ๋ชจ๋ธ์.. Llama, Qwen์ด๋ Mistral ๊ฐ์ ํ๋ ์ํคํ ์ฒ์์๋ ์ผ๋ฐํ๊ฐ ๊ฐ๋ฅํ์ง ์คํํด๋ดค์ผ๋ฉด ์ข์์ ๋ฏ (์ํ ์ด์ ๊ฐ ์์ง ์์๊น?) โข ๋ณด์์ : ํ๋ ๋ชจ๋ธ ์ํคํ ์ฒ์์๋ ๊ผญ ๋ถ์ํด๋ดค์ผ๋ฉด ํจ | 2.2 |
| ์์ผ๋ฉด์ ๋ณด์ | ์ฅ์ : ๊ตฌ์กฐ์ ๊ธฐ๋ฅ ๋ถ์, ํจ์จ์ ๊ด์ , ํฅ๋ฏธ๋ก์ด ๊ด์ ๊ณผ ๊ธฐ์ฌ, ์คํ. ๋จ์ : ๋ฌธ์ ์ ์ ๋ณด์ด๊ณ ํด๊ฒฐํ ๊ฒ์ด ์๋๋ผ, ๊ทธ๋ฅ ํ์์ ๋ณด๊ณ ํ ๋ ผ๋ฌธ์ด๋ผ ์กฐ๊ธ ์์ฌ์. ๊ฒฐ๊ณผ์ ๋ํ ์์ฌ์ด ์๋ก์๋ก ๋ณด์์ : ์์ ๋ ๊ฒ์ ๋ํด ๊ธฐ๋ฅ์ ์ด๋ ์ฑ๋ฅ์ ์ผ๋ก ๋ค์ํ ํ์คํฌ ๋ฐ ๋ชจ๋ธ์ ๋ํด ํ๊ฐํ์ผ๋ฉด ํจ. ๋ถ์์ ์ผ๋ก. | 3.3 |
| ๋ ์๋ฆฌ์คํ์ | โข ๊ฐ์ : LayerNorm์ด inference์ ํ์์ ์ด์ง ์์ ์ ์๋ค๋ ๊ฑธ ์คํ์ ์ผ๋ก ์ ๋ณด์ฌ์ค โข ์ฝ์ : ๋ ๋ค์ํ ๋ชจ๋ธ๊ตฐ์์์ ์คํ์ด ์์ผ๋ฉด ์ข๊ฒ ์. ๋ํ DLA error์ ๊ฐ์ํ์ง๋ง attribution patching์ ์ข์์ง์ง ์์์. Interpretability ๊ด๋ จํ ์ด์ ์ด ์ ๋ฉด์ ์ผ๋ก ๊ฐ์ ๋๋๊ฑฐ๋ ์๋๋ฏ โข ๋ณด์์ : inference์ latency ๊ด๋ จํด์ ์ธก์ ์ด ์์ผ๋ฉด ์ข์๊ฒ ๊ฐ์. ๊ทธ๋ฆฌ๊ณ trainning ์์์๋ LN์ ์ ๊ฑฐ ์ํฅ์ด ๊ถ๊ธํจ | 4.0 |
| ์์ง | โข ๊ฐ์ : Layer norm์ด ๋น์ ํ์ ํน์ฑ์ด ๊ฐํ๋ฐ, ์ด๋ก ์ธํด LLM component๋ค์ด ์๋ก ์ฝํ์๋๊ฑธ ์ ํ์ผ๋ก ๋ฐ๊ฟ์ผ๋ก์จ ํน์ component์ ์ญํ ์ ๋ถ๋ฆฌํ์ฌ ์ค๋ช
ํ ์ ์๋ค๋๊ฒ ๊ฐ์ฅ ํผ โข ์ฝ์ : Layer Norm ์ ๊ฑฐํ๋ ค๋ฉด ์ถ๊ฐ์ ์ธ ํ์ธํ๋์ด ํ์ํ๋ฐ, ์ด์ ๋ํ ๋น์ฉ ๋ถ์ + ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ผ๋ง๋ ๊ฐ์ํ๋ ์ง ๋ถ์ ์คํ์ด ๋ถ์ฌํจ โข ๋ณด์์ : Reasoning task์ ์ ์ฉํด๋ด์ ์ถ๋ก ์ฑ๋ฅ or ์๊ฐ์ ๋ํ ์ํฅ ๋ถ์ ์ํ | 3.8 |
| ํ์ฝ | โข ์ฅ์ : layer-norm์ ๋ชจ๋ธ ์ฑ๋ฅ์ ํฐ ์์ ์์ด ๋ณด๋ค ํจ์จ์ ์ธ ๋ชจ๋๋ก ๋์ฒดํ ์ ์์์ ๋ฐํ โข ๋จ์ : ์ด๋ฏธ ํ์ต๋ LLM์์ ๊ฐ layer-norm ๋ ์ด์ด๋ฅผ ๋ถ๋ถ์ ์ผ๋ก ์ ๊ฑฐํ๋ฉด์ ๋ณต๊ตฌํ๊ณ ์๋๋ฐ, ์ฒ์๋ถํฐ layer-norm ๋ ์ด์ด๊ฐ ์์๋๋ผ๋ ๊ธฐ์กด์ฒ๋ผ ํ์ต์ด ์ ๋์๊น? โข ๋ณด์์ : ์์ ๋ชจ๋ธ๋ก๋ผ๋ layer-norm ์์ด ์ฒ์๋ถํฐ ํ๋ จํ ๋ชจ๋ธ์ ์ฑ๋ฅ ๋น๊ต | 3.5 |
| ํ์ด์ด | โข ์ฅ์ : Layer-Norm์ ์ ๊ฑฐํ๊ณ ๋ ๊ฐ์ ํจ๊ณผ๋ฅผ ์ป์ผ๋ฉด์ ์ค๋ช
๊ฐ๋ฅ์ฑ์ด ํฅ์๋ ์ ์๋ค๋ ๋ถ๋ถ์ ๋ฐํ๋ธ ๊ฒ์ ๊ธฐ์ฌ๊ฐ ํผ. โข ๋จ์ : Layer-Norm์ ์ ๊ฑฐํ๋๋ผ๋ ๊ณผ์ฐ ํ์ต์ด ์๋ ๊น? โข ๋ณด์: Layer-Norm์ ์ ๊ฑฐํ๊ณ ์คํํ๊ฑฐ๋, ํ์ธ ํ๋ํ ๊ฒฐ๊ณผ์ ๋น๊ตํ ์คํ์ด ์ถ๊ฐ๋์์ผ๋ฉด ํจ. | 3.9 |
| ์ด์ฝ๋ฆฟ | โข ์ฅ์ : LN์ ์ ๊ฑฐํ ํ ๋ชจ๋ธ์ด overconfidentํด์ง๋ค๋ ๊ฒ์ ๋
ผ๋ฌธ์์ ์ธ๊ธํ๊ณ ๋ถ์ํ ์ ์ด ์ข์์ โข ์ฝ์ : DLA error๋ ์ค์๋๋ฐ, attribution patching error๋ ๋๋ก ๋จ์์์ด์, LN์ ์ ๊ฑฐํ๋ฉด interpretability๊ฐ ์ข์์ง๋ค๋ ์ฃผ์ฅ์ด ์์ ํ์ง ์์๊ฒ ๊ฐ์์. โข ๋ณด์์ : Attribution patching error๊ฐ ์ ๊ฐ์ ๋์ง ์๋์ง ์์ธ ๋ถ์์ด ์์์ผ๋ฉด ์ข์์๊ฒ ๊ฐ๋ค | 3.5 |
TL; DR
Layer normalization์ training stability์๋ ์ค์ํ์ง๋ง, inference ๋จ๊ณ์์๋ ๊ผญ ํ์ํ์ง ์์ ์ ์๋ค! GPT-2 ์ ๋ชจ๋ LayerNorm์ ์ ๊ฑฐํ์ฌ ๋ณด์ฌ์ค
Summary
- ์ธ์ฉ์: 5
- openreview forum : https://openreview.net/forum?id=VPtHqcafIY
- ๋ฆฌ๋ทฐ์ด๋ค์๊ฒ๋ 8,8,8,6์ ์ ๋ฐ์์ง๋ง, AC๋ meta reiew ์์ โrejection-likeโ meta review๋ฅผ ์ฃผ์์! ํ๋ฒ ์ฝ์ด๋ณด์๋ ๊ฑธ ์ถ์ฒ ~~
Background & Motivation
- Layer Normalization Layer
- โ : Hadamard (element-wise) product
- ํ์ต ๋จ๊ณ์์์ stabilize๋ฅผ ์ํ Layer ์
- ์ ํ์ฐ๊ตฌ์์๋ LN์ด confidence regulation์ ๊ธฐ์ฌํ๋ค๊ณ ๋ฐํ๊ธฐ๋ ํ์
- ๋ฏธ๋ฆฌ ํ์ตํด๋ ๊ฐ์ ์ฌ์ค์ linear transformation์ฒ๋ผ ์ธ ์ ์๋ Batch normalization ๊ณผ ๋ฌ๋ฆฌ, LN์ inference๋ ๋งค๋ฒ ์ํํด์ผ ํ๋ ์ฐ์ฐ์. (=non-linear function)
- LN์ non-linearity๋ก ์ธํด, model์ mechanistic interpretability๊ฐ ๋ฐฉํด๋จ
- mechanistic interpretability ๋? ๋ชจ๋ธ์ ๋ ์์ component์๋ก ๋ถํดํ๊ณ , ๊ฐ component์์ ๊ฐ๋ณ์ ์ธ ํจ๊ณผ์ ์ํธ ์์ฉ์ ์ดํดํ๋ ๊ฒ
- why? LM์ด residual stream activation์ ๋ฐ๋ผ ๊ฐ component(=sublayer, head, โฆ)๊ฐ ์ํฅ์ ๋ฐ๊ธฐ ๋๋ฌธ
- Eq (1) ์์, ํ์ฌ ์์ ์์์ ํ๊ท ์ ๋นผ๊ณ std๋ฅผ ๋๋ ์ ์ ๊ทํํ๋๋ฐ, ๊ฐ ๋จ๊ณ์์์ ํ๊ท , std๊ฐ ๋ค ๋ค๋ฅด๊ธฐ ๋๋ฌธ์ ๋ค ๋จ๊ณ๊น์ง ์ํฅ์ ๋ผ์นจ
- ๊ทธ๋์ ๊ธฐ์กด interpretability ์ฐ๊ตฌ์์๋ NL์ linear transformation์ผ๋ก ๊ทผ์ฌํํด์ ์ํํ์
- ์ฐธ๊ณ ๋
ผ๋ฌธ:
- [ICLRโ24] Copy Suppression: Comprehensively Understanding an Attention Head https://openreview.net/forum?id=g8oaZRhDcf
- but, ์ ํํ์ง ์๊ณ , ์ด๋ ๊ฒ ํ์ต๋ ๋ชจ๋ธ์ ์ค์ LLM๊ณผ ๋ค๋ฆ
- ์ฐธ๊ณ ๋
ผ๋ฌธ:
- ๋๋ ์์ LN ์ ์์ ๊ณ , element-wise tanh function์ ์ฌ์ฉํจ
- ์ฐธ๊ณ ๋ ผ๋ฌธ: https://arxiv.org/abs/2503.10622
- but, ์ฌ์ ํ non-linear function์ด๊ธฐ ๋๋ฌธ์ interpretability ์ฐ๊ตฌ์ ๋ถ์ ํฉ
โ ์ด๋ฏธ ํ์ต๋ ์ค์ transformer์์ LN์ ์ ๊ฑฐํ ๋ฒ์ ์ ๋ถ์ํด๋ณด์ !!
Contributions (What theyโve revealed)
transformer๋ก๋ถํฐ LN์ ์ ๊ฑฐํ์์ & LN layer ์์ด๋ ์๋ํ๋ฉฐ, original model๊ณผ ์ ์ฌํ cross-entropy loss๋ฅผ ๋ฌ์ฑํ ์ ์์์ ๋ณด์
- ๊ตฌ์ฒด์ ์ธ removal process
- 0๋ฒ์งธ MLP layer์ LN ( ๏ปฟ) ๋ฅผ ์ ๊ฑฐํ๊ณ fine-tuning
- 1๋ฒ์งธ MLP layer์ LN ( ๏ปฟ) ๋ฅผ ์ ๊ฑฐํ๊ณ fine-tuning
- โฆ
- 0๋ฒ์งธ query/key LN ( ๏ปฟ) ์ ์ ๊ฑฐ ํ๊ณ fine-tuning
- 1๋ฒ์งธ query/key LN ( ๏ปฟ) ์ ์ ๊ฑฐ ํ๊ณ fine-tuning
- โฆ
- 0๋ฒ์งธ value LN ( ๏ปฟ) ์ ์ ๊ฑฐ ํ๊ณ fine-tuning
- 1๋ฒ์งธ value LN ( ๏ปฟ) ์ ์ ๊ฑฐ ํ๊ณ fine-tuning
- โฆ
- Final NL ๏ปฟ์ ๊ฑฐ
- ํ๋ฒ์ ์์ ์ง ์๋ ์ด์ ?
: ๋ชจ๋ LN์ ํ ๋ฒ์ ์์ ๋ฉด ๋ชจ๋ธ ์ฑ๋ฅ์ด ํ๋ณต ๋ถ๊ฐ๋ฅํ๊ฒ ๋ถ๊ดดํ๊ธฐ ๋๋ฌธ์, LN block์ ํ๋์ฉ ์ ๊ฑฐํ๊ณ loss spike๊ฐ ๊ฐ๋ผ์์๋๊น์ง (์ผ์ step๋ง) fine-tuningํจ
- ๊ตฌ์ฒด์ ์ธ removal process
LN layer๊ฐ ์์ ๋ model์ interpretability๊ฐ ํฅ์๋์์์ ๊ฒ์ฆ
: ๊ธฐ์กด interpretability ๋ถ์ ์ฐ๊ตฌ์์ ๋ง์ด ํ์ฉ๋๋ direct logit attribution (DLA), attribution patching ์ฌ์ฉ
- LN layer๊ฐ ์์ ๋ DLA error๊ฐ 50% โ 0%๋ก ๊ฐ์ํจ
- direct logit attribution(DLA)๋?
- ์ด๋ค component์ direct effect(=ํน์ component๊ฐ ์ค๊ฐ component๋ฅผ ๊ฑฐ์น์ง ์๊ณ output์ ์ฃผ๋ ํจ๊ณผ)๋ฅผ ์ ํ ๊ทผ์ฌ๋ก ์ถ์ ํ๋ ๋ฐฉ๋ฒ
- ๊ธฐ์กด LN์ nonlinearity๋ฅผ ๊ฐ์ง๊ณ ์๊ธฐ ๋๋ฌธ์, DLA๊ฐ ๊ทผ์ฌํ๋จ
โ ์ด์ ๋ฐ๋ผ DLA์ DE๊ฐ์ ์ฐจ์ด(=error)๊ฐ ์์ ์ ๋ฐ์ ์์!
- direct logit attribution(DLA)๋?
- attribution patching์ LN layer ์ ๋ฌด์ ๋ฌด๊ดํจ
- activation patching์ด๋?
: ์ด๋ค ๋ด๋ถ activation์ด ์ ๋ง๋ก ์์ธ์ธ์ง ํ์ธํ๊ธฐ ์ํด, ์ ๋ต์ ์ ๋งํ๋ clean prompt์์ ๋์จ activation์ corrupted prompt์์ ๋์จ activation์ ๊ฐ์ ์์น๋ก ๋์ฒดํ์ฌ, ๊ฐ component์ ๊ฒฐ๊ณผ๋ฅผ ๊ด์ฐฐํ๋ ๋ฐฉ๋ฒ
- attribution patching์ด๋?
: activation patching์ first-order Taylor approximation
โ ๊ทผ์ฌํ๋ ๊ณผ์ ์์ attribution patching errors ๋ฐ์
- activation patching์ด๋?
- LN layer๊ฐ ์์ ๋ DLA error๊ฐ 50% โ 0%๋ก ๊ฐ์ํจ
์ถ๊ฐ ๋ถ์ ์ํ
- LN layer๊ฐ residual stream geometry๋ฅผ ์ด๋ป๊ฒ ๋ฐ๊พธ๋๊ฐ
- ์ ํ์ฐ๊ตฌ์์, โfirst position token์ hidden representation L2 norm์ด ์ ๋ํ ํฌ๋ฉฐ, ์ด๋ attention sink์ ํต์ฌ mechanism์โ์ ๋ฐํ์
- LN-free model์ ์กฐ๊ธ ๋ overconfidentํจ
- LN-free model์ ๊ฐ๋ฐํ๋ ์์ค์, GPT-2 Medium์ ๊ฒฝ์ฐ, ์ถ๋ ฅ ๋ถํฌ์ ํ๊ท ์ํธ๋กํผ๋ 2.86 โ 2.53 ์ผ๋ก ๊ฐ์ํจ์ ๋ฐ๊ฒฌ
- ์ด์ ๋ฐ๋ผ, entropy neuron(=confidence neuron)์ ๋ถ์ํด๋ด
- ๊ธฐ์กด ์ฐ๊ตฌ์ ๋์ผํ ๋ฐฉ๋ฒ์ผ๋ก ๋ฐ๊ฒฌ ๋ฐ ๋ถ์ ์ํํจ
: weight norm์ด ํฌ๊ณ , ๋ชจ๋ output logit์ ๊ฑฐ์ ๊ฐ์ ์ํฅ์ ์ฃผ๋ (token ranking ์์ฒด๋ฅผ ๋ฐ๊พธ์ง ์๋) neuron
- ๊ธฐ์กด ์ฐ๊ตฌ์ ๋์ผํ ๋ฐฉ๋ฒ์ผ๋ก ๋ฐ๊ฒฌ ๋ฐ ๋ถ์ ์ํํจ
- LN layer๊ฐ residual stream geometry๋ฅผ ์ด๋ป๊ฒ ๋ฐ๊พธ๋๊ฐ











