Pytorch 的网络结构

Net 实例化

一个Net,也就是继承自nn.Module的类,当实例化后。本质上就是维护了以下8个字典(OrderedDict):

_parameters
_buffers
_backward_hooks
_forward_hooks
_forward_pre_hooks
_state_dict_hooks
_load_state_dict_pre_hooks
_modules

1._parameters

parameters就是Net的权重参数(比如conv的weight、conv的bias、fc的weight、fc的bias),类型为tensor,用于前向和反向;比如,你针对Net使用cpu()、cuda()等调用的时候,实际上调用的就是parameter这个tensor的cpu()、cuda()等方法;再比如,你保存模型或者重新加载pth文件的时候,针对的都是parameter的操作或者赋值。

2._buffers(不太清楚原理)

该成员值的填充是通过register_buffer API来完成的,通常用来将一些需要持久化的状态(但又不是网络的参数)放到_buffer里;一些极其个别的操作,比如BN,会将running_mean的值放入进来;

3._modules

_modules成员起很重要的桥梁作用,在获取一个net的所有的parameters的时候,是通过递归遍历该net的所有_modules来实现的。

Net 的前向

网络的前向需要通过诸如Net(input)这样的形式来调用,而非Net.forward(input),是因为前者实现了额外的功能:

1. 先执行完所有_forward_pre_hooks里的hooks

2. 在调用Net的forward函数

3. 再执行所有的_forward_hooks中的hooks

4. 执行完所有的_backward_hooks中的hooks

可以看到:

1,_forward_pre_hooks是在网络的forward之前执行的。这些hooks通过网络的register_forward_pre_hook() API来完成注册,通常只有一些Norm操作会定义_forward_pre_hooks。这种hook不能改变input的内容。

2,_forward_hooks是通过register_forward_hook来完成注册的。这些hooks是在forward完之后被调用的,并且不应该改变input和output。目前就是方便自己测试的时候可以用下。

3,_backward_hooks和_forward_hooks类似。

所以总结起来就是,如果你的网络中没有Norm操作,那么使用Net(input)和Net.forward(input)是等价的。

另外,你必须使用Net.eval()操作来将dropout和BN这些op设置为eval模式,否则你将得到不一致的前向返回值。eval()调用会将Net的实例中的training成员设置为False。

Net模型保存和重新加载

如果我们要保存一个训练好哦PyTorch模型的话,会使用下面的API:

cn = Net()
......
torch.save(cn.state_dict(), "your_model_path.pth")

可以看到使用了网络的state_dict() API调用以及torch模块的save调用。一言以蔽之,模型的保存就是先通过state_dict() API的调用获得一个关于网络参数的字典,再通过pickle模块序列化成文件的形式。

而如果我们要load一个pth模型来进行前向的时候,会使用下面的API:

 

cn = Net()

#参数反序列化为python dict
state_dict = torch.load("your_model_path.pth")
#加载训练好的参数
cn.load_state_dict(state_dict)

#变成测试模式,dropout和BN在训练和测试时不一样
#eval()会把模型中的每个module的self.training设置为False 
cn = cn.cuda().eval()

可以看到使用了torch模块的load调用和网络的load_state_dict() API调用。一言以蔽之,模型的重新加载就是先通过torch.load反序列化pickle文件得到一个Dict,然后再使用该Dict去初始化当前网络的state_dict。torch的save和load API在python2中使用的是cPickle,在python3中使用的是pickle。另外需要注意的是,序列化的pth文件会被写入header信息,包括magic number、version信息等。

##关于pkl文件内容显示程序,在另一个博客中自取。

关于模型的保存,我们需要弄清楚以下概念:1, state_dict;2, 序列化一个pth模型用于以后的前向;3, 为之后的再训练保存一个中间的checkpoint;4,将多个模型保存为一个文件;5,用其它模型的参数来初始化当前的网络;6,跨设备的模型的保存和加载。

1.state_dict

conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias

那如果你使用了DataParallel来训练的话:

cn = nn.DataParallel(cn)

那么state_dict中的key将如下所示:

module.conv1.weight
module.conv1.bias
module.conv2.weight
module.conv2.bias
module.fc1.weight
module.fc1.bias
module.fc2.weight
module.fc2.bias
module.fc3.weight
module.fc3.bias

2.序列化中间过程中的checkpoint

这种序列化的目的是为了之后以这个状态为基点重新开始训练。和前述序列化模型的本质不同就在于还需要序列化optimizer的Dict(比如学习率等参数)。传统上,checkpoint文件用.tar作为后缀:

#save
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

#load
model = Net(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.train()
#model.eval()

3.跨device(cpu/gpu)来save/load模型

比如模型是在GPU上训练的,现在要load到cpu上。或者反之,或者在CPU上训练,在GPU上load。这三种情况下,save的方法是一样的:

torch.save(model.state_dict(), PATH)

而load的方法就不一样了:

###############Save on GPU, Load on CPU #########
device = torch.device('cpu')
model = Net(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

###############Save on GPU, Load on GPU #########
device = torch.device("cuda")
model = Net(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
#确保在输入给网络的tensor上调用input = input.to(device)

###############Save on CPU, Load on GPU #########
device = torch.device("cuda")
model = Net(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
#确保在输入给网络的tensor上调用input = input.to(device)

 

你可能感兴趣的:(深度学习,pytorch)