读取mat数据:scipy.io.loadmat模块使用

文章目录


很多数据集都是mat格式的标注信息,使用模块scipy.io的函数loadmat和savemat可以实现Python对mat数据的读写。

scipy.io.loadmat(file_name, mdict=None, appendmat=True, **kwargs)

scipy.io.savemat(file_name, mdict, appendmat=True, format='5', long_field_names=False, do_compression=False, oned_as='row')

例子:
某re-id数据集有如下结构:

.
├── dataset
│   ├── annotation
│   │   ├── Images.mat
│   │   ├── Person.mat
│   │   ├── pool.mat
│   │   └── test
│   │       ├── subset
│   │       │   ├── Occlusion.mat
│   │       │   └── Resolution.mat
│   │       └── train_test
│   │           ├── Train.mat
│   │           ├── TestG2000.mat
│   │           ├── ...
│   │           └── TestG50.mat
│   ├── Image
│   │   └── SSM
│   │       ├── s10000.jpg
│   │       ├── ...
│   │       └── s9.jpg
│   └── README.txt
└── out
    def _load_image_set_index(self):
        """
        Load the indexes for the specific subset (train / test).
        The index is just the image file name.
        """
        # 读取测试图片池
        test = loadmat(osp.join(self._root_dir, 'annotation', 'pool.mat')) # 读取指定路径的mat文件
        test = test['pool'].squeeze()  # 去掉维度为1的那一维,得到长度为6978的numpy array
        test = [str(a[0]) for a in test] # 利用列表生成式,numpy array转list,元素均为图片名
        if self._image_set == 'test': return test # 如果需要测试集,返回即可
        # 读取全部图片
        all_imgs = loadmat(osp.join(self._root_dir, 'annotation', 'Images.mat')) # 读取指定路径的mat文件
        all_imgs = all_imgs['Img'].squeeze() # 去掉维度为1的那一维,得到长度为18184的numpy array
        all_imgs = [str(a[0][0]) for a in all_imgs] # 元素均为图片名
        # training
        return list(set(all_imgs) - set(test)) # 测试集的补集就是训练集,返回
    def _load_probes(self):
        """
        Load the list of (img, roi) for probes. For test split, it's defined
        by the protocol. For training split, will randomly choose some samples
        from the gallery as probes.
        """
        protoc = loadmat(osp.join(self._root_dir,
            'annotation/test/train_test/TestG50.mat'))['TestG50'].squeeze()  # 2900个query person,gallery库为50.
        probes = []
        for item in protoc['Query']: # probe.mat分为query和gallery,query里每个probe一张图
            im_name = osp.join(self._data_path, str(item['imname'][0,0][0])) # 当前probe的query图片路径
            roi = item['idlocate'][0,0][0].astype(np.int32) # query在图片中的区域
            roi[2:] += roi[:2]  # xmin,ymin,width,height -> xmin,ymin,xmax,ymax
            probes.append((im_name, roi)) 
        return probes # probes 2900个元素,用有序list,方便对应person id
    def gt_roidb(self):
        cache_file = osp.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if osp.isfile(cache_file):
            roidb = unpickle(cache_file)
            return roidb

        # Load all images and build a dict from image to boxes
        all_imgs = loadmat(osp.join(self._root_dir, 'annotation', 'Images.mat'))
        all_imgs = all_imgs['Img'].squeeze()
        name_to_boxes = {}
        name_to_pids = {}
        for im_name, __, boxes in all_imgs:
            im_name = str(im_name[0])
            boxes = np.asarray([b[0] for b in boxes[0]])  # 格式处理,内容不变
            boxes = boxes.reshape(boxes.shape[0], 4)	# 格式处理,内容不变
            valid_index = np.where((boxes[:, 2] > 0) & (boxes[:, 3] > 0))[0] # 长宽大于零
            assert valid_index.size > 0, \
                'Warning: {} has no valid boxes.'.format(im_name)
            boxes = boxes[valid_index] # 过滤掉长宽小于0的box
            name_to_boxes[im_name] = boxes.astype(np.int32)
            name_to_pids[im_name] = -1 * np.ones(boxes.shape[0], dtype=np.int32) # pid都初始化为-1

        def _set_box_pid(boxes, box, pids, pid):
            for i in xrange(boxes.shape[0]):
                if np.all(boxes[i] == box):
                    pids[i] = pid
                    return
            print 'Warning: person {} box {} cannot find in Images'.format(pid, box)

        # Load all the train / test persons and label their pids from 0 to N-1
        # Assign pid = -1 for unlabeled background people
        if self._image_set == 'train':
            train = loadmat(osp.join(self._root_dir,
                                     'annotation/test/train_test/Train.mat')) # 5532 query img for train
            train = train['Train'].squeeze()
            for index, item in enumerate(train): 
                scenes = item[0, 0][2].squeeze()  # all images info which the same person appear
                for im_name, box, __ in scenes:
                    im_name = str(im_name[0])  # img name
                    box = box.squeeze().astype(np.int32) 
                    _set_box_pid(name_to_boxes[im_name], box,
                                 name_to_pids[im_name], index) # 拿train集合所有box和所有image中的所有box比对,找到一致的,赋pid(不是真的pid,只是index).train里没有的box不去管它
        else:
            test = loadmat(osp.join(self._root_dir,
                                    'annotation/test/train_test/TestG50.mat'))
            test = test['TestG50'].squeeze()
            for index, item in enumerate(test):
                # query
                im_name = str(item['Query'][0,0][0][0])
                box = item['Query'][0,0][1].squeeze().astype(np.int32)
                _set_box_pid(name_to_boxes[im_name], box,
                             name_to_pids[im_name], index)
                # gallery
                gallery = item['Gallery'].squeeze()
                for im_name, box, __ in gallery:
                    im_name = str(im_name[0])
                    if box.size == 0: break
                    box = box.squeeze().astype(np.int32)
                    _set_box_pid(name_to_boxes[im_name], box,
                                 name_to_pids[im_name], index) 

        # Construct the gt_roidb
        gt_roidb = []
        for im_name in self.image_index:  # Traversal train/test img_name
            boxes = name_to_boxes[im_name]  # get boxes
            boxes[:, 2] += boxes[:, 0]  # weight to xmax
            boxes[:, 3] += boxes[:, 1]  # height to ymax
            pids = name_to_pids[im_name]  # get pid
            num_objs = len(boxes)  # get num of boxes
            gt_classes = np.ones((num_objs), dtype=np.int32)  # initial cls of boxes
            overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32) 
            overlaps[:, 1] = 1.0
            overlaps = csr_matrix(overlaps)
            gt_roidb.append({
                'boxes': boxes,
                'gt_classes': gt_classes,
                'gt_overlaps': overlaps,
                'gt_pids': pids,
                'flipped': False})

        pickle(gt_roidb, cache_file)
        print "wrote gt roidb to {}".format(cache_file)

        return gt_roidb

你可能感兴趣的:(caffe)