Pytorch入门教程(四):Tensor拆分、合并及基本运算

1.  cat 进行维度拼接

a = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.cat([a, b], dim=0)  # 按第0维度进行拼接,除拼接之外的维度必须相同
print(c.shape)

结果:torch.Size([9, 32, 8])

2.  stack 产生一个新的维度

a = torch.rand(5, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.stack([a, b], dim=0)  # 产生一个新的维度,待拼接的向量维度相同
print(c.shape)

结果:torch.Size([2, 5, 32, 8])

3.  split:  按所指定的长度拆分

a = torch.rand(6, 32, 8)
b, c = a.split(3, dim=0)  # 所给的是拆分后,每个向量的大小,指定拆分维度
print(b.shape)
print(c.shape)

结果:

torch.Size([3, 32, 8])
torch.Size([3, 32, 8])

4.   chuck:  按所给数量进行拆分

a = torch.rand(6, 32, 8)
b, c, d = a.chunk(3, dim=0)  # 所给的是拆分的个数,即拆分成多少个
print(b.shape)
print(c.shape)

结果:

torch.Size([2, 32, 8])
torch.Size([2, 32, 8])

5.  加减乘除(元素级别)

a = torch.ones(3, 4) * 2
b = torch.ones(3, 4)
print(a+b)
print(a-b)
print(a*b)
print(a/b)

结果:

tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.]])
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])
tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])

6.  矩阵乘法

a = torch.ones(2, 2) * 2
b = torch.ones(2, 3)
print(torch.mm(a, b))  # 只适用于2维数组
print(a@b)
a = torch.rand(4, 32, 28, 28)
b = torch.rand(4, 32, 28, 16)
print(torch.matmul(a, b).shape)  # 可以适用于多维数组,直讲最后两个维度相乘

tensor([[4., 4., 4.],
        [4., 4., 4.]])
tensor([[4., 4., 4.],
        [4., 4., 4.]])
torch.Size([4, 32, 28, 16])

7.  幂运算

a = torch.ones(2, 2) * 2
print(a.pow(2))  # 平方
print(a**2)
print(a.sqrt())  # 开方
print(a**0.5)

tensor([[4., 4.],
        [4., 4.]])
tensor([[4., 4.],
        [4., 4.]])
tensor([[1.4142, 1.4142],
        [1.4142, 1.4142]])
tensor([[1.4142, 1.4142],
        [1.4142, 1.4142]])

8.  e运算

a = torch.exp(torch.ones(2, 2))  # e运算
print(a)
print(torch.log(a))  # 取对数,默认以e为底

tensor([[2.7183, 2.7183],
        [2.7183, 2.7183]])
tensor([[1., 1.],
        [1., 1.]])

9.  四舍五入

a = torch.tensor(3.14)
print(a.floor())  # 向下取整
print(a.ceil())  # 向上取整
print(a.trunc())  # 取整数部分
print(a.frac())  # 取小数部分

tensor(3.)
tensor(4.)
tensor(3.)
tensor(0.1400)

10.  clamp 限定数组范围

a = torch.rand(2, 3) * 15
print(a)
print(a.clamp(2))  # 限定最小值为2
print(a.clamp(2, 10))  # 取值范围在0-10

tensor([[ 0.7791,  4.7365,  4.2215],
        [12.7793, 11.7283, 13.1722]])
tensor([[ 2.0000,  4.7365,  4.2215],
        [12.7793, 11.7283, 13.1722]])
tensor([[ 2.0000,  4.7365,  4.2215],
        [10.0000, 10.0000, 10.0000]])

你可能感兴趣的:(Pytorch)