本文共介绍四种方法,分别是SumFusion、ConcatFusion、FiLM以及GatedFusion
FiLM参考paper-FiLM: Visual Reasoning with a General Conditioning Layer
GatedFusion参考paper-Efficient Large-Scale Multi-Modal Classification
import torch
import torch.nn as nn
#------------------------------------------#
# SumFusion的定义,为两者过全连接层后进行直接相加
#------------------------------------------#
class SumFusion(nn.Module):
def __init__(self, input_dim=512, output_dim=100):
super(SumFusion, self).__init__()
#---------------------------------------#
# 针对x以及y两个特征张量,分别定义了两个全连接层
#---------------------------------------#
self.fc_x = nn.Linear(input_dim, output_dim)
self.fc_y = nn.Linear(input_dim, output_dim)
def forward(self, x, y):
output = self.fc_x(x) + self.fc_y(y)
return x, y, output
#------------------------------------------#
# ConcatFusion的定义,只定义一个全连接层
# 首先将两者堆叠,之后再将堆叠后的向量送入至全连接层
#------------------------------------------#
class ConcatFusion(nn.Module):
def __init__(self, input_dim=1024, output_dim=100):
super(ConcatFusion, self).__init__()
self.fc_out = nn.Linear(input_dim, output_dim)
def forward(self, x, y):
output = torch.cat((x, y), dim=1)
output = self.fc_out(output)
return x, y, output
#------------------------------------------#
# FiLM融合方法的定义,只定义一个全连接层
#------------------------------------------#
class FiLM(nn.Module):
"""
FiLM: Visual Reasoning with a General Conditioning Layer,
https://arxiv.org/pdf/1709.07871.pdf.
"""
def __init__(self, input_dim=512, dim=512, output_dim=100, x_film=True):
super(FiLM, self).__init__()
self.dim = input_dim
self.fc = nn.Linear(input_dim, 2 * dim)
self.fc_out = nn.Linear(dim, output_dim)
self.x_film = x_film
def forward(self, x, y):
if self.x_film:
film = x
to_be_film = y
else:
film = y
to_be_film = x
gamma, beta = torch.split(self.fc(film), self.dim, 1)
output = gamma * to_be_film + beta
output = self.fc_out(output)
return x, y, output
#------------------------------------------#
# GatedFusion方法的定义
#------------------------------------------#
class GatedFusion(nn.Module):
"""
Efficient Large-Scale Multi-Modal Classification,
https://arxiv.org/pdf/1802.02892.pdf.
"""
def __init__(self, input_dim=512, dim=512, output_dim=100, x_gate=True):
super(GatedFusion, self).__init__()
self.fc_x = nn.Linear(input_dim, dim)
self.fc_y = nn.Linear(input_dim, dim)
self.fc_out = nn.Linear(dim, output_dim)
self.x_gate = x_gate # whether to choose the x to obtain the gate
self.sigmoid = nn.Sigmoid()
def forward(self, x, y):
out_x = self.fc_x(x)
out_y = self.fc_y(y)
if self.x_gate:
gate = self.sigmoid(out_x)
output = self.fc_out(torch.mul(gate, out_y))
else:
gate = self.sigmoid(out_y)
output = self.fc_out(torch.mul(out_x, gate))
return out_x, out_y, output