最近想使用Fixmatch来实现办监督学习,找了看官方代码是tensorflow1.0的版本…于是转战pytorch.但pytorch我刚刚入门,我看了看代码的解读也不多,这里我就记录一下自己一点点琢磨的东西吧,希望大家可以共同学习.
这一篇主要是关于代码中的数据加载部分的.
FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence.
这里还有一个译制版的很方便阅读
FixMatch:通过一致性和置信度简化半监督学习
pytorch的代码有很多版本,我选择了比较简单的一个:
unofficial PyTorch implementation of FixMatch
其他版本的代码:
TorchSSL
同时,其他大佬写的文章也给我了很多参考:
FixMatch文章解读+算法流程+核心代码详解
2021 CVPR《Meta Pseudo Labels 》 PyTorch复现
如果想要将fixmatch应用到自己的程序上,只需要修改数据部分的代码,所以我先从这一部分分析。所有的参数我都默认使用作者给出的例子:
python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/cifar10@4000.5
原文中使用的是CIFAR-10,CIFAR-10 数据集由 10 个类别的 60000 个 32x32 彩色图像组成,每个类别包含 6000 个图像。有 50000 个训练图像和 10000 个测试图像。
这里作者想使用4000张带标签照片(每个类400张)来进行训练。
我们从主文件开始看,一步步分析每个函数的作用。
首先,有一些参数是数据加载比较重要的。
parser.add_argument('--dataset', default='cifar10', type=str,
choices=['cifar10', 'cifar100'],
help='dataset name') //选择哪个数据集
parser.add_argument('--num-labeled', type=int, default=4000,
help='number of labeled data') //多少张带标签的图片
parser.add_argument("--expand-labels", action="store_true",
help="expand labels to fit eval steps")
//数据扩展来适应每个step的数据产生,一会会详细看这一点
parser.add_argument('--total-steps', default=2**20, type=int,
help='number of total steps to run')
//这里使用总步数来进行训练(epoch可以由此计算出)
parser.add_argument('--eval-step', default=1024, type=int,
help='number of eval steps to run') //每个epoch中的步数
parser.add_argument('--batch-size', default=64, type=int,
help='train batchsize') //batch size
parser.add_argument('--mu', default=7, type=int,
help='coefficient of unlabeled batch size')
// 原文中的超参数μ
首先,我们看看dataset是怎样产生的,有了dataset类,我们才能创建DataLoader对象。
这里提一下Pytorch读取数据流程:Pytorch DataLoader详解
在代码中,Dataset是这样产生的:
labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](args, './data')
然后,跳到下一小节看对DATASET_GETTERS函数的分析。Dataset对象 cifar.py
通过dataset类,我们产生dataloader对象:
train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler
labeled_trainloader = DataLoader(
labeled_dataset,
sampler=train_sampler(labeled_dataset),
batch_size=args.batch_size,
num_workers=args.num_workers,
drop_last=True)
unlabeled_trainloader = DataLoader(
unlabeled_dataset,
sampler=train_sampler(unlabeled_dataset),
batch_size=args.batch_size*args.mu,
num_workers=args.num_workers,
drop_last=True)
test_loader = DataLoader(
test_dataset,
sampler=SequentialSampler(test_dataset),
batch_size=args.batch_size,
num_workers=args.num_workers)
关于dataloader的参数:
一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系
数据加载部分差不多就结束了,最后我们再看看在循环中是如何调用这些数据的吧。首先,作者使用了x, y = next(iter(training_loader))结构,其原理:
作者的代码如下,当迭代到最后的轮次时会报错,所以加上except开始新一轮的迭代.
labeled_iter = iter(labeled_trainloader)
for batch_idx in range(1024):
try:
inputs_x, targets_x = labeled_iter.next()
print(targets_x.shape[0])
except:
labeled_iter = iter(labeled_trainloader)
inputs_x, targets_x = labeled_iter.next()
print(targets_x.shape[0])
在dataset文件夹中的cifar.py文件中定义了dataset类。
首先,我们使用get_cifar10函数:
DATASET_GETTERS = {'cifar10': get_cifar10,
'cifar100': get_cifar100}
def get_cifar10(args, root):
transform_labeled = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(size=32,
padding=int(32*0.125),
padding_mode='reflect'),
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])
transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])
base_dataset = datasets.CIFAR10(root, train=True, download=True)
train_labeled_idxs, train_unlabeled_idxs = x_u_split(
args, base_dataset.targets)
train_labeled_dataset = CIFAR10SSL(
root, train_labeled_idxs, train=True,
transform=transform_labeled)
train_unlabeled_dataset = CIFAR10SSL(
root, train_unlabeled_idxs, train=True,
transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
test_dataset = datasets.CIFAR10(
root, train=False, transform=transform_val, download=False)
return train_labeled_dataset, train_unlabeled_dataset, test_dataset
这个函数还是比较复杂的,我们一点点看.
base_dataset = datasets.CIFAR10(root, train=True, download=True)
这是使用cifar数据的常用方法,我们可以查看其返回的对象
for (image, target) in base_dataset:
image.show()
print(target)
print(len(base_dataset.targets)) //50000
之后,使用x_u_split函数将带标签与不带标签的数据索引分开:
train_labeled_idxs, train_unlabeled_idxs = x_u_split(
args, base_dataset.targets)
def x_u_split(args, labels):
label_per_class = args.num_labeled // args.num_classes
labels = np.array(labels)
labeled_idx = []
# unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
unlabeled_idx = np.array(range(len(labels)))
for i in range(args.num_classes):
idx = np.where(labels == i)[0]
idx = np.random.choice(idx, label_per_class, False)
labeled_idx.extend(idx)
labeled_idx = np.array(labeled_idx)
assert len(labeled_idx) == args.num_labeled
if args.expand_labels or args.num_labeled < args.batch_size:
num_expand_x = math.ceil(
args.batch_size * args.eval_step / args.num_labeled)
labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
np.random.shuffle(labeled_idx)
return labeled_idx, unlabeled_idx
每个类带标签数据的个数是均衡的,每个类带标签的数据个数 = 带标签数据总个数//类数.
所以,使用一个循环(10个类):
对于每一个类,找出他们在总数据(labels)中的数据索引,并用random.choice随机选择label_per_class个数据,将他们加入到带标签的数据索引labeled_idx中。
对于不带标签的数据,原文作者使用了所有的数据(包含带标签的数据),所以他的索引为全部数据的索引。
需要注意的一个点是,args.expand_labels参数作者默认为true的,所以我们要进行数据重复.
这里重复的次数num_expand_x为 64(batch_size )* 1024(eval_step)/ 4000 (num_labeled)=17次
所以带标签的数据为 68000个(每个索引都重复了17次)。
train_labeled_dataset = CIFAR10SSL(
root, train_labeled_idxs, train=True,
transform=transform_labeled)
train_unlabeled_dataset = CIFAR10SSL(
root, train_unlabeled_idxs, train=True,
transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
test_dataset = datasets.CIFAR100(
root, train=False, transform=transform_val, download=False)
然后,使用继承CIFAR10的CIFAR10SSL类产生dataset对象。验证集直接使用就行,只需要将数据转化为tensor。
class CIFAR10SSL(datasets.CIFAR10):
def __init__(self, root, indexs, train=True,
transform=None, target_transform=None,
download=False):
super().__init__(root, train=train,
transform=transform,
target_transform=target_transform,
download=download)
if indexs is not None:
self.data = self.data[indexs]
self.targets = np.array(self.targets)[indexs]
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
根据索引返回对应的img和target,用transform参数控制强弱变。
对无标记样本进行扩增(Augment),扩增分为强扩增和弱扩增,弱扩增使用标准的旋转和移位;强扩增使用RandAugment和CTAugment两种算法。
FixMatch 中使用的增强方法
FixMatch 利用了两种增强: “弱"和"强”.
TransformFixMatch类这里我不再分析。文件randaugment.py中的函数操作可以看链接
OK,现在我们已经得到了带标签的数据集,不带标签的数据集,验证集的dataset数据,然后我们回到主文件。Dataset对象 cifar.py