深度学习:人群密度估计CSRNet(cvpr 2018)论文源代码详解

写在前面

大二上学期,刚从事Deep Learning的学习不久,《CSRNet: Dilated Convolutional Neural Networks for Understanding the Highly Congested Scenes》是我所看的第一篇论文,CSRNet神经网络主要用于高密度人群图片的人数估计。前端直接使用预先训练好的VGG16神经网络,CSRNet采用了单分支的网络结构,以及用空洞卷积来代替池化,在保持很少的参数,实现更大保真的同时,扩大了感受野。基于这几个特点,CSRNet在保证极其优秀的准确率的同时,非常易于训练。

我之前找遍全网,都没有找到一篇完整的对源代码的解析,而读懂源代码是学习路途上不可或缺的一步,所以我自己花了两天的时间,一边查资料,一边细看了CSRNet的pytorch代码,所以打算来写我的第一篇博客,对CSRNet的代码做一个详细的解析,会说到代码里每一句的作用,和引入的各个模块的语法,这也是对自己学习的一个笔记。

代码当中的细节很多,需要花很长的时间去理解,代码里面也有几处是我还有疑问的,我会在下面的详解里说出来,如果有懂的同学可以在评论区和我一起交流: )

作为初学者,也是第一次写博客,难免有谬误,欢迎大佬们来批评指正!

注:本文假设你已经对深度学习,pytorch语法和CSRNet的论文有了一定的理解,网络结构方面的东西我会说的比较简略。

下面附上论文和官方源代码的地址:
论文:论文地址
源代码:源代码地址

那么废话不多说,我们开始代码的分析。

代码分析

可以看到,py代码部分主要分了五个文件,分别是dataset.py,image.py,utils.py,model.py和train.py。下面我们五个文件一个一个说。

dataset.py

dataset文件主要实现了数据集的创建,主要围绕listDataset类展开

Dataset类

Dataset类主要帮助我们完成数据集的创建,它包含在torch.utils.data下
当我们需要创建数据集时,需要创建一个类,继承自Dataset类
它相对于直接读取数据的优势在于可以取batch,取shuffle,实现多线程读取
除了构造函数以外,此类只需要定义两个函数即可使用:len(self),getitem(self)

在构造函数中,应定义数据集要使用的数据,和需要使用到的变量

len(self)函数:
仅需return数据集中数据的个数

getitem(self, index)函数:
函数一般包含一个index参数来表示下标,返回值为数据集中下标为index的数据即可


下面我们看第一段代码

import os
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from image import *
import torchvision.transforms.functional as F

class listDataset(Dataset):
    def __init__(self, root, shape=None, shuffle=True, transform=None,  train=False, seen=0, batch_size=1, num_workers=4):
        if train:
            root = root *4
        random.shuffle(root)
        
        self.nSamples = len(root)
        self.lines = root
        self.transform = transform
        self.train = train
        self.shape = shape
        self.seen = seen
        self.batch_size = batch_size
        self.num_workers = num_workers

这里是定义class类的第一段,构造函数。接收了很多来自于外界的参数,其中root为用于训练的图片的路径,是列表类型,shuffle表示是否需要将root的顺序打乱,train表示这个数据集是否是用于训练的,batch_size表示了训练批数据的大小,在CSRNet里,batch的大小一律采用了1。num_workers表示了读取batch的线程数,越大读取效率越高,不过对CPU和内存的开销也越大,所以要做好权衡,一般可以设置为电脑CPU的核心数。

在if train后面,作者对列表root使用了*4的操作,在python中,将列表乘一个数字n,返回值是一个新的列表,新的列表中会将原列表中的元素重复n次。

random.shuffle(list)操作可以实现对list的打乱,不过我觉得这里似乎在前面加一个if shuffle似乎更合适

然后就是对一些成员变量的初始化。我们再来看下面一段

    def __len__(self):
        return self.nSamples
    def __getitem__(self, index):
        assert index <= len(self), 'index range error' 
        
        img_path = self.lines[index]
        
        img,target = load_data(img_path,self.train)
        
        #img = 255.0 * F.to_tensor(img)
        
        #img[0,:,:]=img[0,:,:]-92.8207477031
        #img[1,:,:]=img[1,:,:]-95.2757037428
        #img[2,:,:]=img[2,:,:]-104.877445883
        
        if self.transform is not None:
            img = self.transform(img)
        return img,target

