MobileNetV2 pyTorch Lightning LEGO Minifigures 图像分类案例

MobileNetV2 pyTorch Lightning LEGO Minifigures 图像分类案例

此案例中,我们将通过 pyTorch Lightning 对 MobileNetV2 预训练模型进行迁移学习,对象是 LEGO Minifigures 数据集。提供源代码。

运行环境:

  • 平台:Win10。
  • IDE:Visual Studio Code
  • 建议预装:Anaconda3
  • 代码

觉得写的可以的话点个赞,收藏,加关注哦。


文章目录

  • MobileNetV2 pyTorch Lightning LEGO Minifigures 图像分类案例
    • 1 关于 LEGO Minifigures 数据集
    • 2 代码运行
    • 3 数据读取与处理
      • 3.1 pl.LightningDataModule
      • 3.2 数据预处理
      • 3.3 数据增强(仅对于训练集)
    • 4 模型训练/验证/测试/优化类
      • 4.1 pl.LightningModule
      • 4.2 模型导入
      • 4.3 模型训练
      • 4.4 模型验证
      • 4.5 模型测试
      • 4.5 模型优化
    • 5 迁移学习
      • 5.1 实例化数据读取处理模块以及训练/验证/测试/优化类
      • 5.2 迁移学习的训练,验证,以及测试
    • 6 结果
      • 6.1 训练/验证/测试 Loss 以及 Accuracy 值
      • 6.2 通过 ModelCheckpoint 文件导入最优模型并进行模型推理
      • 6.3 找出误识别的图片


1 关于 LEGO Minifigures 数据集

该数据集包含各种乐高人仔的图片。 数据集中的每个人仔都有几张不同姿势和不同环境的图像。

目前,它包含来自乐高套装的 28 个人物(总共 300 多张图像):Yoda’s Hut, Spider Mech vs. Venom, General Grievous’ Combat Speeder, Kylo Ren’s Shuttle™ Microfighter, AT-ST™ Raider from The Mandalorian, Molten Man Battle, Aragog’s Lair, Black Widow’s Helicopter Chase, Captain America: Outriders Attack, Pteranodon Chase, Iron Man Hall of Armor。

该数据集的GITHUB链接。Kaggle中也有其数据集可以直接下载,Kaggle链接。

这里我们已经下载了这个数据集,在archive文件夹下。需要注意:

  • index.csv:包含了训练集/验证集(一共361个数据)图片的位置,以及对应的标签(所以这个csv文件有2列);
  • test.csv:包含了测试集(一共76个数据)图片位置,以及对应的标签;
  • metadata.csv:包含了这个数据集的37个标签。

2 代码运行

我们建议在运行代码前,我们先在本地建一个虚拟环境。如果我们在本地安装了Anaconda,那么可以使用conda create -n your_env_name python=X.X(2.7、3.6等)命令创建python版本为X.X、名字为your_env_name的虚拟环境。

这里我们输入conda create -n mlFlowEx python=3.8.2

安装完默认的依赖后,我们进入虚拟环境:conda activate mlFlowEx。注意,如果需要退出,则输入conda deactivate。另外,如果Terminal没有成功切换到虚拟环境,可以尝试conda init powershell,然后重启terminal。

然后,我们在虚拟环境中下载好相关依赖:pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

如果我们需要在jupyter notebook中运行代码,这里需要添加名为mlFlowEx_env的kernal:

python -m ipykernel install --user --name mlFlowEx_env

这里有两种运行代码的方式:

  • 我们可以直接在主路径下运行:python LegoCharacterRecognition.py
  • 我们也可以在lego-minifigures-pytorch-lightning-tutorial.ipynb逐条运行代码;

3 数据读取与处理

首先,我们需要读取数据,拆分成训练集、验证集、测试集,并且进行一些数据预处理,必须数据增强等。

3.1 pl.LightningDataModule

pl.LightningDataModule 标准化了训练、验证、测试集的拆分、数据准备和转换。主要优点是一致的数据拆分、数据准备和跨模型转换,一个例子如下:

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
    def train_dataloader(self):
        train_split = Dataset(...)
        return DataLoader(train_split)
    def val_dataloader(self):
        val_split = Dataset(...)
        return DataLoader(val_split)
    def test_dataloader(self):
        test_split = Dataset(...)
        return DataLoader(test_split)
    def teardown(self):
        # clean up after fit or test
        # called on every process in DDP

