图像语义分割实践(一)数据制作与转换

图像语义分割实践(一)数据制作与转换

语义分割实践过程中,网上常用公开数据集和模型都是打包固化的,只要配置环境一致就可以进行所谓“复现”,但是这种拿来主义很被动,很多人陷入邯郸学步,贪多嚼不烂的窘迫感,最后浪费大把时间。因此,我觉得真正要学得深入,就得一点一滴进行实现,才能形成自己的一套体系。

第1步:使用 labelme 打标

pip install labelme
运行labelme,使用labelme进行打标,打标效果图如下图。

labelme-第1步labelme进行图片打标.png

labelme-第2步打标示范图.png

第2步:使用 json2label 函数进行标签转换

复制文末代码命名为genlabel.py,修改"main"底下的数据路径与标签映射,运行genlabel.py函数,生成分割标签掩码。

py修改代码数据路径.png

labelme-第3步鱼打标效果图.png

genlabel.py函数

import shutil,base64,io,os,json,glob,math,warnings
import numpy as np
import PIL
import cv2
from PIL import (ExifTags, Image, ImageOps, ImageDraw, ImageDraw, ImageFont)
from skimage import img_as_ubyte
import tensorflow as tf
import os.path as osp
from tqdm import trange
import matplotlib.pyplot as plt


''' @定制化标签 '''
def shape_to_mask(img_shape, points, shape_type=None,
                  line_width=10, point_size=5):
    mask = np.zeros(img_shape[:2], dtype=np.uint8)
    mask = PIL.Image.fromarray(mask)
    draw = PIL.ImageDraw.Draw(mask)
    xy = [tuple(point) for point in points]
    if shape_type == 'circle':
        assert len(xy) == 2, 'Shape of shape_type=circle must have 2 points'
        (cx, cy), (px, py) = xy
        d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
        draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
    elif shape_type == 'rectangle':
        assert len(xy) == 2, 'Shape of shape_type=rectangle must have 2 points'
        draw.rectangle(xy, outline=1, fill=1)
    elif shape_type == 'line':
        assert len(xy) == 2, 'Shape of shape_type=line must have 2 points'
        draw.line(xy=xy, fill=1, width=line_width)
    elif shape_type == 'linestrip':
        draw.line(xy=xy, fill=1, width=line_width)
    elif shape_type == 'point':
        assert len(xy) == 1, 'Shape of shape_type=point must have 1 points'
        cx, cy = xy[0]
        r = point_size
        draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
    else:
        assert len(xy) > 2, 'Polygon must have points more than 2'
        draw.polygon(xy=xy, outline=1, fill=1) # xy 为[(x,y),(x.y),(...,...),...]
        pass
    '''
    ################### 定制化标签 ###################
    此处可根据需要打标签图draw进行自定义业务标签
    '''
    mask = np.array(mask, dtype=bool)
    return mask

def img_b64_to_arr(img_b64):
    f = io.BytesIO()
    f.write(base64.b64decode(img_b64))
    img_arr = np.array(PIL.Image.open(f))
    return img_arr


def img_arr_to_b64(img_arr): #没调用
    img_pil = PIL.Image.fromarray(img_arr)
    f = io.BytesIO()
    img_pil.save(f, format='PNG')
    img_bin = f.getvalue()
    if hasattr(base64, 'encodebytes'):
        img_b64 = base64.encodebytes(img_bin)
    else:
        img_b64 = base64.encodestring(img_bin)
    return img_b64


def img_data_to_png_data(img_data): # 没调用
    with io.BytesIO() as f:
        f.write(img_data)
        img = PIL.Image.open(f)

        with io.BytesIO() as f:
            img.save(f, 'PNG')
            f.seek(0)
            return f.read()


