Back to Blog List

SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers

โ†Paper Review

ArXivhttps://arxiv.org/abs/2410.10629
Project Pagehttps://nvlabs.github.io/Sana/
Github Codehttps://github.com/NVlabs/Sana
AffiliationNVIDIA, MIT, Tsinghua University
๐Ÿ’ก

Key Differentiator

  1. Efficient Linear DiT design

    ReLU ๊ธฐ๋ฐ˜ Linear Attention ๋„์ž…

    Mix-FFN Block

  1. Deep Compression Autoencoder

    โ†’ ์ด๋กœ ์ธํ•œ 32๋ฐฐ ์••์ถ• ๊ฐ€๋Šฅ์œผ๋กœ ์—ฐ์‚ฐ๋„ ๋นจ๋ผ์ง

Blog Image

์ด๋ฒˆ์—๋Š” SVDQuant์˜ ์ €์ž์ธ Song Han์ด ๋˜ ์ผ์„ ๋ƒˆ๋‹ค.

SANA ๋ผ๋Š” Diffusion ๋ชจ๋ธ์„ NVIDIA์—์„œ ์ œ์ž‘ํ–ˆ๋Š”๋ฐ, ์—ญ๋Œ€๊ธ‰์ด๋‹ค.

๋‚ด๊ฐ€ ํ•˜๋ ค๋˜ On-device 4K Diffusion ์—ฐ๊ตฌ์—๋„ ํฌ๊ฒŒ ๋„์›€๋  ๊ฒƒ ๊ฐ™์•„์„œ ์ฝ์–ด๋ณด์•˜๋‹ค.

1. Introduction

์ง€๋‚œ 1๋…„๋™์•ˆ Diffusion ๋ชจ๋ธ์€ text-to-image ์—ฐ๊ตฌ์—์„œ ์ƒ๋‹นํ•œ ์ง„์ „์„ ๋ณด์ž„.

ํ•˜์ง€๋งŒ, ์•„๋ž˜์™€ ๊ฐ™์ด ์ƒ์—… ๋ชจ๋ธ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋งค์šฐ ์ปค์ง โ†’ ๋†’์€ ํ•™์Šต ๋ฐ ์ถ”๋ก  ๋น„์šฉ์„ ์ดˆ๋ž˜ํ•˜์—ฌ ๋น„์šฉ์ด ๋งŽ์ด ๋“ค์Œ.

Industry models are becoming increasingly large, with parameter counts escalating from PixArtโ€™s 0.6B parameters to SD3 at 8B, LiDiT at 10B, Flux at 12B, and Playground v3 at 24B.

cloud ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ edge devices์—์„œ๋„ ๋น ๋ฅด๊ฒŒ ์‹คํ–‰๋˜๋Š” ๊ณ ํ•ด์ƒ๋„ image generator๋ฅผ ๊ฐœ๋ฐœํ•  ์ˆ˜ ์—†์„๊นŒ?

Blog Image

์ด ๋…ผ๋ฌธ์€ 1024 ร— 1024 ~ 4096 ร— 4096 ๋ฒ”์œ„์˜ ํ•ด์ƒ๋„์—์„œ ์ด๋ฏธ์ง€๋ฅผ ํšจ์œจ์ ์ด๊ณ  ๋น„์šฉ ํšจ์œจ์ ์œผ๋กœ ํ›ˆ๋ จํ•˜๊ณ  ํ•ฉ์„ฑํ•˜๋„๋ก ์„ค๊ณ„๋œ ํŒŒ์ดํ”„ ๋ผ์ธ ์ธ SANA๋ฅผ ์ œ์•ˆ

Pixart-ฯƒ (Chen et al., 2024a)๋ฅผ ์ œ์™ธํ•˜๊ณ ๋Š” 4K ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์ง์ ‘ ํƒ์ƒ‰ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ Pixart-ฯƒ๋Š” 4K ํ•ด์ƒ๋„์— ๊ฐ€๊นŒ์šด ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ œํ•œ๋˜๋ฉฐ (3840 ร— 2160) ์ด๋Ÿฌํ•œ ๊ณ ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ ํ•  ๋•Œ ๋น„๊ต์  ๋А๋ฆฝ๋‹ˆ๋‹ค. ์ด ์•ผ์‹ฌ ์ฐฌ ๋ชฉํ‘œ๋ฅผ ๋‹ฌ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ๋ช‡ ๊ฐ€์ง€ ํ•ต์‹ฌ ๋””์ž์ธ์„ ์ œ์•ˆํ•ฉ๋‹ˆ๋‹ค.

2. METHODS

2.1 DEEP COMPRESSION AUTOENCODER

2.1.1 PRELIMINARY

์›๋ž˜ diffusion ๋ชจ๋ธ์€ ์ด๋ฏธ์ง€ ํ”ฝ์…€ ๊ณต๊ฐ„ (pixel space) ์œ„์—์„œ ์ง์ ‘ ์ž‘๋™ โ†’ ํ›ˆ๋ จ, ์ถ”๋ก  ๋‘˜๋‹ค ๋„ˆ๋ฌด ๋А๋ฆฌ๊ณ  ๋ฌด๊ฑฐ์›€

Latent Diffusion Models

Autoencoder๋กœ ์ด๋ฏธ์ง€ ์••์ถ• ํ›„ ์••์ถ•๋œ latent ๊ณต๊ฐ„ ์œ„์—์„œ diffusion์„ ๋Œ๋ฆฌ์ž!

