前情提要:NNI是微软开发的调参工具,功能有很多,这里介绍其中的一个分支-模型压缩。
模型压缩流程:
1.模型prune
2.模型speedup
模型prune不多介绍,模型speedup就是根据掩码修改模型的结构,比如说通道剪枝,第N层的输出
通道数由10降到了5,那么第N+1层的输入通道数是不是要变成5呀。要保证剪完枝,网络各层还能衔接起来。代码基本就这三行,不过一般运行会出很多问题,除非speedup的模型非常简单且常规
apply_compression_results(net, masks_file, device)
m_speedup = ModelSpeedup(net, dummy_input, masks_file, device)
m_speedup.speedup_model()
运行后报错:
out_channel = out_shape[1]
IndexError: list index out of range
这样的报错,网上很难搜到解决方案,只能自己研究源代码了,m_speedup.speedup_model()的源代码如下:
def speedup_model(self):
"""
There are basically two steps: first, do mask/shape inference,
second, replace modules.
主要有两个步骤:首先,进行mask/形状推断,
第二,替换模块。
"""
_logger.info("start to speed up the model")
self.initialize_speedup()
training = self.bound_model.training
# 设置到测试模式
self.bound_model.train(False)
# 假设在稀疏传播后修复冲突
# which is more elegent 哪一个更优雅?
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
_logger.info("infer module masks...")
self.infer_modules_masks()
_logger.info('resolve the mask conflict')
# 在更换模型之前加载原始权重(dict形式)
self.bound_model.load_state_dict(self.ori_state_dict)
_logger.info("replace compressed modules...")
# mask冲突应该已经解决了
self.replace_compressed_modules()
self.bound_model.train(training)
_logger.info("speedup done")
报错位置为:
fix_mask_conflict(self.masks, self.bound_model, self.dummy_input)
参数解释:
self.masks: 就是模型的掩码,值由0和1构成,具体如下图所示
self.bound_model:网络模型
self.dummy_input:输入
继续往下研究代码,fix_mask_conflict的源代码如下
def fix_mask_conflict(masks, model, dummy_input, traced=None):
"""
MaskConflict修复通道依赖项和组依赖项的掩码mask冲突。
Parameters
----------
masks : dict/str
A dict object that stores the masks or the path of the mask file
model : torch.nn.Module
model to fix the mask conflict
dummy_input : torch.Tensor/list of tensors/dict of tensors
input example to trace the model
traced : torch._C.torch.jit.TopLevelTracedModule
the traced model of the target model, is this parameter is not None,
目标模型的跟踪模型,该参数不是None,
不使用模型和dummpy_input来获得跟踪图。
"""
if isinstance(masks, str):
# 如果mask是路径 则加载mask
assert os.path.exists(masks)
masks = torch.load(masks)
assert len(masks) > 0, 'Mask tensor cannot be empty'
#如果用户使用模型和伪_输入来跟踪模型,我们应该手动获取跟踪模型,这样,我们只跟踪一次模型,
#GroupMaskConflict和ChannelMaskConflict将重用此跟踪模型。
if traced is None:
assert model is not None and dummy_input is not None
training = model.training
# 需要跟踪eval mode
model.eval()
kw_args = {}
if torch.__version__ >= '1.6.0':
# 只有版本大于1.6.0的Pytork才有严格的选项strict 选项
kw_args['strict'] = False
traced = torch.jit.trace(model, dummy_input, **kw_args)
model.train(training)
#以下几行为修复组合通道mask冲突
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
masks = fix_group_mask.fix_mask()
fix_channel_mask = ChannelMaskConflict(masks, model, dummy_input, traced)
masks = fix_channel_mask.fix_mask()
return masks
报错位置为:
masks = fix_channel_mask.fix_mask()
1.strict(bool,可选):是否严格强制:attr:`state_dict`中的键与此模块返回的键匹配:meth:`torch.nn.Module.state_dict。默认值:`True``
2.torch.jit.trace 可以将现有模型或Python函数转换为TorchScript:class:`ScriptFunction`或:class:`ScriptModule`。您必须提供示例输入,然后我们运行函数,记录对所有张量执行的操作。
继续剥洋葱,fix_mask()源代码为:
def fix_mask(self):
"""
在对具有形状依赖关系的层进行mask推断之前,修复mask冲突。
应在“加速”模块的mask推断之前调用此函数。
仅支持结构化修剪mask。
"""
if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced, self.channel_prune_type)
else:
channel_depen = InputChannelDependency(
self.model, self.dummy_input, self.traced)
······
###########后面还有代码,仅列出一部分
报错位置为:
channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced, self.channel_prune_type)
继续深入:ChannelDependency是个类,继承自Dependency
class ChannelDependency(Dependency):
def __init__(self, model, dummy_input, traced_model=None, prune_type='Filter'):
"""
该模型分析模型中conv层之间的通道依赖关系。
Parameters
----------
model : torch.nn.Module
要分析的模型
data : torch.Tensor
示例输入数据以跟踪网络架构。
traced_model : torch._C.Graph
如果我们已经有了目标模型的跟踪图,我们就不需要再跟踪模型了。
prune_type: str
此参数表示通道修剪类型:
1)`Filter`修剪卷积层的过滤器以修剪相应的通道
2)Batchnorm`:修剪Batchnorm层中的通道
"""
self.prune_type = prune_type
self.target_types = []
if self.prune_type == 'Filter':
self.target_types.extend(['Conv2d', 'Linear', 'ConvTranspose2d'])
elif self.prune_type == 'Batchnorm':
self.target_types.append('BatchNorm2d')
super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)
报错位置为:
super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)
继续查看父类Dependency的代码:
class Dependency:
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
为模型建立图
"""
from nni.common.graph_utils import TorchModuleGraph
# 检查输入是否合法
if traced_model is None:
# 用户应提供模型和虚拟输入以进行跟踪
# 模型或已跟踪的模型
assert model is not None and dummy_input is not None
self.graph = TorchModuleGraph(model, dummy_input, traced_model)
self.model = model
self.dependency = dict()
self.build_dependency()
def build_dependency(self):
raise NotImplementedError
def export(self, filepath):
raise NotImplementedError
报错位置为:
self.build_dependency()
看来是build_dependency()出错了,继续往下看build_dependency()代码:
def build_dependency(self):
"""
为模型中的conv层构建通道依赖关系。
"""
# 在分析数据之前,手动解压缩元组/列表
# 通道依赖性
self.graph.unpack_manually()
for node in self.graph.nodes_py.nodes_op:
parent_layers = []
# 找到包含 aten::add的节点
# 或者 aten::cat 操作
if node.op_type in ADD_TYPES:
parent_layers = self._get_parent_layers(node)
elif node.op_type == CAT_TYPE:
#确定此cat操作是否会引入通道
#依赖关系,我们需要cat的特定输入参数
#操作。为了获得cat操作的输入参数,我们
#需要遍历此NodePyGroup包含的所有cpp_节点,
#因为,TorchModuleGraph合并了重要节点和相邻节点
#不重要的节点(例如,以prim::attr开头的节点)进入
#NodepyGroup。
cat_dim = None
for cnode in node.node_cpps:
if cnode.kind() == CAT_TYPE:
cat_dim = list(cnode.inputs())[1].toIValue()
break
if cat_dim != 1:
parent_layers = self._get_parent_layers(node)
dependency_set = set(parent_layers)
#合并 dependencies
for parent in parent_layers:
if parent in self.dependency:
dependency_set.update(self.dependency[parent])
# 保存 dependencies
for _node in dependency_set:
self.dependency[_node] = dependency_set
报错位置为:
parent_layers = self._get_parent_layers(node)
build_dependency()里涉及到了图,即self.graph。graph里的成员还是挺多的:
成员介绍:
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
继续看 代码self._get_parent_layers
def _get_parent_layers(self, node):
"""
为目标节点查找最近的父conv层。
Parameters
---------
node : torch._C.Node
target node.
Returns
-------
parent_layers: list
nearest father conv/linear layers for the target worknode.
"""
parent_layers = []
queue = []
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type in self.target_types:
# 找到第一个相遇的conv
parent_layers.append(curnode.name)
continue
elif curnode.op_type in RESHAPE_OPS:
if reshape_break_channel_dependency(curnode):
continue
parents = self.graph.find_predecessors(curnode.unique_name)
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
queue.append(parent)
return parent_layers
报错位置:
if reshape_break_channel_dependency(curnode):
继续看reshape_break_channel_dependency函数代码
def reshape_break_channel_dependency(op_node):
"""
重塑操作(reshape, view, flatten)可能会打破通道依赖性。我们需要检查这些重塑操作的输入参数,
以检查这个重塑节点是否会打破通道依赖性。然而,分析每个重塑函数的输入参数并推断它是否会打破通
道依赖性是很复杂的。所以目前,我们只是检查输入通道和输出通道是否相同,如果是,那么我们可以说
原始的重塑函数不想改变通道的数量,这意味着通道依赖性没有被打破。相比之下,原始的重塑操作想要更
改通道的数量,因此它打破了通道依赖性。
Parameters
----------
opnode: NodePyOP
A Op node of the graph.
Returns
-------
bool
是否这个操作会打破通道依赖
"""
in_shape = op_node.auxiliary['in_shape']
out_shape = op_node.auxiliary['out_shape']
in_channel = in_shape[1]
out_channel = out_shape[1]
return in_channel != out_channel
报错位置为
out_channel = out_shape[1]
根据代码所示,in_shape和out_shape至少有两个数据,打印in_shape
out:[[9], [9]]
打印out_shape:
out:[18]
函数解析:
判断重塑操作是否会打破通道依赖性,我这里弹出的错误是由cat操作引起的。
一般tensor的形状为[b,c,w,h],代码要判断输入输出通道是否相同,所以代码里比较的是第二维度。我这里的cat的对象不是标准tensor形状,而是一个一维数据,所以肯定没有索引为1的数据。
不过还有个问题,即使是标准的tensor[b,c,w,h]执行cat操作:
代码 out_shape[1] 依然会报错,因为out_shape只有一个成员,奇怪了?
解决方式:
一维数据不采用torch.cat拼接!并将reshape_break_channel_dependency中的
in_channel = in_shape[1]
out_channel = out_shape[1]改为:
in_channel = in_shape[0][1]
out_channel = out_shape[1]
ok,这个问题暂时解决,但是 in_channel = in_shape[0][1]又报
'int' object is not subscriptable
这个节点里的内容一会一个样,迎合了A,B又不行了······
这次是view操作导致:
view前,形状[18],view之后[1,18,1,1]
解决方式:
不同的重塑操作,采用不同的提取方式!
总结:
speedup出现的问题基本上都是由网络模型中采取的某种操作导致的,因为speedup编写的通用代码不可能适用于所有情况。总之,调试起来还是有点麻烦的。
后记:
后来在 self.infer_modules_masks()时,又出现了100个错误,实在是解决不了了,换代码!