MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例

MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例

本案例解释了如何在Pytorch中使用MLFlow,在 LEGO Minifigures 数据集中的案例,图像分类算法选用MobileNet V2。

  • 解释mlflow.pytorch的具体使用方式;
  • 解释pl.LightningModulepl.LightningDataModule的具体使用方式;
  • 解释mlflow run以及mlflow ui的具体使用方式;
  • 解释如何使用 MobileNetV2 进行迁移学习;
  • 通过 LEGO Minifigures 的案例解释如何在pyTorch中使用mlFlow,以及结果。

运行环境:

  • 平台:Win10。
  • IDE:Visual Studio Code
  • 需要预装:Anaconda3
  • MLFlow当前版本:1.25.1
  • 代码

文章目录

  • MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例
    • 1 关于 MLFlow
    • 2 关于如何在PyTorch中使用MLFlow
    • 3 关于代码的运行
      • 3.1 第一种运行代码方式:本地创建虚拟环境运行
      • 3.2 第二种运行代码方式:mlflow run 指令运行
    • 4 概念解释
      • 4.1 mlflow.pytorch.autolog
      • 4.2 pl.LightningModule
      • 4.3 pl.LightningDataModule
      • 4.4 MLproject 以及 conda.yaml 文件
    • 5 LEGO Minifugures 的图像分类案例
      • 5.1 关于 LEGO Minifigures 数据集
      • 5.2 数据读取与处理
        • 5.2.1 LEGOMinifiguresDataModule
        • 5.2.2 数据预处理
        • 5.2.3 数据增强(仅对于训练集)
      • 5.3 模型训练/验证/测试/优化类
        • 5.3.1 模型导入
        • 5.3.2 模型训练
        • 5.3.3 模型验证
        • 5.3.4 模型测试
        • 5.3.5 模型优化
      • 5.4 迁移学习
        • 5.4.1 实例化数据读取处理模块以及训练/验证/测试/优化类
        • 5.4.2 迁移学习的训练,验证,以及测试
    • 6 结果


1 关于 MLFlow

MLFlow是一个能够覆盖机器学习全流程(从数据准备到模型训练到最终部署)的新平台。它一共有四大模块(如下为官网的原文以及翻译):

  • MLflow Tracking:如何通过API的形式管理实验的参数、代码、结果,并且通过UI的形式做对比。
  • MLflow Projects:以可重用、可复制的形式打包ML代码,以便与其他数据科学家共享或部署到生产环境(MLflow项目)。
  • MLflow Models:管理和部署从各种ML库到各种模型服务和推理平台(MLflow模型)的模型。
  • MLflow Model Registry:提供一个中央模型存储,以协同管理MLflow模型的整个生命周期,包括模型版本控制、阶段转换和注释(MLflow模型注册表)。

在这个系列的前半部分,我们对MLFlow做了详细的介绍,以及每一个模块的案例讲解,这里不再赘述。

2 关于如何在PyTorch中使用MLFlow

mlflow.pytorch模块提供了一个用于记录和加载 PyTorch 模型的 API。

需要注意的是,MLFlow 无法直接和PyTorch一起使用,我们需要先装一下pytorch_lightning,当我们调用 pytorch_lightning.Trainer()fit 方法时会执行自动记录(mlflow.pytorch.autolog)。

3 关于代码的运行

有两种种运行代码的方式,这里我们也会一一列举。如果我们是初学者,建议先尝试第一种方式。

3.1 第一种运行代码方式:本地创建虚拟环境运行

首先,我们在Windows的平台下安装Anaconda3。具体的安装步骤此处略过,参见Anaconda的官方文档。

安装完后,新建虚拟环境。在VSCode,使用conda create -n your_env_name python=X.X(2.7、3.6等)命令创建python版本为X.X、名字为your_env_name的虚拟环境。

这里我们输入conda create -n mlFlowEx python=3.8.2

安装完默认的依赖后,我们进入虚拟环境:conda activate mlFlowEx。注意,如果需要退出,则输入conda deactivate。另外,如果Terminal没有成功切换到虚拟环境,可以尝试conda init powershell,然后重启terminal。

