之前写写过
见pytorch中的torch.cat()矩阵拼接的用法及理解
这个函数的作用主要是将矩阵交换两个维度。
import torch
a = torch.Tensor([[1, 2, 3]])
b = torch.Tensor([[[1, 2, 1], [1, 1, 1]]])
aa = torch.transpose(a, 0, 1) # 交换第一维和第二维
bb = torch.transpose(b, 0, -1) # 交换第一维和最后一维
print(a.shape, "----", aa.shape)
print(b.shape, "----", bb.shape)
print(aa)
print(bb)
这个函数能对维度进行变换,view中的参数即为每一维应当包含的元素个数,如果-1即让其自动计算出。
import torch
a = torch.Tensor([[1, 2, 3]])
b = torch.Tensor([[[1, 2, 1], [3, 1, 1]]])
aa = a.view(3, -1) # 第一维3个元素,第二维任意
bb = b.view(3, 2) # 第一维3个元素第二维2个元素
bbb = b.view(1, 1, 6) # 最后一维6个元素
print(aa.shape)
print(bb.shape)
print(bbb.shape)
print(bb)
这个函数能够对任意维度进行任意交换,transpose只能对两个维度进行交换。用法是:原本0,1,2代表第1,2,3维,那么现在b.permute(2,0,1)即现在第一维是原来的第三维,第二维是原来的第一维,第三维是原来的第二维。
import torch
a = torch.Tensor([[1, 2, 3]])
b = torch.Tensor([[[1, 2, 1], [3, 1, 1]]])
bb = b.permute(2, 0, 1)
print(b.shape)
print(bb.shape)
作用就是降维,去掉维度中元素为1的维度
import torch
a = torch.Tensor([[[[1], [2], [3]]]])
b = torch.Tensor([[[1, 1, 1], [1, 1, 1]]])
aa = a.squeeze()
bb = b.squeeze()
aaa = a.squeeze(0) # 0代表将第一维去掉
print(a.shape)
print(b.shape)
print(aa.shape)
print(bb.shape)
print(aaa.shape)
作用就是升1维
import torch
a = torch.Tensor([[[[1], [2], [3]]]])
b = torch.Tensor([[[1, 1, 1], [1, 1, 1]]])
aa = a.unsqueeze(0) # 在第一维前增加一维
bb = b.unsqueeze(-1) # 在最后一维后增加一维
print(a.shape)
print(b.shape)
print(aa.shape)
print(bb.shape)
实现的是矩阵乘法不过必须是三维的,最后两维需要满足矩阵乘法的格式
import torch
a = torch.Tensor([[[[1], [2], [3]]]])
b = torch.Tensor([[[0, 1, 1], [1, 2, 1]]])
print(a.shape)
print(b.shape)
print("----------------")
a = a.view(1, -1, 3) # 维度转换使其能够进行矩阵乘法
b = b.view(1, 3, -1)
print(a.shape)
print(b.shape)
print("----------------")
print(a)
print(b)
c = torch.bmm(a, b)
print(c)
print(c.shape)
求一维向量的内积,即对应元素相乘然后求和
import torch
a = torch.Tensor([1, 1, 2, 3])
b = torch.Tensor([1, 1, 1, 1])
c = torch.dot(a, b)
print(c)
运行结果
1x1 + 1x1 + 2x1 + 3x1 = 7
具体查看是什么tensor类型
# coding:utf-8
import torch
a = torch.Tensor([5])
print(a.dtype)
能够生成不同类型的tensor
# coding:utf-8
import torch
a = torch.tensor([5], dtype=torch.float64)
b = torch.tensor([6], dtype=torch.float32)
c = torch.tensor([7], dtype=torch.int64)
print(a.dtype)
print(b.dtype)
print(c.dtype)
转换tensor的数据类型
# coding:utf-8
import torch
a = torch.tensor([5], dtype=torch.float64).long()
b = torch.tensor([6], dtype=torch.float32).int()
c = torch.tensor([7], dtype=torch.int64).float()
print(a.dtype)
print(b.dtype)
print(c.dtype)
可以用来找对应元素的下标
batch = torch.tensor([[1, 2, 3, 0, 0], [1, 0, 3, 1, 1]])
pos_sen = torch.nonzero(batch==0).squeeze()
print(pos_sen)
用于矩阵的拼接
import torch
a = torch.tensor([[1,2,3]])
b = torch.tensor([[4,5,6]])
c = torch.cat((a, b), axis=0)
d = torch.cat((a, b), axis=1)
print(c)
print(d)
用于矩阵的拆分
import torch
a = torch.tensor([[1,2,3], [2,2,3]])
b, c = a.chunk(2, dim=0)
print(b)
print(c)
a = torch.tensor([[1,2,3], [2,2,3]])
b, c, d = a.chunk(3, dim=1)
print(b)
print(c)
print(d)