torch.cat()函数

一张图像在计算机中的表示通常为三维tensor(张量),即[channels,height,width] 。也就是一张彩色图片通常有三色通道(R,G,B)组成,高和宽也就是常说的照片大小,比如224x224

在图像处理的时候会增加一个变量batch_size,也就是把多少张图片作为一批进行处理。所以就变成了四维张量,即[batch_size,channels,heigth,width],也即是[批量大小,通道数,高,宽]

如何判断一个tensor是几维张量最简单的办法就是看中括号数。例如 [[[[1,2,3]]]],是四维张量。

torch.cat()函数,官方文档是这样写的torch.cat(tensors, dim=0, *, out=None),也就是有两个参数,一个是要合并的张量,一个是在哪个维度上进行合并。

废话少说开始演示。

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

定义了两个四维张量。维度都为[1,1,2,3],即批量大小为1,通道为1,高为2,宽为3

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

#在维度0上面进行合并

x=torch.cat((a,b),dim=0)

print(x.shape)

#torch.Size([2, 1, 2, 3])

在维度0上进行合并,然后输出维度为[2,1,2,3],所以得出结论 四维张量在0维合并的时候 其实是在批量大小维度上进行合并。

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

#在维度0上面进行合并

x=torch.cat((a,b),dim=0)

print(x.shape)

#torch.Size([2, 1, 2, 3])

#在维度1上进行合并

x=torch.cat((a,b),dim=1)

print(x.shape)

#torch.Size([1, 2, 2, 3])

在1维度上进行合并,输出维度为[1,2,2,3],即在1维上合并是在通道维度上进行合并。

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

#在维度0上面进行合并

x=torch.cat((a,b),dim=0)

print(x.shape)

#torch.Size([2, 1, 2, 3])

#在维度1上进行合并

x=torch.cat((a,b),dim=1)

print(x.shape)

#torch.Size([1, 2, 2, 3])

#在维度2上进行合并

x=torch.cat((a,b),dim=2)

print(x.shape)

#torch.Size([1, 1, 4, 3])

在维度2上进行合并,输出维度为[1,1,4,3]。即在2维上进行合并是在高上进行合并(也可以说是在行维度进行合并)

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

#在维度0上面进行合并

x=torch.cat((a,b),dim=0)

print(x.shape)

#torch.Size([2, 1, 2, 3])

#在维度1上进行合并

x=torch.cat((a,b),dim=1)

print(x.shape)

#torch.Size([1, 2, 2, 3])

#在维度2上进行合并

x=torch.cat((a,b),dim=2)

print(x.shape)

#torch.Size([1, 1, 4, 3])

#在维度3上进行合并

x=torch.cat((a,b),dim=3)

print(x.shape)

#torch.Size([1, 1, 2, 6])

在维度3上进行合并,输出维度为[1,1,2,6],即在3维上进行合并是在宽维度进行合并(也可以说是列)

注:在拼接时 除了选择拼接的维度可以不同,其他维度要相同。什么意思?看代码

import torch

#定义两个变量[batch_size,channel,height,width]

a=torch.randn(size=(1,1,2,3))

b=torch.randn(size=(1,2,2,3))

#选择在1维度进行合并(也就是通道维度),注意a,b的通道维度不同,其他维度都相同。

x=torch.cat((a,b),dim=1)

print(x.shape)

#torch.Size([1, 3, 2, 3])

也就是选择合并的那个维度可以不同,其他维度要相同

如果不同,报错,如下。

import torch

#定义两个变量[batch_size,channel,height,width]

a=torch.randn(size=(1,1,2,3))

b=torch.randn(size=(2,2,2,3))

#选择在1维度进行合并(也就是通道维度),注意a,b的批量大小不同,维度不同,其他维度都相同。

x=torch.cat((a,b),dim=1)

print(x.shape)

#RuntimeError: Sizes of tensors must match except in dimension 1. Got 1 and 2 in dimension 0

可以看到当我们选择在通道维度合并时(通道数可以不同),但是其他的维度要相同(下面的a,b的批量大小也不同)。所以直接报错。

原文链接:https://blog.csdn.net/zwb619/article/details/127022873

你可能感兴趣的:(机器学习,pytorch,pytorch)