文章作者:Tyan
博客:noahsnail.com | CSDN | 简书
Python 3.6.9, Pytorch 1.5.0
Tensor
是一个多维矩阵,其中包含所有的元素为同一数据类型。默认数据类型为torch.float32
。
>>> a = torch.tensor([1.0])
>>> a.data
tensor([1.])
>>> a.grad
>>> a.requires_grad
False
>>> a.dtype
torch.float32
>>> a.item()
1.0
>>> type(a.item())
Tensor
中只有一个数字时,使用torch.Tensor.item()
可以得到一个Python数字。requires_grad
为True
时,表示需要计算Tensor
的梯度。requires_grad=False
可以用来冻结部分网络,只更新另一部分网络的参数。
>>> a = torch.tensor([1.0, 2.0])
>>> b = a.data
>>> id(b)
139808984381768
>>> id(a)
139811772112328
>>> b.grad
>>> a.grad
>>> b[0] = 5.0
>>> b
tensor([5., 2.])
>>> a
tensor([5., 2.])
a.data
返回的是一个新的Tensor
对象b
,a, b
的id
不同,说明二者不是同一个Tensor
,但b
与a
共享数据的存储空间,即二者的数据部分指向同一块内存,因此修改b
的元素时,a
的元素也对应修改。
>>> a = torch.tensor([1.0, 2.0])
>>> a.data
tensor([1., 2.])
>>> a.grad
>>> a.requires_grad
False
>>> a.requires_grad_()
tensor([1., 2.], requires_grad=True)
>>> c = a.pow(2).sum()
>>> c.backward()
>>> a.grad
tensor([2., 4.])
>>> b = a.detach()
>>> b.grad
>>> b.requires_grad
False
>>> b
tensor([1., 2.])
>>> b[0] = 6
>>> b
tensor([6., 2.])
>>> a
tensor([6., 2.], requires_grad=True)
requires_grad_()
requires_grad_()
函数会改变Tensor
的requires_grad
属性并返回Tensor
,修改requires_grad
的操作是原位操作(in place)。其默认参数为requires_grad=True
。requires_grad=True
时,自动求导会记录对Tensor
的操作,requires_grad_()
的主要用途是告诉自动求导开始记录对Tensor
的操作。
detach()
detach()
函数会返回一个新的Tensor
对象b
,并且新Tensor
是与当前的计算图分离的,其requires_grad
属性为False
,反向传播时不会计算其梯度。b
与a
共享数据的存储空间,二者指向同一块内存。
注:共享内存空间只是共享的数据部分,a.grad
与b.grad
是不同的。
torch.no_grad()
是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。
>>> a = torch.tensor([1.0, 2.0], requires_grad=True)
>>> with torch.no_grad():
... b = n.pow(2).sum()
...
>>> b
tensor(5.)
>>> b.requires_grad
False
>>> c = a.pow(2).sum()
>>> c.requires_grad
True
上面的例子中,当a
的requires_grad=True
时,不使用torch.no_grad()
,c.requires_grad
为True
,使用torch.no_grad()
时,b.requires_grad
为False
,当不需要进行反向传播时(推断)或不需要计算梯度(网络输入)时,requires_grad=True
会占用更多的计算资源及存储资源。
requires_grad_()
会修改Tensor
的requires_grad
属性。
detach()
会返回一个与计算图分离的新Tensor
,新Tensor
不会在反向传播中计算梯度,会在特定场合使用。
torch.no_grad()
更节省计算资源和存储资源,其作用域范围内的操作不会构建计算图,常用在网络推断中。