例子:根据指定索引换掉torch的值

        一个 (n, 3) 形状的tensorA,一个 (n,) 形状的tensorB,一个 (n,) 形状的tensorC。想将tensorB作为tensorA的dim=1维度的索引,用tensorC替换掉tensorA对应的值。

下面是原地修改实现

import torch

# 创建示例数据
n = 4
tensorA = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
tensorB = torch.tensor([0, 2, 1, 0])
tensorC = torch.tensor([100, 200, 300, 400])

# 使用 tensorB 作为索引,在 tensorA 的 dim=1 维度上进行替换
tensorA.scatter_(1, tensorB.unsqueeze(1), tensorC.unsqueeze(1))

print(tensorA)


# 输出
# tensor([[100,   2,   3],
#         [  4,   5, 200],
#         [  7, 300,   9],
#         [400,  11,  12]])

你可能感兴趣的:(AI,pytorch)