20220222:技巧记录-pytorch和numpy的互转

1:torch.gather对应的numpy操作

首先需要了解gather操作,具体图解可参考 图解PyTorch中的torch.gather函数 - 知乎

 demo1:输入行向量index,并替换列索引(dim=0) 

# torch
tensor0 = torch.tensor([[ 3,  4,  5],
                        [ 6,  7,  8],
                        [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
tensor1 = torch.gather(tensor0, 0, index)
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')

# numpy 
tensor0 = np.array([[ 3,  4,  5],
                    [ 6,  7,  8],
                    [ 9, 10, 11]])
index = np.array([[2, 1, 0]])
tensor1 = tensor0[index, np.arange(index.shape[1],dtype=int)]
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')

20220222:技巧记录-pytorch和numpy的互转_第1张图片

demo2:输入行向量index,并替换列索引(dim=1)

# torch
tensor0 = torch.tensor([[ 3,  4,  5],
                        [ 6,  7,  8],
                        [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]])
tensor1 = torch.gather(tensor0, 1, index)
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')

# numpy
tensor0 = np.array([[ 3,  4,  5],
                    [ 6,  7,  8],
                    [ 9, 10, 11]])
index = np.array([[2, 1, 0]])
tensor1 = tensor0[np.zeros(index.shape[1],dtype=int), index]
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')

20220222:技巧记录-pytorch和numpy的互转_第2张图片    

 demo3:输入列向量index,并替换列索引(dim=0) 

# torch
tensor0 = torch.tensor([[ 3,  4,  5],
                        [ 6,  7,  8],
                        [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]]).t()
tensor1 = torch.gather(tensor0, 0, index)
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')

# numpy
tensor0 = np.array([[ 3,  4,  5],
                    [ 6,  7,  8],
                    [ 9, 10, 11]])
index = np.array([[2, 1, 0]])
tensor1 = tensor0[index, np.zeros(index.shape[1],dtype=int)].T
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')

20220222:技巧记录-pytorch和numpy的互转_第3张图片 

 demo4:输入列向量index,并替换行索引(dim=1) 

# torch
tensor0 = torch.tensor([[ 3,  4,  5],
                        [ 6,  7,  8],
                        [ 9, 10, 11]])
index = torch.tensor([[2, 1, 0]]).t()
tensor1 = torch.gather(tensor0, 1, index)
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')

# numpy
tensor0 = np.array([[ 3,  4,  5],
                    [ 6,  7,  8],
                    [ 9, 10, 11]])
index = np.array([[2, 1, 0]])
tensor1 = tensor0[np.arange(index.shape[1],dtype=int), index].T
print(tensor0, tensor0.shape)
print(index, index.shape)
print(tensor1, tensor1.shape)
print('********************************')

20220222:技巧记录-pytorch和numpy的互转_第4张图片

 

2:torch.view对应的numpy操作

直接替换为 reshape 即可

# torch
tensor0 = torch.tensor([[ 3,  4,  5],
                        [ 6,  7,  8],
                        [ 9, 10, 11]])
tensor1 = tensor0.view(1, -1)
print(tensor0, tensor0.shape)
print(tensor1, tensor1.shape)
print('********************************')

# numpy
tensor0 = np.array([[ 3,  4,  5],
                    [ 6,  7,  8],
                    [ 9, 10, 11]])
tensor1 = tensor0.reshape(1, -1)
print(tensor0, tensor0.shape)
print(tensor1, tensor1.shape)
print('********************************')

20220222:技巧记录-pytorch和numpy的互转_第5张图片

 

3:torch.argmax对应的numpy操作

直接替换为 np.argmax 即可

# torch
tensor0 = torch.tensor([[ 3,  4,  5],
                        [ 6,  7,  8],
                        [ 9, 10, 11]])
tensor1 = torch.argmax(tensor0, 0)
print(tensor0, tensor0.shape)
print(tensor1, tensor1.shape)
print('********************************')

# numpy
tensor0 = np.array([[ 3,  4,  5],
                    [ 6,  7,  8],
                    [ 9, 10, 11]])
tensor1 = np.argmax(tensor0, 0)
print(tensor0, tensor0.shape)
print(tensor1, tensor1.shape)
print('********************************')

20220222:技巧记录-pytorch和numpy的互转_第6张图片

4:torch.max对应的numpy操作

直接替换为 np.max 即可

# torch
tensor0 = torch.tensor([[ 3,  4,  5],
                        [ 6,  7,  8],
                        [ 9, 10, 11]])
tensor1 = torch.max(tensor0, 0)
print(tensor0, tensor0.shape)
print(tensor1[0])
print(tensor1[1])
print('********************************')

# numpy
tensor0 = np.array([[ 3,  4,  5],
                    [ 6,  7,  8],
                    [ 9, 10, 11]])
tensor1 = np.max(tensor0, 0)
print(tensor0, tensor0.shape)
print(tensor1)
print('********************************')

20220222:技巧记录-pytorch和numpy的互转_第7张图片 

 

你可能感兴趣的:(深度学习trick,pytorch,深度学习,python)