这一段是实现__len__()和__getitem__()的。

在构造函数中,定义了self.nSamples为root列表的长度,所以在__len__()方法中只需要return self.nSample即可。

getitem()里,先使用了一个断言语句。 assert即断言的意思,assert后面一般会跟一个逻辑式,式子值为True的时候不发生任何事情,式子值为False时会引发异常。这里显然是判断下标是否溢出。

img_path从成员函数self.lines处读取了图片的地址,下一行里调用了load_data()函数。这个函数在image.py文件里被定义,具体的实现我们在image.py代码的解析里说,我们只要知到这个函数读取了图片地址,返回用于训练的图片和图片对应的密度图的numpy变量,被我们的img和target变量接入。

下面判断图片是否需要做变换,最后return图片和图片对应的密度图。这样数据集类就定义好啦!

image.py

image.py主要实现了对图片的处理和变换。

import random
import os
from PIL import Image,ImageFilter,ImageDraw
import numpy as np
import h5py
from PIL import ImageStat
import cv2

def load_data(img_path,train = True):
    gt_path = img_path.replace('.jpg','.h5').replace('images','ground_truth')
    img = Image.open(img_path).convert('RGB')
    gt_file = h5py.File(gt_path)
    target = np.asarray(gt_file['density'])
    if False:
        crop_size = (img.size[0]/2,img.size[1]/2)
        if random.randint(0,9)<= -1:
            
            
            dx = int(random.randint(0,1)*img.size[0]*1./2)
            dy = int(random.randint(0,1)*img.size[1]*1./2)
        else:
            dx = int(random.random()*img.size[0]*1./2)
            dy = int(random.random()*img.size[1]*1./2)
        
 
        img = img.crop((dx,dy,crop_size[0]+dx,crop_size[1]+dy))
        target = target[dy:crop_size[1]+dy,dx:crop_size[0]+dx]
        
                
        if random.random()>0.8:
            target = np.fliplr(target)
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
    
    
    target = cv2.resize(target,(target.shape[1]/8,target.shape[0]/8),interpolation = cv2.INTER_CUBIC)*64
    
    
    return img,target

这个文件主要是围绕load_data()这个函数展开,我们已经知道,这个函数在上一个文件里被listDataset类调用。

这一个文件里用到了h5py,opencv和PIL这三个模块,我们先讲一讲h5py模块

h5py模块

h5py文件是用来存放dataset或group的容器,我们这里只讲一下dataset的创建方法
h5py文件很像python中的字典,有键(key)和值(value)

生成一个h5py文件:
f = h5py.File(‘name.hdf5’, ‘w’)
当前文件夹下就会生成一个名为name.hdf5的文件
在程序中,这个文件会以名为f的对象的方式被打开

创建一个dataset:
d1 = f.create_dataset(‘key_name1’, (20, ), ‘i’),
'key_name1’即为键名,20是长度,'i’是数据类型,里面的数据会默认为0
然后可以给d1[i]赋值
或者,可以用f[‘key_name2’] = np.arange(15)直接定义dataset同时赋值
创建之后,
for key in f.keys()
f[key].name即可返回键名
f[key].value即可返回键对应的值,例如在这里会返回一个长度为20的数组


回到代码本身,我们可以看到作者先用replace()函数对输入的地址名做了修改。
使用str.replace(‘被替换的字符’, ‘替换的字符’(, 最大替换次数)),返回值即为被替换之后的字符串
这里显然是想生成一个h5py文件的文件名

下一行,通过PIL模块中的Image.open()函数读取了输入的地址所对应的图片

gt_file = h5py.File(gt_path)这一句,为图片建立了h5py文件,并且在程序中打开为gt_file。然后通过target = np.asarray(gt_file[‘density’])将文件里图片对应的密度图读取,并转化为numpy格式。

后面的这个if False是令我迷惑的地方,还有这个判断语句下的if random.randint(0,9)<= -1也一定是假的,我一开始怀疑是作者想把这一段代码注释掉,但是这里面的代码好像是有其作用的,这里对图像的切割处理在论文里也有提到过,所以这一段究竟是作者打代码时的疏忽还是另有玄机呢?我上网查了很久也没有找到相关的解读,如果有懂的小伙伴可以在评论区告诉我。

