使用PyTorch Lightning Flash和FiftyOne快速建立计算机视觉模型

使用PyTorch Lightning Flash和FiftyOne快速建立计算机视觉模型_第1张图片

近年来,开源工具取得了重大进展,满足了许多与端到端平台服务相同的需求。

从模型体系结构开发到数据集管理,再到模型训练和部署,它们都非常有用。通过充分挖掘,你可以找到一个能够支持数据和模型生命周期大部分部分的开源工具。

工具之间的紧密集成是实现近乎无缝工作流的最佳方式。本文深入探讨了模型原型和训练框架PyTorch Lightning Flash与数据集可视化和模型分析工具FiftyOne之间的新集成。

Lightning Flash是一个建立在PyTorch Lighting之上的新框架,它提供了一系列任务,用于快速原型制作、基线、微调,以及通过深度学习解决商业和科学问题。

尽管Flash很容易学习,无论你拥有多少深度学习经验,你都可以使用Lightning和PyTorch修改现有任务,以找到适合你的抽象级别。为了进一步加快速度,Flash代码具有可扩展性,内置支持任何硬件上的分布式训练和推理。

Flash使训练你的第一个模型变得非常容易,但要继续改进它,你需要了解你的模型的性能以及如何改进它。

FiftyOne是由Voxel51开发的用于构建高质量数据集和计算机视觉模型的开源工具。它提供了用于优化数据集分析管道的构建块,允许你亲自操作数据,包括可视化复杂标签、评估模型、探索感兴趣的场景、识别故障模式、查找注释错误、管理训练数据集,等等。

使用Flash+FiftyOne,你可以加载以下所有计算机视觉任务的数据集、训练模型和分析结果:

  • 图像分类

  • 图像目标检测

  • 图像语义分割

  • 视频分类

  • 嵌入可视化

概述

Flash和FiftyOne之间的紧密集成允许你执行端到端的工作流程,即加载数据集、在数据集上训练模型以及可视化/分析其预测,所有这些都只需几个简单的代码块

将FiftyOne数据集加载到Flash中

虽然使用FiftyOne开发数据集一直很容易,但与PyTorch Lightning Flash的集成现在允许你将这些数据集直接加载到Flash和训练任务中,从而使其正常工作。

from flash.image import ImageClassificationData

import fiftyone as fo

train_dataset = fo.Dataset.from_dir(
    "/path/to/train",
    fo.types.ImageClassificationDirectoryTree,
    label_field="ground_truth",
)

val_dataset = fo.Dataset.from_dir(
    "/path/to/val",
    fo.types.ImageClassificationDirectoryTree,
    label_field="ground_truth",
)

datamodule = ImageClassificationData.from_fiftyone(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    label_field="ground_truth",
)

训练Flash任务

Flash提供了为你的任务获取模型所需的工具,并开始在你的数据上对其进行微调,尽可能少的代码,最重要的是,无需成为该领域的专家。

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1.下载数据
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2.加载数据
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 3.构建模型
model = ImageClassifier(num_classes=datamodule.num_classes, backbone="resnet18")

# 4.创建一个Trainer
trainer = flash.Trainer(max_epochs=1)

# 5.微调模型
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6.保存
trainer.save_checkpoint("image_classification_model.pt")

用FiftyOne可视化Flash预测

由于现代数据集的复杂性和规模,图像和视频数据的可视化一直是一个挑战。

FiftyOne旨在为你的数据集和标签(包括注释和模型预测)提供一个用户友好的视图,现在Flash模型只需另外一行代码即可访问该视图。

from flash import Trainer
from flash.core.classification import FiftyOneLabels
from flash.core.integrations.fiftyone import visualize
from flash.video import VideoClassificationData, VideoClassifier

classifier = VideoClassifier.load_from_checkpoint(...)

# 选项1:使用Trainer和数据模块生成预测
datamodule = VideoClassificationData.from_folders(
    predict_folder="/path/to/folder",
    ...
)
trainer = Trainer()
classifier.serializer = FiftyOneLabels(return_filepath=True)
predictions = trainer.predict(classifier, datamodule=datamodule)

session = visualize(predictions) # 启动FiftyOne

# 选项2:使用文件路径从模型生成预测
filepaths = ["list", "of", "filepaths"]
predictions = classifier.predict(filepaths)
classifier.serializer = FiftyOneLabels()

session = visualize(predictions, filepaths=filepaths) # 启动FiftyOne

示例工作流

安装程序

为了遵循本文中的示例,你需要安装相关的软件包。首先,你需要安装PyTorch Lightning Flash和FiftyOne。

pip install fiftyone lightning-flash

对于embeddings可视化工作流,你还需要安装降维软件包umap learn:

pip install umap-learn

一般工作流程

