CREStereo先前记录了大概框架,现在看看有些存在感很强的模块具体的逻辑。
论文阅读
源码
论文
上一篇博客
经过框架大概阅读,列举了这些需要深入分析的模块:
LocalFeatureTransformer, AGCL, BasicUpdateBlock, convex_upsample。注意力应该会和AGCL在一块儿看。
LoFTR阅读
LoFTR参考源码
还是先从代码角度看,不是很理解的看别人的论文阅读吧。直接从forward开始。可以看到这个层单纯的是自注意力和交叉注意力的任意组合,具体组合方式看self.layer是怎么定义的。
for layer, name in zip(self.layers, self.layer_names):
if name == "self":
feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1)
elif name == "cross":
feat0 = layer(feat0, feat1, mask0, mask1)
feat1 = layer(feat1, feat0, mask1, mask0)
else:
raise KeyError
于是看向初始化部分。可以看到套多少个子模块连接是由layer_names决定。这个子模块是由 LoFTREncoderLayer(d_model, nhead, attention)
定义,于是接着把代码往上翻,看这个模块具体内容。
def __init__(self, d_model, nhead, layer_names, attention):
super(LocalFeatureTransformer, self).__init__()
self.d_model = d_model
self.nhead = nhead
self.layer_names = layer_names
encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
self.layers = [
copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))
]
self._reset_parameters()
还是从forward开始看,这边就有点难对应了,结合论文来看。本来想顺着说一遍,读着读着发现基本一致,就看这边论文的描述吧。
def forward(self, x, source, x_mask=None, source_mask=None):
bs = x.shape[0]
query, key, value = x, source, source
# multi-head attention
query = F.reshape(
self.q_proj(query), (bs, -1, self.nhead, self.dim)
) # [N, L, (H, D)] (H=8, D=256//8)
key = F.reshape(
self.k_proj(key), (bs, -1, self.nhead, self.dim)
) # [N, S, (H, D)]
value = F.reshape(self.v_proj(value), (bs, -1, self.nhead, self.dim))
message = self.attention(
query, key, value, q_mask=x_mask, kv_mask=source_mask
) # [N, L, (H, D)]
message = self.merge(
F.reshape(message, (bs, -1, self.nhead * self.dim))
) # [N, L, C]
message = self.norm1(message)
# feed-forward network
message = self.mlp(F.concat([x, message], axis=2))
message = self.norm2(message)
return x + message
PS:在LinearAttention里面有个小小的可以优化的小地方,基本对速度没啥帮助哈哈。就是megengine好像没有原生的elu算子,他是这么表示的。我左瞅右瞅,感觉这边加减1可以消掉,elu_feature_map注释掉的部分是我的写法。
PPS:其实这么看下来,感觉transformer架构很特别,但也只是很特别,感觉reshape之类的冗余操作太多有点影响计算速度。是不是可以从一些角度对此进行一些修改?顺带代码中用的是Linear Transformer,我不是很了解一些变体,看到这么一篇文章,是不是可以试试这里面提到的Performer?
def elu(x, alpha=1.0):
return F.maximum(0, x) + F.minimum(0, alpha * (F.exp(x) - 1))
def elu_feature_map(x):
return elu(x) + 1
# return F.relu(x) + F.minimum(1, F.exp(x))
好家伙,这玩意儿好长,不想看。
先从__call__开始阅读。简单来看就是对应上这两种模块。然后对应以下具体内容。
def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
if iter_mode:
corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
else:
corr = self.corr_att_offset(
self.fmap1, self.fmap2, flow, extra_offset, small_patch
)
return corr
首先是corr_iter。self.coords是模块初始化时生成的相当于是点坐标集,加上flow意思就是新的坐标集,然后从右图按坐标集采点。从这边的small_patch就可以看出他的搜索方式是可选择的1D或2D搜索,但是保证是9个点,从而保证计算的一致性。然后将左右特征按照特征通道数4分割,self.get_correlation就是在psize_list[i]范围内,执行9次相关性计算然后拼接出一个相关性结果,最后四个部分拼接成最后输出的结果。
这边为什么要先分出4部分再融合呢?不知道是处于精简计算的关系还是和输入内容有关,之后要更细致的分析。
def corr_iter(self, left_feature, right_feature, flow, small_patch):
coords = self.coords + flow
coords = F.transpose(coords, (0, 2, 3, 1))
right_feature = bilinear_sampler(right_feature, coords)
if small_patch:
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
else:
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
N, C, H, W = left_feature.shape
lefts = F.split(left_feature, 4, axis=1)
rights = F.split(right_feature, 4, axis=1)
corrs = []
for i in range(len(psize_list)):
corr = self.get_correlation(
lefts[i], rights[i], psize_list[i], dilate_list[i]
)
corrs.append(corr)
final_corr = F.concat(corrs, axis=1)
return final_corr