最后使用了opencv模块对图像进行了压缩处理,我们来简单说一说这里用到的opencv模块:

opencv模块

opencv是用来处理图像的模块,用import cv2的方式导入
cv2.resize用法:
cv2.resize(src, (dst_w, dst_h), interpolation)
需要是numpy类型的矩阵
src为原图像,dst_w为目标宽度,dst_h为目标高度,interpolation为变换类型,共有5种
cv2.INTER_NEAREST - 最近邻插值法
cv2. INTER_LINEAR - 双线性插值法(默认)
cv2.INTER_AREA - 基于局部像素的重采样(resampling using pixel area relation)。对于图像抽取(image decimation)来说,这可能是一个更好的方法。但如果是放大图像时,它和最近邻法的效果类似。
cv2.INTER_CUBIC - 基于4x4像素邻域的3次插值法
cv2.INTER_LANCZOS4 - 基于8x8像素邻域的Lanczos插值


在代码中,对target采取了宽高取1/8的操作因为CSRNet的输出结果就是原图大小的1/8,采用了INTER_CUBIC变换法,并在最后乘了64以保证图片像素之和不变。

好了,image.py文件我们就说完了,我们接下来看utils.py文件

utils.py

这个文件里实现的功能比较简单,utils是工具的意思,这里定义的函数是用来网络参数的保存的。我们来看源代码:

import h5py
import torch
import shutil

def save_net(fname, net):
    with h5py.File(fname, 'w') as h5f:
        for k, v in net.state_dict().items():
            h5f.create_dataset(k, data=v.cpu().numpy())
def load_net(fname, net):
    with h5py.File(fname, 'r') as h5f:
        for k, v in net.state_dict().items():        
            param = torch.from_numpy(np.asarray(h5f[k]))         
            v.copy_(param)
            
def save_checkpoint(state, is_best,task_id, filename='checkpoint.pth.tar'):
    torch.save(state, task_id+filename)
    if is_best:
        shutil.copyfile(task_id+filename, task_id+'model_best.pth.tar')            

这里即是utils.py里实现的三个函数,其中前两个函数似乎一直没有用到,所以我们只讲第三个:save_checkpoint()函数

一共有四个参数,state是需要保存的数据,is_best是一个bool量,表示此次保存的是否是目前最优的参数,task_id表示任务名,最后的filename设置了缺省值,表示了保存参数的文件名的后一半和后缀。

首先用torch.save()函数对传进来的参数进行保存,名字是任务名默认的后缀。若此次传入的参数是目前最优的,那么就使用shutil模块对此次文件进行拷贝,在后半截里加上一个best ^ ^

shutil模块主要作用是对文件的操作和管理
shutil.copyfile(filename1, filename2)可以完成文件1到文件2的拷贝,文件2无需存在

好啦,那我们的utils.py文件也介绍完了,下面还剩下两个文件

model.py

model.py文件主要实现了网络模型的定义,这个文件的代码我们分三部分呈现出来。
下面是第一部分。

import torch.nn as nn
import torch
from torchvision import models
from utils import save_net,load_net

class CSRNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat  = [512, 512, 512,256,128,64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        if not load_weights:
            mod = models.vgg16(pretrained = True)
            self._initialize_weights()
            for i in xrange(len(self.frontend.state_dict().items())):
                self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]

这里是网络模型类的构造函数部分。只有一个参数,就是load_weights,默认值是False,事实上,后面在使用CSRNet类的时候,一直是以默认的False来的。

网络的两个比较重要的列表是self.frontend_feat和self.backend_feat,它们记录了网络前端和后端的结构。数字表示这一层的特征平面的数量,‘M’表示这一层是一个MaxPooling池化层,frontend_feat表示的即是VGG16网络的前端,而back_end表示的是采用空洞卷积的后端,经过作者的消融实验验证,所有层的空洞率都采用2的时候可以达到最好的效果。

前端和后端都使用了make_layers()函数自定义实现,自定义函数我们马上会说到。最后的output层采用了1*1卷积核实现了特征平面数量向1的转变。

