| ArXiv | https://arxiv.org/abs/2410.10629 |
|---|---|
| Project Page | https://nvlabs.github.io/Sana/ |
| Github Code | https://github.com/NVlabs/Sana |
| Affiliation | NVIDIA, MIT, Tsinghua University |
Key Differentiator
- Efficient Linear DiT design
ReLU ๊ธฐ๋ฐ Linear Attention ๋์
Mix-FFN Block
- Deep Compression Autoencoder
โ ์ด๋ก ์ธํ 32๋ฐฐ ์์ถ ๊ฐ๋ฅ์ผ๋ก ์ฐ์ฐ๋ ๋นจ๋ผ์ง

์ด๋ฒ์๋ SVDQuant์ ์ ์์ธ Song Han์ด ๋ ์ผ์ ๋๋ค.
SANA ๋ผ๋ Diffusion ๋ชจ๋ธ์ NVIDIA์์ ์ ์ํ๋๋ฐ, ์ญ๋๊ธ์ด๋ค.
๋ด๊ฐ ํ๋ ค๋ On-device 4K Diffusion ์ฐ๊ตฌ์๋ ํฌ๊ฒ ๋์๋ ๊ฒ ๊ฐ์์ ์ฝ์ด๋ณด์๋ค.
1. Introduction
์ง๋ 1๋ ๋์ Diffusion ๋ชจ๋ธ์ text-to-image ์ฐ๊ตฌ์์ ์๋นํ ์ง์ ์ ๋ณด์.
ํ์ง๋ง, ์๋์ ๊ฐ์ด ์์ ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๊ฐ ๋งค์ฐ ์ปค์ง โ ๋์ ํ์ต ๋ฐ ์ถ๋ก ๋น์ฉ์ ์ด๋ํ์ฌ ๋น์ฉ์ด ๋ง์ด ๋ค์.
Industry models are becoming increasingly large, with parameter counts escalating from PixArtโs 0.6B parameters to SD3 at 8B, LiDiT at 10B, Flux at 12B, and Playground v3 at 24B.
cloud ๋ฟ๋ง ์๋๋ผ edge devices์์๋ ๋น ๋ฅด๊ฒ ์คํ๋๋ ๊ณ ํด์๋ image generator๋ฅผ ๊ฐ๋ฐํ ์ ์์๊น?

