Apple
在WWDC2018上演示了CreateML的使用, 主要包括了图像识别, 文本分类等ML
应用.
CreateML
是什么?
Apple
在iOS11开始支持 CoreML
, 通过 mlModel
这种类型的文件, 编写 Swift 代码, 可以实现 ML 的各种功能, 比如图像识别, 文本分类, 语言翻译等. CreateML
就是一种便捷的框架, 用来快速训练出模型的框架.
如何使用 CreateML
?
Mac环境: Xcode10, macOS Mojave, Swift
App运行环境: iOS12
制作一个数字的图片分类器
1. 准备数据:
数据: mnist数字集(训练集60000, 测试集10000).
2. 训练模型:
-
- 新建一个
macOS
的playgroud, 为什么一定要macOS
下的, 因为CreateML
的使用需要调用系统的支持, 比如GPU训练, 这也就是为什么 Apple 推出了Mojave
的原因 .
- 新建一个
-
- 编码, 点击运行
Max iterations: 迭代次数, 默认是10, 可以通过增加epoch, 来提高识别准确率, 但是epoch越多, 可能造成过拟合. 梯度不降反升等问题.
Augmentation: 增强, 下面的几个选项, 勾选的话, 一张图片会经过多次处理, 用于数据不足的情况.
-
- 将我们准备好的训练集拖到虚线框内, 或者点击
Training data
旁的Choose
按钮 选择目前文件.
- 将我们准备好的训练集拖到虚线框内, 或者点击
我尝试过将训练集60000张图片, 训练完大概需要160分钟, 中间还中断过几次, 所以后来采用10000张图片, 训练了30分钟左右, 最终采用1000张, 大概3分钟不到,
3. 评估表现
经过我的多次测试, 发现模型在迭代次数为 30 时, 表现是最好的.(没有进行数据增强, 太耗时间了)
Epoch | Training | Validation | Evaluation |
---|---|---|---|
10 | 85% | 77% | 78% |
15 | 91% | 89% | 84% |
20 | 95% | 92% | 84% |
25 | 99% | 87% | 85% |
30 | 100% | 91% | 85% |
35 | 100% | 92% | 84% |
4. 保存模型
模型大小148KB, 挺小的.
5. 使用模型
核心代码
private func predictByModel(src: UIImage, model: MNIST_LC) -> [String: Double] {
// 获取图片数据
guard let buffer = Help_Utils.pixelBufferFromImage(inputImage: src) else {
fatalError("获取pixelBuffer失败")
}
// 进行预测
guard let result = try? model.prediction(image: buffer) else {
fatalError("Prediction failed!")
}
// 显示结果
return result.classLabelProbs
}
在这里提一句, 像这种Model, 如果做的多了会发现, 同种类型的Model, 他们预测的代码基本相同, 如果遇到没见过的 Model, 直接跟到Model里面, 去看可以调用的方法也是可以的.
如果运行报错, 类似下面这种情况, 那么你需要升级你的手机系统为 iOS12.
运行结果:
我特意选了几个比较刁钻的图片, 结果没有识别出来 :].
小结:
- 机器学习的一般过程: 准备数据 -> 训练模型 -> 评估表现 -> 保存模型 -> 使用模型
- 训练数据和测试数据都是我们已经打好标签的数据, 也就是我们事先已经把不同类型的数字图片分成不同的文件夹. 尽量保证每种分类文件下的文件数量一样.
- 上面介绍的那种方式是通过CreateML UI, 也就是 MLImageClassifierBuilder 来实现的, 如果要使用代码来做的话, 则必须调用 MLImageClassifier, 示例在文末
- 用 CreateML 来实现图片分类, 目前来说还不太成熟, 仅靠 Mac 这一台机器, 算力非常有限, 训练1000张图片在 2m 30s 左右, 训练60000张, 就要花3个小时左右.
CreateML
的应用目前除了 图像识别, 还有自然语言处理, 回归分析, 下篇文章再讲.
借助 MLImageClassifier
代码实现 图片分类
// 使用代码训练模型
// 数据预处理
let rootDir = "/Users/linchuan/Downloads/四种格式的mnist/mnist.zip/"
let trainDir = URL(fileURLWithPath: rootDir).appendingPathComponent("train_lc_1000")
let testDir = URL(fileURLWithPath: rootDir).appendingPathComponent("test_lc_100")
/**
训练参数:
featureExtractor: 特征提取器
validationData: 验证集, 为nil, 则会从训练集中取 5% 来补充
maxIterations: 迭代次数
augmentationOptions: 数据增强, 包括图片翻转, 模糊, 修剪等
*/
let parameters = MLImageClassifier.ModelParameters(featureExtractor: .scenePrint(revision: 1),
validationData: nil,
maxIterations: 30,
augmentationOptions: [])
// 创建模型
let classifier = try MLImageClassifier(trainingData: .labeledDirectories(at: trainDir),
parameters: parameters)
// 训练模型
//...
// 测试模型
let evaluation = classifier.evaluation(on: .labeledDirectories(at: testDir))
// 打印测试结果
print(evaluation.precisionRecall.columnNames)
// 保存模型
let saveDir = URL(fileURLWithPath: rootDir).appendingPathComponent("MNISTClassifier.mlmodel")
let modelMetadata = MLModelMetadata(author: "LC",
shortDescription: "image classifier" ,
license: nil,
version: "1.0",
additional: nil)
try classifier.write(to: saveDir, metadata: modelMetadata)