What Happens During the Loss Plateau? Understanding Abrupt Learning in Transformers
Review
| ๋๋ค์ | ํ์คํ | ๋ณ์ (0/5) |
|---|---|---|
| ๋ง์คํนํ ์ดํ | abrupt learning ์ด๋ผ๋ ๋จ์ด๋ฅผ ์ฒ์ ๋ด. ์์ค ํ๋ฝ์ด ์ ์ฒด๋๋ค๊ฐ ๊ฐ์๊ธฐ ํฌ๊ฒ ์ผ์ด๋๋ ๊ฒ์ ๋ค๋ฅธ ํ์คํฌ ํ์ต ๋๋ ๋ช๋ฒ ๊ฒฝํํ ์ ์์์ง๋ง, ๊ทธ ์ด์ ๋ฅผ ํ์ ํด๋ณด๊ณ ํํค์ณ๋ณด๋ ์ฐ๊ตฌ๋ ์ฒ์ ๋ณด๋ ๊ฒ ๊ฐ์. ๋ฐ๋ณต์ ํ ํฐ์ด ์ ๋ฐ์ํ๋์ง, attention map์ ์ด๋ป๊ฒ ํ์ํ๊ณ ๊ฐ์ ํ๋์ง, ๊ทธ๋ฆฌ๊ณ ๊ทธ๊ฒ์ ํ์ธํ ๋ฐฉ๋ฒ์ ์ฐธ๊ณ ํ ๋งํด๋ณด์. | 4.0 |
| ๊ทค | Mamba ๊ฐ์ ๋ค๋ฅธ ์ํคํ ์ฒ์์๋ loss plateau์ ๊ฐ์ ํ์์ด ๋ํ๋๋์ง ๊ถ๊ธํจ. Transformer์ plateau๊ฐ attention์ด ์ฌ๋ฐ๋ฅธ ํ ํฐ ์ ๋ ฌ์ ์ฐพ์๊ฐ๋ ๊ณผ์ ์์ ์๊ธฐ๋ ๋ณ๋ชฉ์ด๋ผ๋ฉด, attention์ ์ฌ์ฉํ์ง ์๋ ๋ชจ๋ธ์์๋ plateau๊ฐ ๋ ์ฝํ๊ฒ ๋ํ๋๊ฑฐ๋ ์์ ๋ํ๋์ง ์์ ๊ฐ๋ฅ์ฑ๋ ์์๋ฏ | 3.7 |
| ๋๊น์ค | loss plateau๊ฐ ๋ฐ์ํ๋ค๋ ์ฌ์ค์ ๋ชฐ๋๋๋ฐ plateau์์ ํ์ต ์ข ๋ฃ๋ฅผ ๋๋ฌด ์ผ์ฐ ํด๋ฒ๋ฆฌ๋ฉด plateau๋ฅผ ๋ซ๊ธฐ ์ง์ ์ ํ์ต์ ๋ฉ์ถ ์ ์๊ฒ ๋ค(๋ค์ด์๋ฅผ ์์ ๋๊ณ ๊ด๋ฌผ์บ๊ธฐ ๋ฉ์ถ๋ ๊ฒ์ฒ๋ผ)๋ผ๋ ์๊ฐ๊ณผ ๊ทธ๋ผ ํ์ต ์ข ๋ฃ ํ์ด๋ฐ์ ์ธ์ ๋ก ์ก์์ผ๋์ง ๋ผ๋ ์๊ฐ์ด ๋ฆ | 3.7 |
| ์๋ฉด์ฅ์ | ์์ฌ 1ํ๊ธฐ๋ ์ฝ๋ ์ฐ์ตํ๋ค๊ฐ ์ค์ ๋ก repetition bias, loss plateau๋ฅผ ๊ฒฝํํ ์ ์ด ์๋๋ฐ, ๋
ผ๋ฌธ์์ ์คํผ์
ํ๊ฒ ์ธ๊ธํ๋ ์ ๊ธฐํ๊ณ ์ฌ๋ฐ๊ฒ ์ฝ์์ loss plateau๋ Transformer ๊ตฌ์กฐ์ ๋ฌธ์ ์ผ ๊ฐ๋ฅ์ฑ์ด ํฌ๋ค!! ๊ทผ 2๋ ์ด๋ด๋ก ์ถ์ธ๊ฐ mamba๋ก ์ฎ๊ฒจ๊ฐ๊ฒ ๊ตฌ๋ ! | 4 |
| ์ด์ดํฐ | ํ๋ จ ์ด๊ธฐ ๋จ๊ณ์์ ํ๋ จ ๋ฐ์ดํฐ์ ์ํฅ์ด ํฌ๋ค๋ ์ด๋ฒ์ฃผ ๋ค๋ฅธ ์คํฐ๋ ๋ ผ๋ฌธ๊ณผ ์ฐ๊ฒฐ์ง์ด ์๊ฐํ๊ฒ ๋๋ค (training data temporal dependence ๋ ผ๋ฌธ). ํ์์ ๋ฐํ๋ด๋ ๋ฐ ์ง์คํ๋๋ฐ ํ์๋ค์ ์ด์ ๊ฐ ๋ ๊ถ๊ธํด์ง๋ค | 3.7 |
| ์ฌ๊ณผ | Transformer ๊ธฐ๋ฐ ๋ชจ๋ธ๋ค์ ์คํํ๋ฉด์ ๊ฐ์๊ธฐ loss๊ฐ ์ฆ๊ฐํ์ฌ ์ด์ํ๋ค๊ณ ์๊ฐํ ์ ์ด ๋ง์๋๋ฐ, loss plateau์์ ์ ์ ์์๋ ๋ ผ๋ฌธ. ์์ผ๋ก์ ์คํ์์ loss ํ์ ์์ ์ ์กฐ์ ํ ์ ์์๋ฏ. | 4.7 |
| 7์ผ | Plateau ๊ตฌ๊ฐ์์ ๊ฒ์ผ๋ก ๋ณด์ด์ง ์๋ representation ๋ณํ๋ฅผ ์ง๊ด์ ์ผ๋ก ์คํํ ๋ถ๋ถ์ด ์ธ์์ ์. MWS ํ์คํฌ์ ์กด์ฌ๋ฅผ ์๊ฒ๋๋๋ฐ, ํน์ task ์ฑ๋ฅ์ด ๊ฐ์๊ธฐ ๋ฌด๋์ง๊ธฐ ์ง์ ์ signal์ ํ์งํ๋๋ก ํ์ฉ๊ฐ๋ฅํด๋ณด์. Catastrophic forgetting๋ ์ค์ผ ์ ์์ง ์์๊น? | 4.4 |
TL; DR
๐ก
Transformer ๋ชจ๋ธ ํ๋ จ ์ ์์คํ๋ฝ์ด ์ด๊ธฐ๋จ๊ณ์์ ์ ์ฒด๋๋ค๊ฐ ๊ฐ์๊ธฐ ํฌ๊ฒ ์ผ์ด๋๋ abrupt learning ํ์ ํ๊ตฌ
Summary
Motivation
- Transformers๋ฅผ ์ํ ํน์ ์๊ณ ๋ฆฌ์ฆ ํ์คํฌ์ ํ๋ จํ ๋ ๋ณด์ด๋ abrupt learning (๊ฐ์์ค๋ฌ์ด ํ์ต) ํ์
- : ๋ชจ๋ธ์ ์ฑ๋ฅ์ด ์ค๋ซ๋์ ์ ์ฒด๋์๋ค๊ฐ ๊ฐ์๊ธฐ ๊ธ๊ฒฉํ๊ฒ ํฅ์๋๋ ํ์
- ๋ณธ ๋ ผ๋ฌธ์ ํ๋ จ ์ ์ด๋ฌํ ํ์์ ๋ณดํธ์ ์ธ ํน์ฑ๊ณผ ๊ธฐ๋ณธ์ ๋ฉ์ปค๋์ฆ์ ๋ฐํ๊ณ ์ ํจ
Contribution
์ํ Transformers๋ฅผ ๊ฐ๋จํ ์๊ณ ๋ฆฌ์ฆ ํ์คํฌ๋ก ํ๋ จํ์ฌ abrupt learning๊ณผ ๊ด๋ จ๋ ์ฌ๋ฌ ํ์ ํ๊ตฌํ๊ณ ์ ํจ
- ์ฌ์ฉ ํ์คํฌ: moving-window-sum (MWS)
- ๏ปฟ ์ํ์ค๊ฐ ์ฃผ์ด์ง๋ฉด, ๏ปฟ ์ดํ ๏ปฟ ์ ์ถ๋ ฅํด์ผ ํจ
- ๏ปฟ๋ 0, 1, 2, โฆ, 17 ์ค ํ๋์ ์ซ์
- ๏ปฟ์ ๏ปฟ ๊ทธ๋๋ก, ๏ปฟ๋ ๏ปฟ๋ฅผ ๏ปฟ(=17)๋ก ๋๋ ๋๋จธ์ง, ๏ปฟ์ ๏ปฟ, โฆ
- โ ground truth๊ฐ ์ ์๋ ค์ ธ ์๋ ํ์คํฌ๋ก์ ๋ชจ๋ธ์ ํ๋ จ ์งํ ์ ๋๋ฅผ ์ ํํ ์ธก์ ํ ์ ์์
- ๏ปฟ ์ํ์ค๊ฐ ์ฃผ์ด์ง๋ฉด, ๏ปฟ ์ดํ ๏ปฟ ์ ์ถ๋ ฅํด์ผ ํจ
- ๋ชจ๋ธ ์ํคํ
์ณ: 1-layer, 1-head Transformer
- ์ด ๊ตฌ์กฐ๋ก๋ ์ฃผ์ด์ง ํ์คํฌ ์๋ฒฝํ ์ํ ๊ฐ๋ฅ
- ๏ปฟ : ์ ๋ ฅ ํ ํฐ ์ํ์ค
- ๏ปฟ : 2-layer NN
- ๏ปฟ : residual connection
- ๏ปฟ : linear layer, mapping hidden state to logits
- greedy decoding ์ฌ์ฉ
- โ ์์ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ๋ด๋ถ ๋ฉ์ปค๋์ฆ์ ์ฝ๊ฒ ๋ถ์ํ๊ณ ํด์ํ ์ ์์
- ์ด ๊ตฌ์กฐ๋ก๋ ์ฃผ์ด์ง ํ์คํฌ ์๋ฒฝํ ์ํ ๊ฐ๋ฅ
- ํ๋ จ
- ๏ปฟ์ ์ ์ฒด ์ํ์ค์ ๋ํด next-token-prediction cross-entropy loss ์์ค ์ต์ํํ๋๋ก ํ๋ จ
- 256๊ฐ ํ๋ จ ์ํ๋ก 1 ์ํญ ํ๋ จ
- ์ ํ๋ ์ธก์ : output ๋ถ๋ถ์ธ ๏ปฟ์ ๏ปฟ๊ฐ ํ ํฐ์ ๋ํ ์์ธก ์ ํ๋ ํ๊ท
- ๏ปฟ์ ์ ์ฒด ์ํ์ค์ ๋ํด next-token-prediction cross-entropy loss ์์ค ์ต์ํํ๋๋ก ํ๋ จ
- abrupt learning
- ํ๋ จ ์ค training loss๊ฐ sub-optimal ๊ฐ์์ ์๋นํ ๋ง์ step๋์ ์ ์ง๋๋ค๊ฐ ๊ธ๊ฒฉํ ์ ํ๋ ์ฆ๊ฐ์ ์์ค ํ๋ฝ
- โ optimal solution์ ๊ฐ์๊ธฐ ํ์ตํ๋ abrupt learning ํ์ ํ์ธ
- attention map
- ์คํ ํ์คํฌ์ ๋ํด์ ์ต์ ์ attention pattern์ ๊ฐ output token ๏ปฟ๊ฐ ์์ ์ ๊ณ์ฐํ๋๋ฐ ๊ด๋ จ์๋ input token๋ง์ attendํ๋ ๊ฒ
- ๏ปฟ์ ๏ปฟ์ attend, ๏ปฟ๋ ๏ปฟ์ attend, โฆ
- attention progress measure (APM) ์ฌ์ฉํด attention pattern ๋ณํ ๊ด์ฐฐ
- ๏ปฟ : ๏ปฟ-th output token ๊ณ์ฐํ ๋ ๏ปฟ-th token์ ํ ๋น๋๋ attention score
- ๏ปฟ : ์ต์ attention map์ position pair set
- โ ํ๋ จ ์ค APM์ด 0๋ถํฐ 0.8 ์ ๋๊น์ง ๋จ์กฐ ์ฆ๊ฐํ๋ฉฐ, ์์ค/์ ํ๋๋ณด๋ค ์๋งํ ๋ณํ ๊ณก์ ๋ณด์
- step 150 ์ฏค์์ ๊ธ๊ฒฉํ ์์ค ํ๋ฝ ์๋๋ฐ, ๊ทธ ์ด์ ๋ถํฐ APM์ ์ด๋ฏธ ์๋นํ ์ฆ๊ฐํจ
- โ ์์ค ํ๋ฝ์ ๊ธ๊ฒฉํ์ง๋ง ๊ทธ ์ด์ ๋ถํฐ attention pattern learning์ ์ ์ง์ ์ผ๋ก ์งํ๋จ
- ์คํ ํ์คํฌ์ ๋ํด์ ์ต์ ์ attention pattern์ ๊ฐ output token ๏ปฟ๊ฐ ์์ ์ ๊ณ์ฐํ๋๋ฐ ๊ด๋ จ์๋ input token๋ง์ attendํ๋ ๊ฒ
- ์ฌ์ฉ ํ์คํฌ: moving-window-sum (MWS)
Transformers ํ๋ จ์ ์ด๊ธฐ ์ ์ฒด๊ธฐ (early loss plateau period) ๋์ ๋ชจ๋ธ์ ์ข ์ข
partial solution์ ํ์ตํจ- e.g., moving-window-sum ํ์คํฌ์์ ์ฒซ๋ฒ์งธ ์
๋ ฅ ํ ํฐ ๏ปฟ๋ฅผ ๊ทธ๋๋ก ์ถ๋ ฅํ๋ฉด ๋๋ ๏ปฟ ์์ธก์ ๋น ๋ฅด๊ฒ ํ์ตํ์ง๋ง, ์ ์ฒด ์์ค์ ์ฌ์ ํ ๋์ผ๋ฉฐ ์ดํ ํ ํฐ์ ๋ํ ์ ํ๋ ๋จ์ด์ง
- ์ฒซ๋ฒ์งธ ์ถ๋ ฅ ํ ํฐ ์์ธก ์ ํ๋์ธ partial solution accuracy๊ฐ ๋น ๋ฅด๊ฒ ์ฆ๊ฐ
- ๋ฐ๋ฉด ์ ์ฒด loss๋ ๋ง์ ํ๋ จ ์คํ ์ดํ ํ๋ฝ
- โ ์ด๊ธฐ ์ ์ฒด๊ธฐ ๋์ ์ ์ฒด loss๋ ํฌ๊ฒ ์ค์ง ์์ง๋ง ๋ชจ๋ธ์ partial solution ํ์ต ์งํ๋จ
- e.g., moving-window-sum ํ์คํฌ์์ ์ฒซ๋ฒ์งธ ์
๋ ฅ ํ ํฐ ๏ปฟ๋ฅผ ๊ทธ๋๋ก ์ถ๋ ฅํ๋ฉด ๋๋ ๏ปฟ ์์ธก์ ๋น ๋ฅด๊ฒ ํ์ตํ์ง๋ง, ์ ์ฒด ์์ค์ ์ฌ์ ํ ๋์ผ๋ฉฐ ์ดํ ํ ํฐ์ ๋ํ ์ ํ๋ ๋จ์ด์ง
์ ์ฒด๊ธฐ๋์ ๋ชจ๋ธ์ด ๋ฐ๋ณต์ ํ ํฐ์ ์ถ๋ ฅํ๋ ๊ฒฝํฅ์ธ
repetition bias๊ฐ ๊ฐํ๊ฒ ๋ํ๋จ- repetition frequency: repetition bias ์ ๋ํ ์งํ
- ๊ฒฐ๊ณผ
- repetition frequency๊ฐ ํ๋ จ ์์ ์ ์์๋ค๊ฐ ์ฒ์ 50 ์คํ ๋์ 0.8๊น์ง ์์น
- โ ์ด๊ธฐ ์ ์ฒด๊ธฐ์ ๊ฐํ repetition bias ํ์ธ
output repetition bias๋ ๋ค๋ฅธ ํ ํฐ์ ๋ํ hidden representation์ด ๊ฑฐ์ ๋๋ฑํ๊ฒ ๋๋
representation collapse๋ฅผ ๋๋ฐํจ- ์ถ๋ ฅ ์์น ๏ปฟ์์ hidden representation ๊ฐ pairwise cosine similarity:
- ๏ปฟ : ๏ปฟth ํ ํฐ์ hidden representation (logit ๋ณํ ์ง์ )
- cosine similarity๊ฐ ํ๋ จ ์ด๊ธฐ ๋จ๊ณ์์ ๊ธ๊ฒฉํ ์ฆ๊ฐ
- โ partial solution์์ ์ ํํ ์์ธก๋๋ ์ฒซ๋ฒ์งธ ์ถ๋ ฅ ์์น๋ฅผ ์ ์ธํ๊ณ ๋, ํ๋ จ ์ด๊ธฐ ๋จ๊ณ์์ ์ฌ๋ฌ ์ถ๋ ฅ ์์น์ hidden representation์ด ๊ฑฐ์ ๋๋ฑํด์ง
- ์ถ๋ ฅ ์์น ๏ปฟ์์ hidden representation ๊ฐ pairwise cosine similarity:
attention map learning์ด repetition, representation collapse๊ณผ loss plateau ํ์ฑ์๋ ์ค์ํ ์ญํ ํจ์ ๋ณด์- ์ต์ ์ ์ดํ ์ ๋งต์ ํฅํด(๋๋ ๋ฐ๋๋ก) ํธํฅ์์ผ ๋ฐ๋ณต, ํํ ๋ถ๊ดด, ์์ค ์ ์ฒด๊ฐ ๊ฐ์(๋๋ ์ฆํญ)๋๋์ง ํ์ธ
- ๏ปฟ ์ ๋ํด attention mask ๏ปฟ๋ฅผ ๏ปฟ๋ก ์ค์ : ๋๋จธ์ง ๊ฒฝ์ฐ ๏ปฟ
- ๊ธฐ์กด attention์ ์ด attention mask์ hadamard ๊ณฑ ํจ์ผ๋ก์จ ๋ณํํ์ฌ ํ๋ จ๊ณผ ์ถ๋ก ์ ์ฌ์ฉ
- ๏ปฟ ์ด๋ฉด optimal attention map ๋ฐฉํฅ์ผ๋ก ํธํฅ์ํค๊ณ , ๏ปฟ ์ด๋ฉด ๋ฐ๋ ๋ฐฉํฅ์ผ๋ก ํธํฅ์ํค๋ ๊ฒ
- ๏ปฟ : representation collapse ์ํ์ ๋ชจ๋ธ์ด ๋ณด๋ค ์ค๋ ๋จธ๋ฌผ๊ณ ๋ณด๋ค ์ดํ์ ์๋ ด, repetition frequency ๋ํ ์ ์ฒด๊ธฐ๋์ ํฌ๊ฒ ์ ์ง๋จ
- โ attention map learning์ด repetition, representation collapse, loss plateau ํ์ฑ์ ์ค์ํ ์ญํ ํจ
์ํ Transformer๋ง์ด ์๋๋ผ ์ค์ LLM์ ์ฌ์ ํ๋ จ ์ด๊ธฐ ๋จ๊ณ์์๋ ์ด๋ฌํ repetition bias, representation collapse๊ฐ ๋ํ๋๋์ง ํ์ธ
- LLM: Pythia, OLMo-2 (open-source)
- task: ARC-Easy ํ
์คํธ ๋ฐ์ดํฐ์์ ๋๋ค ์ ์ ํ 100๊ฐ ์ง๋ฌธ (์ด๋ฑํ๊ต ์์ค ๊ณผํ ๊ฐ๊ด์ ์ง๋ฌธ)
- ๊ฐ ์ง๋ฌธ์ ๋ํด 8๊ฐ ํ ํฐ์ ์์ฑํ๊ฒ ํ๊ณ hidden representation์ pair-wise cosine similarity ๊ณ์ฐ
- 14M, 1B, 1.4B, 2.8B Pythia ๋ชจ๋ธ์ ์ด๊ธฐ ํ๋ จ ๋จ๊ณ์์ ์ถ๋ ฅ ์ํ์ค์ repetition bias ๋ฐ๊ฒฌ
- ์ด๊ธฐํ ์ ํ๊ท ์ฝ์ฌ์ธ ์ ์ฌ๋๊ฐ ๋น๊ต์ ๋ฎ์ผ๋ (0.4~0.65), ๋ชจ๋ ์ฌ์ด์ฆ์ ๋ชจ๋ธ์ ๋ํด ๋ช ์คํ ๋ง ํ์ตํด๋ 0.9 ์ด์์ผ๋ก ๊ธ๊ฒฉํ ์ฆ๊ฐ
- OLMo-2 7B ๋ชจ๋ธ์์๋ Pythia์ ์ ์ฌํ representation collapse ํ์ ๊ด์ฐฐ
- 150์คํ ์ ์ด๊ธฐ ํ๋ จ ๋จ๊ณ์์ representation ํ๊ท ์ฝ์ฌ์ธ ์ ์ฌ๋๋ 0.93
- 600์คํ ์์๋ 0.43์ผ๋ก ๊ฐ์
- โ repetition bias, representation collapse๊ฐ LLM ์ด๊ธฐ ์ฌ์ ํ๋ จ ๋จ๊ณ์์ ์ค์ ๋ก ๋ฐ์ํ๋ ํ์์
Conclusion
- Transformer ํ๋ จ์ ์ด๊ธฐ๋จ๊ณ์์ repetition bias์ representation collapse๊ฐ ๋ฐ์ํ๋ฉฐ, ์ด๋ loss plateau์ ๋ฐ์ ํ ๊ด๋ จ์ด ์์
- loss plateau๊ฐ ์ต์ ์ attention map ํ์ํ๋ ๊ณผ์ ์ผ ๊ฐ๋ฅ์ฑ ์์
- ๋ณธ ๋ ผ๋ฌธ์ ๋ฐ๊ฒฌ์ ์ด์ด representation collapse์ ๊ฐ์ ํ์์ด๋, attention map์ ๋๋ฆฐ learning์ ์์ธ์ ๋ํ ํฅํ ์ฐ๊ตฌ๊ฐ ์ ๋ง











