19 March 2026

Diffusion Alignment as Variational Expectation-Maximization

๐Ÿ’กDiffusion ๋ชจ๋ธ์„ ๋ชฉ์  ํ•จ์ˆ˜์— ๋งž๊ฒŒ diffusion alignmentํ•  ๋•Œ ๋ฐœ์ƒํ•˜๋Š” reward over-optimization ๊ณผ mode collapse ๋ฌธ์ œ๋ฅผ EM์•Œ๊ณ ๋ฆฌ์ฆ˜ (E๋‹จ๊ณ„(test time search) โ†’ M๋‹จ๊ณ„(forward-KL)์˜ ๋ฐ˜๋ณต)์œผ๋กœ ํ•ด๊ฒฐํ•˜์ž!

์ด๋‘ํ˜ธ
์ด๋‘ํ˜ธ

Diffusion Alignment as Variational Expectation-Maximization

Review

๋‹‰๋„ค์ž„ Strength & Weakness & Sugguestions ๋ณ„์  (0/5)
์ฝ”์Šคํ”ผ๊ฐ•์ : ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ๋ฐ”๊พธ์ง€ ์•Š๊ณ , Diffusion Optimization์„ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์–ด ํšจ์œจ์„ฑ์ด ๋†’์Œ
์•ฝ์ : ๋””ํ“จ์ „ ๋ชจ๋ธ์ธ๋ฐ, EM๋‹จ๊ณ„๋ฅผ ๊ณ„์† ๋ฐ˜๋ณตํ•˜๋ฉด ๊ณ„์‚ฐ์ด ๋ฌด๊ฑฐ์›Œ์ง€์ง€ ์•Š์„๊นŒ?
์ œ์•ˆ: Timestep์ด๋‚˜ ๋ฐ˜๋ณต ํšŸ์ˆ˜๋ฅผ ์กฐ์ ˆํ•ด์„œ ์„ฑ๋Šฅ์„ ๊ฐœ์„ ํ•˜๋Š”๊ฒŒ ํ•„์š”ํ•ด ๋ณด์ž„.
3.9
์–ผ๋ผ๊ฐ•์ : reward์™€ diversity๋ฅผ ํ•จ๊ป˜ ๊ณ ๋ คํ•˜๋ฉด์„œ ๋ถ€๋ถ„์˜ ์‹คํ—˜์—์„œ SOTA ์„ฑ๋Šฅ์„ ๋‹ฌ์„ฑํ•œ ์ ์ด ๊ฐ•์ 
์•ฝ์ : test-time search์˜ ํ’ˆ์งˆ์— ํฌ๊ฒŒ ์˜์กดํ•  ๊ฒƒ ๊ฐ™์Œ + ๊ณ„์‚ฐ๋Ÿ‰์ด ๋งŽ์ด ํ•„์š”ํ•ด๋ณด์ž„
์ œ์•ˆ: search๋ฅผ ์ตœ์†Œํ™”ํ•˜๋ฉด์„œ ์„ฑ๋Šฅ์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ํ›„์† ์—ฐ๊ตฌ๊ฐ€ ๋‚˜์˜ค๋ฉด ์ข‹์„ ๊ฒƒ ๊ฐ™์Œ
3.8
๋น„์š”๋œจ๊ฐ•์ : ๋ณด์ƒ๋งŒ ์ตœ์ ํ™” ํ•  ๋•Œ collapse ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๋ฌธ์ œ๋ฅผ EM์œผ๋กœ ์ž˜ ๊ท ํ˜•์„ ์žก์€๊ฒƒ ๊ฐ™์Œ. EM์„ diffusion์— ์‚ฌ์šฉํ•œ ์‚ฌ๋ก€๋ฅผ ์ฒ˜์Œ ์ฝ์–ด๋ณด์•„์„œ ์ž˜ ๋ชฐ๋ž์ง€๋งŒ, EMํ•™์Šต ํ๋ฆ„์ด diffusion์˜ ์ƒ˜ํ”Œ๋ง/ํ•™์Šต ๊ตฌ์กฐ์™€ ์ž˜ ๋งž๋Š” ๋А๋‚Œ?
์•ฝ์ : E-Step์—์„œ์˜ ํƒ์ƒ‰ ๋น„์šฉ์ด ๋งค์šฐ ํด๊ฒƒ ๊ฐ™์Œ
์ œ์•ˆ: ๋ชจ๋“  timestamp์—์„œ M๊ฐœ์”ฉ ๋ฝ‘๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ ๋ณด์ƒ์— ๋ฏผ๊ฐํ•œ ๊ตฌ๊ฐ„ ๋ณ„๋กœ sampling ์ˆ˜๋ฅผ ๋‹ค๋ฅด๊ฒŒ ํ•  ์ˆ˜ ์žˆ์ง€ ์•Š์„๊นŒ?
4.1
์นซ์†”๊ฐ•์ : diffusion alignment์— EM ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ƒˆ๋กœ์šด ๋ฐฉ์‹์œผ๋กœ ์ ์šฉํ•˜๊ณ  reverse/forward KL๋„ ๊ฒฐํ•ฉํ•จ
์•ฝ์ : test-time search๋กœ ์ธํ•œ ์‹œ๊ฐ„ ํšจ์œจ์„ฑ
์ œ์•ˆ: test-time search ํ’ˆ์งˆ์ด ์–ด๋А์ •๋„ ์ด์ƒ์ด๋„๋ก ๋ณด์žฅํ•˜๋Š” ์ œ์•ฝ ์ถ”๊ฐ€
3.6
์„คํ–ฅ๋”ธ๊ธฐ๊ฐ•์ : diffusion์ด ๊ฐ€์ง€๋Š” ๋‹ค์–‘์„ฑ์„ ์œ ์ง€ํ•˜๋ฉด์„œ๋„, ์„ฑ๋Šฅ์„ ๊ฐœ์„ ํ•˜๋Š” ๋ฐฉ๋ฒ• ์ œ์•ˆ. ์ตœ๊ทผ ๊ฐ•ํ™”ํ•™์Šต๋“ค์ด ๋‹ค์–‘ํ•œ objective๋ฅผ ๋™์‹œ์— ๊ณ ๋ คํ•  ์ˆ˜ ์žˆ๋„๋ก ๊ฐœ๋Ÿ‰๋˜๊ณ  ์žˆ๋Š” ๊ฒƒ ๊ฐ™๊ณ , ๊ทธ ๊ธฐ์กฐ์— ๋งž๋Š” ๋ฐฉ๋ฒ•๋ก ์ด๋ผ๊ณ  ์ƒ๊ฐํ•จ.
์•ฝ์ : ์ƒˆ๋กญ๊ฒŒ ๋А๊ปด์ง€์ง€ ์•Š์Œ. ๊ทธ๋ƒฅ ๊ธฐ์กด ๋ฐฉ๋ฒ• 2๊ฐœ์˜ ๊ฒฐํ•ฉ ์•„๋‹Œ๊ฐ€?
์ œ์•ˆ: ๋ฆฌ์›Œ๋“œ ๋ชจ๋ธ์„ ์ •ํ™•ํ•˜๊ฒŒ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์ด ์˜คํžˆ๋ ค over-optimization์„ ํ•ด๊ฒฐํ•˜๋Š” ๋” ์ข‹์€ ๋ฐฉ๋ฒ•์ผ ๊ฒƒ ๊ฐ™์Œ. ์ด ์•Œ๊ณ ๋ฆฌ์ฆ˜์˜ ๋Œ€์ƒ์ด ๋ชจ๋ธ์ด ์•„๋‹ˆ๋ผ, reward๋ชจ๋ธ์ด๋‚˜ ๋‹ค๋ฅธ ๋ชจ๋ธ์„ ์กฐ์ •ํ•˜์—ฌ ํ•ด๊ฒฐํ•˜๋Š” ๊ฑด ์–ด๋–จ๊นŒ?
3.7
๋‚˜์Šค๋‹ฅ๊ฐ•์ : ์ˆ˜ํ•™์  Soundness๊ฐ€ ํ’๋ถ€ํ•จ! Diffusion+RL์€ ์ฐธ์‹ ํ•œ ์กฐํ•ฉ์ธ๋“ฏ
๋‹จ์ : ํƒ€ ๋ฉ”์†Œ๋“œ๋“ค๊ณผ ๋น„๊ตํ•ด์„œ ์–ผ๋งˆ๋‚˜ ๊ฐ€๋ฒผ์šด์ง€, ๋น ๋ฅธ์ง€์— ๋Œ€ํ•œ ๋น„๊ต๊ฐ€ ์žˆ์—ˆ๋‹ค๋ฉด ๋” ์ข‹์•˜์„ ๊ฒƒ ๊ฐ™์Œ!
์ œ์•ˆ: Alignment ์„ฑ๋Šฅ, motivation์— ๋Œ€ํ•œ ๊ฐœ์„ ์„ ์ฆ๋ช…ํ•˜๋ ค๋ฉด user study๊ฐ€ ํ•„์š”ํ•ด๋ณด์ž„!
์—ฌ๋‹ด: diffusion์€ NLP๋ž‘ ์ข€ ์•ˆ ๋งž๋Š”๊ฑฐ ๊ฐ™๋‹ค๋Š” ์ƒ๊ฐ์ด ๋งค๋ฒˆ ๋“ฌ
3.5
์ปคํ”ผ๊ฐ•์  : ๊ธฐ์กด diffusion์˜ ๋ฌธ์ œ์ธ mode collapse์™€ ๊ณ„์‚ฐ ๋น„์šฉ ๋ฌธ์ œ์˜ ์›์ธ์ธ reverse-KL์— ๋Œ€ํ•ด์„œ, test-time-search ๋ฐฉ์‹์„ ๊ทธ๋Œ€๋กœ ํ™œ์šฉํ•˜์—ฌ ์„ธํŒ…์„ ๋ฐ”๊พธ๋Š” ๊ฒƒ์ด ์ฐธ์‹ ํ•จ.
๋˜ํ•œ, test time search ๋ฐฉ์‹์˜ ์ƒ˜ํ”Œ์„ ํ†ตํ•ด reward gradient๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ฒŒ ๋˜์–ด ๋” ์ผ๋ฐ˜ํ™”๋œ ๊ฒƒ์ด ์˜๋ฏธ๊ฐ€ ์žˆ๋‹ค๊ณ  ์ƒ๊ฐ.
์•ฝ์  : test time search๋ฅผ ์—ฌ์ „ํžˆ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ ํƒ์ƒ‰๋น„์šฉ์— ํฐ ๊ฐœ์„ ์€ ์—†์„ ๊ฒƒ ๊ฐ™์Œ. ๋˜ํ•œ ๋ฝ‘ํžŒ ์ƒ˜ํ”Œ์˜ ํ€„๋ฆฌํ‹ฐ๊ฐ€ ์ผ๊ด€์ ์ด๋ผ๋ฉด, reverse-KL์˜ mode collapse์˜ ๋ฌธ์ œ๋„ ํฐ ๊ฐœ์„ ์ด ์—†์ง€ ์•Š์„๊นŒ?
์ œ์•ˆ : test time search์™€ ์ƒ˜ํ”Œ์˜ ํ€„๋ฆฌํ‹ฐ ํ™•๋ณด์— ๊ด€๋ จ๋œ ์—ฐ๊ตฌ๊ฐ€ ์ถ”๊ฐ€ ์ œ์‹œ๋˜์—ˆ์œผ๋ฉด ๋” ๋…ผ๋ฆฌ์ ์ด์—ˆ์„ ๊ฒƒ ๊ฐ™์Œ.
3.8
AI๊ฐ•์ : reward ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ฉด์„œ diversity๊นŒ์ง€ ์œ ์ง€ํ•˜๋Š” diffusion alignment์˜ ํ•ต์‹ฌ trade-off๋ฅผ ์ž˜ ํ•ด๊ฒฐํ•œ๋“ฏ + DNA ๋„๋ฉ”์ธ ์‹คํ—˜๋„ ์‹ ๋ฐ•ํ•จ
์•ฝ์ : ๊ทผ๋ณธ์ ์œผ๋กœ diversity๋Š” ์œ ์ง€ํ•˜๋”๋ผ๋„ ๋ชจ๋ธ์˜ bias ์ž์ฒด๋Š” ํ•ด๊ฒฐํ•˜๊ธฐ ํž˜๋“ค์–ด๋ณด์ž„
์ œ์•ˆ: ์—ฌ๊ธฐ์„œ reward๋ฅผ ํ•ญ์ƒ ์ ˆ๋Œ€์ ์œผ๋กœ ์‹ ๋ขฐํ•˜๋Š”๋ฐ, uncertainty๋ฅผ ๊ณ ๋ คํ•ด๋ณผ ์ˆ˜ ์žˆ์ง€ ์•Š์„๊นŒ?
3.9
404๊ฐ•์ : diffusion์„ ๊ธฐ์กด Preference optimization์— ์ ‘๋ชฉํ•˜๋ ค๋Š” ์‹œ๋„ ์ž์ฒด๊ฐ€ novelty๊ฐ€ ํฌ๊ณ , soundness๊ฐ€ ์ข‹๋‹ค๊ณ  ์ƒ๊ฐํ•จ! ํ˜„์žฌ vision์—์„œ difussion์ด ์‚ฌ์šฉ๋˜๋Š” ์ทจ์ง€๊ฐ€, ์ €์ž๋“ค์ด ์ œ์•ˆํ•˜๋Š” motivation๊ณผ ์ง๊ด€์ ์œผ๋กœ align์ด ์ž˜ ๋˜์–ด์„œ, ํฅ๋ฏธ๋กญ๊ฒŒ ์ฝ์Œ
์•ฝ์ : ๋‹ค์–‘์„ฑ ์ด์™ธ์˜ ๋ชจ๋“  ๋ถ€๋ถ„. e.g. ์‹œ๊ฐ„์ ์ธ cost, bias ๋“ฑ๋“ฑ์„ ๊ณ ๋ คํ•˜์ง€ ๋ชปํ•จ (+architecture ๊ทธ๋ฆผ ์—†์–ด์„œ ๊ฐ€๋…์„ฑ์ด ๋„ˆ๋ฌด ๋‚ฎ์Œ)
์ œ์•ˆ: NLP downstream task์— ์ ์šฉ
4.2
๊ตญ๋ฐฅ๊ฐ•์ : mode-seeking ๋ฌธ์ œ๋ฅผ forward-KL๋กœ ์ „ํ™˜ํ•˜๋Š” ๋ฐœ์ƒ์ด ๋‹จ์ˆœํ•˜์ง€๋งŒ ํšจ๊ณผ์ ์ธ๊ฒƒ ๊ฐ™์Œ. ์—ฐ์†, ์ด์‚ฐ ๋‘ ๋„๋ฉ”์ธ์—์„œ ๋™์‹œ์— ๊ฒ€์ฆํ•ด์„œ ์‹คํ—˜ํ•จ.
์•ฝ์ : E-step์—์„œ test time search ๋น„์šฉ์ด ๋งค iteration๋งˆ๋‹ค ๋ฐœ์ƒํ•จ. ๊ธฐ์กด ๋ฐฉ๋ฒ•์— ๋น„ํ•ด ์‹ค์ œ ํ•™์Šต ์‹œ๊ฐ„์ด ์–ผ๋งˆ๋‚˜ ๋” ๊ฑธ๋ฆฌ๋Š”์ง€ ๋น„๊ต๊ฐ€ ์—†์Œ.
์ œ์•ˆ: E-step์—์„œ ํƒ์ƒ‰ ํšŸ์ˆ˜์™€ ์„ฑ๋Šฅ ๊ฐ„ ๋น„๊ต
3.8

