pytorch中copy_()、detach()、data()和clone()操作区别小结

文章目录

  • 1. clone
  • 2. copy_
  • 3. detach
  • 4. data

1. clone

b = a.clone()

创建一个tensor与源tensor有相同的shape,dtype和device,不共享内存地址,但新tensor(b)的梯度会叠加在源tensor(a)上。需要注意的是,b = a.clone()之后,b并非叶子节点,所以不可以访问它的梯度。

import torch

a = torch.tensor([1.,2.,3.],requires_grad=True)
b = a.clone()


print('===========================不共享地址=========================')
print(type(a), a.data_ptr())
print(type(b), b.data_ptr())
print('===========================clone后分别输出=========================')
print('a: ', a) # a:  tensor([1., 2., 3.], requires_grad=True)
print('b: ', b) #b:  tensor([1., 2., 3.], grad_fn=)

c = a ** 2
d = b ** 3
print('===========================反向传播=========================')
c.sum().backward() # 2* a
print('a.grad: ', a.grad)  #a.grad:  tensor([2., 4., 6.])

d.sum().backward() # 3b**2
print('a.grad: ', a.grad) #a.grad:  tensor([ 5., 16., 33.]) ,会将b梯度累加上去
#print('b.grad: ', b.grad) # b.grad:  None , 已经不属于计算图的叶子,不可以访问b.grad

输出:

===========================不共享地址=========================
<class 'torch.Tensor'> 93899916787840
<class 'torch.Tensor'> 93899917014528
===========================clone后分别输出=========================
a:  tensor([1., 2., 3.], requires_grad=True)
b:  tensor([1., 2., 3.], grad_fn=<CloneBackward0>)
===========================反向传播=========================
a.grad:  tensor([2., 4., 6.])
a.grad:  tensor([ 5., 16., 33.])

2. copy_

b = torch.empty_like(a).copy_(a)

copy_()函数是需要一个目标tensor,也就是说需要先构建b,然后将a拷贝给b,而clone操作则不需要。

copy_()函数完成与clone()函数 类似的功能,但也存在区别。调用copy_()的对象是目标tensor,参数是复制操作from的tensor,最后会返回目标tensor;而clone()的调用对象为源tensor,返回一个新tensor。当然clone()函数也可以采用torch.clone()调用,将源tensor作为参数。

import torch

a = torch.tensor([1., 2., 3.],requires_grad=True)
b = torch.empty_like(a).copy_(a)

print('====================copy_内存不一样======================')
print(a.data_ptr())
print(b.data_ptr()) 
print('====================copy_打印======================')
print(a) 
print(b) 

c = a ** 2
d = b ** 3
print('===================c反向传播=======================')
c.sum().backward()
print(a.grad) # tensor([2., 2., 2.])
print('===================d反向传播=======================')
d.sum().backward()
print(a.grad) # 源tensor梯度累加了
#print(b.grad) # None 

输出:

====================copy_内存不一样======================
94358408685568
94358463065088
====================copy_打印======================
tensor([1., 2., 3.], requires_grad=True)
tensor([1., 2., 3.], grad_fn=<CopyBackwards>)
===================c反向传播=======================
tensor([2., 4., 6.])
===================d反向传播=======================
tensor([ 5., 16., 33.])

3. detach

detach()函数返回与调用对象tensor相关的一个tensor,此新tensor与源tensor共享数据内存(那么tensor的数据必然是相同的),但其requires_grad为False,并且不包含源tensor的计算图信息

import torch

a = torch.tensor([1., 2., 3.],requires_grad=True)
b = a.detach()
print('=========================共享内存==============================')
print(a.data_ptr()) 
print(b.data_ptr()) 
print('=========================原值与detach==============================')
print(a) 
print(b) 

c = a * 2
d = b * 3 #不可以反向传播
print('=========================原值反向传播==============================')
c.sum().backward()
print(a.grad) 
print('=========================detach不可以反向传播==============================')
# d.sum().backward()

输出:

=========================共享内存==============================
94503766034432
94503766034432
=========================原值与detach==============================
tensor([1., 2., 3.], requires_grad=True)
tensor([1., 2., 3.])
=========================原值反向传播==============================
tensor([2., 2., 2.])
=========================detach不可以反向传播==============================

由于b已经从计算图脱离出来,pytorch自然也不跟踪其后续计算过程了。如果想要让b重新加入计算图,只需要b.requires_grad_()

pytorch可以继续跟踪b的计算,但梯度不会从b流回a,梯度被截断。但由于b与a共享内存,a与b的值会一直相等

4. data

data方法是得到一个tensor的数据信息,其返回的信息与上面提到的detach()返回的信息是相同的,也具有 内存相同,不保存梯度 信息的特点。但是data有时候不安全,因为它们共享内存,如果改变一个则另一个也跟着改变,而使用detach时候使用反向传播会报错

import torch
import pdb
x = torch.FloatTensor([[1., 2.]]) #默认x.requires_grad == False,只有float类型可以反向传播
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])

w1.requires_grad = True
w2.requires_grad = True

d = torch.matmul(x, w1) #相乘后,d的requires_grad = True(相加操作也是True)

d_ = d.data # d和d_会共享内存,d_的requires_grad  = False
# d_ = d.detach() #d和d_也会共享内存,但是不能反向传播

f = torch.matmul(d, w2)
d_[:] = 1 #d_修改了值,所以d的值也跟着改变

f.backward() #使用data会获取错误的值,使用detach则报错

参考:

https://zhuanlan.zhihu.com/p/38475183

你可能感兴趣的:(pytorch,pytorch,copy_,clone,data,detach)