| ArXiv | https://arxiv.org/abs/2403.04692 |
|---|---|
| Project Page | https://pixart-alpha.github.io/PixArt-sigma-project/ |
| Github Code | https://github.com/PixArt-alpha/PixArt-sigma |
| Affiliation | Huawei Noahโs Ark Lab, Dalian University of Technology, HKU, HKUST |
Key Differentiator
- ๊ธฐ์กด ์ฐ๊ตฌ์๋ PixArt-ฮฑ์์ ์ต์ ํ๋ฅผ ํตํด 4K ์ด๊ณ ํด์๋๊น์ง ๊ฐ๋ฅํ๋๋ก ์ฐ๊ตฌ
- 4K๋ฅผ transformer๋ฅผ ํ์ฉํด directly๋ก ํ๋ฒ์ ์์ฑ
2. Related Work
PixArt-ฮฑ (ICLR 2024 Spotlight)
- ์ต์ด์ Transformer ๊ธฐ๋ฐ Diffusion Model (DiT)๋ก 1024ร1024 ํด์๋๊น์ง ์์ฑ ๊ฐ๋ฅ
Stable Diffusion XL (SDXL, 2023)
- Latent Diffusion Model (LDM) ๊ตฌ์กฐ๋ฅผ ํ์ฉํ์ฌ 1024ร1024 ์ด์์ ๊ณ ํด์๋ ์ด๋ฏธ์ง ์์ฑ ๊ฐ๋ฅ
GigaGAN (Adobe, 2023)
- GAN ๊ธฐ๋ฐ ์ด๊ณ ํด์๋ ์ด๋ฏธ์ง ์์ฑ ๋ชจ๋ธ (1024px ์ด์ ์ง์)
LLaVA (Visual Instruction Tuning, 2023)
- ์ด๋ฏธ์ง-ํ ์คํธ ์ ๋ ฌ์ ํ์ตํ์ฌ ์ด๋ฏธ์ง์ ๋ํ ์ค๋ช (์บก์ )์ ์๋์ผ๋ก ์์ฑํ๋ ๋ชจ๋ธ
DALLยทE 3 (OpenAI, 2023)
- GPT-4 ๊ธฐ๋ฐ ํ ์คํธ ์ดํด๋ ฅ์ ํ์ฉํ์ฌ ํ๋กฌํํธ๋ฅผ ๋ ์ ๋ฐํ๊ฒ ๋ฐ์
3. Framework
3.1 Data Analysis
| Data | ||
| Internal-ฮฑ | 14M | |
| Internal-ฮฃ | 33M | >=1K (33M) real photo 4K (8M) |
| SD v1.5 (open-source) | 2B |
a ๋๋ณด๋ค ๋ฐ์ดํฐ๊ฐ ๋ง์ด ๋์๊ณ , 4K real photo๋ ์ถ๊ฐํจ.
ํ์ง๋ง SD v1.5๊ฐ 2B ๋ฐ์ดํฐ์ธ๊ฑธ ๊ฐ์ํ๋ฉด ์์ฃผ ์ ํ์ ์ธ ๋ฐ์ดํฐ.
ํ์ง๋ง ํจ๊ณผ์ ์ผ๋ก trainingํจ.
์ด๋ฏธ์ง์ ์์ ์ ํ์ง์ ํ๊ฐํ๋ Aesthetic Scoring Model(AES)์ ์ฌ์ฉํ์ฌ 2M(200๋ง ์ฅ)์ ๊ณ ํ์ง ์ด๋ฏธ์ง ์ ๋ณ.
โ ํด์๋๊ฐ ๋์์ง์๋ก ๋ชจ๋ธ์ ์ถฉ์ค๋(ํ๋ ์ ฐ ์ด์ ๊ฑฐ๋ฆฌ(FID) [18])์ ์๋ฏธ์ ์ ๋ ฌ(CLIP ์ ์)์ด ํฅ์
Better Text-Image Alignment
ํ ์คํธ ํ๋กฌํํธ(์ค๋ช )์ ์์ฑ๋ ์ด๋ฏธ์ง๊ฐ ์ผ๋ง๋ ์ผ์นํ๋์ง
์ฆ, ์ฌ์ฉ์๊ฐ ์ ๋ ฅํ ํ ์คํธ(prompt)์ ๋ชจ๋ธ์ด ์์ฑํ ์ด๋ฏธ์ง๊ฐ ์ผ๋ง๋ ์ ํํ๊ฒ ๋์ํ๋์ง๋ฅผ ํ๊ฐํ๋ ๊ฐ๋
PixArt-ฮฑ ๋ LLaVa๋ฅผ ์ฌ์ฉํ์๊ณ , PixArt-ฮฃ๋ Share-Captioner ์ฌ์ฉ
| ํญ๋ชฉ | LLaVA | Share-Captioner |
| ๊ธฐ๋ฐ ๋ชจ๋ธ | CLIP + LLaMA | GPT-4V (GPT-4 with Vision) |
| ํ ์คํธ ์์ฑ | ๋น๊ต์ ๋จ์ | ๋ ๊ธธ๊ณ ์ธ๋ฐํ ์ค๋ช |
| ์ ํ๋ | ๊ฐ๋ ํ๊ฐ ๋ฌธ์ ๋ฐ์ | ๋ ๋์ ์ ํ๋ |
| ์ด๋ฏธ์ง ๋ํ ์ผ ๋ฐ์ | ์ ํ์ (๋จ์ ์ค๋ช ) | ๋ ์ ๋ฐํ ๊ฐ์ฒด ๋ฐ ๊ด๊ณ ์ค๋ช |
| ์บก์ ํ์ง | ์ผ๋ฐ์ ์ธ ์ค๋ช ์์ค | ๊ณ ํ์ง, ๊ตฌ์ฒด์ ์ธ ๋ฌ์ฌ ๊ฐ๋ฅ |
๋ค์๊ณผ ๊ฐ์ ํ๊ฐ (Hallucinations)๊ฐ ๋ฐ์ํ์์