โ†’ 8๋ฐฐ ์••์ถ• ์‚ฌ์šฉ

  • Pixel space: ๏ปฟ
  • Latent space: ๏ปฟ
    ์—ฌ๊ธฐ์„œ ๏ปฟ๋Š” latent ์ฑ„๋„ ์ˆ˜

Diffusion Transformer (DiT)

์ถ”๊ฐ€๋กœ latent feature๋ฅผ Patch ๋‹จ์œ„๋กœ ๋˜ ๋‚˜๋ˆ ์„œ ์ฒ˜๋ฆฌ

ํŒจ์น˜ํฌ๊ธฐ๊ฐ€ PxP ๋ผ๋ฉด ์ตœ์ข…์ ์œผ๋กœ ๋‹ค๋ฃจ๋Š” ํ† ํฐ ๊ฐœ์ˆ˜๋Š”

๊ธฐ์กด latent diffusion ๋ชจ๋ธ๋“ค(PixArt, SD3, Flux ๋“ฑ)์€ ๋ณดํ†ต ๋‹ค์Œ ์„ธํŒ…์„ ์”€

  • AE-F8C4P2 ๋˜๋Š” AE-F8C16P2
    • F8: Autoencoder๊ฐ€ 8๋ฐฐ ์••์ถ•
    • C4 ๋˜๋Š” C16: latent ์ฑ„๋„ ์ˆ˜ (4๊ฐœ๋‚˜ 16๊ฐœ)
    • P2: Patch ํฌ๊ธฐ 2ร—2๋กœ ๋ฌถ๊ธฐ

๊ธฐ์กด์ฒ˜๋Ÿผ 8๋ฐฐ ์••์ถ•๋งŒ ํ•˜๋ฉด ๊ณ„์‚ฐ๋Ÿ‰์ด ์—ฌ์ „ํžˆ ๋„ˆ๋ฌด ๋งŽ์Œ

๊ทธ๋ž˜์„œ SANA๋Š” ๊ณผ๊ฐํ•˜๊ฒŒ 32๋ฐฐ ์••์ถ•(AE-F32)ํ•˜๊ณ , ํŒจ์น˜๋กœ๋Š” ๋ฌถ์ง€ ์•Š์Œ.

2.1.2 AUTOENCODER DESIGN PHILOSOPHY

๊ตฌ๋ถ„๊ธฐ์กด (PixArt, Flux)SANA
AE ์••์ถ•๋น„8๋ฐฐ (F=8)32๋ฐฐ (F=32)
Patchify (P=2)O (ํŒจ์น˜๋กœ ๋ฌถ์Œ)โœ–๏ธ (ํŒจ์น˜ ์•ˆ ๋ฌถ์Œ)
์ตœ์ข… Token ์ˆ˜์ค„์˜€์ง€๋งŒ ์•„์ง ๋งŽ์Œํ›จ์”ฌ ์ ์Œ (16๋ฐฐ ๊ฐ์†Œ)
Blog Image

์œ„์˜ ํ‘œ๋ฅผ ๋ณด๋ฉด ์•Œ ์ˆ˜ ์žˆ๋“ฏ์ด 32๋ฐฐ ์••์ถ•ํ•˜๋”๋ผ๋„ ์ ์ˆ˜๊ฐ€ ํฌ๊ฒŒ ๋–จ์–ด์ง€์ง€ ์•Š๋Š” ๋ชจ์Šต์„ ๋ณด์ž„

2.1.3 ABLATION OF AUTOENCODER DESIGNS

  • ์–ด๋””์„œ ์••์ถ•์„ ๋” ํ•˜๋Š” ๊ฒŒ ์ข‹์€๊ฐ€? (AE vs DiT)
Blog Image
์„ค์ •์„ค๋ช…
AE-F8C16P48๋ฐฐ ์••์ถ• + ํŒจ์น˜ ํฌ๊ธฐ 4
AE-F16C32P216๋ฐฐ ์••์ถ• + ํŒจ์น˜ ํฌ๊ธฐ 2
AE-F32C32P132๋ฐฐ ์••์ถ• + ํŒจ์น˜ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Œ (SANA)

AE-F32C32P1 ์„ค์ •์ด ๊ฐ€์žฅ ๋›ฐ์–ด๋‚œ ์„ฑ๋Šฅ(FID, CLIP Score)์„ ๊ธฐ๋ก

Autoencoder๊ฐ€ ์••์ถ•์„ ์ „์ ์œผ๋กœ ๋‹ด๋‹นํ•˜๋Š” ๊ฒƒ์ด ์„ฑ๋Šฅ ๋ฐ ํ›ˆ๋ จ ์•ˆ์ •์„ฑ ๋ชจ๋‘์—์„œ ๊ฐ€์žฅ ์šฐ์ˆ˜

  • Autoencoder latent ์ฑ„๋„ ์ˆ˜๋ฅผ ๋ช‡ ๊ฐœ๋กœ ํ•˜๋Š” ๊ฒŒ ์ข‹์€๊ฐ€?
