参考链接:pytorch图像分类篇:3.搭建AlexNet并训练花分类数据集
注意:下面模型代码里有些参数跟论文里不一样,是因为FashionMnist数据集图像是单通道图像,分辨率大小为28*28,跟论文里不一样,所以要改一些参数
model.py
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=1000, init_weights=False):
super(AlexNet, self).__init__()
# 使用nn.Sequential()将网络打包成一个模块,精简代码
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2)
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(256*3*3, 1024),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Linear(512, num_classes)
)
if init_weights:
self._initialize_weights()
#定义前向传播的过程
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01) #正态分布初始化
nn.init.constant_(m.bias, 0) #初始化偏重为0
train.py
参考:Pytorch官网例子
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
#AlexNet模型导入
from AlexNet.model import AlexNet
train_data = datasets.FashionMNIST(
root='./data',
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root='./data',
train=False,
download=True,
transform=ToTensor()
)
batch_size = 64
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
model = AlexNet(num_classes=10, init_weights=True).to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# 前向传播求出预测值
pred = model(X)
#求loss
loss = loss_fn(pred, y)
#将梯度设置为0,即初始化梯度
optimizer.zero_grad()
#反向传播求梯度
loss.backward()
#更新所有参数
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss: >7f} [{current: >5d}/{size: >5d}]")
def _test(dataloader, model):
size = len(dataloader.dataset)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /= size
print(f"Test Error: \n Acc: {correct}, Avg loss: {test_loss}")
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-----------------")
train(train_dataloader, model, loss_fn, optimizer)
_test(test_dataloader, model)
print("Done!")
# Save Models
torch.save(model.state_dict(), "AlexNet-FashionMnist.pth")
print("Saved PyTorch Model State to AlexNet-FashionMnist.pth")
predict.py
import torch
from AlexNet.model import AlexNet
from torchvision import datasets
from torchvision.transforms import ToTensor
test_data = datasets.FashionMNIST(
root='./data',
train=False,
download=False,
transform=ToTensor()
)
#这里就不随机找图片来预测了,要不然还要处理图片
x, y = test_data[0][0], test_data[0][1]
print(x.shape)
#这里注意要给图片数据x扩展一个维度
x = x.unsqueeze(0)
print(x.shape)
model = AlexNet(num_classes=10)
model.load_state_dict(torch.load("AlexNet-FashionMnist.pth"))
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
model.eval()
print(model)
with torch.no_grad():
pred = model(x)
predicted, actual = classes[pred[0].argmax(0)], classes[y]
print(f'Predicted: "{predicted}", Actual: "{actual}"')