์ด ๋ ผ๋ฌธ์ 1024 ร 1024 ~ 4096 ร 4096 ๋ฒ์์ ํด์๋์์ ์ด๋ฏธ์ง๋ฅผ ํจ์จ์ ์ด๊ณ ๋น์ฉ ํจ์จ์ ์ผ๋ก ํ๋ จํ๊ณ ํฉ์ฑํ๋๋ก ์ค๊ณ๋ ํ์ดํ ๋ผ์ธ ์ธ SANA๋ฅผ ์ ์
Pixart-ฯ (Chen et al., 2024a)๋ฅผ ์ ์ธํ๊ณ ๋ 4K ํด์๋ ์ด๋ฏธ์ง ์์ฑ์ ์ง์ ํ์ํ์ง ๋ชปํ์ต๋๋ค. ๊ทธ๋ฌ๋ Pixart-ฯ๋ 4K ํด์๋์ ๊ฐ๊น์ด ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ ๊ฒ์ผ๋ก ์ ํ๋๋ฉฐ (3840 ร 2160) ์ด๋ฌํ ๊ณ ํด์๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑ ํ ๋ ๋น๊ต์ ๋๋ฆฝ๋๋ค. ์ด ์ผ์ฌ ์ฐฌ ๋ชฉํ๋ฅผ ๋ฌ์ฑํ๊ธฐ ์ํด ๋ช ๊ฐ์ง ํต์ฌ ๋์์ธ์ ์ ์ํฉ๋๋ค.
2. METHODS
2.1 DEEP COMPRESSION AUTOENCODER
2.1.1 PRELIMINARY
์๋ diffusion ๋ชจ๋ธ์ ์ด๋ฏธ์ง ํฝ์ ๊ณต๊ฐ (pixel space) ์์์ ์ง์ ์๋ โ ํ๋ จ, ์ถ๋ก ๋๋ค ๋๋ฌด ๋๋ฆฌ๊ณ ๋ฌด๊ฑฐ์
Latent Diffusion Models
Autoencoder๋ก ์ด๋ฏธ์ง ์์ถ ํ ์์ถ๋ latent ๊ณต๊ฐ ์์์ diffusion์ ๋๋ฆฌ์!
โ 8๋ฐฐ ์์ถ ์ฌ์ฉ
- Pixel space: ๏ปฟ
- Latent space: ๏ปฟ
์ฌ๊ธฐ์ ๏ปฟ๋ latent ์ฑ๋ ์
Diffusion Transformer (DiT)
์ถ๊ฐ๋ก latent feature๋ฅผ Patch ๋จ์๋ก ๋ ๋๋ ์ ์ฒ๋ฆฌ
ํจ์นํฌ๊ธฐ๊ฐ PxP ๋ผ๋ฉด ์ต์ข ์ ์ผ๋ก ๋ค๋ฃจ๋ ํ ํฐ ๊ฐ์๋
๊ธฐ์กด latent diffusion ๋ชจ๋ธ๋ค(PixArt, SD3, Flux ๋ฑ)์ ๋ณดํต ๋ค์ ์ธํ ์ ์
- AE-F8C4P2 ๋๋ AE-F8C16P2
- F8: Autoencoder๊ฐ 8๋ฐฐ ์์ถ
- C4 ๋๋ C16: latent ์ฑ๋ ์ (4๊ฐ๋ 16๊ฐ)
- P2: Patch ํฌ๊ธฐ 2ร2๋ก ๋ฌถ๊ธฐ
๊ธฐ์กด์ฒ๋ผ 8๋ฐฐ ์์ถ๋ง ํ๋ฉด ๊ณ์ฐ๋์ด ์ฌ์ ํ ๋๋ฌด ๋ง์
๊ทธ๋์ SANA๋ ๊ณผ๊ฐํ๊ฒ 32๋ฐฐ ์์ถ(AE-F32)ํ๊ณ , ํจ์น๋ก๋ ๋ฌถ์ง ์์.
2.1.2 AUTOENCODER DESIGN PHILOSOPHY
| ๊ตฌ๋ถ | ๊ธฐ์กด (PixArt, Flux) | SANA |
| AE ์์ถ๋น | 8๋ฐฐ (F=8) | 32๋ฐฐ (F=32) |
| Patchify (P=2) | O (ํจ์น๋ก ๋ฌถ์) | โ๏ธ (ํจ์น ์ ๋ฌถ์) |
| ์ต์ข Token ์ | ์ค์์ง๋ง ์์ง ๋ง์ | ํจ์ฌ ์ ์ (16๋ฐฐ ๊ฐ์) |

์์ ํ๋ฅผ ๋ณด๋ฉด ์ ์ ์๋ฏ์ด 32๋ฐฐ ์์ถํ๋๋ผ๋ ์ ์๊ฐ ํฌ๊ฒ ๋จ์ด์ง์ง ์๋ ๋ชจ์ต์ ๋ณด์
2.1.3 ABLATION OF AUTOENCODER DESIGNS
- ์ด๋์ ์์ถ์ ๋ ํ๋ ๊ฒ ์ข์๊ฐ? (AE vs DiT)

| ์ค์ | ์ค๋ช |
| AE-F8C16P4 | 8๋ฐฐ ์์ถ + ํจ์น ํฌ๊ธฐ 4 |
| AE-F16C32P2 | 16๋ฐฐ ์์ถ + ํจ์น ํฌ๊ธฐ 2 |
| AE-F32C32P1 | 32๋ฐฐ ์์ถ + ํจ์น ์ฌ์ฉํ์ง ์์ (SANA) |
AE-F32C32P1 ์ค์ ์ด ๊ฐ์ฅ ๋ฐ์ด๋ ์ฑ๋ฅ(FID, CLIP Score)์ ๊ธฐ๋ก
Autoencoder๊ฐ ์์ถ์ ์ ์ ์ผ๋ก ๋ด๋นํ๋ ๊ฒ์ด ์ฑ๋ฅ ๋ฐ ํ๋ จ ์์ ์ฑ ๋ชจ๋์์ ๊ฐ์ฅ ์ฐ์
- Autoencoder latent ์ฑ๋ ์๋ฅผ ๋ช ๊ฐ๋ก ํ๋ ๊ฒ ์ข์๊ฐ?

