此处对张量(Tensor)的相关属性及计算方法做一个简单汇总
1、数据类型
16位浮点数:torch.float16 或 torch.half,torch.HalfTensor,torch.cuda.HalfTensor
32位浮点数:torch.float32 或 torch.float,torch.FloatTensor,torch.cuda.FloatTensor
64位浮点数:torch.float64 或 torch.double,torch.DoubleTensor,torch.cuda.DoubleTensor
8位无符号整型:torch.uint8,torch.ByteTensor,torch.cuda.ByteTensor
8位有符号整型:torch.int8,torch.CharTensor,torch.cuda.CharTensor
16位有符号整型:torch.int16 或 torch.short,torch.ShortTensor,torch.cuda.ShortTensor
32位有符号整型:torch.int32 或 torch.int,torch.IntTensor,torch.cuda.IntTensor
64位有符号整型:torch.int64 或 torch.long,torch.LongTensor,torch.cuda.LongTensor
pytorch 中默认的数据类型是 torch.float32
对于张量,可使用 .dtype 获取张量类型
2、生成张量
torch.tensor():可将python列表或序列转换为张量,可使用dtype参数指定数据类型
torch.Tensor():生成张量,可指定形状
同时,可以对 numpy 的 ndarray 和 tensor 相互转化,参见此文
torch.empty():生成指定大小的空张量(随机填充)
torch.zeros(),torch.ones() 分别用于生成指定大小的全0张量,全1张量
torch.zeros_like(),torch.ones_like() 分别用于生成与指定张量维度相同的全0张量,全1张量
torch.full():可用于生成填充指定值的张量,反而 torch.zeros(),torch.ones() 限定了填充值
除以上外,还可生成随机张量(理应提前设定随机种子)
torch.rand():生成服从均匀分布的张量,默认区间[0, 1]
torch.randn():生成服从标准正态分布的张量
torch.rand_like(),torch.randn_like() 分别可用于生成与指定张量形状相同的随机数张量
torch.arange():可用于生成张量序列,默认从0开始间隔1,可指定起始值、终止值和间隔
3、张量形状
.shape:获取张量维度,返回 torch.Size 对象
.size():获取张量形状大小,返回 torch.Size 对象
.numel():获得张量中包含的元素数量
.reshape():可用于改变张量的形状
同时,可使用 torch.squeeze() 和 torch.unsqueeze() 对张量进行维度扩增或者维度压缩操作,参见此文
可对张量进行重复扩展
torch.expand():在维数为1的维度上进行拓展,其他维度保持与原来一样即可,可首相使用.size()查看原始维度;如果没有维数为1的维度则不能拓展;
torch.repeat():根据指定的维度进行填充,指定的值为扩展的倍数;指定的维数不能少于原始维数,但可以增加维数以增加新维度
3、张量拼接与拆分
torch.cat():将所给张量按照指定维度进行拼接
torch.stack():将所给张量按照指定的新维度上进行拼接
另有 torch.column_stack(同torch.hstack)、torch.row_stack(同torch.vstack)可分别对列和行角度进行拼接,是一般拼接操作的细化
torch.chunk():将所给张量按指定维度划分为指定数量的张量,不能整除时则最后一个张量较小
torch.split():将张量按指定维度进行划分,可通过list指定划分的块大小
4、张量的比较
torch.eq():逐个元素比较是否相等
torch.equal():判断是否具有相同的形状和元素
torch.ge():逐个元素比较,是否大于等于
torch.gt():逐个元素比较,是否大于
torch.le():逐个元素比较,是否小于等于
torch.lt():逐个元素比较,是否小于
torch.ne():逐个元素比较,是否不等于
torch.isnan():逐个元素判断,是否为nan值
5、张量的基本运算
减加乘除四则运算、幂运算、对数运算、求平方根等均采用广播形式进行,与 numpy 用法同
torch.pow():幂运算
torch.exp():指数运算
torch.log():求自然对数,另有torch.log2(),torch.log10() 等
torch.sqrt():求平方根
6、张量的矩阵运算
torch.matmul():矩阵乘法
torch.t():矩阵转置
torch.inverse():矩阵的逆矩阵
torch.trace():矩阵的迹
7、张量的统计函数
基础的统计功能,最大最小值函数torch.max()、torch.min()等可参见此文
此外还有各种统计参数等
torch.mean():均值
torch.median():中位数
torch.std():标准差
torch.sum():求和
以上可指定维度,若未指定,则对所有元素计算统计
torch.cumsum():累加和,应指定维度
torch.cumprod():累乘积,应指定维度
注:以上仅对常用部分进行了汇总,未包括的部分需要时再查即可