def apply_exif_orientation(image): #没调用
    try:
        exif = image._getexif()
    except AttributeError:
        exif = None

    if exif is None:
        return image

    exif = {
        PIL.ExifTags.TAGS[k]: v
        for k, v in exif.items()
        if k in PIL.ExifTags.TAGS
    }

    orientation = exif.get('Orientation', None)

    if orientation == 1:
        # do nothing
        return image
    elif orientation == 2:
        # left-to-right mirror
        return PIL.ImageOps.mirror(image)
    elif orientation == 3:
        # rotate 180
        return image.transpose(PIL.Image.ROTATE_180)
    elif orientation == 4:
        # top-to-bottom mirror
        return PIL.ImageOps.flip(image)
    elif orientation == 5:
        # top-to-left mirror
        return PIL.ImageOps.mirror(image.transpose(PIL.Image.ROTATE_270))
    elif orientation == 6:
        # rotate 270
        return image.transpose(PIL.Image.ROTATE_270)
    elif orientation == 7:
        # top-to-right mirror
        return PIL.ImageOps.mirror(image.transpose(PIL.Image.ROTATE_90))
    elif orientation == 8:
        # rotate 90
        return image.transpose(PIL.Image.ROTATE_90)
    else:
        return image

def polygons_to_mask(img_shape, polygons, shape_type=None):
    warnings.warn(
        "The 'polygons_to_mask' function is deprecated, "
        "use 'shape_to_mask' instead."
    )
    return shape_to_mask(img_shape, points=polygons, shape_type=shape_type)


def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):

    assert type in ['class', 'instance']

    cls = np.zeros(img_shape[:2], dtype=np.int32)
    if type == 'instance':
        ins = np.zeros(img_shape[:2], dtype=np.int32)
        instance_names = ['_background_']
    for shape in shapes:
        points = shape['points']
        label = shape['label']
        shape_type = shape.get('shape_type', None)
        if type == 'class':
            cls_name = label
        elif type == 'instance':
            cls_name = label.split('-')[0]
            if label not in instance_names:
                instance_names.append(label)
            ins_id = instance_names.index(label)
        cls_id = label_name_to_value[cls_name]
        #mask = shape_to_mask(img_shape[:2], points, shape_type) # detail
        mask = shape_to_mask(img_shape[:2], points, shape_type, line_width=10, point_size=5) # detail
        cls[mask] = cls_id # 对每个label进行赋值类别
        if type == 'instance':
            ins[mask] = ins_id
            pass
        pass
    if type == 'instance':
        return cls, ins
    return cls


def labelme_shapes_to_label(img_shape, shapes):
    warnings.warn('labelme_shapes_to_label is deprecated, so please use '
                'shapes_to_label.')

    label_name_to_value = {'_background_': 0}
    for shape in shapes:
        label_name = shape['label']
        if label_name in label_name_to_value:
            label_value = label_name_to_value[label_name]
        else:
            label_value = len(label_name_to_value)
            label_name_to_value[label_name] = label_value

    lbl = shapes_to_label(img_shape, shapes, label_name_to_value)
    return lbl, label_name_to_value

def masks_to_bboxes(masks):
    if masks.ndim != 3:
        raise ValueError(
            'masks.ndim must be 3, but it is {}'
            .format(masks.ndim)
        )
    if masks.dtype != bool:
        raise ValueError(
            'masks.dtype must be bool type, but it is {}'
            .format(masks.dtype)
        )
    bboxes = []
    for mask in masks:
        where = np.argwhere(mask)
        (y1, x1), (y2, x2) = where.min(0), where.max(0) + 1
        bboxes.append((y1, x1, y2, x2))
    bboxes = np.asarray(bboxes, dtype=np.float32)
    return bboxes

def label_colormap(N=256):

    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    cmap = np.zeros((N, 3))
    for i in range(0, N):
        id = i
        r, g, b = 0, 0, 0
        for j in range(0, 8):
            r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
            g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
            b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
            id = (id >> 3)
        cmap[i, 0] = r
        cmap[i, 1] = g
        cmap[i, 2] = b
    cmap = cmap.astype(np.float32) / 255
    return cmap


def _validate_colormap(colormap, n_labels):
    if colormap is None:
        colormap = label_colormap(n_labels)
    else:
        assert colormap.shape == (colormap.shape[0], 3), \
            'colormap must be sequence of RGB values'
        assert 0 <= colormap.min() and colormap.max() <= 1, \
            'colormap must ranges 0 to 1'
    return colormap


# similar function as skimage.color.label2rgb
def label2rgb(
    lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
):
    if n_labels is None:
        n_labels = len(np.unique(lbl))

    colormap = _validate_colormap(colormap, n_labels)
    colormap = (colormap * 255).astype(np.uint8)

    lbl_viz = colormap[lbl]
    lbl_viz[lbl == -1] = (0, 0, 0)  # unlabeled

    if img is not None:
        img_gray = PIL.Image.fromarray(img).convert('LA')
        img_gray = np.asarray(img_gray.convert('RGB'))
        # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        # img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
        lbl_viz = alpha * lbl_viz + (1 - alpha) * img_gray
        lbl_viz = lbl_viz.astype(np.uint8)
        pass
    return lbl_viz