然后,我们在虚拟环境中下载好相关依赖:pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

这个案例的依赖包括:

mlflow==1.25.1
torchvision>=0.9.1
torch>=1.9.0
pytorch-lightning==1.6.1

关于mlflow的版本,1.23.1应该也是可以运行的。

我们将代码下载到本地:git clone https://gitee.com/yichaoyyds/mlflow-ex-pytorch.git。进入文件夹MNIST_pytorch后,我们看到几个.py文件,输入python LegoCharacterRecognition.py运行代码。

3.2 第二种运行代码方式:mlflow run 指令运行

这里我们不需要在本地新建一个虚拟环境(mlflow run 指令会自动新建),我们将代码下载到本地,进入文件夹LegoCharacterRecognition后,我们在terminal运行mlflow run .。注意,最后有一个.不要忘记,这个.的意思是当前文件夹路径。在之前的系列文章中,我们解释过,运行mlflow run之后,系统会去寻找指定文件夹下的MLproject文件。此文件包含了我们需要运行的脚本指令。

4 概念解释

在解释详细代码之前,我们有四个重要的概念(函数)需要解释。

4.1 mlflow.pytorch.autolog

mlflow.pytorch的其他函数我们都可以先不看,只要在train代码之前加上这行,mlflow就可以自动开始运行,包括保存artifactsmetricsparamstag值。有很多值是自动生成的。在结果这个章节中,我们会详细解释。这里我们先解释mlflow.pytorch.autolog函数。

一般在代码中,我们直接使用mlflow.pytorch.autolog()即可,因为函数中的参数都有默认值,我们一般使用默认值就可以。这里我们来详细过一遍其中重要的参数。完整的函数定义如下:

mlflow.pytorch.autolog(log_every_n_epoch=1, log_every_n_step=None, log_models=True, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, registered_model_name=None)
  • log_every_n_epoch: 如果指定,则每 n 个 epoch 记录一次 metric 值。 默认情况下,每个 epoch 后都会记录 metric 值;
  • log_models:如果为 True,则经过训练的模型将记录在 MLflow artifacts 路径下。 如果为 False,则不记录经过训练的模型。注意,这里只会记录一个model,应该是性能最好的那个model。在Microsoft Azure MLOps中,它会记录每一个epoch的model。所以从性能角度上,Azure MLOps工具确实强大,但我们可以根据实际需要进行选择,毕竟Azure MLOps是付费的,而MLFlow是免费的。如果对 Azure MLOps 感兴趣,也可以翻看这个系列的其他文章。
  • disable:如果为 True,则禁用 PyTorch Lightning 自动日志记录集成功能。 如果为 False,则启用。当PyTorch Lightning 在进行模型训练(进行初始化)的时候,Lightning 在后台使用 TensorBoard 记录器,并将日志存储到目录中(默认情况下在 Lightning_logs/ 中)。相关链接。我们可以将这个默认日志功能关闭。这里我们建议开着自动日志功能,因为如果关闭,MLFlow的artifactsmetricsparamstag默认保存的参数就无法保存;

4.2 pl.LightningModule

对于 PyTorch Lightning,有两个函数是至关重要,一个是pl.LightningModule,一个是pl.LightningDataModule。前者的包含了训练/验证/预测/优化的所有模块,后者则是数据集读取模块。我们通过PyTorch Lightning进行模型训练的时候,通常会继承这两个类。目前我对 PyTorch Lightning 不是很了解,所以这里我作为一个初学者的角度,针对这个案例进行一些相关的解读。

关于pl.LightningModule,和我们这个案例相关的函数包括:

  • forward,作用和torch.nn.Module.forward()一样,这里我们不再赘述;
  • training_step,我们计算并返回训练损失和一些额外的metrics。
  • validation_step,我们计算并返回验证损失和一些额外的metrics。
  • test_step,我们计算并返回测试损失和一些额外的metrics。
  • validation_epoch_end,在验证epoch结束后,计算这个epoch的平均验证accuracy。
  • test_epoch_end,在测试epoch结束后,计算计算这个epoch的平均测试accuracy。
  • configure_optimizers,选择要在优化中使用的优化器和学习率调度器。

