关于moco的论文解析有很多了,但是光看解析还是有一些内容不是很懂。于是就想着看看代码。
图片进行两次数据增强,分别成为query和key。每张图片要只和自己是一类,和别的图片不成为一类。
看论文的时候理解错误的地方:
看论文的时候没有理解的地方:
理论上需要保证,但是代码里在第一个epoch是不会重复,在后面的epoch有极低的概率重复,不会影响到训练
开始的时候随机初始化队列,随机结果充满队列
代码:https://github.com/facebookresearch/moco
首先是看它数据准备的部分。由于预训练是无监督的,所以只加载了images,而标签就舍弃了
for i, (images, _) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
if args.gpu is not None:
images[0] = images[0].cuda(args.gpu, non_blocking=True)
images[1] = images[1].cuda(args.gpu, non_blocking=True)
train_loader使用的Pytorch官方的ImageFolder类进行加载,不同的是moco自定义了transform的方法,让transform的结果可以对一张图片进行两次数据增强,一个作为query,另一个作为key。
构造dataset,定义transform
train_dataset = datasets.ImageFolder(
traindir,
moco.loader.TwoCropsTransform(transforms.Compose(augmentation)))
...
#在文件loader.py中
class TwoCropsTransform:
"""Take two random crops of one image as the query and key."""
def __init__(self, base_transform):
self.base_transform = base_transform
def __call__(self, x):
q = self.base_transform(x)
k = self.base_transform(x)
return [q, k]
这样,dataloader就产生了q和k两个结果。
模型结构就是resnet一类经典结构,没什么好说的。他会作为q和k的encoder。模型中比较重要的是它队列的实现方法。理解了队列的实现也就理解了这篇论文。
# create the queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
就是一个循环队列。注册了一个buffer来保存队列的值,队列不参与反向求导。使用torch.randn进行了初始化,即刚开始时队列是随机的。
队列的更新就在deque_and_enque函数里面
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
# gather keys before updating queue
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
每次得到的key会进入队列,并且将最旧的那批数据更新出去。作者在实现代码的时候要求队列的大小可以被batch_size整除,这也是为了更新的方便。
比较重要的一点是怎么样保证队列中不存在query里的样本。可以看到队列的大小默认是65536,也就是说队列中最多存在65536个样本。在第一个epoch时,由于样本全部没有进入过队列,所以第一个epoch是绝对可以保证不重复的。到了第二个epoch时,dataloader会进行shuffle,顺序打乱,此时就并不能保证一定不重复。imagenet有千万级别的图片,抽到和队列中重复的样本概率并不是很高,用这种方法来保证不重复。在实验的过程中如果训练集没有这么大的量级,可以考虑缩小队列的大小。当然这样也就会影响到训练效果。在这两者之间取一个平衡。