此项目代码:

class LEGOMinifiguresDataModule(pl.LightningDataModule):
    
    def __init__(
        self, 
        train_batch_size, 
        valid_batch_size, 
        test_batch_size, 
        image_size, 
        base_dir,
        train_augmentations=None
    ):
        """
        Initialization of inherited lightning data module
        """
        super().__init__()
        self.train_batch_size = train_batch_size
        self.valid_batch_size = valid_batch_size
        self.test_batch_size = test_batch_size
        self.image_size = image_size
        self.base_dir = base_dir
        self.train_augmentations=train_augmentations
        
    def setup(self, stage):
        """
        Load the data, parse it and split the data into train, test, validation data
        """
        # Load train dataset, test dataset
        self.df_train = pd.read_csv(os.path.join(self.base_dir, 'index.csv'))
        self.df_test = pd.read_csv(os.path.join(self.base_dir, 'test.csv'))
        # Split train dataset into train dataset and validation dataset
        X, y = df_train.path, df_train.class_id
        train_paths, valid_paths, y_train, y_valid = train_test_split(X, y, random_state=0)
        # Store train/validation/test dataset path and label
        self.train_targets = y_train - 1
        self.train_paths = list(map(lambda x: os.path.join(self.base_dir, x), train_paths))
        self.valid_targets = y_valid - 1
        self.valid_paths = list(map(lambda x: os.path.join(self.base_dir, x), valid_paths))
        tmp_test = self.df_test
        test_paths = tmp_test.path
        self.test_targets = tmp_test.class_id - 1
        self.test_paths = list(map(lambda x: os.path.join(self.base_dir, x), test_paths))
        
    def train_dataloader(self):
        """
        :return: output - Train data loader for the given input
        """
        train_data_retriever = DataRetriever(
            self.train_paths, 
            self.train_targets, 
            image_size=self.image_size,
            transforms=self.train_augmentations
        )
        train_loader = torch_data.DataLoader(
            train_data_retriever,
            batch_size=self.train_batch_size,
            shuffle=True,
        )
        return train_loader
    
    def val_dataloader(self):
        """
        :return: output - Validation data loader for the given input
        """
        valid_data_retriever = DataRetriever(
            self.valid_paths, 
            self.valid_targets, 
            image_size=self.image_size,
        )
        valid_loader = torch_data.DataLoader(
            valid_data_retriever, 
            batch_size=self.valid_batch_size,
            shuffle=True,
        )
        return valid_loader
    
    def test_dataloader(self):
        """
        :return: output - Test data loader for the given input
        """
        test_data_retriever = DataRetriever(
            self.test_paths, 
            self.test_targets, 
            image_size=self.image_size,
        )
        test_loader = torch_data.DataLoader(
            test_data_retriever, 
            batch_size=self.test_batch_size,
            shuffle=False,
        )
        return test_loader

LEGOMinifiguresDataModule类继承于pl.LightningDataModule,大体上可以说由两部分组成。第一部分是读取数据集,并且拆分成训练数据集/验证数据集/测试数据集(setup函数)。第二部分是对于这三个数据集的dataloader(train_dataloaderval_dataloadertest_dataloader函数)。

3.2 数据预处理

对于train_dataloaderval_dataloadertest_dataloader函数,我们需要先对各自输入的数据集进行一些预处理,见DataRetriever类:

class DataRetriever(torch_data.Dataset):
    def __init__(
        self, 
        paths, 
        targets, 
        image_size=(224, 224),
        transforms=None
    ):
        self.paths = list(paths)
        self.targets = list(targets)
        self.image_size = image_size
        self.transforms = transforms
        self.preprocess = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]
            ),
        ])
          
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, index):
        img = cv2.imread(self.paths[index])
        # 我们需要调整图像大小到与MoblieNetV2的输入尺寸匹配
        img = cv2.resize(img, self.image_size)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transforms:
            img = self.transforms(image=img)['image']
            
        img = self.preprocess(img)
        
        y = torch.tensor(self.targets[index], dtype=torch.long)
            
        return {'X': img, 'y': y}