下面是VGG16预训练参数加载的过程,我们先讲一下torchvision.model模块的使用:

torchvision.model中为我们预先保存了一些常见的网络,如vgg, alexnet等
例如我们可以使用mod = torchvision.models.vgg16(pretrained = True)的方式来直接创造一个vgg16网络对象
当pretrained = True时,vgg16网络预先训练好的参数也会一并加载到网络对象当中
通过这种方式,我们就可以实现网络前端使用预先训练好的网络的一部分,这为我们的训练节省了大量时间

有了预先训练好的网络,我们只需要将参数从训练好的网络向我们自己的网络里面拷贝就好了。先调用self._initialize()方法进行手动初始化,然后进行拷贝。那么我们是如何进行参数的拷贝的呢?

net.state_dict()是网络全部参数的字典
里面的键是网络各层参数的名字,值是封装好参数的Tensor
dict.items()语句会返回一个可遍历的(键,值)元组
那么通过遍历这个元组的方式,即通过逐个的i访问mod.state_dict().items()[i][1].data[:]即可遍历完所有的参数,完成拷贝。

构造函数看完,我们接下来看文件的第二部分:

    def forward(self,x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

这一部分是对网络forward的定义和初始化的方法。

forward比较简单,这里就不再赘述,我们来补充一点网络初始化的知识。

torch的网络一般会自动初始化,不过torch也提供了手动给参数初始化的方法,包含在torch.nn.init下
均匀分布:nn.init.uniform(tensor, a=0, b=1)
正态分部:nn.init.normal_(tensor, mean=0, std=1)
初始化为常数:nn.init.normal_(tensor, val)

代码中正是以这样的方法实现对参数的初始化的,net.modules()是网络各层的列表,我们使用m来遍历列表中的元素,判断m的层类型,然后分别使用init下的函数来完成初始化。

初始化方法看完,我们来看这个文件的最后一段:

def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)                

这一块代码是网络结构自定义的函数,是写在类的定义之外的。首先,通过dilation与否判断空洞率的大小,若是False则是网络前端,空洞率为1,若是True则是网络后端,空洞率为2。

接下来遍历传入的frontend_feat和backend_feat,以此来确定网络层的类型,若是’M’则是MaxPooling层;若是数字,则是卷积层,一套卷积层包括了Conv2d层,BatchNorm层和ReLU层,遍历完毕后,即可完成对网络前端和后端的创建。

OK,我们终于把第四个文件也给看完了,我们只剩下最后一个重头戏,也就是train.py文件了。

train.py

train.py文件比较长,所以我们也分几块来说。我们先看第一部分代码:

import sys
import os

import warnings

from model import CSRNet

from utils import save_checkpoint

import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms

import numpy as np
import argparse
import json
import cv2
import dataset
import time

parser = argparse.ArgumentParser(description='PyTorch CSRNet')

parser.add_argument('train_json', metavar='TRAIN',
                    help='path to train json')
parser.add_argument('test_json', metavar='TEST',
                    help='path to test json')

parser.add_argument('--pre', '-p', metavar='PRETRAINED', default=None,type=str,
                    help='path to the pretrained model')

parser.add_argument('gpu',metavar='GPU', type=str,
                    help='GPU id to use.')

parser.add_argument('task',metavar='TASK', type=str,
                    help='task id to use.')

第一段的代码是模块的导入和对argparse模块的使用。可能有些小伙伴不熟悉argparse模块,我们先来简单说一说这个模块:

argparse模块

argparse模块可以让python在命令行启动的时候接收参数
首先,使用parse = argparse.ArgumentParser()来创建一个解析对象
然后可以通过parse.add_argument()函数来增加命令行参数
函数的第一第二个参数一般是参数的名字,可以是一个,也可以是两个
可以在长名字的前面加上–(两个杠),短名字的前面加上-(一个杠),让该参数变为可选参数
metavar参数会改变显示出来的名字
help参数会在命令行打出-h或–help的时候显示出来
default参数可以给该参数设置缺省值
type参数可以给该参数指定类型,默认为str
在命令行中运行该文件的时候,后面应加上参数名 参数 参数名 参数 ……,这样程序就会接收到参数
添加完参数之后,应让args = parse.parse_args,然后使用args.即可返回程序接收的参数值
在这之后,我们依旧可以通过args. = … 来定义新的参数