Blog Image
  • C=16, C=32, C=64 ์‹คํ—˜ ์ˆ˜ํ–‰
  • C=16์€ ์ •๋ณด ์†์‹ค๋กœ ์ธํ•ด ํ’ˆ์งˆ ์ €ํ•˜ ๋ฐœ์ƒ
  • C=64๋Š” ๋ณต์› ํ’ˆ์งˆ์€ ์ข‹์•˜์œผ๋‚˜ ๋ชจ๋ธ ๋ณต์žก๋„๊ฐ€ ๊ธ‰๊ฒฉํžˆ ์ฆ๊ฐ€ํ•˜์—ฌ ๋น„ํšจ์œจ์ ์ž„
  • C=32๊ฐ€ ์„ฑ๋Šฅ๊ณผ ํšจ์œจ ์‚ฌ์ด์—์„œ ์ตœ์  ๊ท ํ˜•์„ ๋‹ฌ์„ฑํ•จ

2.2 EFFICIENT LINEAR DIT DESIGN

  • ๊ธฐ์กด diffusion transformer(์˜ˆ: DiT) ๊ตฌ์กฐ๋Š” Self-Attention์„ ์‚ฌ์šฉํ•จ.
  • Self-Attention์˜ ์—ฐ์‚ฐ๋Ÿ‰์€ O(Nยฒ) ์— ๋น„๋ก€ํ•จ.
    • NNN์€ ์ž…๋ ฅ ํ† ํฐ ์ˆ˜
    • ํ† ํฐ ์ˆ˜๊ฐ€ ๋งŽ์•„์ง€๋ฉด ์—ฐ์‚ฐ๋Ÿ‰์ด ๊ธ‰๊ฒฉํžˆ ์ปค์ง
  • 4K ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€๋ฅผ ๋‹ค๋ฃจ๋ ค๋ฉด, latent token ์ˆ˜๊ฐ€ ๋งŽ์•„์งˆ ์ˆ˜๋ฐ–์— ์—†์Œ.

โ†’ ์ด๋•Œ, ๊ธฐ์กด ์—ฐ๊ตฌ๋“ค์€ ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋ ค๊ณ  ํ•ด์ƒ๋„ ๋‚ฎ์ถ”๊ฑฐ๋‚˜ Token ์ˆ˜๋ฅผ ์ค„์˜€์Œ.

ReLU ๊ธฐ๋ฐ˜ Linear Attention ๋„์ž…

๊ธฐ์กด Softmax ๊ธฐ๋ฐ˜ Attention์„ ์ œ๊ฑฐํ•˜๊ณ , ReLU๋ฅผ ์ด์šฉํ•œ Linear Attention์„ ์ฑ„ํƒ

Softmax๋Š” ๋ชจ๋“  Query-Token ์กฐํ•ฉ์„ ๋‹ค ๊ณ„์‚ฐํ•˜๊ธฐ ๋•Œ๋ฌธ์— O(Nยฒ) ๋ณต์žก๋„๊ฐ€ ๋ฐœ์ƒ

Blog Image
  • SANA์—์„œ๋Š” ๋‹ค์Œ์ฒ˜๋Ÿผ ๊ณ„์‚ฐ ๊ตฌ์กฐ๋ฅผ ๋ณ€๊ฒฝ:
    1. ๊ฐ Key์— ReLU๋ฅผ ์ ์šฉํ•จ: ReLU(K)
    1. ๋‘ ๊ฐ€์ง€ ๊ณต์œ  term์„ ๋ฏธ๋ฆฌ ๊ณ„์‚ฐ
      • ๏ปฟ (dxd matrix)
      • ๏ปฟ (dร—1 vector)
    1. ์ดํ›„, ๊ฐ Query์— ๋Œ€ํ•ด ์ด pre-computed shared term์„ ์žฌ์‚ฌ์šฉํ•˜์—ฌ Attention์„ ๊ณ„์‚ฐ
  • ์ด ๋ฐฉ์‹์€ Query๋งˆ๋‹ค ๊ฐœ๋ณ„์ ์œผ๋กœ ์—ฐ์‚ฐํ•  ํ•„์š”๊ฐ€ ์—†์–ด์„œ, ์ „์ฒด Attention ๊ณ„์‚ฐ์ด O(N) ์œผ๋กœ ์ค„์–ด๋“ฆ.

PixArt๋„ Linear์ด๋žฌ๋Š”๋ฐ ๋‹ค๋ฅธ์ ์€?

PixArt์—์„œ๋Š” Key์™€ Value ํ† ํฐ์„ ์••์ถ•ํ•˜์—ฌ ์—ฐ์‚ฐ๋Ÿ‰์„ ์ค„์—ฌ์„œ Engineering Optimization์œผ๋กœ O(N)๊ณผ ๋น„์Šทํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ์‹, SANA์—์„œ๋Š” ์•„์˜ˆ ์ˆ˜ํ•™์ ์œผ๋กœ ๊ณ„์‚ฐ๋Ÿ‰์ด O(N)

Mix-FFN Block

Blog Image

๊ธฐ์กด Transformer์˜ FFN (Feed-Forward Network)์€ ๋‹จ์ˆœํžˆ 2๊ฐœ์˜ Linear Layer๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์—ˆ์Œ.

FFN์€ ์ „์—ญ์ ์ธ ์ •๋ณด๋Š” ์ž˜ ์ฒ˜๋ฆฌํ•˜์ง€๋งŒ, ์ง€์—ญ์ ์ธ(local) ๋””ํ…Œ์ผ ๋ณต์›์—๋Š” ์•ฝํ–ˆ์Œ.

