一直以来,深度神经网络作为一种功能强大的“黑盒”,被认为可解释性较弱。目前,常用的一种典型可解释性分析方法是就是可视化方法。
本文整理了深度神经网络训练过程中常用的可视化技巧,便于对训练过程进行分析和检查。
以resnet18为例,提取第一层的卷积核(7x7)进行可视化,可以看出大多提取的是边缘、角点之类的底层视觉特征。
在全连接层前的卷积层采用的是3x3卷积核,表达高层语义信息,更加抽象:
这里采用torchvision.utils.make_grid对卷积核进行网格化显示,图像网格的列数由nrow参数确定。
卷积核的可视化代码参考1进行修改:
def plot_conv(writer,model):
for name,param in model.named_parameters():
if 'conv' in name and 'weight' in name:
in_channels = param.size()[1] # 输入通道
out_channels = param.size()[0] # 输出通道
k_w, k_h = param.size()[3], param.size()[2] # 卷积核的尺寸
kernel_all = param.view(-1, 1, k_w, k_h) # 每个通道的卷积核
kernel_grid = torchvision.utils.make_grid(kernel_all, normalize=True, scale_each=True, nrow=in_channels)
writer.add_image(f'{name}_all', kernel_grid, global_step=0)
利用直方图可以对每一层参数的分布进行直观展示,便于分析模型参数的学习情况。
代码示例如下:
def plot_param_hist(writer,model):
for name, param in model.named_parameters():
writer.add_histogram(f"{name}", param, 0)
输入图像经过第一个卷积层的激活映射:
从pytorch模型中获取指定层的权重和激活的代码如下,参考facebook的工程2:
class GetWeightAndActivation:
"""
A class used to get weights and activations from specified layers from a Pytorch model.
"""
def __init__(self, model, layers):
"""
Args:
model (nn.Module): the model containing layers to obtain weights and activations from.
layers (list of strings): a list of layer names to obtain weights and activations from.
Names are hierarchical, separated by /. For example, If a layer follow a path
"s1" ---> "pathway0_stem" ---> "conv", the layer path is "s1/pathway0_stem/conv".
"""
self.model = model
self.hooks = {}
self.layers_names = layers
# eval mode
self.model.eval()
self._register_hooks()
def _get_layer(self, layer_name):
"""
Return a layer (nn.Module Object) given a hierarchical layer name, separated by /.
Args:
layer_name (str): the name of the layer.
"""
layer_ls = layer_name.split("/")
prev_module = self.model
for layer in layer_ls:
prev_module = prev_module._modules[layer]
return prev_module
def _register_single_hook(self, layer_name):
"""
Register hook to a layer, given layer_name, to obtain activations.
Args:
layer_name (str): name of the layer.
"""
def hook_fn(module, input, output):
self.hooks[layer_name] = output.clone().detach()
layer = get_layer(self.model, layer_name)
layer.register_forward_hook(hook_fn)
def _register_hooks(self):
"""
Register hooks to layers in `self.layers_names`.
"""
for layer_name in self.layers_names:
self._register_single_hook(layer_name)
def get_activations(self, input, bboxes=None):
"""
Obtain all activations from layers that we register hooks for.
Args:
input (tensors, list of tensors): the model input.
bboxes (Optional): Bouding boxes data that might be required
by the model.
Returns:
activation_dict (Python dictionary): a dictionary of the pair
{layer_name: list of activations}, where activations are outputs returned
by the layer.
"""
input_clone = [inp.clone() for inp in input]
if bboxes is not None:
preds = self.model(input_clone, bboxes)
else:
preds = self.model(input_clone)
activation_dict = {}
for layer_name, hook in self.hooks.items():
# list of activations for each instance.
activation_dict[layer_name] = hook
return activation_dict, preds
def get_weights(self):
"""
Returns weights from registered layers.
Returns:
weights (Python dictionary): a dictionary of the pair
{layer_name: weight}, where weight is the weight tensor.
"""
weights = {}
for layer in self.layers_names:
cur_layer = get_layer(self.model, layer)
if hasattr(cur_layer, "weight"):
weights[layer] = cur_layer.weight.clone().detach()
else:
logger.error(
"Layer {} does not have weight attribute.".format(layer)
)
return weights
对给定输入进行测试,输出指定层的激活映射,并绘制在tensorboard中:
# 模型测试,避免改变权重
model.eval()
# Set up writer for logging to Tensorboard format.
writer = tb.TensorboardWriter(cfg)
# 注册指定层的激活hook
layer_ls=["conv1","layer1/1/conv2","layer2/1/conv2","layer3/1/conv2","layer4/1/conv2"]
model_vis = GetWeightAndActivation(model, layer_ls)
# 给定一个输入,获取指定层的激活映射
activations, preds = model_vis.get_activations(inputs)
# 绘制激活映射(如画在tensorboard中)
plot_weights_and_activations(writer,activations,tag="Input {}/Activations: ".format(0))
本文整理了深度神经网络常用的局部可视化代码,对卷积核、权重和激活映射进行可视化,便于对训练过程进行分析和检查。有需要的朋友可以马住收藏。
https://zhuanlan.zhihu.com/p/54947519 ↩︎
https://github.com/facebookresearch/SlowFast ↩︎