pytorch中张量的分块.chunk()方法和拆分.split()方法

张量的分块和拆分方法

1.分块:.chunk()方法

.chunk()方法能够按照某维度,对张量进行均匀切分,并且返回结果是原张量的视图。
(1).chunk()函数的使用,第一个参数:目标张量,第二个参数:等分的块数,第三个参数:按照的维度

t2 = torch.arange(12).reshape(4, 3)
#结果为:tensor([[ 0,  1,  2],
                 [ 3,  4,  5],
                 [ 6,  7,  8],
                 [ 9, 10, 11]])
tc = torch.chunk(t2, 4, dim=0)   # 在第零个维度上(按行),进行四等分,每一行都为一个二维张量
#结果为:(tensor([[0, 1, 2]]),
          tensor([[3, 4, 5]]),
          tensor([[6, 7, 8]]),
          tensor([[ 9, 10, 11]]))

注意:chunk返回结果是一个视图,不是新生成了一个对象

tc[0][0]
#结果为:tensor([0, 1, 2])
tc[0][0][0] = 1     # 修改tc中的值
print(tc)
#结果为:(tensor([[1, 1, 2]]),
          tensor([[3, 4, 5]]),
          tensor([[6, 7, 8]]),
          tensor([[ 9, 10, 11]]))
print(t2)      # 原张量也会对应发生变化
#结果为:tensor([[ 1,  1,  2],
                 [ 3,  4,  5],
                 [ 6,  7,  8],
                 [ 9, 10, 11]])

注:当原张量不能均分时,chunk不会报错,但会返回其他均分的结果

torch.chunk(t2, 3, dim=0)     # 返回次一级均分结果,即2等分
#结果为:(tensor([[1, 1, 2],
                  [3, 4, 5]]),
          tensor([[ 6,  7,  8],
                  [ 9, 10, 11]]))
len(torch.chunk(t2, 3, dim=0))
#结果为:2
torch.chunk(t2, 5, dim=0)     # 次一级均分结果,即4等分
#结果为:(tensor([[1, 1, 2]]),
          tensor([[3, 4, 5]]),
          tensor([[6, 7, 8]]),
          tensor([[ 9, 10, 11]]))

2.拆分:.split()方法

.split()方法既能进行均分,也能进行自定义切分。当然,需要注意的是,和chunk函数一样,split返回结果也是view。
(1).split()函数的使用
均分情况

t2 = torch.arange(12).reshape(4, 3)
#结果为:tensor([[ 0,  1,  2],
                 [ 3,  4,  5],
                 [ 6,  7,  8],
                 [ 9, 10, 11]])
torch.split(t2, 2, 0)    # 第二个参数只输入一个数值时表示均分,第三个参数表示切分的维度
#结果为:(tensor([[0, 1, 2],
                  [3, 4, 5]]),
          tensor([[6, 7, 8],
                  [ 9, 10, 11]]))

按照索引值+1切分,序列中1代表的是第一行或第一列,不是第二行或第二列

torch.split(t2, [1, 3], 0)   # 第二个参数输入一个序列时,表示按照序列数值进行切分,也就是1/3分
#结果为:(tensor([[0, 1, 2]]),
          tensor([[ 3,  4,  5],
                  [ 6,  7,  8],
                  [ 9, 10, 11]]))

注意:当第二个参数位输入一个序列时,序列的各数值的和必须等于对应维度下形状分量的取值。例如,上述代码中,是按照第一个维度进行切分,而t2总共有4行,因此序列的求和必须等于4,也就是1+3=4,而序列中每个分量的取值,则代表切块大小。

torch.split(t2, [1, 1, 1, 1], 0)  
#结果为:(tensor([[0, 1, 2]]),
          tensor([[3, 4, 5]]),
          tensor([[6, 7, 8]]),
          tensor([[ 9, 10, 11]]))
          
torch.split(t2, [1, 1, 2], 0)
#结果为:(tensor([[0, 1, 2]]),
          tensor([[3, 4, 5]]),
          tensor([[ 6,  7,  8],
                  [ 9, 10, 11]]))

ts = torch.split(t2, [1, 2], 1)
#结果为:(tensor([[0],
                  [3],
                  [6],
                  [9]]),
          tensor([[1, 2],
                  [4, 5],
                  [7, 8],
                  [10, 11]]))

ts[0][0] = 1     # view进行修改
print(t2)        # 原对象同步改变
#结果为:tensor([[ 1,  1,  2],
                 [ 3,  4,  5],
                 [ 6,  7,  8],
                 [ 9, 10, 11]])

注意:tensor的split方法和array的split方法有很大的区别,array的split方法是根据索引进行切分,tensor的split方法是根据索引值再加上1来切分的。

你可能感兴趣的:(pytorch深度学习,pytorch,python,深度学习)