Pytorch scatter_()

scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中。

import torch
import numpy as np

def one_hot(labels):
    _labels = torch.zeros([2,2,4])
    _labels.scatter_(dim=0, index=labels.long(), value=1)#scatter_(input, dim, index, src)
    return _labels

temp = torch.Tensor([[[1,1,0,0],
                    [1,1,0,0]]])
target2 = one_hot(temp)
print(target2)

输出:channel0是背景,channel1是前景。

tensor([[[0., 0., 1., 1.],
         [0., 0., 1., 1.]],

        [[1., 1., 0., 0.],
         [1., 1., 0., 0.]]])

你可能感兴趣的:(Pytorch)