DL模型可视化

    model = Alexnet2fc()
    x = torch.rand(8, 3, 112, 112)
    y = model(x)
    
    # need TensorFlow 2.2 or higher
    # method 1
    # from keras.utils import plot_model
    # plot_model(model, to_file='model.png')

    # method 2
    # from IPython.display import SVG
    # from keras.utils.visualize_util import model_to_dot
    # SVG(model_to_dot(model).create(prog='dot', format='svg'))
    
	from torchviz import make_dot
	# method 3 for not MTL
	# g = make_dot(y)
	# g.render('net_arch', view=False) 
	
	# method 4 for MTL。模型的输出是一个列表,要么拼接
    for i in range(len(y)):
        if i == 0: c = torch.cat((y[0], y[1]), 1)
        elif i >= 2 and i <= len(y)-2:
            c = torch.cat((c, y[i+1]), 1)
    g = make_dot(c)
    g.render('net_arch', view=False)  # 会自动保存为一个 espnet.pdf,第二个参数为True,则会自动打开该PDF文件,为False则不打开

	# method 5 for MTL,要么生成一个元祖,这种方法更推荐!!(模型输出为长度40的列表)
	g = make_dot(tuple((y[i] for i in range(40))),)
    g.render('net_arch', view=False)  # 会自动保存为一个 espnet.pdf,第二个参数为True,则会自动打开该PDF文件,为False则不打开
	

method 1/2: https://bbs.cvmart.net/articles/232/shen-du-xue-xi-de-mo-xing-tiao-shi-xiao-ji-qiao-mei-bie-de-jiu-shi-shi-yong?order_by=created_at&

method 5来自github神人:chesharma @ https://github.com/szagoruyko/pytorchviz/issues/31#issuecomment-662768396

你可能感兴趣的:(#,DL-基础,#,可视化,Pytorch)