基于pytorch框架下的一个简单的train与test代码

个人认为最大的差异有两点:
1,train需要使用到batch概念,test并不需要使用到batch概念。
2,train需要使用逆向传播,在逆向传播中更改W权重数值,test并不需要逆向传播。

def train(dataloader,model,loss_fn,optimizer):
      size=len(dataloader.dataset)
      for batch,(X,y) in enumerate(dataloader):
            X,y=X.to(device)
            y.to(device)
            pred=model(X)
            loss=loss_fn(pred,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()




def test(dataloader,model,loss_fn):
  size=len(dataloader.dataset)
  num_batches=len(dataloader)
  model.eval()
  test_loss,correct=0,0
  with torch.no_grad():
      for X,y in dataloader:
             X=x.to(device)
             y=y.to(device)
             pred=model(X)
#测试需要损失函数,并不需要逆向传播
             test_loss+=loss_fn(pred,y).item()
             correct+=(pred.argmax(1)==y).type(torch.float).sum,item()
    test_loss=/=num_batches
    correct/=size

你可能感兴趣的:(人工智能,pytorch入门)