【实验】vit代码

这里写目录标题

  • 参考
  • 讲解一:代码+理论
  • 版本一:lucidrains
    • 1. 大佬复现版本给的使用案例
    • 2. Transformer结构
    • 3. Attention
    • 4. FeedForward
    • 5. ViT操作流程
  • 版本二 :rwightman 源码可直接运行
    • PatchEmbed模块
    • Attention模块
    • MLP Block(图中的名字)/FeedForward类(代码中的实现)
    • Encoder Block主模块
    • VisionTransformer类

参考

霹雳吧啦Wz-pytorch_classification/vision_transformer
视频:
霹雳吧啦Wz

笔记:
VIT(vision transformer)模型介绍+pytorch代码炸裂解析

Visual Transformer (ViT) 代码实现 PyTorch版本
详细—Vision Transformer——ViT代码解读

讲解一:代码+理论

很详细:理论+代码----Vision Transformer(ViT)PyTorch代码全解析(附图解)

版本一:lucidrains

使用einopseinops和einsum:直接操作张量的利器
代码:
大佬复现-pytorch版
这个版本的代码超级受欢迎且易使用,我看的时候,Git repo已经被star 5.7k次。大家直接 pip install vit-pytorch就好。
所以作为初次接触vit的同学们来说,推荐看第二个版本,结构清晰明了。
笔记:
强推——很详细!-lucidrains-版本讲解
【实验】vit代码_第1张图片

1. 大佬复现版本给的使用案例

大家完全可以把这段代码copy-paste到自己的pycharm里,然后使用调试功能,一步步看ViT的每一步操作。

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,    # 图像大小
    patch_size = 32,     # patch大小(分块的大小)
    num_classes = 1000,  # imagenet数据集1000分类
    dim = 1024,          # position embedding的维度
    depth = 6,           # encoder和decoder中block层数是6
    heads = 16,          # multi-head中head的数量为16
    mlp_dim = 2048,
    dropout = 0.1,       # 
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)
print(preds.shape)  # (16, 1000)

2. Transformer结构

进行6次for循环,有6层encoder结构。
for循环内部依次使用multi-head attention和Feed Forward,
对应Transformer的Encoder结构中多头自注意力模块和MLP模块。在自注意力后及feed forward后,有残差连接。

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

PreNorm类代码如下,在使用multi-head attention和Feed Forward之前,首先对输入通过LayerNorm进行处理。

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

3. Attention

torch.chunk(tensor, chunk_num, dim)函数的功能:与torch.cat()刚好相反,它是将tensor按dim(行或列)分割成chunk_num个tensor块,返回的是一个元组。
attention操作的整体流程:

1.首先对输入生成query, key和value,这里的“输入”有可能是整个网络的输入,也可能是某个hidden layer的output。在这里,生成的qkv是个长度为3的元组,每个元组的大小为(1, 65, 1024)
2.对qkv进行处理,重新指定维度,得到的q, k, v维度均为(1, 16, 65, 64)
3.q和k做点乘,得到的dots维度为(1, 16, 65, 65)
4.对dots的最后一维做softmax,得到各个patch对其他patch的注意力得分
5.将attention和value做点乘
6.对各个维度重新排列,得到与输入相同维度的输出 (1, 65, 1024)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)    # 首先生成q,k,v
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

4. FeedForward

FeedForward模块共有2个全连接层,整个结构是:

1.首先过一个全连接层
2.经过GELU()激活函数进行处理
3.nn.Dropout(),以一定概率丢失掉一些神经元,防止过拟合
4.再过一个全连接层
5.nn.Dropout()
注意:GELU(x) = x * Φ(x), 其中,Φ(x)是高斯分布的累积分布函数 。

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),   # dim=1024, hidden_dim=2048
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

5. ViT操作流程

ViT的各个结构都写在了__init__()里,不再细讲,通过forward()来看ViT的整个前向传播过程(操作流程)。
整体流程:

