PaddleX快速实现图像分类训练

飞桨 -PaddleX 是一套更加简明易懂的API,并配套一键下载安装的图形化开发客户端。用PaddleX实现图像分类训练非常快速,代码量也小。
第一步:安装paddlex, 参考《在windows10下安装飞桨2.0.2和PaddleX》
第二步:下载并解压蔬菜分类数据集,用迅雷直接下载

https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gzicon-default.png?t=M276https://links.jianshu.com/go?to=https%3A%2F%2Fbj.bcebos.com%2Fpaddlex%2Fdatasets%2Fvegetables_cls.tar.gz或者用命令:

wget https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz
tar xzvf vegetables_cls.tar.gz

第三步:运行train.py程序,源代码如下所示,训练模型

from paddlex.cls import transforms
import paddlex as pdx 

train_transforms = transforms.Compose([
    transforms.RandomCrop(crop_size=224),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize()
])
eval_transforms = transforms.Compose([
    transforms.ResizeByShort(short_size=256),
    transforms.CenterCrop(crop_size=224),
    transforms.Normalize()
])

train_dataset = pdx.datasets.ImageNet(
    data_dir='vegetables_cls',
    file_list='vegetables_cls/train_list.txt',
    label_list='vegetables_cls/labels.txt',
    transforms=train_transforms,
    shuffle=True)
eval_dataset = pdx.datasets.ImageNet(
    data_dir='vegetables_cls',
    file_list='vegetables_cls/val_list.txt',
    label_list='vegetables_cls/labels.txt',
    transforms=eval_transforms)

num_classes = len(train_dataset.labels)
model = pdx.cls.MobileNetV3_small_ssld(num_classes=num_classes)

model.train(num_epochs=20,
            train_dataset=train_dataset,
            train_batch_size=32,
            eval_dataset=eval_dataset,
            lr_decay_epochs=[4, 6, 8],
            save_dir='output/mobilenetv3_small_ssld',
            use_vdl=True)

 训练结果如下所示:

PaddleX快速实现图像分类训练_第1张图片

 第四步:运行infer.py程序,源代码如下所示,执行推理计算,获得推理结果

import paddlex as pdx
model = pdx.load_model('output/mobilenetv3_small_ssld/best_model')
result = model.predict('vegetables_cls/bocai/100.jpg')
print("Predict Result: ", result)

 PaddleX快速实现图像分类训练_第2张图片

 

你可能感兴趣的:(paddle,paddlepaddle,paddlex)