论文的主要贡献在我看来有两个:
已知卷积神经网络在具有足够的标记数据的情况下非常擅长学习输入输出关系。因此,采用端到端的学习方法来预测光流:
给定一个由图像对和光流组成的数据集,我们训练网络以直接从图像中预测x-y的光流场。但是,为此目的,好的架构是什么?
一个简单的选择是将两个输入图像堆叠在一起,并通过一个相当通用的网络将其输入,从而使网络可以自行决定如何处理图像对以提取运动信息。这种仅包含卷积层的架构为“FlowNetSimple”:
炼丹兄简单讲网络结构:
我们来看refinement部分,其实这个部分跟Unet也有些类似,但是又有独特的光流模型的特性。
lass FlowNetS(nn.Module):
expansion = 1
def __init__(self,batchNorm=True):
super(FlowNetS,self).__init__()
self.batchNorm = batchNorm
self.conv1 = conv(self.batchNorm, 6, 64, kernel_size=7, stride=2)
self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
self.conv3_1 = conv(self.batchNorm, 256, 256)
self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
self.conv4_1 = conv(self.batchNorm, 512, 512)
self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
self.conv5_1 = conv(self.batchNorm, 512, 512)
self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
self.conv6_1 = conv(self.batchNorm,1024, 1024)
self.deconv5 = deconv(1024,512)
self.deconv4 = deconv(1026,256)
self.deconv3 = deconv(770,128)
self.deconv2 = deconv(386,64)
self.predict_flow6 = predict_flow(1024)
self.predict_flow5 = predict_flow(1026)
self.predict_flow4 = predict_flow(770)
self.predict_flow3 = predict_flow(386)
self.predict_flow2 = predict_flow(194)
self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
kaiming_normal_(m.weight, 0.1)
if m.bias is not None:
constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
constant_(m.weight, 1)
constant_(m.bias, 0)
def forward(self, x):
out_conv2 = self.conv2(self.conv1(x))
out_conv3 = self.conv3_1(self.conv3(out_conv2))
out_conv4 = self.conv4_1(self.conv4(out_conv3))
out_conv5 = self.conv5_1(self.conv5(out_conv4))
out_conv6 = self.conv6_1(self.conv6(out_conv5))
flow6 = self.predict_flow6(out_conv6)
flow6_up = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)
concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
flow5 = self.predict_flow5(concat5)
flow5_up = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)
concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
flow4 = self.predict_flow4(concat4)
flow4_up = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)
concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
flow3 = self.predict_flow3(concat3)
flow3_up = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2)
out_deconv2 = crop_like(self.deconv2(concat3), out_conv2)
concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
flow2 = self.predict_flow2(concat2)
if self.training:
return flow2,flow3,flow4,flow5,flow6
else:
return flow2
def weight_parameters(self):
return [param for name, param in self.named_parameters() if 'weight' in name]
def bias_parameters(self):
return [param for name, param in self.named_parameters() if 'bias' in name]
import torch
import torch.nn.functional as F
def EPE(input_flow, target_flow, sparse=False, mean=True):
EPE_map = torch.norm(target_flow-input_flow,2,1)
batch_size = EPE_map.size(0)
if sparse:
# invalid flow is defined with both flow coordinates to be exactly 0
mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
EPE_map = EPE_map[~mask]
if mean:
return EPE_map.mean()
else:
return EPE_map.sum()/batch_size
def multiscaleEPE(network_output, target_flow, weights=None, sparse=False):
def one_scale(output, target, sparse):
b, _, h, w = output.size()
if sparse:
target_scaled = sparse_max_pool(target, (h, w))
else:
target_scaled = F.interpolate(target, (h, w), mode='area')
return EPE(output, target_scaled, sparse, mean=False)
if type(network_output) not in [tuple, list]:
network_output = [network_output]
if weights is None:
weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article
assert(len(weights) == len(network_output))
loss = 0
for output, weight in zip(network_output, weights):
loss += weight * one_scale(output, target_flow, sparse)
return loss
这里和之前的simple版本的区别,在于:先对图片做了相同的特征处理,类似于孪生网络,然后对于提取的两个特征图,做论文中提出的叫做correlation处理,融合成一个特征图,然后再做类似于simple版本的后续处理。
这里直接看模型代码:
class FlowNetC(nn.Module):
expansion = 1
def __init__(self,batchNorm=True):
super(FlowNetC,self).__init__()
self.batchNorm = batchNorm
self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2)
self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1)
self.conv3_1 = conv(self.batchNorm, 473, 256)
self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
self.conv4_1 = conv(self.batchNorm, 512, 512)
self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
self.conv5_1 = conv(self.batchNorm, 512, 512)
self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
self.conv6_1 = conv(self.batchNorm,1024, 1024)
self.deconv5 = deconv(1024,512)
self.deconv4 = deconv(1026,256)
self.deconv3 = deconv(770,128)
self.deconv2 = deconv(386,64)
self.predict_flow6 = predict_flow(1024)
self.predict_flow5 = predict_flow(1026)
self.predict_flow4 = predict_flow(770)
self.predict_flow3 = predict_flow(386)
self.predict_flow2 = predict_flow(194)
self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
kaiming_normal_(m.weight, 0.1)
if m.bias is not None:
constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
constant_(m.weight, 1)
constant_(m.bias, 0)
def forward(self, x):
x1 = x[:,:3]
x2 = x[:,3:]
out_conv1a = self.conv1(x1)
out_conv2a = self.conv2(out_conv1a)
out_conv3a = self.conv3(out_conv2a)
out_conv1b = self.conv1(x2)
out_conv2b = self.conv2(out_conv1b)
out_conv3b = self.conv3(out_conv2b)
out_conv_redir = self.conv_redir(out_conv3a)
out_correlation = correlate(out_conv3a,out_conv3b)
in_conv3_1 = torch.cat([out_conv_redir, out_correlation], dim=1)
out_conv3 = self.conv3_1(in_conv3_1)
out_conv4 = self.conv4_1(self.conv4(out_conv3))
out_conv5 = self.conv5_1(self.conv5(out_conv4))
out_conv6 = self.conv6_1(self.conv6(out_conv5))
flow6 = self.predict_flow6(out_conv6)
flow6_up = crop_like(self.upsampled_flow6_to_5(flow6), out_conv5)
out_deconv5 = crop_like(self.deconv5(out_conv6), out_conv5)
concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
flow5 = self.predict_flow5(concat5)
flow5_up = crop_like(self.upsampled_flow5_to_4(flow5), out_conv4)
out_deconv4 = crop_like(self.deconv4(concat5), out_conv4)
concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
flow4 = self.predict_flow4(concat4)
flow4_up = crop_like(self.upsampled_flow4_to_3(flow4), out_conv3)
out_deconv3 = crop_like(self.deconv3(concat4), out_conv3)
concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
flow3 = self.predict_flow3(concat3)
flow3_up = crop_like(self.upsampled_flow3_to_2(flow3), out_conv2a)
out_deconv2 = crop_like(self.deconv2(concat3), out_conv2a)
concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1)
flow2 = self.predict_flow2(concat2)
if self.training:
return flow2,flow3,flow4,flow5,flow6
else:
return flow2
def weight_parameters(self):
return [param for name, param in self.named_parameters() if 'weight' in name]
def bias_parameters(self):
return [param for name, param in self.named_parameters() if 'bias' in name]
里面的关键在这个部分:
out_conv_redir = self.conv_redir(out_conv3a)
out_correlation = correlate(out_conv3a,out_conv3b)
in_conv3_1 = torch.cat([out_conv_redir, out_correlation], dim=1)
from spatial_correlation_sampler import spatial_correlation_sample
,但是这个库并没有在代码中提供,所以关于这个版本的flownet,我也就此作罢。我猜测这个模块是作者引用别人的代码,应该在github主页有说明,但是我这里上github太卡了,回头有空再补充这个知识点把。(不过一般也没有什么人看文章哈哈,没人问我的话,那我就忽视这个坑了2333)