首先将json文件(如下),经过一系列处理好保存在trainset.pth文件中
self.path_trainset = osp.join(self.subdir_processed, 'trainset.pth') #将vqa2.0json文件处理好后存放的地方
def process(self):
dir_ann = osp.join(self.dir_raw, 'annotations')
path_train_ann = osp.join(dir_ann, 'mscoco_train2014_annotations.json')
path_train_ques = osp.join(dir_ann, 'OpenEnded_mscoco_train2014_questions.json')
train_ann = json.load(open(path_train_ann))
train_ques = json.load(open(path_train_ques))
trainset = self.merge_annotations_with_questions(train_ann, train_ques) #合并答案和question文件
trainset = self.add_image_names(trainset) #向文件中添加图像名
trainset['annotations'] = self.add_answer(trainset['annotations']) #向文件中添加答案
trainset['annotations'] = self.tokenize_answers(trainset['annotations']) #对答案进行tokenize处理
trainset['questions'] = self.tokenize_questions(trainset['questions'], self.nlp) #对问题采用nlp进行tokenize处理
trainset['questions'] = self.insert_UNK_token(trainset['questions'], wcounts, self.minwcount)
trainset['questions'] = self.encode_questions(trainset['questions'], word_to_wid)
trainset['annotations'] = self.encode_answers(trainset['annotations'], ans_to_aid)
torch.save(trainset, self.path_trainset) #保存处理好后的json文件到trainset.pth中
#加载数据集
if not os.path.exists(self.subdir_processed):
self.process()
self.dataset = torch.load(self.path_trainset)
#添加rcnn提取的信息
def add_rcnn_to_item(self, item):
'''
:param item: 传入的coco/extract/coco_train*******.jpg.pth文件
:return:
'''
path_rcnn = os.path.join(self.dir_rcnn, '{}.pth'.format(item['image_name']))
item_rcnn = torch.load(path_rcnn) #加载pth文件
print(item_rcnn)
item['visual'] = item_rcnn['pooled_feat'] #区域特征
item['coord'] = item_rcnn['rois'] #感兴趣区域位置
item['norm_coord'] = item_rcnn['norm_rois'] #感兴趣区域特征标准化
item['nb_regions'] = item['visual'].size(0) #区域数
return item
def __getitem__(self, index):
item = {}
item['index'] = index
# Process Question (word token)
question = self.dataset['questions'][index]
if self.load_original_annotation:
item['original_question'] = question
item['question_id'] = question['question_id'] #向item中添加问题id:question_id
item['question'] = torch.LongTensor(question['question_wids']) #向item添加问题单词索引表示:question
item['lengths'] = torch.LongTensor([len(question['question_wids'])]) #向item添加问题长度:lengths
item['image_name'] = question['image_name'] #向item添加图像名:image_name
# Process Object, Attribut and Relational features
# 处理对象、特性和关系特征
item = self.add_rcnn_to_item(item) #向item中添加由faster-rcnn提取好的图像特征信息 :boxes,feature
# 如果答案存在,处理答案(主要是因为测试集没有答案,所有处理训练集)
if 'annotations' in self.dataset:
annotation = self.dataset['annotations'][index]
if self.load_original_annotation:
item['original_annotation'] = annotation
if 'train' in self.split and self.samplingans:
proba = annotation['answers_count']
proba = proba / np.sum(proba)
item['answer_id'] = int(np.random.choice(annotation['answers_id'], p=proba))
else:
item['answer_id'] = annotation['answer_id']
item['class_id'] = torch.LongTensor([item['answer_id']])
item['answer'] = annotation['answer']
item['question_type'] = annotation['question_type']
else:
if item['question_id'] in self.is_qid_testdev:
item['is_testdev'] = True
else:
item['is_testdev'] = False
return item
整个item字典中键有
{
index :索引,
question_id:问题id,458752001
question:问题, tensor([4321, 2932, 1997, 3968, 2286, 2878])
lengths:问题长度, tensor([6]),
image_name:图像名,'COCO_train2014_000000458752.jpg'
visual:图像特征,
coord:感兴趣区域位置信息,
norm_coord:感兴趣区域位置信息标准化,
nb_regions:区域数,36
answer_id:答案id,382
class_id:分类id, tensor([382])
answer:答案,'pitcher'
question_type:问题类型 'what'
}
如下item数据具体信息:
{ 'index': 1,
'question_id': 458752001,
'question': tensor([4321, 2932, 1997, 3968, 2286, 2878]), 'lengths': tensor([6]),
'image_name': 'COCO_train2014_000000458752.jpg',
'visual': tensor([[0.0000, 0.0000, 0.0231, ..., 0.0000, 0.0281, 1.5262],
[0.0000, 0.0169, 0.0587, ..., 0.0000, 0.0064, 1.1313],
[0.3978, 0.0000, 0.0000, ..., 0.0000, 0.1113, 3.8770],
...,
[0.0326, 0.0000, 0.0000, ..., 0.0799, 2.7793, 1.2371],
[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.3857],
[0.0084, 0.8026, 0.0966, ..., 0.0000, 0.7668, 0.0798]]),
'coord': tensor([[282.9814, 302.8545, 372.2248, 468.0808],
[311.6291, 333.7408, 359.6907, 358.5282],
[215.0726, 172.1102, 352.8074, 407.3577],
[285.7687, 189.5694, 329.6428, 231.0218],
[274.9748, 160.3990, 318.7502, 208.2673],
[241.7279, 302.4235, 286.6740, 342.9375],
[241.9454, 230.9683, 355.9464, 350.7243],
[ 0.0000, 0.0000, 383.6760, 425.9977],
[372.8926, 360.4357, 629.3776, 401.7158],
[348.9116, 45.5669, 639.2000, 479.2000],
[391.1129, 149.4865, 610.5438, 353.4078],
[ 28.7101, 178.1590, 235.1222, 398.0370],
[353.4412, 420.1210, 381.8891, 442.3383],
[249.2581, 316.8993, 319.8651, 357.7612],
[ 0.0000, 0.0000, 487.8425, 174.9353],
[177.7185, 63.3202, 472.4503, 479.2000],
[ 6.3949, 120.7137, 639.2000, 479.2000],
[ 88.5412, 209.1507, 558.1020, 479.2000],
[254.3652, 164.8777, 332.2064, 206.6262],
[189.8386, 128.2660, 426.3345, 479.2000],
[237.2819, 281.1407, 411.9520, 479.2000],
[ 20.9822, 370.2453, 301.8062, 402.6493],
[312.2184, 263.2010, 344.6071, 296.3635],
[265.3174, 229.0582, 374.0845, 349.4183],
[257.8582, 154.3860, 341.2274, 235.9603],
[108.7576, 342.0231, 573.0241, 455.5504],
[ 57.4191, 0.0000, 617.8732, 117.8669],
[234.5487, 271.3556, 268.1855, 318.8475],
[323.6842, 0.0000, 639.2000, 145.7849],
[263.1414, 249.6308, 396.5386, 479.2000],
[310.9734, 257.4800, 349.8676, 292.9267],
[349.4448, 423.4623, 388.5869, 452.1093],
[269.6038, 153.9579, 300.1459, 188.4087],
[162.7299, 0.0000, 639.2000, 230.7880],
[286.1820, 371.8325, 346.1609, 479.2000],
[168.0096, 305.8445, 479.0755, 479.2000]]),
'norm_coord': tensor([[0.4422, 0.6309, 0.5816, 0.9752],
[0.4869, 0.6953, 0.5620, 0.7469],
[0.3361, 0.3586, 0.5513, 0.8487],
[0.4465, 0.3949, 0.5151, 0.4813],
[0.4296, 0.3342, 0.4980, 0.4339],
[0.3777, 0.6300, 0.4479, 0.7145],
[0.3780, 0.4812, 0.5562, 0.7307],
[0.0000, 0.0000, 0.5995, 0.8875],
[0.5826, 0.7509, 0.9834, 0.8369],
[0.5452, 0.0949, 0.9988, 0.9983],
[0.6111, 0.3114, 0.9540, 0.7363],
[0.0449, 0.3712, 0.3674, 0.8292],
[0.5523, 0.8753, 0.5967, 0.9215],
[0.3895, 0.6602, 0.4998, 0.7453],
[0.0000, 0.0000, 0.7623, 0.3644],
[0.2777, 0.1319, 0.7382, 0.9983],
[0.0100, 0.2515, 0.9988, 0.9983],
[0.1383, 0.4357, 0.8720, 0.9983],
[0.3974, 0.3435, 0.5191, 0.4305],
[0.2966, 0.2672, 0.6661, 0.9983],
[0.3708, 0.5857, 0.6437, 0.9983],
[0.0328, 0.7713, 0.4716, 0.8389],
[0.4878, 0.5483, 0.5384, 0.6174],
[0.4146, 0.4772, 0.5845, 0.7280],
[0.4029, 0.3216, 0.5332, 0.4916],
[0.1699, 0.7125, 0.8954, 0.9491],
[0.0897, 0.0000, 0.9654, 0.2456],
[0.3665, 0.5653, 0.4190, 0.6643],
[0.5058, 0.0000, 0.9988, 0.3037],
[0.4112, 0.5201, 0.6196, 0.9983],
[0.4859, 0.5364, 0.5467, 0.6103],
[0.5460, 0.8822, 0.6072, 0.9419],
[0.4213, 0.3207, 0.4690, 0.3925],
[0.2543, 0.0000, 0.9988, 0.4808],
[0.4472, 0.7747, 0.5409, 0.9983],
[0.2625, 0.6372, 0.7486, 0.9983]]),
'nb_regions': 36,
'answer_id': 382,
'class_id': tensor([382]),
'answer': 'pitcher',
'question_type': 'what'}