学习资料:
MOCO论文精读-哔哩哔哩
动量对比学习方法做无监督表征学习。
使用动量的方式更新编码器,使用队列的方法存储memory bank这个字典,从而获得一个又大又一致的字典。
把memory bank这个字典当做一个队列进行维护
加权移动平均。让当前输出不完全依赖于当前输入,而是也取决于前一时刻的输出。
y t = m ∗ y t − 1 + ( 1 − m ) ∗ x t y_t = m*y_{t-1}+(1-m)*x_t yt=m∗yt−1+(1−m)∗xt
当 m m m趋近于1的时候, y t y_t yt改变的非常缓慢,因为它基本就相当于前一刻的输出, x t x_t xt基本不起作用。
当 m m m趋近于0的时候, y t y_t yt改变的非常迅速,因为前一刻的输出基本不起作用,基本取决于当前输入 x t x_t xt。
moco利用动量的特性,缓慢更新编码器。让字典中的特征尽可能保持一致。
原图x1经过数据增广后产生的两张图片互为正样本对,数据集中除了x1以外的其他图片都可以为负样本。
把对比学习当成动态字典。
These methods can be thought of as building dynamic dictionaries. The “keys” (tokens) in the dictionary are sampled from data (e.g., images or patches) and are represented by an encoder network. Unsupervised learning trains encoders to perform dictionary look-up: an encoded “query” should be similar to its matching key and dissimilar to others. Learning is formulated as minimizing a contrastive loss.
在论文中,作者使用:
使用队列可以使字典变大,但它也使得通过反向传播(梯度应该传播到队列中的所有样本)更新编码器变得难以处理。
一个解决方案是直接把 f q f_q fq的参数直接复制给 f k f_k fk,而不进行梯度更新,但是这样效果很差,因为这样就会失去 f k f_k fk编码器对特征提取的一致性,由于编码器的参数一直在改变,所以一开始提取的特征就会和后面提取的特征很不一致,所以对 f k f_k fk编码器采用动量更新。
动量更新通过逐步调整编码器参数,使 f k f_k fk编码器逐渐学习到 f q f_q fq编码器的知识。这种更新策略有助于提高特征表示的一致性和稳定性,使得编码器能够更好地捕捉数据的结构和语义信息。
一个正样本对:anchor和positive,anchor为q(anchor经过编码器后生成的特征向量),在字典中与q唯一配对的key成为 k + k_+ k+(positive经过编码器后生成的特征向量)
对于目标函数来说,当 q q q与它的 k + k_+ k+相似而与所有其他key(q的负样本)不相似时,它的值应该很低。
对于损失函数来说,如果使用交叉熵作为损失函数,那么公式中的k值就会非常大,因为k代表的是类别个数,在个体判别领域,一张图片代表一个类,那么k值就等于数据集的大小,这样计算量过大。
为解决交叉熵损失函数中对类别求概率时k值过大的问题,提出了NCE loss(noise contractive )
NCE loss(noise contractive estimation) 可以把问题转换为只有以下两个类别的二分类问题:
其实NCE loss就类似于一个交叉熵损失函数,做了一个k+1类的分类任务,目的就是把q分类为 k + k_+ k+类。
MOCO详解-知乎
更新的有两个:encoder和队列
encode_p是根据梯度回传进行更新的,encode_k是通过动量进行缓慢 更新的。
# f_q, f_k: encoder networks for query and key
# queue: dictionary as a queue of K keys (CxK)
# m: momentum
# t: temperature
f_k.params = f_q.params # initialize
for x in loader: # load a minibatch x with N samples
x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version
q = f_q.forward(x_q) # queries: NxC
k = f_k.forward(x_k) # keys: NxC
k = k.detach() # no gradient to keys
# positive logits: Nx1
l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
# negative logits: NxK
l_neg = mm(q.view(N,C), queue.view(C,K))
# logits: Nx(1+K)
logits = cat([l_pos, l_neg], dim=1)
# contrastive loss, Eqn.(1)
labels = zeros(N) # positives are the 0-th
loss = CrossEntropyLoss(logits/t, labels)
# SGD update: query network
loss.backward() update(f_q.params)
# momentum update: key network
f_k.params = m*f_k.params+(1-m)*f_q.params
# update dictionary
enqueue(queue, k) # enqueue the current minibatch
dequeue(queue) # dequeue the earliest minibatch