PyTorch里eval和no_grad的关系

首先这两者有着本质上区别。

model.eval()是用来告知model内的各个layer采取eval模式工作。这个操作主要是应对诸如dropoutbatchnorm这些在训练模式下需要采取不同操作的特殊layer。训练和测试的时候都可以开启。
torch.no_grad()则是告知自动求导引擎不要进行求导操作。这个操作的意义在于加速计算、节约内存。但是由于没有gradient,也就没有办法进行backward。所以只能在测试的时候开启。

所以在evaluate的时候,需要同时使用两者。

model = ...
dataset = ...
loss_fun = ...

# training
lr=0.001
model.train()
for x,y in dataset:
	model.zero_grad()
	p = model(x)
	l = loss_fun(p, y)
	l.backward()
	for p in model.parameters():
		p.data -= lr*p.grad
	
# evaluating
sum_loss = 0.0
model.eval()
with torch.no_grad():
	for x,y in dataset:
		p = model(x)
		l = loss_fun(p, y)
		sum_loss += l
print('total loss:', sum_loss)

另外no_grad还可以作为函数是修饰符来用,从而简化代码。

def train(model, dataset, loss_fun, lr=0.001):
	model.train()
	for x,y in dataset:
		model.zero_grad()
		p = model(x)
		l = loss_fun(p, y)
		l.backward()
		for p in model.parameters():
			p.data -= lr*p.grad
	
@torch.no_grad()
def test(model, dataset, loss_fun):
	sum_loss = 0.0
	model.eval()
	for x,y in dataset:
		p = model(x)
		l = loss_fun(p, y)
		sum_loss += l
	return sum_loss

# main block:
model = ...
dataset = ...
loss_fun = ...

# training
train()
# test
sum_loss = test()
print('total loss:', sum_loss)

参考:
https://pytorch.org/docs/stable/generated/torch.no_grad.html
https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615

你可能感兴趣的:(Python,机器学习,pytorch)