- C=16, C=32, C=64 ์คํ ์ํ
- C=16์ ์ ๋ณด ์์ค๋ก ์ธํด ํ์ง ์ ํ ๋ฐ์
- C=64๋ ๋ณต์ ํ์ง์ ์ข์์ผ๋ ๋ชจ๋ธ ๋ณต์ก๋๊ฐ ๊ธ๊ฒฉํ ์ฆ๊ฐํ์ฌ ๋นํจ์จ์ ์
- C=32๊ฐ ์ฑ๋ฅ๊ณผ ํจ์จ ์ฌ์ด์์ ์ต์ ๊ท ํ์ ๋ฌ์ฑํจ
2.2 EFFICIENT LINEAR DIT DESIGN
- ๊ธฐ์กด diffusion transformer(์: DiT) ๊ตฌ์กฐ๋ Self-Attention์ ์ฌ์ฉํจ.
- Self-Attention์ ์ฐ์ฐ๋์ O(Nยฒ) ์ ๋น๋กํจ.
- NNN์ ์ ๋ ฅ ํ ํฐ ์
- ํ ํฐ ์๊ฐ ๋ง์์ง๋ฉด ์ฐ์ฐ๋์ด ๊ธ๊ฒฉํ ์ปค์ง
- 4K ํด์๋ ์ด๋ฏธ์ง๋ฅผ ๋ค๋ฃจ๋ ค๋ฉด, latent token ์๊ฐ ๋ง์์ง ์๋ฐ์ ์์.
โ ์ด๋, ๊ธฐ์กด ์ฐ๊ตฌ๋ค์ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ค๊ณ ํด์๋ ๋ฎ์ถ๊ฑฐ๋ Token ์๋ฅผ ์ค์์.
ReLU ๊ธฐ๋ฐ Linear Attention ๋์
๊ธฐ์กด Softmax ๊ธฐ๋ฐ Attention์ ์ ๊ฑฐํ๊ณ , ReLU๋ฅผ ์ด์ฉํ Linear Attention์ ์ฑํ
Softmax๋ ๋ชจ๋ Query-Token ์กฐํฉ์ ๋ค ๊ณ์ฐํ๊ธฐ ๋๋ฌธ์ O(Nยฒ) ๋ณต์ก๋๊ฐ ๋ฐ์

- SANA์์๋ ๋ค์์ฒ๋ผ ๊ณ์ฐ ๊ตฌ์กฐ๋ฅผ ๋ณ๊ฒฝ:
- ๊ฐ Key์ ReLU๋ฅผ ์ ์ฉํจ: ReLU(K)
- ๋ ๊ฐ์ง ๊ณต์ term์ ๋ฏธ๋ฆฌ ๊ณ์ฐ
- ๏ปฟ (dxd matrix)
- ๏ปฟ (dร1 vector)
- ์ดํ, ๊ฐ Query์ ๋ํด ์ด pre-computed shared term์ ์ฌ์ฌ์ฉํ์ฌ Attention์ ๊ณ์ฐ
- ์ด ๋ฐฉ์์ Query๋ง๋ค ๊ฐ๋ณ์ ์ผ๋ก ์ฐ์ฐํ ํ์๊ฐ ์์ด์, ์ ์ฒด Attention ๊ณ์ฐ์ด O(N) ์ผ๋ก ์ค์ด๋ฆ.
PixArt๋ Linear์ด๋ฌ๋๋ฐ ๋ค๋ฅธ์ ์?
PixArt์์๋ Key์ Value ํ ํฐ์ ์์ถํ์ฌ ์ฐ์ฐ๋์ ์ค์ฌ์ Engineering Optimization์ผ๋ก O(N)๊ณผ ๋น์ทํ๊ฒ ํ๋ ๋ฐฉ์, SANA์์๋ ์์ ์ํ์ ์ผ๋ก ๊ณ์ฐ๋์ด O(N)
Mix-FFN Block

