moco代码阅读笔记

关于moco的论文解析有很多了,但是光看解析还是有一些内容不是很懂。于是就想着看看代码。

基本思路

图片进行两次数据增强,分别成为query和key。每张图片要只和自己是一类,和别的图片不成为一类。

看论文的时候理解错误的地方:

  • 以为每次从队列中弹出来一部分作为反例,其实是整个队列都会作为反例的

看论文的时候没有理解的地方:

  • 是不是一定需要保证队列里存在的样本和query的样本是不重复的?怎么保证的?

理论上需要保证,但是代码里在第一个epoch是不会重复,在后面的epoch有极低的概率重复,不会影响到训练

  • 既然需要一个队列,那么第一个epoch的第一个batch的时候,队列里面是空的,这个时候负样本是怎么产生的?

开始的时候随机初始化队列,随机结果充满队列

代码阅读

代码:https://github.com/facebookresearch/moco

数据构造

首先是看它数据准备的部分。由于预训练是无监督的,所以只加载了images,而标签就舍弃了

    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if args.gpu is not None:
            images[0] = images[0].cuda(args.gpu, non_blocking=True)
            images[1] = images[1].cuda(args.gpu, non_blocking=True)

train_loader使用的Pytorch官方的ImageFolder类进行加载,不同的是moco自定义了transform的方法,让transform的结果可以对一张图片进行两次数据增强,一个作为query,另一个作为key。
构造dataset,定义transform

    train_dataset = datasets.ImageFolder(
        traindir,
        moco.loader.TwoCropsTransform(transforms.Compose(augmentation)))
...
#在文件loader.py中
class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]

这样,dataloader就产生了q和k两个结果。

模型结构

模型结构就是resnet一类经典结构,没什么好说的。他会作为q和k的encoder。模型中比较重要的是它队列的实现方法。理解了队列的实现也就理解了这篇论文。

队列构建
       # create the queue
       self.register_buffer("queue", torch.randn(dim, K))
       self.queue = nn.functional.normalize(self.queue, dim=0)
       
       self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

就是一个循环队列。注册了一个buffer来保存队列的值,队列不参与反向求导。使用torch.randn进行了初始化,即刚开始时队列是随机的
队列的更新就在deque_and_enque函数里面

    @torch.no_grad()
   def _dequeue_and_enqueue(self, keys):
       # gather keys before updating queue
       keys = concat_all_gather(keys)

       batch_size = keys.shape[0]

       ptr = int(self.queue_ptr)
       assert self.K % batch_size == 0  # for simplicity

       # replace the keys at ptr (dequeue and enqueue)
       self.queue[:, ptr:ptr + batch_size] = keys.T
       ptr = (ptr + batch_size) % self.K  # move pointer

       self.queue_ptr[0] = ptr

每次得到的key会进入队列,并且将最旧的那批数据更新出去。作者在实现代码的时候要求队列的大小可以被batch_size整除,这也是为了更新的方便。
比较重要的一点是怎么样保证队列中不存在query里的样本。可以看到队列的大小默认是65536,也就是说队列中最多存在65536个样本。在第一个epoch时,由于样本全部没有进入过队列,所以第一个epoch是绝对可以保证不重复的。到了第二个epoch时,dataloader会进行shuffle,顺序打乱,此时就并不能保证一定不重复。imagenet有千万级别的图片,抽到和队列中重复的样本概率并不是很高,用这种方法来保证不重复。在实验的过程中如果训练集没有这么大的量级,可以考虑缩小队列的大小。当然这样也就会影响到训练效果。在这两者之间取一个平衡。

你可能感兴趣的:(深度学习,深度学习)