pytorch 重复采样 与 非重复采样

import torch
import torch.nn.functional as F
from torch.autograd import *

a = Variable(torch.FloatTensor([[0,0,0,0,0,0,90,100]]))
b=F.softmax(a,-1)

print(b.multinomial()) # 76
print(b.multinomial(2)) # 6,77,6
print(b.multinomial(2,True)) # 7,77,66,76,6

也可以试试
WeightedRandomSampler
主要是replace也就是True False那个参数决定采样数据是否重复

你可能感兴趣的:(PyTorch)