torch.squeeze(input, dim=None, *, out=None)
x = torch.zeros(2, 1, 2, 1, 2)
print(x.shape)
x = x.squeeze()
print(x)
print(x.shape)
x = torch.zeros(2, 1, 2, 1, 2)
print(x.shape)
x = x.squeeze(0)
print(x.shape)
x = x.squeeze(1)
print(x.shape)
多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。
在多维张量中,如果某一个维度是1,那么这个维度是为了扩充维度,所以为了加快计算,进行降维操作时可以去掉1的维度。
torch.squeeze是为了降维,那么torch.unsqueeze是了升维。
torch.unsqueeze(input, dim)
x = torch.tensor([1, 2, 3, 4])
print(x)
print(x.size())
print('*'*50)
x = x.unsqueeze(1)
print(x)
print(x.size())
squeeze_和unsqueeze_分别在squeeze和unsqueeze的基础上增加下划线,区别在于是否改变原来张量。
加上“_”,将会直接改变原始张量,否则不直接改变原始张量。
x = torch.zeros(2, 1, 2, 1, 2)
y = torch.zeros(2, 1, 2, 1, 2)
x_t = x.squeeze_(1)
y_t = y.squeeze(1)
print('squeeze原始张量:',y.size())
print('squeeze变化张量:',y_t.size())
print('squeeze_原始张量:',x.size())
print('squeeze_变化张量:',x_t.size())
x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([1, 2, 3, 4])
x_t = x.unsqueeze(1)
y_t = y.unsqueeze_(1)
print('unsqueeze原始张量:',x.size())
print('unsqueeze变化张量:',x_t.size())
print('unsqueeze_原始张量:',y.size())
print('unsqueeze_变化张量:',y_t.size())