MoCo论文中的Algorithm 1伪代码解读

具体解读了什么东西

论文中提供的伪代码大约如下:
MoCo论文中的Algorithm 1伪代码解读_第1张图片

下面我将分步骤介绍这个代码干什么

1.query encoder和key encoder的参数初始化

其实也没表达什么就是一开始大家的参数是一样的:

f_k.params = f_q.params

2.之后就是loader当中取数据

这个也没啥的就是取出来数据的问题:

for x in loader: # load a minibatch x with N samples

3.数据增强

就是代码不是直接将内容输入其中,也会通过数据增强取出内容

x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version

4.核心操作

首先我们先理解一下这个N和C是什么?

q = f_q.forward(x_q) # queries: NxC
k = f_k.forward(x_k) # keys: NxC

N其实是一个batch_size
C是一个输入数据的特征数,每个输入数据是一个1×C的张量

k = k.detach() # no gradient to keys

这个其实就是文章的主要创新点了,因为优化key_encoder是来自于query_encoder的优化。所以自然就不需要前传梯度,也能剩下个内存。

这里是矩阵乘法,理解一下这里的矩阵乘法:

# positive logits: Nx1
l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
# negative logits: NxK
l_neg = mm(q.view(N,C), queue.view(C,K))
# logits: Nx(1+K)
logits = cat([l_pos, l_neg], dim=1)
  • 1.首先我们应当理解一下这个q和k到底是什么东西,可以看到q和k分别来自于x_q和x_k,我们注意这两个东西其实都来自于x只是作了不同的数据增强罢了。
    好了,现在我们应该能判断出来,这里的x和k我们应该认为同一个类别。
  • 2.l_pos 现在我们就知道这个东西应该是一个N*1的一组接近1的数值
  • 3.我们注意queue是我们存储的之前的batch的内容,所以这个东西和我们当前这个batch的内容应该是没有任何交集的,也就是他们来自于不同的内容,按照对比学习的思想,来自不同事物的内容应该完全不相交。所以他们的相似度应该尽量的低。
  • 4.l_neg应当得到一个N*K的一组接近0的数值。
  • 5.logits的内容就自然而然出现了,应该为一个N*(K+1)的内容,这些内容应该具有下面的特点:K+1的向量除了第一位接近1之外其他都应该接近0。
  • 6.在现在的情况下我们自然而然可以得出一个内容就是,每个(K+1)的张量经过softmax之后,模型都应该判别其为正确。也就是所有的N个张量都是0号分类。

5.交叉熵loss

这里其实不能完全的算成交叉熵损失函数,这个是一个带有热度的交叉熵损失函数。但是其实我们可以将其想成交叉熵函数来简化理解:

# contrastive loss, Eqn.(1)
labels = zeros(N) # positives are the 0-th
loss = CrossEntropyLoss(logits/t, labels)

之前我们谈过了,这里的所有内容都应该是第0个分类,所以我们这里直接让所有的分类都是第0分类就完事了。

模型更新

下面是很正常的backward

# SGD update: query network
loss.backward()
update(f_q.params)

然后就是本文核心的动量优化

# momentum update: key network
f_k.params = m*f_k.params+(1-m)*f_q.params

其实就是让keyencoder也和queryencoder做相同方向的优化

7.更新字典

首先理解什么是字典,就是和什么比较的问题,这个字典就是我们用来和学习的内容比较的内容。这里其实就是实现了将这个batchsize的内容出队将新的这个batchsize进队。


# update dictionary
enqueue(queue, k) # enqueue the current minibatch
dequeue(queue) # dequeue the earliest minibatch

你可能感兴趣的:(对比学习论文阅读记录,深度学习)