基于伪标签的半监督学习——Pytorch框架识别MNIST数据集

概述

在训练模型的时候,同时使用有标签数据和无标签数据进行训练,利用伪标记的方法给无标签数据赋予伪标签,再将无标签数据当作有标签数据进行训练,即利用无标签数据进行半监督学习。

伪标记

利用模型现有的预测能力,将无标签样本的预测值作为伪标签。例如将MNIST数据集输入到模型中,得到相应的0~9类别得分,将得分最高的类别作为伪标签,该伪标签当作一般标签使用,和原样本计算损失,迭代模型参数。

整体思路

模型需要先进行预训练,即先用少量的有标签数据训练模型,使得模型获得一定的准确率,之后再输入无标签数据和有标签一起训练模型,其中无标签数据的损失权重逐渐增加。整体两个阶段如下:

  1. 有标签数据预训练
  2. 加入无标签数据一起训练

关键代码

在无标签数据训练时,每到Nunlabel_batch_size就开始一个有标签数据“中途插入”批量训练,这个N是可以自己调控的超参数,N越大,无标签数据中有标签数据穿插训练的频率越快。

    for epoch in range(EPOCHS):
        for batch_idx, unlabeled_batch in enumerate(unlabeled_loader):
            # Forward Pass to get the pseudo labels
            x_unlabeled, y_unlabeled = unlabeled_batch[0],unlabeled_batch[1]
  
            output_unlabeled = model(x_unlabeled)
            _, pseudo_labeled = torch.max(output_unlabeled, 1)

            # Now calculate the unlabeled loss using the pseudo label
            output = model(x_unlabeled)
            # alpha = alpha_weight(step, T1, T2, af)
            unlabeled_loss = alpha * F.nll_loss(output, pseudo_labeled)

            # Backpropogate
            optimizer.zero_grad()
            unlabeled_loss.backward()
            optimizer.step()

            # For every 50 batches train one epoch on labeled data
            if batch_idx % 30 == 0:

                # Normal training procedure
                for batch_idx, label_batch in enumerate(label_loader):
                    X_batch, y_batch = label_batch[0], label_batch[1] 
                    output = model(X_batch)
                    predicted = torch.max(output, 1)[1]
                    labeled_loss = F.nll_loss(output, y_batch)
                    optimizer.zero_grad()
                    labeled_loss.backward()
                    optimizer.step()
                # Now we increment step by 1
                step += 1

重点是理解伪标记的方法以及如何使用伪标签,此处伪标记的方法就是简单的利用模型预测出最有可能的类别作为伪标签,这是一种非常暴力简单的赋予伪标签的方法,当然,实际上这样的半监督学习方法就能够很有效地利用大量无标签的数据。
感兴趣的朋友可以看看这篇文献,整体思路与这篇文献一致。欢迎朋友们交流探讨。
Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks

你可能感兴趣的:(学习,pytorch,深度学习)