在深度学习中,需要用到很多pytorch的一些基础函数,这里列出了一些在CNN、RNN、LSTM中对tensor需要用到的常用函数
代码如下(示例):
x = torch.tensor([1.5, -0.5, 3.0]).unsqueeze(0)
x初始的shape为torch.Size([3])
unsqueeze()表示扩展指定维度返回一个新的张量
unsqueeze(0)后x的shapetorch.Size([1,3])
使用堆叠的方式合并多个张量(而cat拼接操作是在现有维度上合并数据,并不会创建新的维度)
代码如下(示例):
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((T1,T2),dim=0).shape)
#cat函数
#torch.cat(tensors, dim=0, *, out=None) → Tensor
x = torch.randn(2, 3)#shape:torch.Size([2,3])
torch.cat((x, x, x), 0)#shape:torch.Size([6,3])
torch.cat((x, x, x), 1)#shape:torch.Size([2,9])
T1 、T2的shape为:torch.Size([3, 3])
使用stack函数后,在第0维增加维度,shape变成torch.Size([2, 3, 3])
代码如下(示例):
batch=3,n_q=2,n_head=8,d_q=256
a=(batch,n_q,n_head*d_q)
b=a.view(batch, n_q, n_head, d_q)
c=b.permute(2, 0, 1, 3)#将tensor的维度换位
d=d.contiguous().view(-1, n_q, d_q)
a的shape为:torch.Size([3, 2,8*256])
b的shape为:torch.Size([3,2,8,256])
c的shape为:torch.Size([8,3,2,256])
d的shape为:torch.Size([24,2,256])
代码如下(示例):
mask = torch.zeros(3, 2, 64).bool()
mask = mask.repeat(8, 1, 1)
mask.shape
repeat()函数对数据进行复制并保存到内存中
repeat后mask的shape为:torch.Size([24, 2,64])
代码如下(示例):
B=tensor([[[3, 2],
[1, 4]],
[[5, 6],
[7, 8]]])
m = torch.max(B, dim=0)
print(m)
tensor([[5, 6],
[7, 8]])
m = torch.max(B, dim=1)
print(m)
tensor([[3, 4],
[7, 8]])
m = torch.max(B, dim=2)
print(m)
tensor([[3, 4],
[6, 8]])
代码如下(示例):
torch.unsqueeze(input, dim, out=None)扩展维度
torch.squeeze(input, dim=None, out=None)挤压维度
import torch
x = torch.Tensor([1, 2, 3, 4]) #shape:torch.Size([4])
print(torch.unsqueeze(x, 1))#shape:torch.Size([4,1])
m = torch.zeros(2, 1, 2, 1, 2)
print(m.size()) # torch.Size([2, 1, 2, 1, 2])
n = torch.squeeze(m)
print(n.size()) # torch.Size([2, 2, 2])
n = torch.squeeze(m, 1)
print(n.size()) # torch.Size([2, 2, 1, 2])
这里介绍了pytorch的常用函数,在定义网络结构和forward时会经常用到这些函数,希望对大家有所帮助