TL; DR

๐Ÿ’ก

Diffusion ๋ชจ๋ธ์„ ๋ชฉ์  ํ•จ์ˆ˜์— ๋งž๊ฒŒ diffusion alignmentํ•  ๋•Œ ๋ฐœ์ƒํ•˜๋Š” reward over-optimization ๊ณผ mode collapse ๋ฌธ์ œ๋ฅผ EM์•Œ๊ณ ๋ฆฌ์ฆ˜ (E๋‹จ๊ณ„(test time search) โ†’ M๋‹จ๊ณ„(forward-KL)์˜ ๋ฐ˜๋ณต)์œผ๋กœ ํ•ด๊ฒฐํ•˜์ž!

Summary

  • ์—ฐ๊ตฌ์ง„: KAIST, MongooseAI, Mila, University of Edinburgh, Omelet

Background & Motivation

Diffusion ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€, ๋กœ๋ณดํ‹ฑ์Šค, ์ƒ๋ฌผํ•™ ๋“ฑ ๋‹ค์–‘ํ•œ ๋„๋ฉ”์ธ์—์„œ high-fidelity ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ๋›ฐ์–ด๋‚จ.

โ†’ but ์‹ค์ œ ์‘์šฉ์—์„œ๋Š” ๋‹จ์ˆœํžˆ ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ ์™ธ์—๋„, ์™ธ๋ถ€ ๊ธฐ์ค€(์ด๋ฏธ์ง€์˜ ๋ฏธ์  ํ’ˆ์งˆ, DNA enhancersํ™œ์„ฑ๋„ ๋“ฑ)์— ๋งž์ถ˜ ์ƒ˜ํ”Œ์ด ํ•„์š”ํ•จ

