mmsegmentation使用.csv文件读取数据

自定义数据读取方式

将原本的mmcv.imfrombytes读取固定目录格式的图片直接改成mmcv.imread读取csv文件的格式

定义数据读取方式

在/mmsegmentation-master/mmseg/datasets路径下新建一个mydata.py文件

@DATASETS.register_module()
class MyDataset(Dataset):
    CLASSES = ('gt',)
    PALETTE = [[255, 255, 255],]

    def __init__(self,
                 pipeline,
                 test_mode=False,
                 csv_file = 'tmp.csv'):
        self.pipeline = Compose(pipeline)
        self.test_mode = test_mode
        self.csv_file = csv_file
        self.img_infos = self.load_img_annotations_fromcsv(csv_file=self.csv_file)

    def __len__(self):
        """Total number of samples of data."""
        return len(self.img_infos)


    def load_img_annotations_fromcsv(self, csv_file): # 通过csv文件读取img,depth,gt,seg_thresh路径
        frames = pd.read_csv(csv_file, sep=";", header=None)
        img_infos = []
        for idx in range(len(frames)):
            img_info = dict(filename=frames.iloc[idx, 0])
            img_info['ann'] = frames.iloc[idx, 1]
            img_infos.append(img_info) 
        return img_infos

    def get_ann_info(self, idx): # 获取gt路径
        return self.img_infos[idx]['ann']
    
    def __getitem__(self, idx):
        if self.test_mode:
            return self.prepare_test_img(idx)
        else:
            return self.prepare_train_img(idx)

    def prepare_train_img(self, idx):
        img_info = self.img_infos[idx]
        ann_info = self.get_ann_info(idx)
        results = dict(img_info=img_info, ann_info=ann_info)
        results['seg_fields'] = []
        return self.pipeline(results)
        
        
    def prepare_test_img(self, idx):
        img_info = self.img_infos[idx]
        ann_info = self.get_ann_info(idx)
        results = dict(img_info=img_info, ann_info=ann_info)
        results['seg_fields'] = []
        return self.pipeline(results)

    def get_gt_seg_map_by_idx(self, index):
        """Get one ground truth segmentation map for evaluation."""
        ann_info = self.get_ann_info(index)
        results = dict(ann_info=ann_info)
        results['gt_semantic_seg'] = mmcv.imread(results['ann_info'], -1)
        return results['gt_semantic_seg']

    def get_gt_seg_maps(self, efficient_test=None):
        """Get ground truth segmentation maps for evaluation."""
        for idx in range(len(self)):
            ann_info = self.get_ann_info(idx)
            results = dict(ann_info=ann_info)
            results['gt_semantic_seg'] = mmcv.imread(results['ann_info'], -1)
            yield results['gt_semantic_seg']

/mmsegmentation-master/mmseg/datasets/pipelines/loading.py新增相应的函数

@PIPELINES.register_module()
class LoadImageFromCSV(object):
    """Load an image from csv file.

    Required keys are "img_prefix" and "img_info" (a dict that must contain the
    key "filename"). Added or updated keys are "filename", "img", "img_shape",
    "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
    "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).

    Args:
        to_float32 (bool): Whether to convert the loaded image to a float32
            numpy array. If set to False, the loaded image is an uint8 array.
            Defaults to False.
    """

    def __init__(self,
                 to_float32=False,
                ):
        self.to_float32 = to_float32
        

    def __call__(self, results):
        """Call functions to load image and get image meta information."""

        img = mmcv.imread(results['img_info']['filename'])

        if self.to_float32:
            img = img.astype(np.float32)
     
        results['filename'] = results['img_info']['filename']
        results['ori_filename'] = results['img_info']['filename']
        results['img'] = img
        results['img_shape'] = img.shape
        results['ori_shape'] = img.shape
        # Set initial values for default meta_keys
        results['pad_shape'] = img.shape
        results['scale_factor'] = 1.0      
        
        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
        results['img_norm_cfg'] = dict(
            mean=np.zeros(num_channels, dtype=np.float32),
            std=np.ones(num_channels, dtype=np.float32),
            to_rgb=False)
        return results


@PIPELINES.register_module()
class LoadAnnotationsFromCSV(object):
    """Load annotations for semantic segmentation."""

    def __call__(self, results):
        """Call function to load multiple types annotations."""

        gt_semantic_seg = mmcv.imread(results['ann_info'], -1)
        results['gt_semantic_seg'] = gt_semantic_seg
        results['seg_fields'].append('gt_semantic_seg')
        return results

在init文件中增加MyDataset

同时在/mmsegmentation-master/mmseg/datasets中的____init____.py添加刚刚新建的数据集

...
from .mydata import MyDataset 

__all__ = [
    ..., 'MyDataset ',
]

在config文件中修改数据读取路径

dict(type='LoadImageFromCSV'), 
dict(type='LoadAnnotationsFromCSV'),

至此大功告成!!!

你可能感兴趣的:(mmsegmentation,python,numpy)