pytorch与numpy张量拷贝需要注意的点

前言

最近在改一些小demo,让自己在假期之后也不会对相关编程语言过于生疏。今天在修改yolov5的过程中,我注意到了一些我之前没有注意到的点。那便是关于pytorch张量的拷贝到底哪些算是浅拷贝哪些是深拷贝的问题。进一步,我观察一些之前写的数据处理程序,发现不仅是pytorch的拷贝机制和以往不同,numpy的张量拷贝机制也同样和一般情况下是不一样的,接下来进行简要的介绍。

正文

pytorch

浅拷贝

所谓浅拷贝,为共享内存机制,两个变量都指向同一块地址

直接进行变量赋值,内存机制

import torch
a = torch.tensor([1])
b = a
a *= 10
print(a, b)

结果

tensor([10]) tensor([10])

索引操作

b = a[:]
a *= 10

结果

tensor([10]) tensor([10])

改变张量行列

  • view()
b = a.view(-1)
a *= 10

结果

tensor([10]) tensor([10])
  • reshape()
b = a.reshape(-1)
a *= 10

结果

tensor([10]) tensor([10])
  • flatten(input, dim)
b = torch.flatten(a, 0)
a *= 10

结果

tensor([10]) tensor([10])
  • expand(),expand_as()
b = a.expand((1,2))
a *= 10

结果

tensor([10]) tensor([[10, 10]])

与numpy相互转换

  • numpy()
b = a.numpy()
a *= 10

结果

tensor([10]) [10]
  • from_numpy()
a = np.array([1])
b = torch.from_numpy(a)
b *= 10

结果

[10] tensor([10], dtype=torch.int32)
  • as_tensor()
    该函数只有在与numpy转换时会表现出浅拷贝的性质

作为方法参数

def ppp(a):
    print(a)
    b = a
    b *= 10
    print(a, b)

if __name__=='__main__':
    a = torch.tensor([1])
    ppp(a)
    print(a)

结果

tensor([1])
tensor([10]) tensor([10])
tensor([10])

深层拷贝

clone

a = torch.tensor([1])
b = a.clone()
a *= 10

结果

tensor([10]) tensor([1])

copy_()

a = torch.tensor([1])
b = torch.zeros_like(a)
torch.Tensor.copy_(b, a)
a *= 10

结果

tensor([10]) tensor([1])

矩阵拼接

  • stack()
a = torch.tensor([1])
c = torch.zeros_like(a)
b = torch.stack((a, c), 1)
a *= 10

结果

tensor([10]) tensor([[1, 0]])
  • cat()
a = torch.tensor([1])
c = torch.zeros_like(a)
b = torch.cat((a, c), 1)
a *= 10
tensor([10]) tensor([1, 0])

矩阵运算

a = torch.tensor([1])
b = a * 1
b *= 10

结果

tensor([1]) tensor([10])

like类函数

  • ones_like()
  • zeros_like()

Numpy

浅拷贝

直接赋值

a = np.array([1])
b = a
b *= 10

结果

[10] [10]

作为参数传递

def ppp(a):
    print(a)
    b = a
    b *= 10
    print(a, b)

if __name__=='__main__':
    a = np.array([1])
    ppp(a)
    print(a)

结果

[1]
[10] [10]
[10]

深拷贝

copy

a = np.array([1])
b = a.copy()
b *= 10

结果

[1] [10]

矩阵运算

a = np.array([1])
b = a+1
b *= 10

结果

[1] [20]

结束

以上便是笔者目前所总结的,numpy()部分并不是很多,如有错误欢迎大家指正与补充!!!!

你可能感兴趣的:(炼丹心得)