[pytorch]FixMatch代码详解-数据加载

FixMatch代码详解-数据加载

  • 原文及代码
    • 原文
    • 代码
  • 数据加载分析
    • 数据集
    • 训练执行文件 train.py
    • Dataset对象 cifar.py

原文及代码

最近想使用Fixmatch来实现办监督学习,找了看官方代码是tensorflow1.0的版本…于是转战pytorch.但pytorch我刚刚入门,我看了看代码的解读也不多,这里我就记录一下自己一点点琢磨的东西吧,希望大家可以共同学习.

这一篇主要是关于代码中的数据加载部分的.

原文

FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence.
这里还有一个译制版的很方便阅读
FixMatch:通过一致性和置信度简化半监督学习

代码

pytorch的代码有很多版本,我选择了比较简单的一个:
unofficial PyTorch implementation of FixMatch

其他版本的代码:
TorchSSL

同时,其他大佬写的文章也给我了很多参考:
FixMatch文章解读+算法流程+核心代码详解
2021 CVPR《Meta Pseudo Labels 》 PyTorch复现

数据加载分析

如果想要将fixmatch应用到自己的程序上,只需要修改数据部分的代码,所以我先从这一部分分析。所有的参数我都默认使用作者给出的例子:

python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/cifar10@4000.5

数据集

原文中使用的是CIFAR-10,CIFAR-10 数据集由 10 个类别的 60000 个 32x32 彩色图像组成,每个类别包含 6000 个图像。有 50000 个训练图像和 10000 个测试图像。
这里作者想使用4000张带标签照片(每个类400张)来进行训练。

训练执行文件 train.py

我们从主文件开始看,一步步分析每个函数的作用。
首先,有一些参数是数据加载比较重要的。

parser.add_argument('--dataset', default='cifar10', type=str,
                        choices=['cifar10', 'cifar100'],
                        help='dataset name') //选择哪个数据集
parser.add_argument('--num-labeled', type=int, default=4000,
                        help='number of labeled data')    //多少张带标签的图片
parser.add_argument("--expand-labels", action="store_true",
                        help="expand labels to fit eval steps") 
                        //数据扩展来适应每个step的数据产生,一会会详细看这一点
parser.add_argument('--total-steps', default=2**20, type=int,
                        help='number of total steps to run') 
                        //这里使用总步数来进行训练(epoch可以由此计算出)
parser.add_argument('--eval-step', default=1024, type=int,
                        help='number of eval steps to run') //每个epoch中的步数
parser.add_argument('--batch-size', default=64, type=int,
                        help='train batchsize') //batch size
parser.add_argument('--mu', default=7, type=int,
                        help='coefficient of unlabeled batch size')
                        // 原文中的超参数μ    

首先,我们看看dataset是怎样产生的,有了dataset类,我们才能创建DataLoader对象。
这里提一下Pytorch读取数据流程:Pytorch DataLoader详解
[pytorch]FixMatch代码详解-数据加载_第1张图片
在代码中,Dataset是这样产生的:

 labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](args, './data')

然后,跳到下一小节看对DATASET_GETTERS函数的分析。Dataset对象 cifar.py

通过dataset类,我们产生dataloader对象:

train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler

labeled_trainloader = DataLoader(
    labeled_dataset,
    sampler=train_sampler(labeled_dataset),
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    drop_last=True)

unlabeled_trainloader = DataLoader(
    unlabeled_dataset,
    sampler=train_sampler(unlabeled_dataset),
    batch_size=args.batch_size*args.mu,
    num_workers=args.num_workers,
    drop_last=True)

test_loader = DataLoader(
    test_dataset,
    sampler=SequentialSampler(test_dataset),
    batch_size=args.batch_size,
    num_workers=args.num_workers)

关于dataloader的参数:
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
[pytorch]FixMatch代码详解-数据加载_第2张图片
数据加载部分差不多就结束了,最后我们再看看在循环中是如何调用这些数据的吧。首先,作者使用了x, y = next(iter(training_loader))结构,其原理:

[pytorch]FixMatch代码详解-数据加载_第3张图片
作者的代码如下,当迭代到最后的轮次时会报错,所以加上except开始新一轮的迭代.

labeled_iter = iter(labeled_trainloader)
for batch_idx in range(1024):
    try:
        inputs_x, targets_x = labeled_iter.next()
        print(targets_x.shape[0])
    except:
        labeled_iter = iter(labeled_trainloader)
        inputs_x, targets_x = labeled_iter.next()
        print(targets_x.shape[0])

Dataset对象 cifar.py


在dataset文件夹中的cifar.py文件中定义了dataset类。
首先,我们使用get_cifar10函数:

DATASET_GETTERS = {'cifar10': get_cifar10,
                   'cifar100': get_cifar100}
