DALL·E 2【论文精读】_哔哩哔哩_bilibili更多论文:https://github.com/mli/paper-reading, 视频播放量 30350、弹幕量 256、点赞数 1767、投硬币枚数 1318、收藏人数 751、转发人数 344, 视频作者 跟李沐学AI, 作者简介 ,相关视频:博一研究生 求偶视频,如何做好文献阅读及笔记整理,在线求偶|26岁985副教授,开组会时,师兄SCI见刊了,生成对抗网络GAN开山之作论文精读,GAN论文逐段精读【论文精读】,对比学习论文综述【论文精读】,01 机器学习编译概述 【MLC-机器学习编译中文版】,导师对不起,您评院士的事可能得缓缓了,【精读AI论文】知识蒸馏https://www.bilibili.com/video/BV17r4y1u77B?spm_id_from=333.999.0.0&vd_source=4aed82e35f26bb600bc5b46e65e25c22看到市面上的一些关于dalle2的的解释其实都不太好,没说的很明白,生成模型的三大方向分别是vae,gan和扩散模型,其中ae->dae->vae->vqvae->diffusion,扩散模型的ddpm->improved ddpm->diffusion beets GAN->glide->dalle2.
1.introduction
clip对图像分布变化具有鲁棒性,可以zero-shot,扩散模型能满足样本多样性且保真度也不错。dalle2结合了这两个模型的优良特性。
2.method
上面这张图画的很好,结合这个图来看,首先虚线上面是一个clip,这个clip是提前训练好的,在dalle2的训练期间不会再去训练clip,是个权重锁死的,在dalle2的训练时,输入也是一对数据,一个文本对及其对应的图像,首先输入一个文本,经过clip的文本编码模块(bert,clip对图像使用vit,对text使用bert进行编码,clip是基本的对比学习,两个模态的编码很重要,模态编码之后直接余弦求相似度了),在输入一个图像,经过clip的图像编码模块,产生了图像的vector,这个图像vector其实是gt。产生的文本编码输入到第一个prior模型中,这是一个扩散模型,也可以用自回归的transformer,这个扩散模型输出一组图像vector,这时候通过经过clip产生的图像vector进行监督,此处其实是一个监督模型,后面是一个decoder模块,在以往的dalle中,encoder和decoder是放在dvae中一起训练的,但是此处的deocder是单训的,也是一个扩散模型,其实虚线之下的生成模型,是将一个完整的生成步骤,变成了二阶段显式的图像生成,作者实验这种显式的生成效果更好。这篇文章称自己为unclip,clip是将输入的文本和图像转成特征,而dalle2是将文本特征转成图像特征再转成图像的过程,其实图像特征到图像是通过一个扩散模型实现的。在deocder时既用了classifier-free guidence也用了clip的guidence,这个guidence指的是在decoder的过程中,输入是t时刻的一个带噪声的图像,最终输出是一个图像,这个带噪声的图像通过unet每一次得到的一个特征图可以用一个图像分类器去做判定,此处一般就用交叉熵函数做一个二分类,但是可以获取图像分类的梯度,利用这个梯度去引导扩散去更好的decoder。
3.代码
GitHub - lucidrains/DALLE2-pytorch: Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
核心是训练一个先验模型和一个decoder模型,这两个都是扩散模型,当然先验模型也可以是自回归AE,如果是自回归AE就是dalle的思路,clip自己选一个训练好的即可,clip的本质是提供良好的图像和文本的vector。
train_diffusion_prior输入是经过img2dataset和clip-retrieval的转换,通过img2dataset下载数据,训练先验模型,主要是通过clip-retrieval生成需要的img_emb/text_emb和meta_url。train_decoder的训练输入是img2dataset生成的tar模型,tar中包含图片即可,在dalle2-pytorch中会对输入的图片做判定,内置了一个clip对其进行特征提取,实际上先验和生成模型作为显式分离的部分,是单独训练的。
diffusion_prior
DiffusionPrior(
(noise_scheduler): NoiseScheduler()
(clip): OpenAIClipAdapter(
(clip): CLIP(
(visual): VisionTransformer(
(conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(transformer): Transformer(
(resblocks): Sequential(
(0): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(1): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(2): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(3): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(4): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(5): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(6): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(7): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(8): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(9): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(10): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(11): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(12): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(13): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(14): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(15): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(16): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(17): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(18): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(19): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(20): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(21): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(22): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(23): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
)
(ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(transformer): Transformer(
(resblocks): Sequential(
(0): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(1): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(2): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(3): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(4): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(5): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(6): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(7): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(8): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(9): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(10): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(11): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
(token_embedding): Embedding(49408, 768)
(ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(clip_normalize): Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
(net): DiffusionPriorNetwork(
(to_text_embeds): Sequential(
(0): Identity()
(1): Rearrange('b (n d) -> b n d', n=1)
)
(to_time_embeds): Sequential(
(0): Embedding(1000, 768)
(1): Rearrange('b (n d) -> b n d', n=1)
)
(to_image_embeds): Sequential(
(0): Identity()
(1): Rearrange('b (n d) -> b n d', n=1)
)
(causal_transformer): CausalTransformer(
(init_norm): Identity()
(rel_pos_bias): RelPosBias(
(relative_attention_bias): Embedding(32, 12)
)
(layers): ModuleList(
(0): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(1): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(2): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(3): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(4): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(5): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(6): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(7): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(8): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(9): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(10): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
(11): ModuleList(
(0): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.05, inplace=False)
(to_q): Linear(in_features=768, out_features=768, bias=False)
(to_kv): Linear(in_features=768, out_features=128, bias=False)
(rotary_emb): RotaryEmbedding()
(to_out): Sequential(
(0): Linear(in_features=768, out_features=768, bias=False)
(1): LayerNorm()
)
)
(1): Sequential(
(0): LayerNorm()
(1): Linear(in_features=768, out_features=6144, bias=False)
(2): SwiGLU()
(3): LayerNorm()
(4): Dropout(p=0.05, inplace=False)
(5): Linear(in_features=3072, out_features=768, bias=False)
)
)
)
(norm): LayerNorm()
(project_out): Linear(in_features=768, out_features=768, bias=False)
)
)
)
先验模型是一个自回归的transformer,流程如下:
TrainDiffusionPriorConfig.from_json_path()->prior/data/train/tracker(保存)->train:DiffusionPriorTrainer=make_model()->(prior_config)DiffusionPriorConfig/DiffusionPriorTrainConfig->diffusion_prior=prior_config.create()->clip=clip.create()->AdapterConfig.create()->OpenAIClipAdapter()->diffusion_prior_network=net.create()->DiffusionPriorNetworkConfig.create()->DiffusionPriorNetwork->trainer=DiffusionPriorTrainer->tracker=create_tracker()->TrackerConfig.create()->img_reader=get_reader()此处输入是三组img_url/text_url/meta_url->image_reader=EmbeddingReader()/text_reader=EmbeddingReader()->train_loader/eval_loader/test_loader=make_splits->train:DiffusionPriorTrainer/Tracker/DiffusionPriorTrainConfig->img:16,768,txt:16,77->DiffusionPrior.forward:net:DiffusionPriorNetwork->image_embed,_=self.clip.embed_image()/text_embed,text_encodings=self.clip.embed_text()->times=self.noise_scheduler.sample_random_times->self.p_losses->image_embed_noisy=self.noise_scheduler.q_sample(NoiseScheduler)->pred=self.net:image_embed_noisy:16,768,text_cond:text_embed 16,768/text_encodings16,77,768->DiffusionPriorNetwork,forward()->image_embed:16,768,text_embed:16,768->tokens:16,81,768->pred_image_embed:16,768->target=noise:16,768->loss=self.noise_scheduler.loss_fn(l2:mse)->trainer.update
decoder:
Decoder(
(clip): OpenAIClipAdapter(
(clip): CLIP(
(visual): VisionTransformer(
(conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
(ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(transformer): Transformer(
(resblocks): Sequential(
(0): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(1): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(2): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(3): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(4): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(5): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(6): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(7): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(8): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(9): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(10): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(11): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(12): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(13): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(14): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(15): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(16): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(17): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(18): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(19): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(20): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(21): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(22): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(23): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
)
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=1024, out_features=4096, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=4096, out_features=1024, bias=True)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
)
(ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(transformer): Transformer(
(resblocks): Sequential(
(0): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(1): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(2): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(3): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(4): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(5): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(6): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(7): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(8): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(9): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(10): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(11): ResidualAttentionBlock(
(attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
)
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(c_fc): Linear(in_features=768, out_features=3072, bias=True)
(gelu): QuickGELU()
(c_proj): Linear(in_features=3072, out_features=768, bias=True)
)
(ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
(token_embedding): Embedding(49408, 768)
(ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(clip_normalize): Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
(unets): ModuleList(
(0): Unet(
(init_conv): CrossEmbedLayer(
(convs): ModuleList(
(0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): Conv2d(3, 4, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
(2): Conv2d(3, 4, kernel_size=(15, 15), stride=(1, 1), padding=(7, 7))
)
)
(to_time_hiddens): Sequential(
(0): SinusoidalPosEmb()
(1): Linear(in_features=16, out_features=64, bias=True)
(2): GELU()
)
(to_time_tokens): Sequential(
(0): Linear(in_features=64, out_features=32, bias=True)
(1): Rearrange('b (r d) -> b r d', r=2)
)
(to_time_cond): Sequential(
(0): Linear(in_features=64, out_features=64, bias=True)
)
(image_to_tokens): Identity()
(norm_cond): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(norm_mid_cond): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(downs): ModuleList(
(0): ModuleList(
(0): None
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
)
(3): Identity()
(4): Conv2d(16, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
(1): ModuleList(
(0): None
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=16, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=16, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=16, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=16, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
)
(3): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=16, out_features=64, bias=False)
(to_kv): Linear(in_features=16, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=16, bias=False)
(1): LayerNorm()
)
)
)
)
(4): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
(2): ModuleList(
(0): None
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(block1): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
)
(3): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=64, bias=False)
(to_kv): Linear(in_features=32, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
)
(4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
(3): ModuleList(
(0): None
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(block1): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
)
(3): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=64, bias=False)
(to_kv): Linear(in_features=64, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
)
(4): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
)
)
(ups): ModuleList(
(0): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=64, bias=False)
(to_kv): Linear(in_features=128, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
)
(3): PixelShuffleUpsample(
(net): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
(1): SiLU()
(2): PixelShuffle(upscale_factor=2)
)
)
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=128, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 64, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=64, out_features=64, bias=False)
(to_kv): Linear(in_features=64, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=64, bias=False)
(1): LayerNorm()
)
)
)
)
(3): PixelShuffleUpsample(
(net): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
(1): SiLU()
(2): PixelShuffle(upscale_factor=2)
)
)
)
(2): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=64, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 32, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=32, out_features=64, bias=False)
(to_kv): Linear(in_features=32, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=32, bias=False)
(1): LayerNorm()
)
)
)
)
(3): PixelShuffleUpsample(
(net): Sequential(
(0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
(1): SiLU()
(2): PixelShuffle(upscale_factor=2)
)
)
)
(3): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
)
(1): ModuleList(
(0): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
)
)
(2): Identity()
(3): Identity()
)
)
(mid_block1): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(mid_attn): EinopsToAndFrom(
(fn): Residual(
(fn): Attention(
(norm): LayerNorm()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=64, bias=False)
(to_kv): Linear(in_features=128, out_features=32, bias=False)
(to_out): Sequential(
(0): Linear(in_features=64, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
)
(mid_block2): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=256, bias=True)
)
(cross_attn): EinopsToAndFrom(
(fn): CrossAttention(
(norm): LayerNorm()
(norm_context): Identity()
(dropout): Dropout(p=0.0, inplace=False)
(to_q): Linear(in_features=128, out_features=512, bias=False)
(to_kv): Linear(in_features=16, out_features=1024, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=128, bias=False)
(1): LayerNorm()
)
)
)
(block1): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 128, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Identity()
)
(upsample_combiner): UpsampleCombiner()
(final_resnet_block): ResnetBlock(
(time_mlp): Sequential(
(0): SiLU()
(1): Linear(in_features=64, out_features=32, bias=True)
)
(block1): Block(
(project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(block2): Block(
(project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): GroupNorm(8, 16, eps=1e-05, affine=True)
(act): SiLU()
)
(res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
)
(to_out): Conv2d(16, 3, kernel_size=(1, 1), stride=(1, 1))
)
)
(vaes): ModuleList(
(0): NullVQGanVAE()
)
(noise_schedulers): ModuleList(
(0): NoiseScheduler()
)
(lowres_conds): ModuleList(
(0): None
)
)
TrainDecoderConfig->DecoderConfig/DecoderDataConfig/DecoderTrainConfig/DecoderEvaluateConfig/TrackerConfig->dataloader=create_dataloaders->create_image_embedding_dataloader->decoder=config.decoder.create()->DecoderConfig.create()->Unet(unconfigs)->clip=clip.create()->OpenAIClipAdapter->Decoder->tracker=create_tracker->train:DecoderTrainer->img:4,3,224,224,txt:cat/sea/tree/motel->DecoderTrainer.forward()->self.decoder->Decoder.forward()->resize_image_to:image:4,3,64,64->image=vae.encode(image)->p_losses->x_noisy=noise_scheduler.q_sample()->model_output=unet():4,3,64,64->Unet->target:4,3,64,64,pred:4,2,64,64->loss=noise_scheduler.loss_fn:l2:mse->mean,var=noise_scheduler.q_posterior()->kl=normal_kl->loss+vb_loss