所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
ImageNet是一个用于图像分类的超大数据集,它的官方网站是:http://www.image-net.org/ 。在cv领域,使用模型在ImageNet上的预训练参数来训练其他任务已经是一种普遍的做法。本文的目的是从零开始介绍如何在ImageNet上训练模型,就以最常用的ResNet50为例。
由于ImageNet数据集每年都会更新,通常我们指的ImageNet数据集是ILSVRC2012,该数据集共有1000个类,120万张训练集图片和5万张验证集图片。你可以在官方网站下载该数据集,也可以从我的百度云下载:链接:https://pan.baidu.com/s/1ROYJwexTvXN9bCzuyAYSuw 提取码:yn5z 。
下载后解压,文件夹组织结构如下:
ILSVRC2012
|
|--------train--------1000个子类文件夹
|
|--------val--------1000个子类文件夹
这样数据集就处理好了。
为了简便起见,我们使用pytorch官方提供的ResNet实现,但在加载模型处稍作修改。pytorch官方提供了ResNet的预训练模型,但该模型同时保存了模型结构和模型参数。官方模型的点数如下可以在这里查到:https://github.com/facebookarchive/fb.resnet.torch 。注意输入均为224x224。
Network | Top-1 error | Top-5 error |
---|---|---|
ResNet-18 | 30.43 | 10.76 |
ResNet-34 | 26.73 | 8.74 |
ResNet-50 | 24.01 | 7.02 |
ResNet-101 | 22.44 | 6.21 |
ResNet-152 | 22.16 | 6.16 |
ResNet-200 | 21.66 | 5.79 |
经过修改后的ResNet实现代码如下。注意我没有修改ResNet的网络结构,仅仅增加了ResNet34_half和ResNet50_half(即ResNet34和ResNet50 channel数减半)。另外,由于下面的训练中我保存的pth文件将只保存模型参数而不保存模型结构,所以对加载模型部分进行了修改,这样我们可以从本地加载训练好的pth模型参数,修改网络结构也很方便。当然,在ImageNet上训练ResNet时并不需要加载预训练模型。
"""
Deep Residual Learning for Image Recognition
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""
import os
import sys
BASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(
os.path.abspath(__file__)))))
sys.path.append(BASE_DIR)
from public.path import pretrained_models_path
import torch
import torch.nn as nn
__all__ = [
'ResNet',
'resnet18',
'resnet34_half',
'resnet34',
'resnet50_half',
'resnet50',
'resnet101',
'resnet152',
'resnext50_32x4d',
'resnext101_32x8d',
'wide_resnet50_2',
'wide_resnet101_2',
]
model_urls = {
'resnet18':
'{}/resnet/resnet18-epoch100-acc70.316.pth'.format(pretrained_models_path),
'resnet34_half':
'{}/resnet/resnet34_half-epoch100-acc67.472.pth'.format(
pretrained_models_path),
'resnet34':
'{}/resnet/resnet34-epoch100-acc73.736.pth'.format(pretrained_models_path),
'resnet50_half':
'{}/resnet/resnet50_half-epoch100-acc72.066.pth'.format(
pretrained_models_path),
'resnet50':
'{}/resnet/resnet50-epoch100-acc76.512.pth'.format(pretrained_models_path),
'resnet101':
'{}/resnet/resnet101-epoch100-acc77.724.pth'.format(
pretrained_models_path),
'resnet152':
'{}/resnet/resnet152-epoch100-acc78.564.pth'.format(
pretrained_models_path),
'resnext50_32x4d':
'empty',
'resnext101_32x8d':
'empty',
'wide_resnet50_2':
'empty',
'wide_resnet101_2':
'empty',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block,
layers,
inplanes=64,
num_classes=1000,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = inplanes
self.interplanes = [
self.inplanes, self.inplanes * 2, self.inplanes * 4,
self.inplanes * 8
]
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(
replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3,
self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, self.interplanes[0], layers[0])
self.layer2 = self._make_layer(block,
self.interplanes[1],
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
self.interplanes[2],
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
self.interplanes[3],
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(self.interplanes[3] * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
# only load state_dict()
if pretrained:
model.load_state_dict(
torch.load(model_urls[arch], map_location=torch.device('cpu')))
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34_half(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['inplanes'] = 32
return _resnet('resnet34_half', BasicBlock, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50_half(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['inplanes'] = 32
return _resnet('resnet50_half', Bottleneck, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
progress, **kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" `_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" `_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)
训练代码根据pytorch官方给出的在ImageNet上的训练代码修改而来:https://github.com/pytorch/examples/tree/master/imagenet 。首先去除分布式训练部分的代码,使得代码只支持在nn.parallel模式下训练。为了更好的进行训练和查看训练结果,我将所有超参数写入一个单独的config.py文件,而训练过程写入train.py文件。同时,在训练时会自动生成log,方便训练完成后查看训练中途的acc、loss等结果。最后,我们将config.py文件和train.py文件放到同一个文件夹下,只需要运行下面一行指令即可开始训练。
python3 train.py
config.py文件如下(以ResNet50为例,实际上ResNet系列网络的训练超参数都一样):
import os
import sys
BASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
from public.path import ILSVRC2012_path
import torchvision.transforms as transforms
import torchvision.datasets as datasets
class Config(object):
log = './log' # Path to save log
checkpoint_path = './checkpoints' # Path to store checkpoint model
resume = './checkpoints/latest.pth' # load checkpoint model
evaluate = None # evaluate model path
train_dataset_path = os.path.join(ILSVRC2012_path, 'train')
val_dataset_path = os.path.join(ILSVRC2012_path, 'val')
network = "resnet50"
pretrained = False
num_classes = 1000
seed = 0
input_image_size = 224
scale = 256 / 224
train_dataset = datasets.ImageFolder(
train_dataset_path,
transforms.Compose([
transforms.RandomResizedCrop(input_image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]))
val_dataset = datasets.ImageFolder(
val_dataset_path,
transforms.Compose([
transforms.Resize(int(input_image_size * scale)),
transforms.CenterCrop(input_image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]))
milestones = [30, 60, 90]
epochs = 100
batch_size = 256
accumulation_steps = 1
lr = 0.1
weight_decay = 1e-4
momentum = 0.9
num_workers = 8
print_interval = 100
apex = False
apex即NVIDIA提供的混合精度训练,好处是可以在训练出同样性能表现的模型的情况下降低25%-30%的训练显存占用,但训练速度会稍微变慢一些。
train.py文件如下(ResNet系列网络的训练设置都一样):
import sys
import os
import argparse
import random
import time
import warnings
BASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')
from apex import amp
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from thop import profile
from thop import clever_format
from torch.utils.data import DataLoader
from config import Config
from public.imagenet import models
from public.imagenet.utils import DataPrefetcher, get_logger, AverageMeter, accuracy
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--network',
type=str,
default=Config.network,
help='name of network')
parser.add_argument('--lr',
type=float,
default=Config.lr,
help='learning rate')
parser.add_argument('--momentum',
type=float,
default=Config.momentum,
help='momentum')
parser.add_argument('--weight_decay',
type=float,
default=Config.weight_decay,
help='weight decay')
parser.add_argument('--epochs',
type=int,
default=Config.epochs,
help='num of training epochs')
parser.add_argument('--batch_size',
type=int,
default=Config.batch_size,
help='batch size')
parser.add_argument('--milestones',
type=list,
default=Config.milestones,
help='optimizer milestones')
parser.add_argument('--accumulation_steps',
type=int,
default=Config.accumulation_steps,
help='gradient accumulation steps')
parser.add_argument('--pretrained',
type=bool,
default=Config.pretrained,
help='load pretrained model params or not')
parser.add_argument('--num_classes',
type=int,
default=Config.num_classes,
help='model classification num')
parser.add_argument('--input_image_size',
type=int,
default=Config.input_image_size,
help='input image size')
parser.add_argument('--num_workers',
type=int,
default=Config.num_workers,
help='number of worker to load data')
parser.add_argument('--resume',
type=str,
default=Config.resume,
help='put the path to resuming file if needed')
parser.add_argument('--checkpoints',
type=str,
default=Config.checkpoint_path,
help='path for saving trained models')
parser.add_argument('--log',
type=str,
default=Config.log,
help='path to save log')
parser.add_argument('--evaluate',
type=str,
default=Config.evaluate,
help='path for evaluate model')
parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
parser.add_argument('--print_interval',
type=bool,
default=Config.print_interval,
help='print interval')
parser.add_argument('--apex',
type=bool,
default=Config.apex,
help='use apex or not')
return parser.parse_args()
def train(train_loader, model, criterion, optimizer, scheduler, epoch, logger,
args):
top1 = AverageMeter()
top5 = AverageMeter()
losses = AverageMeter()
# switch to train mode
model.train()
iters = len(train_loader.dataset) // args.batch_size
prefetcher = DataPrefetcher(train_loader)
inputs, labels = prefetcher.next()
iter_index = 1
while inputs is not None:
inputs, labels = inputs.cuda(), labels.cuda()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss = loss / args.accumulation_steps
if args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if iter_index % args.accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# measure accuracy and record loss
acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
top1.update(acc1.item(), inputs.size(0))
top5.update(acc5.item(), inputs.size(0))
losses.update(loss.item(), inputs.size(0))
inputs, labels = prefetcher.next()
if iter_index % args.print_interval == 0:
logger.info(
f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}"
)
iter_index += 1
scheduler.step()
return top1.avg, top5.avg, losses.avg
def validate(val_loader, model, args):
batch_time = AverageMeter()
data_time = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for inputs, labels in val_loader:
data_time.update(time.time() - end)
inputs, labels = inputs.cuda(), labels.cuda()
outputs = model(inputs)
acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
top1.update(acc1.item(), inputs.size(0))
top5.update(acc5.item(), inputs.size(0))
batch_time.update(time.time() - end)
end = time.time()
throughput = 1.0 / (batch_time.avg / inputs.size(0))
return top1.avg, top5.avg, throughput
def main(logger, args):
if not torch.cuda.is_available():
raise Exception("need gpu to train network!")
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = True
gpus = torch.cuda.device_count()
logger.info(f'use {gpus} gpus')
logger.info(f"args: {args}")
cudnn.benchmark = True
cudnn.enabled = True
start_time = time.time()
# dataset and dataloader
logger.info('start loading data')
train_loader = DataLoader(Config.train_dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=args.num_workers)
val_loader = DataLoader(Config.val_dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=args.num_workers)
logger.info('finish loading data')
logger.info(f"creating model '{args.network}'")
model = models.__dict__[args.network](**{
"pretrained": args.pretrained,
"num_classes": args.num_classes,
})
flops_input = torch.randn(1, 3, args.input_image_size,
args.input_image_size)
flops, params = profile(model, inputs=(flops_input, ))
flops, params = clever_format([flops, params], "%.3f")
logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")
for name, param in model.named_parameters():
logger.info(f"{name},{param.requires_grad}")
model = model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=args.milestones, gamma=0.1)
if args.apex:
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
model = nn.DataParallel(model)
if args.evaluate:
if not os.path.isfile(args.evaluate):
raise Exception(
f"{args.resume} is not a file, please check it again")
logger.info('start only evaluating')
logger.info(f"start resuming model from {args.evaluate}")
checkpoint = torch.load(args.evaluate,
map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
acc1, acc5, throughput = validate(val_loader, model, args)
logger.info(
f"epoch {checkpoint['epoch']:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
)
return
start_epoch = 1
# resume training
if os.path.exists(args.resume):
logger.info(f"start resuming model from {args.resume}")
checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
start_epoch += checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
logger.info(
f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, "
f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
f"top1_acc: {checkpoint['acc1']}%")
if not os.path.exists(args.checkpoints):
os.makedirs(args.checkpoints)
logger.info('start training')
for epoch in range(start_epoch, args.epochs + 1):
acc1, acc5, losses = train(train_loader, model, criterion, optimizer,
scheduler, epoch, logger, args)
logger.info(
f"train: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, losses: {losses:.2f}"
)
acc1, acc5, throughput = validate(val_loader, model, args)
logger.info(
f"val: epoch {epoch:0>3d}, top1 acc: {acc1:.2f}%, top5 acc: {acc5:.2f}%, throughput: {throughput:.2f}sample/s"
)
# remember best prec@1 and save checkpoint
torch.save(
{
'epoch': epoch,
'acc1': acc1,
'loss': losses,
'lr': scheduler.get_lr()[0],
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}, os.path.join(args.checkpoints, 'latest.pth'))
if epoch == args.epochs:
torch.save(
model.module.state_dict(),
os.path.join(
args.checkpoints,
"{}-epoch{}-acc{}.pth".format(args.network, epoch, acc1)))
training_time = (time.time() - start_time) / 3600
logger.info(
f"finish training, total training time: {training_time:.2f} hours")
if __name__ == '__main__':
args = parse_args()
logger = get_logger(__name__, args.log)
main(logger, args)
在训练ResNet网络时,我们共训练100个epoch,采用SGD优化器,momentum=0.9,weight decay=1e-4。对于学习率,采用multistep的方式进行衰减,初始lr设为0.1,在30、60、90个epoch均将lr除以10。保存100个epoch时的模型参数。对于resnet,不使用warm up。训练时的log可以在训练实验的目录/log/main.info.log中找到。
为什么scheduler.step()放在每个epoch训练结束以后而不是放在每个epoch训练开始时?
scheduler.step()的作用是根据当前epoch数衰减学习率。由于我们训练时训练可能会被中断,上面的代码中能够自动根据latest.pth读取之前保存的模型从中断处继续开始训练。如果scheduler.step()放在每个epoch训练开始时,如果训练到这个epoch的一半后训练被中断了,那么latest.pth并不会更新(因为要跑完一个epoch才会更新),此时scheduler.step()就被多更新了一次,所有scheduler.step()必须放在末尾。
为什么每一个单独的实验都需要自己的config.py文件和train.py文件?
超参数全部写在config.py文件中,这主要是为了方便查找当时训练时的超参数设置。train.py不复用的原因是在做实验时,我们往往会尝试一些不太常用且仅使用几次的训练方式,这个时候往往要对train.py文件做较大的改动,每个实验都有独立的train.py文件改动时就很方便了。
train.py中还包含了DataPrefetcher, get_logger, AverageMeter, accuracy函数,代码如下:
import os
import torch
import logging
from logging.handlers import TimedRotatingFileHandler
def get_logger(name, log_dir='log'):
"""
Args:
name(str): name of logger
log_dir(str): path of log
"""
if not os.path.exists(log_dir):
os.makedirs(log_dir)
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
info_name = os.path.join(log_dir, '{}.info.log'.format(name))
info_handler = TimedRotatingFileHandler(info_name,
when='D',
encoding='utf-8')
info_handler.setLevel(logging.INFO)
error_name = os.path.join(log_dir, '{}.error.log'.format(name))
error_handler = TimedRotatingFileHandler(error_name,
when='D',
encoding='utf-8')
error_handler.setLevel(logging.ERROR)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
info_handler.setFormatter(formatter)
error_handler.setFormatter(formatter)
logger.addHandler(info_handler)
logger.addHandler(error_handler)
return logger
class DataPrefetcher():
def __init__(self, loader):
self.loader = iter(loader)
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
sample = next(self.loader)
self.next_input, self.next_target = sample
except StopIteration:
self.next_input = None
self.next_target = None
return
with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True)
self.next_input = self.next_input.float()
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input
target = self.next_target
self.preload()
return input, target
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1, )):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
DataPrefetcher函数:在这一次训练迭代时预读取下一个batch的数据,可以使训练过程加快。
get_logger函数:用于记录训练中的log信息。
AverageMeter函数:用于计算每个epoch中各个指标的平均值。
accuracy函数:用于计算模型在验证集上的性能表现。
根据上面这套代码训练的ResNet结果如下:
Network | Top-1 error |
---|---|
ResNet-18 | 29.684 |
ResNet-34-half | 32.528 |
ResNet-34 | 26.264 |
ResNet-50-half | 27.934 |
ResNet-50 | 23.488 |
ResNet-101 | 22.276 |
ResNet-152 | 21.436 |
可以看到基本上所有的ResNet点数都要比官方模型高0.5个百分点左右。