Pytorch 的复制函数 torch.repeat_interleave() 和 torch.repeat()

1、torch.repeat_interleave()

repeat_interleave(self, repeats, dim)

self:张量数据

repeats:需要复制的份数

dim:需要复制的维度

import torch

a=torch.Tensor([[1,2,3,4],
                [2,3,4,5]])

print("--------------------------")
print(a.repeat_interleave(2, dim=0))

print("--------------------------")
print(a.repeat_interleave(2, dim=1))

# D:\Anaconda3\python.exe D:/0_me_python/3.py
# --------------------------
# tensor([[1., 2., 3., 4.],
#         [1., 2., 3., 4.],
#         [2., 3., 4., 5.],
#         [2., 3., 4., 5.]])
# --------------------------
# tensor([[1., 1., 2., 2., 3., 3., 4., 4.],
#         [2., 2., 3., 3., 4., 4., 5., 5.]])

2、torch.repeat()

repeat_interleave(self, repeats)

self:张量数据

repeats:需要复制的份数

注意:repeats复制的维度需要与目标数据维度相同,即数据为2维度,则复制的份数也必须是2维度的。

import torch

a=torch.Tensor([[1,2,3,4],
                [2,3,4,5]])

print("--------------------------")
print(a.repeat(1, 2))

print("--------------------------")
print(a.repeat(2, 1))

print("--------------------------")
print(a.repeat(4, 2))

# D:\Anaconda3\python.exe D:/0_me_python/3.py
# --------------------------
# tensor([[1., 2., 3., 4., 1., 2., 3., 4.],
#         [2., 3., 4., 5., 2., 3., 4., 5.]])
# --------------------------
# tensor([[1., 2., 3., 4.],
#         [2., 3., 4., 5.],
#         [1., 2., 3., 4.],
#         [2., 3., 4., 5.]])
# --------------------------
# tensor([[1., 2., 3., 4., 1., 2., 3., 4.],
#         [2., 3., 4., 5., 2., 3., 4., 5.],
#         [1., 2., 3., 4., 1., 2., 3., 4.],
#         [2., 3., 4., 5., 2., 3., 4., 5.],
#         [1., 2., 3., 4., 1., 2., 3., 4.],
#         [2., 3., 4., 5., 2., 3., 4., 5.],
#         [1., 2., 3., 4., 1., 2., 3., 4.],
#         [2., 3., 4., 5., 2., 3., 4., 5.]])
# 
# Process finished with exit code 0

你可能感兴趣的:(函数参数,pytorch,深度学习,机器学习)