本文将具体介绍如何将PyTorch的代码迁移至MindSpore,并在Ascend芯片上实现单机单卡训练。使用的PyTorch代码为:Resnet50+CIFAR-10的图像分类任务。
示例代码:包含PyTorch和MindSpore代码
数据集:CIFAR-10
MindSpore API主要类别
一、训练流程对比介绍
由于MindSpore的架构设计不同于PyTorch框架,二者的训练流程以及代码实现也有不同,下图展示了两者的区别。
二、训练代码实现
该部分与Pyotrch基本保持一致
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore CIFAR-10 Example')
parser.add_argument('--pre_trained', type=str, default=None,
help='Pretrained checkpoint path')
parser.add_argument('--data_path', type=str, default=None,
help='data_path')
parser.add_argument('--epoch_num', type=int, default=200, help='epoch_num')
parser.add_argument('--checkpoint_max_num', type=int, default=5,
help='Max num of checkpoint')
args = parser.parse_args()
LR_ORI = 0.01
EPOCH_MAX = args.epoch_num
TRAIN_BATCH_SIZE = 128
VAL_BATCH_SIZE = 100
MOMENTUM_ORI = 0.9
WEIGHT_DECAY = 5e-4
CHECKPOINT_MAX_NUM = args.checkpoint_max_num
# Data path
TRAIN_PATH = args.data_path
VAL_PATH = args.data_path
MindSpore通过context.set_context来配置运行需要的信息,譬如运行模式、后端、硬件等信息。该用例中,我们配置使用图模式,并运行在Ascend芯片上。
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
2.3 数据集加载与处理
PyTorch的数据增强方式包括:RandomCrop,RandomHorizontalFlip,Normalize,ToTensor。
通过文API映射关系找到PyTorch在MindSpore对应的接口,进行代码迁移,此处使用了c_transforms接口,是基于C++ opencv开发的高性能图像增强模块,因此最后需通过HWC2CHW()将HWC格式转为MindSpore支持的CHW格式。
迁移后的MindSpore数据集加载与处理代码如下:
import mindsporeimport mindspore.dataset as dsimport mindspore.dataset.vision.c_transforms as CVimport mindspore.dataset.transforms.c_transforms as C
def create_dataset(data_home, do_train, batch_size):
# Define dataset
if do_train:
cifar_ds = ds.Cifar10Dataset(dataset_dir=data_home,
num_parallel_workers=8,
shuffle=True, usage='train')
else:
cifar_ds = ds.Cifar10Dataset(dataset_dir=data_home,
num_parallel_workers=8,
shuffle=False, usage='test')
if do_train:
# Transformation on train data
transform_data = C.Compose([CV.RandomCrop((32, 32), (4, 4, 4, 4)),
CV.RandomHorizontalFlip(),
CV.Rescale(1.0 / 255.0, 0.0),
CV.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
CV.HWC2CHW()])
else:
# Transformation on validation data
transform_data = C.Compose([CV.Rescale(1.0 / 255.0, 0.0),
CV.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
CV.HWC2CHW()])
# Transformation on label
transform_label = C.TypeCast(mindspore.dtype.int32)
# Apply map operations on images
cifar_ds = cifar_ds.map(operations=transform_label, num_parallel_workers=8,
python_multiprocessing=True, input_columns="label")
cifar_ds = cifar_ds.map(operations=transform_data, num_parallel_workers=8,
python_multiprocessing=True, input_columns="image")
cifar_ds = cifar_ds.batch(batch_size, num_parallel_workers=8,
drop_remainder=True)
steps_per_epoch = cifar_ds.get_dataset_size()
return cifar_ds, steps_per_epoch
定义好后直接在主函数中调用即可
# Create dataset
ds_train, steps_per_epoch_train = create_dataset(TRAIN_PATH, do_train=True, batch_size=TRAIN_BATCH_SIZE)
ds_val, steps_per_epoch_val = create_dataset(VAL_PATH, do_train=False, batch_size=VAL_BATCH_SIZE)
MindSpore针对以下三种情况已经做了很好的适配,可参考使用。
1.常用数据集加载
2.特定格式数据集加载(MindRecord)
3.自定义数据集加载
2.4 网络定义
分析PyTorch网络中所包含的算子,通过API映射关系和MindSpore API,找到MindSpore对应的算子,并构造Resnet网络:MindSpore中使用nn.Cell构造网络结构。在Cell的__init__函数内,定义需要使用的算子。然后在construct函数内将定义好的算子连接起来,最后将输出通过return返回。
注: 为了保证权重初始化与PyTorch一致,故定义了_conv2d和_dense函数。
import mathimport mindsporeimport mindspore.nn as nnfrom mindspore.ops import operations as P
EXPANSION = 4
def _conv2d(in_channel, out_channel, kernel_size, stride=1, padding=0):
scale = math.sqrt(1/(in_channel*kernel_size*kernel_size))
if padding == 0:
return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size,
stride=stride, padding=padding, pad_mode='same',
weight_init=mindspore.common.initializer.Uniform(scale=scale))
else:
return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size,
stride=stride, padding=padding, pad_mode='pad',
weight_init=mindspore.common.initializer.Uniform(scale=scale))
def _dense(in_channel, out_channel):
scale = math.sqrt(1/in_channel)
return nn.Dense(in_channel, out_channel,
weight_init=mindspore.common.initializer.Uniform(scale=scale),
bias_init=mindspore.common.initializer.Uniform(scale=scale))
class ResidualBlock(nn.Cell):
def __init__(self, in_planes, planes, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = _conv2d(in_planes, planes, kernel_size=1)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = _conv2d(planes, planes, kernel_size=3,
stride=stride, padding=1)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = _conv2d(planes, EXPANSION*planes, kernel_size=1)
self.bn3 = nn.BatchNorm2d(EXPANSION*planes)
self.shortcut = nn.SequentialCell()
if stride != 1 or in_planes != EXPANSION*planes:
self.shortcut = nn.SequentialCell(
_conv2d(in_planes, EXPANSION*planes,
kernel_size=1, stride=stride),
nn.BatchNorm2d(EXPANSION*planes))
self.relu = nn.ReLU()
self.add = P.Add()
def construct(self, x_input):
out = self.relu(self.bn1(self.conv1(x_input)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
identity = self.shortcut(x_input)
out = self.add(out, identity)
out = self.relu(out)
return out
ResNet网络中有大量的重复结构,可以使用循环构造多个Cell实例并通过SequentialCell来串联,减少代码重复。在construct函数内将定义好的算子连接起来,最后将网络输出通过return返回。
主干网络代码如下:
class ResNet(nn.Cell):
def __init__(self, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = _conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.layer1 = self._make_layer(64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(512, num_blocks[3], stride=2)
self.avgpool2d = nn.AvgPool2d(kernel_size=4, stride=4)
self.reshape = mindspore.ops.Reshape()
self.linear = _dense(2048, num_classes)
def _make_layer(self, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(ResidualBlock(self.in_planes, planes, stride))
self.in_planes = EXPANSION*planes
return nn.SequentialCell(*layers)
def construct(self, x_input):
x_input = self.conv1(x_input)
out = self.relu(self.bn1(x_input))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = self.avgpool2d(out)
out = self.reshape(out, (out.shape[0], 2048))
out = self.linear(out)
return out
def resnet_50():
return ResNet([3, 4, 6, 3])
下图展示了PyTorch与MindSpore在定义一个小的CNN网络上的差异:
2.5 定义损失函数和优化器
PyTorch损失函数和优化器:
# Define network
net = resnet_50()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = net.to(device)
# Define the loss function
criterion = torch.nn.CrossEntropyLoss()
# Define the optimizer
optimizer = torch.optim.SGD(net.parameters(), LR_ORI, MOMENTUM_ORI, WEIGHT_DECAY)
迁移后的MindSpore的损失函数和优化器:
# Define network
net = resnet_50()
# Define the loss function
loss = nn.loss.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# Define the optimizer
opt = nn.SGD(net.trainable_params(), LR_ORI, MOMENTUM_ORI, WEIGHT_DECAY)
2.6 构建模型
MindSpore推荐使用mindspore.Model接口对网络进行封装,内部会自动构建训练流程。需要将定义好的网络原型、损失函数、优化器和metrics传入Model接口,同时为了便于模型评估,MindSpore中提供了多种Metrics,如Accuracy、Precision、Recall、F1等。
注:此处为了发挥Ascend芯片的高性能算力,开启了amp_level="O3"。
from mindspore import Model
# Create train model
metrics = {'accuracy': nn.Accuracy(), 'loss': nn.Loss()}
model = Model(net, loss, opt, metrics=metrics, amp_level="O3")
2.7 训练并验证
MindSpore通过调用Model.train接口,并在callbacks中传入自带的ModelCheckpoint、LossMonitor和自定义的EvalCallBack、PrintFps实例,进行训练并验证。
import timefrom mindspore.train.callback import Callback, ModelCheckpoint, LossMonitorfrom mindspore.train.callback import CheckpointConfig
class EvalCallBack(Callback):
def __init__(self, eval_model, eval_dataset, eval_per_epoch):
self.eval_model = eval_model
self.eval_dataset = eval_dataset
self.eval_per_epoch = eval_per_epoch
def epoch_end(self, run_context):
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
if cur_epoch % self.eval_per_epoch == 0:
acc = self.eval_model.eval(self.eval_dataset,
dataset_sink_mode=False)
print(acc)
class PrintFps(Callback):
def __init__(self, step_num, start_time):
self.step_num = step_num
self.start_time = start_time
self.end_time = time.time()
def epoch_begin(self, run_context):
self.start_time = time.time()
def epoch_end(self, run_context):
self.end_time = time.time()
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
fps = self.step_num / (self.end_time - self.start_time)
print("Epoch:{}, {:.2f}imgs/sec".format(cur_epoch, fps))
# CheckPoint CallBack definition
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch_train,
keep_checkpoint_max=CHECKPOINT_MAX_NUM)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10",
directory="./checkpoint/", config=config_ck)
# Eval CallBack definition
EVAL_PER_EPOCH = 1
eval_cb = EvalCallBack(model, ds_val, EVAL_PER_EPOCH)
train_data_num = steps_per_epoch_train * TRAIN_BATCH_SIZE
# FPS CallBack definition
init_time = time.time()
fps_cb = PrintFps(train_data_num, init_time)
# Train
print("============== Starting Training ==============")
model.train(EPOCH_MAX, ds_train,
callbacks=[LossMonitor(), eval_cb, fps_cb, ckpoint_cb],
dataset_sink_mode=True,
sink_size=steps_per_epoch_train)
三、运行
启动命令:
python MindSpore_1P.py --epoch_num=xxx --data_path=xxx
在Terminal中运行脚本,可以看到网络输出结果:
相关代码请点击附件下载:
Resnet50_cifar10.rar6