CreateML使用简介

苹果在去年推出了CoreML机器学习模型,今年在XCode10中提供的CreateML framework,可以创建CoreML模型。

使用CreateML创建CoreML模型时,仅需编写少量的代码。

准备工作

1、XCode10(目前是beta版本)

2、MacOS Mojave(目前也是beta版本)

3、训练数据:在同一个目录下,以文件夹作为分类,各个文件夹下存放对应分类的图片

4、测试数据:和训练数据一样,并且文件夹分类的名称要和训练数据的名称一致

说明:

1、训练数据可以自己准备,也可以从网上找一些,例如:Kaggle Cats and Dogs Dataset(本文是以Pets-100目录下的图片进行的训练)

2、训练数据数量越大,训练的模型越准确,训练的时间也就越长

创建图像分类CoreML模型

1、运行XCode10,创建一个空的playground工程,清除所有代码,然后将下面的代码拷贝在playground中

import CreateMLUI

let builder = MLImageClassifierBuilder()

builder.showInLiveView()

2、切换显示XCode的assistant editor,再点击运行

MLImageClassifierBuilder的Live view

3、此时,XCode的assistant editor中,会显示MLImageClassifierBuilder的live view,将训练数据的目录拖拽进来,XCode便开始训练CoreML模型了

拖拽数据进行训练

4、将训练后的模型,保存到文件

保存ML模型
保存的ML模型

5、应用创建的模型进行预测:将想要预测的图片(或目录)拖拽到模型上,进行预测。例如,将Pets-1000目录拖拽到Live view上,预测的准确率如下

说明:除了在Live view中进行预测外,也可以将保存后的模型导入到app中使用。参见Classifying Images with Vision and Core ML

应用模型进行预测

创建文本分类模型

创建文本分类ML模型,可以使用MLDataTable和MLTextClassifier类。步骤如下:

1、创建一个MLDataTable对象,读取训练数据(可以是JSON或CSV格式、或者Dictionary)

2、创建一个MLTextClassifier对象,使用MLDataTable对象中的数据进行训练

3、通过MLTextClassifier对象的write(to:metadata:)方法,将模型保存到磁盘

csv文件格式示例:

title,author,pageCount,genre

Alice in Wonderland,Lewis Carroll,124,Fantasy

Hamlet,William Shakespeare,98,Drama

Treasure Island,Robert L. Stevenson,280,Adventure

Peter Pan,J. M. Barrie,94,Fantasy

JSON文件格式示例:

[

{

"title": "Alice in Wonderland",

"author": "Lewis Carroll",

"pageCount": 124,

"genre": "Fantasy"

},

{

"title": "Hamlet",

"author": "William Shakespeare",

"pageCount": 98,

"genre": "Drama"

}, ...

]

//Dictionary数据示例

let data: [String: MLDataValueConvertible] = [

    "title": ["Alice in Wonderland", "Hamlet", "Treasure Island", "Peter Pan"],

    "author": ["Lewis Carroll", "William Shakespeare", "Robert L. Stevenson", "J. M. Barrie"],

    "pageCount": [124, 98, 280, 94],

    "genre": ["Fantasy", "Drama", "Adventure", "Fantasy"]

]

let bookTable = try MLDataTable(dictionary: data)

示例代码

在XCode创建一个空的playground工程,在资源中添加训练使用的数据spam-sms.csv,然后将下面的代码粘贴到工程中

import Foundation

import CreateML

//获取csv文件路径

guard let trainingCSV = Bundle.main.url(forResource: "spam-sms", withExtension: "csv") else {

    fatalError()

}

//将csv文件内容加载到MLDataTable中

var spamData = try MLDataTable(contentsOf: trainingCSV)

let (trainingData, testData) = spamData.randomSplit(by: 0.8, seed: 0)

//创建文本分类器,进行训练

//message和label分别对应csv文件中的短信内容列、短信标签列

let predictor = try MLTextClassifier(trainingData: trainingData, textColumn: "message", labelColumn: "label")

//在测试数据集上验证

let metrics = predictor.evaluation(on: testData)

工程示例:创建文本分类ML模型

说明:

使用400条中文短信内容的csv,训练模型时,内存占用十分严重,超过Mac系统的物理内存,训练卡在解析短信的步骤,未能训练出模型。

使用英文短信内容进行训练时,没有内存问题,可以训练出模型。

其它

MLClassifier是一个通用的分类模型,MLRegressor是一个回归模型,给定训练模型(MLDataTable)中的特征列和结果列后,就可以对这两种模型进行训练。

缺点

模型训练好后,如果增加了数据集,必须重新开始训练,即无法在训练好的模型上应用新的数据进行训练。

模型优化

提高训练数据集上的准确率(Training Accuracy)

对于MLImageClassifierBuilder,可以将训练的迭代次数调整成20次

对于自然语言的分类器,可以尝试不同的算法(MLTextClassifier.ModelAlgorithmType)

对于MLClassifier和MLRegressor,则可以尝试选用不同的模型进行训练

提高验证数据集上的准确率(Validation Accuracy)

对于拟合不足的问题,可以通过增加训练数据集来进行优化。例如,对于图像分类器,可以在训练时勾选Augmentation(增加)选项:

Augmentation选项

对于过拟合的问题,则可以尝试减少迭代次数进行优化。

提高测试数据集上的准确率(Evaluation Accuracy)

如果训练数据集、验证数据集上的准确率,高于测试数据集上的准确率,原因通常是训练数据和测试数据存在比较明显的差异导致,这种情况下,可以尝试在训练数据集中使用更多的不同的数据。

你可能感兴趣的:(CreateML使用简介)