动手学gluon系列之读取预训练模型----多种方法读取预训练模型进行finetune

本文主要是博主学习gluon时候的一些总结,共勉,如有错误,欢迎指正

gluon主要有3个方法得到预训练模型:

  • gluon自身的model_zoo
  • gluoncv提供的model_zoo
  • mxnet提供的预训练模型(.params ,.json)

下面分别就这三个方面进行介绍


一:读取gluon model_zoo提供的模型,并进行finetune

gluon提供的model主要在gluon.model_zoo.vision下,模型地址:https://mxnet.incubator.apache.org/api/python/gluon/model_zoo.html,你可以根据自己的情况查找对应的模型进行使用。model_zoo提供的模型均为features+output结构

调用方法如下:

一:只修改最终的fc层,进行finetune:

from mxnet import gluon
class_num = 3
ctx = [mx.gpu(0),mx.gpu(1)]

finetune_net = gluon.model_zoo.vision.resnet50_v2(pretrained=True)

with finetune_net.name_scope():
    finetune_net.output = nn.Dense(class_num)
finetune_net.output.initialize(init=mx.init.Xavier(),ctx=ctx)
finetune_net.hybridize()

二:不仅仅修改最终的fc层,还可以增加几层

下面的方法,首先提取出features,然后构建增加的sequential,最后将两部分通过sequential合并在一起。

from mxnet import gluon
pretrained_net = gluon.model_zoo.vision.resnet50_v2(pretrained=True)
pretrained_net_features = pretrained_net.features

class_num = 3
ctx = [mx.gpu(0),mx.gpu(1)]
modify_net = nn.HybridSequential(prefix="")
with modify_net.name_scope():
    modify_net.add(nn.Dense(128,activation='relu'),
                nn.Dropout(0.5),
                nn.Dense(class_num))
    modify_net.collect_params().initialize(init=mx.init.Xavier(),ctx=ctx)
    
net = nn.HybridSequential(prefix="")
with modify_net.name_scope():
    net.add(pretrained_net_features)
    net.add(modify_net)
net.hybridize() ## 该语句代表静态图动态图切换。

也可以直接修改features,达到同样的效果,不过记得初始化

from mxnet import gluon
class_num = 3
ctx = [mx.gpu(0),mx.gpu(1)]

finetune_net = gluon.model_zoo.vision.resnet50_v2(pretrained=True)

with finetune_net.name_scope():
    finetune_net.features.add(nn.Dense(128,activation='relu'),
                              nn.Dropout(0.5))
    finetune_net.output = nn.Dense(class_num)
finetune_net.features.initialize(init=mx.init.Xavier(),force_init=False,ctx=ctx)
finetune_net.output.initialize(init=mx.init.Xavier(),ctx=ctx)
finetune_net.hybridize()

二:读取gluoncv model_zoo提供的模型,并进行finetune(推荐)

gluoncv是gluon提供的比较强大的视觉库,其中提供了很多的预训练模型可以使用,链接:https://gluon-cv.mxnet.io/model_zoo/classification.html

使用gluoncv的预训练模型也很方便,跟使用gluon的model_zoo方法基本一致,不同点如下:

from gluoncv.model_zoo import get_model

finetune_net = get_model('ResNet50_v2', pretrained=True)

其他的就跟上面的一致了。
注意,有个gluoncv模型不是feature,output结构的,所以在使用的时候,可以看一下其结构,灵活判断

三、直接读取mxnet模型( .params + .json)

有的时候,我们可能需要利用gluon读取mxnet模型,目前利用gluon读取mxnet模型,只能使用gluon.nn.SymbolBlock()进行读取,如下:

ctx = mx.gpu(0)
sym, arg_params, aux_params = mx.model.load_checkpoint('../model/resnetv1d-101',17) ## model path and model index
internals = sym.get_internals()
net_out = internals['fc1_output']

net = gluon.nn.SymbolBlock(outputs=net_out, inputs=mx.sym.var('data'))

net.load_params(filename='../model/resnetv1d-101-0017.params', ctx=ctx)

如上,我们便读取了mxnet的model,现在我们便可以对net进行操作了,如下代码构建了一个3分类的网络:


class_num = 3
finetune_net = nn.HybridSequential(prefix="")
with finetune_net.name_scope():
    finetune_net.add(net)
    finetune_net.add(nn.Dense(class_num))## 输出3分类
net.hybridize() ## 该语句代表静态图动态图切换。

四、最优雅的方式,重新定义网络,实现任意的操作,:

这种方法最为优雅,也最灵活,你可以采用上面个各个方法读取模型,然后重写forward,实现网络的任意操作

class PretrainedNetwork(gluon.HybridBlock):
    def __init__(self, pretrained_layer, **kwargs):
        super(PretrainedNetwork, self).__init__(**kwargs)
        with self.name_scope():
            self.pretrained_layer = pretrained_layer 
            self.fc = nn.HybridSequential()
            self.fc.add(
                        nn.Flatten(),
                        nn.Dense(256, activation = 'relu'),
                        nn.Dropout(rate = 0.5),
                        nn.Dense(128)
                        )
            self.output = nn.Dense(2)

            
    def hybrid_forward(self, F, x):  ## 这里注意F不要忘记。
        x = self.pretrained_layer(x)
        x = self.fc(x)
        out = self.output(x)
     
        return out
        
        
        
### 采用如下得到网络:

from gluoncv.model_zoo import get_model

finetune_net = get_model('ResNet50_v2', pretrained=True)    
net = PretrainedNetwork(pretrained_layer = finetune_net)
net.initialize(forece_reinit = False, init = init.Xavier()) ## 初始化


至此,应该常用的利用预训练模型进行finetune的方法都包含了,如果还有更新,欢迎讨论

你可能感兴趣的:(深度学习,mxnet,一起学深度学习之Gluon)