ViT中的DropPath代码

DropPath代码

  • DropPath代码

DropPath代码

最近在学习ViT模型,记录一下其中的droppath操作,实际上就是对一个batch中随机选择一定数量的sample,将其特征值变为0:
ViT github源码地址链接

def drop_path(x, drop_prob: float = 0., training: bool = False):

    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    # shape (b, 1, 1, 1...)
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    # 向下取整用于确定保存哪些样本
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    # 除以keep_prob是为了保持训练和测试时期望一致
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

你可能感兴趣的:(深度学习,python,pytorch)