MXNet - 网络结构可视化

MXNet网络结构可视化

  • mx.viz.plot_networks
    以由节点和边组成的计算图的方式表示网络结构.
    输入: Symbol,网络定义、node_attr属性、节点的shape参数.

Prerequisites

  • Graphviz

Examples 1 - 线性矩阵分解网络

import mxnet as mx
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')

# Set dummy dimensions
k = 64
max_user = 100
max_item = 50

# user feature lookup
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)

# item feature lookup
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)

# predict by the inner product, which is elementwise product and then sum
net = user * item
net = mx.symbol.sum_axis(data = net, axis = 1)
net = mx.symbol.Flatten(data = net)

# loss layer
net = mx.symbol.LinearRegressionOutput(data = net, label = score)

# 网络可视化
mx.viz.plot_network(net)

Output:
MXNet - 网络结构可视化_第1张图片

Examples 2

import mxnet as mx  
import numpy as np  
import cv2  
import matplotlib.pyplot as plt  

# 网络定义  
data = mx.symbol.Variable('data')  
fc1 = mx.symbol.FullyConnected(data=data,name='fc1',num_hidden=128)  
act1 = mx.symbol.Activation(data=fc1,name='relu1',act_type='relu')  
fc2 = mx.symbol.FullyConnected(data=act1,name='fc2',num_hidden=64)  
act2 = mx.symbol.Activation(data=fc2,name='relu2',act_type='relu')  
fc3 = mx.symbol.FullyConnected(data=act2,name='fc3',num_hidden=10)  
mlp = mx.symbol.SoftmaxOutput(data=fc3,name='softmax')  

# 网络可视化  
mx.viz.plot_network(mlp).view()  

Output:
MXNet - 网络结构可视化_第2张图片

Reference

  • How to visualize Neural Networks as computation graph
  • Visualizing CNN architectures side by side with mxnet
  • Mxnet学习系列3—-网络的可视化

你可能感兴趣的:(MXNet)