我们先读取图像(cv2.imread(self.paths[index])),然后调整图像大小到与MoblieNetV2的输入尺寸匹配(cv2.resize(img, self.image_size)),将图像格式从BGR转换到RGB(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))。对于训练集,我们会对其进行数据增强(self.transforms(image=img)['image'])。然后,我们对于所有数据集需要进行预处理(self.preprocess(img)):

self.preprocess = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225]
            ),
        ])

最后,__getitem__返回图像和标签:{'X': img, 'y': y}

3.3 数据增强(仅对于训练集)

和验证集/测试集略有不同,训练集的train_dataloader需要对于数据进行数据增强。

def get_train_transforms():
    return A.Compose(
        [
            A.Rotate(limit=30, border_mode=cv2.BORDER_REPLICATE, p=0.5),
            A.Cutout(num_holes=8, max_h_size=25, max_w_size=25, fill_value=0, p=0.25),
            A.Cutout(num_holes=8, max_h_size=25, max_w_size=25, fill_value=255, p=0.25),
            A.HorizontalFlip(p=0.5),
            A.RandomContrast(limit=(-0.3, 0.3), p=0.5),
            A.RandomBrightness(limit=(-0.4, 0.4), p=0.5),
            A.Blur(p=0.25),
        ], 
        p=1.0
    )

4 模型训练/验证/测试/优化类

pyTorch Lightning的一大优点就是将机器学习/深度学习流程化,标准化。与模型训练/验证/测试相关的代码都写在LitModel类中,继承于pl.LightningModule

4.1 pl.LightningModule

对于 PyTorch Lightning,有两个函数是至关重要,一个是pl.LightningModule,一个是pl.LightningDataModule。前者的包含了训练/验证/预测/优化的所有模块,后者则是数据集读取模块。我们通过PyTorch Lightning进行模型训练的时候,通常会继承这两个类。目前我对 PyTorch Lightning 不是很了解,所以这里我作为一个初学者的角度,针对这个案例进行一些相关的解读。

关于pl.LightningModule,和我们这个案例相关的函数包括:

  • forward,作用和torch.nn.Module.forward()一样,这里我们不再赘述;
  • training_step,我们计算并返回训练损失和一些额外的metrics。
  • validation_step,我们计算并返回验证损失和一些额外的metrics。
  • test_step,我们计算并返回测试损失和一些额外的metrics。
  • validation_epoch_end,在验证epoch结束后,计算这个epoch的平均验证accuracy。
  • test_epoch_end,在测试epoch结束后,计算计算这个epoch的平均测试accuracy。
  • configure_optimizers,选择要在优化中使用的优化器和学习率调度器。

此网页有详细的描述,这里不再赘述。

4.2 模型导入

这里我们基于 MobileNetV2 的预训练模型,对 LEGO Minifigures 数据集进行迁移训练。所以,首先我们需要导入预训练模型,并且加上最后一层全连接层,组成一个完整的神经网络。代码位于LitModel类的__init__函数中:

def __init__(self, n_classes):
    super().__init__()
    self.net = torch.hub.load(
        'pytorch/vision:v0.6.0', 
        'mobilenet_v2', 
        pretrained=True
    )
    self.net.classifier = torch_nn.Linear(
        in_features=1280, 
        out_features=n_classes, 
        bias=True
    )
    self.save_hyperparameters()

def forward(self, x):
    x = self.net(x)
    return x

4.3 模型训练

模型训练相关的代码位于LitModel类的training_step函数中。模型训练的输入batch可以理解为LEGOMinifiguresDataModuletrain_dataloader的输出(我们可以这么理解)。相关代码如下:

def training_step(self, batch, batch_idx):
    """
    Training the data as batches and returns training loss on each batch

    :param train_batch: Batch data
    :param batch_idx: Batch indices

    :return: output - Training loss
    """
    X, y = batch['X'], batch['y']
    y_hat = self(X)
    train_loss = torch_F.cross_entropy(y_hat, y)
    train_acc = accuracy(
        y_hat, 
        y, 
        num_classes=self.hparams.n_classes
    )
    
    #result = pl.TrainResult(train_loss)
    self.log('train_loss', train_loss, prog_bar=True, on_epoch=True, on_step=False)
    self.log('train_acc', train_acc, prog_bar=True, on_epoch=True, on_step=False)
    return {"loss": train_loss}

