说明:将输入张量分割成相等形状的chunks(如果可分)。如果沿指定维的张量形状大小不能被整分,则最后一块会小于其他分块。
参数:
tensor(Tensor) -- 待分割张量
split_size(int) -- 单个分块的形状大小
dim(int) -- 沿着此维进行分
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1135, 0.5779, -0.9737, -0.0718],
[ 0.4136, 1.1577, 0.5689, -0.1970],
[ 1.4281, 0.3540, 1.4346, -0.1444]])
>>> torch.split(x, 2, 1)
(tensor([[0.1135, 0.5779],
[0.4136, 1.1577],
[1.4281, 0.3540]]), tensor([[-0.9737, -0.0718],
[ 0.5689, -0.1970],
[ 1.4346, -0.1444]]))
>>> torch.split(x, 2, 0)
(tensor([[ 0.1135, 0.5779, -0.9737, -0.0718],
[ 0.4136, 1.1577, 0.5689, -0.1970]]), tensor([[ 1.4281, 0.3540, 1.4346, -0.1444]]))
原文链接:https://blog.csdn.net/gyt15663668337/article/details/91345951