Judge Decoding: Faster Speculative Sampling Requires Going Beyond Model Alignment
Review
| ๋๋ค์ | ํ์คํ | ๋ณ์ (0/5) |
|---|---|---|
| ์ฐฐ๋ | LLM-as-judge ๊ฐ ๊ฐ์ง ๋ฌธ์ ์ค ๊ฐ ๋ชจ๋ธ์ ์์ฑ๊ณผ ์ ๋ ฌ๋๋ ๊ธฐ์ค์ผ๋ก ํ๊ฐํ๋ค๋ ๋ฌธ์ ๋ ์ ์๋ ค์ง ๊ฒ ๊ฐ์. ์คํ์ผ๋ก ์ธํ ์ํฅ ๋ฑ์ด ์์๋ ๊ฒ ๊ฐ์๋ฐ, ์ง๊ด์ ์ผ๋ก ์ด ํ์์ ์ ๊ท๋ช ํ๊ณ ๊ฐ์ ธ์จ ๊ฒ ๊ฐ์. ์ข์ ์์ด๋์ด๋ ์ฌํํ ๋ฐฉ๋ฒ๋ก ์ผ๋ก, ๋ช ํํ๊ฒ ํด๊ฒฐ๋ ์ ์๋ค๋ ๊ฒ์ ๋๊ผ์. | 4.3 |
| ์์ฌ๋น๊ฝ๊ฒ๋ | ๊ธฐ์กด์ speculative decoding์ 'draft๊ฐ target๊ณผ ์ผ๋ง๋ ๋น์ทํ๊ฐ'์ ์์กดํด์ alignment์ชฝ์ ์ฐ๊ตฌํ๋ค๋ฉด, ํด๋น ๋ ผ๋ฌธ์ ๊ด์ ์ ์ข ๋ฐ๊ฟ์ '์ด ํ ํฐ์ด ๋ฐ์๋ค์ฌ์ง์ง๋ฅผ ์์ธก'ํ๋๊ฒ์ผ๋ก ๋ฌธ์ ๋ฅผ ๋ฐ๊ฟ. ๋์ผํ ๋ฌธ์ ๋ฅผ ์๋ก์ด ๊ด์ ์ผ๋ก ๋ฐ๋ผ๋ณด๋ ์๊ฐ์ด ํ์ํ๊ฑฐ ๊ฐ์ | 4 |
| ๋ฉ๊ฐ์ปคํผ | ๊ฒ์ฆ ์์ฒด์ ๊ด์ ์ ๋ฐ๊ฟ์ ์คํ์ ํ๋ค๋ ์ ์์ Novelty๊ฐ ์๋ค. ์คํ์์ ํ์คํฌ์ ๋ํ accuarcy๋ฅผ ์ ์งํ ์ฑ acceptance๋ฅผ ๋์๋ค๋ ์ ์์ ๋ ผ๋ฌธ์ ์ค๋๋ ฅ๊ณผ ๋ฐฉ๋ฒ๋ก ์ ์ ๋น์ฑ์ ๋์์ ์ข๋ค | 4.1 |
| ์๋ฆฌ๊ดด๋ฌผ | ๋ณดํต LLM-as-judge๋ ์์ฒญ ๊ธด ํ๋กฌํํธ์ CoT ๊ธฐ๋ฐ์ ๋๋ฆฐ ์ถ๋ก ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋๋ฐ... ๋จ์ํ๊ฒ ์์ ์ด์ง ๋ถ๋ฅ๊ธฐ ํ๋๋ก ๋น ๋ฅด๊ฒ ํ์ต์ด ๊ฐ๋ฅํ๊ฒ ํ ์ ์ด ๋๋๋ค. ์ค์ง์ ์ผ๋ก ์ด๋ฐ์ reject๋๋ ํ ํฐ์๊ฐ ํ์ฐํ ์ค๊ฒ ๋ค | 4.3 |
| ์์ฐ๊นก | ์ง์ ํ๋ ๊ธฐ์กด ๋ฐฉ์์ ํ๊ณ์ ํด๊ฒฐ์ฑ ์ด ์์ ๋ฉ๋์ด ๊ฐ๋ค. ์ ํ๋๋ณด๋ค judge ๋ชจ๋ธ์ ์ ํธ๋ ๊ธฐ์ค์ผ๋ก ๋์ฝ๋ฉ ๊ฒฐ์ ํ ์ ์๋ ๋ฌธ์ ๋ฅผ ์ ํ๋ ๋ถ๋ฅ๊ธฐ๋ก ํด๊ฒฐํ๋๋ฐ, ์ด๋ฐ ๋ฐฉ๋ฒ์ ๋์ฝ๋ฉ์ ์ ์ฉํ๋ค๋ ๊ฒ ์๋ก์ | 4.4 |
| ๊ณ ๊ตฌ๋ง๋ง๋๋ฆฌ | LLM์ ๋ณธ์ง์ ์ธ ํน์ฑ, ๊ธฐ์กด ๊ฒ์ฆ ์ฐ๊ตฌ์ ํ๊ณ์ ์์ ์์ํด์ ์คํ, ์ธ์ฌ์ดํธ๊น์ง ๋ ผ๋ฆฌ์ ์ด๊ณ ์ ๊ตํ๊ณ ๋ ์ ์ฉํ๋ค! ๊ทธ์น ์ ์ด์ ๋ถ์์ ํ ๊ฒ๊ณผ alignํ์ฌ ํ๊ฐํ๋๊ฒ ์ด์ํ๊ธด ํ๋ค! | 5 |
| ์์ฑ์ฌ | Motivation, Technical soundness, performance, Research impact ์๋ฒฝํฉ๋๋ค. ์ธ์ด๋ชจ๋ธ ๋ต๊ฒ softํ๊ฒ ์ฒ๋ฆฌํ์๋ ์์ด๋์ด๊ฐ ๊ทธ์ค์์๋ ๋๋ณด์ด๋ค์. ์์กด์ ๋๋ค. | 5 |
| ์คํ๋ฒ ์ค | Embedding ์์ ์ด์ง ๋ถ๋ฅ๊ธฐ๋ฅผ ๋ถ์์ผ๋ก์จ ๊ธฐ์กด์ ๋๋ฆฐ ๊ฒ์ฆ ๋ฌธ์ ๋ฅผ ํ ๋ฒ์ ํด๊ฒฐํ ์ ์๋ค๋๊ฒ Novelty๊ฐ ์๋ ๊ฒ ๊ฐ๋ค. ๋ฐฉ๋ฒ๋ก ์์ฒด๋ ๋จ์ํ์ง๋ง, ๋ช ํํ๊ณ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ ์ผ๋ค๋ ๊ฒ์ด ์๋ฏธ๊ฐ ํฐ ๊ฒ ๊ฐ๋ค. | 4.8 |
TL; DR
Speculative Decoding์์ ๋ฐ์ํ๋ ๋ณ๋ชฉ์ด Target model์ ์ ๋ ฌ(alignment) ๊ธฐ๋ฐ ๊ฒ์ฆ ๋๋ฌธ์์ ๋ฐํ๊ณ , Target model์ ์๋ฒ ๋ฉ์ผ๋ก ํ ํฐ์ ์ ๋ต์ฑ(correctness)์ ํ์ ํ๋ ์๋ก์ด ๊ฒ์ฆ ๋ฐฉ์์ธ Judge Decoding ๋ฐฉ์์ ๋์ ํจ!
- ์ ์
- ๋ฉํ ์ํผ ์ธํ ๋ฆฌ์ ์ค ๋ฉ (์์ 5๋ช ), ์ทจ๋ฆฌํ ์ฐ๋ฐฉ ๊ณต๊ณผ๋ํ๊ต, ์ํธ๋กํฝ, GenAI, MAI
- cited: 28
Preliminary: Speculative Decoding
paper: https://arxiv.org/abs/2211.17192
์ LLM์ ๋๋ฆฐ๊ฐ
- LLM์ ๊ตฌ์กฐ์ ์ธ ํ๊ณ: Auto-regressive ํ ๋์ฝ๋ฉ ๋ฐฉ์
- ๋จ์ด ํ๋๋ฅผ ์ถ๋ ฅํ ๋๋ง๋ค ์์ฒญ๋ ์์ ๋ฉ๋ชจ๋ฆฌ์ ๊ณ์ฐ์ด ํ์
Speculative Decoding์ผ๋ก ํด๊ฒฐํด๋ณด์!
- ๋ชฉํ: LLM์ Inference Speed ํฅ์
- ํต์ฌ ์์ด๋์ด: ์์ฑ์ ๋๋ฆฌ์ง๋ง, ๊ฒ์ฆ์ ๋ณ๋ ฌ๋ก ๋น ๋ฅด๊ฒ ํ ์ ์๋ค
- ๊ตฌ์ฑ์์
- Draft Model (DM): ์์ฃผ ๋น ๋ฅด์ง๋ง ์ฝ๊ฐ ๋ ๋๋ํ ํ์ ์ญํ . ๋๋ต์ ์ธ ์ด์์ ๋น ๋ฅด๊ฒ ์์ฑ
- Target Model (TM): ๋๋ฆฌ์ง๋ง ์์ฃผ ๋๋ํ๊ณ ์ ํํ ์ ์๋ ์ญํ . Draft ๊ฒฐ๊ณผ๋ฅผ ๊ฒํ
- ์๋์๋ฆฌ
- Drafting: DM์ด ๋จผ์ ๋ฌธ์ฅ์ ๋ท๋ถ๋ถ์ ์ถ์ธกํด์ ๋ฏธ๋์ ์ฌ ํ ํฐ ๏ปฟ๊ฐ(์: 4๊ฐ)๋ฅผ ๋น ๋ฅด๊ฒ ์์ฑ
- e.g., "The cat is [sitting on the mat]" (๊ดํธ ์์ด DM์ด ์ถ์ธกํ ๋ถ๋ถ)
- Verification: TM์ด DM์ด ์์ฑํ ๏ปฟ๊ฐ์ ํ ํฐ์ ์
๋ ฅ ๋ฐ๊ณ ํ ๋ฒ์(Parallel) ์ฐ์ฐ์ผ๋ก ๏ปฟ๊ฐ์ ํ ํฐ ํ๋ฅ ์ ํ๋ฒ์ (Foward Pass)๋ก ๊ณ์ฐ
- ์์ฑ(Generation)์ ์์ฐจ์ ์ด์ด์ผ ํ์ง๋ง, ๊ฒ์ฆ(Verification)์ ๋ณ๋ ฌ๋ก ํ ์ ์์
- Teacher forcing ๋ฐฉ์์ผ๋ก ์ ์ฒด ๋ฌธ๋งฅ์ ํ ๋ฒ์ ํ์ธ
- teacher forcing: target word(Ground Truth)๋ฅผ ๋์ฝ๋์ ๋ค์ ์ ๋ ฅ์ผ๋ก ๋ฃ์ด์ฃผ๋ ๊ธฐ๋ฒ
- Accept/Reject: ๋ง์ฝ TM์ด DM์ ์ถ๋ ฅ์ด ์ณ๋ค๊ณ ํ๋จํ๋ฉด, TM์ ์น์ธ(Accept)๋ง ํ๋ฉด ๋จ (์๊ฐ ์ ์ฝ!)
- ์ค๊ฐ์ ํ๋ฆฐ ๋ถ๋ถ์ด ์๋ค๋ฉด(e.g., DM์ "mat"๋ผ๊ณ ์ผ๋๋ฐ TM์ "sofa"๋ผ๊ณ ์๊ฐํจ), ๊ทธ ์ดํ์ ํ ํฐ์ ๋ชจ๋ ๋ฒ๋ฆฌ๊ณ (Reject), ํด๋น ์ง์ ๋ถํฐ ๋ค์ ์์ฑํจ
- Drafting: DM์ด ๋จผ์ ๋ฌธ์ฅ์ ๋ท๋ถ๋ถ์ ์ถ์ธกํด์ ๋ฏธ๋์ ์ฌ ํ ํฐ ๏ปฟ๊ฐ(์: 4๊ฐ)๋ฅผ ๋น ๋ฅด๊ฒ ์์ฑ
Introduction
- Scaling Law์ ํ์ค์ ์ธ ๋ฌธ์
- Meta๋ ์ต๊ทผ 4,050์ต ๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง ์ฌ์ ์ต๋ ๊ท๋ชจ, ์ต๊ณ ์ฑ๋ฅ์ ๋ชจ๋ธ์ธ Llama-3.1-405B๋ฅผ ๊ณต๊ฐํจ
- ์ด๋ฐ ๋ํ ๋ชจ๋ธ๋ค์ ๋ฐฐํฌ์ ๋ง๋ํ ์์์ ์๊ตฌํ๋ฉฐ, ์ถ๋ก ํจ์จ์ฑ์ด ์ค์ํ ๋ฌธ์ ๋ก ๋ ์ค๋ฆ
- ์ด๋ฐ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด Speculative Decoding (SD)์ด ์ ์๋จ
- Speculative Decoding์ ๋ฌธ์ ์
- ๊ธฐ์กด Standard Speculative Decoding์ ๊ฒ์ฆ ๋ฐฉ์:
- ์ด ํ ํฐ์ด ๋ฌธ๋งฅ์ ์ผ๋ก ์ข์๊ฐ?๊ฐ ์๋๋ผ ์ด ํ ํฐ์ด TM์ด ๊ณ ๋ฅผ ํ ํฐ๊ณผ ์ผ๋ง๋ ์ผ์น(alignment)ํ๋๊ฐ?๋ก ํ๋จํจ!
๋ฌธ์ ์ DM์ด ์ถฉ๋ถํ ๋ง๋ ๋ต์ ์์ฑํด๋ TM๊ณผ Align์ด ์๋๋ฉด ์ด๋ฐ์ reject๊ฐ ์์ฃผ ๋ฐ์
- ๊ทธ๋์ ์ค์ ํ๊ฒฝ์์ M (DM์ด ๋ฏธ๋ฆฌ ๋ฝ๋ ํ๋ณด ํ ํฐ ์) ๋ฅผ ํฌ๊ฒ ๋ชป ํค์ (๋ณดํต 5~7 ์ ๋์ ์์ ๊ฐ์ผ๋ก ์ค์ )
- ์ด ํ ํฐ์ด ๋ฌธ๋งฅ์ ์ผ๋ก ์ข์๊ฐ?๊ฐ ์๋๋ผ ์ด ํ ํฐ์ด TM์ด ๊ณ ๋ฅผ ํ ํฐ๊ณผ ์ผ๋ง๋ ์ผ์น(alignment)ํ๋๊ฐ?๋ก ํ๋จํจ!
RQ๊ฒ์ฆ ๊ณผ์ ์ TM๊ณผ์ ์ ๋ ฌ(Alignment)์ด ์๋ ํ ํฐ ์์ฒด์ ํ์ง์ ํ๊ฐํ๋๋ก ๋ฐ๊ฟ ์ ์์๊น?- ํต์ฌ ์์ด๋์ด: LLM-as-a-judge
- LLM judge๋ ์ ์ฐํ ๋ฐฉ์์ผ๋ก ๋ต๋ณ์ ํ๊ฐํจ
- target๊ณผ ์์ ํ ์ ๋ ฌ(Align)๋์ง ์์๋๋ผ๋ ์ฌ๋ฐ๋ฅธ(Correct) ์๋ต์ ๊ธ์ ์ ์ผ๋ก ํ๊ฐ
โ LLM judge๋ฅผ ํตํด alignment๊ณผ correctness๋ฅผ ๊ตฌ๋ถํด๋ณด์!
- ํต์ฌ ์์ด๋์ด: LLM-as-a-judge
- ๊ธฐ์กด Standard Speculative Decoding์ ๊ฒ์ฆ ๋ฐฉ์:
Contribution
- ๊ธฐ์กด SD๊ฐ ๊ณ ํ์ง ํ ํฐ์ ๋ง์ด Rejectํ๋ค๋ ํ๊ณ๋ฅผ ์คํ์ ์ผ๋ก ์ ์ฆ
- LLM-as-a-judge ๊ฐ๋ ์ SD ๊ฒ์ฆ์ ์ ์ฉ
- Llama-8B/70B-Judge๋ก ์ต๋ 9๋ฐฐ์ ์๋ ํฅ์, Llama-405B ์์ค ํ์ง ์ ์ง
Judge Decoding
๊ธฐ์กด Speculative Decoding์ ๊ฒ์ฆ ๋ฐฉ์
์ฉ์ด ์ ๋ฆฌ
- ๏ปฟ : target model
- ๏ปฟ : draft model
- ๏ปฟ: ์ดํ ์งํฉ V
- ๏ปฟ: ํ๋ณด ํ ํฐ์ ๊ฐ์
- ๏ปฟ: ์ค์ Accept๋ ํ ํฐ์ ๊ฐ์
- ๏ปฟ: ํ์ฌ ๋ฌธ๋งฅ(Context)
- ๏ปฟ
- ๋ฌธ๋งฅ s๊ฐ ์ฃผ์ด์ก์ ๋, LLM์ผ๋ก๋ถํฐ m๊ฐ์ ํ ํฐ์ auto-regressive ํ๊ฒ ์ํ๋งํ ๊ฒฐ๊ณผ
- ๏ปฟ: ๏ปฟ ๋ฒ์งธ๋ก ์์ฑ๋ ํ ํฐ
- ๏ปฟ: ํด๋น ์์ ์ softmax ๋ถํฌ
- ๋ฌธ๋งฅ s๊ฐ ์ฃผ์ด์ก์ ๋, LLM์ผ๋ก๋ถํฐ m๊ฐ์ ํ ํฐ์ auto-regressive ํ๊ฒ ์ํ๋งํ ๊ฒฐ๊ณผ
- ๏ปฟ
- ๋ฌธ๋งฅ ๏ปฟ ์ ํ ํฐ ๏ปฟ์ ํ ๋ฒ์ ์ ๋ ฅ
- Target ๋ชจ๋ธ์ ๋ณ๋ ฌ forward pass๋ก ์คํ
- ์ด ๋์ ๊ฐ ์์น์์์ ํ๋ฅ ๋ถํฌ๋ฅผ ๏ปฟ๋ก ์ ์
Draft Model
- DM์ ํ์ฌ ๋ฌธ๋งฅ s๋ฅผ ๊ธฐ์ค์ผ๋ก ๋ณดํต greedy decoding ๋ฐฉ์์ ํตํด M๊ฐ์ ํ๋ณด ํ ํฐ์ ์์ฑ
- ๏ปฟ: draft ๋ชจ๋ธ์ด ์์ฑํ M๊ฐ์ ํ๋ณด ํ ํฐ๊ณผ ๊ฐ ํ ํฐ ํ๋ฅ
- ๏ปฟ: ํ๋ณด ํ ํฐ
- ๏ปฟ: DM์ Softmax ๋ถํฌ
- ๏ปฟ: DM์ด ํ ํฐ ๏ปฟ์ ๋ถ์ฌํ ํ๋ฅ
Target model์ ๊ฒ์ฆ ๋ฐฉ์(Acceptance Rule)
- target ๋ชจ๋ธ์ ์ด ํ๋ณด ํ ํฐ๋ค์ ๋ณ๋ ฌ๋ก ์ฒ๋ฆฌํ์ฌ ๏ปฟ ํ๋ฅ ๋ฒกํฐ๋ฅผ ์์ฑ
- ๊ฐ ํ๋ณด ํ ํฐ ๏ปฟ์ ๋ํ์ฌ ์๋ ์กฐ๊ฑด ๊ฒ์ฌ(์ด์ ํ ํฐ๋ค์ด ๋ชจ๋ Accept๋ ๊ฒฝ์ฐ์๋ง ํ์ฌ ํ ํฐ์ ๊ฒ์ฌ)
- Acceptance Rule: ์ด์ ํ ํฐ๋ค์ด ๋ชจ๋ Accept๋์๊ณ , ๊ท ๋ฑ๋ถํฌ์์ ์ํํ ๏ปฟ๊ฐ ๏ปฟ (TM์ ๏ปฟ ์์ฑ ํ๋ฅ ๋๋๊ธฐ draft ๋ชจ๋ธ์์ ๏ปฟ ์์ฑ ํ๋ฅ ) ๋ณด๋ค ์์ผ๋ฉด ํ ํฐ ๏ปฟ๋ฅผ Accept
- alignment ๊ธฐ๋ฐ ๊ฒ์ฆ
- ๏ปฟ (TM ํ๋ฅ โฅ DM ํ๋ฅ ) โ ๏ปฟ โ ๋ฌด์กฐ๊ฑด accept
- ๏ปฟ(TM ํ๋ฅ < DM ํ๋ฅ ) โ ํ๋ฅ ์ ์ผ๋ก accept
- e.g., ๋น์จ์ด 0.3์ด๋ฉด โ 30% ํ๋ฅ ๋ก Accept/70%๋ Reject
โ TM์ด ํด๋น ํ ํฐ์ DM๋ณด๋ค ๋ ๋์ ํ๋ฅ ๋ก ํ๊ฐํ๋ฉด accept
- alignment ๊ธฐ๋ฐ ๊ฒ์ฆ
โ Standard Speculative Decoding์์๋ ๊ฒ์ฆ ๋ฐฉ์(alignment) ์์ฒด์ ํ๊ณ ๋๋ฌธ์ draft ํ ํฐ ์๋ฅผ ๋๋ ค๋ acceptance๊ฐ ํฌํ๋์ด M์ ํฌ๊ฒ ์ฐ๋ ๊ฒ์ด ์คํ๋ ค ๋นํจ์จ์ ์
๊ธฐ์กด ๊ฒ์ฆ ๋ฐฉ์(Alignment)์ ํ๊ณ
RQ ์ด๋ค ์ข
๋ฅ์ ํ ํฐ๋ค์ด ๊ฑฐ์ ๋๋๊ฐ?
โ GSM8K, MT-Bench, HumanEval ๋ฑ์ ์ฌ๋ฌ ๋ฒค์น๋งํฌ์์ SD์ ๋์์ ๋ถ์
- draft ๋ชจ๋ธ๋ก๋ Llama-8B๋ฅผ, target ๋ชจ๋ธ๋ก๋ Llama-405B ์ฌ์ฉ
- draft ๋ชจ๋ธ๋ ์ฑ๋ฅ ๋์์ง ์๊ธฐ ๋๋ฌธ์ acceptance rate๋ฅผ ๋์ฌ๋ ํ์ง์ด ๋ฐ๋์ ์ ํ๋์ง๋ ์์
- ํนํ ๋น๊ต์ ๋จ์ํ ์ง๋ฌธ์ ๊ฒฝ์ฐ ๋ง์ draft ๋ต๋ณ์ ๊ทธ๋๋ก Accept๋์ด๋ ๊ฐ ์ถ
- GSM8K ๊ฐ์ ๋ฌธ์ ์์ ํนํ ๊ฐํจ
- ๋ฌธ์ : draft ๋ชจ๋ธ์ด ์์ ํ ์ ํํ ํด๋ต์ ์์ฑํ ๊ฒฝ์ฐ๋, ๊ธฐ์กด์ ๊ฒ์ฆ ๋ฐฉ์์ ํ๊ณ๋ก ์ธํด target ๋ชจ๋ธ์ด ๋ง์ ํ ํฐ์ ์์ฃผ ๊ฑฐ์ ํจ!
- ์ด๋ฌํ ๊ฑฐ์ ์ ์ด์ ๋ DM์ ์๋ต์ correctness์ ๋ฌด๊ดํ๊ฒ ๋ฐ์
- Standard SD์ ๋ฌธ์ ์
- Standard SD๋ ํ ํฐ์ Accept/Reject ํ ๋ ์ ๋ ฌ(alignment)๋ง ๋ด
- draft๊ฐ ๋ง๋ ํ ํฐ์ด ์ ๋ต์ด๊ณ ๋ฌธ๋งฅ์ ์ผ๋ก ์ข์๋ฐ๋, TM์ด ์ ํธํ๋ ํํ๊ณผ ๋ค๋ฅด๋ฉด Reject ๋๋ ค๋ฒ๋ฆผ
โ ๋ชฉํ: ํ๋ณด ํ ํฐ์ด ๋ฌธ๋งฅ์ ์ผ๋ก ์ฌ๋ฐ๋ฅธ(Correct) ๊ฒฝ์ฐ Acceptํ๋๋ก TM์ ํ์ตํ์!
์๋ก์ด ๊ฒ์ฆ ๋ฐฉ์์ธ Judge Decoding ์ ์
- Judge Decoding ๋ชฉํ: TM๊ณผ์ alignment ๋ง๊ณ , ํ ํฐ์ด ํ๋ ธ๋์ง/๋ง๋์ง(correctness)๋ฅผ ํ๋จํด์ Accept๋ฅผ ๋๋ฆฌ์!!
- LLM-as-a-Judge์์ ์ฐฉ์
- But, LLM-as-a-Judge์ ๋ฌธ์ ์
- ๊ธด ์์คํ ํ๋กฌํํธ์ CoT ์ถ๋ก ์ด ํ์ํ๋ฐ ์ด๊ฒ๋ค์ด ์ถ๋ก ์๋๋ฅผ ์ ํ์ํด
- LLM judge๋ ์ ์ฒด ๋ต๋ณ์ ํ๊ฐํ๋ ๋ฐฉ์์ ์ฌ์ฉํ๋๋ฐ SD๋ ์งง๊ณ ๋ถ๋ถ์ ์ธ ์ฐ์ ํ ํฐ์ ํ๊ฐํด์ผ ํ๋ค๋ ์ ์์ ์ค์ฉ์ ์ด์ง ์์
โ ๊ธฐ์กด ๋ฐฉ์์ ์ฅ์ ์ ์ด๋ฆฌ๋, LLM-as-a-Judge ์ ๋๋์ ์ด๋ฆฌ๋๋ก ์ค๊ณํด๋ณด์!
์๋ฒ ๋ฉ์ ์ด๋ฏธ ์ค๋ฅ๋ฅผ ์๊ณ ์๋ค..!
- TM์ ์๋ชป๋ ํ ํฐ์ ์ฒ๋ฆฌํ๋ฉด, ๋ง์ง๋ง hidden layer embedding์์ ์ด์ ๊ฐ์ง ์ ํธ๋ฅผ ๋ฐ์์ํด!
- ๋ชจ๋ธ์ ์ดํ ํ ํฐ์์ ํด๋น ์ค๋ฅ๋ฅผ ์์ ํ๋ ค๋ ๋ฐฉํฅ์ผ๋ก ์ถ๋ ฅ์ ์์ฑํ๊ฒ ๋จ
(์ ๊ทธ๋ฆผ์ ์ผ์ชฝ ์ด์์คํดํธ ์ฐธ๊ณ )
๋ชจ๋ธ์ ์ด๋ฏธ ์ด ํ ํฐ์ด ํ๋ ธ๋ค๋ ๊ฑธ ๋ด๋ถ์ ์ผ๋ก ์๊ณ ์๋ค..!
- ๋ชจ๋ธ์ ์ดํ ํ ํฐ์์ ํด๋น ์ค๋ฅ๋ฅผ ์์ ํ๋ ค๋ ๋ฐฉํฅ์ผ๋ก ์ถ๋ ฅ์ ์์ฑํ๊ฒ ๋จ
Judge Head
- TM์ embedding ์์ ๋ถ๋ ์์ ์ด์ง(binary) ๋ถ๋ฅ๊ธฐ
- ๋ชฉ์ : target ์๋ฒ ๋ฉ์ ๋ด๊ธด ์ ํธ๋ฅผ ์ด์ฉํด, ๊ฐ ํ๋ณด ํ ํฐ ๏ปฟ์ ๋ํด ์ด ํ ํฐ์ ๋ฌธ๋งฅ์(correctness) ํต๊ณผ์์ผ๋ ๋๋๊ฐ?๋ฅผ ๋น ๋ฅด๊ฒ ํ์
- ์ ๋ ฅ: ํ ํฐ ์๋ฒ ๋ฉ ๏ปฟ
- ์ถ๋ ฅ: ํด๋น ํ ํฐ์ด ํต๊ณผ(accept) ๊ฐ๋ฅํ ํ๋ฅ (score) โ ๏ปฟ
- ๊ฒฐ์ : ์๊ณ๊ฐ ๏ปฟ๋ฅผ ๋์ผ๋ฉด Accept โ ๏ปฟ
- Linear head (logistic regression)๋ก ๊ตฌํ
- ์ด๋ฏธ target embedding์ โ์ค๋ฅ ์ ํธโ๊ฐ ์กด์ฌ โ ๋ณต์กํ ๋ชจ๋ธ์ด ํ์ ์์
- ๋ณต์กํ MLP/Transformer๋ ์คํ๋ ค ๊ณผ์ ํฉ ์ํ
- ์ฅ์
- ํ๋ผ๋ฏธํฐ๊ฐ ๋งค์ฐ ์๊ณ (์ฝ 16.4k)
- ํ์ต ๋น ๋ฆ (~1.5h)
- TM์ ํ๋ผ๋ฏธํฐ๋ ๋๊ฒฐ(frozen)
- ์ด๋ฏธ target embedding์ โ์ค๋ฅ ์ ํธโ๊ฐ ์กด์ฌ โ ๋ณต์กํ ๋ชจ๋ธ์ด ํ์ ์์
Judge head ํ์ต์ ์ํ ๋ฐ์ดํฐ์ ๊ตฌ์ถ
- ์ด 500๊ฐ์ ๊ณ ํ์ง ์ฌ์ฉ์ ์ง๋ฌธ๊ณผ ๊ทธ์ ๋ํ ์ ๋ต/์ค๋ต ๋ต๋ณ ์์ผ๋ก ๊ตฌ์ฑ๋ ๋ฐ์ดํฐ์
๊ตฌ์ถ
- ์๋กญ๊ฒ ์์ฑํ ์ง๋ฌธ๊ณผ Alpaca์ ARC ๋ฐ์ดํฐ์ ์์ ํํฐ๋งํ ์ง๋ฌธ๋ค ์ฌ์ฉ
- ์ ๋ ฅ ์ง๋ฌธ๋ง ์ฌ์ฉํ๊ณ ๊ทธ์ ๋ํ ์ ๋ต์ ์ฌ์ฉํ์ง ์์์!
- Mistral-Large-2, Llama-8B, Llama-405B์ ํ์ฉํด ์ ๋ต๊ณผ ์ค๋ต์ ๋ค์ํ๊ฒ ์์ฑ
- ์ธ๊ฐ์ด ์ค์ ๋ก ์ค๋ฅ ํ ํฐ์ ์ฃผ์๋ ๋ฌ์
- ํ์ต ๊ณผ์ : ์ ๋ต ๋ต๋ณ์ ํฌํจ๋ ๋ชจ๋ ํ ํฐ์ positive ๋ก ๋ผ๋ฒจ๋ง, ์ค๋ต ๋ต๋ณ์์๋ ์ค๋ฅ๊ฐ ๋ฐ์ํ๊ธฐ ์ ๊น์ง์ ๋ชจ๋ ํ ํฐ์ positive๋ก, ์ค๋ฅ๊ฐ ๋ฐ์ํ ํ ํฐ๋ค์ negative์ผ๋ก ๋ผ๋ฒจ๋งํ์์
- positive๊ฐ negative ๋ณด๋ค 20๋ฐฐ ๋ง์
๋ชจ๋ธ ์ค๊ณ ๋ฐ ํ์ต
- ๊ตฌ์ถํ ๋ฐ์ดํฐ์
์ ๋ฐํ์ผ๋ก target ๋ชจ๋ธ์ ์๋ฒ ๋ฉ ์์ linear head์ธ ๏ปฟ๋ฅผ ํ์ต์ํด
- ๋ฐ์ดํฐ ๋ถ๊ท ํ์ ๋ณด์ ํ๊ธฐ ์ํด ๊ฐ์ค cross-entropy loss ์ฌ์ฉ
- ์๋ชป๋ ํ ํฐ์ ์ ์ก์๋ด๋๋ก negative ์ํ์ ๋ ํฐ ๊ฐ์ค์น๋ฅผ ๋
- ํ์ต ํ๋ผ๋ฏธํฐ: 16.4k
- ํ์ต ๋ฐ์ดํฐ: 30k ํ ํฐ
- ํ์ต ์๊ฐ: 1.5์๊ฐ ์ด๋ด
- target ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ: ๊ณ ์ (frozen)
์ค์ Inference ๊ณผ์ (Judge + Standard SD ๊ฒฐํฉ)
- TM์ด ํ๋ณด ํ ํฐ์ ๋ณด๊ณ ๋ ๊ฐ์ง ์ง๋ฌธ์ ๋์์ ํจ
- alignment: ๋ด๊ฐ ์๋ ๋ฝ์ ํ ํฐ์ด๋ ๋น์ทํด?
- correctness: ๋ฌธ๋งฅ์ ๋ง๋ ํ ํฐ์ธ์ง
โ ๋ ์ค ํ๋๋ผ๋ Yes๋ฉด Accept
==== ์ถ๋ก ๊ณผ์ ====
- DM์ด ํ๋ณด ํ ํฐ M๊ฐ ์์ฑ (๊ธฐ์กด ๋ฐฉ์๊ณผ ๋์ผ)
๏ปฟ
- TM์ด ํ๋ณด ํ ํฐ์ ํ ๋ฒ์ ๊ฒํ (2๊ฐ์ ํ์ ๋์ ์ํ)
- Standard SD ๊ฒ์ฆ ๋ง์คํฌ ๏ปฟ
- ๊ธฐ์ค: alignment (ํ๋ฅ ๋น์จ)
- ๏ปฟ โ Accept
- ๏ปฟโ Reject
- ๊ธฐ์ค: alignment (ํ๋ฅ ๋น์จ)
- Judge ๋ง์คํฌ ๏ปฟ
- ๊ธฐ์ค: correctness (embedding ๊ธฐ๋ฐ)
- ํ ํฐ ์๋ฒ ๋ฉ ๏ปฟ ๋ฅผ ๋ณด๊ณ judge head๊ฐ ์ ์ ๊ณ์ฐ
- ๊ณ์ฐ: ๏ปฟ
- ๊ฒฐ๊ณผ:
- ์ ์ > ฮด ๏ปฟ๏ปฟโ Accept
- ์ ์โค ฮด โ ๏ปฟ โ Reject
- ๊ธฐ์ค: correctness (embedding ๊ธฐ๋ฐ)
- Standard SD ๊ฒ์ฆ ๋ง์คํฌ ๏ปฟ
- ์ต์ข
Accept/Reject๋ OR๋ก ๊ฒฐํฉ
- ๏ปฟ
- Standard SD๊ฐ Accept๋ฉด โ ๋ฌด์กฐ๊ฑด Accept
- Standard SD๊ฐ Reject์ฌ๋, Judge๊ฐ Accept๋ฉด โ Accept
- ๏ปฟ
Experiment
Draft ํ์ง์ด ์์ฃผ ์ข์ ๊ฒฝ์ฐ (GPT-4o)
- ์คํ ๋ชฉ์ : DM์ ์ฌ๋ฐ๋ฅธ ์๋ต(correct)์กฐ์ฐจ ๋์ Reject์ ๊ฒช๋๋ค๋ ์ ์ ์ฆ๋ช ํ๊ธฐ ใ ์ดํด
- ์คํ ์
์
:
- DM: GPT-4o
- TM: Llama-405B
- ๋ฐ์ดํฐ์ : MT-Bench, GSM8K, HumanEval
- ์คํ ๋ฐฉ์: ๋ฐ์ดํฐ์ ์ ์ง๋ฌธ์ ๋ํ ์ ์ฒด ๋ต๋ณ์ ์์ฑํ ๋ค, greedy ๊ฒ์ฆ์์ ์ฒซ ๊ฑฐ์ ์ด ๋ฐ์ํ๊ธฐ ์ ๊น์ง TM์ด ๋ช ๊ฐ์ ํ ํฐ์ Acceptํ๋์ง ์ธก์
- ์คํ ๊ฒฐ๊ณผ:
- Standard SD acceptance: ์ฝ 2๊ฐ ํ ํฐ Accept
- Judge SD acceptance: 20~27 ํ ํฐ accept
- insight
- Draft์ ํ์ง์ด ์ข์์ง๋ค๊ณ acceptance๊ฐ ์ข์์ง์ง ์๋๋ค!
- Judge Decoding ๋ฐฉ์์ ์ฐ์..!
์คํ ์ธํ ์ ๋ฐ๋(draft โ target ๋ชจ๋ธ ๋ฐ๊พธ๊ธฐ)๋ก ํด๋ ๊ฒฐ๊ณผ๋ ๋์ผ!
- draft ๋ชจ๋ธ: Llama-405B
- target ๋ชจ๋ธ: GPT-4o (์ ์คํ๊ณผ ๋ฐ๋ ์ธํ )
- ์คํ ๊ฒฐ๊ณผ
- 8B/405B ์ผ ๋์ acceptance โ 6.6 ํ ํฐ
- 405B/8B ์ผ ๋์ acceptance โ 6.3 ํ ํฐ
โ ๊ฑฐ์ ์ฐจ์ด ์์
Human expert drafting
- ์ธ๊ฐ ์ ๋ฌธ๊ฐ๊ฐ ์์ฑํ Draft ํ ํฐ(์ฑ๋ฅ ์ต์)์ ๊ฒ์ฆํด๋ณด์!
- ๊ณ ํ์ง์ ์ปค๋ฎค๋ํฐ ๊ฒ์ฆ ์์ฝ๋ฌธ์ ํฌํจํ wikipedia-summary ๋ฐ์ดํฐ์ ์ผ๋ถ๋ฅผ ์ฌ์ฉํด, greedy SD ๊ฒ์ฆ์์์ ํ ํฐ Accept๋ฅ ์ ํ๊ฐ
- ์คํ ๊ฒฐ๊ณผ
- Standard SD ์ผ ๋์ acceptance โ 3.1 ํ ํฐ
- Judge SD ์ผ ๋์ acceptance โ 12.3 ํ ํฐ
- โ human์ด ์ง์ draft ํ ํ ํฐ๋ค๋ reject ๋๋ฆฌ๋๊ฑฐ๋ฉด TM์ ๊ธฐ์กด ๊ฒ์ฆ ๋ฐฉ์ ๋ฌธ์ ํ์คํ ์์!
โ target model๊ณผ์ alignment๋ฅผ ํตํด ๊ฒ์ฆํ๋ ๊ธฐ์กด์ ๊ฒ์ฆ ๋ฐฉ์์ด ๋ฌธ์ ์์!!
Judge Decoding Benchmark Results
- ์คํ ์ธํ
- ๋ชจ๋ธ: DM (Llama-8B) + TM (Llama-70B/405B)
- ๋น๊ต ๋ฐฉ๋ฒ(Decoding / Verification)
- Draft only (Llama-8B)
- Target only (Llama-70B, 405B)
- Top-K verification (ํ์ค SD ์ํ ํด๋ฆฌ์คํฑ): Target ๋ชจ๋ธ์ด ๋ณด๊ธฐ์ ํ๋ฅ ์ด ๋์ K๊ฐ ์์ ๋ค๋ฉด ๊ทธ๋ฅ ํต๊ณผ์ํค๋ ๊ฒ์ฆ ๋ฐฉ์
- Judge Decoding
- ํ์ดํผํ๋ผ๋ฏธํฐ
- Top-K: M=10
- Judge Decoding: M=25
- ๋ฒค์น๋งํฌ: GSM8K, HumanEval, ARC-Challenge, MMLU, MT-Bench
- ์คํ ๋ชฉ์
- Top-K (ํด๋ฆฌ์คํฑ ์ํ): ๊ฒ์ฆ ๊ธฐ์ค์ ๋์จํ๊ฒ ํ์ ๋ ์ ํ๋๊ฐ ์ผ๋ง๋ ๋ฌด๋์ง๋์ง ํ์ธ
- Judge Decoding์ด ๋ ๊ธด ํ๋ณด ์ํ์ค(M=25) ๋ฅผ acceptํ๋ฉด์๋ Target ์์ค ์ ํ๋๋ฅผ ์ ์งํ๋์ง ๊ฒ์ฆ
- ์คํ ๊ฒฐ๊ณผ
- Top-K๋ ์ ํ๋ ํฌ๊ฒ ํ๋ฝ
- ์ผ๋ถ ๋ฒค์น๋งํฌ์์ ๋๋๋ฌ์ง ์ฑ๋ฅ ์ ํ ๋ฐ์
- ์์ K ์์๋ง ๋ค๋ฉด acceptํ๋ Top-K ๋ฐฉ์์ ํ๋ฆฐ ํ ํฐ๋ ํต๊ณผ์ํค๊ธฐ ์ฌ์
- ํ์ง-์๋ trade-off๊ฐ ์ฌํ๊ฒ ๋ฐ์(์๋๋ ๋น ๋ฅด์ง๋ง ํ์ง์ ๋์๋ค๋ ๋ป)
- Judge Decoding์ ์ ํ๋ ๊ฑฐ์ ๋ณด์กด
- ๋ชจ๋ ๋ฒค์น๋งํฌ์์ Target-only ๋๋น ๊ฑฐ์ ์ฐจ์ด ์์ด ์ ์ง
- ์ฆ, ์ฝ 20๊ฐ ์์ค ํ ํฐ์ ํ ๋ฒ์ ๋ ๋ง์ด acceptํด๋ ํ์ง์ด ๊นจ์ง์ง ์์์ ๋ณด์ฌ์ค
- 70B/405B ๋ชจ๋์์ ์ผ๊ด๋ ๊ฒฝํฅ
- Target์ด ์ปค์ ธ๋(70Bโ405B) ๊ฒฐ๊ณผ ํจํด์ด ์ ์ง๋จ
โ ๋ฐฉ๋ฒ์ด ํน์ ๋ชจ๋ธ ํฌ๊ธฐ์๋ง ๋ง๋ ํธ๋ฆญ์ด ์๋๋ผ๋ ๊ทผ๊ฑฐ
- Target์ด ์ปค์ ธ๋(70Bโ405B) ๊ฒฐ๊ณผ ํจํด์ด ์ ์ง๋จ
- Top-K๋ ์ ํ๋ ํฌ๊ฒ ํ๋ฝ
- ์ธ์ฌ์ดํธ
- Judge Decoding์ top-K๋ณด๋ค 2.5๋ฐฐ ๋ง์ ์ํ์ ์์ฑํ๋ฉด์๋ ๋ ๋์ ์ ํ๋ ๋ฌ์ฑ
โ ์ํ ๊ฐ์๋ณด๋ค ํ์ง๊ณผ ๊ฒ์ฆ ๋ฉ์ปค๋์ฆ์ด ์ค์- Top-K: ๋จ์ ๊ท์น์ด๋ผ correctness ํ๋จ ์คํจ โ ์ ํ๋ ํ๋ฝ
- Judge Decoding: ์๋ฒ ๋ฉ ๊ธฐ๋ฐ correctness ํ์ โ ๋ ๊ธธ๊ฒ acceptํด๋ ์ ํ๋ ์ ์ง
- Target ๋ชจ๋ธ์ด ์ง์ ์ต์ข ํ ํฐ์ ์ ํํ์ฌ ๋ชจ๋ธ ์ ๋ ฌ(alignment) ๋ฌธ์ ๋ฅผ ๊ทผ๋ณธ์ ์ผ๋ก ํด๊ฒฐ
- Judge Decoding์ top-K๋ณด๋ค 2.5๋ฐฐ ๋ง์ ์ํ์ ์์ฑํ๋ฉด์๋ ๋ ๋์ ์ ํ๋ ๋ฌ์ฑ
๋ถํฌ ์ธ ์ผ๋ฐํ ์คํ
- ํ์ต๋์ง ์์ ์ํฉ์์๋ judge decoding์ด ์ผ๋ง๋ ์ผ๋ฐํ๋๋์ง๋ฅผ ํ๊ฐ
- ์ฝ๋ฉ ์์ ๋ฅผ ์ ๊ฑฐํ ๋ฐ์ดํฐ๋ก judge๋ฅผ ํ์ตํ ๋ค, HumanEval์์ ํ๊ฐ
- ์ฑ๋ฅ์ด 86.6%์์ 80.4%๋ก ํ๋ฝํ๊ธด ํ์ง๋ง, ์ฌ์ ํ DM(71.3%)๋ณด๋ค๋ ํจ์ฌ ๋์