我们通过training_step可以计算得到当前循环下 train loss 以及 train accuracy。函数返回 train loss 值。

4.4 模型验证

模型验证相关的代码位于LitModel类的validation_stepvalidation_epoch_end函数中。逻辑与模型训练类似:

def validation_step(self, batch, batch_idx):
    """
    Performs validation of data in batches

    :param batch: Batch data
    :param batch_idx: Batch indices

    :return: output - valid step loss
    """
    X, y = batch['X'], batch['y']
    y_hat = self(X)
    
    valid_loss = torch_F.cross_entropy(y_hat, y)
    valid_acc = accuracy(
        y_hat, 
        y, 
        num_classes=self.hparams.n_classes
    )
    
    #result = pl.EvalResult(checkpoint_on=valid_loss, early_stop_on=valid_loss)
    self.log('valid_loss', valid_loss, prog_bar=True, on_epoch=True, on_step=False)
    self.log('valid_acc', valid_acc, prog_bar=True, on_epoch=True, on_step=False)
    return {"val_step_loss": valid_loss}

def validation_epoch_end(self, outputs):
    """
    Computes average validation accuracy

    :param outputs: outputs after every epoch end

    :return: output - average valid loss
    """
    avg_loss = torch.stack([x["val_step_loss"] for x in outputs]).mean()
    self.log("val_loss", avg_loss, sync_dist=True)

4.5 模型测试

模型验证相关的代码位于LitModel类的test_steptest_epoch_end函数中。逻辑与模型训练类似:

def test_step(self, batch, batch_idx):
    """
    Performs test and computes the accuracy of the model

    :param batch: Batch data
    :param batch_idx: Batch indices

    :return: output - Testing accuracy
    """
    X, y = batch['X'], batch['y']
    y_hat = self(X)
    
    test_loss = torch_F.cross_entropy(y_hat, y)
    test_acc = accuracy(
        y_hat, 
        y, 
        num_classes=self.hparams.n_classes
    )
    
    #result = pl.EvalResult(checkpoint_on=valid_loss, early_stop_on=valid_loss)
    self.log('test_loss', test_loss, prog_bar=True, on_epoch=True, on_step=False)
    self.log('test_acc', test_acc, prog_bar=True, on_epoch=True, on_step=False)
    return {"test_acc": test_acc}

def test_epoch_end(self, outputs):
    """
    Computes average test accuracy score

    :param outputs: outputs after every epoch end

    :return: output - average test loss
    """
    avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
    self.log("avg_test_acc", avg_test_acc)

4.5 模型优化

最后一块就是模型优化,我们可以自行选择优化器。相关代码如下:

    def configure_optimizers(self):
        """
        Initializes the optimizer and learning rate scheduler

        :return: output - Initialized optimizer and scheduler
        """
        self.optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        self.scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode="min",
                factor=0.2,
                patience=2,
                min_lr=1e-6,
                verbose=True,
            ),
            "monitor": "val_loss",
        }
        return [self.optimizer], [self.scheduler]

5 迁移学习

这里,我们就开始迁移学习的训练,验证,以及最后的测试了。

5.1 实例化数据读取处理模块以及训练/验证/测试/优化类

上面章节介绍了 pytorch lightning 的 LightningDataModule 以及 LightningModule 类如何在我们这个案例中被继承与使用。这里,我们首先需要实例化 LitModel 以及 LEGOMinifiguresDataModule

# 实例化模型训练/验证/测试/优化类
model = LitModel(n_classes=N_CLASSES)
# 实例化数据读取处理模块
data_module = LEGOMinifiguresDataModule(
    train_batch_size=4, 
    valid_batch_size=1, 
    test_batch_size=1,
    image_size=(512, 512), 
    base_dir=BASE_DIR,
    train_augmentations=get_train_transforms()
)

5.2 迁移学习的训练,验证,以及测试

我们定义 EarlyStopping 条件,ModelCheckpoint 的保存机制,以及对 learning rate 的保存跟踪,然后就可以开始训练,验证,以及测试了。代码如下:

# 定义 EarlyStopping Criteria
early_stopping = EarlyStopping(
        monitor='valid_loss',
        mode='min',
        verbose=True,
        patience=3,
    )