โ†’ ์ด๋ฅผ ์œ„ํ•ด diffusion alignment(์‚ฌ์ „ํ•™์Šต๋œ diffusion ๋ชจ๋ธ์„ downstream objective์— ๋งž๊ฒŒ fine tuning)์ด ํ•„์š”

Diffusion Alignment์˜ ๊ธฐ์กด ์ ‘๊ทผ๋ฒ•

  • RL ๊ธฐ๋ฐ˜ fine-tuning (DDPO, DPOK)
    • on-policy ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ reverse-KL objective๋ฅผ ํ†ตํ•ด ๋””ํ“จ์ „ ๋ชจ๋ธ์„ ํŒŒ์ธํŠœ๋‹
    • Denoising ๊ณผ์ •์„ sequential decision making์œผ๋กœ ๋ณด๊ณ , black box reward function์„ ์ตœ๋Œ€ํ™”ํ•˜๋„๋ก policy ์ตœ์ ํ™”
    • Reverse-KL objective ์‚ฌ์šฉ โ†’ mode-seeking ํ–‰๋™ โ†’ mode collapse ๋ฐœ์ƒ
  • Direct backpropagation(DRaFT, AlignProp)
    • ๋ฏธ๋ถ„ ๊ฐ€๋Šฅํ•œ reward functio์œผ๋กœ๋ถ€ํ„ฐ gradient๋ฅผ denoising chain์„ ํ†ตํ•ด ์ง์ ‘ ์—ญ์ „ํŒŒ
    • ์ƒ˜ํ”Œ ํšจ์œจ์„ฑ์€ ๋†’์ง€๋งŒ, reward model์˜ gradient ๊ฐ’์— ์˜์กด โ†’ reward ๋ชจ๋ธ ์ž์ฒด๊ฐ€ ์™„์ „ํ•˜์ง€ ์•Š์Œ โ†’ reward over-optimization ๋ฐœ์ƒ

