【Pytorch基础(3)】张量的拼接,拆分与统计

一、张量的拼接

张量的拼接主要通过cat()和stack()函数实现。其中torch.cat([a, b], dim=n)是在n维度上进行两个张量的拼接,其参数n的含义代表要进行拼接操作的维度,a和b则代表要拼接的张量。在使用cat()方法时需要注意的是两个张量除了拼接的维度可以不同,其他的维度必须相同,否则会报错。示例如下:
Statistics about scores
a [class1-3, students, scores]
b [class4-9, students, scores]

import torch

a = torch.rand(3, 32, 8)
b = torch.rand(6, 32, 8)

print(a.shape)
print(b.shape)
print(torch.cat([a, b], dim=0).shape)

# output
# torch.Size([3, 32, 8])
# torch.Size([6, 32, 8])
# torch.Size([9, 32, 8])

torch.stack([a, b], dim=n)是拼接两个张量a,b时,在维度n之前生成一个新的维度。注意,stack()方法对于带拼接的两个张量形状要求更加严格,具体来说当使用stack()方法时,要保证拼接的两个张量形状是相同的,否则会报错,示例如下:

import torch

a = torch.rand(3, 32, 8)
b = torch.rand(6, 32, 8) 
c = torch.rand(3, 32, 8)

print(torch.stack([a, c], dim=0).shape)
print(torch.stack([a, b], dim=0).shape)

# output
# torch.Size([2, 3, 32, 8])
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# Input In [35], in ()
#       5 c = torch.rand(3, 32, 8)
#       7 print(torch.stack([a, c], dim=0).shape)
# ----> 8 print(torch.stack([a, b], dim=0).shape)

# RuntimeError: stack expects each tensor to be equal size, 
# but got [3, 32, 8] at entry 0 and [6, 32, 8] at entry 1

二、张量的拆分

张量的拆分主要通过split()和chunk()函数实现。其中split()是在某维度上按照定义的间隔进行维度拆分的,方法的格式为torch.split(要拆掉的张量,拆分时的间隔数,要拆分的维度索引) ,拆分后的结果将以列表的形式进行返回。示例如下:

import torch

a = torch.rand(5, 32, 8)
# 对张量a中的第0维以间隔2进行拆分
b = torch.split(a, 2, 0)

print(a.shape)
print(len(b))
print(b[0].shape)
print(b[1].shape)
print(b[2].shape)

# output
# torch.Size([5, 32, 8])
# 3
# torch.Size([2, 32, 8])
# torch.Size([2, 32, 8])
# torch.Size([1, 32, 8])

至于chunk(),则是在某维度上按照定义的数量进行维度拆分的,方法的格式为torch.chunk(要拆掉的张量,拆分后的数量,要拆分的维度索引) ,拆分后的结果将以列表的形式进行返回。示例如下:

import torch

a = torch.rand(5, 32, 8)
# 对张量a中的第1维进行拆分,拆分后可得到两个子集
b = torch.chunk(a, 2, 1)

print(a.shape)
print(len(b))
print(b[0].shape)
print(b[1].shape) 

# output
# torch.Size([5, 32, 8])
# 2
# torch.Size([5, 16, 8])
# torch.Size([5, 16, 8])

二、张量的统计运算

pytorch中,常用的张量的取整方法有五种,分别是:

.floor() 向下取整
.ceil() 向上取整
.round() 四舍五入
.trunc() 裁剪出整数部分
.frac() 裁剪出小数部分
示例如下:

import torch

a = torch.tensor(3.1415926)

print(a.floor())
print(a.ceil())
print(a.round())
print(a.trunc())
print(a.frac())

# output
# tensor(3.)
# tensor(4.)
# tensor(3.)
# tensor(3.)
# tensor(0.1416)

pytorch中,常用的张量统计方法有五种,分别是:

.mean() 求均值
.sum() 求和
.max() 求最大值
.min() 求最小值
.prod() 求乘积
示例如下:

import torch

a = torch.tensor([1., 2., 3., 4., 5., 6., 7.])

print(a.mean())
print(a.sum())
print(a.max())
print(a.min())
print(a.prod())


# output
# tensor(4.)
# tensor(28.)
# tensor(7.)
# tensor(1.)
# tensor(5040.)

pytorch中,我们还可以取到一个张量最大值或最小值的索引,使用的方法是argmin()和argmax()。这个在做识别任务时非常常见,后续会讲到。

示例如下:

import torch

a = torch.tensor([1., 2., 3., 4., 5., 6., 7.])

print(a.argmin())
print(a.argmax()) 

# output
# tensor(0)
# tensor(6)

pytorch中,我们可以使用的方法torch.eq()和torch.equal()方法来判断两个张量是否相等。两者接受的参数都是两个张量,其中eq()方法的返回值是按元素位置返回True或False,False代表不等,True代表相等。而equal()方法的返回值是True或False,当两个张量完全一样时,才会返回True,不然返回False。代码示例如下:

import torch

a = torch.ones(3,3)
b = torch.eye(3,3)

print(torch.eq(a, b))
print(torch.equal(a, b))

# output
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
# False

三、torch.eye()函数

函数原型:

result = torch.eye(n,m=None,out=None)

参数解释:
n:行数
m:列数
out:输出类型
例:

c = torch.eye(3)
print(c)
print(type(c))

输出

tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])

你可能感兴趣的:(pytorch,pytorch,深度学习,python)