VQA学习笔记(一)CNN-LSTM

 

笔者小白,初学VQA,如有不对之处还请指教。

mmf是什么?官方提供的README中是这么说的:

MMF is a modular framework for vision and language multimodal research from Facebook AI Research. MMF contains reference implementations of state-of-the-art vision and language models and has powered multiple research projects at Facebook AI Research. See full list of project inside or built on MMF here.

mmf中包含了许多vqa中基本模型的实现,通过学习这些模型的代码实现,可以快速地了解vqa的发展与技术基础。

今天首先从CNNLSTM这个最简单的模型出发,学习mmf构建模型的基本框架。VQA学习笔记(一)CNN-LSTM_第1张图片

 

上图是该模型的基本思路。mmf中的代码仅对融合特征进行了Classfiy并没有进行RNN 的 decoder 。

# Copyright (c) Facebook, Inc. and its affiliates.
​
from copy import deepcopy
​
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
from mmf.modules.layers import ClassifierLayer, ConvNet, Flatten
from torch import nn
​
​
_TEMPLATES = {
    "question_vocab_size": "{}_text_vocab_size",
    "number_of_answers": "{}_num_final_outputs",
}
​
_CONSTANTS = {"hidden_state_warning": "hidden state (final) should have 1st dim as 2"}
​
​
@registry.register_model("cnn_lstm")
class CNNLSTM(BaseModel):
    """CNNLSTM is a simple model for vision and language tasks. CNNLSTM is supposed
    to acts as a baseline to test out your stuff without any complex functionality.
    Passes image through a CNN, and text through an LSTM and fuses them using
    concatenation. Then, it finally passes the fused representation from a MLP to
    generate scores for each of the possible answers.
​
    Args:
        config (DictConfig): Configuration node containing all of the necessary
                             config required to initialize CNNLSTM.
​
    Inputs: sample_list (SampleList)
        - **sample_list** should contain image attribute for image, text for
          question split into word indices, targets for answer scores
    """
​
    def __init__(self, config):
        super().__init__(config)
        self._global_config = registry.get("config")
        self._datasets = self._global_config.datasets.split(",")
​
    @classmethod
    def config_path(cls):
        return "configs/models/cnn_lstm/defaults.yaml"
​
    def build(self):
        assert len(self._datasets) > 0
        num_question_choices = registry.get(
            _TEMPLATES["question_vocab_size"].format(self._datasets[0])
        )
        num_answer_choices = registry.get(
            _TEMPLATES["number_of_answers"].format(self._datasets[0])
        )
​
        self.text_embedding = nn.Embedding(
            num_question_choices, self.config.text_embedding.embedding_dim
        )
        self.lstm = nn.LSTM(**self.config.lstm)
​
        layers_config = self.config.cnn.layers
        conv_layers = []
        for i in range(len(layers_config.input_dims)):
            conv_layers.append(
                ConvNet(
                    layers_config.input_dims[i],
                    layers_config.output_dims[i],
                    kernel_size=layers_config.kernel_sizes[i],
                )
            )
        conv_layers.append(Flatten())
        self.cnn = nn.Sequential(*conv_layers)
​
        # As we generate output dim dynamically, we need to copy the config
        # to update it
        classifier_config = deepcopy(self.config.classifier)
        classifier_config.params.out_dim = num_answer_choices
        self.classifier = ClassifierLayer(
            classifier_config.type, **classifier_config.params
        )
​
    def forward(self, sample_list):
        self.lstm.flatten_parameters()
​
        question = sample_list.text
        image = sample_list.image
​
        # Get (h_n, c_n), last hidden and cell state
        _, hidden = self.lstm(self.text_embedding(question))
        # X x B x H => B x X x H where X = num_layers * num_directions
        hidden = hidden[0].transpose(0, 1)
​
        # X should be 2 so we can merge in that dimension
        assert hidden.size(1) == 2, _CONSTANTS["hidden_state_warning"]
​
        hidden = torch.cat([hidden[:, 0, :], hidden[:, 1, :]], dim=-1)
        image = self.cnn(image)
​
        # Fuse into single dimension
        fused = torch.cat([hidden, image], dim=-1)
        scores = self.classifier(fused)
​
        return {"scores": scores}
​

以上类继承了BaseModel类。mmf中所有的model类都要继承自BaseModel。

在生成类时,对模型进行了注册。

