pytorch中的钩子(Hook)有何作用 和 查看模型中间结果

#coding=UTF-8


import torch
import caffe
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
# caffemodel.
# model.pkl

import torchvision.models as models

alexnet = models.alexnet()

params = alexnet.load_state_dict(torch.load('alexnet-owt-4df8aa71.pth'))

print('The architecture of alexnet: ')
#print(alexnet)
#params = alexnet.state_dict()

imgSize = [224,224]

img = Image.open('cat.jpg')
res_img = img.resize((imgSize[0],imgSize[1]))
img = np.double(res_img)
img = img[:,:,(2,1,0)] # rgb 2 bgr
img = np.transpose(img, (2,0,1)) # h * w *c==> c*h*w


print(img.shape)
#plt.imshow(img)
#plt.show()

'''
for k,v in params.items():
    print(k)

a = params['features.0.weight']
#print(params['features.0.weight'])

print(a.shape)
print(type(a))
print(a)
'''

def vis_square(data):
    """Take an array of shape (n, height, width) or (n, height, width, 3)
       and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""

    # normalize data for display
    data = (data - data.min()) / (data.max() - data.min())

    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
               (0, 1), (0, 1))                 # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
    data = np.pad(data, padding, mode='constant', constant_values=1)  # pad with ones (white)

    # tile the filters into an image
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])

    plt.imshow(data); plt.axis('off')

data_arr = np.zeros(shape=(1,3,imgSize[0],imgSize[1]),dtype=np.float32)
data_arr[0,...] = img
input_data = Variable(torch.from_numpy(data_arr).type(torch.FloatTensor))

feat_result  = []
grad_result = []

def get_features_hook(self,input,output):
    # number of input:
    print('len(input): ',len(input))
    # number of output:
    print('len(output): ',len(output))
    print('###################################')
    print(input[0].shape) # torch.Size([1, 3, 224, 224])

    print('###################################')
    print(output[0].shape) # torch.Size([64, 55, 55])


    feat_result.append(output.data.cpu().numpy())

def get_grads_hook(self,input_grad, output_grad):
    # number of input:
    print('len(input): ', len(input_grad))
    # number of output:
    print('len(output): ', len(output_grad))

    print('###################################')
    print(input_grad[0]) # None 
    print(input_grad[1].shape) # torch.Size([64, 3, 11, 11])  for weights
    print(input_grad[2].shape) # torch.Size([64])             for bias

    print('###################################')
    print(output_grad[0].shape) # torch.Size([1, 64, 55, 55]) for x
    print('###################################')

    grad_result.append(output_grad[0].data.cpu().numpy())

handle_feat = alexnet.features[0].register_forward_hook(get_features_hook) # conv1
handle_grad = alexnet.features[0].register_backward_hook(get_grads_hook)



num_class = 1000
a  = alexnet(input_data)
print('a.shape: ', a.shape)
a.backward(torch.ones(1,num_class))

#### remove handle
handle_feat.remove()
handle_grad.remove()

feat1 = feat_result[0]
grad1 = grad_result[0]

vis_square(feat1[0,...])
#plt.show()
plt.savefig('feat_visual.png')
vis_square(grad1[0,...])
#plt.show()
plt.savefig('grad_x.png')

print('save feature and gradx over ...')

第一层卷积后的结果:
pytorch中的钩子(Hook)有何作用 和 查看模型中间结果_第1张图片

关于第一层结果的梯度(好像是):

pytorch中的钩子(Hook)有何作用 和 查看模型中间结果_第2张图片

摘自:
1.https://www.zhihu.com/question/61044004【pytorch中的钩子(Hook)有何作用】
2.https://blog.csdn.net/manong_wxd/article/details/78720119【PyTorch学习总结(一)——查看模型中间结果】

你可能感兴趣的:(pytorch中的钩子(Hook)有何作用 和 查看模型中间结果)