def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
    """Draw pixel-wise label with colorization and label names.

    label: ndarray, (H, W)
        Pixel-wise labels to colorize.
    img: ndarray, (H, W, 3), optional
        Image on which the colorized label will be drawn.
    label_names: iterable
        List of label names.
    """
    backend_org = plt.rcParams['backend']
    plt.switch_backend('agg')
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0,
                        wspace=0, hspace=0)
    plt.margins(0, 0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    if label_names is None:
        label_names = [str(l) for l in range(label.max() + 1)]
    colormap = _validate_colormap(colormap, len(label_names))
    label_viz = label2rgb(
        label, img, n_labels=len(label_names), colormap=colormap, **kwargs
    )
    plt.imshow(label_viz)
    plt.axis('off')
    plt_handlers = []
    plt_titles = []
    for label_value, label_name in enumerate(label_names):
        if label_value not in label:
            continue
        fc = colormap[label_value]
        p = plt.Rectangle((0, 0), 1, 1, fc=fc)
        plt_handlers.append(p)
        plt_titles.append('{value}: {name}'
                          .format(value=label_value, name=label_name))
    plt.legend(plt_handlers, plt_titles, loc='lower right', framealpha=.5)
    f = io.BytesIO()
    plt.savefig(f, bbox_inches='tight', pad_inches=0)
    plt.cla()
    plt.close()
    plt.switch_backend(backend_org)
    out_size = (label_viz.shape[1], label_viz.shape[0])
    out = PIL.Image.open(f).resize(out_size, PIL.Image.BILINEAR).convert('RGB')
    out = np.asarray(out)
    return out

def draw_instances(
    image=None,
    bboxes=None,
    labels=None,
    masks=None,
    captions=None,
):

    # TODO(wkentaro)
    assert image is not None
    assert bboxes is not None
    assert labels is not None
    assert masks is None
    assert captions is not None

    viz = PIL.Image.fromarray(image)
    draw = PIL.ImageDraw.ImageDraw(viz)

    font_path = osp.join(
        osp.dirname(matplotlib.__file__),
        'mpl-data/fonts/ttf/DejaVuSans.ttf'
    )
    font = PIL.ImageFont.truetype(font_path)

    colormap = label_colormap(255)
    for bbox, label, caption in zip(bboxes, labels, captions):
        color = colormap[label]
        color = tuple((color * 255).astype(np.uint8).tolist())
        xmin, ymin, xmax, ymax = bbox
        draw.rectangle((xmin, ymin, xmax, ymax), outline=color)
        draw.text((xmin, ymin), caption, font=font)

    return np.asarray(viz)

def lblsave(filename, lbl):
    if os.path.splitext(filename)[1] != '.png':
        filename += '.png'
    # Assume label ranses [-1, 254] for int32,
    # and [0, 255] for uint8 as VOC.
    if lbl.min() >= -1 and lbl.max() < 255:
        lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
        colormap = label_colormap(255)
        lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten())
        lbl_pil.save(filename)
    else:
        raise ValueError(
            '[%s] Cannot save the pixel-wise class label as PNG. '
            'Please consider using the .npy format.' % filename
        )
        pass
    pass