โ†’ ๊ธฐ์กด์˜ ๋‘ ๋ฐฉ๋ฒ•์—์„œ mode collapse(์ƒ์„ฑ๋œ ์ƒ˜ํ”Œ์ด ํ•˜๋‚˜์˜ mode๋กœ๋งŒ ์ƒ์„ฑ์ด ๋˜์–ด์„œ ๋‹ค์–‘์„ฑ์ด ๋–จ์–ด์ง), reward over optimization(reward ์ ์ˆ˜๋Š” ๋†’์ง€๋งŒ ์‹ค์ œ ํ’ˆ์งˆ์€ ์˜คํžˆ๋ ค ๋–จ์–ด์ง) ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒ

Fine-tuning approaches

  • Liu et al. (2024); Domingo-Enrich et al. (2025) ์—์„œ reward function์˜ ๊ธฐ์šธ๊ธฐ ์‹ ํ˜ธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅด๋„๋ก ์—ฐ์† diffusion ๋ชจ๋ธ์„ fine tuningํ•  ๊ฒƒ์„ ์ œ์•ˆ

Test-time search ๋ฐฉ์‹

  • ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ๋ฐ”๊พธ์ง€ ์•Š๊ณ  ์ถ”๋ก  ์‹œ์— ์ถ”๊ฐ€ ์—ฐ์‚ฐ์„ ํˆฌ์ž…
  • 2๊ฐ€์ง€ ๋ฐฉ์‹
    • Guidance ๊ธฐ๋ฐ˜: ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ์˜ ๊ฐ ๋‹จ๊ณ„๋งˆ๋‹ค reward๊ฐ€ ๋†’์•„์ง€๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ์‹ ํ˜ธ๋ฅผ ์คŒ. ํ•˜์ง€๋งŒ ๊ทผ์‚ฌ์น˜์ด๊ธฐ ๋•Œ๋ฌธ์— underoptimization ์ผ์–ด๋‚จ.
    • search ๊ธฐ๋ฐ˜: ๊ฐ ๋‹จ๊ณ„์—์„œ ์—ฌ๋Ÿฌ ํ›„๋ณด๋“ค์„ ์ƒ์„ฑํ•˜๊ณ  ๊ทธ์ค‘ ๊ฐ€์žฅ ์ข‹์€ ๊ฒƒ์„ ์„ ํƒ. ๊ณ„์‚ฐ๋น„์šฉ์ด ํผ
  • ๊ธฐ์กด์˜ test time search ๋ฐฉ์‹์€ ๊ณ„์‚ฐ ๋น„์šฉ์ด ํฌ๊ณ , underoptimization ํ˜„์ƒ ์ผ์–ด๋‚จ

