pytorch一些函数的基本用法及作用

Pytorch一些函数的基本用法及作用

  • Pytorch的一些函数
    • 1.torch.cat()
    • 2.torch.transpose()
    • 3.tensor.view()
    • 4.tensor.permute()
    • 5.tensor.squeeze()
    • 6.tensor.unsqueeze()
    • 7.torch.bmm()
    • 8.torch.dot()
    • 9.tensor.dtype()
    • 10.torch.tensor()
    • 11.tensor.long()、tensor.float()
    • 12.torch.nonzero()
    • 13.torch.cat()
    • 14.tensor.chunk()


Pytorch的一些函数

1.torch.cat()

之前写写过
见pytorch中的torch.cat()矩阵拼接的用法及理解

2.torch.transpose()

这个函数的作用主要是将矩阵交换两个维度。

import torch

a = torch.Tensor([[1, 2, 3]])
b = torch.Tensor([[[1, 2, 1], [1, 1, 1]]])
aa = torch.transpose(a, 0, 1)  # 交换第一维和第二维
bb = torch.transpose(b, 0, -1)  # 交换第一维和最后一维

print(a.shape, "----", aa.shape)
print(b.shape, "----", bb.shape)
print(aa)
print(bb)

结果
pytorch一些函数的基本用法及作用_第1张图片

3.tensor.view()

这个函数能对维度进行变换,view中的参数即为每一维应当包含的元素个数,如果-1即让其自动计算出。

import torch

a = torch.Tensor([[1, 2, 3]])
b = torch.Tensor([[[1, 2, 1], [3, 1, 1]]])
aa = a.view(3, -1) # 第一维3个元素,第二维任意
bb = b.view(3, 2)  # 第一维3个元素第二维2个元素
bbb = b.view(1, 1, 6)  # 最后一维6个元素

print(aa.shape)
print(bb.shape)
print(bbb.shape)
print(bb)

运行结果
pytorch一些函数的基本用法及作用_第2张图片

4.tensor.permute()

这个函数能够对任意维度进行任意交换,transpose只能对两个维度进行交换。用法是:原本0,1,2代表第1,2,3维,那么现在b.permute(2,0,1)即现在第一维是原来的第三维,第二维是原来的第一维,第三维是原来的第二维。

import torch

a = torch.Tensor([[1, 2, 3]])
b = torch.Tensor([[[1, 2, 1], [3, 1, 1]]])
bb = b.permute(2, 0, 1)  
print(b.shape)
print(bb.shape)

运行结果
在这里插入图片描述

5.tensor.squeeze()

作用就是降维,去掉维度中元素为1的维度

import torch

a = torch.Tensor([[[[1], [2], [3]]]])
b = torch.Tensor([[[1, 1, 1], [1, 1, 1]]])

aa = a.squeeze()
bb = b.squeeze()
aaa = a.squeeze(0)  # 0代表将第一维去掉

print(a.shape)
print(b.shape)

print(aa.shape)
print(bb.shape)

print(aaa.shape)

运行结果
pytorch一些函数的基本用法及作用_第3张图片

6.tensor.unsqueeze()

作用就是升1维

import torch

a = torch.Tensor([[[[1], [2], [3]]]])
b = torch.Tensor([[[1, 1, 1], [1, 1, 1]]])

aa = a.unsqueeze(0)  # 在第一维前增加一维
bb = b.unsqueeze(-1)  # 在最后一维后增加一维

print(a.shape)
print(b.shape)
 
print(aa.shape)
print(bb.shape)

运行结果
pytorch一些函数的基本用法及作用_第4张图片

7.torch.bmm()

实现的是矩阵乘法不过必须是三维的,最后两维需要满足矩阵乘法的格式

import torch

a = torch.Tensor([[[[1], [2], [3]]]])
b = torch.Tensor([[[0, 1, 1], [1, 2, 1]]])

print(a.shape)
print(b.shape)

print("----------------")
a = a.view(1, -1, 3)  # 维度转换使其能够进行矩阵乘法
b = b.view(1, 3, -1)
print(a.shape)
print(b.shape)
print("----------------")
print(a)
print(b)
c = torch.bmm(a, b)
print(c)
print(c.shape)

运行结果
pytorch一些函数的基本用法及作用_第5张图片

8.torch.dot()

求一维向量的内积,即对应元素相乘然后求和

import torch

a = torch.Tensor([1, 1, 2, 3])
b = torch.Tensor([1, 1, 1, 1])
c = torch.dot(a, b)
print(c)

运行结果
1x1 + 1x1 + 2x1 + 3x1 = 7
在这里插入图片描述

9.tensor.dtype()

具体查看是什么tensor类型

# coding:utf-8
import torch
a = torch.Tensor([5])
print(a.dtype)

运行结果
在这里插入图片描述

10.torch.tensor()

能够生成不同类型的tensor

# coding:utf-8
import torch
a = torch.tensor([5], dtype=torch.float64)
b = torch.tensor([6], dtype=torch.float32)
c = torch.tensor([7], dtype=torch.int64)
print(a.dtype)
print(b.dtype)
print(c.dtype)

结果
在这里插入图片描述

11.tensor.long()、tensor.float()

转换tensor的数据类型

# coding:utf-8
import torch
a = torch.tensor([5], dtype=torch.float64).long()
b = torch.tensor([6], dtype=torch.float32).int()
c = torch.tensor([7], dtype=torch.int64).float()
print(a.dtype)
print(b.dtype)
print(c.dtype)

结果如下
在这里插入图片描述

12.torch.nonzero()

可以用来找对应元素的下标

batch = torch.tensor([[1, 2, 3, 0, 0], [1, 0, 3, 1, 1]])
pos_sen = torch.nonzero(batch==0).squeeze()
print(pos_sen)

输出:
pytorch一些函数的基本用法及作用_第6张图片

13.torch.cat()

用于矩阵的拼接

import torch
a = torch.tensor([[1,2,3]])
b = torch.tensor([[4,5,6]])
c = torch.cat((a, b), axis=0)
d = torch.cat((a, b), axis=1)
print(c)
print(d)

输出:
pytorch一些函数的基本用法及作用_第7张图片

14.tensor.chunk()

用于矩阵的拆分

import torch
a = torch.tensor([[1,2,3], [2,2,3]])
b, c = a.chunk(2, dim=0)
print(b)
print(c)

a = torch.tensor([[1,2,3], [2,2,3]])
b, c, d = a.chunk(3, dim=1)
print(b)
print(c)
print(d)

结果:
pytorch一些函数的基本用法及作用_第8张图片

你可能感兴趣的:(pytorch,pytorch,深度学习,python)