1:torch.gather对应的numpy操作
首先需要了解gather操作,具体图解可参考 图解PyTorch中的torch.gather函数 - 知乎
demo1:输入行向量index,并替换列索引(dim=0)
# torch
tensor0 = torch.tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
tensor1 = torch.gather(tensor0, 0, index)
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')
# numpy
tensor0 = np.array([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = np.array([[2, 1, 0]])
tensor1 = tensor0[index, np.arange(index.shape[1],dtype=int)]
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')
demo2:输入行向量index,并替换列索引(dim=1)
# torch
tensor0 = torch.tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
tensor1 = torch.gather(tensor0, 1, index)
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')
# numpy
tensor0 = np.array([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = np.array([[2, 1, 0]])
tensor1 = tensor0[np.zeros(index.shape[1],dtype=int), index]
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')
demo3:输入列向量index,并替换列索引(dim=0)
# torch
tensor0 = torch.tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]]).t()
tensor1 = torch.gather(tensor0, 0, index)
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')
# numpy
tensor0 = np.array([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = np.array([[2, 1, 0]])
tensor1 = tensor0[index, np.zeros(index.shape[1],dtype=int)].T
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')
demo4:输入列向量index,并替换行索引(dim=1)
# torch
tensor0 = torch.tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]]).t()
tensor1 = torch.gather(tensor0, 1, index)
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')
# numpy
tensor0 = np.array([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
index = np.array([[2, 1, 0]])
tensor1 = tensor0[np.arange(index.shape[1],dtype=int), index].T
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')
2:torch.view对应的numpy操作
直接替换为 reshape 即可
# torch
tensor0 = torch.tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor1 = tensor0.view(1, -1)
print(tensor0, tensor0.shape)
print(tensor1, tensor1.shape)
print('********************************')
# numpy
tensor0 = np.array([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor1 = tensor0.reshape(1, -1)
print(tensor0, tensor0.shape)
print(tensor1, tensor1.shape)
print('********************************')
3:torch.argmax对应的numpy操作
直接替换为 np.argmax 即可
# torch
tensor0 = torch.tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor1 = torch.argmax(tensor0, 0)
print(tensor0, tensor0.shape)
print(tensor1, tensor1.shape)
print('********************************')
# numpy
tensor0 = np.array([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor1 = np.argmax(tensor0, 0)
print(tensor0, tensor0.shape)
print(tensor1, tensor1.shape)
print('********************************')
4:torch.max对应的numpy操作
直接替换为 np.max 即可
# torch
tensor0 = torch.tensor([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor1 = torch.max(tensor0, 0)
print(tensor0, tensor0.shape)
print(tensor1[0])
print(tensor1[1])
print('********************************')
# numpy
tensor0 = np.array([[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
tensor1 = np.max(tensor0, 0)
print(tensor0, tensor0.shape)
print(tensor1)
print('********************************')