pytorch Tensor基础操作汇总

以下内容来自龙良曲老师的pytorch课程

目录

  • 1、数据类型
  • 2、维度变换
    • view/reshape
    • Squeese/unsqueeze
    • Expand/repeat
    • permute
  • 3、Broadcast
    • 什么时候用broadcast
  • 4、拼接和拆分
    • cat
    • stack
    • split
    • chunk
  • 5、数学运算
    • 基本运算(四则)
    • 矩阵相乘 matmul
    • power
    • 近似值
    • clamp
  • 6、统计属性
    • norm 范数
    • mean,sum,min,max,prod
    • dim,keepdim
    • Top-k
    • 比较
  • 7、进阶操作
    • where
    • gather

1、数据类型

在python中的各种数据类型都用Tensor进行概括:
pytorch Tensor基础操作汇总_第1张图片
对于string类型,pytorch中要计算string类型的数据,需要先将其也转化为可以处理的Tensor类型

  • one-hot 编码
    [0,1,0,0],[1,0,0,0]…
  • Embedding
    word2vec,glove

pytorch中的数据类型如下所列:

pytorch Tensor基础操作汇总_第2张图片
在程序中可以用.type()和isinstance()检验类型
pytorch Tensor基础操作汇总_第3张图片
注意部署在CPU和GPU上是不一样的
pytorch Tensor基础操作汇总_第4张图片
注意pytorch里标量是0维的,生成方法如下:
pytorch Tensor基础操作汇总_第5张图片
注意pytorch中标量的shape和size都是空数组,长度为0
pytorch Tensor基础操作汇总_第6张图片
注意和dim为1的张量作区分:
在这里插入图片描述
还可以从numpy中转化得到Tensor
在这里插入图片描述
多维的情况(Dim=3)
pytorch Tensor基础操作汇总_第7张图片

2、维度变换

pytorch Tensor基础操作汇总_第8张图片

view/reshape

view是之前版本的api,与reshape完全一致

举例,现在有一个四维的Tensor。在mnist数据集中,它可以代表4张图片(batch size)灰度信息,尺寸是28*28

a=torch.rand(4,1,28,28)

#view函数 要满足prod相等,注意要表示正确的实际意义意义
#此操作的意义是对每一张图片,直接784个数字作为一维,忽略了二维位置信息,适用于全连接层
a.view(4,28*28)
#此操作看成四个二维数组
a.view(4*1,28,28)

Squeese/unsqueeze

unsqueeze 可以添加更高维度。参数为非负数的话在之前插入维度,负数的话在索引之后插入
在这里插入图片描述

pytorch Tensor基础操作汇总_第9张图片
举例:
pytorch Tensor基础操作汇总_第10张图片

当维度不同的Tensor相加的时候,要先用unsqueeze进行维度展开,然后将各个维度大小进行变换后相加:
pytorch Tensor基础操作汇总_第11张图片
squeeze:删减维度,无参数则挤压掉所有可以挤压的维度(dim size=1的维度)给出索引则挤压掉指定维度。如果输入了不能挤压的维度,不会报错,但是Tensor不变
pytorch Tensor基础操作汇总_第12张图片

Expand/repeat

比如现在有维度为[32]和[4,32,14,14]的两个Tensor,可以先用unsqueeze将维度扩展为[1,32,1,1]用expand就可以进行维度大小扩展(重复值,但是不重新分配内存)
pytorch Tensor基础操作汇总_第13张图片
如果参数是-1 则维度不变 。如果输入一个除了-1的负数,维度会变成这个负数(一般不这样用)
repeat传入的参数为每一维要拷贝的次数
pytorch Tensor基础操作汇总_第14张图片

permute

可以交换不同的维度,参数是原来的维度索引
在这里插入图片描述在这里插入图片描述

3、Broadcast

Broadcast总是在“大维度”上进行自动扩张,可以认为左边的维度是大维度。
实际问题:有一个Feature map:[4,32,14,14],分别代表batch size、通道数、长、宽
需要加上一个偏置[32,1,1]
pytorch Tensor基础操作汇总_第15张图片
相当于unsqueeze+expand
pytorch Tensor基础操作汇总_第16张图片
实际中,可能要在一个高维Tensor上加上一个标量,就用到了broadcasting
pytorch Tensor基础操作汇总_第17张图片
另外用broadcast可以节省内存

什么时候用broadcast

低维度要么是1(可以自动扩展相加),要么和被加Tensor的低维度匹配
比如A [4,32,8] B为标量,可以用broadcast机制相加。先将低维扩展成维度=8,再扩展出两个高维。如果B为Tensor,维数大小与A的低维大小相同,也可以自动扩展高维之后相加,比如B为[1,8]
情形1
pytorch Tensor基础操作汇总_第18张图片
情形2
在每张图片的每个通道都叠加一个二维Tensor
pytorch Tensor基础操作汇总_第19张图片
不可用情形
高维只给了两张的信息,操作无法完成。可以用B[n]举出某一张的Tensor然后相加
pytorch Tensor基础操作汇总_第20张图片

