继续更新pytorch的学习
# 定义数据处理方式
train_transform = transforms.Compose([
transforms.ToTensor()
])
# 构建数据dataloader
train_data = MNIST('./data', train=True, transform=train_transform, download=True)
valid_data = MNIST('./data', train=False, transform=train_transform, download=True)
train_iter = DataLoader(train_data, batch_size=128, shuffle=True)
valid_iter = DataLoader(valid_data, batch_size=128, shuffle=False)
# 构建网络模型
class MNIST_Net(nn.Module):
def __init__(self):
super().__init__()
self.f1 = nn.Linear(784, 128)
self.f2 = nn.Linear(128, 256)
self.f3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.f1(x))
x = F.relu(self.f2(x))
x = self.f3(x)
return x
利用model.train()和model.eval()
** with torch.no_grad()** 不需要反向传播
import torch.nn as nn
from torchvision.datasets import MNIST
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
import torch
# 定义数据处理方式
train_transform = transforms.Compose([
transforms.ToTensor()
])
# 构建数据dataloader
train_data = MNIST('./data', train=True, transform=train_transform, download=True)
valid_data = MNIST('./data', train=False, transform=train_transform, download=True)
train_iter = DataLoader(train_data, batch_size=128, shuffle=True)
valid_iter = DataLoader(valid_data, batch_size=128, shuffle=False)
# 构建网络模型
class MNIST_Net(nn.Module):
def __init__(self):
super().__init__()
self.f1 = nn.Linear(784, 128)
self.f2 = nn.Linear(128, 256)
self.f3 = nn.Linear(256, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.f1(x))
x = F.relu(self.f2(x))
x = self.f3(x)
return x
def get_model():
model = MNIST_Net()
return model, optim.SGD(model.parameters(), lr=0.001)
def fit(epochs, model, loss_func, opt, train_iter, valid_iter):
for epoch in range(epochs):
#
model.train()
for xb, yb in train_iter:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad():
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_iter]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
print('当前step:' + str(epoch), "验证集损失:" + str(val_loss))
def loss_batch(model, loss_func, xb, yb, opt = None):
loss = loss_func(model(xb), yb)
if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()
return loss.item(), len(xb)
model, opt = get_model()
fit(30, model, F.cross_entropy, opt, train_iter, valid_iter)