pytorch 打印网络回传梯度

需求:打印梯度,检查网络学习情况

net = your_network().cuda()
def train():
	...
	outputs = net(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
	for name, parms in net.named_parameters():	
		print('-->name:', name, '-->grad_requirs:',parms.requires_grad, \
		 ' -->grad_value:',parms.grad)
	...

打印结果如下: name表示网络参数的名字; parms.requires_grad 表示该参数是否可学习,是不是frozen的; parm.grad 打印该参数的梯度值。
pytorch 打印网络回传梯度_第1张图片

你可能感兴趣的:(python)