本章主要总结了tensor的一些常用的操作,如果想把每个关于tensor的方法记下是不太现实的。只需要留个印象,在实际的场景中需要用到,知道有这么个东西,再通过查询api进行使用。
torch这个package中包含了多维张量(tensor)数据结构,并定义了对这些张量的数学运算,并且提供了许多工具对张量各种操作。以下总结常用的一些操作。
#导入torch
import torch
作用是检测该tensor中的元素是否非零
参数是只含有一个元素的tensor
返回 布尔值
注意:此方法输入必须是只含有一个元素的tensor a4,a5的输入会报错
a1 = torch.is_nonzero(torch.tensor([0]))
a2 = torch.is_nonzero(torch.tensor([1]))
a3 = torch.is_nonzero(torch.tensor([0.]))
# a4 = torch.is_nonzero(torch.tensor([]))
# a5 = torch.is_nonzero(torch.tensor([1,2,3]))
a1 ,a2 ,a3
#输出
(False, True, False)
此方法返回tensor中共有多少个元素
num = torch.numel(torch.rand(3,3))
num
#输出
9
类似矩阵中的转置操作,行列互换
input (Tensor) – 输入为一个tensor.
input (Tensor) – 输入为一个tensor
dim0 (int) – 第一个要转置的维度
dim1 (int) - 第二个要转置的维度
都实现了张量的转置操作
t(input)只能进行二维张量的转置
transpose(input,dim0, dim1)可以进行高纬张量的转置,但是只能选择两个维度之间的转置
二维数据
data = torch.rand(3,4)
t1 = torch.t(data)
#输出
data,t1
(tensor([[0.6796, 0.0266, 0.4069, 0.3042],
[0.8158, 0.7434, 0.7118, 0.3403],
[0.0744, 0.5714, 0.4415, 0.0747]]),
tensor([[0.6796, 0.8158, 0.0744],
[0.0266, 0.7434, 0.5714],
[0.4069, 0.7118, 0.4415],
[0.3042, 0.3403, 0.0747]]))
注意:如果torch.t(data)传入两维度以上的张量会报一下错误
RuntimeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_9232/2496212846.py in <module>
----> 1 torch.t(data)
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
多维张量
data = torch.rand(3,4,2)
data
#元数据
tensor([[[0.4097, 0.1857],
[0.9020, 0.6462],
[0.1544, 0.8468],
[0.6699, 0.4116]],
[[0.6666, 0.8028],
[0.4677, 0.9692],
[0.6497, 0.0196],
[0.7469, 0.6108]],
[[0.8890, 0.8045],
[0.6521, 0.4382],
[0.4315, 0.1732],
[0.4397, 0.6457]]])
不同维度的转置
torch.transpose(data,0,1)
#输出
tensor([[[0.4097, 0.1857],
[0.6666, 0.8028],
[0.8890, 0.8045]],
[[0.9020, 0.6462],
[0.4677, 0.9692],
[0.6521, 0.4382]],
[[0.1544, 0.8468],
[0.6497, 0.0196],
[0.4315, 0.1732]],
[[0.6699, 0.4116],
[0.7469, 0.6108],
[0.4397, 0.6457]]])
torch.transpose(data,-1,1)
#输出
tensor([[[0.4097, 0.9020, 0.1544, 0.6699],
[0.1857, 0.6462, 0.8468, 0.4116]],
[[0.6666, 0.4677, 0.6497, 0.7469],
[0.8028, 0.9692, 0.0196, 0.6108]],
[[0.8890, 0.6521, 0.4315, 0.4397],
[0.8045, 0.4382, 0.1732, 0.6457]]])
torch.cat(tensors, dim=0, *, out=None)连接给定维度的张量,所有的tensor必须是相同的维度
tensors – 具有相同shape的多个tensor组成的元组(tensor1,tensor2)
dim (int, optional) – 张量连接的维度
cat_data1 = torch.rand(2,3)
cat_data2 = torch.rand(2,3)
cat_data3 = torch.rand(2,3)
cat_data = torch.cat((cat_data1,cat_data2,cat_data3))
cat_data
#输出结果cat_data
tensor([[0.4117, 0.1281, 0.1735],
[0.0882, 0.9527, 0.8244],
[0.5485, 0.8877, 0.4297],
[0.9650, 0.3667, 0.4441],
[0.1581, 0.4601, 0.7503],
[0.1222, 0.6008, 0.3165]])
可以选择在不同的维度上进行连接
cat_data = torch.cat((cat_data1,cat_data2,cat_data3),1)
cat_data
#输出结果cat_data
tensor([[0.4117, 0.1281, 0.1735, 0.5485, 0.8877, 0.4297, 0.1581, 0.4601, 0.7503],
[0.0882, 0.9527, 0.8244, 0.9650, 0.3667, 0.4441, 0.1222, 0.6008, 0.3165]])
cat_data = torch.cat((cat_data1,cat_data2,cat_data3),1)
cat_data