参考 https://github.com/megvii-research/CREStereo/issues/32
把相关部分替代一下即可。
注意缩进,不然就会IndentationError: unexpected indent
def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):
N, C, H, W = left_feature.shape
di_y, di_x = dilate[0], dilate[1]
pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x
right_pad = F.pad(right_feature, pad_width=(
(0, 0), (0, 0), (pady, pady), (padx, padx)), mode="replicate")
right_slid = F.sliding_window(
right_pad, kernel_size=(H, W), stride=(di_y, di_x)) # N, C, 1, 9, H, W
right_slid = right_slid.reshape(N, C, -1, H, W) # N, C, 9, H, W
right_slid = F.transpose(right_slid, (0, 2, 1, 3, 4)) # N, 9, C, H, W
# This is buggy when trying to train with batch size larger than 1
# right_slid = right_slid.reshape(-1, C, H, W) # N * 9, C, H, W
# corr_mean = F.mean(left_feature * right_slid, axis=1, keepdims=True) # 9, N, H, W
# corr_final = corr_mean.reshape(1, -1, H, W) # N, 9, H, W
Nr, num_slides, Cr, Hr, Wr = right_slid.shape
corr_list = []
for slide_idx in range(0, num_slides):
right_crop = right_slid[:, slide_idx, :, :, :]
assert right_crop.shape == left_feature.shape
corr = F.mean(left_feature * right_crop, axis=1, keepdims=True)
corr_list.append(corr)
corr_final = F.concat(corr_list, axis=1) # N, 9, H, W
return corr_final