+์—ฐ์†๊ณผ ์ด์‚ฐ diffusion ๋ชจ๋‘์— ์ ์šฉ ๊ฐ€๋Šฅํ•œ ํ”„๋ ˆ์ž„์›Œํฌ๋Š” ์—†์Œ

  • ๊ธฐ์กด ๋ฐฉ๋ฒ•๋“ค์€ ๋ฏธ๋ถ„ ๊ฐ€๋Šฅํ•œ reward์™€ ์—ฐ์† diffusion์— ํ•œ์ •๋จ

โ†’ Reward๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜๋ฉด์„œ๋„ ๋‹ค์–‘์„ฑ๊ณผ ์ž์—ฐ์Šค๋Ÿฌ์›€์„ ๋†’์ด๊ณ , ์—ฐ์†/์ด์‚ฐ ๋„๋ฉ”์ธ ๋ชจ๋‘์— ์ ์šฉ ๊ฐ€๋Šฅํ•œ ๋””ํ“จ์ „ ๋ชจ๋ธ fine tuning ํ”„๋ ˆ์ž„์›Œํฌ๊ฐ€ ํ•„์š”ํ•จ

โ†’ DAV๋Š” test time search ๋ฐฉ์‹์„ ํ†ตํ•ด ์ƒ˜ํ”Œ ์ˆ˜์ง‘ โ†’ ์ˆ˜์ง‘ํ•œ ์ƒ˜ํ”Œ์„ ๋””ํ“จ์ „ ๋ชจ๋ธ์— distill ํ•จ์œผ๋กœ์จ ์œ„ ๋‘ ํŒจ๋Ÿฌ๋‹ค์ž„์„ ํ†ตํ•ฉ.

Contributions

DAV (Diffusion Alignment as Variational EM) ํ”„๋ ˆ์ž„์›Œํฌ

  • Diffusion alignment๋ฅผ variational EM ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ๊ตฌํ˜„
  • E-step (ํƒ์ƒ‰)๊ณผ M-step (๊ทผ์‚ฌํ™”)๋ฅผ ๋ฐ˜๋ณตํ•˜์—ฌ reward ์ตœ์ ํ™”์™€ ๋‹ค์–‘์„ฑ ๋ณด์กด์„ ๋™์‹œ์— ๋งŒ์กฑํ•จ

E-step์—์„œ test time search๋ฅผ ํ™œ์šฉํ•œ posterior inference

  • Soft Q-function ๊ธฐ๋ฐ˜์˜ test time search๋กœ ๋ณด์ƒ์ด ๋†’์€ ๋‹ค์–‘ํ•œ ์ƒ˜ํ”Œ์„ variational posterior์—์„œ ํƒ์ƒ‰ํ•˜์—ฌ ์‚ฌ์šฉ
  • ๊ธฐ์กด EM๊ธฐ๋ฐ˜ RL ์ ‘๊ทผ๋ฒ•์˜ ์•ฝ์ (on poliocy ์ƒ˜ํ”Œ์„ reweightingํ•˜์—ฌ ์‚ฌํ›„๋ถ„ํฌ๋ฅผ ๊ทผ์‚ฌํ• ๋•Œ ์‚ฌํ›„ ๋ถ„ํฌ๋ฅผ ์ž˜๋ชป ์ง€์ •ํ•˜๊ฒŒ ๋จ)๋ฅผ ๊ทน๋ณต

M-step์—์„œ forward-KL distillation์œผ๋กœ ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ

  • Reverse-KL(mode-seeking) ๋Œ€์‹  forward-KL(mode-covering)์„ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์–‘์„ฑ ๋ณด์กด
  • E-step์—์„œ ๋ฐœ๊ฒฌํ•œ ๋‹ค์–‘ํ•œ mode๋ฅผ ๋ชจ๋‘ ์ปค๋ฒ„ํ•˜๋„๋ก ๋ชจ๋ธ์ด ํ•™์Šต๋จ

์—ฐ์† + ์ด์‚ฐ diffusion์— ๋ชจ๋‘ ์ ์šฉ ๊ฐ€๋Šฅ

  • Text-to-image์™€ DNA sequence design ์—์„œ ์‹คํ—˜ ๊ฒ€์ฆ
  • Reward function์˜ ๋ฏธ๋ถ„ ๊ฐ€๋Šฅ์„ฑ์— ๋Œ€ํ•œ ๊ฐ€์ • ๋ถˆํ•„์š”(๋ฏธ๋ถ„์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Œ) โ†’ ๋” ์ผ๋ฐ˜์ ์ธ ํ”„๋ ˆ์ž„์›Œํฌ(์—ฐ์†, ์ด์‚ฐ ๋ชจ๋‘ ์ ์šฉ ๊ฐ€๋Šฅ)

