dalle2:hierarchical text-conditional image generation with clip

DALL·E 2【论文精读】_哔哩哔哩_bilibili更多论文:https://github.com/mli/paper-reading, 视频播放量 30350、弹幕量 256、点赞数 1767、投硬币枚数 1318、收藏人数 751、转发人数 344, 视频作者 跟李沐学AI, 作者简介 ,相关视频:博一研究生 求偶视频,如何做好文献阅读及笔记整理,在线求偶|26岁985副教授,开组会时,师兄SCI见刊了,生成对抗网络GAN开山之作论文精读,GAN论文逐段精读【论文精读】,对比学习论文综述【论文精读】,01 机器学习编译概述 【MLC-机器学习编译中文版】,导师对不起,您评院士的事可能得缓缓了,【精读AI论文】知识蒸馏https://www.bilibili.com/video/BV17r4y1u77B?spm_id_from=333.999.0.0&vd_source=4aed82e35f26bb600bc5b46e65e25c22看到市面上的一些关于dalle2的的解释其实都不太好,没说的很明白,生成模型的三大方向分别是vae,gan和扩散模型,其中ae->dae->vae->vqvae->diffusion,扩散模型的ddpm->improved ddpm->diffusion beets GAN->glide->dalle2.

1.introduction

        clip对图像分布变化具有鲁棒性,可以zero-shot,扩散模型能满足样本多样性且保真度也不错。dalle2结合了这两个模型的优良特性。

2.method

dalle2:hierarchical text-conditional image generation with clip_第1张图片

上面这张图画的很好,结合这个图来看,首先虚线上面是一个clip,这个clip是提前训练好的,在dalle2的训练期间不会再去训练clip,是个权重锁死的,在dalle2的训练时,输入也是一对数据,一个文本对及其对应的图像,首先输入一个文本,经过clip的文本编码模块(bert,clip对图像使用vit,对text使用bert进行编码,clip是基本的对比学习,两个模态的编码很重要,模态编码之后直接余弦求相似度了),在输入一个图像,经过clip的图像编码模块,产生了图像的vector,这个图像vector其实是gt。产生的文本编码输入到第一个prior模型中,这是一个扩散模型,也可以用自回归的transformer,这个扩散模型输出一组图像vector,这时候通过经过clip产生的图像vector进行监督,此处其实是一个监督模型,后面是一个decoder模块,在以往的dalle中,encoder和decoder是放在dvae中一起训练的,但是此处的deocder是单训的,也是一个扩散模型,其实虚线之下的生成模型,是将一个完整的生成步骤,变成了二阶段显式的图像生成,作者实验这种显式的生成效果更好。这篇文章称自己为unclip,clip是将输入的文本和图像转成特征,而dalle2是将文本特征转成图像特征再转成图像的过程,其实图像特征到图像是通过一个扩散模型实现的。在deocder时既用了classifier-free guidence也用了clip的guidence,这个guidence指的是在decoder的过程中,输入是t时刻的一个带噪声的图像,最终输出是一个图像,这个带噪声的图像通过unet每一次得到的一个特征图可以用一个图像分类器去做判定,此处一般就用交叉熵函数做一个二分类,但是可以获取图像分类的梯度,利用这个梯度去引导扩散去更好的decoder。

3.代码

GitHub - lucidrains/DALLE2-pytorch: Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch

核心是训练一个先验模型和一个decoder模型,这两个都是扩散模型,当然先验模型也可以是自回归AE,如果是自回归AE就是dalle的思路,clip自己选一个训练好的即可,clip的本质是提供良好的图像和文本的vector。

train_diffusion_prior输入是经过img2dataset和clip-retrieval的转换,通过img2dataset下载数据,训练先验模型,主要是通过clip-retrieval生成需要的img_emb/text_emb和meta_url。train_decoder的训练输入是img2dataset生成的tar模型,tar中包含图片即可,在dalle2-pytorch中会对输入的图片做判定,内置了一个clip对其进行特征提取,实际上先验和生成模型作为显式分离的部分,是单独训练的。

diffusion_prior