使用这些工具的大多数模型开发工作流遵循相同的一般结构:

  1. 将数据集加载到FiftyOne中

  2. 从数据集创建Flash数据模块

  3. 微调任务

  4. 从模型生成预测

  5. 将预测添加回数据集并将其可视化

图像目标检测

本节展示了使用PyTorch Lightning Flash和FiftyOne之间的集成来训练和评估图像对象检测模型的具体示例。

from itertools import chain

import fiftyone as fo
import fiftyone.zoo as foz

from flash import Trainer
from flash.image import ObjectDetectionData, ObjectDetector
from flash.image.detection.serialization import FiftyOneDetectionLabels

# 1.加载你的FiftyOne数据集
# 这里我们使用视图到一个数据集,但你也可以为每个分割创建不同的数据集
dataset = foz.load_zoo_dataset("quickstart", max_samples=40)
train_dataset = dataset.shuffle(seed=51)[:20]
test_dataset = dataset.shuffle(seed=51)[20:25]
val_dataset = dataset.shuffle(seed=51)[25:30]
predict_dataset = dataset.shuffle(seed=51)[30:40]

# 2. 加载Datamodule
datamodule = ObjectDetectionData.from_fiftyone(
    train_dataset = train_dataset,
    test_dataset = test_dataset,
    val_dataset = val_dataset,
    predict_dataset = predict_dataset,
    label_field = "ground_truth",
    batch_size=4,
    num_workers=4,
)

# 3. 创建模型
model = ObjectDetector(
    model="retinanet",
    num_classes=datamodule.num_classes,
    serializer=FiftyOneDetectionLabels(),
)

# 4. 创建trainer
trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)

# 5. 微调model
trainer.finetune(model, datamodule=datamodule)

# 6. 保存
trainer.save_checkpoint("object_detection_model.pt")

# 7. 生成预测
model = ObjectDetector.load_from_checkpoint(
  "https://flash-weights.s3.amazonaws.com/object_detection_model.pt"
)
model.serializer = FiftyOneDetectionLabels()

predictions = trainer.predict(model, datamodule=datamodule)

predictions = list(chain.from_iterable(predictions)) # 扁平化

# 8. 将预测添加到数据集和分析
predict_dataset.set_values("flash_predictions", predictions)
session = fo.launch_app(view=predict_dataset)

使用PyTorch Lightning Flash和FiftyOne快速建立计算机视觉模型_第2张图片

从这里开始,你现在可以将预测返回到数据集中,并可以运行评估以生成混淆矩阵、PR曲线以及准确率和mAP等指标。

特别是,你能够识别和查看单个真/假阳性/阴性结果,从而了解模型在哪些方面表现良好,在哪些方面表现不佳。基于常见故障模式改进模型是开发更好模型的更可靠的方法。

嵌入可视化

此工作流的独特之处在于它采用预训练模型,并使用它们为数据集中的每个图像生成嵌入向量。

然后,你可以在低维空间中计算这些嵌入的可视化,以找到数据簇。此功能可以为硬样本挖掘、数据预注释、注释样本推荐等带来宝贵的发现。

import numpy as np
import torch

from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

import fiftyone as fo
import fiftyone.brain as fob

# 1 下载数据
download_data(
    "https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip"
)

# 2 加载数据到FiftyOne
dataset = fo.Dataset.from_dir(
    "data/hymenoptera_data/test/",
    fo.types.ImageClassificationDirectoryTree,
)

# 3 载入模型
embedder = ImageEmbedder(backbone="resnet101", embedding_dim=128)

# 4 生成嵌入
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))

# 5 可视化
results = fob.compute_visualization(dataset, embeddings=embeddings)

session = fo.launch_app(dataset)

plot = results.visualize(labels="ground_truth.label")
plot.show()

使用PyTorch Lightning Flash和FiftyOne快速建立计算机视觉模型_第3张图片

使用交互式绘图的概念,你可以单击或圈出这些嵌入的区域,并自动更新会话以查看和标记相应的样本。

附加任务

你可以在此处查看其他任务(如分类)的类似工作流:https://voxel51.com/docs/fiftyone/integrations/lightning_flash.html#model-training。

总结

多年来,开源社区取得了令人瞩目的发展,特别是在机器学习领域。

虽然单个工具可以很好地解决特定问题,但正是工具之间的紧密集成导致了强大的工作流。PyTorch Lightning Flash和FiftyOne之间的新集成为开发数据集、训练模型和分析结果提供了一种新的简单方法。

☆ END ☆

如果看到这里,说明你喜欢这篇文章,请转发、点赞。微信搜索「uncle_pn」,欢迎添加小编微信「 woshicver」,每日朋友圈更新一篇高质量博文。

扫描二维码添加小编↓

使用PyTorch Lightning Flash和FiftyOne快速建立计算机视觉模型_第4张图片

你可能感兴趣的:(可视化,大数据,python,机器学习,人工智能)