1.首先对输入进来的img(256256大小),划分为3232大小的patch,共有8*8个。并将patch转换成embedding。(对应第26行代码)
2.生成cls_tokens (对应第28行代码)
3.将cls_tokens沿dim=1维与x进行拼接 (对应第29行代码)
4.生成随机的position embedding,每个embedding都是1024维 (对应代码14行和30行)
5.对输入经过Transformer进行编码(对应代码第32行)
6.如果是分类任务的话,截取第一个可学习的class embedding
7.最后过一个MLP Head用于分类。

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert  image_height % patch_height ==0 and image_width % patch_width == 0

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim)
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))					# nn.Parameter()定义可学习参数
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)        # b c (h p1) (w p2) -> b (h w) (p1 p2 c) -> b (h w) dim
        b, n, _ = x.shape           # b表示batchSize, n表示每个块的空间分辨率, _表示一个块内有多少个值

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)  # self.cls_token: (1, 1, dim) -> cls_tokens: (batchSize, 1, dim)  
        x = torch.cat((cls_tokens, x), dim=1)               # 将cls_token拼接到patch token中去       (b, 65, dim)
        x += self.pos_embedding[:, :(n+1)]                  # 加位置嵌入(直接加)      (b, 65, dim)
        x = self.dropout(x)

        x = self.transformer(x)                                                 # (b, 65, dim)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]                   # (b, dim)

        x = self.to_latent(x)                                                   # Identity (b, dim)
        print(x.shape)

        return self.mlp_head(x)                                                 #  (b, num_classes)


版本二 :rwightman 源码可直接运行

代码来自于哔哩哔哩博主:霹雳吧啦Wz
提供源码的视频链接:[使用pytorch搭建Vision Transformer(vit)模型](https://www.bilibili.com/video/BV1AL411W7dT/?
笔记:【超详细】初学者包会的Vision Transformer(ViT)的PyTorch实现代码学习

【实验】vit代码_第2张图片
spm_id_from=333.788)

"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import os
from collections import OrderedDict
from functools import partial

import torch
import torch.nn as nn

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def drop_path(x, drop_prob: float = 0., training: bool = False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

PatchEmbed模块

class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])  #grid_size=224÷16=14
        self.num_patches = self.grid_size[0] * self.grid_size[1]  
        #num_patches=14*14
        self.flatten = flatten
        
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        #proj使用卷积,embed_dimension这一参数在vision transformer的base16模型用到的是768,所以默认是768。但是如果是large或者huge模型的话embed_dim也会变。
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        #norm_layer默认是None,就是进行nn.Identity()也就是不做任何操作;如果有传入(非None),则会进行初始化一个norm_layer。
        
    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
            #assert:进行判断,如果代码模型定义和实际输入尺寸不同则会报错
        x = self.proj(x)  #用卷积实现序列化
        if self.flatten:
            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
            #flatten(2)操作实现了[B,C,H,W,]->[B,C,HW],指从维度2开始进行展平
            #transpose(1,2)操作实现了[B,C,HW]->[B,HW,C]
        x = self.norm(x)
        #通过norm层输出
        return x

Attention模块

该模块实现多头注意力机制。