SANA์˜ ํ•ด๊ฒฐ์ฑ…:

  • ๊ธฐ์กด MLP ์‚ฌ์ด์— 3ร—3 Depthwise Convolution์„ ์‚ฝ์ž…
  • ์ด๋ฅผ ํ†ตํ•ด ์ง€์—ญ ๊ตฌ์กฐ(local structure) ํ•™์Šต์„ ๊ฐ•ํ™”
  • ๊ฒฐ๊ณผ์ ์œผ๋กœ ํ…์Šค์ฒ˜, ๊ฒฝ๊ณ„์„ , ์ด๋Ÿฐ ์„ธ๋ฐ€ํ•œ ๋ถ€๋ถ„ ๋ณต์›์— ์œ ๋ฆฌ

DiT without Positional Encoding (NoPE)

๊ธฐ์กด Transformer ๊ตฌ์กฐ๋Š” ์ž…๋ ฅ ์ˆœ์„œ๋ฅผ ์ธ์‹ํ•˜๋„๋ก Positional Encoding์„ ์‚ฌ์šฉํ–ˆ์Œ.

์™œ๋ƒ๋ฉด Transformer๋Š” ์ž…๋ ฅ ์ˆœ์„œ๋ฅผ ๊ตฌ๋ณ„ํ•  ์ˆ˜ ์—†์—ˆ๊ธฐ ๋•Œ๋ฌธ์—โ€ฆ

ํ•˜์ง€๋งŒ 4K ๊ณ ํ•ด์ƒ๋„ latent์ฒ˜๋Ÿผ ํ† ํฐ ์ˆ˜๊ฐ€ ๋งŽ์„ ๋•Œ, Positional Encoding์„ ๊ณ„์‚ฐํ•˜๊ณ  ์ €์žฅํ•˜๋Š” ๋ฐ๋„ ๋น„์šฉ์ด ํผ.

SANA์—์„œ๋Š” ์•„์˜ˆ Positional Encoding์„ ์ œ๊ฑฐํ•จ

  • 3ร—3 Depthwise Convolution์ด Mix-FFN์— ์ถ”๊ฐ€๋˜์–ด์„œ ์ง€์—ญ์  ์œ„์น˜ ๊ด€๊ณ„๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Œ.
  • Linear Attention์€ ์ „์—ญ ๊ด€๊ณ„๋ฅผ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํฌ์ฐฉํ•  ์ˆ˜ ์žˆ์Œ.

    โ†’ ๋ณ„๋„๋กœ ์œ„์น˜ ์ •๋ณด๋ฅผ ๋ถ€์—ฌํ•˜์ง€ ์•Š์•„๋„ ์ถฉ๋ถ„ํžˆ ํŒจํ„ด๊ณผ ๊ตฌ์กฐ๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Œ.

๊ฒฐ๊ณผ์ ์œผ๋กœ ํ’ˆ์งˆ์ด ์œ ์ง€๋˜๋ฉด์„œ ๊ตฌ์กฐ๊ฐ€ ๊ฐ„๋‹จํ•ด์ง€๊ณ  ๋ฉ”๋ชจ๋ฆฌ ์—ฐ์‚ฐ๋Ÿ‰์ด ๊ฐ์†Œํ•จ

Blog Image

Triton Acceleration Training/Inference

Appendix์— ์ถ”๊ฐ€ํ•œ๋‹ค๊ณ  ๋˜์–ด์žˆ๋Š”๋ฐ ์•„์ง ๊ด€๋ จ ๋‚ด์šฉ ์—†์Œ.

GPT๊ฐ€ ์•Œ๋ ค์ค€ ๋‚ด์šฉ

  • Linear Attention์„ ๊ตฌํ˜„ํ•  ๋•Œ, ๋‹จ์ˆœํžˆ ์•Œ๊ณ ๋ฆฌ์ฆ˜๋งŒ ๊ฐœ์„ ํ•˜๋Š” ๊ฒƒ์œผ๋กœ๋Š” ๋ถ€์กฑํ•จ.
  • ์‹ค์ œ ์—ฐ์‚ฐ ํšจ์œจ๊นŒ์ง€ ๊ทน๋Œ€ํ™”ํ•˜๋ ค๋ฉด, GPU kernel ๋ ˆ๋ฒจ ์ตœ์ ํ™”๊ฐ€ ํ•„์š”ํ•จ.
  • SANA๋Š” Triton์„ ์‚ฌ์šฉํ•˜์—ฌ Linear Attention ์—ฐ์‚ฐ์„ ์ง์ ‘ ์ตœ์ ํ™”ํ•จ.
    • Triton์€ NVIDIA๊ฐ€ ์ง€์›ํ•˜๋Š” ์ปค์Šคํ…€ GPU ์ปค๋„ ํ”„๋กœ๊ทธ๋ž˜๋ฐ ํ”„๋ ˆ์ž„์›Œํฌ์ž„.
    • CUDA๋ณด๋‹ค ๋‹จ์ˆœํ•œ ๋ฌธ๋ฒ•์œผ๋กœ, ๊ณ ์„ฑ๋Šฅ ์ปค๋„์„ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์Œ.

Triton์œผ๋กœ ์ตœ์ ํ™”ํ•œ ๊ฒฐ๊ณผ:

  • Matrix ๊ณฑ ์—ฐ์‚ฐ(GEMM)๊ณผ Memory Access๋ฅผ ์ค„์ž„.
  • ์‹ค์ œ latency(์ง€์—ฐ ์‹œ๊ฐ„)์™€ memory bandwidth ์†Œ๋ชจ๋ฅผ ํฌ๊ฒŒ ๊ฐœ์„ 

