Pytorch中常用函数

1 矩阵方面

1.1 torch.unsqueeze(input, dim, out=None)

作用:拓展维度
参数:
tensor (Tensor) – 输入张量
dim (int) – 插入维度的索引
out (Tensor, optional) – 返回张量

import torch
x = torch.Tensor([1, 2, 3, 4])  # torch.Tensor是默认的tensor类型(torch.FlaotTensor)的简称。

print('-' * 50)
print(x)  # tensor([1., 2., 3., 4.])
print(x.size())  # torch.Size([4])
print(x.dim())  # 1
print(x.numpy())  # [1. 2. 3. 4.]

print('-' * 50)
print(torch.unsqueeze(x, 0)) # tensor([[1., 2., 3., 4.]])一维变二维
print(torch.unsqueeze(x, 0).size())  # torch.Size([1, 4])
print(torch.unsqueeze(x, 0).dim())  # 2
print(torch.unsqueeze(x, 0).numpy())  # [[1. 2. 3. 4.]]

print('-' * 50)
print(torch.unsqueeze(x, 1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
print(torch.unsqueeze(x, 1).size())  # torch.Size([4, 1])
print(torch.unsqueeze(x, 1).dim())  # 2

print('-' * 50)
print(torch.unsqueeze(x, -1))
#相当于torch.unsqueeze(x, 1)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
print(torch.unsqueeze(x, -1).size())  # torch.Size([4, 1])
print(torch.unsqueeze(x, -1).dim())  # 2

1.2 torch.squeeze_()

这里提一下torch.unsqueeze()和torch.unsqueeze_()
torch的F_()函数和F()函数作用上是一致的,只是F_()函数是作用于变量本身,也就是不占用额外内存(in place),比如:

import torch
x=torch.tensor([1,2,3,4])
a=torch.unsqueeze(x,0)
print(a) # tensor([[1., 2., 3., 4.]])

print(x.unzqueeze_(0)) #tensor([[1., 2., 3., 4.]])
print(x) #tensor([[1., 2., 3., 4.]])
# x本身的值已经改变

1.3 torch.squeeze(input, dim=None, out=None)

作用:降维
将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

m = torch.zeros(2, 1, 2, 1, 2)
print(m.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m)
print(n.size())  # torch.Size([2, 2, 2])

n = torch.squeeze(m, 1) # 当给定dim时,那么挤压操作只在给定维度上
print(n.size())  # torch.Size([2, 2, 1, 2])

n = torch.squeeze(m, 0)  #给定的dim不是1,不会挤压
print(n.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m, 2)
print(n.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m, 3)
print(n.size())  # torch.Size([2, 1, 2, 2])

1.4 torch.view(dim1, dim2)

作用:相当于resize

import torch

a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=torch.Tensor([1,2,3,4,5,6])

print(a.view(1,6)) #tensor([[1., 2., 3., 4., 5., 6.]])
print(b.view(1,6))
print(a.view(3,2))
#tensor([[1., 2.],
#       [3., 4.],
#        [5., 6.]])

1.4 torch.permute()

作用:将tensor维度换位

2 网络操作

你可能感兴趣的:(pytorch,Deep,Learning,pytorch,python,numpy)