# – coding:utf-8 –
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import torch
from progress.bar import Bar
from models.data_parallel import DataParallel
from utils.utils import AverageMeter
class ModelWithLoss(torch.nn.Module):
def __init__(self, model, loss):
super(ModelWithLoss, self).__init__()
self.model = model
self.loss = loss
def forward(self, batch):
#model是DLASeg类的一个实例化对象,批量数据中的input经model输出预测值
outputs = self.model(batch['input'])
#然后经loss将预测值和批量中的地面真值比较的出损失。
loss, loss_stats = self.loss(outputs, batch)
return outputs[-1], loss, loss_stats
class BaseTrainer(object):
def __init__(
self, opt, model, optimizer=None):
self.opt = opt
self.optimizer = optimizer
self.loss_stats, self.loss = self._get_losses(opt)
#self.model_with_loss是ModelWithLoss类的实例化对象
self.model_with_loss = ModelWithLoss(model, self.loss)
def set_device(self, gpus, chunk_sizes, device):
if len(gpus) > 1:
self.model_with_loss = DataParallel(
self.model_with_loss, device_ids=gpus,
chunk_sizes=chunk_sizes).to(device)
else:
self.model_with_loss = self.model_with_loss.to(device)
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device=device, non_blocking=True)
def run_epoch(self, phase, epoch, data_loader):
model_with_loss = self.model_with_loss
if phase == 'train':
model_with_loss.train()
else:
if len(self.opt.gpus) > 1:
model_with_loss = self.model_with_loss.module
model_with_loss.eval()
torch.cuda.empty_cache()
opt = self.opt
results = {}
data_time, batch_time = AverageMeter(), AverageMeter()
avg_loss_stats = {l: AverageMeter() for l in self.loss_stats}
num_iters = len(data_loader) if opt.num_iters < 0 else opt.num_iters
bar = Bar('{}/{}'.format(opt.task, opt.exp_id), max=num_iters)
end = time.time()
#下面一句for语句的作用 :通过for循环配合枚举,先获取一个批量的输入和对应真值(输入是inp,真值包含hm, reg_mask, ind, wh),再将这一个批量的输入加载进model,
进行前向传播得到预测值,再对比刚得到的真值计算损失,再梯度后向传播、参数更新,然后循环下一个批量进行相同操作。
#data_loader实例对象是main.py函数中的train_loader,而train_loader中的Dataset继承类PascalVOC和CTDetDataset.
#又因为PascalVOC类和CTDetDataset类中分别有魔法方法__len__和__getitem__,且enumerate(data_loader)返回序列data_loader中元素下标(iter_id)和对应元素(batch),
#其中对应元素(batch)即返回data_loader[下标],即调用了__getitem__方法。所以对应元素(batch)是__getitem__返回的ret字典
# 所以下面一句代码会先后跳进类PascalVOC和CTDetDataset中的__len__和__getitem__。
for iter_id, batch in enumerate(data_loader):
if iter_id >= num_iters:
break
data_time.update(time.time() - end)
for k in batch:
if k != 'meta':
batch[k] = batch[k].to(device=opt.device, non_blocking=True)
#model_with_loss是一个类的实例化对象,分为model和loss两部分;
#作用是将一个批量(batch)的数据(经过枚举后含有'input', 和地面真值'hm', 'reg_mask', 'ind', 'wh'等信息),经model由批量数据中input输出预测值,然后经loss将预测值和批量中的地面真值比较的出损失。
output, loss, loss_stats = model_with_loss(batch)
loss = loss.mean()
if phase == 'train':
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
#Bar :进度条。即,训练时可视化时的进度条
Bar.suffix = '{phase}: [{0}][{1}/{2}]|Tot: {total:} |ETA: {eta:} '.format(
epoch, iter_id, num_iters, phase=phase,
total=bar.elapsed_td, eta=bar.eta_td)
for l in avg_loss_stats:
avg_loss_stats[l].update(
loss_stats[l].mean().item(), batch['input'].size(0))
Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
if not opt.hide_data_time:
Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) ' \
'|Net {bt.avg:.3f}s'.format(dt=data_time, bt=batch_time)
if opt.print_iter > 0:
if iter_id % opt.print_iter == 0:
print('{}/{}| {}'.format(opt.task, opt.exp_id, Bar.suffix))
else:
bar.next()
if opt.debug > 0:
self.debug(batch, output, iter_id)
if opt.test:
self.save_result(output, batch, results)
del output, loss, loss_stats
bar.finish()
ret = {k: v.avg for k, v in avg_loss_stats.items()}
ret['time'] = bar.elapsed_td.total_seconds() / 60.
return ret, results
def debug(self, batch, output, iter_id):
raise NotImplementedError
def save_result(self, output, batch, results):
raise NotImplementedError
def _get_losses(self, opt):
raise NotImplementedError
def val(self, epoch, data_loader):
return self.run_epoch('val', epoch, data_loader)
def train(self, epoch, data_loader):
return self.run_epoch('train', epoch, data_loader)