import torchvision
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor()
])
mnist_train_dataset = torchvision.datasets.MNIST(
root='./', train=True, transform=transform,download=True)
minist_test_dataset = torchvision.datasets.MNIST(
root='./', train=False, transform=transform, download=True)
from torch.utils.data import DataLoader
batch_size = 64
train_dl = DataLoader(mnist_train_dataset, batch_size, shuffle=True)
test_dl = DataLoader(mnist_test_dataset, batch_size, shuffle=True)
import torch.nn as nn
model = nn.Sequential()
model.add_module('conv1', nn.Conv2d(in_channels=1,out_channels=32, kernel_size=5, padding=2))
model.add_module('relu1', nn.ReLU())
model.add_module('pool1', nn.MaxPool2d(kernel_size=2))
model.add_module('conv2', nn.Conv2d(in_channels=32,out_channels=64, kernel_size=5, padding=2))
model.add_module('relu2', nn.ReLU())
model.add_module('pool2', nn.MaxPool2d(kernel_size=2))
model.add_module('flatten', nn.Flatten())
x = torch.ones((4, 1, 28, 28))
print(model(x).shape)
output:
torch.Size([4, 3136])
pool=2 代表的是 same padding, 即卷进前后 feature map 的height&width保持不变.
pytorch 的输入是NCHW 模式, N: batch number. C: channel number. H: height of feature map. W: width of feature map.
(tensorflow 的输入模式为NHWC)
接着全连接层:
model.add_module('fc1', nn.Linear(3136, 1024))
model.add_module('relu3', nn.ReLU())
model.add_module('dropout', nn.Dropout(p=0.5))
model.add_module('fc2', nn.Linear(1024, 10))
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
nn.CrossEntropyLoss() 里面包含了 softmax函数 所以在定义 模型的时候 最后一层 没有加 softmax 层
def train(model, num_epochs, trin_dl, valid_dl):
loss_hist_train = [0] * num_epochs
accuracy_hist_train = [0] * num_epochs
loss_hist_test = [0] * num_epochs
accuracy_hist_test = [0] * num_epochs
for epoch in range(num_epochs):
model.train()
for x_batch, y_batch in train_dl:
pred = model(x_batch)
loss = loss_fn(pred, y_batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss_hist_train[epoch] += loss.item()*y_batch.size(0)
is_crrect = (torch.argmax(pred, dim=1) == y_batch).float()
accuracy_hist_train[epoch] += is_crrect.sum()
loss_hist_train[epoch] /= len(train_dl.dataset)
accuracy_hist_train[epoch] /= len(train_dl.dataset)
model.eval()
with torch.no_grad():
for x_batch, y_batch in test_dl:
pred = model(x_batch)
loss = loss_fn(pred, y_batch)
loss_hist_test[epoch] += loss.item()*y_batch.size(0)
is_correct = (torch.argmax(pred, dim=1) == y_batch).float()
accuracy_hist_test[epoch] += is_correct.sum()
loss_hist_test[epoch] /= len(test_dl.dataset)
accuracy_hist_test[epoch] /= len(test_dl.dataset)
print(f'Epoch {epoch+1} accuracy: {accuracy_hist_train[epoch]:.4f}'
f'test_accuracy: {accuracy_hist_test[epoch]:.4f}')
return loss_hist_train, loss_hist_test, accuracy_hist_train, accuracy_hist_test
num_epochs= 20
hist = train(model, num_epochs, train_dl, test_dl)
Epoch 1 accuracy: 0.9868test_accuracy: 0.9925
Epoch 2 accuracy: 0.9902test_accuracy: 0.9914
参考自: Machine Learning with PyTorch and Scikit-Learn Book by Sebastian Raschka