| ํญ๋ชฉ | PixArt-ฮฑ | PixArt-ฮฃ |
| ํ ์คํธ ํด์ ๊ธธ์ด | 120 ํ ํฐ | 300 ํ ํฐ (2.5๋ฐฐ ์ฆ๊ฐ) |
| ์บก์ ์์ฑ ๋ชจ๋ธ | LLaVA (๋จ์ํจ) | Share-Captioner (์ ํํ ์ค๋ช ) |
| CLIP Score | 0.2787 | 0.2797 (ํฅ์๋จ) |
| ํ๊ฐ ๋ฌธ์ ํด๊ฒฐ | ์ผ๋ถ ์กด์ฌ | ํ๊ฐ ๊ฐ์ (๋ ์ ๋ฐํ ์บก์ ์ฌ์ฉ) |
PixArt-ฮฃ๋ ๋ ๊ธด ๋ฌธ์ฅ์ ํด์ํ๊ณ , ๋ ์ ๊ตํ ์บก์ ์ ์ฌ์ฉํ์ฌ ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ ์ฑ๋ฅ์ ๋์์.
Share-Captioner๋ฅผ ์ฌ์ฉํ์ฌ ํ ์คํธ์ ์ด๋ฏธ์ง ๊ฐ ์ ๋ณด ์ผ์น๋๋ฅผ ๊ฐ์ ํจ.
ํ๊ฐ ๋ฐ์ดํฐ์ ๊ตฌ์ฑ (High-Quality Evaluation Dataset)
- ๊ธฐ์กด ๋ชจ๋ธ๋ค์ด ์ฌ์ฉํ๋ MSCOCO ๋ฐ์ดํฐ์ ์ ์์ ์ ํ์ง๊ณผ ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ์ ํ๊ฐํ๊ธฐ์ ์ถฉ๋ถํ์ง ์์.
- ๋ฐ๋ผ์ PixArt-ฮฃ๋ ์๋ก์ด ํ๊ฐ ๋ฐ์ดํฐ์ (30,000๊ฐ ์ํ) ๊ตฌ์ถ.
- ํ๊ฐ ํญ๋ชฉ:
- Frรฉchet Inception Distance (FID) โ ์ด๋ฏธ์ง ํ์ง ํ๊ฐ
- CLIP Score โ ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ ์ฑ๋ฅ ํ๊ฐ

