这篇博客来写一下我对BNN(二值化神经网络)pytorch代码的理解,我是第一次阅读项目代码,所以想仔细的自己写一遍,把细节理解透彻,希望也能帮到大家!
论文链接:https://papers.nips.cc/paper/6573-binarized-neural-networks
代码链接: https://github.com/itayhubara/BinaryNet.pytorch
models 网络结构构建脚本集合
init.py 初始化脚本
alexnet.py alexnet pytorch版本实现
alexnet_binary.py 对alexnet进行二值化代码实现
binarized_modules.py 量化函数实现
resnet.py resnet pytorch版本实现
resnet_binary.py 对resnet进行二值化代码实现
vgg_cifar10.py vggnet pytorch版本实现
vgg_cifar10_binary.py 对vggnet进行二值化代码实现
data.py 数据读取脚本
main_binary.py 训练+测试脚本
main_binary_hinge.py 训练+测试+hinge_loss脚本
main_mnist.py MNIST数据集训练+测试脚本
preprocess.py 数据预处理相关的脚本
utils.py 参数记录日志脚本
首先是导入模块:
import argparse
import os
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import models
from torch.autograd import Variable
from data import get_dataset
from preprocess import get_transform
from utils import *
from datetime import datetime
from ast import literal_eval
from torchvision.utils import save_image
这里的argparse包、logging包、ast包以前都没有接触过,Google一下,下面直接结合代码来说了:
argparse是python标准库里面用来处理命令行参数的库 ,说白了就是用来写命令行的,我们在项目工程代码里显然不能跟以前写代码一样一个按键跑完代码,学习必要的命令行代码编写是必要的,结合代码来分析吧:
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
这一步就是把models文件夹里的脚本文件进行排序然后封装到model_names里,用于后面的参数添加。
parser = argparse.ArgumentParser(description='PyTorch ConvNet Training')
显然,这一句就是初始化一个argparse的对象。
parser.add_argument
parser.add_argument('--results_dir', metavar='RESULTS_DIR', default='./results',
help='results dir')
...
parser.add_argument('-e', '--evaluate', type=str, metavar='FILE',
help='evaluate model FILE on validation set')
这一段是大量的重复格式代码,parser.add_argument:
ArgumentParser.add_argument(name or flags…[, action][, nargs][, const][, default][, type][, choices][, required][, help][, metavar][, dest])
name of flags 是必须的参数,该参数接受选项参数或者是位置参数。 例如上面的’–results_dir’,在启动程序main_binary.py时, ./main_binary.py --results_dir xxx,就会把xxx赋给results_dir, 下面重复的就不赘述了,当没有参数时,会从default取值。
def main():
global args, best_prec1
best_prec1 = 0
args = parser.parse_args()
if args.evaluate:
args.results_dir = '/tmp'
if args.save is '':
args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
save_path = os.path.join(args.results_dir, args.save)
if not os.path.exists(save_path):
os.makedirs(save_path)
如果参数中有evaluate项,那么结果目录前要加上/tmp,表示暂存,这段代码就是得到训练结果的保存路径。
logging包
logging模块是Python内置的标准模块,主要用于输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件回滚等;相比print,具备如下优点:
1). 可以通过设置不同的日志等级,在release版本中只输出重要信息,而不必显示大量的调试信息;
2). print将所有信息都输出到标准输出中,严重影响开发者从标准输出中查看其它数据;logging则可以由开发者决定将信息输出到什么地方,以及怎么输出;
setup_logging(os.path.join(save_path, 'log.txt'))
results_file = os.path.join(save_path, 'results.%s')
results = ResultsLog(results_file % 'csv', results_file % 'html')
logging.info("saving to %s", save_path)
logging.debug("run arguments: %s", args)
setup_logging函数的实现在utils.py中实现:
def setup_logging(log_file='log.txt'):
"""Setup logging configuration
"""
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filename=log_file,
filemode='w')
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(message)s')
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
ResultsLog类也在utils.py中实现:
class ResultsLog(object):
def __init__(self, path='results.csv', plot_path=None):
self.path = path
self.plot_path = plot_path or (self.path + '.html')
self.figures = []
self.results = None
def add(self, **kwargs):
df = pd.DataFrame([kwargs.values()], columns=kwargs.keys())
if self.results is None:
self.results = df
else:
self.results = self.results.append(df, ignore_index=True)
def save(self, title='Training Results'):
if len(self.figures) > 0:
if os.path.isfile(self.plot_path):
os.remove(self.plot_path)
output_file(self.plot_path, title=title)
plot = column(*self.figures)
save(plot)
self.figures = []
self.results.to_csv(self.path, index=False, index_label=False)
def load(self, path=None):
path = path or self.path
if os.path.isfile(path):
self.results.read_csv(path)
def show(self):
if len(self.figures) > 0:
plot = column(*self.figures)
show(plot)
#def plot(self, *kargs, **kwargs):
# line = Line(data=self.results, *kargs, **kwargs)
# self.figures.append(line)
def image(self, *kargs, **kwargs):
fig = figure()
fig.image(*kargs, **kwargs)
self.figures.append(fig)
到这里会在日志里输出保存至的路径和所有的输入参数。
if 'cuda' in args.type:
args.gpus = [int(i) for i in args.gpus.split(',')]
torch.cuda.set_device(args.gpus[0])
cudnn.benchmark = True
else:
args.gpus = None
这里应该是cuda gpu之类的,暂且跳过。
下面开始创造模型:
# create model
logging.info("creating model %s", args.model)
model = models.__dict__[args.model]
model_config = {'input_size': args.input_size, 'dataset': args.dataset}
if args.model_config is not '':
model_config = dict(model_config, **literal_eval(args.model_config))
model = model(**model_config)
logging.info("created model with configuration: %s", model_config)
ast包
上面的literal_eval函数就是ast包的函数,简单点说ast模块就是帮助Python应用来处理抽象的语法解析的。而该模块下的 literal_eval() 函数:则会判断需要计算的内容计算后是不是合法的python类型,如果是则进行运算,否则就不进行运算。
在上面一段代码的末尾调用了model,在前面的代码中model被定义为 :
model = models.__dict__[args.model]
调用model将返回该模型的所有属性。
# optionally resume from a checkpoint
if args.evaluate:
if not os.path.isfile(args.evaluate):
parser.error('invalid checkpoint: {}'.format(args.evaluate))
checkpoint = torch.load(args.evaluate)
model.load_state_dict(checkpoint['state_dict'])
logging.info("loaded checkpoint '%s' (epoch %s)",
args.evaluate, checkpoint['epoch'])
elif args.resume:
checkpoint_file = args.resume
if os.path.isdir(checkpoint_file):
results.load(os.path.join(checkpoint_file, 'results.csv'))
checkpoint_file = os.path.join(
checkpoint_file, 'model_best.pth.tar')
if os.path.isfile(checkpoint_file):
logging.info("loading checkpoint '%s'", args.resume)
checkpoint = torch.load(checkpoint_file)
args.start_epoch = checkpoint['epoch'] - 1
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
logging.info("loaded checkpoint '%s' (epoch %s)",
checkpoint_file, checkpoint['epoch'])
else:
logging.error("no checkpoint found at '%s'", args.resume)
num_parameters = sum([l.nelement() for l in model.parameters()])
logging.info("number of parameters: %d", num_parameters)
这一部分是加载已有的参数(如果有)。
下面是加载数据部分:
# Data loading code
default_transform = {
'train': get_transform(args.dataset,
input_size=args.input_size, augment=True),
'eval': get_transform(args.dataset,
input_size=args.input_size, augment=False)
}
transform = getattr(model, 'input_transform', default_transform)
regime = getattr(model, 'regime', {0: {'optimizer': args.optimizer,
'lr': args.lr,
'momentum': args.momentum,
'weight_decay': args.weight_decay}})
这部分的函数都在preprocess.py中实现,具体的就不写了,毕竟这不是个数据处理博客- -。
# define loss function (criterion) and optimizer
criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
criterion.type(args.type)
model.type(args.type)
这部分是损失函数的定义。
val_data = get_dataset(args.dataset, 'val', transform['eval'])
val_loader = torch.utils.data.DataLoader(
val_data,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
if args.evaluate:
validate(val_loader, model, criterion, 0)
return
train_data = get_dataset(args.dataset, 'train', transform['train'])
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
logging.info('training regime: %s', regime)
上面是加载数据集以及优化方法部分,下面就要开始训练啦:
for epoch in range(args.start_epoch, args.epochs):
optimizer = adjust_optimizer(optimizer, epoch, regime)
# train for one epoch
train_loss, train_prec1, train_prec5 = train(
train_loader, model, criterion, epoch, optimizer)
# evaluate on validation set
val_loss, val_prec1, val_prec5 = validate(
val_loader, model, criterion, epoch)
# remember best prec@1 and save checkpoint
is_best = val_prec1 > best_prec1
best_prec1 = max(val_prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'model': args.model,
'config': args.model_config,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'regime': regime
}, is_best, path=save_path)
logging.info('\n Epoch: {0}\t'
'Training Loss {train_loss:.4f} \t'
'Training Prec@1 {train_prec1:.3f} \t'
'Training Prec@5 {train_prec5:.3f} \t'
'Validation Loss {val_loss:.4f} \t'
'Validation Prec@1 {val_prec1:.3f} \t'
'Validation Prec@5 {val_prec5:.3f} \n'
.format(epoch + 1, train_loss=train_loss, val_loss=val_loss,
train_prec1=train_prec1, val_prec1=val_prec1,
train_prec5=train_prec5, val_prec5=val_prec5))
results.add(epoch=epoch + 1, train_loss=train_loss, val_loss=val_loss,
train_error1=100 - train_prec1, val_error1=100 - val_prec1,
train_error5=100 - train_prec5, val_error5=100 - val_prec5)
#results.plot(x='epoch', y=['train_loss', 'val_loss'],
# title='Loss', ylabel='loss')
#results.plot(x='epoch', y=['train_error1', 'val_error1'],
# title='Error@1', ylabel='error %')
#results.plot(x='epoch', y=['train_error5', 'val_error5'],
# title='Error@5', ylabel='error %')
results.save()
代码不长,都可以看懂,主要就是对下面这两句以及他们带出来的函数进行解读啦:
# train for one epoch
train_loss, train_prec1, train_prec5 = train(
train_loader, model, criterion, epoch, optimizer)
# evaluate on validation set
val_loss, val_prec1, val_prec5 = validate(
val_loader, model, criterion, epoch)
def train(data_loader, model, criterion, epoch, optimizer):
# switch to train mode
model.train()
return forward(data_loader, model, criterion, epoch,
training=True, optimizer=optimizer)
def validate(data_loader, model, criterion, epoch):
# switch to evaluate mode
model.eval()
return forward(data_loader, model, criterion, epoch,
training=False, optimizer=None)
下面对forward函数进行解读:
def forward(data_loader, model, criterion, epoch=0, training=True, optimizer=None):
if args.gpus and len(args.gpus) > 1:
model = torch.nn.DataParallel(model, args.gpus)
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
这个AverageMeter()是啥呢, 在utils.py中找到了它:
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
__optimizers = {
'SGD': torch.optim.SGD,
'ASGD': torch.optim.ASGD,
'Adam': torch.optim.Adam,
'Adamax': torch.optim.Adamax,
'Adagrad': torch.optim.Adagrad,
'Adadelta': torch.optim.Adadelta,
'Rprop': torch.optim.Rprop,
'RMSprop': torch.optim.RMSprop
}
作者的注释已经说的很明白啦:Computes and stores the average and current value。
end = time.time()
for i, (inputs, target) in enumerate(data_loader):
# measure data loading time
data_time.update(time.time() - end)
if args.gpus is not None:
target = target.cuda(async=True)
input_var = Variable(inputs.type(args.type), volatile=not training)
target_var = Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
if type(output) is list:
output = output[0]
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.data[0], inputs.size(0))
top1.update(prec1[0], inputs.size(0))
top5.update(prec5[0], inputs.size(0))
这段代码从data_loader中读入数据并输入模型中,计算正确率,更新loss值。
optimizer.step()
if training:
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
for p in list(model.parameters()):
if hasattr(p,'org'):
p.data.copy_(p.org)
optimizer.step()
for p in list(model.parameters()):
if hasattr(p,'org'):
p.org.copy_(p.data.clamp_(-1,1))
这一步先进行反向传播,然后终于到了跟BNN有关的地方了(不说我都忘了我写这个博客的初衷了),这里发现,反向传播计算出梯度之后,在梯度更新之前,先把模型参数恢复为原来的精度,更新后再将参数限制到(-1,1)区间中。
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
logging.info('{phase} - Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(data_loader),
phase='TRAINING' if training else 'EVALUATING',
batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
return losses.avg, top1.avg, top5.avg
这一部分是保存训练讯息,存入日志的代码,不再详细赘述。
先写到这,剩下的明天再写~
def Binarize(tensor,quant_mode='det'):
if quant_mode=='det':
return tensor.sign()
else:
return tensor.add_(1).div_(2).add_(torch.rand(tensor.size()).add(-0.5)).clamp_(0,1).round().mul_(2).add_(-1)
二值化函数,输入向量,返回二值化的向量,与论文中的描述一样,包括随机二值化和确定二值化。
def Quantize(tensor,quant_mode='det', params=None, numBits=8):
tensor.clamp_(-2**(numBits-1),2**(numBits-1))
if quant_mode=='det':
tensor=tensor.mul(2**(numBits-1)).round().div(2**(numBits-1))
else:
tensor=tensor.mul(2**(numBits-1)).round().add(torch.rand(tensor.size()).add(-0.5)).div(2**(numBits-1))
quant_fixed(tensor, params)
return tensor
import torch.nn._functions as tnnf
量化函数,把向量量化为指定的精度。
class BinarizeLinear(nn.Linear):
def __init__(self, *kargs, **kwargs):
super(BinarizeLinear, self).__init__(*kargs, **kwargs)
def forward(self, input):
if input.size(1) != 784:
input.data=Binarize(input.data)
if not hasattr(self.weight,'org'):
self.weight.org=self.weight.data.clone()
self.weight.data=Binarize(self.weight.org)
out = nn.functional.linear(input, self.weight)
if not self.bias is None:
self.bias.org=self.bias.data.clone()
out += self.bias.view(1, -1).expand_as(out)
return out
二值化全连接层,第一层输入不进行量化,对权重进行量化,但同时保留初始权重。
class BinarizeConv2d(nn.Conv2d):
def __init__(self, *kargs, **kwargs):
super(BinarizeConv2d, self).__init__(*kargs, **kwargs)
def forward(self, input):
if input.size(1) != 3:
input.data = Binarize(input.data)
if not hasattr(self.weight,'org'):
self.weight.org=self.weight.data.clone()
self.weight.data=Binarize(self.weight.org)
out = nn.functional.conv2d(input, self.weight, None, self.stride,
self.padding, self.dilation, self.groups)
if not self.bias is None:
self.bias.org=self.bias.data.clone()
out += self.bias.view(1, -1, 1, 1).expand_as(out)
return out
二值化卷积层,第一层输入不进行量化,对权重进行量化,但同时保留初始权重。
上面的三个部分没有写的必要了,就是有一点和一般的网络不一样,在激活函数部分使用Hardtanh,这部分的作用是对sign函数进行松弛化,不然梯度全都是0,没办法反向传播。
BNN是网络压缩量化这方面非常经典的开创之作,读完论文之后再来仔细研读了一边代码,收获很多,但是这个论文的复现难度非常低,我读代码也只是在学习项目框架而已,所以量化这方面的论文代码还得多看,这只是个开始,继续加油~