pytorch中基础函数介绍

文章目录

  • 前言
  • 一、pytorch
    • 1.unsqueeze()函数
    • 2.stack()函数
    • 3.view()函数
    • 4.repeat()函数
    • 5.max()函数
    • 6.torch.squeeze()函数
  • 总结


前言

在深度学习中,需要用到很多pytorch的一些基础函数,这里列出了一些在CNN、RNN、LSTM中对tensor需要用到的常用函数

一、pytorch

1.unsqueeze()函数

代码如下(示例):

x = torch.tensor([1.5, -0.5, 3.0]).unsqueeze(0)

x初始的shape为torch.Size([3])
unsqueeze()表示扩展指定维度返回一个新的张量
unsqueeze(0)后x的shapetorch.Size([1,3])

2.stack()函数

使用堆叠的方式合并多个张量(而cat拼接操作是在现有维度上合并数据,并不会创建新的维度)
代码如下(示例):

T1 = torch.tensor([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
                 [40, 50, 60],
                 [70, 80, 90]])
print(torch.stack((T1,T2),dim=0).shape)

#cat函数
#torch.cat(tensors, dim=0, *, out=None) → Tensor
 x = torch.randn(2, 3)#shape:torch.Size([2,3])
 torch.cat((x, x, x), 0)#shape:torch.Size([6,3])
 torch.cat((x, x, x), 1)#shape:torch.Size([2,9])

T1 、T2的shape为:torch.Size([3, 3])
使用stack函数后,在第0维增加维度,shape变成torch.Size([2, 3, 3])

3.view()函数

代码如下(示例):

batch=3,n_q=2,n_head=8,d_q=256
a=(batch,n_q,n_head*d_q)
b=a.view(batch, n_q, n_head, d_q)
c=b.permute(2, 0, 1, 3)#将tensor的维度换位
d=d.contiguous().view(-1, n_q, d_q)

a的shape为:torch.Size([3, 2,8*256])
b的shape为:torch.Size([3,2,8,256])
c的shape为:torch.Size([8,3,2,256])
d的shape为:torch.Size([24,2,256])

4.repeat()函数

代码如下(示例):

mask = torch.zeros(3, 2, 64).bool()
mask = mask.repeat(8, 1, 1)
mask.shape

repeat()函数对数据进行复制并保存到内存中
repeat后mask的shape为:torch.Size([24, 2,64])

5.max()函数

代码如下(示例):

B=tensor([[[3, 2],
         [1, 4]],

        [[5, 6],
         [7, 8]]])
m = torch.max(B, dim=0)
print(m)
tensor([[5, 6],
        [7, 8]])

m = torch.max(B, dim=1)
print(m)
tensor([[3, 4],
        [7, 8]])

m = torch.max(B, dim=2)
print(m)
tensor([[3, 4],
        [6, 8]])

6.torch.squeeze()函数

代码如下(示例):
torch.unsqueeze(input, dim, out=None)扩展维度
torch.squeeze(input, dim=None, out=None)挤压维度

import torch
x = torch.Tensor([1, 2, 3, 4])  #shape:torch.Size([4])
print(torch.unsqueeze(x, 1))#shape:torch.Size([4,1])

m = torch.zeros(2, 1, 2, 1, 2)
print(m.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m)
print(n.size())  # torch.Size([2, 2, 2])

n = torch.squeeze(m, 1)
print(n.size())  # torch.Size([2, 2, 1, 2])

总结

这里介绍了pytorch的常用函数,在定义网络结构和forward时会经常用到这些函数,希望对大家有所帮助

你可能感兴趣的:(人工智能,深度学习,pytorch)