Pytorch神经网络中间层——可视化显示

Pytorch神经网络中间层——可视化显示

一、基本步骤

step 1: 加载预训练模型

step 2: 定义显示输出的钩子函数

step 3: 将钩子函数挂载到网络层上

将钩子函数挂载到相应的网络层之后,在测试过程中,会自动执行钩子函数内部的代码。

二、具体实现

from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from utils import *
from network.Network import * 
from utils.load_test_setting import *
from skimage.metrics import peak_signal_noise_ratio
import torch 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 加载网络模型
network = Network(H, W, message_length, noise_layers, device, batch_size, lr, with_diffusion)
EC_path = result_folder + "/models/EC_" + str(model_epoch) + ".pth"
network.load_model_ed(EC_path)  


# 2. 定义钩子函数
#    该函数内实现了64个特征图的显示和存储
def image_hook_fn(module, input, output):
    for i in range(64): 
        plt.subplot(8, 8, i + 1)
        fig = plt.figure(figsize=(1.28, 1.28), dpi=100)
        imageShow = output[0][i].detach().cpu().numpy()
        imageShow = (imageShow + 1) / 2
        plt.imshow(imageShow, cmap='gray')
        filename = os.path.join(result_folder, "images", "image_first_layer_filters", str(i) + ".png")
        fig.savefig(filename, bbox_inches='tight')
    plt.show() 

# 3. 注册钩子函数到目标中间层
#    我想可视化的是网络中的image_first_layer层
image_first_layer = network.encoder_decoder.module.encoder.image_first_layer
image_first_layer.register_forward_hook(image_hook_fn) 

# 读取测试集
test_dataset = MBRSDataset(os.path.join(dataset_path, "test"), H, W)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

# 进行测试流程
print("\nStart Testing : \n\n") 
for i, images in enumerate(test_dataloader):
    image = images.to(device) 
    message = torch.Tensor(np.random.choice([0, 1], (image.shape[0], message_length))).to(device) 
    
    network.encoder_decoder.eval()
    network.discriminator.eval()

    with torch.no_grad(): 
        images, messages = images.to(network.device), message.to(network.device) 
        encoded_images = network.encoder_decoder.module.encoder(images, messages)
        encoded_images = images + (encoded_images - image) * strength_factor
        noised_images = network.encoder_decoder.module.noise([encoded_images, images]) 
        decoded_messages = network.encoder_decoder.module.decoder(noised_images) 

三、可视化结果

我的image_first_layer网络层有64个filter,将每个filter的单独可视化,结果如下所示:
Pytorch神经网络中间层——可视化显示_第1张图片

你可能感兴趣的:(深度学习,pytorch,神经网络,python)