pytorch的数组操作

本文是记录自己学习中的一些小知识,比较散

a = torch.randn(4, 3, 28, 28)
print(a[:2].shape)

#out: torch.Size([2, 3, 28, 28])

a[:2]   这里是指,a的第0维中,0到2,不包含2的数据,其他维度默认取全部,等效于a[:2, :, :, :]

让然,很明显,a[:, :, :, :], 这个能控制自己想要去取出的数据

 

另一个有意思的是,维度中出现的负数

比如:a[-1:],这个是指 a的0维,从最后一个,取到最后,也就是取最后一个。

a[-2:,],这个是指a的0维,从倒数第二个,取到最后一个

正常情况下,a中第0维有4个数据,也就是0,1,2,3  ,当序号是正数时,对应就是0,1,2,3,序号是负数时,对应是-4,-3,-2,-1

 

步长控制

print(a[:, :, 0:28:2, 0:28:2].shape)

#out: torch.Size([4, 3, 14, 14])

这里是取,a的第2维度,从2取到28(不包括28),以步长为2

这里步长的写法与matlab稍有区别:

matlab : start: step: end

pytorch :  start: end: step

 

三个点 ...

代表省略,根据实际情况推测,有时候,省略中间,或者前面的一些维度比较方便

a[..., :2]

 

维度变换

使用permute函数,可以交换原来的维度~

 

tensor合并

1. cat 将某一维度合并,其余维度要一致;

也就是将tensor在指定的维度拼接起来,不会增加新的维度

a = torch.randn(4, 3, 28, 28)
b = torch.randn(3, 3, 28, 28)
c = torch.cat([a, b], dim= 0)

print(c.shape)

#out: torch.Size([7, 3, 28, 28])

2. Stack

拼接两个tensor,并且给出一个新的维度,要求是拼接的tensor原来的维度都必须一致

a = torch.randn(4, 3, 28, 28)
b = torch.randn(4, 3, 28, 28)
c = torch.stack([a, b], dim= 0)

print(c.shape)

#out: torch.Size([2, 4, 3, 28, 28])

tensor拆分

1.split 在指定的维度,分割为指定的数量

a = torch.randn(4, 3, 28, 28)
#将第0维的4分割为2+1+1, 如果长度都一样,只用给一个长度的值就行
a1, a2, a3 = a.split([2, 1, 1], dim=0)
print(a1.shape, '\n', a2.shape, '\n', a3.shape )

#out : torch.Size([2, 3, 28, 28]) 
#out : torch.Size([1, 3, 28, 28]) 
#out : torch.Size([1, 3, 28, 28])

2. chunk 是在指定维度的,分割为指定的数量,指定的那个维度的数量必须被分割数量整除

a = torch.randn(12, 28, 28)
a1, a2, a3, a4 = a.chunk(4, dim=0)
print(a1.shape)
print(a2.shape)
print(a3.shape)
print(a4.shape)


#out: torch.Size([3, 28, 28])
#out: torch.Size([3, 28, 28])
#out: torch.Size([3, 28, 28])
#out: torch.Size([3, 28, 28])

 

你可能感兴趣的:(pytorch~)