多模态学习中四种常用的跨模态特征融合方法定义与PyTorch实现

本文共介绍四种方法,分别是SumFusion、ConcatFusion、FiLM以及GatedFusion

FiLM参考paper-FiLM: Visual Reasoning with a General Conditioning Layer

GatedFusion参考paper-Efficient Large-Scale Multi-Modal Classification

import torch
import torch.nn as nn

#------------------------------------------#
# SumFusion的定义,为两者过全连接层后进行直接相加
#------------------------------------------#
class SumFusion(nn.Module):
    def __init__(self, input_dim=512, output_dim=100):
        super(SumFusion, self).__init__()
        #---------------------------------------#
        # 针对x以及y两个特征张量,分别定义了两个全连接层
        #---------------------------------------#
        self.fc_x = nn.Linear(input_dim, output_dim)
        self.fc_y = nn.Linear(input_dim, output_dim)

    def forward(self, x, y):
        output = self.fc_x(x) + self.fc_y(y)
        return x, y, output

#------------------------------------------#
# ConcatFusion的定义,只定义一个全连接层
# 首先将两者堆叠,之后再将堆叠后的向量送入至全连接层
#------------------------------------------#
class ConcatFusion(nn.Module):
    def __init__(self, input_dim=1024, output_dim=100):
        super(ConcatFusion, self).__init__()
        self.fc_out = nn.Linear(input_dim, output_dim)

    def forward(self, x, y):
        output = torch.cat((x, y), dim=1)
        output = self.fc_out(output)
        return x, y, output

#------------------------------------------#
# FiLM融合方法的定义,只定义一个全连接层
#------------------------------------------#
class FiLM(nn.Module):
    """
    FiLM: Visual Reasoning with a General Conditioning Layer,
    https://arxiv.org/pdf/1709.07871.pdf.
    """
    def __init__(self, input_dim=512, dim=512, output_dim=100, x_film=True):
        super(FiLM, self).__init__()
        self.dim    = input_dim
        self.fc     = nn.Linear(input_dim, 2 * dim)
        self.fc_out = nn.Linear(dim, output_dim)
        self.x_film = x_film

    def forward(self, x, y):
        if self.x_film:
            film = x
            to_be_film = y
        else:
            film = y
            to_be_film = x

        gamma, beta = torch.split(self.fc(film), self.dim, 1)

        output = gamma * to_be_film + beta
        output = self.fc_out(output)

        return x, y, output

#------------------------------------------#
# GatedFusion方法的定义
#------------------------------------------#
class GatedFusion(nn.Module):
    """
    Efficient Large-Scale Multi-Modal Classification,
    https://arxiv.org/pdf/1802.02892.pdf.
    """

    def __init__(self, input_dim=512, dim=512, output_dim=100, x_gate=True):
        super(GatedFusion, self).__init__()
        self.fc_x    = nn.Linear(input_dim, dim)
        self.fc_y    = nn.Linear(input_dim, dim)
        self.fc_out  = nn.Linear(dim, output_dim)
        self.x_gate  = x_gate  # whether to choose the x to obtain the gate
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        out_x = self.fc_x(x)
        out_y = self.fc_y(y)

        if self.x_gate:
            gate   = self.sigmoid(out_x)
            output = self.fc_out(torch.mul(gate, out_y))
        else:
            gate   = self.sigmoid(out_y)
            output = self.fc_out(torch.mul(out_x, gate))

        return out_x, out_y, output

你可能感兴趣的:(深度学习,pytorch,人工智能,音视频)