PyTorch F.dropout为什么要加self.training 以及与nn.dropout()的区别

PyTorch的F.dropout为什么要加self.training

nn.dropout的用法详解

import torch.nn as nn
import torch

m = nn.Dropout(p=0.3)
input = torch.randn(3, 4, 5)
output = m(input)
print(output)

print('被置为0的个数为{}'.format(3*4*5-torch.sum(output.masked_fill(output != 0, 1.0 )).item()))

输出:

tensor([[[ 0.0000e+00, -1.7265e-01, -9.1897e-02,  3.5391e-04, -8.7150e-01],
         [-3.3677e-01,  1.2160e+00,  1.0352e+00,  2.4846e+00, -9.1686e-01],
         [ 2.3508e-01, -0.0000e+00,  1.3112e+00,  2.6958e-01, -2.2409e+00],
         [-1.3472e+00, -0.0000e+00,  8.3373e-01, -6.5853e-01, -2.8076e-01]],

        [[ 4.1611e-01,  1.4991e+00, -1.4775e+00, -3.7494e-01,  1.4327e-01],
         [ 1.5931e+00,  0.0000e+00,  0.0000e+00, -9.9996e-01, -0.0000e+00],
         [ 0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  4.2582e-01],
         [ 0.0000e+00,  7.3300e-01, -4.2084e-01,  2.0661e+00,  1.1305e+00]],

        [[ 1.1499e+00,  7.9644e-01,  0.0000e+00,  1.3319e+00, -0.0000e+00],
         [ 1.8994e+00, -0.0000e+00,  2.8527e+00, -1.8907e+00, -9.5148e-01],
         [ 0.0000e+00,  4.1316e+00,  0.0000e+00,  9.0102e-01, -7.8583e-01],
         [ 1.1153e+00, -2.3886e+00,  2.1948e+00,  1.5231e-01,  1.4871e+00]]])
被置为0的个数为16.0

注解:
可以看出nn.dropout对三维tensor的dropout是随机对全部3 * 4 * 5 = 60 个元素进行随机失活

实例应用

对三维的tensor的第二维使用相同的mask(dropout)

import torch.nn as nn
import torch

m = nn.Dropout(p=0.3)
mask = torch.ones(5,4)
mask = m(mask)
print(mask)

mask_unsqueeze = mask.unsqueeze(1)
print(mask_unsqueeze)

input = torch.rand(5,3,4)
print(mask_unsqueeze*input)

输出:

tensor([[1.4286, 1.4286, 0.0000, 1.4286],
        [0.0000, 0.0000, 1.4286, 1.4286],
        [1.4286, 0.0000, 1.4286, 1.4286],
        [1.4286, 1.4286, 1.4286, 1.4286],
        [1.4286, 0.0000, 0.0000, 0.0000]])
        
        
tensor([[[1.4286, 1.4286, 0.0000, 1.4286]],

        [[0.0000, 0.0000, 1.4286, 1.4286]],

        [[1.4286, 0.0000, 1.4286, 1.4286]],

        [[1.4286, 1.4286, 1.4286, 1.4286]],

        [[1.4286, 0.0000, 0.0000, 0.0000]]])
        
        
tensor([[[1.1233, 1.1809, 0.0000, 0.8632],
         [0.9204, 0.0115, 0.0000, 0.6817],
         [0.8842, 0.1036, 0.0000, 0.5776]],

        [[0.0000, 0.0000, 0.9940, 0.0435],
         [0.0000, 0.0000, 0.7082, 0.1009],
         [0.0000, 0.0000, 0.3689, 0.4203]],

        [[0.0117, 0.0000, 0.5813, 0.5428],
         [1.3159, 0.0000, 1.1855, 0.8073],
         [0.0849, 0.0000, 0.8872, 0.4103]],

        [[0.2462, 0.3455, 0.4301, 0.6576],
         [0.6236, 0.6201, 1.3619, 1.3616],
         [0.6890, 0.0831, 0.9151, 1.0038]],

        [[0.2905, 0.0000, 0.0000, 0.0000],
         [1.3565, 0.0000, 0.0000, 0.0000],
         [0.9842, 0.0000, 0.0000, 0.0000]]])

你可能感兴趣的:(Pytorch)