model.py
# Define VGG-16 and VGG-19.
import torch
cfg = {
'VGG-16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'VGG-19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}
# VGG-16 and VGG-19
class VGGNet(torch.nn.Module):
def __init__(self, VGG_type, num_classes):
super(VGGNet, self).__init__()
self.features = self._make_layers(cfg[VGG_type])
self.classifier = torch.nn.Linear(512, num_classes)
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
def _make_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M': # MaxPool2d
layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [torch.nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
torch.nn.BatchNorm2d(x),
torch.nn.ReLU(inplace=True)]
in_channels = x
layers += [torch.nn.AvgPool2d(kernel_size=1, stride=1)]
return torch.nn.Sequential(*layers) # The number of parameters is more than one.
train.py
# import packages
import os
import sys
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from VGGnet.model import VGGNet
# Hyper-parameters
epochs = 300
batch_size = 100
learning_rate = 0.01
num_classes = 10
# Transform configuration and Data Augmentation.
transform_train = torchvision.transforms.Compose([torchvision.transforms.Pad(4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.RandomCrop(32),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# Load downloaded dataset.
train_dataset = torchvision.datasets.CIFAR10(root='data', download=True, train=True, transform=transform_train)
val_dataset = torchvision.datasets.CIFAR10(root='data', download=True, train=False, transform=transform_test)
# Data Loader.
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
# Make model.
net_name = 'VGG-16'
# net_name = 'VGG-19'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGGNet(net_name, num_classes).to(device)
# Loss and optimizer.
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
train_num = len(train_dataset)
val_num = len(val_dataset)
train_steps = len(train_loader)
val_steps = len(val_loader)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[136, 185], gamma=0.1)
resume = True # 设置是否需要从上次的状态继续训练
if resume:
if os.path.isfile("VGGnet.pth"):
print("Resume from checkpoint...")
checkpoint = torch.load("VGGnet.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
initepoch = checkpoint['epoch'] + 2
print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch'] + 1))
else:
print("====>no checkpoint found.")
initepoch = 1 # 如果没进行训练过,初始训练epoch值为1
writer = SummaryWriter("logs")
for epoch in range(initepoch - 1, epochs):
# train
print("-------第 {} 轮训练开始-------".format(epoch + 1))
model.train()
train_acc = 0.0
running_loss = 0.0
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
images, labels = data
optimizer.zero_grad()
outputs = model(images.to(device))
loss = loss_function(outputs, labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
_, predict = torch.max(outputs, dim=1)
train_acc += torch.eq(predict, labels.to(device)).sum().item()
train_loss = running_loss / train_steps
train_accurate = train_acc / train_num
# val
model.eval()
val_acc = 0.0
running_loss = 0.0
with torch.no_grad():
val_bar = tqdm(val_loader, file=sys.stdout)
for step, val_data in enumerate(val_bar):
val_images, val_labels = val_data
outputs = model(val_images.to(device))
loss = loss_function(outputs, val_labels.to(device))
running_loss += loss.item()
_, predict = torch.max(outputs, dim=1)
val_acc += torch.eq(predict, val_labels.to(device)).sum().item()
val_loss = running_loss / val_steps
val_accurate = val_acc / val_num
scheduler.step()
print('[epoch %d] train_loss: %.3f val_loss:%.3f train_accuracy:%.3f val_accuracy: %.3f' %
(epoch + 1, train_loss, val_loss, train_accurate, val_accurate))
writer.add_scalars('loss',
{'train': train_loss, 'val': val_loss}, global_step=epoch)
writer.add_scalars('acc',
{'train': train_accurate, 'val': val_accurate}, global_step=epoch)
# 保存断点
checkpoint = {"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "VGGnet.pth"
torch.save(checkpoint, path_checkpoint)
print("保存模型成功")
print('Finished Training')
writer.close()
程序设置了断点续训,可以接着训练,查看日志可以用tensorboard