import torch.nn as nn
import torch
m = nn.Dropout(p=0.3)
input = torch.randn(3, 4, 5)
output = m(input)
print(output)
print('被置为0的个数为{}'.format(3*4*5-torch.sum(output.masked_fill(output != 0, 1.0 )).item()))
输出:
tensor([[[ 0.0000e+00, -1.7265e-01, -9.1897e-02, 3.5391e-04, -8.7150e-01],
[-3.3677e-01, 1.2160e+00, 1.0352e+00, 2.4846e+00, -9.1686e-01],
[ 2.3508e-01, -0.0000e+00, 1.3112e+00, 2.6958e-01, -2.2409e+00],
[-1.3472e+00, -0.0000e+00, 8.3373e-01, -6.5853e-01, -2.8076e-01]],
[[ 4.1611e-01, 1.4991e+00, -1.4775e+00, -3.7494e-01, 1.4327e-01],
[ 1.5931e+00, 0.0000e+00, 0.0000e+00, -9.9996e-01, -0.0000e+00],
[ 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 4.2582e-01],
[ 0.0000e+00, 7.3300e-01, -4.2084e-01, 2.0661e+00, 1.1305e+00]],
[[ 1.1499e+00, 7.9644e-01, 0.0000e+00, 1.3319e+00, -0.0000e+00],
[ 1.8994e+00, -0.0000e+00, 2.8527e+00, -1.8907e+00, -9.5148e-01],
[ 0.0000e+00, 4.1316e+00, 0.0000e+00, 9.0102e-01, -7.8583e-01],
[ 1.1153e+00, -2.3886e+00, 2.1948e+00, 1.5231e-01, 1.4871e+00]]])
被置为0的个数为16.0
注解:
可以看出nn.dropout对三维tensor的dropout是随机对全部3 * 4 * 5 = 60
个元素进行随机失活
对三维的tensor的第二维使用相同的mask(dropout)
import torch.nn as nn
import torch
m = nn.Dropout(p=0.3)
mask = torch.ones(5,4)
mask = m(mask)
print(mask)
mask_unsqueeze = mask.unsqueeze(1)
print(mask_unsqueeze)
input = torch.rand(5,3,4)
print(mask_unsqueeze*input)
输出:
tensor([[1.4286, 1.4286, 0.0000, 1.4286],
[0.0000, 0.0000, 1.4286, 1.4286],
[1.4286, 0.0000, 1.4286, 1.4286],
[1.4286, 1.4286, 1.4286, 1.4286],
[1.4286, 0.0000, 0.0000, 0.0000]])
tensor([[[1.4286, 1.4286, 0.0000, 1.4286]],
[[0.0000, 0.0000, 1.4286, 1.4286]],
[[1.4286, 0.0000, 1.4286, 1.4286]],
[[1.4286, 1.4286, 1.4286, 1.4286]],
[[1.4286, 0.0000, 0.0000, 0.0000]]])
tensor([[[1.1233, 1.1809, 0.0000, 0.8632],
[0.9204, 0.0115, 0.0000, 0.6817],
[0.8842, 0.1036, 0.0000, 0.5776]],
[[0.0000, 0.0000, 0.9940, 0.0435],
[0.0000, 0.0000, 0.7082, 0.1009],
[0.0000, 0.0000, 0.3689, 0.4203]],
[[0.0117, 0.0000, 0.5813, 0.5428],
[1.3159, 0.0000, 1.1855, 0.8073],
[0.0849, 0.0000, 0.8872, 0.4103]],
[[0.2462, 0.3455, 0.4301, 0.6576],
[0.6236, 0.6201, 1.3619, 1.3616],
[0.6890, 0.0831, 0.9151, 1.0038]],
[[0.2905, 0.0000, 0.0000, 0.0000],
[1.3565, 0.0000, 0.0000, 0.0000],
[0.9842, 0.0000, 0.0000, 0.0000]]])