Efficient Token-Guided Image-Text Retrieval withConsistent Multimodal Contrastive Training

paper: https://arxiv.org/pdf/2306.08789.pdf

code: https://github.com/LCFractal/TGDT

1. 论文核心思想

  • 整合了粗粒度与细粒度检索,利用了二者的优点
  • 新的训练目标 :Consistent Multimodal Contrastive (CMC) loss ,确保模态内和模态间语义一致性
  • 基于混合全局和局部的跨模态相似性两阶段推理方法
  • 效果:检索精度高,推理时间小

1.1 前置知识

Image-Text Retrieval

Image-Text Retrieval involves searching and retrieving relevant images from a dataset based on textual queries or vice versa. This process relies on understanding and matching the semantic content between images and texts. In essence, it requires a system to accurately interpret the context and details within both visual and textual data to find correspondences, enabling applications like enhancing search engines, aiding content discovery, and facilitating multimedia databases management.

图像-文本检索涉及根据文本查询从数据集中搜索和检索相关图像,反之亦然。这个过程依赖于对图像和文本之间语义内容的理解和匹配。从本质上讲,它需要一个系统来准确解释视觉和文本数据中的上下文和细节,以查找对应关系,从而实现增强搜索引擎、协助内容发现和促进多媒体数据库管理等应用。


Efficient Token-Guided Image-Text Retrieval withConsistent Multimodal Contrastive Training_第1张图片

1.2 图像处理

  • Faster R-CNN 产生

    1. 图片中实例的位置及其特征(局部特征)
    2. 图片全局特征
  • 通过transformer encoder 产生cross-modal representations

1.3 文本处理

  • BERT产生单词以及句子级别的特征
  • 通过另一个 transformer encoder 产生linguistic representations

1.4 全局检索和局部检索

  • 全局检索匹配整个图像的特征和句子级别的特征

  • 局部检索在token-level alignment后获得跨模态相似性

  • Token-level alignment 意味着在多模态任务中,对于每个模态(例如文本和图像),模型在处理它们时会保持标记级别的对齐。换句话说,模型会尝试将文本中的每个单词或图像中的每个区域与其他模态中的相应部分进行对齐,以便模型能够理解它们之间的关联。

    例如,在图像描述任务中,模型需要对图像中的每个区域与文本描述中的每个单词进行对齐。这种对齐可以帮助模型理解图像中的不同对象或场景与文本描述中的相应概念之间的对应关系。类似地,在音频描述任务中,模型需要将音频片段中的每个时间步与文本描述中的每个单词进行对齐。

    在多模态任务中,保持标记级别的对齐非常重要,因为它可以帮助模型更好地理解不同模态之间的关联,从而更准确地执行任务,如图像描述、视觉问答或视频理解。

1.5 Consistent Multimodal Contrastive Training (CMC) loss

使用Consistent Multimodal Contrastive Training (CMC) loss 同时训练两个网络(全局和局部)

Efficient Token-Guided Image-Text Retrieval withConsistent Multimodal Contrastive Training_第2张图片

传统方法Multimodal Contrastive Loss :
在这里插入图片描述

也就是通过找到一个锚点图片的最难负样本和相应的文本,通过找到一个锚点文本的最难负样本和相应的图片

缺点: 对于同一模态的样本缺乏限制

改进Consistent Multimodal Contrastive Loss :

在这里插入图片描述

增加了模态内约束。它旨在确保在同一模态(图像与图像或文本与文本)中,不匹配的对也在嵌入空间中分离。变量σ充当松弛变量,以灵活地控制来自不同模态的样本距离之间的间隙,允许一些不一致以适应距离的自然变化。

CMC Loss
在这里插入图片描述

一方面,Lr控制样本之间的距离。另一方面,La保证了匹配样本之间距离的一致性。因此,结合两种损失可以保证多模态样本之间的局部和全局相似度的一致性。

1.6 Two-Stage Inference Method

inference

在深度学习中,“inference”(推断)是指使用已经训练好的模型来对新的、未见过的数据进行预测或分类的过程。在推断阶段,模型已经完成了训练过程,参数已经被学习,并且模型已经具备了对数据进行预测或分类的能力。

Multitask learning

多任务学习是一种机器学习范式,其中模型被训练以同时执行多个任务。与为每个任务训练单独的模型不同,多任务学习利用跨任务共享的信息来提高每个单独任务的性能。当任务相关或共享潜在结构或特征时,这种方法尤其有益。

1.6.1 Training

最终训练损失函数:

在这里插入图片描述

表示总损失是全局和局部相似性的 CMC 损失之和。这种组合损失函数允许模型在端到端可训练网络中跨两种模态(图像和文本)共同学习全局和局部表示。该方法旨在利用多任务学习来约束参数空间,与单独训练它们相比,在这两个任务上都能获得更好的表示和更高的性能。

1.6.2 Inference

为了平衡精度和速度,该文提出了一种两阶段推理过程

  • 在第一阶段,使用全局检索来快速缩小候选样本的范围
  • 在第二阶段,结合了全局和局部信息的混合相似性用于对这些候选者进行重新排名,以获得更准确的最终结果。

