pytorch基础知识七【拼接与拆分】

拼接与拆分

  • 1. 拼接
    • 1.1 cat()
    • 1.2 stack()
    • 1.3 cat() VS stack()
  • 2. 拆分
    • 2.1 split()
    • 2.2 chunk()

1. 拼接

1.1 cat()

torch.cat([tensor1,tensor2],dim)

[tensor1,tensor2]表示需要拼接的张量;
dim 表示在哪个维度上拼接

注意:拼接时,除了要拼接的维度外,其他维度的形状必须相同。
pytorch基础知识七【拼接与拆分】_第1张图片
pytorch基础知识七【拼接与拆分】_第2张图片
示意图:根据不同的维度拼接
pytorch基础知识七【拼接与拆分】_第3张图片

pytorch基础知识七【拼接与拆分】_第4张图片
pytorch基础知识七【拼接与拆分】_第5张图片

1.2 stack()

stack在拼接张量时,会创建新的维度。
pytorch基础知识七【拼接与拆分】_第6张图片

1.3 cat() VS stack()

2. 拆分

2.1 split()

.split(len,dim=0) # 按固定长度将张量拆分,dim表示在哪个维度上拆分,拆分后的每个张量在维度dim上的形状都是len。

.split([len1,len2,...],dim=0) # 按[len1,len2,...]中的长度将张量拆分为dim维度上形状不固定的若干个张量,dim表示在哪个维度上拆分。

pytorch基础知识七【拼接与拆分】_第7张图片

2.2 chunk()

.chunk(num,dim)  # 按数量拆分,表示张量按dim维度分成num块。

pytorch基础知识七【拼接与拆分】_第8张图片

你可能感兴趣的:(pytorch,pytorch,深度学习,人工智能)