4、拼接和拆分

pytorch Tensor基础操作汇总_第21张图片

cat

假设有两份成绩单,一份是1-4班的成绩单,一份是5-9班的成绩单,成绩单Tensor有三个维度,分别代表班级、学生和课程。
在这里插入图片描述
现在要将此两个Tensor拼接在一起,就可以用cat,传入两个Tensor以及合并的维度
在这里插入图片描述
举二维Tensor的例子,在dim=0上拼接,即按照行来拼接:[4,4] [4,4] [4,4] 得到[12,4]
pytorch Tensor基础操作汇总_第22张图片
在dim=1上拼接,[4,4] [4,3]就得到[4,7]
pytorch Tensor基础操作汇总_第23张图片
注意在cat的时候除了拼接的维度,其他维度的size要一样

stack

在拼接的时候创建新维度
对比:
pytorch Tensor基础操作汇总_第24张图片
比如有两张表,都是328,如果用cat会合成一张648的大表,但是有时候我们想要分开存放,便于调用,这样就要用stack,创建了一个新的维度(班级),便于调用管理。
pytorch Tensor基础操作汇总_第25张图片
注意用stack的话,除了生成的新维度,其他维度都要相同。

split

参数格式1:切分后每个单元的长度(如果是一个数代表每个的长度都是这么多;如果是一个列表则分别代表每个拆分后的维度大小)+维度索引
拆分之前:[2,32,8]
在这里插入图片描述

在这里插入图片描述

chunk

按照数量来拆分
在这里插入图片描述

5、数学运算

pytorch Tensor基础操作汇总_第26张图片

基本运算(四则)

下面的add用到了广播机制
pytorch Tensor基础操作汇总_第27张图片
sub,mul,div 操作同理
pytorch Tensor基础操作汇总_第28张图片

矩阵相乘 matmul

pytorch Tensor基础操作汇总_第29张图片
一般就使用matmul
pytorch Tensor基础操作汇总_第30张图片
实例:降低某一维度的长度:
乘以一个[784,512]矩阵
可以将[4,784]->[512,784]
这里反着写是因为pytorch约定chanel-out chanel-in的顺序,后面进行矩阵相乘的时候用.t()转置一下
pytorch Tensor基础操作汇总_第31张图片
二维以上的矩阵相乘,只对后面两维作相乘运算
在这里插入图片描述
在这里插入图片描述
如果之前的维数不一样,由于broadcast机制可以自动扩展相乘
在这里插入图片描述
broadcast在0维扩展,如果无法用broadcast扩展,则会报错
pytorch Tensor基础操作汇总_第32张图片

power

接收矩阵和每个元素的pow
pytorch Tensor基础操作汇总_第33张图片
其他操作同理 rsqrt是平方根的倒数

pytorch Tensor基础操作汇总_第34张图片
exp和log:
pytorch Tensor基础操作汇总_第35张图片

近似值

取下、取上、四舍五入、取整、取小数
pytorch Tensor基础操作汇总_第36张图片

clamp

如果只有一个参数,限定最小值
如果两个参数,限定最小值和最大值
pytorch Tensor基础操作汇总_第37张图片

6、统计属性

norm 范数

第一范数是和 第二范数是平方和开根号
pytorch Tensor基础操作汇总_第38张图片

mean,sum,min,max,prod

注意argmax和argmin返回的是索引。且是变成vector后的索引
pytorch Tensor基础操作汇总_第39张图片
在用argmax的时候可以输入维度索引:
pytorch Tensor基础操作汇总_第40张图片

dim,keepdim

统计信息会消除dimension,用keepdim可以避免消除dimension
pytorch Tensor基础操作汇总_第41张图片

Top-k

取最大的前k个,同样可以用dim指定维度
将largest设置为false(默认为true)可以求前k小的
kthvalue返回第k小的值和索引
pytorch Tensor基础操作汇总_第42张图片

比较

可以对Tensor的每个元素进行比较,返回与原Tensor维度相同的0-1Tensor作为结果
注意eq和equal的区别,前者逐个比较Tensor元素,后者比较Tensor整体
pytorch Tensor基础操作汇总_第43张图片

7、进阶操作

pytorch Tensor基础操作汇总_第44张图片

where

参数:条件+原数据A+原数据B
条件是一个和A和B维度相同的Tensor,0代表来自ATensor该位置的元素,1代表来自BTensor该位置的元素
pytorch Tensor基础操作汇总_第45张图片
例子:使一个Tensor中大于0.5的数字取0,小于等于0.5的数字取1:
pytorch Tensor基础操作汇总_第46张图片
为什么要用where而不用for循环一个个比较,是因为后者完全使用cpu,用where是用的并行运算,用的GPU,速度会更快。

gather

pytorch Tensor基础操作汇总_第47张图片
根据索引表,从一个表中采集不同的元素
pytorch Tensor基础操作汇总_第48张图片
用gather进行查表操作用法如下
在这里插入图片描述
例子:
pytorch Tensor基础操作汇总_第49张图片

你可能感兴趣的:(pytorch)