@registry.register_model("cnn_lstm")

相关代码可以查看Registry类中的相关类函数。

 @classmethod
    def register_model(cls, name):
        r"""Register a model to registry with key 'name'
​
        Args:
            name: Key with which the model will be registered.
​
        Usage::
​
            from mmf.common.registry import registry
            from mmf.models.base_model import BaseModel
​
            @registry.register_task("pythia")
            class Pythia(BaseModel):
                ...
        """
​
        def wrap(func):
            from mmf.models.base_model import BaseModel
​
            assert issubclass(
                func, BaseModel
            ), "All models must inherit BaseModel class"
            cls.mapping["model_name_mapping"][name] = func
            return func
​
        return wrap

模型的默认配置在configs/models/cnn_lstm/defaults.yaml中

model_config:
  cnn_lstm:
    losses:
    - type: logit_bce
    text_embedding:
      embedding_dim: 20
    lstm:
      input_size: 20
      hidden_size: 50
      bidirectional: true
      batch_first: true
    cnn:
      layers:
        input_dims: [3, 64, 128, 128, 64, 64]
        output_dims: [64, 128, 128, 64, 64, 10]
        kernel_sizes: [7, 5, 5, 5, 5, 1]
    classifier:
      type: mlp
      params:
        in_dim: 450
        out_dim: 2

之后这些配置会详细的介绍。

首先生成一个20维的embedding

 self.text_embedding = nn.Embedding(
            num_question_choices, self.config.text_embedding.embedding_dim
        )

生成LSTM模块,隐藏层维度为50.

self.lstm = nn.LSTM(**self.config.lstm)

生成CNN模块,各层通道数和卷积核的大小由config中定义。共6层卷积。

layers_config = self.config.cnn.layers
        conv_layers = []
        for i in range(len(layers_config.input_dims)):
            conv_layers.append(
                ConvNet(
                    layers_config.input_dims[i],
                    layers_config.output_dims[i],
                    kernel_size=layers_config.kernel_sizes[i],
                )
            )
        conv_layers.append(Flatten())

在modules/layers.py中有对ConvNet的定义。

卷积后加池化加batchnorm构成一个ConvNet。最后在所有卷积层之后Flatten。

class ConvNet(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding_size="same",
        pool_stride=2,
        batch_norm=True,
    ):
        super().__init__()
​
        if padding_size == "same":
            padding_size = kernel_size // 2
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size, padding=padding_size
        )
        self.max_pool2d = nn.MaxPool2d(pool_stride, stride=pool_stride)
        self.batch_norm = batch_norm
​
        if self.batch_norm:
            self.batch_norm_2d = nn.BatchNorm2d(out_channels)
​
    def forward(self, x):
        x = self.max_pool2d(nn.functional.leaky_relu(self.conv(x)))
​
        if self.batch_norm:
            x = self.batch_norm_2d(x)
​
        return x
    
​
class Flatten(nn.Module):
    def forward(self, input):
        if input.dim() > 1:
            input = input.view(input.size(0), -1)
​
        return input

 

在前向传播中

    def forward(self, sample_list):
        self.lstm.flatten_parameters()
​
        question = sample_list.text
        image = sample_list.image
​
        # Get (h_n, c_n), last hidden and cell state
        _, hidden = self.lstm(self.text_embedding(question))
        # X x B x H => B x X x H where X = num_layers * num_directions
        hidden = hidden[0].transpose(0, 1)
​
        # X should be 2 so we can merge in that dimension
        assert hidden.size(1) == 2, _CONSTANTS["hidden_state_warning"]
​
        hidden = torch.cat([hidden[:, 0, :], hidden[:, 1, :]], dim=-1)
        image = self.cnn(image)
​
        # Fuse into single dimension
        fused = torch.cat([hidden, image], dim=-1)
        scores = self.classifier(fused)
​
        return {"scores": scores}

hidden与image的维度都为2,第一维为batch_size,第二维分别为lstm和cnn出来的特征向量

将lstm最后一个隐藏状态和cnn的输出特征进行拼接,输入全连接网络,输出两个评分。这两个评分,是对于预设答案的评分。

该网络是vqa中最简单的网络。然而,任何复杂的网络都需要从简单的网络中逐渐演变诞生而来。Rome wasn‘t built in a day !

 

你可能感兴趣的:(vqa学习之路,vqa,pytorch)