<一>、运行环境 pychram、conda、mindspore
conda安装可参考Windows|anaconda 安装Mindspore 教程_白白白-CSDN博客
<二>、代码
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.common import dtype as mstype
import os
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
import mindspore.nn as nn
from mindspore.common.initializer import Normal
#创建数据集
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode 双线性方式 调整输入图像大小
rescale_nml_op = CV.Rescale(rescale_nml,
shift_nml) # 张量操作以重新缩放输入图像,两个参数parameters( Rescale factor,Shift factor) 缩放和转移因子
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW() # 转置输入图像
type_cast_op = C.TypeCast(mstype.int32) # 转换给定的数据类型
# apply map operations on images
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 打乱数据,设置缓冲区大小为10000
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) # 设置每批training数据的大小,每32个数据为一个数组,丢弃多余剩下
mnist_ds = mnist_ds.repeat(repeat_size) # 对数据做重复,每次数据都是不同的
return mnist_ds
#定义网络模型
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Number of classes. Default: 10.
num_channel (int): Number of channels. Default: 1.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, num_channel=1, include_top=True):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5,
pad_mode='valid') # 2维卷积 参数:(input_channel , output_channel , kernel_size , pad_mode)
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.relu = nn.ReLU() # 线性激活函数 ReLU(x)=max(0,x),
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) # 池化层
self.include_top = include_top
if self.include_top:
self.flatten = nn.Flatten() # 展平张量而不改变第 0 轴上批量大小的维度。
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
if not self.include_top:
return x
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
# x : Tensor(shape=[2], dtype=Int64, value= [32 10])
return x
def train_net(model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)
def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc))
if __name__ == '__main__':
# 实例化网络
net = LeNet5()
# 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
# 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
mnist_path = r"E:\lenet\datasets\MNIST_Data" #数据集路径根据实际做修改
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(model=model, epoch_size=train_epoch, data_path=mnist_path, repeat_size=dataset_size, ckpoint_cb=ckpoint, sink_mode=False)
test_net(net, model, mnist_path)
最终运行结果: