Pytorch学习笔记-第六章

Pytorch学习笔记-第六章猫狗大战

  • 数据处理和加载
  • 模型定义
  • 训练和测试
    • 训练
    • 测试
  • 过程可视化
  • 工程思想

记录一下个人学习和使用Pytorch中的一些问题。强烈推荐 《深度学习框架PyTorch:入门与实战》.写的非常好而且作者也十分用心,大家都可以看一看,本文记录按照该书学习Pytorch时在第六章猫狗大战遇到的一些问题。

数据处理和加载

这部分实践基于kaggle上一个经典比赛猫狗大战,是一个传统的二分类问题,其训练集包含25000张图片,均放置在同一文件夹下,命名格式为..jpg, 如cat.10000.jpg、dog.100.jpg,测试集包含12500张图片,命名为.jpg,如1000.jpg。参赛者需根据训练集的图片训练模型,并在测试集上进行预测,输出它是狗的概率。
关于数据加载的相关操作,其基本原理就是使用Dataset进行数据集的封装,再使用Dataloader实现数据并行加载。将文件读取等费时操作放在__getitem__函数中,利用多进程加速。避免一次性将所有图片都读进内存,不仅费时也会占用较大内存,而且不易进行数据增强等操作。

#首先要实现一个dataset,主要是实现__getitem__和__len__
class DogCat(data.Dataset):
	def __init__(self, root, transforms=None, train=True, test=False):
	#大部分操作在初始化里面完成:根据路径读入目录内数据,图片裁剪格式,数据标准化,数据增强等。
	def __getitem__(self, index):
	#返回一张图片的数据和标签,对于测试集,没有label,返回图片id,如1000.jpg返回1000
	def __len__(self):
	#返回数据库里图片个数


#dataset完成之后就可以进一步封装成dataloader供以后续模型使用
train_dataset = DogCat(opt.train_data_root, train=True)
trainloader = DataLoader(train_dataset,
                        batch_size = opt.batch_size,
                        shuffle = True,
                        num_workers = opt.num_workers)

模型定义

因为这个实战包含了3个网络的实现,为了让避免重复一些公共函数,所以先把基础的nn.moudle再往上封装一层。

class BasicModule(t.nn.Module):
   '''
   封装了nn.Module,主要提供save和load两个方法
   '''

   def __init__(self,opt=None):
       super(BasicModule,self).__init__()
       self.model_name = str(type(self)) # 模型的默认名字

   def load(self, path):
       '''
       可加载指定路径的模型
       '''
       self.load_state_dict(t.load(path))

   def save(self, name=None):
       '''
       保存模型,默认使用“模型名字+时间”作为文件名,
       如AlexNet_0710_23:57:29.pth
       '''
       if name is None:
           prefix = 'checkpoints/' + self.model_name + '_'
           name = time.strftime(prefix + '%m%d_%H:%M:%S.pth')
       t.save(self.state_dict(), name)
       return name

然后实现具体的模型时集成这个基类就免去了保存加载函数的实现。
注意一下python语言特性里过于包的事情。如果一个目录里有了__init__.py这个文件,那么就可以用import掉包方法调用这个目录里面的py文件。比如:

#再model文件夹的__init__.py文件加入下面两行
from .AlexNet import AlexNet
from .ResNet34 import ResNet34
#那么主函数里面就可以用
from models import AlexNet
#或者
import models
model = models.AlexNet()
#再或者
import models
model = getattr(models, 'AlexNet')()
#最后一种方法是可以用传递字符串的方法直接选择调用的模型比较方便。

使用模型时有一点注意事项:

  • 尽量使用nn.Sequential(比如AlexNet)以为这样可以自动识别子模型结构以及参数。
  • 将经常使用的结构封装成子Module(比如GoogLeNet的Inception结构,ResNet的Residual Block结构)为了代码重用和便捷封装性。
  • 将重复且有规律性的结构,用函数生成(比如VGG的多种变体,ResNet多种变体都是由多个重复卷积层组成)会更加简洁,可读性好。

训练和测试

训练

利用上面模型和数据部分代码,生成一个模型实例和dataloader实例,再生成一个目标函数和优化器。为了顺利可视化训练结果,每个epoch记得保存一下结果。同时在利用验证集验证时记得开启模型的eval模式,然后重新训练时开启train,两个模式在bn层和dropout上的处理不同。

测试

利用训练得到的模型,重新生成一个测试集的dataloader,然后输入到模型,把输出层的结果整理成需要的格式保存下即可。

过程可视化

在Visdom的基础上,加工出几个更方便使用的函数,比如重新设置vis的参数,一次性画多个点,预设定好不同数值线条的样式。封装在单独的工具文件中,方便实际使用和根据需求修改。
有不需修改,调用原生接口函数的话可以借助如下方法传递过去。

def __getattr__(self, name):
       '''
       self.function 等价于self.vis.function
       自定义的plot,image,log,plot_many等除外
       '''
       return getattr(self.vis, name)

工程思想

因为同样的模型用于不同的数据集或者不同设备不同人的使用习惯之类因素,很多参数都是经常调整的,比如文件路径,学习率,batchsize,有些模型的具体工程实现中这些参数是运行模型时命令输入的。但是我很认同作者的习惯,单独的用一个config文件保存,十分方便我们去了解模型哪些参数时可设置的,也很快的就可以去更改,如果碰过命令行输入或者时哪一个py文件设置的变量甚至硬编码到代码里面的参数是相当令人头疼的。
同时像上面过程可视化里面的做法,把自己加工的用于项目的小工具函数单独集中在一个文件中,而不是在用的地方再def一个函数是很好的一个做法,让别人再阅读代码理解实现的时候很容易发现你用了一个自定义函数,而且知道作用是什么,也可以让人根据自己需求更改。

你可能感兴趣的:(DL)