b = a.clone()
创建一个tensor与源tensor有相同的shape,dtype和device,不共享内存地址,但新tensor(b)的梯度会叠加在源tensor(a)上
。需要注意的是,b = a.clone()之后,b并非叶子节点,所以不可以访问它的梯度。
import torch
a = torch.tensor([1.,2.,3.],requires_grad=True)
b = a.clone()
print('===========================不共享地址=========================')
print(type(a), a.data_ptr())
print(type(b), b.data_ptr())
print('===========================clone后分别输出=========================')
print('a: ', a) # a: tensor([1., 2., 3.], requires_grad=True)
print('b: ', b) #b: tensor([1., 2., 3.], grad_fn=)
c = a ** 2
d = b ** 3
print('===========================反向传播=========================')
c.sum().backward() # 2* a
print('a.grad: ', a.grad) #a.grad: tensor([2., 4., 6.])
d.sum().backward() # 3b**2
print('a.grad: ', a.grad) #a.grad: tensor([ 5., 16., 33.]) ,会将b梯度累加上去
#print('b.grad: ', b.grad) # b.grad: None , 已经不属于计算图的叶子,不可以访问b.grad
输出:
===========================不共享地址=========================
<class 'torch.Tensor'> 93899916787840
<class 'torch.Tensor'> 93899917014528
===========================clone后分别输出=========================
a: tensor([1., 2., 3.], requires_grad=True)
b: tensor([1., 2., 3.], grad_fn=<CloneBackward0>)
===========================反向传播=========================
a.grad: tensor([2., 4., 6.])
a.grad: tensor([ 5., 16., 33.])
b = torch.empty_like(a).copy_(a)
copy_()函数是需要一个目标tensor,也就是说需要先构建b,然后将a拷贝给b,而clone操作则不需要。
copy_()函数完成与clone()函数 类似的功能
,但也存在区别。调用copy_()的对象是目标tensor,参数是复制操作from的tensor,最后会返回目标tensor;而clone()的调用对象为源tensor,返回一个新tensor。当然clone()函数也可以采用torch.clone()调用,将源tensor作为参数。
import torch
a = torch.tensor([1., 2., 3.],requires_grad=True)
b = torch.empty_like(a).copy_(a)
print('====================copy_内存不一样======================')
print(a.data_ptr())
print(b.data_ptr())
print('====================copy_打印======================')
print(a)
print(b)
c = a ** 2
d = b ** 3
print('===================c反向传播=======================')
c.sum().backward()
print(a.grad) # tensor([2., 2., 2.])
print('===================d反向传播=======================')
d.sum().backward()
print(a.grad) # 源tensor梯度累加了
#print(b.grad) # None
输出:
====================copy_内存不一样======================
94358408685568
94358463065088
====================copy_打印======================
tensor([1., 2., 3.], requires_grad=True)
tensor([1., 2., 3.], grad_fn=<CopyBackwards>)
===================c反向传播=======================
tensor([2., 4., 6.])
===================d反向传播=======================
tensor([ 5., 16., 33.])
detach()函数返回与调用对象tensor相关的一个tensor,此新tensor与源tensor共享数据内存
(那么tensor的数据必然是相同的),但其requires_grad为False,并且不包含源tensor的计算图信息
。
import torch
a = torch.tensor([1., 2., 3.],requires_grad=True)
b = a.detach()
print('=========================共享内存==============================')
print(a.data_ptr())
print(b.data_ptr())
print('=========================原值与detach==============================')
print(a)
print(b)
c = a * 2
d = b * 3 #不可以反向传播
print('=========================原值反向传播==============================')
c.sum().backward()
print(a.grad)
print('=========================detach不可以反向传播==============================')
# d.sum().backward()
输出:
=========================共享内存==============================
94503766034432
94503766034432
=========================原值与detach==============================
tensor([1., 2., 3.], requires_grad=True)
tensor([1., 2., 3.])
=========================原值反向传播==============================
tensor([2., 2., 2.])
=========================detach不可以反向传播==============================
由于b已经从计算图脱离出来,pytorch自然也不跟踪其后续计算过程了。如果想要让b重新加入计算图,只需要b.requires_grad_()
。
pytorch可以继续跟踪b的计算,但梯度不会从b流回a,梯度被截断。但由于b与a共享内存,a与b的值会一直相等
。
data方法是得到一个tensor的数据信息,其返回的信息与上面提到的detach()返回的信息是相同的,也具有 内存相同,不保存梯度
信息的特点。但是data有时候不安全,因为它们共享内存,如果改变一个则另一个也跟着改变,而使用detach时候使用反向传播会报错
。
import torch
import pdb
x = torch.FloatTensor([[1., 2.]]) #默认x.requires_grad == False,只有float类型可以反向传播
w1 = torch.FloatTensor([[2.], [1.]])
w2 = torch.FloatTensor([3.])
w1.requires_grad = True
w2.requires_grad = True
d = torch.matmul(x, w1) #相乘后,d的requires_grad = True(相加操作也是True)
d_ = d.data # d和d_会共享内存,d_的requires_grad = False
# d_ = d.detach() #d和d_也会共享内存,但是不能反向传播
f = torch.matmul(d, w2)
d_[:] = 1 #d_修改了值,所以d的值也跟着改变
f.backward() #使用data会获取错误的值,使用detach则报错
参考:
https://zhuanlan.zhihu.com/p/38475183