通过对argparse模块的简单讲解,大家应该已经可以大概明白代码中这些语句的用处。没错,train_json用来接收训练数据集的地址,json文件的用处我们待会会讲到,test_json接收测试数据集的地址,pre接收之前训练时保存的数据文件地址,可不填,gpu接收gpu的地址,task用来接收任务名。我们在命令行需要接收的参数就是这些啦。

接下来看main函数的代码:

def main():
    
    global args,best_prec1
    
    best_prec1 = 1e6
    
    args = parser.parse_args()
    args.original_lr = 1e-7
    args.lr = 1e-7
    args.batch_size    = 1
    args.momentum      = 0.95
    args.decay         = 5*1e-4
    args.start_epoch   = 0
    args.epochs = 400
    args.steps         = [-1,1,100,150]
    args.scales        = [1,1,1,1]
    args.workers = 4
    args.seed = time.time()
    args.print_freq = 30
    with open(args.train_json, 'r') as outfile:        
        train_list = json.load(outfile)
    with open(args.test_json, 'r') as outfile:       
        val_list = json.load(outfile)
    
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.cuda.manual_seed(args.seed)
    
    model = CSRNet()
    
    model = model.cuda()
    
    criterion = nn.MSELoss(size_average=False).cuda()
    
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.decay)

    if args.pre:
        if os.path.isfile(args.pre):
            print("=> loading checkpoint '{}'".format(args.pre))
            checkpoint = torch.load(args.pre)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.pre, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.pre))
            
    for epoch in range(args.start_epoch, args.epochs):
        
        adjust_learning_rate(optimizer, epoch)
        
        train(train_list, model, criterion, optimizer, epoch)
        prec1 = validate(val_list, model, criterion)
        
        is_best = prec1 < best_prec1
        best_prec1 = min(prec1, best_prec1)
        print(' * best MAE {mae:.3f} '
              .format(mae=best_prec1))
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.pre,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best,args.task)

首先,在args下设置了所有需要用到的参数,以供后面使用,然后打开数据集所在的json文件,将数据集图片的地址读取到train_list和test_list列表下,关于json模块,我们补充一些简单的知识:

json模块

