Skip to content

Commit 50f29be

Browse files
committed
Initial commit.
0 parents  commit 50f29be

File tree

74 files changed

+78688
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+78688
-0
lines changed

README.md

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# ImageBART
2+
#### [NeurIPS 2021](https://nips.cc/)
3+
4+
![teaser](assets/modelfigure.png)
5+
<br/>
6+
[Patrick Esser](https://github.com/pesser)\*,
7+
[Robin Rombach](https://github.com/rromb)\*,
8+
[Andreas Blattmann](https://github.com/ablattmann)\*,
9+
[Björn Ommer](https://ommer-lab.com/)<br/>
10+
\* equal contribution
11+
12+
[arXiv](https://arxiv.org/abs/2108.08827) | [BibTeX](#bibtex) | [Poster](assets/imagebart_poster.pdf)
13+
14+
## Requirements
15+
A suitable [conda](https://conda.io/) environment named `imagebart` can be created
16+
and activated with:
17+
18+
```
19+
conda env create -f environment.yaml
20+
conda activate imagebart
21+
```
22+
23+
## Get the Models
24+
25+
We provide pretrained weights and hyperparameters for models trained on the following datasets:
26+
27+
* FFHQ:
28+
* [4 scales, geometric noise schedule](https://ommer-lab.com/files/ffhq_4_scales_geometric.zip): `wget -c https://ommer-lab.com/files/ffhq_4_scales_geometric.zip`
29+
* [2 scales, custom noise schedule](https://ommer-lab.com/files/ffhq_2_scales_custom.zip): `wget -c https://ommer-lab.com/files/ffhq_2_scales_custom.zip`
30+
* LSUN, 3 scales, custom noise schedules:
31+
* [Churches](https://ommer-lab.com/files/churches_3_scales.zip): `wget -c https://ommer-lab.com/files/churches_3_scales.zip`
32+
* [Bedrooms](https://ommer-lab.com/files/bedrooms_3_scales.zip): `wget -c https://ommer-lab.com/files/bedrooms_3_scales.zip`
33+
* [Cats](https://ommer-lab.com/files/cats_3_scales.zip): `wget -c https://ommer-lab.com/files/cats_3_scales.zip`
34+
* Class-conditional ImageNet:
35+
* [5 scales, custom noise schedule](https://ommer-lab.com/files/cin_5_scales_custom.zip): `wget -c https://ommer-lab.com/files/cin_5_scales_custom.zip`
36+
* [4 scales, geometric noise schedule](https://ommer-lab.com/files/cin_4_scales_geometric.zip): `wget -c https://ommer-lab.com/files/cin_4_scales_geometric.zip`
37+
38+
Download the respective files and extract their contents to a directory `./models/`.
39+
40+
Moreover, we provide all the required VQGANs as a .zip at [https://ommer-lab.com/files/vqgan.zip](https://ommer-lab.com/files/vqgan.zip),
41+
which contents have to be extracted to `./vqgan/`.
42+
43+
## Get the Data
44+
Running the training configs or the [inpainting script](scripts/inpaint_imagebart.py) requires
45+
a dataset available locally. For ImageNet and FFHQ, see this repo's parent directory [taming-transformers](https://github.com/CompVis/taming-transformers).
46+
The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
47+
We performed a custom split into training and validation images, and provide the corresponding filenames
48+
at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
49+
After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
50+
also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
51+
52+
## Inference
53+
54+
### Unconditional Sampling
55+
We provide a script for sampling from unconditional models trained on the LSUN-{bedrooms,bedrooms,cats}- and FFHQ-datasets.
56+
57+
#### FFHQ
58+
59+
On the FFHQ dataset, we provide two distinct pretrained models, one with a chain of length 4 and a geometric noise schedule as proposed by Sohl-Dickstein et al. [[1]](##References) , and another one with a chain of length 2 and a custom schedule.
60+
These models can be started with
61+
```shell script
62+
CUDA_VISIBLE_DEVICES=<gpu_id> streamlit run scripts/sample_imagebart.py configs/sampling/ffhq/<config>
63+
```
64+
65+
#### LSUN
66+
For the models trained on the LSUN-datasets, use
67+
```shell script
68+
CUDA_VISIBLE_DEVICES=<gpu_id> streamlit run scripts/sample_imagebart.py configs/sampling/lsun/<config>
69+
```
70+
71+
### Class Conditional Sampling on ImageNet
72+
73+
74+
To sample from class-conditional ImageNet models, use
75+
```shell script
76+
CUDA_VISIBLE_DEVICES=<gpu_id> streamlit run scripts/sample_imagebart.py configs/sampling/imagenet/<config>
77+
```
78+
79+
### Image Editing with Unconditional Models
80+
81+
We also provide a script for image editing with our unconditional models. For our FFHQ-model with geometric schedule this can be started with
82+
```shell script
83+
CUDA_VISIBLE_DEVICES=<gpu_id> streamlit run scripts/inpaint_imagebart.py configs/sampling/ffhq/ffhq_4scales_geometric.yaml
84+
```
85+
resulting in samples similar to the following.
86+
![teaser](assets/image_editing.png)
87+
88+
89+
## Training
90+
In general, there are two options for training the autoregressive transition probabilities of the
91+
reverse Markov chain: (i) train them jointly, taking into account a weighting of the
92+
individual scale contributions, or (ii) train them independently, which means that each
93+
training process optimizes a single transition and the scales must be stacked after training.
94+
We conduct most of our experiments using the latter option, but provide configurations for both cases.
95+
96+
### Training Scales Independently
97+
For training scales independently, each transition requires a seperate optimization process, which can
98+
started via
99+
100+
```
101+
CUDA_VISIBLE_DEVICES=<gpu_id> python main.py --base configs/<data>/<config>.yaml -t --gpus 0,
102+
```
103+
104+
We provide training configs for a four scale training of FFHQ using a geometric schedule,
105+
a four scale geometric training on ImageNet and various three-scale experiments on LSUN.
106+
See also the overview of our [pretrained models](#get-the-models).
107+
108+
109+
### Training Scales Jointly
110+
111+
For completeness, we also provide a config to run a joint training with 4 scales on FFHQ.
112+
Training can be started by running
113+
114+
```
115+
CUDA_VISIBLE_DEVICES=<gpu_id> python main.py --base configs/ffhq/ffhq_4_scales_joint-training.yaml -t --gpus 0,
116+
```
117+
118+
119+
## Shout-Outs
120+
Many thanks to all who make their work and implementations publicly available.
121+
For this work, these were in particular:
122+
123+
- The extremely clear and extensible encoder-decoder transformer implementations by [lucidrains](https://github.com/lucidrains):
124+
https://github.com/lucidrains/x-transformers
125+
- Emiel Hoogeboom et al's paper on multinomial diffusion and argmax flows: https://arxiv.org/abs/2102.05379
126+
127+
128+
![teaser](assets/foxchain.png)
129+
130+
## References
131+
132+
[1] Sohl-Dickstein, J., Weiss, E., Maheswaranathan, N. &amp; Ganguli, S.. (2015). Deep Unsupervised Learning using Nonequilibrium Thermodynamics. <i>Proceedings of the 32nd International Conference on Machine Learning
133+
134+
## Bibtex
135+
136+
```
137+
@article{DBLP:journals/corr/abs-2108-08827,
138+
author = {Patrick Esser and
139+
Robin Rombach and
140+
Andreas Blattmann and
141+
Bj{\"{o}}rn Ommer},
142+
title = {ImageBART: Bidirectional Context with Multinomial Diffusion for Autoregressive
143+
Image Synthesis},
144+
journal = {CoRR},
145+
volume = {abs/2108.08827},
146+
year = {2021}
147+
}
148+
```

assets/foxchain.png

848 KB
Loading

assets/image_editing.png

1.43 MB
Loading

assets/imagebart_poster.pdf

1.76 MB
Binary file not shown.

assets/modelfigure.png

672 KB
Loading

assets/sample-001.jpg

17.6 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
model:
2+
base_learning_rate: 0.0625
3+
target: imagebart.models.diffusion.DenoisingXTransformer
4+
params:
5+
first_stage_key: image
6+
monitor: val/loss
7+
n_scales: 2
8+
single_scale: 1
9+
top_k: 548
10+
alpha: 0.0
11+
redraw_prob: ffhq_bernoulli_PSIM
12+
use_ema: true
13+
14+
scheduler_config:
15+
target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
16+
params:
17+
verbosity_interval: 0
18+
warm_up_steps: 10000
19+
max_decay_steps: 1500001
20+
lr_start: 2.5e-06
21+
lr_max: 0.0001
22+
lr_min: 1.0e-08
23+
transformer_config:
24+
target: imagebart.modules.xtransformers.x_transformer.XTransformer
25+
params:
26+
wrap_decoder: false
27+
dim: 1152
28+
enc_num_tokens: 548
29+
enc_depth: 32
30+
enc_heads: 16
31+
enc_max_seq_len: 257
32+
dec_num_tokens: 548
33+
dec_depth: 6
34+
dec_heads: 16
35+
tie_token_emb: false
36+
dec_max_seq_len: 256
37+
first_stage_config:
38+
target: imagebart.models.vqgan.VQGANWrapper
39+
params:
40+
ckpt_path: vqgan/vqgan-ffhq.ckpt
41+
remap: data/vqgan_indices/ffhq_indices.npy
42+
sane_index_shape: true
43+
embed_dim: 256
44+
n_embed: 1024
45+
ddconfig:
46+
double_z: false
47+
z_channels: 256
48+
resolution: 256
49+
in_channels: 3
50+
out_ch: 3
51+
ch: 128
52+
ch_mult:
53+
- 1
54+
- 1
55+
- 2
56+
- 2
57+
- 4
58+
num_res_blocks: 2
59+
attn_resolutions:
60+
- 16
61+
dropout: 0.0
62+
lossconfig:
63+
target: taming.modules.losses.vqperceptual.DummyLoss
64+
65+
data:
66+
target: main.DataModuleFromConfig
67+
params:
68+
batch_size: 18
69+
num_workers: 32
70+
wrap: false
71+
train:
72+
target: taming.data.faceshq.FFHQTrain
73+
params:
74+
size: 256
75+
validation:
76+
target: taming.data.faceshq.FFHQValidation
77+
params:
78+
size: 256
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
model:
2+
base_learning_rate: 0.0625
3+
target: imagebart.models.diffusion.DecoderOnlyDenoiser
4+
params:
5+
first_stage_key: image
6+
monitor: val/loss
7+
n_scales: 2
8+
single_scale: 2
9+
top_k: 548
10+
alpha: 1.0
11+
redraw_prob: ffhq_bernoulli_PSIM
12+
use_ema: true
13+
scheduler_config:
14+
target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
15+
params:
16+
verbosity_interval: 0
17+
warm_up_steps: 10000
18+
max_decay_steps: 1500001
19+
lr_start: 2.5e-06
20+
lr_max: 0.0001
21+
lr_min: 1.0e-08
22+
transformer_config:
23+
target: imagebart.modules.transformer.mingpt.GPT
24+
params:
25+
vocab_size: 548
26+
block_size: 256
27+
n_layer: 36
28+
n_head: 16
29+
n_embd: 1216
30+
first_stage_config:
31+
target: imagebart.models.vqgan.VQGANWrapper
32+
params:
33+
ckpt_path: vqgan/vqgan-ffhq.ckpt
34+
remap: data/vqgan_indices/ffhq_indices.npy
35+
sane_index_shape: true
36+
embed_dim: 256
37+
n_embed: 1024
38+
ddconfig:
39+
double_z: false
40+
z_channels: 256
41+
resolution: 256
42+
in_channels: 3
43+
out_ch: 3
44+
ch: 128
45+
ch_mult:
46+
- 1
47+
- 1
48+
- 2
49+
- 2
50+
- 4
51+
num_res_blocks: 2
52+
attn_resolutions:
53+
- 16
54+
dropout: 0.0
55+
lossconfig:
56+
target: taming.modules.losses.vqperceptual.DummyLoss
57+
58+
data:
59+
target: main.DataModuleFromConfig
60+
params:
61+
batch_size: 18
62+
wrap: false
63+
train:
64+
target: taming.data.faceshq.FFHQTrain
65+
params:
66+
size: 256
67+
validation:
68+
target: taming.data.faceshq.FFHQValidation
69+
params:
70+
size: 256
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
model:
2+
base_learning_rate: 0.0625
3+
target: imagebart.models.diffusion.DenoisingXTransformer
4+
params:
5+
first_stage_key: "image"
6+
monitor: "val/loss"
7+
n_scales: 4
8+
single_scale: 1
9+
top_k: 548
10+
alpha: 0.0
11+
redraw_prob: geometric
12+
use_ema: True
13+
14+
scheduler_config:
15+
target: imagebart.lr_scheduler.LambdaWarmUpCosineScheduler
16+
params:
17+
verbosity_interval: 0 # 0 or negative to disable
18+
warm_up_steps: 10000
19+
max_decay_steps: 1500001
20+
lr_start: 2.5e-6
21+
lr_max: 1.0e-4
22+
lr_min: 1.0e-8
23+
24+
transformer_config:
25+
target: imagebart.modules.xtransformers.x_transformer.ResidualScaledXTransformer
26+
params:
27+
scale_pos: 0
28+
n_scales: 4
29+
xt_start: 1
30+
xt_size: 256 # predict x_{t-1}
31+
wrap_decoder: False
32+
dim: 752
33+
enc_num_tokens: 548
34+
enc_depth: 18
35+
enc_heads: 16
36+
enc_max_seq_len: 257
37+
dec_num_tokens: 548
38+
dec_depth: 6
39+
dec_heads: 16
40+
tie_token_emb: False
41+
dec_max_seq_len: 256
42+
43+
first_stage_config:
44+
target: imagebart.models.vqgan.VQGANWrapper
45+
params:
46+
ckpt_path: vqgan/vqgan-ffhq.ckpt
47+
remap: "data/vqgan_indices/ffhq_indices.npy"
48+
sane_index_shape: True
49+
embed_dim: 256
50+
n_embed: 1024
51+
ddconfig:
52+
double_z: false
53+
z_channels: 256
54+
resolution: 256
55+
in_channels: 3
56+
out_ch: 3
57+
ch: 128
58+
ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
59+
num_res_blocks: 2
60+
attn_resolutions: [ 16 ]
61+
dropout: 0.0
62+
lossconfig:
63+
target: taming.modules.losses.vqperceptual.DummyLoss
64+
65+
data:
66+
target: main.DataModuleFromConfig
67+
params:
68+
batch_size: 16
69+
num_workers: 32
70+
wrap: False
71+
train:
72+
target: taming.data.faceshq.FFHQTrain
73+
params:
74+
size: 256
75+
validation:
76+
target: taming.data.faceshq.FFHQValidation
77+
params:
78+
size: 256
79+
80+
lightning:
81+
callbacks:
82+
image_logger:
83+
target: main.ImageLogger
84+
params:
85+
batch_frequency: 1000
86+
max_images: 4
87+
increase_log_steps: False
88+
trainer:
89+
benchmark: True

0 commit comments

Comments
 (0)