gluon网络可视化

  1,利用print(net)。就能看到net网络具体每层的参数。
  
  2,利用mx.viz.plot_network(net).view()。这个能具体出pdf图像,不过要先转成sym编程才行。如何转见这里

  举个栗子:

import mxnet as mx
from mxnet.gluon import nn


class Net(nn.HybridBlock):
    def __init__(self, **kwargs):
        super(Net, self).__init__(**kwargs)
        with self.name_scope():
            self.rgb_conv1 = nn.Conv2D(channels=128, kernel_size=3, strides=1, padding=1, activation='relu')
            self.rgb_conv2 = nn.Conv2D(channels=128, kernel_size=3, strides=1, padding=1, activation='relu')

            self.tdf_conv1 = nn.Conv3D(channels=128, kernel_size=3, strides=1, padding=1, activation='relu')
            self.tdf_conv2 = nn.Conv3D(channels=128, kernel_size=3, strides=1, padding=1, activation='relu')

            self.dense_fusion1 = nn.Dense(1024, activation='relu')
            self.dense_fusion2 = nn.Dense(2048, activation='relu')
            self.dense_fusion3 = nn.Dense(512, activation='relu')

            self.dense_prediction1 = nn.Dense(2048, activation='relu')
            self.dense_prediction2 = nn.Dense(2)

            self.maxpool2D = nn.MaxPool2D(pool_size=3, strides=2)
            self.maxpool3D = nn.MaxPool3D(pool_size=3, strides=2)

    def CNNBlock(self, F, rgb, tdf):
        rgb_conv = self.rgb_conv1(rgb)
        rgb_conv = self.maxpool2D(rgb_conv)
        rgb_conv = self.rgb_conv2(rgb_conv)

        tdf_conv = self.tdf_conv1(tdf)
        tdf_conv = self.maxpool3D(tdf_conv)
        tdf_conv = self.tdf_conv2(tdf_conv)

        flatten = nn.Flatten()
        rgb_conv = flatten(rgb_conv)
        tdf_conv = flatten(tdf_conv)

        fc = F.concat(rgb_conv, tdf_conv, dim=1)
        fc = self.dense_fusion1(fc)
        fc = self.dense_fusion2(fc)
        fc = self.dense_fusion3(fc)

        return fc

    def hybrid_forward(self, F, rgb1, tdf1, rgb2, tdf2):
        out1 = self.CNNBlock(F, rgb1, tdf1)
        out2 = self.CNNBlock(F, rgb2, tdf2)

        out = F.concat(out1, out2, dim=1)
        out = self.dense_prediction2(self.dense_prediction1(out))
        return out

    def getFeature(self, img, depth):
        return self.CNNBlock(img, depth)

net = Net()
# 第一种方法
print(net)
# 第二种方法
mx.viz.plot_network(net(mx.sym.var("data1"), mx.sym.var("data2"), mx.sym.var("data3"), mx.sym.var("data4"))).view()

你可能感兴趣的:(mxnet,gluon,网络可视化)