3.2 Efficient DiT Design
Key-Value (KV) Token Compression ๊ธฐ๋ฒ
๊ธฐ์กด Attention ์ฐ์ฐ ๋ฌธ์
- Self-Attention์ Query(Q), Key(K), Value(V)์ ๊ณฑ์ ๊ณ์ฐํ๋ ๋ฐฉ์์ด๋ฏ๋ก,ํ ํฐ ๊ฐ์๊ฐ ๋ง์์ง์๋ก ์ฐ์ฐ๋์ด O(Nยฒ)์ผ๋ก ์ฆ๊ฐํจ.
- ํด๊ฒฐ ๋ฐฉ๋ฒ: Key์ Value ํ ํฐ์ ์์ถํ์ฌ ์ฐ์ฐ๋์ ์ค์.
PixArt-ฮฃ์ KV Token Compression ๋ฐฉ์
- PixArt-ฮฃ (ํ ํฐ ์์ถ ์ ์ฉ):
- Key(K)์ Value(V)๋ฅผ Stride 2์ Group Convolution์ ์ฌ์ฉํด ์์ถ
- ์ด๋ฅผ ํตํด ํ ํฐ ๊ฐ์๋ฅผ N โ N/R^2 ์ผ๋ก ์ค์
- ์ ํ๋๊ฐ ํฌ๊ฒ ๋จ์ด์ง์ง ์๋ ์ ์์ R์ ์กฐ์ (1~4)ํ๊ธฐ
- ์ต์ข ์ ์ผ๋ก ์ฐ์ฐ๋์ ๊ธฐ์กด ๋๋น ์ฝ 34% ์ ๊ฐ
ํต์ฌ ํจ๊ณผ
- 4K ํด์๋ ์ด๋ฏธ์ง ์์ฑ ์๋ ํฅ์ (์ฐ์ฐ๋ ๊ฐ์)
- ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ๊ฐ์ โ ๋ ์์ GPU์์๋ ์คํ ๊ฐ๋ฅ
- ๊ธฐ์กด PixArt-ฮฑ ๋ชจ๋ธ์์ ์์ฐ์ค๋ฝ๊ฒ ์ ๊ทธ๋ ์ด๋ ๊ฐ๋ฅ (๊ธฐ์กด ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ํ์ฉ)


3.3 Weak-to-Strong Training Strategy
PixArt-ฮฃ์ Weak-to-Strong Training์ ๊ธฐ์กด ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ํ์ฉํ์ฌ ๋น ๋ฅด๊ฒ ์ ์ํ๋๋ก ์ค๊ณ๋จ.
์ด ๊ณผ์ ์์ 3๋จ๊ณ์ ํ์ต ์ ๋ต์ด ์ ์ฉ๋จ.
(1) VAE ์ ์ (VAE Adaptation)
- PixArt-ฮฑ์์ ์ฌ์ฉํ๋ ๊ธฐ์กด VAE๋ฅผ Stable Diffusion XL(SDXL)์ VAE๋ก ๊ต์ฒด
- VAE ๊ต์ฒด ํ ๋น ๋ฅธ ์ ์์ ์ํด 2K Training Steps ๋ง์ ์๋ ดํ๋๋ก ํ์ต ์ ๋ต ์ ์ฉ.
- ์๋ก์ด VAE ์ ์ฉ ํ์๋ ๊ธฐ์กด ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฌ์ฉํ์ฌ ๋น ๋ฅด๊ฒ ํ์ต ๊ฐ๋ฅ.