๊ธฐ์กด Transformer์ FFN (Feed-Forward Network)์ ๋จ์ํ 2๊ฐ์ Linear Layer๋ก ๊ตฌ์ฑ๋์ด ์์์.
FFN์ ์ ์ญ์ ์ธ ์ ๋ณด๋ ์ ์ฒ๋ฆฌํ์ง๋ง, ์ง์ญ์ ์ธ(local) ๋ํ ์ผ ๋ณต์์๋ ์ฝํ์.
SANA์ ํด๊ฒฐ์ฑ :
- ๊ธฐ์กด MLP ์ฌ์ด์ 3ร3 Depthwise Convolution์ ์ฝ์
- ์ด๋ฅผ ํตํด ์ง์ญ ๊ตฌ์กฐ(local structure) ํ์ต์ ๊ฐํ
- ๊ฒฐ๊ณผ์ ์ผ๋ก ํ ์ค์ฒ, ๊ฒฝ๊ณ์ , ์ด๋ฐ ์ธ๋ฐํ ๋ถ๋ถ ๋ณต์์ ์ ๋ฆฌ
DiT without Positional Encoding (NoPE)
๊ธฐ์กด Transformer ๊ตฌ์กฐ๋ ์ ๋ ฅ ์์๋ฅผ ์ธ์ํ๋๋ก Positional Encoding์ ์ฌ์ฉํ์.
์๋๋ฉด Transformer๋ ์ ๋ ฅ ์์๋ฅผ ๊ตฌ๋ณํ ์ ์์๊ธฐ ๋๋ฌธ์โฆ
ํ์ง๋ง 4K ๊ณ ํด์๋ latent์ฒ๋ผ ํ ํฐ ์๊ฐ ๋ง์ ๋, Positional Encoding์ ๊ณ์ฐํ๊ณ ์ ์ฅํ๋ ๋ฐ๋ ๋น์ฉ์ด ํผ.
SANA์์๋ ์์ Positional Encoding์ ์ ๊ฑฐํจ
- 3ร3 Depthwise Convolution์ด Mix-FFN์ ์ถ๊ฐ๋์ด์ ์ง์ญ์ ์์น ๊ด๊ณ๋ฅผ ํ์ตํ ์ ์์.
- Linear Attention์ ์ ์ญ ๊ด๊ณ๋ฅผ ์์ฐ์ค๋ฝ๊ฒ ํฌ์ฐฉํ ์ ์์.
โ ๋ณ๋๋ก ์์น ์ ๋ณด๋ฅผ ๋ถ์ฌํ์ง ์์๋ ์ถฉ๋ถํ ํจํด๊ณผ ๊ตฌ์กฐ๋ฅผ ํ์ตํ ์ ์์.
๊ฒฐ๊ณผ์ ์ผ๋ก ํ์ง์ด ์ ์ง๋๋ฉด์ ๊ตฌ์กฐ๊ฐ ๊ฐ๋จํด์ง๊ณ ๋ฉ๋ชจ๋ฆฌ ์ฐ์ฐ๋์ด ๊ฐ์ํจ

Triton Acceleration Training/Inference
Appendix์ ์ถ๊ฐํ๋ค๊ณ ๋์ด์๋๋ฐ ์์ง ๊ด๋ จ ๋ด์ฉ ์์.
GPT๊ฐ ์๋ ค์ค ๋ด์ฉ
- Linear Attention์ ๊ตฌํํ ๋, ๋จ์ํ ์๊ณ ๋ฆฌ์ฆ๋ง ๊ฐ์ ํ๋ ๊ฒ์ผ๋ก๋ ๋ถ์กฑํจ.
- ์ค์ ์ฐ์ฐ ํจ์จ๊น์ง ๊ทน๋ํํ๋ ค๋ฉด, GPU kernel ๋ ๋ฒจ ์ต์ ํ๊ฐ ํ์ํจ.
- SANA๋ Triton์ ์ฌ์ฉํ์ฌ Linear Attention ์ฐ์ฐ์ ์ง์ ์ต์ ํํจ.
- Triton์ NVIDIA๊ฐ ์ง์ํ๋ ์ปค์คํ GPU ์ปค๋ ํ๋ก๊ทธ๋๋ฐ ํ๋ ์์ํฌ์.
- CUDA๋ณด๋ค ๋จ์ํ ๋ฌธ๋ฒ์ผ๋ก, ๊ณ ์ฑ๋ฅ ์ปค๋์ ์์ฑํ ์ ์์.
Triton์ผ๋ก ์ต์ ํํ ๊ฒฐ๊ณผ:
- Matrix ๊ณฑ ์ฐ์ฐ(GEMM)๊ณผ Memory Access๋ฅผ ์ค์.
- ์ค์ latency(์ง์ฐ ์๊ฐ)์ memory bandwidth ์๋ชจ๋ฅผ ํฌ๊ฒ ๊ฐ์
2.3 TEXT ENCODER DESIGN
์ T5 ๋์ Decoder-only LLM์ ์ฌ์ฉํ๋๊ฐ?
SANA๋ Gemma๋ฅผ Text Encoder๋ก ์ฌ์ฉํ๊ธฐ๋ก ์ฑํ
| ํญ๋ชฉ | ๊ธฐ์กด (T5) | SANA (Gemma-2) |
| ๋ชจ๋ธ ๊ตฌ์กฐ | Encoder-Decoder | Decoder-only |
| Reasoning ๋ฅ๋ ฅ | ์ ํ์ | ๋งค์ฐ ๊ฐํจ (CoT, ICL ๊ฐ๋ฅ) |
| ์ถ๋ก ์๋ | ๋๋ฆผ (T5-XXL) | 6๋ฐฐ ๋น ๋ฆ (Gemma-2-2B) |

