torch.split(tensor, split_size, dim=0)

torch.split(tensor, split_size, dim=0)

说明:将输入张量分割成相等形状的chunks(如果可分)。如果沿指定维的张量形状大小不能被整分,则最后一块会小于其他分块。

参数

tensor(Tensor) -- 待分割张量
split_size(int) -- 单个分块的形状大小
dim(int) -- 沿着此维进行分

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1135,  0.5779, -0.9737, -0.0718],
        [ 0.4136,  1.1577,  0.5689, -0.1970],
        [ 1.4281,  0.3540,  1.4346, -0.1444]])
>>> torch.split(x, 2, 1)
(tensor([[0.1135, 0.5779],
        [0.4136, 1.1577],
        [1.4281, 0.3540]]), tensor([[-0.9737, -0.0718],
        [ 0.5689, -0.1970],
        [ 1.4346, -0.1444]]))
>>> torch.split(x, 2, 0)
(tensor([[ 0.1135,  0.5779, -0.9737, -0.0718],
        [ 0.4136,  1.1577,  0.5689, -0.1970]]), tensor([[ 1.4281,  0.3540,  1.4346, -0.1444]]))


原文链接:https://blog.csdn.net/gyt15663668337/article/details/91345951

 

 

你可能感兴趣的:(Python)