(2) ํด์๋ ์ ๊ทธ๋ ์ด๋ (Resolution Upscaling)
- 256px โ 512px โ 1024px โ 4K๋ก ์ ์ง์ ์ผ๋ก ํด์๋๋ฅผ ์ฆ๊ฐ์ํค๋ฉฐ ํ์ต.
- PE Interpolation(์์น ์๋ฒ ๋ฉ ๋ณด๊ฐ๋ฒ)์ ์ ์ฉํ์ฌ, ๊ธฐ์กด ํด์๋์ ๊ฐ์ค์น๋ฅผ ์ ํด์๋์์๋ ์์ฐ์ค๋ฝ๊ฒ ์ฌ์ฉ ๊ฐ๋ฅํ๋๋ก ์กฐ์ .
- ๋ณด๊ฐ๋ฒ (Interpolation)์ย ์๋ ค์ง ๊ฐ์ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ์ ๊ณ์ฐํ๋ ํ๋ก์ธ์ค
- Transformer ๊ธฐ๋ฐ ๋ชจ๋ธ(์: DiT, ViT ๋ฑ)์ ์ ๋ ฅ ์ด๋ฏธ์ง์ ๊ฐ ์์น ์ ๋ณด๋ฅผ ํํํ๊ธฐ ์ํด ์์น ์๋ฒ ๋ฉ์ ์ฌ์ฉ.
- ๋ชจ๋ธ์ด 256ร256์์ ํ์ต๋์๋ค๋ฉด, 256ร256 ํด์๋์ ์ต์ ํ๋ ์์น ์๋ฒ ๋ฉ์ ํ์ตํจ.
- ํ์ง๋ง ํด์๋๋ฅผ 1024ร1024๋ก ์ฆ๊ฐ์ํค๋ฉด, ๊ธฐ์กด 256ร256 ์์น ์๋ฒ ๋ฉ๊ณผ ๊ตฌ์กฐ๊ฐ ๋ฌ๋ผ์ ธ ๋ชจ๋ธ ์ฑ๋ฅ์ด ๊ธ๊ฒฉํ ์ ํ๋จ.
- ๊ธฐ์กด ์์น ์๋ฒ ๋ฉ์ 1024ร1024 ํฌ๊ธฐ๋ก ๋ณด๊ฐ(interpolation)
- ์ฆ, 256๊ฐ์ ๊ฐ์ 1024๊ฐ๋ก ํ์ฅํ๋ ๊ณผ์ ์์ ์์ฐ์ค๋ฝ๊ฒ ๋งค๋๋ฌ์ด ๊ฐ์ผ๋ก ๋ณํ๋จ.
- ์ด๋ฅผ ํตํด ์๋ก์ด ํด์๋์์๋ ๊ธฐ์กด ๋ชจ๋ธ์ ๊ณต๊ฐ ์ ๋ณด๊ฐ ์ ์ง๋จ.
- ๋จ 1000 Training Steps๋ง์ผ๋ก๋ ํด์๋ ์ฆ๊ฐ์ ์ ์ ๊ฐ๋ฅ.

(3) KV Token Compression ๋์ (์ฐ์ฐ ์ต์ ํ)
- PixArt-ฮฃ ๋ชจ๋ธ์ KV Token Compression์ ์ ์ฉํ์
- ํ์ง๋ง ๊ธฐ์กด ๋ชจ๋ธ๊ณผ ๊ตฌ์กฐ๊ฐ ๋ฌ๋ผ์ ์ฑ๋ฅ ์ ํ ์ํ์ด ์์.
- PixArt-ฮฃ์์๋ ๊ธฐ์กด ๋ชจ๋ธ์์ ์์ฐ์ค๋ฝ๊ฒ ์ ์ํ๋๋ก "Conv Avg Init." ์ ๋ต ์ ์ฉ.
ํ๊ท ์ฐ์ฐ(Averaging) ๊ธฐ๋ฐ ์ด๊ธฐํ
- Conv Avg Init์ ๊ฐ์ค์น ๊ฐ์
1/Rยฒ๋ก ์ค์ ํ์ฌ, ๊ธฐ์กด ์ ๋ณด๋ฅผ ์ต๋ํ ์ ์งํ๋ฉด์ ๋ถ๋๋ฝ๊ฒ ์ ํํจ.
- ์ฆ, ๋จ์ํ ์์ถํ๋ ๊ฒ์ด ์๋๋ผ ๊ธฐ์กด ๊ณต๊ฐ ์ ๋ณด๋ฅผ ์ต๋ํ ๋ณด์กดํ๋ ๋ฐฉ์.
- ์ด๊ธฐ์๋ ์์ถ ์์ด ํ์ต ํ, ํ์ต์ด ์์ ํ๋๋ฉด KV Compression์ ์ ์ฉํ์ฌ ์ฐ์ฐ๋ ๊ฐ์.
- 4K ์ด๋ฏธ์ง ์์ฑ ์ ์ฐ์ฐ๋ 34% ์ ๊ฐ.
๊ฒฐ๊ณผ์ ์ผ๋ก, ๊ธฐ์กด PixArt-ฮฑ ๋๋น ์ ์ ์ฐ์ฐ๋๊ณผ ๋น ๋ฅธ ํ์ต์ผ๋ก 4K ์ด๋ฏธ์ง ์์ฑ์ด ๊ฐ๋ฅํด์ง.