此网页有详细的描述,这里不再赘述。

4.3 pl.LightningDataModule

pl.LightningDataModule 标准化了训练、验证、测试集的拆分、数据准备和转换。主要优点是一致的数据拆分、数据准备和跨模型转换,一个例子如下:

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
    def train_dataloader(self):
        train_split = Dataset(...)
        return DataLoader(train_split)
    def val_dataloader(self):
        val_split = Dataset(...)
        return DataLoader(val_split)
    def test_dataloader(self):
        test_split = Dataset(...)
        return DataLoader(test_split)
    def teardown(self):
        # clean up after fit or test
        # called on every process in DDP

4.4 MLproject 以及 conda.yaml 文件

如果我们要使用mlflow run指令,那么我们就需要明白MLproject以及conda.yaml文件的作用。

MLproject文件:

name: lego-minifigure-classification

conda_env: conda.yaml

entry_points:
  main:
    command: |
          python LegoCharacterRecognition.py

当我们执行mlflow run .的时候,相当于系统读取当前文件夹下MLproject中的指令。

conda.yaml文件

channels:
- conda-forge
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro
- https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2
dependencies:
- python=3.8.2
- pip
- pip:
  - mlflow
  - pytorch-lightning==1.6.1
  - ax-platform
  - torchvision>=0.9.1
  - torch>=1.9.0
  - albumentations==1.1.0
  - matplotlib==3.5.2
  - seaborn==0.11.2
  - -i https://pypi.tuna.tsinghua.edu.cn/simple

包含了这个项目的依赖项。当我们第一次运行mlflow run .的时候,系统会自动新建一个虚拟环境,安装对应的依赖(对应conda.yaml文件),最后运行对应的代码(对应MLproject文件)。

5 LEGO Minifugures 的图像分类案例

5.1 关于 LEGO Minifigures 数据集

该数据集包含各种乐高人仔的图片。 数据集中的每个人仔都有几张不同姿势和不同环境的图像。

目前,它包含来自乐高套装的 28 个人物(总共 300 多张图像):Yoda’s Hut, Spider Mech vs. Venom, General Grievous’ Combat Speeder, Kylo Ren’s Shuttle™ Microfighter, AT-ST™ Raider from The Mandalorian, Molten Man Battle, Aragog’s Lair, Black Widow’s Helicopter Chase, Captain America: Outriders Attack, Pteranodon Chase, Iron Man Hall of Armor。

该数据集的GITHUB链接。Kaggle中也有其数据集可以直接下载,Kaggle链接。

这里我们已经下载了这个数据集,在archive文件夹下。需要注意:

  • index.csv:包含了训练集/验证集(一共361个数据)图片的位置,以及对应的标签(所以这个csv文件有2列);
  • test.csv:包含了测试集(一共76个数据)图片位置,以及对应的标签;
  • metadata.csv:包含了这个数据集的37个标签。

5.2 数据读取与处理

首先,我们需要读取数据,拆分成训练集、验证集、测试集,并且进行一些数据预处理,必须数据增强等。此模块的代码位于LEGOMinifiguresDataRetriever.py

5.2.1 LEGOMinifiguresDataModule

pl.LightningDataModule 标准化了训练、验证、测试集的拆分、数据准备和转换。主要优点是一致的数据拆分、数据准备和跨模型转换。相关代码:

