【人工智能笔记】第三十四节:TF2实现VITGAN对抗生成网络,PositionalEmbedding 实现

【人工智能笔记】第三十四节:TF2实现VITGAN对抗生成网络,PositionalEmbedding 实现_第1张图片
该章节介绍VITGAN对抗生成网络中,PositionalEmbedding 部分的代码实现。

目录(文章发布后会补上链接):

  1. 网络结构简介
  2. Mapping NetWork 实现
  3. PositionalEmbedding 实现
  4. MLP 实现
  5. MSA多头注意力 实现
  6. SLN自调制 实现
  7. CoordinatesPositionalEmbedding 实现
  8. ModulatedLinear 实现
  9. Siren 实现
  10. Generator生成器 实现
  11. PatchEmbedding 实现
  12. ISN 实现
  13. Discriminator鉴别器 实现
  14. VITGAN 实现

PositionalEmbedding 简介

【人工智能笔记】第三十四节:TF2实现VITGAN对抗生成网络,PositionalEmbedding 实现_第2张图片
PositionalEmbedding 就是图中1-N的位置编码,根据论文中描述,位置编号1-N(N为图片块数量),经过全连接层映射,再用sin函数约束值的范围[-1,1]。

代码实现

import tensorflow as tf


class PositionalEmbedding(tf.Module):
    """
    输入位置编码
    """

    def __init__(
        self,
        sequence_length,
        emb_dim,
        name=None,
    ):
        super().__init__(name=name)
        self.emb_dim = emb_dim
        self.sequence_length = sequence_length
        self.pos_emb = tf.keras.layers.Dense(emb_dim, use_bias=False)
        self.pos_input = tf.linspace(-1, 1, sequence_length)[tf.newaxis, :, tf.newaxis]

    def __call__(self):
        x = self.pos_emb(self.pos_input)
        x = tf.math.sin(x)
        return x


if __name__ == "__main__":
    layer = PositionalEmbedding(
        sequence_length=196,
        emb_dim=768
    )
    o1 = layer()
    tf.print('o1:', tf.shape(o1))

参考资料:

  • 论文地址:VITGAN: Training GANs with Vision Transformers
  • 源码:https://github.com/tfwcn/VITGAN-tf2

你可能感兴趣的:(深度学习,人工智能,tensorflow,深度学习,VITGAN,位置编码)