json是一种轻量级的数据交换格式。易于人阅读和编写。同时也易于机器解析和生成。它基于JavaScript Programming Language, Standard ECMA-262 3rd Edition - December 1999的一个子集。JSON采用完全独立于语言的文本格式,但是也使用了类似于C语言家族的习惯(包括C, C++, C#, Java, JavaScript, Perl, Python等)。这些特性使JSON成为理想的数据交换语言。

python到json的过程称为序列化(encoding),将python对象转化为json字符串
json到python的过程称为反序列化(decoding),将json字符串格式解码为python对象
json.load()可以从文件中读取json字符串
json.loads()可以从文件中读取,并转化为字典类型
json.dumps()将python中的字典类型转换为字符串类型
json.dump()将json格式字符串写到文件中


读取完图片地址后,我们需要定义gpu
os.environ[‘CUDA_VISIBLE_DEVICES’] = args.gpu
torch.cuda.manual_seed(args.seed)
这两句完成了对gpu的配置,这里需要讲一下os模块里这个语句的作用:

os.environ[‘CUDA_VISIBLE_DEVICES’] = gpu_id可以设置当前要使用的gpu设备
例如os.environ[‘CUDA_VISIBLE_DEVICES’] = '0’即可设置当前使用0号GPU

所以这里就从args.gpu参数取得gpu的id,完成对gpu的设置。

而第二句中用到了torch.cuda模块:
torch.cuda.manual_seed(args.seed)即可为当前GPU设置随机种子,多GPU的时候应该使用
也就是说,这一句是用来应付多GPU的情况的。

我们继续往下看,下面是比较常规的操作了:定义网络类的对象model;将model转移到gpu上;定义误差函数为MSE(均方误差);定义优化器optimizer,采用随机梯度下降法(SGD),提供了学习率,动量,以及衰退率。

在下面的if args.pre:这一段是检查在命令行是否有之前保存的数据路径输入,使用os.path.isfile()这个函数来判断输入的路径是否是有效的文件,如果是,就开始向网络里读入之前保存的数据。使用checkpoint来接入torch.load读取的字典,并把字典里的start_epoch,best_prec1,以及网络的优化器的参数一起拷贝进来。

接下来就开始了我们的epoch循环,一步一步地训练网络。在这个循环的开始,依次调用了adjust_learning_rate()函数,train()函数和validate()函数,这三个函数分别实现了学习率调节,网络前向传播和误差逆传播,准确率检测功能,这三个函数都会在下面的部分讲到。这三部分完成之后我们需要判断一下由validate函数返回的MAE(平均绝对误差)是否是最优,若是最优,则把最优的信息输出到屏幕上。最后的最后,我们在每个epoch的最后存储训练的参数,调用了utils.py里面的save_checkpoint()函数,传进去的第一个参数是一个列表,包含了目前训练到的epoch,网络和优化器的所有参数字典,最优的MAE。(第二个’arch’参数似乎没有被使用到)

其实到这里,我们整个网络的训练步骤已经完全结束了,相信大家已经对CSRNet的训练过程有了比较系统的了解,下面就是几个函数的实现了:

def train(train_list, model, criterion, optimizer, epoch):
    
    losses = AverageMeter()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    
    
    train_loader = torch.utils.data.DataLoader(
        dataset.listDataset(train_list,
                       shuffle=True,
                       transform=transforms.Compose([
                       transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
                   ]), 
                       train=True, 
                       seen=model.seen,
                       batch_size=args.batch_size,
                       num_workers=args.workers),
        batch_size=args.batch_size)
    print('epoch %d, processed %d samples, lr %.10f' % (epoch, epoch * len(train_loader.dataset), args.lr))
    
    model.train()
    end = time.time()
    
    for i,(img, target)in enumerate(train_loader):
        data_time.update(time.time() - end)
        
        img = img.cuda()
        img = Variable(img)
        output = model(img)
 
        target = target.type(torch.FloatTensor).unsqueeze(0).cuda()
        target = Variable(target)
        
        
        loss = criterion(output, target)
        
        losses.update(loss.item(), img.size(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        
        batch_time.update(time.time() - end)
        end = time.time()
        
        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  .format(
                   epoch, i, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))

这里是train()函数的实现,也是训练的核心部分了,主要完成了训练批数据的定义,网络的前向传播和误差逆传播,这也是训练神经网络的基本算法。一开始,使用AverageMeter类创建了几个统计数据的对象,这个类比较简单,我们放在最后说。

使用torch.utils.data.DataLoader方法进行了训练批数据的创建,这个方法的第一个参数就是数据集,而我们的数据集类已经在dataset.py文件中定义好了,我们只需要传入参数即可完成数据集的创建,可见我们的参数包含了train_list,即为训练图片的地址,打乱shuffle选择了True,对图片进行了transforms的Normalize,然后是batch_size选择1,num_workers是4,这些参数我们在dataset.py的讲解中已经提到了。

后面的model.train()语句我们需要说一下:

net.train()与net.eval():

若训练的模型含BatchNormalize或dropout,那么应当在训练之前使用net.train(),在测试之前使用net,eval()
.train()会保证BN用的是每一批数据的均值和方差,.eval()会保证BN用的是全部数据的均值和方差
而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接(结果是取了平均)

这里显然我们的网络包含了BN,所以在训练之前,我们需要声明一下model.train()。

enumerate(list)的使用:
enumerate(list)会将list的下标和下标对应的值一起返回。这里我们就使用了enumerate来连下标一起读取训练批数据里面的数据。img是图片,而target是真实的密度图。后面是前向和逆向传播的过程:将img转移到gpu上,将img声明为Variable变量,将img传入网络,得出结果,将output与target比较,得出loss,将梯度清零,loss的梯度反向传播,使用optimizer.step完成一次梯度的更新,最后统计一下时间,完成在屏幕上的输出,一切都一气呵成^^。

这里就是train()函数的全部内容了,其实下面的validate()函数也大同小异:

def validate(val_list, model, criterion):
    print ('begin test')
    test_loader = torch.utils.data.DataLoader(
    dataset.listDataset(val_list,
                   shuffle=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
                   ]),  train=False),
    batch_size=args.batch_size)    
    
    model.eval()
    
    mae = 0
    
    for i,(img, target) in enumerate(test_loader):
        img = img.cuda()
        img = Variable(img)
        output = model(img)
        
        mae += abs(output.data.sum()-target.sum().type(torch.FloatTensor).cuda())
        
    mae = mae/len(test_loader)    
    print(' * MAE {mae:.3f} '
              .format(mae=mae))

    return mae    