# 定义 ModelCheckpoint,我们可以通过导入.ckpt文件模型进行模型继续训练,或者模型推理。
callback_model_checkpoint = ModelCheckpoint(
    dirpath=os.getcwd(),
    filename='sample-{epoch}-{valid_loss:.3f}', 
    save_top_k=1,
    verbose=True,
    monitor='valid_loss', 
    mode='min',
)

lr_logger = LearningRateMonitor()
# 实例化Trainer,这里epoch我只设定了12次,由于我这边的运行环境是CPU。
trainer = pl.Trainer(
    gpus=0,
    callbacks=[lr_logger, early_stopping, callback_model_checkpoint],
    checkpoint_callback=True,
    max_epochs=12
)
# 模型训练/验证
trainer.fit(
    model, 
    data_module,
)
# 模型测试
trainer.test(datamodule=data_module)

6 结果

6.1 训练/验证/测试 Loss 以及 Accuracy 值

由于我这边用CPU进行的训练,所以epoch的数量设置的比较小,如果是GPU的同学,建议设置50。运行的Terminal记录如下:

Using cache found in C:\Users\XXX/.cache\torch\hub\pytorch_vision_v0.6.0
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name | Type        | Params
-------------------------------------
0 | net  | MobileNetV2 | 2.3 M 
-------------------------------------
2.3 M     Trainable params
0         Non-trainable params
2.3 M     Total params
9.085     Total estimated model params size (MB)
Epoch 0:   8%|▊         | 13/159 [00:43<08:05,  3.32s/it, loss=3.67, v_num=13]Metric valid_loss improved. New best score: 3.316
Epoch 0: 100%|██████████| 159/159 [03:52<00:00,  1.46s/it, loss=3.51, v_num=13, valid_loss=3.320, valid_acc=0.143, train_loss=3.600, train_acc=0.0741]Epoch 0, global step 68: 'valid_loss' reached 3.31555 (best 3.31555), saving model to 'XXX\\sample-epoch=0-valid_loss=3.316.ckpt' as top 1
Epoch 1: 100%|██████████| 159/159 [07:42<00:00,  2.91s/it, loss=3, v_num=13, valid_loss=2.820, valid_acc=0.308, train_loss=3.600, train_acc=0.0741]     Metric valid_loss improved by 0.494 >= min_delta = 0.0. New best score: 2.822
Epoch 1: 100%|██████████| 159/159 [07:42<00:00,  2.91s/it, loss=3, v_num=13, valid_loss=2.820, valid_acc=0.308, train_loss=3.100, train_acc=0.215] Epoch 1, global step 136: 'valid_loss' reached 2.82201 (best 2.82201), saving model to 'XXX\\sample-epoch=1-valid_loss=2.822.ckpt' as top 1
...
Epoch 11: 100%|██████████| 159/159 [48:15<00:00, 18.21s/it, loss=0.569, v_num=13, valid_loss=0.391, valid_acc=0.945, train_loss=0.670, train_acc=0.944]    Metric valid_loss improved by 0.081 >= min_delta = 0.0. New best score: 0.391
Epoch 11: 100%|██████████| 159/159 [48:15<00:00, 18.21s/it, loss=0.569, v_num=13, valid_loss=0.391, valid_acc=0.945, train_loss=0.548, train_acc=0.978]Epoch 11, global step 816: 'valid_loss' reached 0.39098 (best 0.39098), saving model to 'XXX\\sample-epoch=11-valid_loss=0.391.ckpt' as top 1
Epoch 11: 100%|██████████| 159/159 [48:16<00:00, 18.21s/it, loss=0.569, v_num=13, valid_loss=0.391, valid_acc=0.945, train_loss=0.548, train_acc=0.978]
Restoring states from the checkpoint path at D:\yichao\learning\courses\MLOps\mlflow-ex\ex\8_LegoCharacterRecognition\sample-epoch=11-valid_loss=0.391.ckpt
Loaded model weights from checkpoint at D:\yichao\learning\courses\MLOps\mlflow-ex\ex\8_LegoCharacterRecognition\sample-epoch=11-valid_loss=0.391.ckpt
Testing DataLoader 0: 100%|██████████| 76/76 [00:20<00:00,  3.66it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      avg_test_acc          0.8684210777282715
        test_acc            0.8684210777282715
        test_loss           0.5233738422393799
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[{'test_loss': 0.5233738422393799,
  'test_acc': 0.8684210777282715,
  'avg_test_acc': 0.8684210777282715}]

我们可以看到,当我们最终得到的结果:

  • loss=0.569
  • valid_loss=0.391
  • valid_acc=0.945
  • train_loss=0.548
  • train_acc=0.978
  • test_loss=0.5233738422393799
  • test_acc=0.8684210777282715
  • avg_test_acc=0.8684210777282715

6.2 通过 ModelCheckpoint 文件导入最优模型并进行模型推理

这里我们尝试读取训练过程中保存的 checkpoint 文件:sample-epoch=11-valid_loss=0.391.ckpt,然后用其进行模型推理:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 读取 checkpoint 文件。注意,这个案例中这个 checkpoint 保存在当前路径下。
best_model_path = callback_model_checkpoint.best_model_path

model = LitModel.load_from_checkpoint(
    checkpoint_path=best_model_path
)
model = model.to(device)
model.freeze()

# 模型推理
y_pred = []
y_gt = []
for ind, batch in enumerate(data_module.test_dataloader()):
    pred_probs = model(batch['X'])
    y_pred.extend(pred_probs.argmax(axis=-1).cpu().numpy())
    y_gt.extend(batch['y'])
    
# Calculate needed metrics
print(f'Accuracy score on test data:\t{sk_metrics.accuracy_score(y_gt, y_pred)}')
print(f'Macro F1 score on test data:\t{sk_metrics.f1_score(y_gt, y_pred, average="macro")}')

我们得到的结果如下:

Accuracy score on test data: 0.868421052631579
Macro F1 score on test data: 0.8436293436293436

需要注意,由于这里我设置epoch的值只有12,这么模型完全可以精度再高一些,但由于我这边只有CPU,所以就不花时间等结果了。

6.3 找出误识别的图片

最后,我们可以更进一步,找出哪些误识别的图片(由于我们这个数据集比较小)。

首先,让我们来看看整个数据集的 confusion matrix 出来的效果:

# Load metadata to get classes people-friendly names
labels = df_metadata['minifigure_name'].tolist()

# Calculate confusion matrix
confusion_matrix = sk_metrics.confusion_matrix(y_gt, y_pred)
df_confusion_matrix = pd.DataFrame(confusion_matrix, index=labels, columns=labels)

# Show confusion matrix
plt.figure(figsize=(12, 12))
sn.heatmap(df_confusion_matrix, annot=True, cbar=False, cmap='Oranges', linewidths=1, linecolor='black')
plt.xlabel('Predicted labels', fontsize=15)
plt.xticks(fontsize=12)
plt.ylabel('True labels', fontsize=15)
plt.yticks(fontsize=12);

MobileNetV2 pyTorch Lightning LEGO Minifigures 图像分类案例_第1张图片

最后,我们将那些错误识别的图片都显示出来:

error_images = []
error_label = []
error_pred = []
error_prob = []
for batch in data_module.test_dataloader():
    _X_test, _y_test = batch['X'], batch['y']
    pred = torch.softmax(model(_X_test), axis=-1).cpu().numpy()
    pred_class = pred.argmax(axis=-1)
    if pred_class != _y_test.cpu().numpy():
        error_images.extend(_X_test)
        error_label.extend(_y_test)
        error_pred.extend(pred_class)
        error_prob.extend(pred.max(axis=-1))

def denormalize_image(image):
    return image * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]

plt.figure(figsize=(16, 16))
w_size = int(len(error_images) ** 0.5)
h_size = math.ceil(len(error_images) / w_size)
for ind, image in enumerate(error_images):
    plt.subplot(h_size, w_size, ind + 1)
    plt.imshow(denormalize_image(image.permute(1, 2, 0).numpy()))
    pred_label = labels[error_pred[ind]]
    pred_prob = error_prob[ind]
    true_label = labels[error_label[ind]]
    plt.title(f'predict: {pred_label} ({pred_prob:.2f}) true: {true_label}')
    plt.axis('off')

结果如下:

MobileNetV2 pyTorch Lightning LEGO Minifigures 图像分类案例_第2张图片

你可能感兴趣的:(deep,learning,pytorch,分类,深度学习)