Pytorch中requires_grad_(), detach(), torch.no_grad()的区别

1. 基本概念

Tensor是一个多维矩阵,其中包含所有的元素为同一数据类型。默认数据类型为torch.float32

  • 示例一
>>> a = torch.tensor([1.0])
>>> a.data
tensor([1.])
>>> a.grad
>>> a.requires_grad
False
>>> a.dtype
torch.float32
>>> a.item()
1.0
>>> type(a.item())

Tensor中只有一个数字时,使用torch.Tensor.item()可以得到一个Python数字。requires_gradTrue时,表示需要计算Tensor的梯度。requires_grad=False可以用来冻结部分网络,只更新另一部分网络的参数。

  • 示例二
>>> a = torch.tensor([1.0, 2.0])
>>> b = a.data
>>> id(b)
139808984381768
>>> id(a)
139811772112328
>>> b.grad
>>> a.grad
>>> b[0] = 5.0
>>> b
tensor([5., 2.])
>>> a
tensor([5., 2.])

a.data返回的是一个新的Tensor对象ba, bid不同,说明二者不是同一个Tensor,但ba共享数据的存储空间,即二者的数据部分指向同一块内存,因此修改b的元素时,a的元素也对应修改。

2. requires_grad_()与detach()

>>> a = torch.tensor([1.0, 2.0])
>>> a.data
tensor([1., 2.])
>>> a.grad
>>> a.requires_grad
False
>>> a.requires_grad_()
tensor([1., 2.], requires_grad=True)
>>> c = a.pow(2).sum()
>>> c.backward()
>>> a.grad
tensor([2., 4.])
>>> b = a.detach()
>>> b.grad
>>> b.requires_grad
False
>>> b
tensor([1., 2.])
>>> b[0] = 6
>>> b
tensor([6., 2.])
>>> a
tensor([6., 2.], requires_grad=True)
  • requires_grad_()

requires_grad_()函数会改变Tensorrequires_grad属性并返回Tensor,修改requires_grad的操作是原位操作(in place)。其默认参数为requires_grad=Truerequires_grad=True时,自动求导会记录对Tensor的操作,requires_grad_()的主要用途是告诉自动求导开始记录对Tensor的操作。

  • detach()

detach()就是截断反向传播的梯度流。detach()函数会返回一个新的Tensor对象b,并且新Tensor是与当前的计算图分离的,其requires_grad属性为False,反向传播时不会计算其梯度。ba共享数据的存储空间,二者指向同一块内存。

:共享内存空间只是共享的数据部分,a.gradb.grad是不同的。

3. torch.no_grad()

torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。

>>> a = torch.tensor([1.0, 2.0], requires_grad=True)
>>> with torch.no_grad():
...     b = n.pow(2).sum()
...
>>> b
tensor(5.)
>>> b.requires_grad
False
>>> c = a.pow(2).sum()
>>> c.requires_grad
True

上面的例子中,当arequires_grad=True时,不使用torch.no_grad()c.requires_gradTrue,使用torch.no_grad()时,b.requires_gradFalse,当不需要进行反向传播时(推断)或不需要计算梯度(网络输入)时,requires_grad=True会占用更多的计算资源及存储资源。

4. 总结

requires_grad_()会修改Tensorrequires_grad属性。

detach()会返回一个与计算图分离的新Tensor,新Tensor不会在反向传播中计算梯度,会在特定场合使用。

torch.no_grad()更节省计算资源和存储资源,其作用域范围内的操作不会构建计算图,常用在网络推断中。

你可能感兴趣的:(pytorch,python,pytorch,detach,torch.no.grad,requires_grad_)