最近需要学习处理VQA任务特此记录,这个主要是对论文bottom-up and top-down()和bilinear attention network()中的代码部分的学习记录,目前也并不是很熟悉,如有记录错误的地方还请给位大佬即使指正。
VQA2.0数据集的下载地址:VQA: Visual Question Answering (visualqa.org)
详细的可以看这位博主:VQA_v2数据集预处理_呆呆_kk的博客-CSDN博客_vqa数据集
数据集主要是有三部分:
(1)VQA Annotations:
Training:4437570
Validation:2143540
(2)VQA Input Questions:
Training:443757
Validation:214354
Testing:447793
(3)VQA Input Images:
Training:82783
Validation:40504
Testing:81434
也就是说,每张图片会有5个左右的问题,每个问题会有10个左右的回答。
以bottom-up and top-down这篇论文为例:
VQA任务就是给定一张图片和一个问题,模型要根据给定的输入来进行回答。
很明显,VQA任务的输入有两个(image和question),对于如何提取image的feature,这里就不在赘述, 可以使用CNN提取特征的方式,CNN可以选择Resnet、VGG等骨干网络(去除pooling和fc层)。
对于如何提取question的特征,一般的做法是,由于question本身是文字,需要转换为对应的向量形式,因为模型是不认识文字的,所以要先将每个词转换为与之对应的向量,简单的做法是使用one-hot编码,当然这里你首先是需要建立一个词典(包含你任务中所有可能出现的单词)one-hot编码呢例如():
初始有一个词典:['a','is','he','boy','she','dog','school','man','woman','king','queen']
那出现的单词可以用one-hot编码来表示这个单词,如'he'->[0,0,1,0,0,0,0,0,0,0,0],很明显的缺点就是当dictionary很大时,one-hot编码会特别稀疏,这时候一般使用其它的编码方式,比如说你可以使用预训练的glove向量,这个我们之后放在预处理中再阐述,至于如何建立词典我们也会在预训练中阐述。
当我们将单词转换为对应的向量后,此时我们要整句话对应的向量送往RNN中来提取对应的文本特征,RNN我们一般是会选择LSTM或者GRU(需要讲解LSTM和GRU的我们可以下次再说)。
得到image和question的feature后,就可以通过模型去训练,训练的方式可以使用多标签分类的形式,所以它的损失也就是一个多标签损失,也就可以使用交叉熵损失。
数据预处理部分我们主要是针对bottom-up and top-down和bilinear attention network代码中的数据预处理部分来说明。
贴出相关代码的地址:
Bottom-up:GitHub - hengyuan-hu/bottom-up-attention-vqa: An efficient PyTorch implementation of the winning entry of the 2017 VQA Challenge.
BAN:GitHub - jnhwkim/ban-vqa: Bilinear attention networks for visual question answering
download.sh文件中下载并解压到对应位置,在此工作之前,请先准备足够大的存储空间。
wget -P data http://nlp.stanford.edu/data/glove.6B.zip
unzip data/glove.6B.zip -d data/glove
rm data/glove.6B.zip
解压后会得到4种维度的词向量,可以根据自己的选择来使用:
简单看一下50维的词向量,可以看到每个单词对应一个50维度的向量:
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097
# Questions
wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Train_mscoco.zip
unzip data/v2_Questions_Train_mscoco.zip -d data
rm data/v2_Questions_Train_mscoco.zip
wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Val_mscoco.zip
unzip data/v2_Questions_Val_mscoco.zip -d data
rm data/v2_Questions_Val_mscoco.zip
wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Questions_Test_mscoco.zip
unzip data/v2_Questions_Test_mscoco.zip -d data
rm data/v2_Questions_Test_mscoco.zip
# Annotations
wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Train_mscoco.zip
unzip data/v2_Annotations_Train_mscoco.zip -d data
rm data/v2_Annotations_Train_mscoco.zip
wget -P data https://s3.amazonaws.com/cvmlp/vqa/mscoco/vqa/v2_Annotations_Val_mscoco.zip
unzip data/v2_Annotations_Val_mscoco.zip -d data
rm data/v2_Annotations_Val_mscoco.zip
解压后会得到对应的json文件:
其中question的数据结构格式如下:
{
"info" : info,
"task_type" : str,
"data_type": str,
"data_subtype": str,
"questions" : [question],
"license" : license
}
--------------------------取出所关注的question部分
question{
"question_id" : int,
"image_id" : int,
"question" : str
}
annotation的数据格式如下:
{
"info" : info,
"data_type": str,
"data_subtype": str,
"annotations" : [annotation],
"license" : license
}
------------------------取出所关注的annotation部分
annotation{
"question_id" : int,
"image_id" : int,
"question_type" : str,
"answer_type" : str,
"answers" : [answer],
"multiple_choice_answer" : str
}
----------------------具体的answer部分
answer{
"answer_id" : int,
"answer" : str,
"answer_confidence": str
}
这里直接使用的是预训练的feature,就不再需要自己将image送往CNN中来提取特征,预训练是使用Resnet101为骨干网络,使用faster-RCNN的方式在genome上进行预训练提取的特征。
# Image Features # resnet101_faster_rcnn_genome
wget -P data https://imagecaption.blob.core.windows.net/imagecaption/trainval.zip
wget -P data https://imagecaption.blob.core.windows.net/imagecaption/test2014.zip
wget -P data https://imagecaption.blob.core.windows.net/imagecaption/test2015.zip
unzip data/trainval.zip -d data
unzip data/test2014.zip -d data
unzip data/test2015.zip -d data
rm data/trainval.zip
rm data/test2014.zip
rm data/test2015.zip
解压后:
在所有的question、annotation和image feature准备好了之后,下面开始为训练做一些准备工作:
我们首先要建立一个字典,字典包含所有出现的单词,并将其转换为对应的glove词向量,通过加载question文件,取出所关注的“question”部分,再取出“question(str)”(这里有点绕,仔细看question的数据结构就可以理解,就是一层一层的去取我们所需要的部分),就获得了所有的问题,有了所有的问题,就可以取出其中的单词,然后将大写转小写等等操作,再去掉重复的单词,就得到了该任务所有可能出现的单词。
dataroot = '../data' if args.task == 'vqa' else 'data/flickr30k'
dictionary_path = os.path.join(dataroot, 'dictionary.pkl')
d = create_dictionary(dataroot, args.task) # data,vqa #
d.dump_to_file(dictionary_path) # 字典,包含所有单词
d = Dictionary.load_from_file(dictionary_path)
这样会创建一个词典,同时可以通过词典来进行查找“单词到索引(word2idx)”和“索引到单词(idx2word)”:
print(d.idx2word) # 所有单词 ['what', 'is', 'this',..., 'pegasus', 'bebe']
print(len(d.idx2word)) # 共有19901个单词
print(d.word2idx) # {'what': 0, 'is': 1, 'this': 2,...,'pegasus': 19899, 'bebe': 19900}
光得到词典还不行,我们最终需要丢进模型中训练的一定是tensor或者说是向量形式的,文字肯定是不行的,所以需要转换为对应的词向量,其实就是为了进行word embedding做准备:
# 使用glove
glove_file = '../data/glove/glove.6B.%dd.txt' % emb_dim # 加载预训练好的data/glove/glove.6B.300d.txt
weights, word2emb = create_glove_embedding_init(d.idx2word, glove_file)
np.save(os.path.join(dataroot, 'glove6b_init_%dd.npy' % emb_dim), weights)
其中weights就是我们最后得到的预训练好的词典向量,我们有19901个单词,每个单词转换为300维的向量,所以weights.shape为[19901,300]:
print(weights, weights.shape) # [19901,300],从glove中加载出每个词对应的向量
将字典和weights进行保存:
由于考试原因暂时先写到这,准备这段时间慢慢将第三部分写完, 等待更新...
等待更新...
等待更新...
说明:个人比较懒,有些可能懒得打字,如果没有说清楚或者有说错的地方请及时指出,共同学习!