def json2label(json_root="dataset\\2021-08-30", img_root="dataset\\2021-08-30", save_root="dataset\\2021-08-30",
               label_name_to_value = {'_background_': 0, "object":1, "flaw":2},
               use_label=True, use_color=True, use_jpeg=True, use_visual=True,
               is_show_time=1):
    ####### 默认输出路径 ########
    save_img=os.path.join(save_root, "JPEGImages")
    save_ano=os.path.join(save_root, "Annotations")
    save_label=os.path.join(save_root, "SegmentationClassRaw")
    save_color=os.path.join(save_root, "SegmentationClassPNG")
    save_visual=os.path.join(save_root, "SegmentationClassVisualization")
    
    if use_jpeg and not os.path.exists(save_img):
        os.makedirs(save_img)
    if use_label and not os.path.exists(save_ano):
        os.makedirs(save_ano)
    if use_color and not os.path.exists(save_color):
        os.makedirs(save_color)
    if use_label and not os.path.exists(save_label):
        os.makedirs(save_label)
    if use_visual and not os.path.exists(save_visual):
        os.makedirs(save_visual)
        pass

    # 检索对应的json和img
    img_dirs=glob.glob(os.path.join(img_root, "*[jpg,JPG,JPEG,bmp,BMP]")) # list 的文件
    print("当前root下有图片个数: ", len(img_dirs))
    for j in trange(len(img_dirs)):
        img_dir=img_dirs[j]
        path,name=os.path.split(img_dir)
        first,second=os.path.splitext(name)
        json_dir=os.path.join(json_root, first+".json")

        if not os.path.exists(img_dir) or not os.path.isfile(img_dir):
            print("no_exists: ", j, img_dir)
            continue
            pass
        if not os.path.exists(json_dir) or not os.path.isfile(json_dir):
            print("no_exists: ", j, json_dir)
            continue
            pass

        # 进行json解析和标签转换
        if os.path.isfile(json_dir):
            data = json.load(open(json_dir))
            pass
        if data['imageData']:
            imageData = data['imageData']
        else:
            with open(img_dir, 'rb') as f:
                imageData = f.read()
                imageData = base64.b64encode(imageData).decode('utf-8')
                pass
            pass
        ''' 解析图片: img : (array type)  '''
        img = img_b64_to_arr(imageData)    

        # mask the label_name
        for shape in data['shapes']:
            label_name = shape['label'] # 获取label名
            if label_name in label_name_to_value:
                label_value = label_name_to_value[label_name]
            else:
                label_value = len(label_name_to_value)
                label_name_to_value[label_name] = label_value

        label_values, label_names = [], []
        for ln, lv in sorted(label_name_to_value.items(), key=lambda x: x[1]):
            label_values.append(lv)
            label_names.append(ln)
            pass

        if label_values != list(range(len(label_values))):
            print("assert label_values")

        ''' 解析标签: "seg_label"  lbl已经为0,1,2,3,的标签了 '''
        lbl = shapes_to_label(img.shape, data['shapes'], label_name_to_value)

        ''' 掩码可视化:"draw_mask"  lbl_viz ''' 
        captions = ['{}: {}'.format(lv, ln)
        for ln, lv in label_name_to_value.items()]
        lbl_viz = draw_label(lbl, img, captions)    

        if is_show_time>0:
            is_show_time-=1
            plt.figure(figsize=(20,20))
            plt.subplot(121)
            plt.imshow(lbl)
            plt.subplot(122)
            plt.imshow(lbl_viz)
            plt.show()
            pass

        # 保存为统一图片
        if use_jpeg:
            PIL.Image.fromarray(img).save(os.path.join(save_img, first+".jpeg"))

        # 保存为数字标签图
        if use_label:
            with tf.io.gfile.GFile(os.path.join(save_label, first+".png"), mode='w') as f:
                Image.fromarray(lbl.astype(dtype=np.uint8)).save(f, 'PNG')

        # 保存为彩色标签图
        if use_color:
            lblsave(os.path.join(save_color, first+".png"), lbl) 

        # 保存为掩码可视化
        if use_visual:
            PIL.Image.fromarray(lbl_viz).save(os.path.join(save_visual, first+".png"))

        # 移动一份数据-到aon和jpg文件夹,统一输出格式
        if use_label:
            shutil.move(img_dir,  os.path.join(save_ano, name)) if os.path.exists(img_dir) else None
            shutil.move(json_dir, os.path.join(save_ano, first+".json")) if os.path.exists(json_dir) else None
            pass
        pass
    pass

if __name__=="__main__":
    '''
    说明:
    函数 shapes_to_label  可以修改线宽,支持 矩形,点
    函数 shape_to_mask    标签扩充放大, 使用膨胀'''
    json2label(json_root="data_fishs/Annotations", 
               img_root ="data_fishs/JPEGImages",
               save_root="data_fishs/outputs",
               label_name_to_value = {'_background_': 0, "fish":1, "head":2},
               use_label=True,   # 
               use_color=True,   # 
               use_jpeg=True,    # 
               use_visual=True,  # 
               is_show_time=1)

图像语义分割实践

你可能感兴趣的:(图像语义分割实践(一)数据制作与转换)