caffe学习笔记(11):多任务学习之HDF5Data类型数据集生成

最近开始研究多任务学习(multi-task learning, MTL),先分享给大家:
本文主要讲述数据集的建立,HDF5Data类型用于处理多标签数据,在网络中定义为:

layer {
  name: "data"
  type: "HDF5Data"
  top: "data"
  top: "label"
  include {
    phase: TRAIN
  }
  hdf5_data_param {
    source: "list_train.txt"
    batch_size: 1000
    shuffle: true
  }
}

HDF5Data类型数据集的格式为.h5。
list_train.txt 文件中存放训练数据集列表,list _val.txt文件同:
caffe学习笔记(11):多任务学习之HDF5Data类型数据集生成_第1张图片

因为caffe中要求1个hdf5文件大小不超过2GB,所以如果数据量太大,需要生成多个hdf5文件
本人数据40M,生成的数据集800M。

matlab代码可参照:点击进入

制作train.txt,val.txt文件:
我的标签用了六个,如图:
caffe学习笔记(11):多任务学习之HDF5Data类型数据集生成_第2张图片

代码如下:
代码完成任务:
1. 将整个数据集写入trainval.txt文件
caffe学习笔记(11):多任务学习之HDF5Data类型数据集生成_第3张图片
2. 将trainval.txt文件中的图像乱序,生成trainval_re.txt
caffe学习笔记(11):多任务学习之HDF5Data类型数据集生成_第4张图片
3. 根据trainval_re.txt生成train.txt,val.txt
caffe学习笔记(11):多任务学习之HDF5Data类型数据集生成_第5张图片

caffe学习笔记(11):多任务学习之HDF5Data类型数据集生成_第6张图片

下面是生成这些文件的matlab代码,依据此代码修改完成:

% the file is to create the train.txt and val.txt for multi-task learning
% (MTL) using caffe

clc;  
clear;  
%% create the file of trainval.txt and labels.txt in order  
% set the percent of the train images and validation images  
maindir='images\'; % the file saving images
wf = fopen('trainval.txt','w');
lbf=fopen('labels.txt','w');
train_percent=0.9; %val_percent=1-train_percent  

subdir = dir(maindir);  
ii=-1;  
numoffile=0;
for i = 1:length(subdir) % the first subfile direction
    if ~strcmp(subdir(i).name ,'.') && ~strcmp(subdir(i).name,'..')
        % initialize the label for each key point
        a1 = 0;
        a2 = 0;
        a3 = 0;
        a4 = 0;
        a5 = 0;

        ii=ii+1;
        label = subdir(i).name;
        switch ii
            case 0
                a1 = 1;
                a2 = 1;
                a3 = 1;
                a4 = 1;
                a5 = 1;
            case 1
                a2 = 1;
                a3 = 1;
            case 2
                a1 = 1;
                a2 = 1;
                a3 = 1;
            case 3
                a1 = 1;
                a2 = 1;
                a3 = 1;
                a4 = 1;
                a5 = 1;
            case 4
            case 5
                a2 = 1;
                a5 = 1;
            case 6
                a2 = 1;
            case 7
                a2 = 1;
                a3 = 1;
                a4 = 1;
                a5 = 1;
            case 8
                a3 = 1;
                a4 = 1;
                a5 = 1;
            case 9
                a1 = 1;
            case 10
                a5 = 1;
        end

        fprintf(lbf,'%s: %d %d %d %d %d %d\n', label, ii, a1, a2, a3, a4, a5);
       label=strcat(label,'/');
        subsubdir = dir(strcat(maindir,label));
        for j=1:length(subsubdir)
            if ~strcmp(subsubdir(j).name ,'.') && ~strcmp(subsubdir(j).name,'..')
                fprintf(wf,'./%s%s%s %d %d %d %d %d %d\n','images/', label, subsubdir(j).name, ii, a1, a2, a3, a4, a5);
                numoffile=numoffile+1;
                fprintf('the label is %d, and the image order is %d\n',ii,j-2);
            end
        end
    end
end
fclose(wf);  
fclose(lbf);  

%%  
% random trainval.txt  
file=cell(1,numoffile);  
fin=fopen('trainval.txt','r');  
i=1;  
while ~feof(fin)  
    tline=fgetl(fin);  
    file{i}=tline;  
    i=i+1;  
end  
fclose(fin);  

fprintf('\ntrainval.txt has %d rows,random its order....\n',numoffile);  
pause(1);  
rep=randperm(numoffile);  
fout=fopen('trainval_re.txt','w');  
for i=1:numoffile  
    fprintf(fout,'%s\n',file{rep(i)});  
end  
fprintf('the trainval_re.txt is the random order file.\n');  
fclose(fout);  

%%  
%create train.txt and val.txt based on trainval_re.txt  
fprintf('create the train.txt and val.txt...\n');  
pause(1);  
train_file=fopen('train.txt','w');  
text_file=fopen('val.txt','w');  
trainvalfile=fopen('trainval_re.txt','r');  

num_train=sort(randperm(numoffile,floor(numoffile*train_percent)));  
num_test=setdiff(1:numoffile,num_train);  
i=1;  
while ~feof(trainvalfile)  
    tline=fgetl(trainvalfile);  
    if ismember(i,num_train)  
        fprintf(train_file,'%s\n',tline);  
    else  
        fprintf(text_file,'%s\n',tline);  
    end  
    i=i+1;  
end  
fclose(train_file);  
fclose(text_file);  
fclose(trainvalfile);  
fprintf('the total number of images is %d\n',numoffile);  
fprintf('Done!\n'); 

你可能感兴趣的:(deep,learning,matlab,object,detection,数据,标签,caffe,深度学习)