pytorch查看网络结构

请参考:http://www.freesion.com/article/340667237/

想在终端将我们自己的网络结构保存为pdf文件,就用下面这种方法:

import torch
from torchvision.models import AlexNet
from torchviz import make_dot

x=torch.rand(8,1,224,224).cuda()
model=torch.load('/home/resnet_mnist.pth')
y=model(x)
g = make_dot(y)
g.render('./weights/espnet_model', view=False) 

注意,这里的输入数据要是这种形式才行:x=torch.rand(8,1,224,224).cuda() #如果模型是用GPU训练的就加.cuda()

你可能感兴趣的:(机器学习,深度学习,jupyter)