AverageMeter()的作用与用法

train函数是模型训练的入口。首先一些变量的更新采用自定义的AverageMeter类来管理,后面会介绍该类的定义。然后model.train()是设置为训练模式。 for i, (input, target) in enumerate(train_loader) 是数据迭代读取的循环函数,具体而言,当执行enumerate(train_loader)的时候,是先调用DataLoader类的__iter__方法,该方法里面再调用DataLoaderIter类的初始化操作__init__。而当执行for循环操作时,调用DataLoaderIter类的__next__方法,在该方法中通过self.collate_fn接口读取self.dataset数据时就会调用TSNDataSet类的__getitem__方法,从而完成数据的迭代读取。读取到数据后就将数据从Tensor转换成Variable格式,然后执行模型的前向计算:output = model(input_var),得到的output就是batch size*class维度的Variable;损失函数计算: loss = criterion(output, target_var);准确率计算: prec1, prec5 = accuracy(output.data, target, topk=(1,5));模型参数更新等等。其中loss.backward()是损失回传, optimizer.step()是模型参数更新。

在train函数中采用自定义的AverageMeter类来管理一些变量的更新。在初始化的时候就调用的重置方法reset。当调用该类对象的update方法的时候就会进行变量更新,当要读取某个变量的时候,可以通过对象.属性的方式来读取,比如在train函数中的top1.val读取top1准确率。

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

你可能感兴趣的:(ReID行人重识别,Pytorch)