pytorch中的掩膜mask

1结果维度不变

y_hat=torch.tensor([[1,2,3],[4,5,6]])
y=torch.tensor([1,0,0],dtype=torch.bool)
print(y_hat[:,y])
#tensor([[1],[4]])     y_hat[:,y].shape ([2,1])

2 结果维度改变

y_hat=torch.tensor([[1,2,3],[4,5,6]])
y=torch.tensor([1,0,0],dtype=torch.bool)
print(y_hat[range(len(y_hat)),y])
print(y_hat[range(len(y_hat)),y].shape)
#tensor([1, 4])    torch.Size([2])

3 结果维度改变

y_hat=torch.tensor([[1,2,3],[4,5,6]])
y_2=torch.tensor([[1,0,0],[0,0,1]]).bool()
print(y_hat[y_2])
#tensor([1, 6])

你可能感兴趣的:(深度学习,pytorch,深度学习,python)