MoCo论文中的Algorithm 1伪代码解读
具体解读了什么东西
论文中提供的伪代码大约如下:
下面我将分步骤介绍这个代码干什么
1.query encoder和key encoder的参数初始化
其实也没表达什么就是一开始大家的参数是一样的:
f_k.params = f_q.params
2.之后就是loader当中取数据
这个也没啥的就是取出来数据的问题:
for x in loader: # load a minibatch x with N samples
3.数据增强
就是代码不是直接将内容输入其中,也会通过数据增强取出内容
x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version
4.核心操作
首先我们先理解一下这个N和C是什么?
q = f_q.forward(x_q) # queries: NxC
k = f_k.forward(x_k) # keys: NxC
N其实是一个batch_size
C是一个输入数据的特征数,每个输入数据是一个1×C的张量
k = k.detach() # no gradient to keys
这个其实就是文章的主要创新点了,因为优化key_encoder是来自于query_encoder的优化。所以自然就不需要前传梯度,也能剩下个内存。
这里是矩阵乘法,理解一下这里的矩阵乘法:
# positive logits: Nx1
l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
# negative logits: NxK
l_neg = mm(q.view(N,C), queue.view(C,K))
# logits: Nx(1+K)
logits = cat([l_pos, l_neg], dim=1)
- 1.首先我们应当理解一下这个q和k到底是什么东西,可以看到q和k分别来自于x_q和x_k,我们注意这两个东西其实都来自于x只是作了不同的数据增强罢了。
好了,现在我们应该能判断出来,这里的x和k我们应该认为同一个类别。 - 2.l_pos 现在我们就知道这个东西应该是一个N*1的一组接近1的数值
- 3.我们注意queue是我们存储的之前的batch的内容,所以这个东西和我们当前这个batch的内容应该是没有任何交集的,也就是他们来自于不同的内容,按照对比学习的思想,来自不同事物的内容应该完全不相交。所以他们的相似度应该尽量的低。
- 4.l_neg应当得到一个N*K的一组接近0的数值。
- 5.logits的内容就自然而然出现了,应该为一个N*(K+1)的内容,这些内容应该具有下面的特点:K+1的向量除了第一位接近1之外其他都应该接近0。
- 6.在现在的情况下我们自然而然可以得出一个内容就是,每个(K+1)的张量经过softmax之后,模型都应该判别其为正确。也就是所有的N个张量都是0号分类。
5.交叉熵loss
这里其实不能完全的算成交叉熵损失函数,这个是一个带有热度的交叉熵损失函数。但是其实我们可以将其想成交叉熵函数来简化理解:
# contrastive loss, Eqn.(1)
labels = zeros(N) # positives are the 0-th
loss = CrossEntropyLoss(logits/t, labels)
之前我们谈过了,这里的所有内容都应该是第0个分类,所以我们这里直接让所有的分类都是第0分类就完事了。
模型更新
下面是很正常的backward
# SGD update: query network
loss.backward()
update(f_q.params)
然后就是本文核心的动量优化
# momentum update: key network
f_k.params = m*f_k.params+(1-m)*f_q.params
其实就是让keyencoder也和queryencoder做相同方向的优化
7.更新字典
首先理解什么是字典,就是和什么比较的问题,这个字典就是我们用来和学习的内容比较的内容。这里其实就是实现了将这个batchsize的内容出队将新的这个batchsize进队。
# update dictionary
enqueue(queue, k) # enqueue the current minibatch
dequeue(queue) # dequeue the earliest minibatch
MoCo论文中的Algorithm 1伪代码解读
具体解读了什么东西
论文中提供的伪代码大约如下:
下面我将分步骤介绍这个代码干什么
1.query encoder和key encoder的参数初始化
其实也没表达什么就是一开始大家的参数是一样的:
f_k.params = f_q.params
2.之后就是loader当中取数据
这个也没啥的就是取出来数据的问题:
for x in loader: # load a minibatch x with N samples
3.数据增强
就是代码不是直接将内容输入其中,也会通过数据增强取出内容
x_q = aug(x) # a randomly augmented version
x_k = aug(x) # another randomly augmented version
4.核心操作
首先我们先理解一下这个N和C是什么?
q = f_q.forward(x_q) # queries: NxC
k = f_k.forward(x_k) # keys: NxC
N其实是一个batch_size
C是一个输入数据的特征数,每个输入数据是一个1×C的张量
k = k.detach() # no gradient to keys
这个其实就是文章的主要创新点了,因为优化key_encoder是来自于query_encoder的优化。所以自然就不需要前传梯度,也能剩下个内存。
这里是矩阵乘法,理解一下这里的矩阵乘法:
# positive logits: Nx1
l_pos = bmm(q.view(N,1,C), k.view(N,C,1))
# negative logits: NxK
l_neg = mm(q.view(N,C), queue.view(C,K))
# logits: Nx(1+K)
logits = cat([l_pos, l_neg], dim=1)
- 1.首先我们应当理解一下这个q和k到底是什么东西,可以看到q和k分别来自于x_q和x_k,我们注意这两个东西其实都来自于x只是作了不同的数据增强罢了。
好了,现在我们应该能判断出来,这里的x和k我们应该认为同一个类别。 - 2.l_pos 现在我们就知道这个东西应该是一个N*1的一组接近1的数值
- 3.我们注意queue是我们存储的之前的batch的内容,所以这个东西和我们当前这个batch的内容应该是没有任何交集的,也就是他们来自于不同的内容,按照对比学习的思想,来自不同事物的内容应该完全不相交。所以他们的相似度应该尽量的低。
- 4.l_neg应当得到一个N*K的一组接近0的数值。
- 5.logits的内容就自然而然出现了,应该为一个N*(K+1)的内容,这些内容应该具有下面的特点:K+1的向量除了第一位接近1之外其他都应该接近0。
- 6.在现在的情况下我们自然而然可以得出一个内容就是,每个(K+1)的张量经过softmax之后,模型都应该判别其为正确。也就是所有的N个张量都是0号分类。
5.交叉熵loss
这里其实不能完全的算成交叉熵损失函数,这个是一个带有热度的交叉熵损失函数。但是其实我们可以将其想成交叉熵函数来简化理解:
# contrastive loss, Eqn.(1)
labels = zeros(N) # positives are the 0-th
loss = CrossEntropyLoss(logits/t, labels)
之前我们谈过了,这里的所有内容都应该是第0个分类,所以我们这里直接让所有的分类都是第0分类就完事了。
模型更新
下面是很正常的backward
# SGD update: query network
loss.backward()
update(f_q.params)
然后就是本文核心的动量优化
# momentum update: key network
f_k.params = m*f_k.params+(1-m)*f_q.params
其实就是让keyencoder也和queryencoder做相同方向的优化
7.更新字典
首先理解什么是字典,就是和什么比较的问题,这个字典就是我们用来和学习的内容比较的内容。这里其实就是实现了将这个batchsize的内容出队将新的这个batchsize进队。
# update dictionary
enqueue(queue, k) # enqueue the current minibatch
dequeue(queue) # dequeue the earliest minibatch