pytorch函数理解(一)detach,requires_grad和volatile

仅供思考使用,转载于:https://www.jianshu.com/p/f1bd4ff84926

在跑CIN的代码时,将batch_size从10一路降到2,依然每执行sample就爆显存.请教师兄后,发现问题出在这一句上:

在进行sample的时候,不止保存之前的变量fake,而且还保存了fake前所有的梯度.计算图进行累积,那样不管有多大显存都是放不下的.

之后,在self.G(real_x, target_c)[0]后面加上了.detach(),代码就可以顺利运行了.

查阅pytorch的官方文档,上面是这么说的:

关于detach

简单来说,就是创建一个新的tensor,将其从当前的计算图中分离出来.新的tensor与之前的共享data,但是不具有梯度.在任意一个tensor上进行原地操作都会报错(what?)

进行验证发现,v_c是具有梯度的,但是进行detach之后创建的新变量v_c_detached是不具有梯度的.

对v_c_detached进行修改,v_c的data值也会改变.说明他们是共享同一块显存的.

在pytorch中,autograd是由计算图实现的.Variable是autograd的核心数据结构,其构成分为三部分: data(tensor), grad(也是Variable), grad_fn(得到这一节点的直接操作).对于requires_grad为false的节点,是不具有grad的.

计算图

用户自己创建的节点是leaf_node(如图中的abc三个节点),不依赖于其他变量,对于leaf_node不能进行in_place操作.根节点是计算图的最终目标(如图y),通过链式法则可以计算出所有节点相对于根节点的梯度值.这一过程通过调用root.backward()就可以实现.

因此,detach所做的就是,重新声明一个变量,指向原变量的存放位置,但是requires_grad为false.更深入一点的理解是,计算图从detach过的变量这里就断了, 它变成了一个leaf_node.即使之后重新将它的requires_node置为true,它也不会具有梯度.

y=y.detach()后

另一方面,在调用完backward函数之后,非leaf_node的梯度计算完会立刻被清空.这也是为什么在执行backward之前显存占用很大,执行完之后显存占用立刻下降很多的原因.当然,这其中也包含了一些中间结果被存在buffer中,调用结束后也会被释放.

至于另一个参数volatile,如果一个变量的volatile=true,它可以将所有依赖于它的节点全部设为volatile=true,优先级高于requires_grad=true.这样的节点不会进行求导,即使requires_grad为真,也无法进行反向传播.在inference中如果采用这种设置,可以实现一定程度的速度提升,并且节约大概一半显存.

作者:nowherespyfly

链接:https://www.jianshu.com/p/f1bd4ff84926

来源:

著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

你可能感兴趣的:(pytorch函数理解(一)detach,requires_grad和volatile)