背景:coco有原始的分组,我们通过谱聚类进行了新的分组。需要对coco进行再分组。然后送入网络训练。
目录
一、分组写入文件
1.1 写入参考
1.2 分组结果写入
1.3 分组变量的读出
二、程序中coco分组的运用
2.1 程序调用关系
2.2 聚类分组情况
2.3 更改Config
2.4 重新定义网络
2.5 直接在网络中加载
三、网络结构的定义
3.1 网络输出与标签
3.2 序列按照idx进行重排
3.3 运用index_select调序
3.4 序列之中的调换
可以运用pickle写入文件
参考其他程序中的写入:
correlations = {}
correlations.update(pp=A_B) #p(A/B)
correlations.update(fp=notA_B) # P(not A/B)
correlations.update(pf=A_notB)
correlations.update(ff=notA_notB)
with open('sk_spectral_cluster/coco_correlations.pkl', 'wb') as f:
print("write correlations in sk_spectral_cluster/coco_correlations.pkl")
pickle.dump(correlations, f)
with open('sk_spectral_cluster/coco_names.pkl','wb') as name_file:
print("write correlations in sk_spectral_cluster/coco_names.pkl")
pickle.dump(names, name_file)
相当于直接将变量用pickle.dump写入文件f之中。
分组之后,我们的结果写入split_groups之中,是一个字典。格式为下面注释中的格式。
#---------------store the split result into .pkl file-------
# in format dict { 0: [1, 2, 3, 4, 5, 7, 9, 10, 11, 12]
# 1: [46, 47, 49, 50, 51]
# .........
# 2: [22, 23, 32, 34, 35, 38] }
with open('sk_spectral_cluster/coco_label_cluster_result.pkl', 'wb') as f:
print("write cluster result into sk_spectral_cluster/coco_label_cluster_result.pkl")
pickle.dump(split_groups, f)
命名为 coco_label_cluster_result.pkl
直接根据写入路径即可读出:
with open('sk_spectral_cluster/coco_label_cluster_result.pkl', 'rb') as f:
print("loading sk_spectral_cluster/coco_label_cluster_result.pkl ")
split_groups= pickle.load(f)
print("split_groups: ",split_groups)
general_train之中,调用COCO2014进行数据集的读取。
train_dataset = COCO2014(args.data, phase='train', inp_name=Config.INP_NAME, is_grouping=True) # fixme
val_dataset = COCO2014(args.data, phase='val', inp_name=Config.INP_NAME, is_grouping=True) # fixme
其中对组的定义通过config传入函数之中:
GROUPS = 12
NCLASSES = 80
NCLASSES_PER_GROUP = [1, 8, 5, 10, 5, 10, 7, 10, 6, 6, 5, 7] # FIXME: to check
GROUP_CHANNELS = 512
CLASS_CHANNELS = 256
直接将分组的参量传入GROUP之中,
if Config.MODEL == 'hgat_fc':
import mymodels.hgat_fc as hgat_fc
model = hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
nclasses_per_group=Config.NCLASSES_PER_GROUP,
group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)
split groups:
group: 1 group element numbers 10 group_elements : [1, 2, 3, 4, 5, 7, 9, 10, 11, 12]
group: 2 group element numbers 5 group_elements : [46, 47, 49, 50, 51]
group: 3 group element numbers 6 group_elements : [22, 23, 32, 34, 35, 38]
group: 4 group element numbers 4 group_elements : [16, 18, 19, 29]
group: 5 group element numbers 5 group_elements : [6, 24, 25, 26, 28]
group: 6 group element numbers 7 group_elements : [15, 57, 59, 65, 73, 74, 77]
group: 7 group element numbers 4 group_elements : [61, 71, 78, 79]
group: 8 group element numbers 7 group_elements : [39, 58, 68, 69, 70, 72, 75]
group: 9 group element numbers 10 group_elements : [0, 8, 13, 14, 17, 20, 21, 33, 36, 37]
group: 10 group element numbers 5 group_elements : [62, 63, 64, 66, 67]
group: 11 group element numbers 10 group_elements : [40, 41, 42, 43, 44, 45, 53, 55, 56, 60]
group: 12 group element numbers 2 group_elements : [30, 31]
group: 13 group element numbers 2 group_elements : [48, 52]
group: 14 group element numbers 2 group_elements : [27, 54]
group: 15 group element numbers 1 group_elements : [76]
Final results,group numbers: 15 max_classes_per_group: 10 probability filter threshold: 0.05
group: 1 group element numbers: 10
group_elements : ['bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'truck', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter']
group: 2 group element numbers: 5
group_elements : ['banana', 'apple', 'orange', 'broccoli', 'carrot']
group: 3 group element numbers: 6
group_elements : ['zebra', 'giraffe', 'sports ball', 'baseball bat', 'baseball glove', 'tennis racket']
group: 4 group element numbers: 4
group_elements : ['dog', 'sheep', 'cow', 'frisbee']
group: 5 group element numbers: 5
group_elements : ['train', 'backpack', 'umbrella', 'handbag', 'suitcase']
group: 6 group element numbers: 7
group_elements : ['cat', 'couch', 'bed', 'remote', 'book', 'clock', 'teddy bear']
group: 7 group element numbers: 4
group_elements : ['toilet', 'sink', 'hair drier', 'toothbrush']
group: 8 group element numbers: 7
group_elements : ['bottle', 'potted plant', 'microwave', 'oven', 'toaster', 'refrigerator', 'vase']
group: 9 group element numbers: 10
group_elements : ['person', 'boat', 'bench', 'bird', 'horse', 'elephant', 'bear', 'kite', 'skateboard', 'surfboard']
group: 10 group element numbers: 5
group_elements : ['tv', 'laptop', 'mouse', 'keyboard', 'cell phone']
group: 11 group element numbers: 10
group_elements : ['wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'pizza', 'cake', 'chair', 'dining table']
group: 12 group element numbers: 2
group_elements : ['skis', 'snowboard']
group: 13 group element numbers: 2
group_elements : ['sandwich', 'hot dog']
group: 14 group element numbers: 2
group_elements : ['tie', 'donut']
group: 15 group element numbers: 1
group_elements : ['scissors']
根据分组情况更改相应的程序:
BACKBONE = 'resnet101'
GROUPS = 15
NCLASSES = 80
NCLASSES_PER_GROUP = [10, 5, 6, 4, 5, 7, 4,7,10,5,10,2,2,2,1] # FIXME: to check
在网络之中加入结构:
elif Config.MODEL=='clustered_hgat_fc':
import momydels.clustered_hgat_fc as clustered_hgat_fc
model=clustered_hgat_fc.HGAT_FC(Config.BACKBONE, groups=Config.GROUPS, nclasses=Config.NCLASSES,
nclasses_per_group=Config.NCLASSES_PER_GROUP,
group_channels=Config.GROUP_CHANNELS, class_channels=Config.CLASS_CHANNELS)
同时在my_model文件夹之中加入文件
clustered_hgat_fc.py
直接在网络之中加载分组情况,加载完之后,送入网络,便于分组。
#fixme----------- load clustered results
# load groups and group classes
with open('sk_spectral_cluster/coco_label_cluster_result.pkl', 'rb') as f:
print("loading sk_spectral_cluster/coco_label_cluster_result.pkl ")
split_groups = pickle.load(f)
for key in split_groups:
print("group:", key, " group element numbers", len(split_groups[key]), " group_elements : ", split_groups[key])
print("groups=len(split_groups) :", len(split_groups))
nclasses_per_group = []
cls_idx_order=[]
for idx in range(len(split_groups)):
nclasses_per_group.append(len(split_groups[idx + 1]))
cls_idx_order=cls_idx_order+split_groups[idx + 1]
torch_cls_idx_order=torch.IntTensor(cls_idx_order)
self.torch_cls_idx_order=torch_cls_idx_order
# print("final idx order:",self.torch_cls_idx_order)
最终输出的为一个n_classes的变量,对于与每个一标签。
x = torch.cat(outside, dim=1) # [B,nclasses,C]
x = torch.cat([self.fcs[i](x[:, i, :]) for i in range(self.nclasses)], dim=1) # [B,nclasses]
return x
https://blog.csdn.net/qq_25037903/article/details/88651166
torch.index_select
官网地址:https://pytorch-cn.readthedocs.io/zh/latest/package_references/torch/#torchindex_select
torch.index_select(input, dim, index, out=None) → Tensor
沿着指定维度对输入进行切片,取index
中指定的相应项(index
为一个LongTensor),然后返回到一个新的张量, 返回的张量与原始张量_Tensor_有相同的维度(在指定轴上)。
注意: 返回的张量不与原始张量共享内存空间。
参数:
理解为,index为目标张量out中的值再原始张量input中的位置。
例如a为
idx2为:
就是根据idx2中的元素值选出a中的位置的值,存入out之中。
经过
a.index_select(0,idx2)
恢复出来为:
下面的x为分组后的组。
outside = []
for i in range(self.groups):
inside = []
for j in range(self.nclasses_per_group[i]):
inside.append(self.class_fcs[count](x[:, i, :])) # [B,C]
count += 1
inside = torch.stack(inside, dim=1) # [B,N,C]
inside = self.gat2s[i](inside) # [B,N,C]
outside.append(inside)
x = torch.cat(outside, dim=1) # [B,nclasses,C]
x = torch.cat([self.fcs[i](x[:, i, :]) for i in range(self.nclasses)], dim=1) # [B,nclasses]
恢复到分组前,需要将x调换顺序。
这个self.torch_cls_idx_order为每组在原order的顺序。
torch_cls_idx_order=torch.IntTensor(cls_idx_order)
self.torch_cls_idx_order=torch_cls_idx_order
我们的目的是:group_classes在class_order中的顺序,恢复出原class_order
cls=x
for idx in range(0,self.nclasses):
cls[idx]=x[self.cls_idx_order[idx]]
x=cls
index=self.order_cls_idx
x=x.index_select(0,index)
此法总是报错。暂时不管
程序之中改为:
#fixme change order from order in groups to class order
x = torch.cat([self.fcs[i](x[:, self.idx_in_group_2_cls_idx[i], :]) for i in range(self.nclasses)], dim=1) # [B,nclasses]
级运用x的顺序,恢复出class的顺序。x为分组后的变量顺序,为了恢复出原始的变量的顺序,我们需要一个对应。
split_group的对应为,下标idx对应每组输出的位置,值value对应每个类中的类别。
那么我们希望根据每组输出的恢复出每个类的标签。
我们现在,已知每个类别,需要找到在split_group中对应的下标,所以需要一个value到idx的对应。
idx_in_group_2_cls_idx={}
for idx in range(len(cls_idx_order)):
idx_in_group_2_cls_idx[cls_idx_order[idx]]=idx
print("idx_in_group_2_cls_idx",idx_in_group_2_cls_idx)
self.idx_in_group_2_cls_idx=idx_in_group_2_cls_idx