โ ๋น ๋ฅธ๋ฐ๋ ๋ถ๊ตฌํ๊ณ , CLIP Score์ FID(์ด๋ฏธ์ง ํ์ง ์งํ)์์๋ ์ฑ๋ฅ์ด ๋น์ทํจ
Decoder-only LLM์ Text Encoder๋ก ์ฐ๋ฉด์ ์๊ธด ๋ฌธ์ ํด๊ฒฐ
Decoder-only LLM (Gemma, Qwen ๋ฑ)์ ํ ์คํธ ์๋ฒ ๋ฉ์ Variance๊ฐ ํจ์ฌ ํผ.
- ํฐ ๊ฐ์ด ํ ์คํธ ์๋ฒ ๋ฉ ์์ ๋ง์ด ํฌํจ๋์ด ์์.
- Cross-Attention ์ฐ์ฐ ์ค ์์น ํญ๋ฐ(NaN)๋ก ์ด์ด์ง.
๋ฐฉ๋ฒ 1: RMSNorm ์ถ๊ฐ
Gemma-2์ ํ ์คํธ ์๋ฒ ๋ฉ ์ถ๋ ฅ์ RMSNorm์ ์ ์ฉ
RMSNorm?
- ์ ๋ ฅ ๋ฒกํฐ์ Variance๋ฅผ 1.0์ผ๋ก ์ ๊ทํ
- ํฐ ๊ฐ์ด๋ ์์ ๊ฐ๋ค์ ๊ท ์ผํ๊ฒ ๋ง๋ค์ด ์์น ํญ๋ฐ ๋ฐฉ์ง
๋ฐฉ๋ฒ 2: Learnable Scale Factor ์ถ๊ฐ
- ์ถ๊ฐ๋ก, ํ ์คํธ ์๋ฒ ๋ฉ์ ํ์ต ๊ฐ๋ฅํ ์์ ์ค์ผ์ผ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ณฑํจ
- ์ด๊ธฐ ๊ฐ์ ๋งค์ฐ ์๊ฒ ์ค์ ํจ (์: 0.01)
- ์ด ํ๋ผ๋ฏธํฐ๊ฐ ํ์ต์ ํตํด ์ ์ ํ ํฌ๊ธฐ๋ก ์กฐ์ ๋๋ฉด์ ๋ชจ๋ธ ์๋ ด ์๋๊ฐ ๋นจ๋ผ์ง
โ ํ๋ จ ์์ ์ฑ ํ๋ณด + ์๋ ด ์๋ ํฅ์

Complex Human Instruction Improves Text-Image Alignment
Gemma๋ ๊ฐ๋ ฅํ LLM์ด์ง๋ง, ์ฌ์ฉ์๊ฐ ์งง๊ฑฐ๋ ๋ชจํธํ ํ๋กฌํํธ๋ฅผ ์ ๋ ฅํ๋ฉด (์: "a cat")
LLM์ด ์ด์ ์ ์๊ณ ์๋ฑํ ๋ต๋ณ์ ํ ์๋ ์์.
โ LLM์ด ํ๋กฌํํธ์๋ง ์ง์คํ๊ฒ ๋ง๋๋ ์ถ๊ฐ ์ง์๋ฌธ์ด ํ์ํจ.
CHI๊ฐ ๊ทธ๋์ ๋ญ๋?
LLM์ In-Context Learning ๋ฅ๋ ฅ์ ํ์ฉํ์ฌ ํ๋กฌํํธ๋ฅผ ์ฃผ๊ธฐ ์ ์,
LLM์๊ฒ "์์, ํฌ๊ธฐ, ์์น ๊ด๊ณ ๊ฐ์ ์ธ๋ถ ๋ฌ์ฌ๋ฅผ ์ถ๊ฐํด๋ผ"์ ๊ฐ์ ๋ณต์กํ ๋ช ๋ น ์ธํธ๋ฅผ ํจ๊ป ์ ๊ณตํ๋ ๊ฒ
๊ฒฐ๊ณผ 1