def get_cifar10(args, root):
    transform_labeled = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32,
                              padding=int(32*0.125),
                              padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
    ])
    base_dataset = datasets.CIFAR10(root, train=True, download=True)

    train_labeled_idxs, train_unlabeled_idxs = x_u_split(
        args, base_dataset.targets)

    train_labeled_dataset = CIFAR10SSL(
        root, train_labeled_idxs, train=True,
        transform=transform_labeled)

    train_unlabeled_dataset = CIFAR10SSL(
        root, train_unlabeled_idxs, train=True,
        transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))

    test_dataset = datasets.CIFAR10(
        root, train=False, transform=transform_val, download=False)

    return train_labeled_dataset, train_unlabeled_dataset, test_dataset

这个函数还是比较复杂的,我们一点点看.

base_dataset = datasets.CIFAR10(root, train=True, download=True)

这是使用cifar数据的常用方法,我们可以查看其返回的对象

for (image, target)  in base_dataset:
    image.show()
    print(target)
print(len(base_dataset.targets)) //50000

之后,使用x_u_split函数将带标签与不带标签的数据索引分开:

train_labeled_idxs, train_unlabeled_idxs = x_u_split(
        args, base_dataset.targets)
def x_u_split(args, labels):
    label_per_class = args.num_labeled // args.num_classes
    labels = np.array(labels)
    labeled_idx = []
    # unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
    unlabeled_idx = np.array(range(len(labels)))
    for i in range(args.num_classes):
        idx = np.where(labels == i)[0]
        idx = np.random.choice(idx, label_per_class, False)
        labeled_idx.extend(idx)
    labeled_idx = np.array(labeled_idx)
    assert len(labeled_idx) == args.num_labeled

    if args.expand_labels or args.num_labeled < args.batch_size:
        num_expand_x = math.ceil(
            args.batch_size * args.eval_step / args.num_labeled)
        labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
    np.random.shuffle(labeled_idx)
    return labeled_idx, unlabeled_idx

每个类带标签数据的个数是均衡的,每个类带标签的数据个数 = 带标签数据总个数//类数.
所以,使用一个循环(10个类):
对于每一个类,找出他们在总数据(labels)中的数据索引,并用random.choice随机选择label_per_class个数据,将他们加入到带标签的数据索引labeled_idx中。
对于不带标签的数据,原文作者使用了所有的数据(包含带标签的数据),所以他的索引为全部数据的索引。
需要注意的一个点是,args.expand_labels参数作者默认为true的,所以我们要进行数据重复.
这里重复的次数num_expand_x为 64(batch_size )* 1024(eval_step)/ 4000 (num_labeled)=17次
所以带标签的数据为 68000个(每个索引都重复了17次)。

    train_labeled_dataset = CIFAR10SSL(
        root, train_labeled_idxs, train=True,
        transform=transform_labeled)

    train_unlabeled_dataset = CIFAR10SSL(
        root, train_unlabeled_idxs, train=True,
        transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
    test_dataset = datasets.CIFAR100(
        root, train=False, transform=transform_val, download=False)

然后,使用继承CIFAR10的CIFAR10SSL类产生dataset对象。验证集直接使用就行,只需要将数据转化为tensor。

class CIFAR10SSL(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super().__init__(root, train=train,
                         transform=transform,
                         target_transform=target_transform,
                         download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

根据索引返回对应的img和target,用transform参数控制强弱变。
对无标记样本进行扩增(Augment),扩增分为强扩增和弱扩增,弱扩增使用标准的旋转和移位;强扩增使用RandAugment和CTAugment两种算法。
FixMatch 中使用的增强方法
FixMatch 利用了两种增强: “弱"和"强”.

  • 弱增强是一种标准的翻转和移位增强策略. 例如在数据集上以 50% 的概率随机水平翻转图像, 并且在垂直和水平方向上随机平移.
  • 对于"强"增强, 文中尝试了两种基于 AutoAugment 的方法, 然后是 Cutout. AutoAugment 使用强化学习来查找包含来自 Python Imaging Library 的转换的增强策略. 这需要标记数据来学习增强策略, 这使得在可用标记数据有限的 SSL 设置中使用存在问题. 因此, 使用不需要利用标记数据学习增强策略的 AutoAugment 变体, 例如 RandAugment 和 CTAugment. RandAugment 和 CTAugment 都没有使用学习策略, 而是为每个样本随机选择转换. 对于 RandAugment, 控制所有失真严重程度的幅度是从预定义的范围内随机采样的. 具有随机幅度的 RandAugment 也被用于 UDA. 而对于 CTAugment, 单个变换的幅度是即时学习的.

TransformFixMatch类这里我不再分析。文件randaugment.py中的函数操作可以看链接

OK,现在我们已经得到了带标签的数据集,不带标签的数据集,验证集的dataset数据,然后我们回到主文件。Dataset对象 cifar.py

你可能感兴趣的:(半监督学习,pytorch,深度学习,python)