4. Experiment
4.1 Implementation Details (๊ตฌํ ์ธ๋ถ์ฌํญ)
1. ๋ชจ๋ธ ๊ตฌ์ฑ
ํ ์คํธ ์ธ์ฝ๋
- Flan-T5-XXL ์ฌ์ฉ (Imagen ๋ฐ PixArt-ฮฑ์ ๋์ผ)
- ๊ธฐ์กด ๋ชจ๋ธ์์ 120๊ฐ ํ ํฐ์ ์ฌ์ฉํ๋ ๊ฒ์ 300๊ฐ ํ ํฐ๊น์ง ํ์ฅํ์ฌ ๋ ์ ๋ฐํ ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ ๊ฐ๋ฅ.
VAE (Variational Autoencoder) ์ ์ฉ
- Stable Diffusion XL(SDXL)์ VAE ์ฌ์ฉ
- ๋ ๋์ ํ์ง์ ์ด๋ฏธ์ง ๋์ฝ๋ฉ ๊ฐ๋ฅ โ ์ธ๋ฐํ ๋ํ ์ผ ๋ณด์กด
๊ธฐ๋ฐ ๋ชจ๋ธ
- PixArt-ฮฑ๋ฅผ ๋ฒ ์ด์ค ๋ชจ๋ธ๋ก ์ฌ์ฉ
- 256px ์ฌ์ ํ์ต๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ํ์ฉํ์ฌ 4K๊น์ง ํ์ฅ
KV Token Compression ์ ์ฉ
- ์ฐ์ฐ๋ 34% ์ ๊ฐ
- ์ด๊ณ ํด์๋(4K) ์ด๋ฏธ์ง ์์ฑ์ ๊ฐ๋ฅํ๊ฒ ํจ
2. ํ์ต ํ๊ฒฝ ๋ฐ ํ๋์จ์ด
ํ๋ จ GPU ํ๊ฒฝ
- 1K ๋ชจ๋ธ ํ์ต: 32 V100 GPUs ์ฌ์ฉ
- 2K & 4K ๋ชจ๋ธ ํ์ต: 16 A100 GPUs ์ฌ์ฉ
์ต์ ํ ์๊ณ ๋ฆฌ์ฆ
- CAME Optimizer ์ฌ์ฉ (AdamW ๋์ )
- ํ์ต๋ฅ : 2e-5 (๊ณ ์ Learning Rate ์ฌ์ฉ)
- Weight Decay: 0
Position Embedding Interpolation (PE Interp.) ์ ์ฉ
- ๋ฎ์ ํด์๋์์ ํ์ต๋ ๋ชจ๋ธ์ ๊ณ ํด์๋๋ก ๋ณํํ ๋ ์์น ์๋ฒ ๋ฉ์ ๋ณด๊ฐ(interpolation)ํ์ฌ ์ ์ฉ.
- ์ด๋ฅผ ํตํด ๊ณ ํด์๋๋ก ํ์ฅ ์ ์ฑ๋ฅ ์ ํ ์์ด ๋น ๋ฅด๊ฒ ์ ์ ๊ฐ๋ฅ.
3. ํ์ต ๋ฐ์ดํฐ ๋ฐ ํ๋ จ ๊ณผ์
ํ๋ จ ๋ฐ์ดํฐ์
- ์ด 33M(3,300๋ง ๊ฐ)์ ๊ณ ํด์๋ ์ด๋ฏธ์ง ์ฌ์ฉ
- 1K ํด์๋ ์ด์์ ๋ฐ์ดํฐ๋ง ํฌํจ
- 4K ํด์๋ ์ด๋ฏธ์ง 2.3M(230๋ง ๊ฐ) ํฌํจ
- Aesthetic Scoring Model(AES) ์ ์ฉํ์ฌ ๊ณ ํ์ง ์ด๋ฏธ์ง ์ ๋ณ
ํ๋ จ ๊ณผ์
- 256px โ 512px โ 1024px โ 4K ํด์๋๋ก ์ ์ง์ ์ ์ค์ผ์ผ๋ง ์ ์ฉ
- VAE ๊ต์ฒด ํ 2K Training Steps ๋ด ๋น ๋ฅด๊ฒ ์ ์
- PE Interpolation์ ์ ์ฉํ์ฌ ๊ณ ํด์๋์์ ์ถ๊ฐ ํ์ต ๋น์ฉ ์ ๊ฐ
ํ์ต ๋น์ฉ ์ ๊ฐ
- ๊ธฐ์กด PixArt-ฮฑ ๋๋น ํ๋ จ ๋น์ฉ 9%๋ง ์ฌ์ฉํ์ฌ 1K ์์ฑ ๊ฐ๋ฅ
- KV Compression๊ณผ Weak-to-Strong Training์ ๊ฒฐํฉํ์ฌ GPU ๋น์ฉ ์ ๊ฐ
4.2 ์คํ ๊ฒฐ๊ณผ
์ด๋ฏธ์ง ํ์ง ๋น๊ต (Qualitative Evaluation)
PixArt-ฮฃ๋ ํฌํ ๋ฆฌ์ผ๋ฆฌ์ฆ(Photorealism), ๋ํ ์ผ ์์ค, ์คํ์ผ ๋ค์์ฑ ์ธก๋ฉด์์ ์ด์ ๋ชจ๋ธ๋ณด๋ค ๊ฐ์ ๋จ.
์๋์ ๊ฐ์ ๋ชจ๋ธ๋ค๊ณผ ๋น๊ต๋จ:

PixArt-ฮฑ vs PixArt-ฮฃ
| ํญ๋ชฉ | PixArt-ฮฑ (๊ธฐ์กด) | PixArt-ฮฃ (๊ฐ์ ) |
| ์ต๋ ํด์๋ | 1K (1024ร1024) | 4K (3840ร2160) ์ง์ |
| ์ฐ์ฐ๋ ์ต์ ํ | ์์ | KV Token Compression ์ ์ฉ (์ฐ์ฐ๋ 34% ๊ฐ์) |
| VAE ๋ชจ๋ธ | ๊ธฐ๋ณธ VAE | SDXL VAE๋ก ๋ณ๊ฒฝ (๊ณ ํ์ง ์ด๋ฏธ์ง ์์ฑ ๊ฐ๋ฅ) |
| ํ์ต ์ ๋ต | ์ผ๋ฐ ํ์ต | Weak-to-Strong Training (๊ธฐ์กด ๋ชจ๋ธ ํ์ฉํ์ฌ ๋น ๋ฅด๊ฒ ํ์ต) |
| ํ ์คํธ ๊ธธ์ด | 120 ํ ํฐ | 300 ํ ํฐ์ผ๋ก ํ์ฅ (๋ ์ ๋ฐํ ํ ์คํธ-์ด๋ฏธ์ง ์ ๋ ฌ ๊ฐ๋ฅ) |
| ํ๋ จ ๋น์ฉ | ๋์ | ๊ธฐ์กด ๋๋น GPU ๋น์ฉ 9%๋ก ์ ๊ฐ |