pytorch打印网络结构

最简单的方法当然可以直接print(net),但是这样网络比较复杂的时候效果不太好,看着比较乱;以前使用caffe的时候有一个网站可以在线生成网络框图,tensorflow可以用tensor board,keras中可以用model.summary()、或者plot_model()。pytorch没有这样的API,但是可以用代码来完成。

(1)安装环境:graphviz

conda install -n pytorch python-graphviz

或:

sudo apt-get install graphviz

或者从官网下载,按此教程。


(2)生成网络结构的代码:

def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def size_to_str(size):
        return '('+(', ').join(['%d' % v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot


(3)打印网络结构:

import torch  
from torch.autograd import Variable  
import torch.nn as nn  
from graphviz import Digraph

class CNN(nn.module):
    def __init__(self):
     ******
     def forward(self,x):
      ******
      return out

*****************************
def make_dot():  #复制上面的代码
*****************************

if __name__ == '__main__':  
    net = CNN()  
    x = Variable(torch.randn(1, 1, 1024,1024))  
    y = net(x)  
    g = make_dot(y)  
    g.view()  
  
    params = list(net.parameters())  
    k = 0  
    for i in params:  
        l = 1  
        print("该层的结构:" + str(list(i.size())))  
        for j in i.size():  
            l *= j  
        print("该层参数和:" + str(l))  
        k = k + l  
    print("总参数数量和:" + str(k))

(4)结果展示(例如这是一个resnet block类型的网络):

pytorch打印网络结构_第1张图片




你可能感兴趣的:(deep-learning,python)