mxnet(gluon)学习之路-使用HybridBlock构建网络

简介

mxnet gluon接口提供 Sequential, HybridSequential 通过 add 串联的形式将构建网络,同时也提供HybridBlock通过继承的方式来构建网络,那么他们之间有什么区别呢?
1. Sequential构建的是动态图,即命令式编程形式,这种形式可以很方便的debug。例如:

net = nn.Sequential()
net.add(
    nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
    nn.MaxPool2D(pool_size=2, strides=2),
    nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
    nn.MaxPool2D(pool_size=2, strides=2),
    nn.Dense(120, activation='sigmoid'),
    nn.Dense(84, activation='sigmoid'),
    nn.Dense(10)
)

调试的时候,我们可以直接输入:

x = nd.random_normal(shape=(1,1,, 64,64))
net(x)

进行调试,而不用先构建整个图,再进行计算,效率高,同时可以直接输入变量 x,而不用将x先转为符号, x=mx.sym.var(‘data’)再进行调试。
2. HybridSequential 和HybridBlock则可以进行命令式和符号式混合,可以在动态图和静态图之间转换,使用者可以先用imperatvie的方式写网络,debug,最后跑通网络之后,如果网络是一个静态图结构,就可以用net.hybridize()的方式将其转换成静态图,众所周知静态图的运算会比动态图快,所以这是Gluon比PyTorch更好的地方。

使用HybridBlock构建网络

  1. name_scope 使用同一个名称前缀,保证每个变量唯一
  2. hybridize 将动态图转换为静态图提升执行效率
  3. F 根据输入来决定F使用 NDArray 或 Symbol
import mxnet as mx
from mxnet.gluon import loss as gloss, nn
import mxnet.gluon as gluno

class LeNet(gluno.nn.HybridBlock):
    def __init__(self, classes=10,feature_size=2, **kwargs):
        super(LeNet,self).__init__(**kwargs)

        with self.name_scope():
            self.conv1 = nn.Conv2D(channels=20, kernel_size=5, activation='relu')
            self.conv2 = nn.Conv2D(channels=50, kernel_size=5, activation='relu')
            self.maxpool = nn.MaxPool2D(pool_size=2, strides=2)
            self.flat = nn.Flatten()
            self.dense1 = nn.Dense(feature_size)
            self.dense2 = nn.Dense(classes)

    def hybrid_forward(self, F, x, *args, **kwargs):
        print('F: ',F)
        print('X: ',x)

        x = self.maxpool(self.conv1(x))
        x = self.maxpool(self.conv2(x))
        x = self.flat(x)
        ft = self.dense1(x)
        output = self.dense2(ft)
        return output

if __name__ == '__main__':
    net = LeNet()
    net.initialize()
    x = mx.nd.random.normal(shape=(1,1, 64, 64))
    net(x)
    net.hybridize()

你可能感兴趣的:(mxnet-gluon之路)