2.3 TEXT ENCODER DESIGN

์™œ T5 ๋Œ€์‹  Decoder-only LLM์„ ์‚ฌ์šฉํ•˜๋Š”๊ฐ€?

SANA๋Š” Gemma๋ฅผ Text Encoder๋กœ ์‚ฌ์šฉํ•˜๊ธฐ๋กœ ์ฑ„ํƒ

ํ•ญ๋ชฉ๊ธฐ์กด (T5)SANA (Gemma-2)
๋ชจ๋ธ ๊ตฌ์กฐEncoder-DecoderDecoder-only
Reasoning ๋Šฅ๋ ฅ์ œํ•œ์ ๋งค์šฐ ๊ฐ•ํ•จ (CoT, ICL ๊ฐ€๋Šฅ)
์ถ”๋ก  ์†๋„๋А๋ฆผ (T5-XXL)6๋ฐฐ ๋น ๋ฆ„ (Gemma-2-2B)

Blog Image

โ†’ ๋น ๋ฅธ๋ฐ๋„ ๋ถˆ๊ตฌํ•˜๊ณ , CLIP Score์™€ FID(์ด๋ฏธ์ง€ ํ’ˆ์งˆ ์ง€ํ‘œ)์—์„œ๋Š” ์„ฑ๋Šฅ์ด ๋น„์Šทํ•จ

Decoder-only LLM์„ Text Encoder๋กœ ์“ฐ๋ฉด์„œ ์ƒ๊ธด ๋ฌธ์ œ ํ•ด๊ฒฐ

Decoder-only LLM (Gemma, Qwen ๋“ฑ)์˜ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์€ Variance๊ฐ€ ํ›จ์”ฌ ํผ.

  • ํฐ ๊ฐ’์ด ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์•ˆ์— ๋งŽ์ด ํฌํ•จ๋˜์–ด ์žˆ์Œ.
  • Cross-Attention ์—ฐ์‚ฐ ์ค‘ ์ˆ˜์น˜ ํญ๋ฐœ(NaN)๋กœ ์ด์–ด์ง.

๋ฐฉ๋ฒ• 1: RMSNorm ์ถ”๊ฐ€

Gemma-2์˜ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ถœ๋ ฅ์— RMSNorm์„ ์ ์šฉ

RMSNorm?

  • ์ž…๋ ฅ ๋ฒกํ„ฐ์˜ Variance๋ฅผ 1.0์œผ๋กœ ์ •๊ทœํ™”
  • ํฐ ๊ฐ’์ด๋‚˜ ์ž‘์€ ๊ฐ’๋“ค์„ ๊ท ์ผํ•˜๊ฒŒ ๋งŒ๋“ค์–ด ์ˆ˜์น˜ ํญ๋ฐœ ๋ฐฉ์ง€

๋ฐฉ๋ฒ• 2: Learnable Scale Factor ์ถ”๊ฐ€

  • ์ถ”๊ฐ€๋กœ, ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ์— ํ•™์Šต ๊ฐ€๋Šฅํ•œ ์ž‘์€ ์Šค์ผ€์ผ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ณฑํ•จ
  • ์ดˆ๊ธฐ ๊ฐ’์€ ๋งค์šฐ ์ž‘๊ฒŒ ์„ค์ •ํ•จ (์˜ˆ: 0.01)
  • ์ด ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ํ•™์Šต์„ ํ†ตํ•ด ์ ์ ˆํ•œ ํฌ๊ธฐ๋กœ ์กฐ์ •๋˜๋ฉด์„œ ๋ชจ๋ธ ์ˆ˜๋ ด ์†๋„๊ฐ€ ๋นจ๋ผ์ง

โ†’ ํ›ˆ๋ จ ์•ˆ์ •์„ฑ ํ™•๋ณด + ์ˆ˜๋ ด ์†๋„ ํ–ฅ์ƒ

Blog Image

Complex Human Instruction Improves Text-Image Alignment

Gemma๋Š” ๊ฐ•๋ ฅํ•œ LLM์ด์ง€๋งŒ, ์‚ฌ์šฉ์ž๊ฐ€ ์งง๊ฑฐ๋‚˜ ๋ชจํ˜ธํ•œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด (์˜ˆ: "a cat")

LLM์ด ์ดˆ์ ์„ ์žƒ๊ณ  ์—‰๋šฑํ•œ ๋‹ต๋ณ€์„ ํ•  ์ˆ˜๋„ ์žˆ์Œ.

โ†’ LLM์ด ํ”„๋กฌํ”„ํŠธ์—๋งŒ ์ง‘์ค‘ํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ์ถ”๊ฐ€ ์ง€์‹œ๋ฌธ์ด ํ•„์š”ํ•จ.

CHI๊ฐ€ ๊ทธ๋ž˜์„œ ๋ญ”๋””?

LLM์˜ In-Context Learning ๋Šฅ๋ ฅ์„ ํ™œ์šฉํ•˜์—ฌ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฃผ๊ธฐ ์ „์—,