class Attention(nn.Module):
    def __init__(self, 
                 dim,   #输入token的dim
                 num_heads=8,  #多头注意力中head的个数
                 qkv_bias=False,  #在生成qkv时是否使用偏置,默认否
                 attn_drop=0., 
                 proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads  #计算每一个head需要传入的dim
        self.scale = head_dim ** -0.5  #head_dim的-0.5次方,即1/根号d_k,即理论公式里的分母根号d_k
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  #qkv是通过1个全连接层参数为dim和3dim进行初始化的,也可以使用3个全连接层参数为dim和dim进行初始化,二者没有区别,
        self.attn_drop = nn.Dropout(attn_drop)#定义dp层 比率attn_drop
        self.proj = nn.Linear(dim, dim)  #再定义一个全连接层,是 将每一个head的结果进行拼接的时候乘的那个矩阵W^O
        self.proj_drop = nn.Dropout(proj_drop)#定义dp层 比率proj_drop

    def forward(self, x):#正向传播过程
    #输入是[batch_size, 
    #      num_patches+1, (base16模型的这个数是14*14)
    #      total_embed_dim(base16模型的这个数是768)]
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#qkv->[batchsize, num_patches+1, 3*total_embed_dim]
#reshape->[batchsize, num_patches+1, 3, num_heads, embed_dim_per_head]
#permute->[3, batchsize, num_heads, num_patches+1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]
        # make torchscript happy (cannot use tensor as tuple)
#q、k、v大小均[batchsize, num_heads, num_patches+1, embed_dim_per_head]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        #现在的操作都是对每个head进行操作
        #transpose是转置最后2个维度,@就是矩阵乘法的意思
        #q  [batchsize, num_heads, num_patches+1, embed_dim_per_head]
        #k^T[batchsize, num_heads, embed_dim_per_head, num_patches+1]
        #q*k^T=[batchsize, num_heads, num_patches+1, num_patches+1]
        #self.scale=head_dim的-0.5次方
        #至此完成了(Q*K^T)/根号d_k的操作
        attn = attn.softmax(dim=-1)
        #dim=-1表示在得到的结果的每一行上进行softmax处理,-1就是最后1个维度
        #至此完成了softmax[(Q*K^T)/根号d_k]的操作
        attn = self.attn_drop(attn)
      
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        #@->[batchsize, num_heads, num_patches+1, embed_dim_per_head]
        #这一步矩阵乘积就是加权求和
        #transpose->[batchsize, num_patches+1, num_heads, embed_dim_per_head]
        #reshape->[batchsize, num_patches+1, num_heads*embed_dim_per_head]即[batchsize, num_patches+1, total_embed_dim]
        #reshape实际上就实现了concat拼接
        x = self.proj(x)
        #将上一步concat的结果通过1个线性映射,通常叫做W,此处用全连接层实现
        x = self.proj_drop(x)
        #dropout
        #至此完成了softmax[(Q*K^T)/根号d_k]*V的操作
        #一个head的attention的全部操作就实现了
        return x

MLP Block(图中的名字)/FeedForward类(代码中的实现)

class FeedForward(nn.Module):
#全连接层1+GELU+dropout+全连接层2+dropout
#全连接层1的输出节点个数是输入节点个数的4倍,即mlp_ratio=4.
#全连接层2的输入节点个数是输出节点个数的1/4
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

Encoder Block主模块

【实验】vit代码_第3张图片

class Block(nn.Module):
    def __init__(self, 
                 dim, 
                 num_heads, 
                 mlp_ratio=4., 
                 qkv_bias=False, 
                 drop=0., 
                 #多头注意力模块中的最后的全连接层之后的dropout层对应的drop比率
                 attn_drop=0.,
                 #多头注意力模块中softmax[Q*K^T/根号d_k]之后的dropout层的drop比率
                 drop_path=0.,
                 #本代码用到的是DropPath方法(上面右图的DropPath),所以上面右图的两个droppath层有这个比率
                 act_layer=nn.GELU, 
                 norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        #第一层LN
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        #第一个多头注意力
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        #如果传入的drop_path大于0,就会实例化一个droppath方法;如果传入的drop_path等于0,则执行Identity()不做任何操作
        self.norm2 = norm_layer(dim)
        #第二个LN层
        mlp_hidden_dim = int(dim * mlp_ratio)
        #mlp层的隐层个数是输入的4倍,实例化一个MLP模块的时候需要传入mlp_hidden_dim这个参数,所以在此提前计算
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        #act_layer是激活函数

    def forward(self, x):
    #前向传播过程:
    #第一部分:LN+Mul-Head-Attention+ Dropout之后,加上第一个LN之前的输入
    #第二部分:LN+MLP+Dropout之后,加上第二个LN之前的输入
    #输出x
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

VisionTransformer类

也就是文章里的模型主体了

class VisionTransformer(nn.Module):
  def __init__(self, img_size=224, 
               patch_size=16, 
               in_chans=3, 
               num_classes=1000, 
               embed_dim=768, 
               depth=12,
               num_heads=12, 
               mlp_ratio=4., 
               qkv_bias=True, 
               representation_size=None,
#representation_size是最后的MLP Head中的pre-logits中的全连接层的节点个数,默认为None,此时就不会去构建这个pre-logits,也就是此时在MLP Head中只有一个全连接层,而没有pre-logits层。【pre-logits层是什么:就是全连接层+激活函数】
               distilled=False,
               #distilled后续的DeiT才用到这个参数
               drop_rate=0., 
               attn_drop_rate=0., 
               drop_path_rate=0., 
               embed_layer=PatchEmbed, 
               #这个参数是nn.Module类型,即模块PatchEmbed
               norm_layer=None,
               #这个参数也是nn.Module类型
               act_layer=None, 
               weight_init=''):
    super().__init__()
    self.num_classes = num_classes #复制参数
    self.num_features = self.embed_dim = embed_dim  #复制参数
    # num_features for consistency with other models
    
    self.num_tokens = 2 if distilled else 1 
    norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
    act_layer = act_layer or nn.GELU
    #因为ViT模型的distilled=False,所以前面这三句:
    #num_tokens=1
    #norm_layer=partial(nn.LayerNorm, eps=1e-6)
    #act_layer= nn.GELU
    
    self.patch_embed = embed_layer(
        img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) #对图片进行patch和embed
    num_patches = self.patch_embed.num_patches #得到patches的个数
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    #加上class_token,零矩阵初始化,尺寸1*1*embed_dim.
    #第一个1是batchsize维度,是为了后面进行拼接所以设置成1。
    #第二、三个维度就是1*768
    self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None #这一行可以直接忽略,本文(ViT)模型用不到dist_token
    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
    #position embedding,使用零矩阵初始化
    #尺寸为1 *(num_patches + self.num_tokens)* embed_dim
    #第一个维度1是batchsize维度
    #第二个维度:num_tokens=1(见本段代码第29行),num_patches在base16模型中是14*14=196,加一起就是197
    #第三个维度:embed_dim
    self.pos_drop = nn.Dropout(p=drop_rate)
    
    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
    #根据传入的drop_path_rate参数(默认为0),for i in的语句使得每一层的drop_path层的drop比率是递增的,但是默认为0,则不创建。
    # stochastic depth decay rule(随机深度衰减规则)
    
#下面利用for i in range(depth),即根据模型深度depth(默认=12)堆叠Block
    self.blocks = nn.Sequential(*[
        Block(
            dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
            attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
        for i in range(depth)])
    self.norm = norm_layer(embed_dim)
    
    #不用管下面这一段因为本模型中representation_size=None
    #前面提过这个参数的作用(本段代码第12行)
    # Representation layer
    if representation_size and not distilled:
        self.num_features = representation_size
        self.pre_logits = nn.Sequential(OrderedDict([
            ('fc', nn.Linear(embed_dim, representation_size)),
            ('act', nn.Tanh())
        ]))#其实就是全连接层+tanh激活函数
    else:
        self.pre_logits = nn.Identity()
    
    #下面就是最终用于分类的全连接层的实现了
    # Classifier head(s)
    self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()#输入向量长度为num_features(定义在本段代码第26行,这个参数值=embed_dim),输出的向量长度为num_classes类别数
    #下面的部分和ViT无关可以不看
    self.head_dist = None
    if distilled:
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
    self.init_weights(weight_init)   

  def forward_features(self, x):
      x = self.patch_embed(x)#这个模块第一部分讲了
      cls_token = self.cls_token.expand(x.shape[0], -1, -1)  #本段第40行附近有解释,原cls_token尺寸1*1*embed_dim,将其在BatchSize维度复制B份,现cls_token尺寸为B*1*embed_dim
      if self.dist_token is None:#本模型中这个值就是None
          x = torch.cat((cls_token, x), dim=1)
          #在维度1上进行拼接,即值为196的维度上拼接。本行之后->[B,14*14+1,embed_dim]
      else:#本模型不执行这句
          x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
          
      x = self.pos_drop(x + self.pos_embed)#加上position embedding再通过51行定义的dropout层
      x = self.blocks(x)#通过58行定义的transformer encoder堆叠模块
      x = self.norm(x)#通过norm
      if self.dist_token is None:#本模型该参数为None
          return self.pre_logits(x[:, 0])
          #x[:, 0]将class_token通过切片取出,因为拼接的时候放在了最前面
          #而前面提过pre_logits层在参数representation_size=None的时候返回nn.Identity()即无操作,所以本句输出就是x[:, 0]
      else:
          return x[:, 0], x[:, 1]
      
  def forward(self, x):
  #前向部分
      x = self.forward_features(x)#
      if self.head_dist is not None:
      #本模型head_dist=None(81行)所以不执行此分支 不用看
          x, x_dist = self.head(x[0]), self.head_dist(x[1])  # x must be a tuple
          if self.training and not torch.jit.is_scripting():
              # during inference, return the average of both classifier predictions
              return x, x_dist
          else:
              return (x + x_dist) / 2
      else:
          x = self.head(x)#直接来到这,head是79行定义的分类头
      return x

你可能感兴趣的:(实验代码,pytorch,深度学习,python)