这种双层方法允许快速进行初始过滤,然后进行更精确的选择,从而优化检索任务的速度和准确性。

2. 代码

2.1 训练

  1. 加载相关配置文件与预设参数
  • 指定log文件与对象
  1. 加载数据集
    train_loader, val_loader = data.get_loaders(config, opt.workers)
  1. collate_fn = Collate(config) 这行代码为了提供一个定制的批处理函数,以确保在训练和验证过程中能够正确地处理数据。
  2. 通过 get_paths 函数获取数据集的根路径和相关信息,包括图像路径、注释文件路径以及数据集的拆分情况等。
  3. 使用 get_transform 函数根据数据集的名称和拆分情况获取相应的数据转换(数据预处理)操作。这些转换操作通常包括将图像转换为 PyTorch 张量、归一化等操作。
  4. 调用 get_loader_single 函数来创建训练和验证数据加载器。该函数用于构建单个数据集的数据加载器,并设置了一些参数,如批量大小、是否随机打乱数据等。
  5. 返回创建好的训练和验证数据加载器
  1. get_model(config)函数是一个简单的函数,用于获取模型。创建了一个 BASELINE 类的实例并返回该实例。
  • BASELINE 的模型类是一个 PyTorch 模型,用于图像和文本的联合表示学习任务

这个模型由两个主要组件组成:

  1. JointTextImageTransformerEncoder 类:该类定义了一个联合的文本-图像编码器,它将输入的图像和文本进行编码,并输出它们的联合表示。这个编码器包含了一个文本编码器(EncoderText),一个图像编码器(EncoderImage),以及一些转换器编码层(TransformerEncoderLayer)用于处理编码后的特征。该类的 forward 方法定义了模型的前向传播逻辑,包括将输入的图像和文本进行编码,并将它们的特征通过一系列的转换器编码层进行处理,最后输出得到联合的图像和文本表示。
  2. BASELINE 类:该类是模型的主类,它包含了 JointTextImageTransformerEncoder 类的一个实例,并定义了模型的前向传播逻辑和损失函数。该类的 forward 方法将输入的图像和文本传递给 JointTextImageTransformerEncoder 类进行编码,并根据任务类型计算损失。

此外,模型还包括一些其他的辅助函数和变量,用于初始化模型参数、设置训练模式等。


  • models.text

定义了两个文本编码器:EncoderTextGRUEncoderTextBERT

  1. EncoderTextGRU 类是一个基于GRU的文本编码器。它将输入的文本序列中的单词索引转换为单词嵌入,并通过一个或多个GRU层对单词嵌入进行编码。编码后的输出是文本的表示,其维度为[batch_size, hidden_size]
  2. EncoderTextBERT 类是基于BERT的文本编码器。它可以直接使用预训练的BERT模型来对输入的文本序列进行编码,也可以选择在BERT模型之后添加一个或多个Transformer编码层以进一步处理文本特征。最终输出的文本表示与 EncoderTextGRU 类似,也是一个维度为[batch_size, hidden_size]的张量。

  • models.visual

定义了几个图像编码器,根据配置文件中的参数选择合适的编码器类型。

  1. EncoderImageFull: 这是一个完整的图像编码器,它可以选择使用预训练的CNN模型(如VGG或ResNet)提取图像特征,并将这些特征映射到固定维度的向量空间。它还支持使用Transformer网络对图像特征进行进一步处理。具体来说,它可以选择在CNN之后添加一个Transformer网络层。最终输出的图像表示是一个维度为[batch_size, embed_size]的张量。

  2. EncoderImagePrecomp: 这是一个预先提取的图像特征编码器。它假设输入已经是预先计算好的图像特征,并且直接将这些特征映射到固定维度的向量空间。这种编码器通常用于使用预先提取的图像特征的情况,例如在某些数据集上已经有提前计算好的图像特征。

  3. TransformerPostProcessing: 这是一个基于Transformer的图像编码器,它使用Transformer网络对输入的视觉特征进行处理。它支持在Transformer之后添加池化层以及线性投影层,以便将输出的特征映射到固定维度的向量空间。最终输出的图像表示也是一个维度为[batch_size, embed_size]的张量。


  • 得到特征后根据相似度计算损失

对齐对比损失(Alignment Contrastive Loss)也就是CMC Loss?

对齐对比损失旨在优化模型以确保相对应的图像和文本嵌入彼此靠近,同时使不相关的嵌入远离。在这种损失函数中,对齐的概念不仅仅是简单的匹配。它可能涉及到更复杂的关系,比如序列中的不同元素(图像中的对象或文本中的单词)之间的对齐。这种方法试图捕获图像和文本之间更细粒度的关系,而不是仅仅基于整体的图像-文本相似度。

标准对比损失(Standard Contrastive Loss)

标准对比损失,通常简称为对比损失,是一种更直观的方法,目的是缩小正样本对(即相关的图像和文本对)之间的距离,同时扩大负样本对(不相关的图像和文本对)之间的距离。这种方法通常涉及计算一个相似度矩阵,其中每个元素代表一个图像-文本对的相似度得分,然后通过比较正样本对和负样本对的得分来计算损失。

你可能感兴趣的:(人工智能,算法,深度学习)