DiffusionPrior(
  (noise_scheduler): NoiseScheduler()
  (clip): OpenAIClipAdapter(
    (clip): CLIP(
      (visual): VisionTransformer(
        (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (transformer): Transformer(
          (resblocks): Sequential(
            (0): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (1): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (2): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (3): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (4): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (5): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (6): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (7): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (8): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (9): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (10): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (11): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (12): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (13): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (14): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (15): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (16): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (17): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (18): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (19): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (20): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (21): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (22): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (23): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
        )
        (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (2): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (3): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (4): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (5): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (6): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (7): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (8): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (9): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (10): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (11): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (token_embedding): Embedding(49408, 768)
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (clip_normalize): Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
  )
  (net): DiffusionPriorNetwork(
    (to_text_embeds): Sequential(
      (0): Identity()
      (1): Rearrange('b (n d) -> b n d', n=1)
    )
    (to_time_embeds): Sequential(
      (0): Embedding(1000, 768)
      (1): Rearrange('b (n d) -> b n d', n=1)
    )
    (to_image_embeds): Sequential(
      (0): Identity()
      (1): Rearrange('b (n d) -> b n d', n=1)
    )
    (causal_transformer): CausalTransformer(
      (init_norm): Identity()
      (rel_pos_bias): RelPosBias(
        (relative_attention_bias): Embedding(32, 12)
      )
      (layers): ModuleList(
        (0): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (1): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (2): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (3): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (4): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (5): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (6): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (7): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (8): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (9): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (10): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
        (11): ModuleList(
          (0): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.05, inplace=False)
            (to_q): Linear(in_features=768, out_features=768, bias=False)
            (to_kv): Linear(in_features=768, out_features=128, bias=False)
            (rotary_emb): RotaryEmbedding()
            (to_out): Sequential(
              (0): Linear(in_features=768, out_features=768, bias=False)
              (1): LayerNorm()
            )
          )
          (1): Sequential(
            (0): LayerNorm()
            (1): Linear(in_features=768, out_features=6144, bias=False)
            (2): SwiGLU()
            (3): LayerNorm()
            (4): Dropout(p=0.05, inplace=False)
            (5): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
      )
      (norm): LayerNorm()
      (project_out): Linear(in_features=768, out_features=768, bias=False)
    )
  )
)

先验模型是一个自回归的transformer,流程如下:

TrainDiffusionPriorConfig.from_json_path()->prior/data/train/tracker(保存)->train:DiffusionPriorTrainer=make_model()->(prior_config)DiffusionPriorConfig/DiffusionPriorTrainConfig->diffusion_prior=prior_config.create()->clip=clip.create()->AdapterConfig.create()->OpenAIClipAdapter()->diffusion_prior_network=net.create()->DiffusionPriorNetworkConfig.create()->DiffusionPriorNetwork->trainer=DiffusionPriorTrainer->tracker=create_tracker()->TrackerConfig.create()->img_reader=get_reader()此处输入是三组img_url/text_url/meta_url->image_reader=EmbeddingReader()/text_reader=EmbeddingReader()->train_loader/eval_loader/test_loader=make_splits->train:DiffusionPriorTrainer/Tracker/DiffusionPriorTrainConfig->img:16,768,txt:16,77->DiffusionPrior.forward:net:DiffusionPriorNetwork->image_embed,_=self.clip.embed_image()/text_embed,text_encodings=self.clip.embed_text()->times=self.noise_scheduler.sample_random_times->self.p_losses->image_embed_noisy=self.noise_scheduler.q_sample(NoiseScheduler)->pred=self.net:image_embed_noisy:16,768,text_cond:text_embed 16,768/text_encodings16,77,768->DiffusionPriorNetwork,forward()->image_embed:16,768,text_embed:16,768->tokens:16,81,768->pred_image_embed:16,768->target=noise:16,768->loss=self.noise_scheduler.loss_fn(l2:mse)->trainer.update

decoder:

Decoder(
  (clip): OpenAIClipAdapter(
    (clip): CLIP(
      (visual): VisionTransformer(
        (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (transformer): Transformer(
          (resblocks): Sequential(
            (0): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (1): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (2): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (3): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (4): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (5): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (6): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (7): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (8): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (9): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (10): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (11): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (12): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (13): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (14): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (15): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (16): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (17): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (18): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (19): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (20): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (21): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (22): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
            (23): ResidualAttentionBlock(
              (attn): MultiheadAttention(
                (out_proj): _LinearWithBias(in_features=1024, out_features=1024, bias=True)
              )
              (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (mlp): Sequential(
                (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
                (gelu): QuickGELU()
                (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
              )
              (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            )
          )
        )
        (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (2): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (3): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (4): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (5): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (6): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (7): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (8): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (9): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (10): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (11): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): _LinearWithBias(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
        )
      )
      (token_embedding): Embedding(49408, 768)
      (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (clip_normalize): Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
  )
  (unets): ModuleList(
    (0): Unet(
      (init_conv): CrossEmbedLayer(
        (convs): ModuleList(
          (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): Conv2d(3, 4, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (2): Conv2d(3, 4, kernel_size=(15, 15), stride=(1, 1), padding=(7, 7))
        )
      )
      (to_time_hiddens): Sequential(
        (0): SinusoidalPosEmb()
        (1): Linear(in_features=16, out_features=64, bias=True)
        (2): GELU()
      )
      (to_time_tokens): Sequential(
        (0): Linear(in_features=64, out_features=32, bias=True)
        (1): Rearrange('b (r d) -> b r d', r=2)
      )
      (to_time_cond): Sequential(
        (0): Linear(in_features=64, out_features=64, bias=True)
      )
      (image_to_tokens): Identity()
      (norm_cond): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (norm_mid_cond): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (downs): ModuleList(
        (0): ModuleList(
          (0): None
          (1): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=32, bias=True)
            )
            (block1): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Identity()
          )
          (2): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (block1): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (block1): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
          )
          (3): Identity()
          (4): Conv2d(16, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        )
        (1): ModuleList(
          (0): None
          (1): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=32, bias=True)
            )
            (block1): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Identity()
          )
          (2): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=16, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=16, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=16, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=16, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
          )
          (3): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=16, out_features=64, bias=False)
                (to_kv): Linear(in_features=16, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=16, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (4): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        )
        (2): ModuleList(
          (0): None
          (1): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=64, bias=True)
            )
            (block1): Block(
              (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Identity()
          )
          (2): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=64, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=32, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=32, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=64, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=32, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=32, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
          )
          (3): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=32, out_features=64, bias=False)
                (to_kv): Linear(in_features=32, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=32, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        )
        (3): ModuleList(
          (0): None
          (1): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=128, bias=True)
            )
            (block1): Block(
              (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Identity()
          )
          (2): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=128, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=64, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=64, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=128, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=64, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=64, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Identity()
            )
          )
          (3): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=64, out_features=64, bias=False)
                (to_kv): Linear(in_features=64, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=64, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (4): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (ups): ModuleList(
        (0): ModuleList(
          (0): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=256, bias=True)
            )
            (cross_attn): EinopsToAndFrom(
              (fn): CrossAttention(
                (norm): LayerNorm()
                (norm_context): Identity()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=128, out_features=512, bias=False)
                (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=512, out_features=128, bias=False)
                  (1): LayerNorm()
                )
              )
            )
            (block1): Block(
              (project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=256, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=128, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=128, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=256, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=128, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=128, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (2): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=128, out_features=64, bias=False)
                (to_kv): Linear(in_features=128, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=128, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (3): PixelShuffleUpsample(
            (net): Sequential(
              (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))
              (1): SiLU()
              (2): PixelShuffle(upscale_factor=2)
            )
          )
        )
        (1): ModuleList(
          (0): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=128, bias=True)
            )
            (cross_attn): EinopsToAndFrom(
              (fn): CrossAttention(
                (norm): LayerNorm()
                (norm_context): Identity()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=64, out_features=512, bias=False)
                (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=512, out_features=64, bias=False)
                  (1): LayerNorm()
                )
              )
            )
            (block1): Block(
              (project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=128, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=64, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=64, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=128, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=64, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=64, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (2): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=64, out_features=64, bias=False)
                (to_kv): Linear(in_features=64, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=64, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (3): PixelShuffleUpsample(
            (net): Sequential(
              (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
              (1): SiLU()
              (2): PixelShuffle(upscale_factor=2)
            )
          )
        )
        (2): ModuleList(
          (0): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=64, bias=True)
            )
            (cross_attn): EinopsToAndFrom(
              (fn): CrossAttention(
                (norm): LayerNorm()
                (norm_context): Identity()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=32, out_features=512, bias=False)
                (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=512, out_features=32, bias=False)
                  (1): LayerNorm()
                )
              )
            )
            (block1): Block(
              (project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=64, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=32, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=32, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=64, bias=True)
              )
              (cross_attn): EinopsToAndFrom(
                (fn): CrossAttention(
                  (norm): LayerNorm()
                  (norm_context): Identity()
                  (dropout): Dropout(p=0.0, inplace=False)
                  (to_q): Linear(in_features=32, out_features=512, bias=False)
                  (to_kv): Linear(in_features=16, out_features=1024, bias=False)
                  (to_out): Sequential(
                    (0): Linear(in_features=512, out_features=32, bias=False)
                    (1): LayerNorm()
                  )
                )
              )
              (block1): Block(
                (project): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 32, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (2): EinopsToAndFrom(
            (fn): Residual(
              (fn): Attention(
                (norm): LayerNorm()
                (dropout): Dropout(p=0.0, inplace=False)
                (to_q): Linear(in_features=32, out_features=64, bias=False)
                (to_kv): Linear(in_features=32, out_features=32, bias=False)
                (to_out): Sequential(
                  (0): Linear(in_features=64, out_features=32, bias=False)
                  (1): LayerNorm()
                )
              )
            )
          )
          (3): PixelShuffleUpsample(
            (net): Sequential(
              (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
              (1): SiLU()
              (2): PixelShuffle(upscale_factor=2)
            )
          )
        )
        (3): ModuleList(
          (0): ResnetBlock(
            (time_mlp): Sequential(
              (0): SiLU()
              (1): Linear(in_features=64, out_features=32, bias=True)
            )
            (block1): Block(
              (project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (block2): Block(
              (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
              (act): SiLU()
            )
            (res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
          )
          (1): ModuleList(
            (0): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (block1): Block(
                (project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
            )
            (1): ResnetBlock(
              (time_mlp): Sequential(
                (0): SiLU()
                (1): Linear(in_features=64, out_features=32, bias=True)
              )
              (block1): Block(
                (project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (block2): Block(
                (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
                (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
                (act): SiLU()
              )
              (res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (2): Identity()
          (3): Identity()
        )
      )
      (mid_block1): ResnetBlock(
        (time_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=64, out_features=256, bias=True)
        )
        (cross_attn): EinopsToAndFrom(
          (fn): CrossAttention(
            (norm): LayerNorm()
            (norm_context): Identity()
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=128, out_features=512, bias=False)
            (to_kv): Linear(in_features=16, out_features=1024, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=128, bias=False)
              (1): LayerNorm()
            )
          )
        )
        (block1): Block(
          (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (mid_attn): EinopsToAndFrom(
        (fn): Residual(
          (fn): Attention(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=128, out_features=64, bias=False)
            (to_kv): Linear(in_features=128, out_features=32, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=64, out_features=128, bias=False)
              (1): LayerNorm()
            )
          )
        )
      )
      (mid_block2): ResnetBlock(
        (time_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=64, out_features=256, bias=True)
        )
        (cross_attn): EinopsToAndFrom(
          (fn): CrossAttention(
            (norm): LayerNorm()
            (norm_context): Identity()
            (dropout): Dropout(p=0.0, inplace=False)
            (to_q): Linear(in_features=128, out_features=512, bias=False)
            (to_kv): Linear(in_features=16, out_features=1024, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=128, bias=False)
              (1): LayerNorm()
            )
          )
        )
        (block1): Block(
          (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (upsample_combiner): UpsampleCombiner()
      (final_resnet_block): ResnetBlock(
        (time_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=64, out_features=32, bias=True)
        )
        (block1): Block(
          (project): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): Block(
          (project): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 16, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
      )
      (to_out): Conv2d(16, 3, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (vaes): ModuleList(
    (0): NullVQGanVAE()
  )
  (noise_schedulers): ModuleList(
    (0): NoiseScheduler()
  )
  (lowres_conds): ModuleList(
    (0): None
  )
)

TrainDecoderConfig->DecoderConfig/DecoderDataConfig/DecoderTrainConfig/DecoderEvaluateConfig/TrackerConfig->dataloader=create_dataloaders->create_image_embedding_dataloader->decoder=config.decoder.create()->DecoderConfig.create()->Unet(unconfigs)->clip=clip.create()->OpenAIClipAdapter->Decoder->tracker=create_tracker->train:DecoderTrainer->img:4,3,224,224,txt:cat/sea/tree/motel->DecoderTrainer.forward()->self.decoder->Decoder.forward()->resize_image_to:image:4,3,64,64->image=vae.encode(image)->p_losses->x_noisy=noise_scheduler.q_sample()->model_output=unet():4,3,64,64->Unet->target:4,3,64,64,pred:4,2,64,64->loss=noise_scheduler.loss_fn:l2:mse->mean,var=noise_scheduler.q_posterior()->kl=normal_kl->loss+vb_loss

你可能感兴趣的:(多模态和自然语言处理,电商算法与创意生成,人工智能,深度学习)