LLM์—๊ฒŒ "์ƒ‰์ƒ, ํฌ๊ธฐ, ์œ„์น˜ ๊ด€๊ณ„ ๊ฐ™์€ ์„ธ๋ถ€ ๋ฌ˜์‚ฌ๋ฅผ ์ถ”๊ฐ€ํ•ด๋ผ"์™€ ๊ฐ™์€ ๋ณต์žกํ•œ ๋ช…๋ น ์„ธํŠธ๋ฅผ ํ•จ๊ป˜ ์ œ๊ณตํ•˜๋Š” ๊ฒƒ

๊ฒฐ๊ณผ 1

Blog Image

CHI๋ฅผ ์ ์šฉํ–ˆ์„ ๋•Œ, ํ•™์Šต์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ํ•˜๋“ (fresh training)

์•„๋‹ˆ๋ฉด ๊ธฐ์กด ๋ชจ๋ธ์„ ๋ฏธ์„ธ ์กฐ์ •(fine-tuning)ํ•˜๋“ 

ํ…์ŠคํŠธ-์ด๋ฏธ์ง€ ์ •๋ ฌ ์„ฑ๋Šฅ์ด ํ–ฅ์ƒ

๊ฒฐ๊ณผ 2

Blog Image

์งง์€ ํ”„๋กฌํ”„ํŠธ(์˜ˆ: "a cat")๋ฅผ ์ž…๋ ฅํ–ˆ์„ ๋•Œ,

CHI๊ฐ€ ์—†์œผ๋ฉด, ๋ชจ๋ธ์ด ์—‰๋šฑํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ฑฐ๋‚˜ ํ’ˆ์งˆ์ด ๋ถˆ์•ˆ์ •ํ•ด์ง.

CHI๊ฐ€ ์žˆ์œผ๋ฉด, ๋ชจ๋ธ์ด ํ”„๋กฌํ”„ํŠธ์— ์ •ํ™•ํžˆ ๋งž๋Š” ์•ˆ์ •์ ์ธ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ

3 EFFICIENT TRAINING/INFERENCE

3.1 DATA CURATION AND BLENDING

1. Multi-Caption Auto-labelling Pipeline

์ด๋ฏธ์ง€ ํ•˜๋‚˜๋‹น 4๊ฐœ์˜ VLM(Vision-Language Models) ์„ ์ด์šฉํ•ด ์บก์…˜์„ ์ƒ์„ฑํ•จ.

  • VILA-3B
  • VILA-13B
  • InternVL-28B
  • InternVL-26B

โ†’ ์ •ํ™•ํ•œ ์บก์…˜ ์ƒ์„ฑ (ํ•˜๋‚˜๋งŒ ์“ฐ๋Š” ๊ฒƒ๋ณด๋‹ค ์˜ค๋ฅ˜ ์ค„์ž„)

โ†’ ๋‹ค์–‘ํ•œ ํ‘œํ˜„ ํ™•๋ณด (๊ฐ™์€ ์ด๋ฏธ์ง€๋ฅผ ์—ฌ๋Ÿฌ ๊ด€์ ์—์„œ ๋ฌ˜์‚ฌ ๊ฐ€๋Šฅ)


2. CLIP-Score-based Caption Sampler

๋ฌธ์ œ ์ƒํ™ฉ

  • ์บก์…˜์„ ์—ฌ๋Ÿฌ ๊ฐœ ๋งŒ๋“ค์—ˆ๋Š”๋ฐ,

    ํ›ˆ๋ จํ•  ๋•Œ ์–ด๋–ค ์บก์…˜์„ ์„ ํƒํ• ์ง€๊ฐ€ ๋ฌธ์ œ์ž„.

  • ๋ฌด์ž‘์œ„๋กœ(random) ํ•˜๋‚˜ ๊ณ ๋ฅด๋ฉด:
    • ํ’ˆ์งˆ์ด ๋‚ฎ์€ ๋ฌธ์žฅ์„ ๋ฝ‘์„ ์œ„ํ—˜์ด ์žˆ์Œ
    • ๊ทธ๋Ÿฌ๋ฉด ํ›ˆ๋ จ์ด ๋А๋ ค์ง€๊ฑฐ๋‚˜ ๋ชจ๋ธ ํ’ˆ์งˆ์ด ๋–จ์–ด์ง

ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•

  • CLIP score๋ฅผ ํ™œ์šฉํ•ด ํ’ˆ์งˆ ๋†’์€ ์บก์…˜์„ ๋ฝ‘๋Š” ๋ฐฉ์‹ ์‚ฌ์šฉ
    • CLIP์€ ์ด๋ฏธ์ง€-ํ…์ŠคํŠธ ๋งค์นญ ์ •๋„๋ฅผ ์ ์ˆ˜๋กœ ๊ณ„์‚ฐํ•ด์คŒ.
  • ๊ณผ์ •:
    1. ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ์ƒ์„ฑ๋œ ๊ฐ ์บก์…˜์˜ CLIP ์ ์ˆ˜(cic_ici)๋ฅผ ๊ณ„์‚ฐ
    1. ์ ์ˆ˜๊ฐ€ ๋†’์€ ์บก์…˜์ผ์ˆ˜๋ก ๋ฝ‘ํž ํ™•๋ฅ ์ด ๋†’๊ฒŒ ์„ค์ •
    1. Sampling ํ™•๋ฅ  ๊ณต์‹:

