多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)

文章目录

    • 写在前面
    • 简单的concat
    • TFN融合策略
    • LWF融合策略

论文全称:
《Tensor Fusion Network for Multimodal Sentiment Analysis》
《Efficient Low-rank Multimodal Fusion with Modality-Specific Factors》

写在前面

最近在做一个分类的比赛,想要用上数据中的多模态信息(主要是文本和图像特征),因此探索了一些多模态特征的融合机制,并记录下来。

下文中均以3种不同模态下的特征融合为例。并设A模态特征维度为512,B模态特征维度为1024,C模态特征维度为32

import torch
A = torch.randn(16, 512)
B = torch.randn(16, 1024)
C = torch.randn(16, 32)

简单的concat

concat既是最简单也是最常用的一种方式,直接在特征维度将不同模态特征进行拼接后,再送入后续的推理模块。

fusion_feature = torch.cat([A, B, C], dim=1)

TFN融合策略

原理简述

TFN来自17年EMNLP会议论文《Tensor Fusion Network for Multimodal Sentiment Analysis》,其主要考虑了inter-modalityintar-modality两个方面。也就是要求既能考虑各模态之间的特征融合,也要有效地利用各特定模态的特征。

多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)_第1张图片
图左为Early Fusion策略,其实就是之前提到的concat方法,图右展现了作者提出的TFN模块(Tensor Fusion Network)。具体做法就是首先对每个模态用1进行维度扩充,然后对不同模态求笛卡尔积

以两个模态为例,对 z v , z l z_v,z_l zv,zl1先扩充一维,得到后的特征再进行outer product(外积,张量积)。可以看到,用1扩充后,即计算了两个模态间的特征相关性,又保留了特定模态的信息。
多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)_第2张图片
同理,对三个模态求得了笛卡尔积后( [ z a ; 1 ] ⨂ [ z b ; 1 ] ⨂ [ z c ; 1 ] [z_a; 1] \bigotimes [z_b; 1] \bigotimes [z_c; 1] [za;1][zb;1][zc;1]),即计算了两两模态间的特征、三模态间的特征,又保留了各特定模态中的特征(见上图的Tensor Fusion细节)。

n = A.shape[0]
# 用 1 扩充维度
A = torch.cat([A, torch.ones(n, 1)], dim=1)
B = torch.cat([B, torch.ones(n, 1)], dim=1)
C = torch.cat([C, torch.ones(n, 1)], dim=1)
# 计算笛卡尔积
A = A.unsqueeze(2)  # [n, A, 1]
B = B.unsqueeze(1)  # [n, 1, B]
fusion_AB = torch.einsum('nxt, nty->nxy', A, B)  # [n, A, B]
fusion_AB = fusion_AB.flatten(start_dim=1).unsqueeze(1) # [n, AxB, 1]
C = C.unsqueeze(1) # [n, 1, C]
fusion_ABC = torch.einsum('ntx, nty->nxy', fusion_AB, C) # [n, AxB, C]
fusion_ABC = fusion_ABC.flatten(start_dim=1)  # [n, AxBxC]
# A, B, C分别代表原来的特征维度nA,nB,nC加上1

需要注意的是,实际编程实现时并未直接计算得到3-D的笛卡尔积,而是分别两两计算outer product

LWF融合策略

上面提到的TFN对计算了两/三模态间的相关性,也保留了单模态的相关性,但同时也大大地增加了特征维度。增加特征维度从而会影响计算效率以及增加内存消耗,并且TFN所增加的时间/空间复杂度都与输入模态数呈指数增加。并且参数量一多,就容易增加过拟合的风险。

LMF是发表于ACL2017年的工作,针对TFN的上述问题,作者采用了low-rank weight进行多模态融合,降低参数量的同时还提升了计算速度。

建议先看看这篇博客:LWF论文解读

TFN中的融合后的特征Z维度为d1xd2xd3x....dm,其中m表示模态数,i模态特征维度为di。后续要将其送入推理模块中,通常需要降到h维的特征F,此时需要一个维度为(d1xd2xd3x....dm)xh的(M+1阶)权重W进行全连接操作。

全连接操作中,W可以视为h个M阶矩阵,每个矩阵与融合特征Z计算后的结果为F中的一维。

LMF要做的是就是将W分解成M组与各模态相关low-rank因子。按照上述的视角,将W视为h个矩阵,每个特征矩阵Wk如下所示,其中使得分解成立的最小R称为秩(Rank)。
在这里插入图片描述
在LMF中,人为设定固定的秩r,得到每个Wk矩阵了,对特征矩阵进行重新排列,使其变为与模态m相关的特征Wm
在这里插入图片描述
为了更好地理解排列过程,我画了一张图,展示了3个模态时,秩为r,期望维度为h的情况:
多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)_第3张图片

多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)_第4张图片

那么对特征变换(Zd维特征)的过程可以拆分为如下过程:
多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)_第5张图片
多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)_第6张图片

Z本身也是由不同模态的外积得到的,那么组合起来可得到下式。
多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)_第7张图片
其中 Λ \Lambda Λ表示像素级点乘。这样分解之后,避免了从各模态特征Zm去建模Z,并且可以扩展到不同数量的模态上,大大降低了时间复杂度。以3模态的融合为例,图例如下:

多模态特征融合机制(含代码):TFN(Tensor Fusion Network)和LMF(Low-rank Multimodal Fusion)_第8张图片
从上图可知,最后的由多模态特征Zm融合成h维特征的过程就变成了:每个模态分别构建r个权重矩阵,融合后对各模态特征进行矩阵乘法,得到一个h维的特征;然后再将各模态得到的h维特征进行像素级乘法即可。代码如下:

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

A = torch.randn(16, 512)
B = torch.randn(16, 1024)
C = torch.randn(16, 32)

n = A.shape[0]
A = torch.cat([A, torch.ones(n, 1)], dim=1)
B = torch.cat([B, torch.ones(n, 1)], dim=1)
C = torch.cat([C, torch.ones(n, 1)], dim=1)

# 假设所设秩: R = 4, 期望融合后的特征维度: h = 128
R, h = 4, 128
Wa = Parameter(torch.Tensor(R, A.shape[1], h))
Wb = Parameter(torch.Tensor(R, B.shape[1], h))
Wc = Parameter(torch.Tensor(R, C.shape[1], h))
Wf = Parameter(torch.Tensor(1, R))
bias = Parameter(torch.Tensor(1, h))

# 分解后,并行提取各模态特征
fusion_A = torch.matmul(A, Wa)
fusion_B = torch.matmul(B, Wb)
fusion_C = torch.matmul(C, Wc)

# 利用一个Linear再进行特征融合(融合R维度)
funsion_ABC = fusion_A * fusion_B * fusion_C
funsion_ABC = torch.matmul(Wf, funsion_ABC.permute(1,0,2)).squeeze() + bias

你可能感兴趣的:(快乐ML/DL,深度学习)