MXNet半精度(FP16)训练

MXNet半精度训练

1.先决条件

  • Volta range of Nvidia GPUs (e.g. AWS P3 instance)
  • CUDA 9 or higher
  • cuDNN v7 or higher

2.使用Gluon API训练和前向推理

2.1训练

使用cast将网络设置为float16精度进行训练

net.cast('float16')

data = data.astype('float16', copy=False)

optimizer = mx.optimizer.create('sgd', multi_precision=True, lr=0.01)

2.fine tuning

import numpy as np
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model


pretrained_net = get_model(name='resnet50_v2', ctx=mx.cpu(),
                           pretrained=True, classes=1000)
pretrained_net.cast('float16')
net = get_model(name='resnet50_v2', ctx=mx.cpu(),
                pretrained=False, classes=101)
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=mx.cpu())
net.features = pretrained_net.features
net.cast('float16')

2使用Symbolic API在float16中训练网络

涉及以下步骤。

  1. 在网络的开头添加一个图层,将数据转换为float16。这将确保在float16中计算所有以下层。
  2. 建议在softmax之前将层的输出转换为float32,以便softmax计算在float32中完成。这是因为softmax涉及大幅减少,并且有助于将其保留在float32中以获得更精确的答案。
  3. 建议使用优化器的多精度模式进行更精确的重量更新。
    以下是创建优化器时如何启用此模式。
optimizer = mx.optimizer.create('sgd', multi_precision=True, lr=0.01)

示例

data = mx.sym.Variable(name="data")

if dtype == 'float16':
    data = mx.sym.Cast(data=data, dtype=np.float16)

# ... the rest of the network
net_out = net(data)

if dtype == 'float16':
    net_out = mx.sym.Cast(data=net_out, dtype=np.float32)

output = mx.sym.SoftmaxOutput(data=net_out, name='softmax')

http://mxnet.incubator.apache.org/versions/master/faq/float16.html?highlight=distributed training
加速安装
pip install mxnet-cu80 -i https://pypi.tuna.tsinghua.edu.cn/simple/

MXNet分布式训练
https://github.com/MaJun-cn/DistributedTraining_inMXNet
https://github.com/apache/incubator-mxnet/tree/master/example/distributed_training

conda清理没用的安装包

conda clean -p      //删除没有用的包
conda clean -t      //tar打包
conda clean -y -all //删除所有的安装包及cache

pytorch 分布式训练
分布式例子
https://github.com/pytorch/examples/tree/master/imagenet
https://pytorch.org/tutorials/intermediate/dist_tuto.html
https://pytorch.org/tutorials/intermediate/ddp_tutorial.html?highlight=distributed training
https://pytorch.org/docs/stable/distributed_deprecated.html?highlight=distributed training

你可能感兴趣的:(深度学习)