์—ฌ๊ธฐ

  • ฯ„๋Š” "temperature"๋ผ๋Š” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ์ž„.
  • Temperature ์กฐ์ •์œผ๋กœ ๋ฝ‘๋Š” ๊ฐ•๋„๋ฅผ ์กฐ์ ˆํ•  ์ˆ˜ ์žˆ์Œ:
    • ฯ„๊ฐ€ ์ž‘์œผ๋ฉด: ์ ์ˆ˜ ๊ฐ€์žฅ ๋†’์€ ์บก์…˜๋งŒ ๊ฑฐ์˜ ํ•ญ์ƒ ์„ ํƒ
    • ฯ„๊ฐ€ ํฌ๋ฉด: ๋‹ค์–‘ํ•œ ์บก์…˜์ด ๊ณ ๋ฅด๊ฒŒ ์„ ํƒ๋จ

์‹คํ—˜ ๊ฒฐ๊ณผ

  • Table 4 ๊ฒฐ๊ณผ์— ๋”ฐ๋ฅด๋ฉด:
    • ์บก์…˜์„ ๋‹ค์–‘ํ•˜๊ฒŒ ๊ณจ๋ผ๋„ ์ด๋ฏธ์ง€ ํ’ˆ์งˆ(FID)์€ ๊ฑฐ์˜ ๋ณ€ํ•˜์ง€ ์•Š์Œ
    • ํ•˜์ง€๋งŒ ํ›ˆ๋ จ ์ค‘ ํ…์ŠคํŠธ-์ด๋ฏธ์ง€ ์ •๋ ฌ(semantic alignment)์€ ํ›จ์”ฌ ์ข‹์•„์ง

3. Cascade Resolution Training

๊ธฐ์กด ๋ฐฉ์‹

  • ๋Œ€๋ถ€๋ถ„์˜ diffusion ๋ชจ๋ธ์€ ํ•ด์ƒ๋„ 256px์งœ๋ฆฌ ์ด๋ฏธ์ง€๋กœ ๋จผ์ € pre-training์„ ํ•จ.
  • ์ด์œ ๋Š” ์—ฐ์‚ฐ ๋น„์šฉ(cost)์„ ์ค„์ด๊ธฐ ์œ„ํ•ด์„œ์ž„.

๋ฌธ์ œ์ 

  • 256px ์ด๋ฏธ์ง€๋Š” ๋””ํ…Œ์ผ(detail) ์†์‹ค์ด ์‹ฌํ•จ.
  • ๋”ฐ๋ผ์„œ, ์ž‘์€ ํ•ด์ƒ๋„์—์„œ ํ•™์Šต์„ ์‹œ์ž‘ํ•˜๋ฉด:
    • ๋ชจ๋ธ์ด fineํ•œ ๊ตฌ์กฐ๋‚˜ ํ…์Šค์ฒ˜๋ฅผ ๋ฐฐ์šฐ๊ธฐ ์–ด๋ ค์›€
    • ๊ฒฐ๊ตญ ํฐ ํ•ด์ƒ๋„๋กœ ๊ฐˆ ๋•Œ ๋” ๋А๋ฆฌ๊ฒŒ ํ•™์Šตํ•จ.

SANA ๋ฐฉ์‹

  • SANA๋Š” AE-F32C32P1 ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์—

    latent ๊ณต๊ฐ„์ด ๋งค์šฐ ์ž‘์Œ โ†’ ์—ฐ์‚ฐ ๋ถ€๋‹ด์ด ์ ์Œ.

  • ๊ทธ๋ž˜์„œ ๊ตณ์ด 256px์—์„œ ์‹œ์ž‘ํ•  ํ•„์š” ์—†์ด ๋ฐ”๋กœ 512px์—์„œ ํ•™์Šต ์‹œ์ž‘ํ•จ.
  • ํ•™์Šต ์ˆœ์„œ:
    • 512px โ†’ 1024px โ†’ 2K โ†’ 4K ์ˆœ์„œ๋กœ ์ ์ง„์ (fine-tuning)์œผ๋กœ ํ•ด์ƒ๋„๋ฅผ ์˜ฌ๋ฆผ.

3.2 FLOW-BASED TRAINING / INFERENCE

Flow-based Training

๊ธฐ์กด ๋ฐฉ์‹ : noise prediction

  • ๏ปฟ: ์›๋ณธ ์ด๋ฏธ์ง€
  • ๏ปฟ: ๋žœ๋ค ๋…ธ์ด์ฆˆ
  • ๏ปฟ: diffusion ๊ณผ์ •์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ

โ†’ ๋…ธ์ด์ฆˆ๋ฅผ ๋งž์ถ”๋Š” ๊ฒƒ์ด ํ•™์Šต ๋ชฉํ‘œ

๋ฌธ์ œ์  : t๊ฐ€ ์ปค์ง€๋ฉด (Diffusion ๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„์— ๊ฐ€๊นŒ์šฐ๋ฉด) ๋…ธ์ด์ฆˆ๊ฐ€ ์ปค์ ธ์„œ ์˜ˆ์ธก ๋ถˆ์•ˆ์ •

์ƒˆ ๋ฐฉ์‹ : EDM ,RF

noise ๋Œ€์‹  data๋‚˜ velocity(๋…ธ์ด์ฆˆ์™€ ์›๋ณธ ์ด๋ฏธ์ง€ ์ฐจ์ด) ์˜ˆ์ธก

EDM : ๏ปฟ (์›๋ณธ ๋ฐ์ดํ„ฐ ์˜ˆ์ธก)

