苹果在去年推出了CoreML机器学习模型,今年在XCode10中提供的CreateML framework,可以创建CoreML模型。
使用CreateML创建CoreML模型时,仅需编写少量的代码。
准备工作
1、XCode10(目前是beta版本)
2、MacOS Mojave(目前也是beta版本)
3、训练数据:在同一个目录下,以文件夹作为分类,各个文件夹下存放对应分类的图片
4、测试数据:和训练数据一样,并且文件夹分类的名称要和训练数据的名称一致
说明:
1、训练数据可以自己准备,也可以从网上找一些,例如:Kaggle Cats and Dogs Dataset(本文是以Pets-100目录下的图片进行的训练)
2、训练数据数量越大,训练的模型越准确,训练的时间也就越长
创建图像分类CoreML模型
1、运行XCode10,创建一个空的playground工程,清除所有代码,然后将下面的代码拷贝在playground中
import CreateMLUI
let builder = MLImageClassifierBuilder()
builder.showInLiveView()
2、切换显示XCode的assistant editor,再点击运行
3、此时,XCode的assistant editor中,会显示MLImageClassifierBuilder的live view,将训练数据的目录拖拽进来,XCode便开始训练CoreML模型了
4、将训练后的模型,保存到文件
5、应用创建的模型进行预测:将想要预测的图片(或目录)拖拽到模型上,进行预测。例如,将Pets-1000目录拖拽到Live view上,预测的准确率如下
说明:除了在Live view中进行预测外,也可以将保存后的模型导入到app中使用。参见Classifying Images with Vision and Core ML
创建文本分类模型
创建文本分类ML模型,可以使用MLDataTable和MLTextClassifier类。步骤如下:
1、创建一个MLDataTable对象,读取训练数据(可以是JSON或CSV格式、或者Dictionary)
2、创建一个MLTextClassifier对象,使用MLDataTable对象中的数据进行训练
3、通过MLTextClassifier对象的write(to:metadata:)方法,将模型保存到磁盘
csv文件格式示例:
title,author,pageCount,genre
Alice in Wonderland,Lewis Carroll,124,Fantasy
Hamlet,William Shakespeare,98,Drama
Treasure Island,Robert L. Stevenson,280,Adventure
Peter Pan,J. M. Barrie,94,Fantasy
JSON文件格式示例:
[
{
"title": "Alice in Wonderland",
"author": "Lewis Carroll",
"pageCount": 124,
"genre": "Fantasy"
},
{
"title": "Hamlet",
"author": "William Shakespeare",
"pageCount": 98,
"genre": "Drama"
}, ...
]
//Dictionary数据示例
let data: [String: MLDataValueConvertible] = [
"title": ["Alice in Wonderland", "Hamlet", "Treasure Island", "Peter Pan"],
"author": ["Lewis Carroll", "William Shakespeare", "Robert L. Stevenson", "J. M. Barrie"],
"pageCount": [124, 98, 280, 94],
"genre": ["Fantasy", "Drama", "Adventure", "Fantasy"]
]
let bookTable = try MLDataTable(dictionary: data)
示例代码
在XCode创建一个空的playground工程,在资源中添加训练使用的数据spam-sms.csv,然后将下面的代码粘贴到工程中
import Foundation
import CreateML
//获取csv文件路径
guard let trainingCSV = Bundle.main.url(forResource: "spam-sms", withExtension: "csv") else {
fatalError()
}
//将csv文件内容加载到MLDataTable中
var spamData = try MLDataTable(contentsOf: trainingCSV)
let (trainingData, testData) = spamData.randomSplit(by: 0.8, seed: 0)
//创建文本分类器,进行训练
//message和label分别对应csv文件中的短信内容列、短信标签列
let predictor = try MLTextClassifier(trainingData: trainingData, textColumn: "message", labelColumn: "label")
//在测试数据集上验证
let metrics = predictor.evaluation(on: testData)
说明:
使用400条中文短信内容的csv,训练模型时,内存占用十分严重,超过Mac系统的物理内存,训练卡在解析短信的步骤,未能训练出模型。
使用英文短信内容进行训练时,没有内存问题,可以训练出模型。
其它
MLClassifier是一个通用的分类模型,MLRegressor是一个回归模型,给定训练模型(MLDataTable)中的特征列和结果列后,就可以对这两种模型进行训练。
缺点
模型训练好后,如果增加了数据集,必须重新开始训练,即无法在训练好的模型上应用新的数据进行训练。
模型优化
提高训练数据集上的准确率(Training Accuracy)
对于MLImageClassifierBuilder,可以将训练的迭代次数调整成20次
对于自然语言的分类器,可以尝试不同的算法(MLTextClassifier.ModelAlgorithmType)
对于MLClassifier和MLRegressor,则可以尝试选用不同的模型进行训练
提高验证数据集上的准确率(Validation Accuracy)
对于拟合不足的问题,可以通过增加训练数据集来进行优化。例如,对于图像分类器,可以在训练时勾选Augmentation(增加)选项:
对于过拟合的问题,则可以尝试减少迭代次数进行优化。
提高测试数据集上的准确率(Evaluation Accuracy)
如果训练数据集、验证数据集上的准确率,高于测试数据集上的准确率,原因通常是训练数据和测试数据存在比较明显的差异导致,这种情况下,可以尝试在训练数据集中使用更多的不同的数据。