โ†’ forward-KL ๋ฐฉ์‹์œผ๋กœ ๊ธฐ์กด์˜ Diffusion Alignment ์ ‘๊ทผ๋ฒ•๋“ค์˜ ๋‘๊ฐ€์ง€ ๋ฌธ์ œ์ ์„ ํ•ด๊ฒฐํ•˜๋ฉด์„œ Test-time search ๋ฐฉ์‹์„ ์ถ”๊ฐ€ ๊ณ„์‚ฐ ์˜ค๋ฒ„ํ—ค๋“œ ์—†์ด ํ•™์Šต ์‹œ์—๋งŒ ์‚ฌ์šฉํ•จ์œผ๋กœ์จ ํšจ๊ณผ์ ์œผ๋กœ ์ ์šฉํ•œ ๋ฐ์— ์˜๋ฏธ๊ฐ€ ์žˆ์Œ

Method

์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ: E-step๊ณผ M-step์„ ๋ฐ˜๋ณตํ•˜๋ฉฐ, E-step์—์„œ ๋ฐœ๊ฒฌํ•œ ๋†’์€ ๋ณด์ƒ ์ƒ˜ํ”Œ์„ M-step์—์„œ ๋ชจ๋ธ์— distillation.

E-step: ํƒ์ƒ‰. test time search๋กœ ๋ณด์ƒ์ด ๋†’๊ณ  ๋‹ค์–‘ํ•œ ์ƒ˜ํ”Œ์„ ๋ฐœ๊ฒฌ
โ†’
M-step: ์ฆ๋ฅ˜. ๋ฐœ๊ฒฌํ•œ ์ƒ˜ํ”Œ๋“ค์„ forward-KL๋กœ ๋ชจ๋ธ์— distillation
โ†’
๋ฐ˜๋ณต


variational EM formulation

  • optimality variable O๋ฅผ ๋„์ž…
    • O=1์ด๋ฉด ์ข‹์€ ๊ฒฐ๊ณผ, ์ด ํ™•๋ฅ ์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ
    • ๋””๋…ธ์ด์ง• ๊ฒฝ๋กœ ฯ„๊ฐ€ ์ˆจ๊ฒจ์ง„ ๋ณ€์ˆ˜(latent variable) ์—ญํ• 
    • ฯ„๊ฐ€ ๋†’์€ ๋ณด์ƒ์„ ์ค„์ˆ˜๋ก O=1์ผ ํ™•๋ฅ ์ด ๋†’์Œ
  • ์ง์ ‘ ์ตœ์ ํ™”๊ฐ€ ์–ด๋ ค์šฐ๋ฏ€๋กœ Variational distribution ฮท(ฯ„)๋ฅผ ๋„์ž…ํ•ด ELBO๋ฅผ ์ตœ๋Œ€ํ™”
  • Discount factor ฮณ๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ๋…ธ์ด์ฆˆ๊ฐ€ ํฐ ์ดˆ๊ธฐ ๋‹จ๊ณ„์˜ ์˜ํ–ฅ์„ ์ค„์ž„

E-step: test time search๋กœ posterior inference

  • ์ตœ์ ์˜ variational distribution ฮท*๋Š” "ํ˜„์žฌ ๋ชจ๋ธ์˜ ํ™•๋ฅ  ร— ์†Œํ”„ํŠธ Q-ํ•จ์ˆ˜์˜ ์ง€์ˆ˜"์— ๋น„๋ก€ํ•˜๋Š” ๋ณผ์ธ ๋งŒ ๋ถ„ํฌ
  • ฮท*k ์—์„œ ๋‹ค์–‘ํ•˜๊ณ  ๋ณด์ƒ์ด ๋†’์€ ๊ถค์ ์„ ์ƒ˜ํ”Œ๋ง ํ•˜๋Š” ๊ฒƒ์ด ๋ชฉํ‘œ
  • ์ง์ ‘ ์ƒ˜ํ”Œ๋ง์ด ๋ถˆ๊ฐ€๋Šฅํ•˜๋ฏ€๋กœ 2๋‹จ๊ณ„ ๊ทผ์‚ฌ ์ˆ˜ํ–‰:
    1. gradient guidance๋กœ proposal distribution(์ œ์•ˆ ๋ถ„ํฌ) ๊ตฌ์„ฑ (๋ณด์ƒ ๊ธฐ์šธ๊ธฐ๋กœ ์ข‹์€ ๋ฐฉํ–ฅ์„ ์•ˆ๋‚ด)
    1. importance sampling์œผ๋กœ ๋ณด์ • (์‹ค์ œ ์ตœ์  ๋ถ„ํฌ์™€์˜ ์ฐจ์ด๋ฅผ ๊ฐ€์ค‘์น˜๋กœ ๋ณด์ •)
  • ๋ชจ๋“ˆํ™” ์„ค๊ณ„๋ฅผ ํ†ตํ•ด์„œ ๋” ์ข‹์€ ํƒ์ƒ‰ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ๋‚˜์˜ค๋ฉด ๊ต์ฒด ๊ฐ€๋Šฅ

