An Analysis for Reasoning Bias of Language Models with Small Initialization
Review
| ๋๋ค์ | ํ์คํ | ๋ณ์ (0/5) |
|---|---|---|
| ๋งน๊ตฌ | memory์ reasoning์ ๋ค๋ฅธ ๊ฒ์ด๋ผ๊ณ ์์ํ๋ ๊ฒ ๊ฐ์. ์ด๋ฒ์ฃผ ๋ค๋ฅธ ๋ ผ๋ฌธ์์ ๊ฒฐ๊ตญ LLM์ memory ๊ธฐ๋ฐ ์ถ๋ก ์ ํ๋ ๊ฒ์ด๋ค ๋๋์ ๊ฒฐ๋ก ์ด ๋์๋๋ฐ, ์ด ๋ ผ๋ฌธ์ ๋ ๋ํ ์ผํ๊ฒ ํ์ธํด๋ณด๋ ค๊ณ ํ ๊ฒ ๊ฐ๋ค. LLM๋ ๊ฒฐ๊ตญ Transformer๋๊น, ์ด๋ฐ ์คํ๋ ๊ฐ๋ฅํ๊ตฌ๋ ๋ผ๋ ์๊ฐ์ ํ๊ฒ ๋์์. ๊ผญ ํฐ ๋ชจ๋ธ์ ๊ฒฐ๊ณผ์ ๊ณผ์ ์ ์ง์ฐฉํ ํ์๋ ์๊ตฌ๋ ๋ผ๋ ์๊ฐ์ด ๋ค๊ธฐ๋ ํจ! | 4.2 |
| ๊ณ๋์ด๋ฐฅ | reasoning bias๋ฅผ ๊ณต๋ก ํ(?)ํ๊ณ , ์ด๋ฅผ ์ํด reasoning๊ณผ memory๋ฅผ ๋ช ์์ ์ผ๋ก ๊ตฌ๋ถ ๋ฐ ๋ถ์ํ ๋ ผ๋ฌธ! ๊ตฌ๋ถ ๊ธฐ์ค์ด ๋ช ํํ๊ณ , ์คํ ์ค๊ณ๋ ๊ฐ์ฐ์ ์ด๊ณ ๋ ผ๋ฆฌ์ ์ด๋ฉฐ, ์๋ฒ ๋ฉ๋จ์์ ๋ณด์ด๋ ๊ฒฝํฅ๋ ์ ํํํ๋ค. | 4.3 |
| ๊ตญ๋ฐฅ | LLM์ด ์ถ๋ก ์ ํ๋์ง ์๊ธฐ๋ฅผ ํ๋์ง๋ ๋ฐ์ดํฐ๋ ์ํคํ ์ณ๋ง์ ๋ฌธ์ ๊ฐ ์๋๋ผ๋ ๊ฒ์ ์คํ์ ์ผ๋ก ์ ํ์ด๋ธ๊ฑฐ๊ฐ๋ค. ์๋ฒ ๋ฉ ๋ถ๋ฆฌ ์คํ๋ ์ง๊ด์ ์ด๋ผ ์ข๋ค. Query๋ง๋ค ๋ค๋ฅธ KG๊ฐ ์ฃผ์ด์ง๋ฉด..? ์ด๊ฑด ์๊ธฐ๋ณด๋ค๋ ๊ตฌ์กฐ์ ์ธ ์ผ๋ฐํ๊ฐ ์ ๋ฆฌํด๋ณด์ด๋๋ฐ small init์ด ๋ง์ ๋ฏ | 4.4 |
| ํ๋ฒ๊ฑฐ | ๋ชจ๋ธ์ reasoning/ memorization ํธํฅ์ด ์ด๊ธฐํ ์ค์ผ์ผ์ ์ํด ๋ฌ๋ผ์ง ์ ์๋ค๋ ์ ์ ๋ช ํํ๊ฒ ๋ณด์ฌ์ค๋ฏ. ๊ฒฐ๊ตญ ํ์ต ์ด๊ธฐ์ ์ธํ ์ด ์ดํ ํ์ต ๊ณผ์ ์์ ๋ชจ๋ธ์ โ์ฑํฅโ์ด๋ ์๋ ด ๋ฐฉํฅ์ ํฌ๊ฒ ์ข์ฐํ ์ ์์ผ๋ฏ๋ก ๋ชจ๋ธ์ ๋ชฉํ ํน์ฑ์ ๋ง์ถฐ ์ด๊ธฐํ ์ค๊ณ๋ฅผ ์ ์ ํ ์ ์ฉํ ์ ์์๋ฏ | 4.4 |
| ํผ์ | ๋ชจ๋ธ์ ์ด๊ธฐ ํ์ต ์ธํ ๋ฐ Scale, ๋ฐ์ดํฐ์ ์ด ํ์ต์ ์งํ ๋ฐฉํฅ๋ง์ ๋ ํฌ๊ฒ ์ํฅ์ ๋ฏธ์น๋ค๋ ๋ด์ฉ์ Embedding Space๋ก ์ฆ๋ช ํจ์ผ๋ก์จ ํฌ๊ฒ ์๋ฏธ๊ฐ ์๋ ์ฐ๊ตฌ๋ผ๊ณ ๋ณด์ฌ์ง. | 4.1 |
| ์นํจ | ์์ฆ ํ์ต ๋จ๊ณ๋ณ ์ค์ผ์ผ๋ง ๊ด๋ จ ๋ ผ๋ฌธ๋ค์ด ๋ง์ด ๋ณด์ด๋๋ฐ ๊ฒฐ๊ตญ ์ ํด์ง ํ์ต ๋น์ฉ ๋ด์์ ์์ ๋๋น ํจ์จ์ ๊ทน๋ํํ๊ธฐ ์ํด์๊ฒ ์ง? reasoning bias๋ฅผ ์ ์ฆํด๋ธ ์คํ์ด ๋ช ํํ๊ฒ ์ดํด๊ฐ ๋์ ์ข์๋๊ฑฐ ๊ฐ๋ค | 4.6 |
| ํ๋ธ๋ฆฌ์ฆ | memory์ reasoning์ด ๋ค๋ฅธ ๊ฒ์ธ์ง, ๋ค๋ฅด๋ค๋ฉด ๋ญ๊ฐ ์ฐ์ ์ธ ๊ฒ์ธ์ง ๋ ผ์ํ๋ ๋ ผ๋ฌธ ์ค์ ํ๋. ํํธ์ผ๋ก ์ด ์ฃผ์ ๋ก๋ ์ต๋ํ ํฐ ๋ชจ๋ธ์ ๋ง์ด ํ์ต์ํค๊ณ ์ ์คํํ๊ณ ๋ ผ์ํ๋ ๊ฒ ๋ง์ง ์๋ ์ถ๊ธฐ๋ ํ๊ณ .. | 4.2 |
TL; DR
Transformer ๊ธฐ๋ฐ ๋ชจ๋ธ์์ ์ด๊ธฐํ Scale์ ๋ฐ๋ผ์ ์ถ๋ก ์ ๋จผ์ ๋ฐฐ์ฐ๋๊ฐ, ์๊ธฐ๋ฅผ ๋จผ์ ๋ฐฐ์ฐ๋๊ฐ์ ํธํฅ์ด ์กด์ฌํ๋ค!
Summary
์ฐ๊ตฌ์ง: ๋ฏธ๊ตญ Duke University, ์ค๊ตญ ์ํ์ด๊ตํต๋
Cite: 4
- ํธ๋์คํฌ๋จธ ๊ธฐ๋ฐ ์ธ์ด ๋ชจ๋ธ์์ Parameter Initialization Scale์ ๋ฐ๋ผ ํ์ต์ด๋ LLM์ Task Preference์ ์ํฅ์ ๋ํด ์กฐ์ฌํจ
- Small Initialization Scale์์๋ ๋ชจ๋ธ์ด Reasoning Task๋ฅผ ์ ์ํํ๋๋ก Encourage๋์์
- Large Initialization Scale์์๋ ๋ชจ๋ธ์ด Memorization Task๋ฅผ ์ ์ํํ๋๋ก Preference๊ฐ ์ ๋๋จ
โ ์ด๊ธฐํ ์ค์ผ์ผ์ ๋ฐ๋ผ ๋ชจ๋ธ์ โํ์ต ์ฑํฅ(bias)โ๊ฐ ๋ณํจ
- ์ค์ ๋ฐ์ดํฐ์ ๊ณผ Anchor Function์ผ๋ก ์ด ์ฑํฅ์ ๊ฒ์ฆ
- ์ด๊ธฐํ ์ค์ผ์ผ์ ๋ฐ๋ฅธ ํ์์ ์์ธ์ Embedding Space์ Self-Attention Mechanism์ ํตํด ๋ถ์
- Model Training์ ๋์ญํ(dynamics) ๊ด์ ์์ ์ด ํ์์ด ์๊ธฐ๋ ์์ธ์ ์ค๋ช ํ๋ ์ด๋ก ์ ํ๋ ์์ํฌ๋ฅผ ๋ ผ๋ฌธ์์ ์ ์
- LLM์ ์ด๊ธฐํ๊ฐ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ด๋ป๊ฒ ์ํฅ์ ๋ฏธ์น๋์ง ์ดํด๋ฅผ ๋์ด๋ ์ฐ๊ตฌ
Introduction
Motivation
- LLM์ Reasoning Task์ ๋ํด์๋ RHO-1๊ณผ ๊ฐ์ Data-driven Approach๊ฐ ๋ง์ด ์ ์๋์ด ์์ผ๋, LLM์ด ์ง์ง logical rule์ ์ดํดํ๊ณ , reasoning์ ์ํํ๋์ง ์๋๋ฉด ์ฃผ์ด์ง ๊ท์น์ ๋จ์ํ ๋ฐ๋ผ๋ง ํ๋์ง์ ๋ํ ์๋ฌธ
์์ initialization scale์์๋ ๋ชจ๋ธ์ด ์์ ๋จ์ ๋ ๋ฒจ์ ๊ธฐ๋ฅ๊ณผ ๋ณต์กํ ๊ท์น์ ํ์ตํจ์ผ๋ก์จ data์ fit๋๋๋ก ์ ๋ํจNeuron condensation effect๊ฐ ํ์ต ๊ณผ์ ์์ ์๊ฒจ๋จNeuron Condenstation Effect: ๋์ผํ ๊ณ์ธต์ ๋ด๋ฐ๋ค์ด ์ ์ฌํ ์ถ๋ ฅ์ ๋ด๋๋ก ๋ญ์ณ์ง๋ ํ์
- ๊ฐ์ ๋ ์ด์ด์ Neuron์ด ์ ์ฌํ ํจํด์ผ๋ก ๋ง์ถฐ์ง๋ ํ์์ผ๋ก ์ธํด ๋ฐ์ํ์ฌ ์ต์ ๋ณต์ก๋๋ก data fitting์ด ๋๋๋ก ํจ
- ํํ๋ ฅ์ ์ถฉ๋ถํ์ง๋ง ์ค์ง์ ์ผ๋ก ์ฌ์ฉ ๊ฐ๋ฅํ ์์ ๋๊ฐ ์ ์ด์ง
- ๊ฐ๋ณ ์ํ์ ๋ฐ๋ก ์ธ์ธ ์ ์๋ ํ๋ผ๋ฏธํฐ ๋ถ๋ฆฌ๊ฐ ์ด๋ ค์ ์๊ธฐ ์ฑ๋ฅ์ ๋จ์ด์ง
- ๊ณตํต์ผ๋ก ์ ์ฉ๋๋ ๊ฐ๋จํ ๊ท์น์ ์ฐพ๋ ๋ฐฉํฅ์ผ๋ก ํ์ต์ด ์งํ
ํฐ initialization scale์์๋ ๋ชจ๋ธ์ด input-output ๋งคํ์ ๋ํ ๊ธฐ์ต์ ํ๋๋ก ์ ๋ํ์ฌ ์๊ธฐ ์ฑ๋ฅ์ด ์ฌ๋ผ๊ฐ
Contribution
- Reasoning bias๋ฅผ ์ค์ ์์ฐ์ด ํ์ต ์ค์ ์์ ๋ณด์ฌ์ฃผ๋ ์คํ
- ๋ชจ๋ธ์ Initialization Scale์ด Reasoning Behavior (bias)์ ๋ฏธ์น๋ ์ํฅ์ด ์๋น์ ์ค๋ช ํ๋ ์ฐ๊ตฌ
Experiments
- Reasoning Bias๋ฅผ ์๋ณํ๊ธฐ ์ํด neural network๋ฅผ Small Parameter Scale๋ก, ์๋ก ๋ค๋ฅธ reasoning ๋ณต์ก์ฑ์ ๊ฐ์ง ๋ ๊ฐ์ง ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ตํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ๋น๊ตํจ
- ๋ ๋ฐ์ดํฐ์
์ GPT-2 ์์์ ์์ด์ ํ์ต์ ์งํํ ๊ฒฐ๊ณผ
PrOntoQA๋ฐ์ดํฐ์ - CoT๋ฅผ ํฌํจํ๋ QA ๋ฐ์ดํฐ์
- ์ง๋ฌธ์ ์ ํํ ๋ง์ถ๊ธฐ ์ํ
Reasoning์ ๋ช ์์ ์ผ๋ก ํํํ๋ ๋ด์ฉ
TinyStories๋ฐ์ดํฐ์ - 3-4์ธ์ ์์ด๊ฐ ์ดํดํ ์ ์๋ ๋จ์ด๋ก ์ด๋ฃจ์ด์ง ์งง์ ํฉ์ฑ๋ ์คํ ๋ฆฌ(์๊ธฐ(Memory) ์์ฃผ)
- PrOntoQA ๋ฐ์ดํฐ์ ์์ Loss๊ฐ ๋น ๋ฅด๊ฒ ์ค์ด๋๋ ๊ฒ์ผ๋ก ๋ณผ ๋, ๋ชจ๋ธ์ด Reasoning pattern์ ๋ ์ ํ์ ํจ์ ์ ์ ์์
- ์ถ๋ก (Reasoning) ๊ณผ์ ๊ฐ ํ์ต ์ด๊ธฐ์ ๋นจ๋ฆฌ ์ต๋๋๋ ์ด์
- ํ์ต ๊ณผ์ ์์ ์ด๊ธฐ์ Embedding space๊ฐ ๋ ๋ถํ๋๋ ํน์ฑ์ด ์กด์ฌํจ
- ์๋ฒ ๋ฉ์ด ๋ถํ๋๋ค: ์๋ฒ ๋ฉ ๊ณต๊ฐ์์ Vector๋ค์ด ์๋ก ๋ค๋ฅธ ๋ฐฉํฅ, ์์น๋ก ์ด๋
- ํ ํฐ t๋ one-hot์ด๊ธฐ ๋๋ฌธ์ ์๋ฒ ๋ฉ์ผ๋ก ๋ณํ ํ ๊ทธ ํ ํฐ์ด ๋ฑ์ฅํ ๋ชจ๋ ์ํ์ loss gradient๋ฅผ ๋์ ํด์ ์
๋ฐ์ดํธํ๋๋ฐ ํ ํฐ๋ง๋ค ์ด๋ค ๋ผ๋ฒจ๊ณผ ํจ๊ป ๋ฑ์ฅํ๋๊ฐ๊ฐ ์๋ฒ ๋ฉ ๋ฐฉํฅ์ ๊ฒฐ์
Reasoning Task์์ ์๋ฒ ๋ฉ ๋ถํ๊ฐ ๋น ๋ฅธ ์ด์- ํน์ ํ ํฐ์ ํน์ ์ ํ์ ๋ผ๋ฒจ๊ณผ ๊ฐํ๊ฒ ์ฐ๊ด๋์ด ํ ํฐ๋ณ ๋ผ๋ฒจ ๋ถํฌ๊ฐ ์๋ก ๋ค๋ฆ
- ์๋ฒ ๋ฉ์ด ํ์ต ์ด๋ฐ๋ถํฐ ๋ค๋ฅธ ๋ฐฉํฅ์ผ๋ก ์ด๋
Memory Task์์ ์๋ฒ ๋ฉ ๋ถํ๊ฐ ๋๋ฆฐ ์ด์- ์๋ก ๋ค๋ฅธ Memory (์๊ธฐ) ํ ํฐ๋ค์ด ๋น์ทํ ๋ผ๋ฒจ ๋ถํฌ๋ฅผ ๊ฐ์ง
- Gradient ๋ฐฉํฅ์ด ์๋ก ์ ์ฌ, ์ด๊ธฐ์ ์๋ฒ ๋ฉ์ด ์๋ก ๊ตฌ๋ถ๋์ง ์์
- ํ์ต ๊ณผ์ ์์ ์ด๊ธฐ์ Embedding space๊ฐ ๋ ๋ถํ๋๋ ํน์ฑ์ด ์กด์ฌํจ
Result
- Transformer์์ โ์์ ์ด๊ธฐํโ๋ฅผ ์ฌ์ฉํ์ฌ ์คํ์ ์งํ
- Biased๋ ํ์์ ์์ธํ ๊ด์ฐฐํ๊ธฐ ์ํ์ฌ ์๋ฒ ๋ฉ ๋ ์ด์ด์ Multi-layer Perceptron์ผ๋ก ๊ตฌ์ฑ๋ ๊ฐ๋ตํ๋ ๋ชจ๋ธ์ ์ ์
- ํ ํฐ ์๋ฒ ๋ฉ์ ํด๋น ํ ํฐ์ด ๋ฑ์ฅํ ์ํ๋ค์ ๋ผ๋ฒจ ๋ถํฌ์ ์ํด ํ์ต๋๋ ๊ฒ์ ์ด์ฉํ์ฌ ์คํ ์ค๊ณ
- Reasoning Anchor
- ํ ํฐ ์์ฒด๋ง์ผ๋ก ์ ๋ต์ด ๊ฒฐ์ ๋์ง ์๊ณ ๋ค๋ฅธ ํ ํฐ๋ค๊ณผ์ Composition์ ๋ฐ๋ผ ๋ผ๋ฒจ์ด ๋ฌ๋ผ์ง
- Gradient์ ๋ถ์ฐ์ด ํผ
- ๋ผ๋ฒจ ๋ถํฌ๊ฐ ๋ค์ํ๋ฏ๋ก ์ด๊ธฐ ํ์ต ๋จ๊ณ์์ ์๋ฒ ๋ฉ ๋ถํ๊ฐ ๋น ๋ฆ
- Memory Anchor
- ํน์ ํ ํฐ์ด ๊ฑฐ์ ๊ฐ์ ์ ๋ต(label)๊ณผ ์ฐ๊ฒฐ๋จ
- ๋ผ๋ฒจ ๋ถํฌ์ ๋ถ์ฐ์ด ์๊ณ ์๋ฒ ๋ฉ ์ ๋ฐ์ดํธ๋ ๋ถํ๊ฐ ์์
- Reasoning Anchor
Reasoning Bias in Transformer with Composite Anchor Functions
- 0.3, 0.5, 0.8๋ก Gamma ํฌ๊ธฐ๋ฅผ ๋ณํํ์์ ๋, ์๋ ํ์ต ์ค loss์ ๋ณํ, ์๋ ํ์ Prediction Accuracy์ ๋ณํ
- ์ฃผ์: Gammaํฌ๊ธฐ๊ฐ ํด์๋ก ์ด๊ธฐํ ์ค์ผ์ผ์ ์์ ๊ฒ์
- Amem (์๊ธฐ๋ก ํ ์ ์๋ ํ ํฐ), Arsn (๊ท์น์ ์์์ผ ํ ์ ์๋ ํ ํฐ), Z (์ผ๋ฐ ํ ํฐ), M (ํน์ Reasoning Anchor ์์๋ง ์๋ฏธ ์๋ ๊ท์น)์ผ๋ก ๋ฐ์ดํฐ์ ๊ตฌ์ฑ
- ๋ฐ์ดํฐ์
- ์ด 200,000๊ฐ ์ํ
- ์๊ธฐ๋ง ํ๋ฉด ๋๋ ๋ฐ์ดํฐ, Reasoning์ด ํ์ํ train ๋ฐ์ดํฐ, Reasoning์ด ํ์ํ test๋ฐ์ดํฐ๋ก ๋ถ๋ฆฌ
- ๋ชจ๋ธ
- Decoder-only Transformer (2 layers, 1 attention head)
- Loss: Cross Entropy + AdamW
- ์ด๊ธฐํ ์ค์ผ์ผ: 0.3, 0.5, 0.8
- 0.3์ ํฐ ์ด๊ธฐํ
- ํ๋ จ ๋ฐ์ดํฐ์์๋ ์๊ธฐ ๋ฐ์ดํฐ, Reasoning์ด ํ์ํ ๋ฐ์ดํฐ ๋ชจ๋ ์ ๋ง์ถค
- ํ ์คํธ ์ถ๋ก ๋ฐ์ดํฐ์์๋ loss๊ฐ ๊ฑฐ์ ์ค์ง ์์
- ํ๋ จ ์ํ ์์ฒด๋ฅผ ์๊ธฐํ๊ณ ์์์ ์ ์ ์์
Memory ๋ฐ์ดํฐ์ loss๊ฐ ๋ค์ ๋น ๋ฅด๊ฒ ํ๊ฐํ๊ณ ์์
- 0.8์ ์์ ์ด๊ธฐํ
- ์ถ๋ก ๋ฐ์ดํฐ๋ ํ๋ จ ๋ฐ ํ ์คํธ ๋ฐ์ดํฐ ๋ชจ๋ loss๊ฐ ์ ํ๊ฐ
Memory ๋ฐ์ดํฐ์ loss ํ๊ฐ์ด ๋ค์ ๋๋ฆผ
- ๋จ์ ์๊ธฐ๋ณด๋ค ๊ท์น์ ๋จผ์ ํ์ต
- Reasoning Bias๊ฐ ๋ฐ์ํจ
โ ๋ชจ๋ธ์ Learning Bias๊ฐ Initialization scale์ ์ํฅ์ ๋ฐ์
Simplified Model
Bias๋ฅผ ๋ ์ ์ดํดํ๊ธฐ ์ํด 2 layer์ ์์ Fully Connected Network์์ ์คํ
- ๋ชจ๋ธ์ ์ ์
2 layer์ ๋ชจ๋ธ๋ก ๊ตฌ์ฑํ๊ณ , W(1)์ ์ ๋ ฅ ํ ํฐ์์ hidden state๋ฅผ ์ถ์ถํ๋ Weight๋ก, W(2)๋ฅผ hidden state์์ ์ถ๋ ฅ ํ ํฐ์ ์ถ์ถํ๋ Weight๋ก ๊ตฌ์ฑํ๊ณ ํ์ฑํ ํจ์(sigmoid)๋ฅผ ์ฌ์ฉ
- Embedding Space๊ฐ ๋ณด์ด๋ ํจํด ๋น๊ต
Memory Anchor๊ฐ ํ์ต Epoch์ด ์ฆ๊ฐํ์์์๋ ์ผ๊ด์ ์ผ๋ก ๊ฑฐ์ ๊ฐ์ ๋ฐฉํฅ์ ๊ฐ๋ฆฌํด๊ณผReasoning Anchor๊ฐ ๊ฐ๊น์ด ๊ฑฐ๋ฆฌ์ผ์๋ก ๊ฐ์ ๋ฐฉํฅ(Cosine ์ ์ฌ๋๊ฐ ๋์)์ ๊ฐ๋ฆฌํค๊ณ , ๊ฑฐ๋ฆฌ๊ฐ ์ปค์ง์๋ก ์ ์ฌ๋๊ฐ ๊ฐ์ํ๋ ์ฐ์์ ์ธ ๊ฒฐ๊ณผ๋ฅผ ๊ฐ์ง์ ๋ณด์ฌ์ค
Memory Anchor์Reasoning Anchor์ ์๋ฒ ๋ฉ์์ Cosine ์ ์ฌ๋๋ฅผ ๋ณด์์ ๋,
Reasoning Anchor์์๋ Anchor๊ฐ ๊ฑฐ๋ฆฌ๊ฐ ์ปค์ง์๋ก Cosine ์ ์ฌ๋๊ฐ ๊ฐ์ํ๋ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌReasoning Anchor๊ฐ ๋น ๋ฅด๊ฒ ์๋ฒ ๋ฉ ๊ณต๊ฐ์์ ์ฐ์์ , ๊ณ์ธต์ ๊ตฌ์กฐ๊ฐ ๋ง๋ค์ด์ง
Memory Anchor์์๋ ๋ชจ๋ Memory Anchor๊ฐ ๊ฑฐ์ ๊ฐ์ ๋ฐฉํฅ์ ๊ฐ๋ฆฌํด- ๋ชจ๋ธ์ Primitive-level ๋งคํ, ์ฆ, ์๋ฒ ๋ฉ์ด ๋ ๋ค์ํด์ ธ์ผ ํจ
- ๋ณต์ก์ฑ๊ณผ ๋ค์์ฑ์ด ๋์ฑ ์ฆ๋๋์ด์ผ ํจ
- โ ๊ทธ๋ฌ๋ ๊ฒฐ๊ณผ์ ์ผ๋ก Reasoning Anchor์ ๋นํด Memory Anchor๊ฐ ๋ค์ํ์ง ๋ชปํ๊ณ , ๋ถํ๊ฐ ์ ์๋๊ณ ์์
- Target ๋ถํฌ๊ฐ Embedding์ ๊ฒฐ์ ํ๋ ์ด์
- Assumption
- ์์ ์
๋ ฅ์์๋ ํ์ฑํ๊ฐ ๊ฑฐ์ ์ ํ์ด๊ณ , Gradient๊ฐ ํญ์ฃผํ์ง ์์์ ๊ฐ์ ํจ
โ Small initialization์์๋ Emb-MLP (ํฉ์ฑ๋ ๋ชจ๋ธ)๊ฐ ๊ฑฐ์ ์ ํ ๋ชจ๋ธ์ฒ๋ผ ์๋ํ๋ค
- Hidden Layer์ ๋น์ ํ์ฑ์ด ์ฌ๋ผ์ง๋ฏ๋ก, Target Distribution๋ง ๋ณด๊ณ Embedding์ด ์์ง์
- ํ ํฐ s์ ๋ํ Embedding์ Gradient๋ s๊ฐ ๋ฑ์ฅํ ๋ชจ๋ ์ํ์ ์ ๋ต ๋ ์ด๋ธ๊ณผ uniform ๋ถํฌ์ ์ฐจ์ด์ ์ํด ๋์ ๋จ
- Proposition
- ๋๋ค ๋ณ์: ํ ํฐ s๋ฅผ ํฌํจํ ์ํ์ ํ๋ ๋ฌด์์๋ก ๋ฝ์์ ๋์ ์ ๋ต ๋ ์ด๋ธ
๋ถํฌ: ํ ํฐ s๊ฐ ์ด๋ค ๋ ์ด๋ธ๊ณผ ์ผ๋ง๋ ์์ฃผ ํจ๊ป ๋ฑ์ฅํ๋๊ฐ
- Embedding์ ์ด๋ ๋ฐฉํฅ์ ์ ๋ต ๋ถํฌ P์ ์์ ๊ท ๋ฑ ๋ถํฌ์ ์ฐจ์ด์ ์ํด ๊ฒฐ์ ๋จ
๋ชจ๋ธ ๊ตฌ์กฐ, ๋ค๋ฅธ ํ ํฐ์ ๊ฑฐ์ ๊ด์ฌํ์ง ์์
- Results
- Memory Anchor๊ฐ ๋ชจ๋๋ค ๊ฑฐ์ ๊ฐ์ ๋ฐฉํฅ์ผ๋ก Align ๋๋ ์ด์
- ์ด๋ค Memory Anchor๊ฐ ๋ฑ์ฅํด๋ ์ ๋ต ๋ ์ด๋ธ ๋ถํฌ๊ฐ ๋์ผ(Uniform ๋ถํฌ์ ์ฐจ์ด๊ฐ ๊ฑฐ์ ๋ฐ์ํ์ง ์์)
- ๋ชจ๋ Memory Anchor์ Gradient ๋ฐฉํฅ์ด ๊ฐ์
- โ ๋ฐ๋ผ์ Embedding์ด ๊ฐ์ ๋ฐฉํฅ์ผ๋ก๋ง ๊ณ์ ์์ง์
- Reasoning Anchor๊ฐ ๋ถํ๋๋ ์ด์
- Reasoning Anchor s์ ๋ํ์ฌ ์ ๋ต ๋ ์ด๋ธ์ด ๋ชจ๋ ๋์ผํ์ง ์์
- ํ๊ท (๊ธฐ๋๊ฐ)์ด s๋ง๋ค ๋ค๋ฆ
- Embedding Gradient ๋ฐฉํฅ์ด ๋ฌ๋ผ์ง
- โ ์ด๊ธฐ ๋จ๊ณ์์๋ถํฐ ๋น ๋ฅด๊ฒ ๋ถํ
์ผ๋ฐ์ ์ธ Task์์์ Transformer
Transformer์์์ Bias ์คํ ๊ฐ์
- MLP (Multi-layer Perceptron) ๋ชจ๋ธ์ด Noise Sequence์์ ์คํจํ๋ค๋ ์ ์์ ์ด๋ฌํ ์คํจ๊ฐ ์ ์ Transformer ๋ชจ๋ธ์ด ๋ ๋์ ์ ์ ๋ณด์ฌ์ฃผ์ง๋ง, ๊ทธ๋ผ์๋
Reasoning bias๊ฐ ์ ์ง๋๋์ง ๋ณด์ฌ์ฃผ๋ ๋ถ๋ถ
- Transformer์ ์๋ฒ ๋ฉ space๊ฐ Emb-MLP์์ ๋ณด์ธ ๊ฒ๊ณผ ์ ์ฌํ ํ์์ ๋ณด์ด๋์ง ์ดํด๋ณด๊ณ , ๋ชจ๋ธ์ด ์ ๋ ฅ์ผ๋ก๋ถํฐ ์ ๋ณด๋ฅผ ์ด๋ป๊ฒ ํฌ์ฐฉํ๋์ง ํ์ธํ๋ ์คํ
์คํ ๊ฒฐ๊ณผ
- ์์ ํ์์ด ํฉ์ฑ ๋ชจ๋ธ์ด ์๋ ์ค์ Transformer ๋ชจ๋ธ์์๋ ๊ทธ๋๋ก ๋ํ๋๋ค๋ ์ ์ ์ค๋ช ํ Figure
- B: Memory Anchor์ Reasoning Anchor์ PCA ๋ถํฌ์ ๋น๊ต
- C: ์ค์ Transformer Embedding๊ณผ ํฉ์ฑ ๋ชจ๋ธ์ ์ฝ์ฌ์ธ ์ ์ฌ๋ ๋น๊ต ๊ทธ๋ํ
- D: ์ด๋ก ์ ์ผ๋ก ๊ตฌ์ฑํ ์๋ฒ ๋ฉ์ PCA
- Embedding Space
- Transformer์
Embedding Space๋ Emb-MLP์ ๊ฑฐ์ ๋์ผ
Reasoning Anchor๋ ๊ณ์ธต์ , ์ฐ์์ ๊ตฌ์กฐ๋ฅผ ๋ณด์์ผ๋ฉฐ, ๊ฑฐ๋ฆฌ๊ฐ ๋ฉ์๋ก Cosine ์ ์ฌ๋๊ฐ ๊ฐ์ํจ์ ์ ์ ์์
Memory Anchor๋ ๋์กฐ์ ์ผ๋ก Alignment์ ๋ชจ๋ ์ ์ฌ์ฑ์ ๋ณด์
- PCA (์ฃผ์ฑ๋ถ๋ถ์)์์ ์ ์ฒด ์๋ฒ ๋ฉ ๊ณต๊ฐ์ ๊ตฌ์กฐ์ ํน์ฑ์ ๋ถ์ํจ
โAttention์ ์๋ฒ ๋ฉ์ ์๋ก ๋ง๋ค์ง ์๊ณ , ๊ธฐ์กด ์๋ฒ ๋ฉ์ bias๋ฅผ ํค์ฐ๋ ์ญํ ๋ง ํ๋ค๋ ๊ฒ์ ๋ณด์ฌ์ค
โ Transformer๋ Emb-MLP์ ๊ฐ์ ๊ฒฝํฅ์ฑ์ ๋ณด์ธ๋ค!
- Transformer์
- First Attention Module
- i๋ฒ์งธ token์ ์ถ๋ ฅ์ด ๊ทธ ์ด์ ๋ชจ๋ token์ ํ๊ท ์ด ๋๋ ๊ฒ
- Query-Key์ Dot product๊ฐ ์ ์ฐจ ๋์ผํด์ง
- ๊ฒฐ๊ณผ์ ์ผ๋ก mask๋ก ์ธํด Prefix Average ์ฐ์ฐ์ผ๋ก ๋จ
- ์ฒซ Attention Layer๋ ์ด๋ค ํ ํฐ์ด ์ค์ํ์ง ๊ณ ๋ฅด๋ ์ญํ ์ด ์๋ ์ง๊ธ๊น์ง ๋ฑ์ฅํ Token์ ๋์ ํ๋ ์ฅ์น
- ์ ๊ฒฝ์ฐ ๊ฐ์ฅ ํฐ Singular Value๊ฐ ๋๋จธ์ง๋ณด๋ค ํจ์ฌ ํผ
- Corresponding Singular Vector๊ฐ
Reasoning Anchor์๋ ๊ฐ๊น๊ฒ align๋์ง๋ง,Memory Anchor๋ ๊ฑฐ์ ์์ง์Reasoning Vector๋ ๊ฑฐ์ ๋๋ถ๋ถ W์์ ํฌ์ฐฉ๋๋ฉฐ, ๋ชจ๋ subsequent ํ ํฐ๋ค๋ก ์ ํ๋จ
- Second Attention Module
- ์ค์ํ ์ ๋ณด๊ฐ ์ด๋์๋์ง ์ฐพ๊ณ , ๋ง์ง๋ง ์ ๋ณด๋ฅผ ๋ชจ์ผ๋ ์ญํ
- [Definition 2] One-layer Transformer
- Layer Normalization์ final projection layer๋ ์ ์ธํจ
(๊ฒฐ๊ณผ์ ์ํฅ์ ๋ฏธ์น์ง ์๋ ์์์)
- ์์ ๊ด์ฐฐ ๊ฒฐ๊ณผ์ ๊ฐ์ด Small initialization scale์์๋ Attention A๊ฐ average๋ก ํด์๋ ์ ์์
- ์์ ์ด๊ธฐํ์ผ ๋โ Self-Attention์ด Prefix Average์ ๊ฑฐ์ ์ ์ฌํจ
- Q, K์ scale์ด ์๊ณ softmax ์ ๋ ฅ์ด ๊ฑฐ์0์ด๋ฉด softmax ๊ฒฐ๊ณผ๊ฐ ๊ฑฐ์ uniform
- ์์ฐจ์ ์ ๋ณด ๋์ ์ผ๋ก Reasoning์ ์ ๋ฆฌ
- Proposition 2
Memory Anchor๋ค์ Embedding ๋ฐฉํฅ์ด ๋น์ทํด์ง๊ณ ํ ๊ณณ์ผ๋ก ๋ญ์นจ
- Proposition 3
Reasoning Anchor์ ์๋ฒ ๋ฉ ์ ๋ฐ์ดํธ
Reasoning Token์ ์ด์ ํ ํฐ๋ค๊ณผ ์กฐํฉ๋์ด Label์ด ๊ฒฐ์ ๋จ
- Label ๋ถํฌ P๊ฐ s๋ฅผ ์ค์ฌ์ผ๋ก ํผ์ง ๋ถํฌ
- ์๋ฒ ๋ฉ์ด ์ ์ง์ ์ผ๋ก ๋ถํ
- Theorem1
- Reasoning Embedding์ ๊ทผ์ฌ ํํ
- ๊ฒฐ๊ณผ์ ์ผ๋ก Transformer์์๋ Reasoning bias๊ฐ ์ผ์ด๋จ์ ์ํ์ ์ผ๋ก ์ฆ๋ช ํจ
- Layer Normalization์ final projection layer๋ ์ ์ธํจ
Real Language Tasks
- โL: Reasoning Bias๋ฅผ ์ ๋ํํ ์งํ
๋ถ์: L(TinyStories)-L(PrOntoQA): ๋ task ๊ฐ์ loss์ฐจ์ด
๋ถ๋ชจ(L(PrOntoQA)): ์ถ๋ก ๊ณผ์ ์ loss๋ก ์ ๊ทํ
Reasoning ๊ณผ์ ๋๋น, Memory ๊ณผ์ ๊ฐ ์ผ๋ง๋ ๋ ์ด๋ ต๊ฒ ํ์ต๋๊ณ ์๋๊ฐ์ ๋ํ ์งํ
- โL์ด ์ฆ๊ฐํ๋ ๊ฒ์ ์๋ฏธ
- L(PrOntoQA)๊ฐ ์๋์ ์ผ๋ก ๋ ๋น ๋ฅด๊ฒ ๊ฐ์
- ๋ชจ๋ธ์ด ์ถ๋ก ๊ณผ์ ๋ฅผ ๋ ์ ํ์ตํ๋ ๋ฐฉํฅ์ผ๋ก biased
- ์ด ํ์์ ์์ธ
- GPT-2๋ ์์ ์ด๊ธฐํ scale๋ก ํ์ต๋จ
- ์ด๊ธฐ ํ์ต ๋จ๊ณ์์์ Representation์ด ๋ถํํ์์์ ์ ์ ์์
Conclusion
- ์์ ์ด๊ธฐํ๊ฐ ์ถ๋ก ์ ํธ bias๋ฅผ ๋ง๋ฆ
- Label distribution์ด ์๋ฒ ๋ฉ ๊ณต๊ฐ์ ๋ง๋๋๋ฐ ํต์ฌ ์ญํ ์ ํ๊ณ ํ์ต์ ์์ด ๋์ญํ์ ์ํฅ์ ๋ฏธ์นจ
- Next-token prediction training๊ณผ ๊ฐ์ ์ ์ฌํ task์ ํ์ฉ ๊ฐ๋ฅ
- ์คํ์ ๊ด์ฐฐ๊ณผ ์ด๋ก ์ ์ธ ์์์ผ๋ก ์ฆ๋ช ํจ























