trainner 写法
(1)正常写法
一般情况下教程会教你这样去写Pytorch 的train 代码:
#准备好训练数据加载器
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
#准备好模型
model = Net()
#优化方法
optimizer = torch.optim.Adam(model.parameters())
#loss 函数
loss_func = torch.nn.CrossEntropyLoss()
##然后开始每一轮的迭代
for epoch in range(10):
print('epoch {}'.format(epoch + 1))
# training-----------------------------
train_loss = 0.
train_acc = 0.
for batch_x, batch_y in train_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss += loss.data[0]
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_y).sum()
train_acc += train_correct.data[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
train_data)), train_acc / (len(train_data))))
# evaluation--------------------------------
model.eval()
eval_loss = 0.
eval_acc = 0.
for batch_x, batch_y in test_loader:
batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss += loss.data[0]
pred = torch.max(out, 1)[1]
num_correct = (pred == batch_y).sum()
eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_data)), eval_acc / (len(test_data))))
我们看到这样的形式:
- 数据加载器
- 模型初始化
- 优化器
- loss 函数
- 开启循环,进行每一个epoch 的迭代来实现一轮轮的训练
然而:这样写的方式一点都不"优雅"!,感觉像在写脚本代码,又不可复用!
(2)抽象版本
当我们把整个train 的循环过程抽象为一个class的时候,它应该是这样子的结构,"变得些许优雅起来":
class Trainer(object):
def __init__(self, model=None, criterion=None, optimizer=None, dataset=None,USE_CUDA=True):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.dataset = dataset
self.iterations = 0
self.USE_CUDA = USE_CUDA
def run(self, epochs=1):
#每一个epoch 就是一次train的过程
for i in range(1, epochs + 1):
self.train()
def train(self):
#从dataloader 中拿数据
for i, data in enumerate(self.dataset, self.iterations + 1):
batch_input, batch_target = data
input_var = batch_input
target_var = batch_target
if self.USE_CUDA:
input_var = input_var.cuda()
target_var = target_var.cuda()
#每一次前馈就是一次函数闭包操作
def closure():
batch_output = self.model(input_var)
loss = self.criterion(batch_output, target_var)
loss.backward()
return loss
#loss 返回,准备优化
self.optimizer.zero_grad()
self.optimizer.step(closure)
self.iterations += i
这个类好像确实比我们之前的写法"优雅"了很多!
然而, 有没有发现什么问题?
这个类中,当我们需要进行一些操作如每轮"正确率"显示,LOSS日志输出的时候,我们需要每次都修改 train 方法中代码,不能形成 一个个组件的形式;
(3)对trainer 类进行插件化处理
首先要对trainer 进行插件化 写一个接口:
class Trainer(object):
def __init__(self):
self.plugin_queues = {
}
def call_plugins(self):
pass
def register_plugin(self,plugin):
pass
在接口中,我们定义了 插件队列的字典,里面保存其不同时候调用的插件序列;
我们想想在什么时候需要调用插件:
- 在每次获取到数据之后,训练之前 对数据进行不同处理?
- 在完成一次backward操作之后 显示当前loss 或者accuracy
- 在完成每次batch or epoch 后保存模型或者修改学习率?
所以 四种类别的插件是必须的,而且它们的调用顺序也确定下来:
(1) iteration:一般是在完成一个batch 训练之后进行的事件调用序列(一般不改动网络或者优化器,如:计算准确率)调用序列;
(2) batch 在进行batch 训练之前需要进行的事件调用序列
(3)epoch 完成一个epoch 训练之后进行的事件调用序列
(4)update 完成一个batch训练之后进行的事件(涉及到对网络或者优化器的改动,如:学习率的调整)
此外,我们还要对插件的优先序列排个序;
import heapq
class Trainer(object):
def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.dataset = dataset
#总的迭代次数
self.iterations = 0
'''
Trainer的状态,注意这里的状态包含了所有插件提供的状态。初始化为空
'''
self.stats = {}
self.plugin_queues = {
'iteration': [],
'epoch': [],
'batch': [],
'update': [],
}
'''
作者将插件的调用进行了分类:
(1)iteration:一般是在完成一个batch 训练之后进行的事件调用序列(一般不改动网络或者优化器,如:计算准确率)调用序列;
(2)batch 在进行batch 训练之前需要进行的事件调用序列
(3)epoch 完成一个epoch 训练之后进行的事件调用序列
(4)update 完成一个batch训练之后进行的事件(涉及到对网络或者优化器的改动,如:学习率的调整)
iteration 跟update 两种插件调用的时候传入的参数不一样,iteration 会传入batch output,loss 等训练过程中的数据,
而update传入的的model ,方便对网络的修改
'''
def register_plugin(self, plugin):
#注册插件
plugin.register(self)
#插件的触发间隔,一般是这样的形式[(1, 'iteration'), (1, 'epoch')]
intervals = plugin.trigger_interval
if not isinstance(intervals, list):
intervals = [intervals]
for duration, unit in intervals:
#unit 是事件的触发类别
queue = self.plugin_queues[unit]
'''添加事件, 这里的duration就是触发间隔,,以后在调用插件的时候,
会进行更新 duration 决定了比如在第几个iteration or epoch 触发事件。len(queue)这里应当理解为优先级(越小越高)
【在相同duration的情况下决定调用的顺序】,根据加入队列的早晚决定。'''
queue.append((duration, len(queue), plugin))
def call_plugins(self, queue_name, time, *args):
#调用插件
args = (time,) + args
#这里的time 最基本的意思是次数,如(iteration or epoch)
queue = self.plugin_queues[queue_name]
if len(queue) == 0:
return
while queue[0][0] <= time:
'''如果队列第一个事件的duration(也就是触发时间点)小于当前times'''
plugin = queue[0][2]
'''调用相关队列相应的方法,所以如果是继承Plugin类的插件,
必须实现 iteration、batch、epoch和update中的至少一个且名字必须一致。'''
getattr(plugin, queue_name)(*args)
for trigger in plugin.trigger_interval:
if trigger[1] == queue_name:
interval = trigger[0]
'''根据插件的事件触发间隔,来更新事件队列里的事件 duration'''
new_item = (time + interval, queue[0][1], plugin)
heapq.heappushpop(queue, new_item)
'''加入新的事件并弹出最小堆的堆头。最小堆重新排序。'''
def run(self, epochs=1):
for q in self.plugin_queues.values():
'''对四个事件调用序列进行最小堆排序。'''
heapq.heapify(q)
for i in range(1, epochs + 1):
self.train()
#进行每次epoch 的更新
self.call_plugins('epoch', i)
def train(self):
for i, data in enumerate(self.dataset, self.iterations + 1):
batch_input, batch_target = data
#在每次获取batch data 后进行更新
self.call_plugins('batch', i, batch_input, batch_target)
input_var = batch_input
target_var = batch_target
#这里是给后续插件做缓存部分数据,这里是网络输出与loss
plugin_data = [None, None]
def closure():
batch_output = self.model(input_var)
loss = self.criterion(batch_output, target_var)
loss.backward()
if plugin_data[0] is None:
plugin_data[0] = batch_output.data
plugin_data[1] = loss.data
return loss
self.optimizer.zero_grad()
self.optimizer.step(closure)
self.call_plugins('iteration', i, batch_input, batch_target,
*plugin_data)
self.call_plugins('update', i, self.model)
self.iterations += i
(4)插件的实现:
基本类:
class Plugin(object):
def __init__(self, interval=None):
if interval is None:
interval = []
self.trigger_interval = interval
def register(self, trainer):
raise NotImplementedError
主要行为的基本实现:
class Monitor(Plugin):
def __init__(self, running_average=True, epoch_average=True, smoothing=0.7,
precision=None, number_format=None, unit=''):
'''
:param running_average:
:param epoch_average:
:param smoothing:
:param precision:数字输出精度
:param number_format: 数字输出格式
:param unit:
'''
if precision is None:
precision = 4
if number_format is None:
number_format = '.{}f'.format(precision)
#规定了输出格式
number_format = ':' + number_format
'''
在基类 plugin 中,初始化需要传入interval 参数,此处list[(1, 'iteration'), (1, 'epoch')]
代表了插件自身实现的的触发time 跟触发时间
'''
super(Monitor, self).__init__([(1, 'iteration'), (1, 'epoch')])
#是否平滑
self.smoothing = smoothing
#增量计算均值
self.with_running_average = running_average
self.with_epoch_average = epoch_average
#输出日志的格式
self.log_format = number_format
self.log_unit = unit
self.log_epoch_fields = None
self.log_iter_fields = ['{last' + number_format + '}' + unit]
if self.with_running_average:
self.log_iter_fields += [' ({running_avg' + number_format + '}' + unit + ')']
if self.with_epoch_average:
self.log_epoch_fields = ['{epoch_mean' + number_format + '}' + unit]
def register(self, trainer):
self.trainer = trainer
#在此处注册的时候,给train 的stats 注册当前状态,比如log 的格式等
stats = self.trainer.stats.setdefault(self.stat_name, {})
stats['log_format'] = self.log_format
stats['log_unit'] = self.log_unit
stats['log_iter_fields'] = self.log_iter_fields
if self.with_epoch_average:
stats['log_epoch_fields'] = self.log_epoch_fields
if self.with_epoch_average:
stats['epoch_stats'] = (0, 0)
def iteration(self, *args):
#每个iteration 进行的操作
stats = self.trainer.stats.setdefault(self.stat_name, {})
#通过_get_value 方法拿到每个插件的值,放入到stats中
stats['last'] = self._get_value(*args)
if self.with_epoch_average:
stats['epoch_stats'] = tuple(sum(t) for t in
zip(stats['epoch_stats'], (stats['last'], 1)))
if self.with_running_average:
previous_avg = stats.get('running_avg', 0)
stats['running_avg'] = previous_avg * self.smoothing + \
stats['last'] * (1 - self.smoothing)
def epoch(self, idx):
#每个epoch 进行的操作
stats = self.trainer.stats.setdefault(self.stat_name, {})
if self.with_epoch_average:
#如果需要计算每轮epoch 的精度等,需要 总数/轮数
epoch_stats = stats['epoch_stats']
stats['epoch_mean'] = epoch_stats[0] / epoch_stats[1]
stats['epoch_stats'] = (0, 0)
Monitor类实现了每轮epoch ,iteration 应该做什么.通过 self._get_value方法拿到具体实现的插件的值.
那我们来实现一个日志loss类吧:
from .monitor import Monitor
class LossMonitor(Monitor):
stat_name = 'loss'
#该插件的作用为简单记录每次的loss
def _get_value(self, iteration, input, target, output, loss):
return loss.item()
logger :
from collections import defaultdict
from .plugin import Plugin
class Logger(Plugin):
alignment = 4
#不同字段之间的分隔符
separator = '#' * 80
def __init__(self, fields, interval=None):
if interval is None:
interval = [(1, 'iteration'), (1, 'epoch')]
super(Logger, self).__init__(interval)
#需要打印的字段,如loss acc
self.field_widths = defaultdict(lambda: defaultdict(int))
self.fields = list(map(lambda f: f.split('.'), fields))
# 遵循XPath路径的格式。以AccuracyMonitor为例子,如果你想打印所有的状态,
# 那么你只需要令fields=[AccuracyMonitor.stat_name],也就是,['accuracy'],
# 而如果你想只打印AccuracyMonitor的子状态'last',那么你就只需要设置为
# ['accuracy.last'],而这里的split当然就是为了获得[['accuracy', 'last']]
# 这是为了之后的子状态解析(类似XPath路径解析)所使用的。
def _join_results(self, results):
# 这个函数主要是将获得的子状态的结果进行组装。
joined_out = map(lambda i: (i[0], ' '.join(i[1])), results)
joined_fields = map(lambda i: '{}: {}'.format(i[0], i[1]), joined_out)
return '\t'.join(joined_fields)
def log(self, msg):
print(msg)
def register(self, trainer):
self.trainer = trainer
def gather_stats(self):
result = {}
return result
def _align_output(self, field_idx, output):
#对其输出格式
for output_idx, o in enumerate(output):
if len(o) < self.field_widths[field_idx][output_idx]:
num_spaces = self.field_widths[field_idx][output_idx] - len(o)
output[output_idx] += ' ' * num_spaces
else:
self.field_widths[field_idx][output_idx] = len(o)
def _gather_outputs(self, field, log_fields, stat_parent, stat, require_dict=False):
# 这个函数是核心,负责将查找到的最底层的子模块的结果提取出来。
output = []
name = ''
if isinstance(stat, dict):
'''
通过插件的子stat去拿到每一轮的信息,如LOSS等
'''
log_fields = stat.get(log_fields, [])
name = stat.get('log_name', '.'.join(field))
# 找到自定义的输出名称。y有时候我们并不像打印对应的Key出来,所以可以
# 在写插件的时候增加多一个'log_name'的键值对,指定打印的名称。默认为
# field的完整名字。传入的fileds为['accuracy.last']
# 那么经过初始化之后,fileds=[['accuracy',
# 'last']]。所以这里的'.'.join(fields)其实是'accuracy.last'。
# 起到一个还原名称的作用。
for f in log_fields:
output.append(f.format(**stat))
elif not require_dict:
# 在这里的话,如果子模块stat不是字典且require_dict=False
# 那么他就会以父模块的打印格式和打印单位作为输出结果的方式。
name = '.'.join(field)
number_format = stat_parent.get('log_format', '')
unit = stat_parent.get('log_unit', '')
fmt = '{' + number_format + '}' + unit
output.append(fmt.format(stat))
return name, output
def _log_all(self, log_fields, prefix=None, suffix=None, require_dict=False):
results = []
for field_idx, field in enumerate(self.fields):
parent, stat = None, self.trainer.stats
for f in field:
parent, stat = stat, stat[f]
name, output = self._gather_outputs(field, log_fields,
parent, stat, require_dict)
if not output:
continue
self._align_output(field_idx, output)
results.append((name, output))
if not results:
return
output = self._join_results(results)
loginfo = []
if prefix is not None:
loginfo.append(prefix)
loginfo.append("\t")
loginfo.append(output)
if suffix is not None:
loginfo.append("\t")
loginfo.append(suffix)
self.log("".join(loginfo))
def iteration(self, *args):
'''
:param args: ( i, batch_input, batch_target,*plugin_data) 的元祖
:return:
'''
self._log_all('log_iter_fields',prefix="iteration:{}".format(args[0]))
def epoch(self, epoch_idx):
self._log_all('log_epoch_fields',
prefix=self.separator + '\nEpoch summary:',
suffix=self.separator,
require_dict=True)
具体用法:
trainer = Trainer(pretrainModel,criterion,optimizer,train_loader)
trainer.register_plugin(LossMonitor())
trainer.register_plugin(Logger(['loss']))
trainer.run(50)