同train()函数一样,以上来完成测试批数据的建立,只不过这里的shuffle和train改成了False。然后进行model.eval()的声明,作用我们已经说过了。初始化mae。然后同样是读取批数据里的图片,密度图,完成前向传播,只不过这次不需要逆传播,只是需要将输出图片里面的像素值相加,得到网络估计的总人数,与target里的人数相互比较,计算出mae的值,最后将mae的值return回去就可以了。

这就是validate()函数的全部内容了,总的来说还是比train()函数要简单的,我们来继续看adjust_learning_rate()函数,这个函数也比较简单:

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    
    
    args.lr = args.original_lr
    
    for i in range(len(args.steps)):
        
        scale = args.scales[i] if i < len(args.scales) else 1
        
        
        if epoch >= args.steps[i]:
            args.lr = args.lr * scale
            if epoch == args.steps[i]:
                break
        else:
            break
    for param_group in optimizer.param_groups:
        param_group['lr'] = args.lr

adjust_learning_rate()函数的功能是实现训练过程中学习率的改变。需要传入的参数有两个,一个是我们的optimizer对象,还有一个是epoch数,epoch数是我们藉以来判断学习率需要降低到一个什么程度的。一开始,我们先把lr归回到一开始的状态。然后我们遍历args.steps列表中的内容,这里作者设置的值是[-1, 1, 100, 150],意思是网络的学习率在100和150epoch的时候需要改变,然后每一次循环都从scales中读取衰减率,若epoch达到了需要改变的数字,就乘上衰减率,然后将optimizer参数列表里的lr做出改变。

这里函数的意思很容易可以看懂,但是这里还是有疑惑的地方,作者把args.scales的值定为了[1,1,1,1],也就是学习率不管怎么乘,都会是不变的,作者在论文也提到了学习率固定在1e-6,所以这里段代码可能只是模板里的一部分,是冗余的。不知道我的看法是否正确,如果有小伙伴同意我的看法或者有自己的想法的记得来评论区和我交流一下哦!

啊!我们代码的最后一段了,是对AverageMeter类的定义:

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count    

这个类也比较简单,是用来封装一些统计量,并对它们进行计算的。这里定义了四个变量,是方差,平均数,和以及数量,构造函数中将它们初始化为0,并且在update方法中提供了计算这些量的方法。

好啦!简单的AverageMeter类就讲完了!最后的最后,

if __name__ == '__main__':
    main() 

执行main()函数!

一点后记

对代码的解析总算是写完了!有一点累人~

一开始看到CSRNet的代码的时候,我的内心是充满恐惧的,一堆没看过的模块,一堆没用过的函数。不过困难必是可以克服的,有什么不会的东西我们都是可以查,可以学的。学习的过程不就是从无到有吗。这次的代码看下来,我最大的感觉就是自己对深度学习的实现的理解上有了质的飞跃,这篇代码虽有令人疑惑的地方,但是它为我们做了一个很好的示范,让我们知道怎么处理图像,怎么实现程序和文件的交接,怎么自己创建一个数据集,怎么无差错地把网络训练出来,用python把这些事情实现出来简直比仅仅懂它们的原理酷多了!

其实学习更像是一段孤独的旅行,在本科学习深度学习更是如此,即使是计算机专业,身边也鲜有一起学深度学习的同伴。遇到的种种困难,往往是没有教科书,没有人教你如何做的,解决问题的道路需要自己慢慢探索。但是我坚信,充满荆棘的道路一定通向光明!

以后我也会继续更新博客,把我学习历程上遇到的困难和解决办法一起总结出来,如果有和我一样,从事深度学习的小伙伴,欢迎来和我一起讨论交流!

写完这篇文章的时候,快值国庆中秋双庆了,到时候也给自己放个小假^^ 祝各位过一个快乐的节日!

你可能感兴趣的:(深度学习,深度学习,pytorch,神经网络,机器学习)