M-step: forward-KL๋กœ ๋””ํ“จ์ „ ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ

  • E-step์—์„œ ๋ฐœ๊ฒฌํ•œ ๊ถค์ ๋“ค์— ๋Œ€ํ•ด log-likelihood๋ฅผ ์ตœ๋Œ€ํ™”ํ•จ์œผ๋กœ์จ ํ•™์Šต = forward-KL minimization
  • Forward-KL ์‚ฌ์šฉ โ†’ mode-covering ํŠน์„ฑ โ†’ ๋‹ค์–‘ํ•œ ๋ชจ๋“œ๋ฅผ ๋ชจ๋‘ ์ปค๋ฒ„ํ•˜๋„๋ก ํ•™์Šต
    • ๊ธฐ์กด RL์€ reverse-KL โ†’ mode-seeking โ†’ ํ•˜๋‚˜์˜ ๋ชจ๋“œ์—๋งŒ ์ง‘์ค‘
  • DAV-KL variant: ์‚ฌ์ „ํ•™์Šต ๋ชจ๋ธ๊ณผ์˜ KL ํŽ˜๋„ํ‹ฐ๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ๋‹ค์–‘์„ฑ์„ ๋” ๋ณด์กด
  • DAV, DAV-KL?
    • DAV: reward๋ฅผ ๋” ๊ณต๊ฒฉ์ ์œผ๋กœ ์˜ฌ๋ฆผ โ†’ ์ ์ˆ˜๋Š” ๋†’์ง€๋งŒ ๋‹ค์–‘์„ฑ์€ ์ƒ๋Œ€์ ์œผ๋กœ ๋‚ฎ์Œ
    • DAV-KL: reward๋Š” ์ข€ ๋‚ฎ์ง€๋งŒ ์‚ฌ์ „ํ•™์Šต ๋ชจ๋ธ์˜ ํŠน์„ฑ์„ ๋” ๋ณด์กดํ•˜๋„๋ก ์ œ์•ฝ ฮป ์ถ”๊ฐ€

Forward-KL์€ mode-covering objective ํ•จ โ†’ E-step์—์„œ ๋ฐœ๊ฒฌํ•œ ๋ชจ๋“  ๋‹ค์–‘ํ•œ mode๋ฅผ ์ปค๋ฒ„ํ•˜๋„๋ก ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ

Experiments

์‹คํ—˜ 1: Text-to-image ์ƒ์„ฑ (์—ฐ์† diffusion)

  • ์‹คํ—˜ ์„ธํŒ…
    • ๋ชจ๋ธ: Stable Diffusion v1.5 (LoRA rank 4๋กœ ํŒŒ์ธํŠœ๋‹)
    • ๋ณด์ƒ: LAION aesthetic score (๋ฏธ์  ํ’ˆ์งˆ ์ ์ˆ˜, ๋ฏธ๋ถ„ ๊ฐ€๋Šฅ)
    • ํ”„๋กฌํ”„ํŠธ: 40๊ฐœ์˜ ๋™๋ฌผ ํ”„๋กฌํ”„ํŠธ
    • ํ‰๊ฐ€ ์ง€ํ‘œ:
      • reward(๋ฏธ์  ํ’ˆ์งˆ)๊ณผ ๋”๋ถˆ์–ด ์ด์ „ ๋ฐฉ๋ฒ•์˜ ๋‘๊ฐ€์ง€ ์ฃผ์š” ์‹คํŒจ ์˜€๋˜ ๊ณผ์ตœ์ ํ™”์™€ ๋‹ค์–‘์„ฑ ๋ถ•๊ดด(mode collapes)๋ฅผ ํ‰๊ฐ€
        • Aesthetic Score(LAION aesthetic score): ๋ฏธ์  ํ’ˆ์งˆ. (๋ฏธ๋ถ„ ๊ฐ€๋Šฅํ•จ) โ†’ reward. ๋ชจ๋ธ์ด ์ตœ์ ํ™”ํ•˜๋„๋ก ํ›ˆ๋ จ๋œ ๋ชฉํ‘œ ์ ์ˆ˜
        • ImageReward: ์ธ๊ฐ„ ์„ ํ˜ธ๋„ ์ ์ˆ˜ โ†’ ํ•™์Šต์— ์‚ฌ์šฉ๋˜์ง€ ์•Š์€ ๋ณ„๋„์˜ ํ‰๊ฐ€ ์ง€ํ‘œ (๊ณผ์ตœ์ ํ™” ํƒ์ง€๋ฅผ ์œ„ํ•ด ์‚ฌ์šฉ)
        • CLIP Score: ํ”„๋กฌํ”„ํŠธ-์ด๋ฏธ์ง€ ์ผ์น˜๋„ (๊ณผ์ตœ์ ํ™” ํƒ์ง€)
        • LPIPS-A/P: ์ƒ˜ํ”Œ ๋‹ค์–‘์„ฑ
    • Baselines: DDPO (RL๊ธฐ๋ฐ˜ ํŒŒ์ธํŠœ๋‹), DRaFT (์ง์ ‘ ์—ญ์ „ํŒŒ), TDPO (gradient-free RL), DAS (ํ…Œ์ŠคํŠธ ์‹œ๊ฐ„ ํƒ์ƒ‰)
  • ๊ฒฐ๊ณผ
  • DAV๋Š” reward(8.04)๊ฐ€ DDPO(6.83), DRaFT(7.22)๋ณด๋‹ค ํฌ๊ฒŒ ๋†’์œผ๋ฉด์„œ ImageReward(0.95)๋ฅผ ๊ธฐ์กด pretrained ์ˆ˜์ค€์œผ๋กœ ์œ ์ง€
    • โ†’ reward ํ•จ์ˆ˜๋ฅผ ์†์ด์ง€ ์•Š๊ณ (over-optimizationํ•˜์ง€ ์•Š๊ณ ) ์ง„์งœ๋กœ ์ข‹์€ ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ค์—ˆ๋‹ค
  • ๊ธฐ์กด ๋ฐฉ๋ฒ•๋“ค์€ ๋ณด์ƒ์„ ์˜ฌ๋ฆด์ˆ˜๋ก ImageReward, CLIP, ๋‹ค์–‘์„ฑ์ด ๋–จ์–ด์ง (over-optimization๋จ)
    • โ†’ ๊ธฐ์กด ๋ฐฉ์‹(DDPO, DRaFT)์€ Aesthetic Score ๋ผ๋Š” ํ‰๊ฐ€ ์ง€ํ‘œ์—๋งŒ ๋„ˆ๋ฌด over-optimization ๋œ ๋‚˜๋จธ์ง€ ์‹ค์ œ๋กœ ๋ณด๊ธฐ์—๋Š”(ImageReward) ๋‚˜๋น ์กŒ๋‹ค
  • DAV-KL์€ ๋‹ค์–‘์„ฑ๊ณผ ImageReward์—์„œ ๊ฐ€์žฅ ์šฐ์ˆ˜
  • DAV Posterior(ํ…Œ์ŠคํŠธ ์‹œ๊ฐ„ ํƒ์ƒ‰ ์ถ”๊ฐ€)๋Š” ๋ฏธ์  ์ ์ˆ˜ ์ตœ๊ณ  ์ ์ˆ˜ 9.18 ๋‹ฌ์„ฑ
  • DAV Posterior๋ž€?

    DAS: ๊ธฐ์กด test time search. ํ•™์Šต ์•ˆ๋œ ์›๋ณธ ๋ชจ๋ธ์ธ p0์—์„œ ํƒ์ƒ‰.

    DAV: ๋…ผ๋ฌธ์˜ em ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ํ†ตํ•ด์„œ ๋””ํ“จ์ „ ๋ชจ๋ธ ํ•™์Šต ํ›„์— ๋ชจ๋ธ๋งŒ ๊ฐ€์ง€๊ณ  ์ƒ˜ํ”Œ๋ง ๊ฒฐ๊ณผ

    DAV Posterior: ํ•™์Šต ํ›„์— ํ•™์Šต๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก  ๋•Œ๋งˆ๋‹ค ํƒ์ƒ‰(test time search)์„ ์ถ”๊ฐ€๋กœ ์ˆ˜ํ–‰. ์„ฑ๋Šฅ์ด ๋†’์ง€๋งŒ ์‹œ๊ฐ„์ด ์ข€ ๋” ๊ฑธ๋ฆผ