class LEGOMinifiguresDataModule(pl.LightningDataModule):
    
    def __init__(
        self, 
        train_batch_size, 
        valid_batch_size, 
        test_batch_size, 
        image_size, 
        base_dir,
        train_augmentations=None
    ):
        """
        Initialization of inherited lightning data module
        """
        super().__init__()
        self.train_batch_size = train_batch_size
        self.valid_batch_size = valid_batch_size
        self.test_batch_size = test_batch_size
        self.image_size = image_size
        self.base_dir = base_dir
        self.train_augmentations=train_augmentations
        
    def setup(self, stage):
        """
        Load the data, parse it and split the data into train, test, validation data
        """
        # Load train dataset, test dataset
        self.df_train = pd.read_csv(os.path.join(self.base_dir, 'index.csv'))
        self.df_test = pd.read_csv(os.path.join(self.base_dir, 'test.csv'))
        # Split train dataset into train dataset and validation dataset
        X, y = df_train.path, df_train.class_id
        train_paths, valid_paths, y_train, y_valid = train_test_split(X, y, random_state=0)
        # Store train/validation/test dataset path and label
        self.train_targets = y_train - 1
        self.train_paths = list(map(lambda x: os.path.join(self.base_dir, x), train_paths))
        self.valid_targets = y_valid - 1
        self.valid_paths = list(map(lambda x: os.path.join(self.base_dir, x), valid_paths))
        tmp_test = self.df_test
        test_paths = tmp_test.path
        self.test_targets = tmp_test.class_id - 1
        self.test_paths = list(map(lambda x: os.path.join(self.base_dir, x), test_paths))
        
    def train_dataloader(self):
        """
        :return: output - Train data loader for the given input
        """
        train_data_retriever = DataRetriever(
            self.train_paths, 
            self.train_targets, 
            image_size=self.image_size,
            transforms=self.train_augmentations
        )
        train_loader = torch_data.DataLoader(
            train_data_retriever,
            batch_size=self.train_batch_size,
            shuffle=True,
        )
        return train_loader
    
    def val_dataloader(self):
        """
        :return: output - Validation data loader for the given input
        """
        valid_data_retriever = DataRetriever(
            self.valid_paths, 
            self.valid_targets, 
            image_size=self.image_size,
        )
        valid_loader = torch_data.DataLoader(
            valid_data_retriever, 
            batch_size=self.valid_batch_size,
            shuffle=True,
        )
        return valid_loader
    
    def test_dataloader(self):
        """
        :return: output - Test data loader for the given input
        """
        test_data_retriever = DataRetriever(
            self.test_paths, 
            self.test_targets, 
            image_size=self.image_size,
        )
        test_loader = torch_data.DataLoader(
            test_data_retriever, 
            batch_size=self.test_batch_size,
            shuffle=False,
        )
        return test_loader

LEGOMinifiguresDataModule类继承于pl.LightningDataModule,大体上可以说由两部分组成。第一部分是读取数据集,并且拆分成训练数据集/验证数据集/测试数据集(setup函数)。第二部分是对于这三个数据集的dataloader(train_dataloaderval_dataloadertest_dataloader函数)。

5.2.2 数据预处理

对于train_dataloaderval_dataloadertest_dataloader函数,我们需要先对各自输入的数据集进行一些预处理,见DataRetriever类:

class DataRetriever(torch_data.Dataset):
    def __init__(
        self, 
        paths, 
        targets, 
        image_size=(224, 224),
        transforms=None
    ):
        self.paths = list(paths)
        self.targets = list(targets)
        self.image_size = image_size
        self.transforms = transforms
        self.preprocess = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]
            ),
        ])
          
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, index):
        img = cv2.imread(self.paths[index])
        # 我们需要调整图像大小到与MoblieNetV2的输入尺寸匹配
        img = cv2.resize(img, self.image_size)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transforms:
            img = self.transforms(image=img)['image']
            
        img = self.preprocess(img)
        
        y = torch.tensor(self.targets[index], dtype=torch.long)
            
        return {'X': img, 'y': y}

我们先读取图像(cv2.imread(self.paths[index])),然后调整图像大小到与MoblieNetV2的输入尺寸匹配(cv2.resize(img, self.image_size)),将图像格式从BGR转换到RGB(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))。对于训练集,我们会对其进行数据增强(self.transforms(image=img)['image'])。然后,我们对于所有数据集需要进行预处理(self.preprocess(img)):

self.preprocess = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]
            ),
        ])

最后,__getitem__返回图像和标签:{'X': img, 'y': y}

5.2.3 数据增强(仅对于训练集)

和验证集/测试集略有不同,训练集的train_dataloader需要对于数据进行数据增强。