RF : ๏ปฟ (velocity ์˜ˆ์ธก)

๊ฒฐ๊ตญ RF๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ cumulative(๋ˆ„์ ) error๋ฅผ ์ค„์ผ ์ˆ˜ ์žˆ์Œ

Flow-based Inference

๊ธฐ์กด : DPM-Solver++

โ†’ required 28-50 steps for high-quality samples

ํ˜„์žฌ : Flow-DPM-Solver

1.Not predict original data, but velocity

2.substituting the scaling factor ฮฑt with 1 โˆ’ ฯƒt

3.time-steps are redefined over the range [0, 1] instead of [1, 1000]

โ†’Generate high-quality samples in 14-20 steps

5. Experiments

1. Model Details

Sana-0.6B

  • ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: 590M
  • ๊ตฌ์กฐ: DiT-XL ๋ฐ PixArt-ฮฃ์™€ ๊ฑฐ์˜ ๋™์ผํ•œ ๋ ˆ์ด์–ด ์ˆ˜์™€ ์ฑ„๋„ ์ˆ˜ ์‚ฌ์šฉ
  • ๋ชฉ์ : ์†Œํ˜• ๋ชจ๋ธ๋กœ๋„ ํšจ์œจ์„ฑ๊ณผ ํ’ˆ์งˆ์„ ๋™์‹œ์— ํ™•๋ณด

Sana-1.6B

  • ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: 1.6B
  • ๊ตฌ์กฐ:
    • 20๊ฐœ์˜ Transformer ๋ ˆ์ด์–ด
    • ๊ฐ ๋ ˆ์ด์–ด๋งˆ๋‹ค 2240๊ฐœ์˜ ์ฑ„๋„
    • FFN ๋‚ด๋ถ€ ์ฑ„๋„ ์ˆ˜๋Š” 5600
  • ์ด ๊ตฌ์„ฑ์€ ํ•™์Šต ํšจ์œจ์„ฑ๊ณผ ์ƒ์„ฑ ํ’ˆ์งˆ ์‚ฌ์ด์˜ ๊ท ํ˜•์„ ๊ณ ๋ คํ•œ ๊ฒƒ์ž„

2. Evaluation Details

SANA๋Š” ์ด 5๊ฐ€์ง€ ๋Œ€ํ‘œ์ ์ธ ํ‰๊ฐ€ ์ง€ํ‘œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•จ.

์ง€ํ‘œ์„ค๋ช…
FID (Frรฉchet Inception Distance)์ด๋ฏธ์ง€ ํ’ˆ์งˆ์„ ์ˆ˜์น˜๋กœ ์ธก์ •. ๋‚ฎ์„์ˆ˜๋ก ์ข‹์Œ
CLIP Score์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ๊ฐ„ ์˜๋ฏธ์  ์ •๋ ฌ ์ •๋„ ํ‰๊ฐ€. ๋†’์„์ˆ˜๋ก ์ข‹์Œ
GenEval (Ghosh et al., 2024)ํ…์ŠคํŠธ-์ด๋ฏธ์ง€ ์ •๋ ฌ ํ‰๊ฐ€. ์ด 533๊ฐœ์˜ ํ”„๋กฌํ”„ํŠธ ์‚ฌ์šฉ
DPG-Bench (Hu et al., 2024)ํ…์ŠคํŠธ-์ด๋ฏธ์ง€ ์ •๋ ฌ ์ •๋ฐ€๋„ ํ…Œ์ŠคํŠธ. 1065๊ฐœ์˜ ํ”„๋กฌํ”„ํŠธ ์‚ฌ์šฉ
ImageReward (Xu et al., 2024)์ธ๊ฐ„์˜ ์ฃผ๊ด€์  ์„ ํ˜ธ๋„๋ฅผ ๋ฐ˜์˜ํ•œ ์ ์ˆ˜. 100๊ฐœ ํ”„๋กฌํ”„ํŠธ๋กœ ์ธก์ •

3. ํ‰๊ฐ€ ๋ฐ์ดํ„ฐ์…‹

  • MJHQ-30K (Li et al., 2024a)
    • Midjourney์—์„œ ์ˆ˜์ง‘ํ•œ 30,000๊ฐœ ๊ณ ํ’ˆ์งˆ ์ด๋ฏธ์ง€ ํฌํ•จ
    • FID, CLIP Score ์ธก์ •์— ์‚ฌ์šฉ๋จ

Blog Image

Blog Image

Limitation

Blog Image

์ฝ”๋“œ๊ฐ€ ์ „๋ฐ˜์ ์œผ๋กœ NVIDIA์นฉ๋งŒ์„ ์œ„ํ•ด ์„ค๊ณ„๋˜์–ด

๋‹ค๋ฅธ GPU ์žฅ๋น„๋Š” ๋ฌผ๋ก ์ด๊ณ ,

Mobile Device์—์„œ๋Š” ๋‹น์—ฐํžˆ ๋ถˆ๊ฐ€๋Šฅํ•จ.

์•„๋ฌด๋ž˜๋„ NVIDIA์—์„œ ๋‚ธ ๋…ผ๋ฌธ์ด๊ธฐ ๋•Œ๋ฌธ์— Blackwell chip ํ™๋ณด ๊ฒธ NVIDIA chip์—์„œ๋งŒ ๊ฐ€๋Šฅํ•˜๋„๋ก ํ•œ๋“ฏ.