本文是记录自己学习中的一些小知识,比较散
a = torch.randn(4, 3, 28, 28)
print(a[:2].shape)
#out: torch.Size([2, 3, 28, 28])
a[:2] 这里是指,a的第0维中,0到2,不包含2的数据,其他维度默认取全部,等效于a[:2, :, :, :]
让然,很明显,a[:, :, :, :], 这个能控制自己想要去取出的数据
另一个有意思的是,维度中出现的负数。
比如:a[-1:],这个是指 a的0维,从最后一个,取到最后,也就是取最后一个。
a[-2:,],这个是指a的0维,从倒数第二个,取到最后一个
正常情况下,a中第0维有4个数据,也就是0,1,2,3 ,当序号是正数时,对应就是0,1,2,3,序号是负数时,对应是-4,-3,-2,-1
步长控制
print(a[:, :, 0:28:2, 0:28:2].shape)
#out: torch.Size([4, 3, 14, 14])
这里是取,a的第2维度,从2取到28(不包括28),以步长为2
这里步长的写法与matlab稍有区别:
matlab : start: step: end
pytorch : start: end: step
三个点 ...
代表省略,根据实际情况推测,有时候,省略中间,或者前面的一些维度比较方便
a[..., :2]
维度变换
使用permute函数,可以交换原来的维度~
tensor合并
1. cat 将某一维度合并,其余维度要一致;
也就是将tensor在指定的维度拼接起来,不会增加新的维度
a = torch.randn(4, 3, 28, 28)
b = torch.randn(3, 3, 28, 28)
c = torch.cat([a, b], dim= 0)
print(c.shape)
#out: torch.Size([7, 3, 28, 28])
2. Stack
拼接两个tensor,并且给出一个新的维度,要求是拼接的tensor原来的维度都必须一致
a = torch.randn(4, 3, 28, 28)
b = torch.randn(4, 3, 28, 28)
c = torch.stack([a, b], dim= 0)
print(c.shape)
#out: torch.Size([2, 4, 3, 28, 28])
tensor拆分
1.split 在指定的维度,分割为指定的数量
a = torch.randn(4, 3, 28, 28)
#将第0维的4分割为2+1+1, 如果长度都一样,只用给一个长度的值就行
a1, a2, a3 = a.split([2, 1, 1], dim=0)
print(a1.shape, '\n', a2.shape, '\n', a3.shape )
#out : torch.Size([2, 3, 28, 28])
#out : torch.Size([1, 3, 28, 28])
#out : torch.Size([1, 3, 28, 28])
2. chunk 是在指定维度的,分割为指定的数量,指定的那个维度的数量必须被分割数量整除
a = torch.randn(12, 28, 28)
a1, a2, a3, a4 = a.chunk(4, dim=0)
print(a1.shape)
print(a2.shape)
print(a3.shape)
print(a4.shape)
#out: torch.Size([3, 28, 28])
#out: torch.Size([3, 28, 28])
#out: torch.Size([3, 28, 28])
#out: torch.Size([3, 28, 28])