def get_train_transforms():
    return A.Compose(
        [
            A.Rotate(limit=30, border_mode=cv2.BORDER_REPLICATE, p=0.5),
            A.Cutout(num_holes=8, max_h_size=25, max_w_size=25, fill_value=0, p=0.25),
            A.Cutout(num_holes=8, max_h_size=25, max_w_size=25, fill_value=255, p=0.25),
            A.HorizontalFlip(p=0.5),
            A.RandomContrast(limit=(-0.3, 0.3), p=0.5),
            A.RandomBrightness(limit=(-0.4, 0.4), p=0.5),
            A.Blur(p=0.25),
        ], 
        p=1.0
    )

5.3 模型训练/验证/测试/优化类

pyTorch Lightning的一大优点就是将机器学习/深度学习流程化,标准化。与模型训练/验证/测试相关的代码都写在LitModel类中,继承于pl.LightningModule。此模块相关代码见LegoLitModel.py

5.3.1 模型导入

这里我们基于 MobileNetV2 的预训练模型,对 LEGO Minifigures 数据集进行迁移训练。所以,首先我们需要导入预训练模型,并且加上最后一层全连接层,组成一个完整的神经网络。代码位于LitModel类的__init__函数中:

def __init__(self, n_classes):
    super().__init__()
    self.net = torch.hub.load(
        'pytorch/vision:v0.6.0', 
        'mobilenet_v2', 
        pretrained=True
    )
    self.net.classifier = torch_nn.Linear(
        in_features=1280, 
        out_features=n_classes, 
        bias=True
    )
    self.save_hyperparameters()

def forward(self, x):
    x = self.net(x)
    return x

5.3.2 模型训练

模型训练相关的代码位于LitModel类的training_step函数中。模型训练的输入batch可以理解为LEGOMinifiguresDataModuletrain_dataloader的输出(我们可以这么理解)。相关代码如下:

def training_step(self, batch, batch_idx):
    """
    Training the data as batches and returns training loss on each batch

    :param train_batch: Batch data
    :param batch_idx: Batch indices

    :return: output - Training loss
    """
    X, y = batch['X'], batch['y']
    y_hat = self(X)
    train_loss = torch_F.cross_entropy(y_hat, y)
    train_acc = accuracy(
        y_hat, 
        y, 
        num_classes=self.hparams.n_classes
    )
    
    #result = pl.TrainResult(train_loss)
    self.log('train_loss', train_loss, prog_bar=True, on_epoch=True, on_step=False)
    self.log('train_acc', train_acc, prog_bar=True, on_epoch=True, on_step=False)
    return {"loss": train_loss}

我们通过training_step可以计算得到当前循环下 train loss 以及 train accuracy。函数返回 train loss 值。

5.3.3 模型验证

模型验证相关的代码位于LitModel类的validation_stepvalidation_epoch_end函数中。逻辑与模型训练类似:

def validation_step(self, batch, batch_idx):
    """
    Performs validation of data in batches

    :param batch: Batch data
    :param batch_idx: Batch indices

    :return: output - valid step loss
    """
    X, y = batch['X'], batch['y']
    y_hat = self(X)
    
    valid_loss = torch_F.cross_entropy(y_hat, y)
    valid_acc = accuracy(
        y_hat, 
        y, 
        num_classes=self.hparams.n_classes
    )
    
    #result = pl.EvalResult(checkpoint_on=valid_loss, early_stop_on=valid_loss)
    self.log('valid_loss', valid_loss, prog_bar=True, on_epoch=True, on_step=False)
    self.log('valid_acc', valid_acc, prog_bar=True, on_epoch=True, on_step=False)
    return {"val_step_loss": valid_loss}

def validation_epoch_end(self, outputs):
    """
    Computes average validation accuracy

    :param outputs: outputs after every epoch end

    :return: output - average valid loss
    """
    avg_loss = torch.stack([x["val_step_loss"] for x in outputs]).mean()
    self.log("val_loss", avg_loss, sync_dist=True)

5.3.4 模型测试

模型验证相关的代码位于LitModel类的test_steptest_epoch_end函数中。逻辑与模型训练类似:

def test_step(self, batch, batch_idx):
    """
    Performs test and computes the accuracy of the model

    :param batch: Batch data
    :param batch_idx: Batch indices

    :return: output - Testing accuracy
    """
    X, y = batch['X'], batch['y']
    y_hat = self(X)
    
    test_loss = torch_F.cross_entropy(y_hat, y)
    test_acc = accuracy(
        y_hat, 
        y, 
        num_classes=self.hparams.n_classes
    )
    
    #result = pl.EvalResult(checkpoint_on=valid_loss, early_stop_on=valid_loss)
    self.log('test_loss', test_loss, prog_bar=True, on_epoch=True, on_step=False)
    self.log('test_acc', test_acc, prog_bar=True, on_epoch=True, on_step=False)
    return {"test_acc": test_acc}

def test_epoch_end(self, outputs):
    """
    Computes average test accuracy score

    :param outputs: outputs after every epoch end

    :return: output - average test loss
    """
    avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
    self.log("avg_test_acc", avg_test_acc)

5.3.5 模型优化

最后一块就是模型优化,我们可以自行选择优化器。相关代码如下:

    def configure_optimizers(self):
        """
        Initializes the optimizer and learning rate scheduler

        :return: output - Initialized optimizer and scheduler
        """
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        self.scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode="min",
                factor=0.2,
                patience=2,
                min_lr=1e-6,
                verbose=True,
            ),
            "monitor": "val_loss",
        }
        return [self.optimizer], [self.scheduler]

5.4 迁移学习

这里,我们就开始迁移学习的训练,验证,以及最后的测试了。

5.4.1 实例化数据读取处理模块以及训练/验证/测试/优化类

上面章节介绍了 pytorch lightning 的 LightningDataModule 以及 LightningModule 类如何在我们这个案例中被继承与使用。这里,我们首先需要实例化 LitModel 以及 LEGOMinifiguresDataModule

# 实例化模型训练/验证/测试/优化类
model = LitModel(n_classes=N_CLASSES)
# 实例化数据读取处理模块
data_module = LEGOMinifiguresDataModule(
    train_batch_size=4, 
    valid_batch_size=1, 
    test_batch_size=1,
    image_size=(512, 512), 
    base_dir=BASE_DIR,
    train_augmentations=get_train_transforms()
)

5.4.2 迁移学习的训练,验证,以及测试

我们定义 EarlyStopping 条件,ModelCheckpoint 的保存机制,以及对 learning rate 的保存跟踪,然后就可以开始训练,验证,以及测试了。代码如下:

# 定义 EarlyStopping Criteria
early_stopping = EarlyStopping(
        monitor='valid_loss',
        mode='min',
        verbose=True,
        patience=3,
    )

# 定义 ModelCheckpoint,我们可以通过导入.ckpt文件模型进行模型继续训练,或者模型推理。
callback_model_checkpoint = ModelCheckpoint(
    dirpath=os.getcwd(),
    filename='sample-{epoch}-{valid_loss:.3f}', 
    save_top_k=1,
    verbose=True,
    monitor='valid_loss', 
    mode='min',
)

lr_logger = LearningRateMonitor()
# 实例化Trainer,这里epoch我只设定了12次,由于我这边的运行环境是CPU。
trainer = pl.Trainer(
    gpus=0,
    callbacks=[lr_logger, early_stopping, callback_model_checkpoint],
    checkpoint_callback=True,
    max_epochs=18
)

mlflow.pytorch.autolog()

with mlflow.start_run() as run:
    # 模型训练/验证
    trainer.fit(
        model, 
        data_module,
    )
    # 模型测试
    trainer.test(datamodule=data_module)

6 结果

我们运行完代码后,在terminal中输入mlflow ui。打开网页。

parameters相关参数信息见下图:

MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例_第1张图片

Metrics相关参数信息见下图:

MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例_第2张图片

train loss 每个 epoch 见下图:

MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例_第3张图片

train accuracy 每个 epoch 见下图:

MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例_第4张图片

validation loss 每个 epoch 见下图:

MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例_第5张图片

validation accuracy 每个 epoch 见下图:

MLOps极致细节:21. MLFlow Pytorch 的使用案例3:MobileNetV2 图像分类案例_第6张图片

你可能感兴趣的:(mlops,pytorch,深度学习,mlflow,mlops)