์‹คํ—˜ 2: DNA sequence design (์ด์‚ฐ diffusion)

  • ์‹คํ—˜ ์„ธํŒ…
    • ๋ชจ๋ธ: Masked Diffusion Language Model (MDLM)
    • ๋ฐ์ดํ„ฐ: 700K DNA ์ธํ•ธ์„œ ์„œ์—ด (200bp)
    • ๋ณด์ƒ(reward): Enformer ๋ชจ๋ธ์˜ ์ธํ•ธ์„œ ํ™œ์„ฑ๋„ ์˜ˆ์ธก๊ฐ’
    • ํ‰๊ฐ€ ์ง€ํ‘œ:
      • Pred-Activity: ์˜ˆ์ธก ํ™œ์„ฑ๋„ (reward)
      • ATAC-Acc: ์—ผ์ƒ‰์งˆ ์ ‘๊ทผ์„ฑ (์ƒ๋ฌผํ•™์  ํƒ€๋‹น์„ฑ, ๊ณผ์ตœ์ ํ™” ํƒ์ง€)
      • 3-mer Corr: k-mer ๋นˆ๋„ ์ƒ๊ด€๊ด€๊ณ„ (์ž์—ฐ์Šค๋Ÿฌ์›€)
      • Levenshtein Diversity: ์„œ์—ด ๊ฐ„ ํŽธ์ง‘ ๊ฑฐ๋ฆฌ (๋‹ค์–‘์„ฑ)
    • Baselines: DRAKES (์ง์ ‘ ์—ญ์ „ํŒŒ), DDPO/VIDD (RL ๊ธฐ๋ฐ˜)
  • ๊ฒฐ๊ณผ
    • DAV๋Š” ๋ณด์ƒ, ๋‹ค์–‘์„ฑ, ์ž์—ฐ์Šค๋Ÿฌ์›€ ๋ชจ๋“  ์ธก๋ฉด์—์„œ ๊ท ํ˜• ์žกํžŒ ์„ฑ๋Šฅ
    • DDPO/VIDD๋Š” reward๋Š” ๋†’์ง€๋งŒ ๋‹ค์–‘์„ฑ๊ณผ validity(ํƒ€๋‹น์„ฑ)์ด ๋‚ฎ์Œ (over-optimization)
    • DAV Posterior๋Š” reward(9.24)๊ณผ validity(0.920) ๋ชจ๋‘ ์ตœ๊ณ ์ 

Categories

DIFFUSION RL research