Pytorch可视化自定义CNN中间层的每一层输出

定义LayerActivations类,

python代码如下:

class LayerActivations:
    features = None

    def __init__(self, model, layer_num):
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        self.features = output.cpu()

    def remove(self):
        self.hook.remove()

将自己写的CNN模型以及想要产生可视化的层数输入,借助register_forward_hook(pytorch内部的函数)来输出
可以打印CNN的features

print(cnn.parameters())

随机导入一张图片

img = next(iter(train_data))[0]

train_data是Dataloader数据加载器

可视化输出

fig = plt.figure(figsize=(10, 10))
fig.subplots_adjust(left=0, right=1, bottom=0, top=0.8, hspace=0, wspace=0.2)
for i in range(len(cnn.features)):
    conv_out = LayerActivations(cnn.features, i)
    # o = cnn(Variable(img.cuda()))
    o = cnn(Variable(img))
    conv_out.remove()  #
    act = conv_out.features
    for j in range(1):
        ax = fig.add_subplot(6, 5, i + 1, xticks=[], yticks=[])
        ax.imshow(act[0][j].detach().numpy(), cmap="gray")

建立图片模板后,循环加载每一层CNN并将其特征导出
再此之前可以先将训练好保存的模型导入

模型导入

PATH = './cnn_model.pth'
# torch.save(cnn.state_dict(), PATH)
# cnn = CNNClass(*args, **kwargs)
cnn =CNN()
cnn.load_state_dict(torch.load(PATH))

PATH中保存的是CNN模型中的各个层的特征和权重,因此需要加载自己定义的CNN结构
完整代码如下:(省略了模型训练过程,以及测试过程以及自定义的模型)

完整代码

PATH = './cnn_model.pth'
# torch.save(cnn.state_dict(), PATH)
# cnn = CNNClass(*args, **kwargs)
cnn =CNN()
cnn.load_state_dict(torch.load(PATH))
cnn.eval()


class LayerActivations:
    features = None

    def __init__(self, model, layer_num):
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)

    def hook_fn(self, module, input, output):
        self.features = output.cpu()

    def remove(self):
        self.hook.remove()


img = next(iter(train_data))[0]
fig = plt.figure(figsize=(10, 10))
fig.subplots_adjust(left=0, right=1, bottom=0, top=0.8, hspace=0, wspace=0.2)
for i in range(len(cnn.features)):
    conv_out = LayerActivations(cnn.features, i)
    # o = cnn(Variable(img.cuda()))
    o = cnn(Variable(img))
    conv_out.remove()  #
    act = conv_out.features
    for j in range(1):
        ax = fig.add_subplot(5, 3, i + 1, xticks=[], yticks=[])
        ax.imshow(act[0][j].detach().numpy(), cmap="gray")

plt.savefig(r'./out/cnn_layer_visiual.png',dpi=600)
plt.show()

你可能感兴趣的:(pytorch,可视化,python,深度学习,人工智能)