unet示例程序代码分析

本代码由ipynb文件转换为py文件,所以有的地方有改动

  1. 首先导入各种包
from __future__ import division, print_function
# 为了兼容python2而导入print_function 这样即使在python2也得按照python3的输出格式
# division为精算除法 如3/4=0.75 在python3中这都是默认的
# get_ipython().magic('matplotlib inline')
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from tf_unet import image_gen
from tf_unet import unet
from tf_unet import util
  1. 画图的设置以及设置随机数种子
plt.rcParams['image.cmap'] = 'gist_earth'
np.random.seed(98765)
  1. 设置图片尺寸以及建立生成随机数据集的类的一个实例
nx = 572
ny = 572
generator = image_gen.GrayScaleDataProvider(nx, ny, cnt=20)
x_test, y_test = generator(1)

展开
3.1 查看GrayScaleDataProvider这个类

class GrayScaleDataProvider(BaseDataProvider):
    channels = 1
    n_class = 2
    
    def __init__(self, nx, ny, **kwargs):
        super(GrayScaleDataProvider, self).__init__()
        self.nx = nx
        self.ny = ny
        self.kwargs = kwargs
        rect = kwargs.get("rectangles", False)
        if rect:
            self.n_class=3
        
    def _next_data(self):
        return create_image_and_label(self.nx, self.ny, **self.kwargs)

3.1.1 查看父级类BaseDataProvider

class BaseDataProvider(object):
    """
    Abstract base class for DataProvider implementation. Subclasses have to
    overwrite the `_next_data` method that load the next data and label array.
    This implementation automatically clips the data with the given min/max and
    normalizes the values to (0,1]. To change this behavoir the `_process_data`
    method can be overwritten. To enable some post processing such as data
    augmentation the `_post_process` method can be overwritten.

    :param a_min: (optional) min value used for clipping
    :param a_max: (optional) max value used for clipping

    """
    
    channels = 1
    n_class = 2

    def __init__(self, a_min=None, a_max=None):
        self.a_min = a_min if a_min is not None else -np.inf
        self.a_max = a_max if a_min is not None else np.inf

    def _load_data_and_label(self):
        data, label = self._next_data()
            
        train_data = self._process_data(data)
        labels = self._process_labels(label)
        
        train_data, labels = self._post_process(train_data, labels)
        
        nx = train_data.shape[1]
        ny = train_data.shape[0]

        return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class),
    
    def _process_labels(self, label):
        if self.n_class == 2:
            nx = label.shape[1]
            ny = label.shape[0]
            labels = np.zeros((ny, nx, self.n_class), dtype=np.float32)
            labels[..., 1] = label
            labels[..., 0] = ~label
            return labels
        
        return label
    
    def _process_data(self, data):
        # normalization
        data = np.clip(np.fabs(data), self.a_min, self.a_max)
        data -= np.amin(data)

        if np.amax(data) != 0:
            data /= np.amax(data)

        return data
    
    def _post_process(self, data, labels):
        """
        Post processing hook that can be used for data augmentation
        
        :param data: the data array
        :param labels: the label array
        """
        return data, labels
    
    def __call__(self, n):
        train_data, labels = self._load_data_and_label() #增加了一个维度的单张训练图片,以及增加了一个维度的labels,labels里面包含label以及~label
#这里的train_data,和labels,每个train_data[i,...],和labels[i,...]都代表了一张图和对应的标签label和~label
        nx = train_data.shape[1]
        ny = train_data.shape[2]
    
        X = np.zeros((n, nx, ny, self.channels))
        Y = np.zeros((n, nx, ny, self.n_class))
    
        X[0] = train_data
        Y[0] = labels
        for i in range(1, n):
            train_data, labels = self._load_data_and_label() #导入下一个图片和标签
            X[i] = train_data
            Y[i] = labels
    # 该方法返回张量X Y 存储了n-1个图片和对应的标签
        return X, Y

其中call()方法使得实例对象变得可以调用
聚焦该方法,第一行调用了_load_data_and_label()

    def _load_data_and_label(self):
        data, label = self._next_data() #生成单张图片以及对应的标签
            
        train_data = self._process_data(data) # 将图片每个像素得灰度值归一化
        labels = self._process_labels(label) # 将布尔类型的lable放在lablels的第一个通道,将~lable放在labels的第0个通道
        
        train_data, labels = self._post_process(train_data, labels) #pass 这个钩子函数在这个模块里还没编辑,但是其它模块里面编辑了哦
        
        nx = train_data.shape[1]
        ny = train_data.shape[0]

        return train_data.reshape(1, ny, nx, self.channels), labels.reshape(1, ny, nx, self.n_class),
# 输出 增加一个维度的单张图片(tran_data),和增加一个维度的lebels

这里又调用了_next_data()方法,但是该方法不在父级类BaseDataProvider中,
在子类GrayScaleDataProvider中有该方法

    def _next_data(self):
        return create_image_and_label(self.nx, self.ny, **self.kwargs)

查看create_image_and_label()方法
该方法可以生成单张图像和标签

def create_image_and_label(nx,ny, cnt = 10, r_min = 5, r_max = 50, border = 92, sigma = 20, rectangles=False):
    
    
    image = np.ones((nx, ny, 1))
    label = np.zeros((nx, ny, 3), dtype=np.bool)
    mask = np.zeros((nx, ny), dtype=np.bool)
    for _ in range(cnt):
        a = np.random.randint(border, nx-border)
        b = np.random.randint(border, ny-border)
        r = np.random.randint(r_min, r_max)
        h = np.random.randint(1,255)

        y,x = np.ogrid[-a:nx-a, -b:ny-b]
        m = x*x + y*y <= r*r
        mask = np.logical_or(mask, m)

        image[m] = h

    label[mask, 1] = 1
    
    if rectangles:
        mask = np.zeros((nx, ny), dtype=np.bool)
        for _ in range(cnt//2):
            a = np.random.randint(nx)
            b = np.random.randint(ny)
            r =  np.random.randint(r_min, r_max)
            h = np.random.randint(1,255)
    
            m = np.zeros((nx, ny), dtype=np.bool)
            m[a:a+r, b:b+r] = True
            mask = np.logical_or(mask, m)
            image[m] = h
            
        label[mask, 2] = 1
        
        label[..., 0] = ~(np.logical_or(label[...,1], label[...,2]))
    
    image += np.random.normal(scale=sigma, size=image.shape)
    image -= np.amin(image)
    image /= np.amax(image)
    
    if rectangles:
        return image, label
    else:
        return image, label[..., 1]
  • 所以,_next_data()方法的作用就是生成单张图片和其对应的标签

你可能感兴趣的:(unet示例程序代码分析)