CHI๋ฅผ ์ ์ฉํ์ ๋, ํ์ต์ ์ฒ์๋ถํฐ ํ๋ (fresh training)
์๋๋ฉด ๊ธฐ์กด ๋ชจ๋ธ์ ๋ฏธ์ธ ์กฐ์ (fine-tuning)ํ๋
ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ ์ฑ๋ฅ์ด ํฅ์
๊ฒฐ๊ณผ 2

์งง์ ํ๋กฌํํธ(์: "a cat")๋ฅผ ์ ๋ ฅํ์ ๋,
CHI๊ฐ ์์ผ๋ฉด, ๋ชจ๋ธ์ด ์๋ฑํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ฑฐ๋ ํ์ง์ด ๋ถ์์ ํด์ง.
CHI๊ฐ ์์ผ๋ฉด, ๋ชจ๋ธ์ด ํ๋กฌํํธ์ ์ ํํ ๋ง๋ ์์ ์ ์ธ ์ด๋ฏธ์ง๋ฅผ ์์ฑ
3 EFFICIENT TRAINING/INFERENCE
3.1 DATA CURATION AND BLENDING
1. Multi-Caption Auto-labelling Pipeline
์ด๋ฏธ์ง ํ๋๋น 4๊ฐ์ VLM(Vision-Language Models) ์ ์ด์ฉํด ์บก์ ์ ์์ฑํจ.
- VILA-3B
- VILA-13B
- InternVL-28B
- InternVL-26B
โ ์ ํํ ์บก์ ์์ฑ (ํ๋๋ง ์ฐ๋ ๊ฒ๋ณด๋ค ์ค๋ฅ ์ค์)
โ ๋ค์ํ ํํ ํ๋ณด (๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ์ฌ๋ฌ ๊ด์ ์์ ๋ฌ์ฌ ๊ฐ๋ฅ)
2. CLIP-Score-based Caption Sampler
๋ฌธ์ ์ํฉ
- ์บก์
์ ์ฌ๋ฌ ๊ฐ ๋ง๋ค์๋๋ฐ,
ํ๋ จํ ๋ ์ด๋ค ์บก์ ์ ์ ํํ ์ง๊ฐ ๋ฌธ์ ์.
- ๋ฌด์์๋ก(random) ํ๋ ๊ณ ๋ฅด๋ฉด:
- ํ์ง์ด ๋ฎ์ ๋ฌธ์ฅ์ ๋ฝ์ ์ํ์ด ์์
- ๊ทธ๋ฌ๋ฉด ํ๋ จ์ด ๋๋ ค์ง๊ฑฐ๋ ๋ชจ๋ธ ํ์ง์ด ๋จ์ด์ง
ํด๊ฒฐ ๋ฐฉ๋ฒ
- CLIP score๋ฅผ ํ์ฉํด ํ์ง ๋์ ์บก์
์ ๋ฝ๋ ๋ฐฉ์ ์ฌ์ฉ
- CLIP์ ์ด๋ฏธ์ง-ํ ์คํธ ๋งค์นญ ์ ๋๋ฅผ ์ ์๋ก ๊ณ์ฐํด์ค.
- ๊ณผ์ :
- ์ด๋ฏธ์ง์ ๋ํด ์์ฑ๋ ๊ฐ ์บก์ ์ CLIP ์ ์(cic_ici)๋ฅผ ๊ณ์ฐ
- ์ ์๊ฐ ๋์ ์บก์ ์ผ์๋ก ๋ฝํ ํ๋ฅ ์ด ๋๊ฒ ์ค์
- Sampling ํ๋ฅ ๊ณต์:
์ฌ๊ธฐ
- ฯ๋ "temperature"๋ผ๋ ํ์ดํผํ๋ผ๋ฏธํฐ์.
- Temperature ์กฐ์ ์ผ๋ก ๋ฝ๋ ๊ฐ๋๋ฅผ ์กฐ์ ํ ์ ์์:
- ฯ๊ฐ ์์ผ๋ฉด: ์ ์ ๊ฐ์ฅ ๋์ ์บก์ ๋ง ๊ฑฐ์ ํญ์ ์ ํ
- ฯ๊ฐ ํฌ๋ฉด: ๋ค์ํ ์บก์ ์ด ๊ณ ๋ฅด๊ฒ ์ ํ๋จ
์คํ ๊ฒฐ๊ณผ
- Table 4 ๊ฒฐ๊ณผ์ ๋ฐ๋ฅด๋ฉด:
- ์บก์ ์ ๋ค์ํ๊ฒ ๊ณจ๋ผ๋ ์ด๋ฏธ์ง ํ์ง(FID)์ ๊ฑฐ์ ๋ณํ์ง ์์
- ํ์ง๋ง ํ๋ จ ์ค ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ(semantic alignment)์ ํจ์ฌ ์ข์์ง
3. Cascade Resolution Training
๊ธฐ์กด ๋ฐฉ์
- ๋๋ถ๋ถ์ diffusion ๋ชจ๋ธ์ ํด์๋ 256px์ง๋ฆฌ ์ด๋ฏธ์ง๋ก ๋จผ์ pre-training์ ํจ.
- ์ด์ ๋ ์ฐ์ฐ ๋น์ฉ(cost)์ ์ค์ด๊ธฐ ์ํด์์.
๋ฌธ์ ์
- 256px ์ด๋ฏธ์ง๋ ๋ํ ์ผ(detail) ์์ค์ด ์ฌํจ.
- ๋ฐ๋ผ์, ์์ ํด์๋์์ ํ์ต์ ์์ํ๋ฉด:
- ๋ชจ๋ธ์ด fineํ ๊ตฌ์กฐ๋ ํ ์ค์ฒ๋ฅผ ๋ฐฐ์ฐ๊ธฐ ์ด๋ ค์
- ๊ฒฐ๊ตญ ํฐ ํด์๋๋ก ๊ฐ ๋ ๋ ๋๋ฆฌ๊ฒ ํ์ตํจ.
SANA ๋ฐฉ์
- SANA๋ AE-F32C32P1 ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์
latent ๊ณต๊ฐ์ด ๋งค์ฐ ์์ โ ์ฐ์ฐ ๋ถ๋ด์ด ์ ์.
- ๊ทธ๋์ ๊ตณ์ด 256px์์ ์์ํ ํ์ ์์ด ๋ฐ๋ก 512px์์ ํ์ต ์์ํจ.
- ํ์ต ์์:
- 512px โ 1024px โ 2K โ 4K ์์๋ก ์ ์ง์ (fine-tuning)์ผ๋ก ํด์๋๋ฅผ ์ฌ๋ฆผ.
3.2 FLOW-BASED TRAINING / INFERENCE
Flow-based Training
๊ธฐ์กด ๋ฐฉ์ : noise prediction
- ๏ปฟ: ์๋ณธ ์ด๋ฏธ์ง
- ๏ปฟ: ๋๋ค ๋ ธ์ด์ฆ
- ๏ปฟ: diffusion ๊ณผ์ ์ ํ์ดํผํ๋ผ๋ฏธํฐ
โ ๋ ธ์ด์ฆ๋ฅผ ๋ง์ถ๋ ๊ฒ์ด ํ์ต ๋ชฉํ
๋ฌธ์ ์ : t๊ฐ ์ปค์ง๋ฉด (Diffusion ๋ง์ง๋ง ๋จ๊ณ์ ๊ฐ๊น์ฐ๋ฉด) ๋ ธ์ด์ฆ๊ฐ ์ปค์ ธ์ ์์ธก ๋ถ์์
์ ๋ฐฉ์ : EDM ,RF
noise ๋์ data๋ velocity(๋ ธ์ด์ฆ์ ์๋ณธ ์ด๋ฏธ์ง ์ฐจ์ด) ์์ธก
EDM : ๏ปฟ (์๋ณธ ๋ฐ์ดํฐ ์์ธก)
RF : ๏ปฟ (velocity ์์ธก)
๊ฒฐ๊ตญ RF๋ฅผ ์ฌ์ฉํ์ฌ cumulative(๋์ ) error๋ฅผ ์ค์ผ ์ ์์
Flow-based Inference
๊ธฐ์กด : DPM-Solver++
โ required 28-50 steps for high-quality samples
ํ์ฌ : Flow-DPM-Solver
1.Not predict original data, but velocity
2.substituting the scaling factor ฮฑt with 1 โ ฯt
3.time-steps are redefined over the range [0, 1] instead of [1, 1000]
โGenerate high-quality samples in 14-20 steps
5. Experiments
1. Model Details
Sana-0.6B
- ํ๋ผ๋ฏธํฐ ์: 590M
- ๊ตฌ์กฐ: DiT-XL ๋ฐ PixArt-ฮฃ์ ๊ฑฐ์ ๋์ผํ ๋ ์ด์ด ์์ ์ฑ๋ ์ ์ฌ์ฉ
- ๋ชฉ์ : ์ํ ๋ชจ๋ธ๋ก๋ ํจ์จ์ฑ๊ณผ ํ์ง์ ๋์์ ํ๋ณด
Sana-1.6B
- ํ๋ผ๋ฏธํฐ ์: 1.6B
- ๊ตฌ์กฐ:
- 20๊ฐ์ Transformer ๋ ์ด์ด
- ๊ฐ ๋ ์ด์ด๋ง๋ค 2240๊ฐ์ ์ฑ๋
- FFN ๋ด๋ถ ์ฑ๋ ์๋ 5600
- ์ด ๊ตฌ์ฑ์ ํ์ต ํจ์จ์ฑ๊ณผ ์์ฑ ํ์ง ์ฌ์ด์ ๊ท ํ์ ๊ณ ๋ คํ ๊ฒ์
2. Evaluation Details
SANA๋ ์ด 5๊ฐ์ง ๋ํ์ ์ธ ํ๊ฐ ์งํ๋ฅผ ์ฌ์ฉํ์ฌ ์ฑ๋ฅ์ ํ๊ฐํจ.
| ์งํ | ์ค๋ช |
| FID (Frรฉchet Inception Distance) | ์ด๋ฏธ์ง ํ์ง์ ์์น๋ก ์ธก์ . ๋ฎ์์๋ก ์ข์ |
| CLIP Score | ์ด๋ฏธ์ง์ ํ ์คํธ ๊ฐ ์๋ฏธ์ ์ ๋ ฌ ์ ๋ ํ๊ฐ. ๋์์๋ก ์ข์ |
| GenEval (Ghosh et al., 2024) | ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ ํ๊ฐ. ์ด 533๊ฐ์ ํ๋กฌํํธ ์ฌ์ฉ |
| DPG-Bench (Hu et al., 2024) | ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ ์ ๋ฐ๋ ํ ์คํธ. 1065๊ฐ์ ํ๋กฌํํธ ์ฌ์ฉ |
| ImageReward (Xu et al., 2024) | ์ธ๊ฐ์ ์ฃผ๊ด์ ์ ํธ๋๋ฅผ ๋ฐ์ํ ์ ์. 100๊ฐ ํ๋กฌํํธ๋ก ์ธก์ |
3. ํ๊ฐ ๋ฐ์ดํฐ์
- MJHQ-30K (Li et al., 2024a)
- Midjourney์์ ์์งํ 30,000๊ฐ ๊ณ ํ์ง ์ด๋ฏธ์ง ํฌํจ
- FID, CLIP Score ์ธก์ ์ ์ฌ์ฉ๋จ


Limitation

์ฝ๋๊ฐ ์ ๋ฐ์ ์ผ๋ก NVIDIA์นฉ๋ง์ ์ํด ์ค๊ณ๋์ด
๋ค๋ฅธ GPU ์ฅ๋น๋ ๋ฌผ๋ก ์ด๊ณ ,
Mobile Device์์๋ ๋น์ฐํ ๋ถ๊ฐ๋ฅํจ.
์๋ฌด๋๋ NVIDIA์์ ๋ธ ๋ ผ๋ฌธ์ด๊ธฐ ๋๋ฌธ์ Blackwell chip ํ๋ณด ๊ฒธ NVIDIA chip์์๋ง ๊ฐ๋ฅํ๋๋ก ํ๋ฏ.