做东西,最重要的就是动手了,所以这篇文章动手跑了一个fcn32s和fcn8s以及deeplab v3+的例子,这个例子的数据集选用自动驾驶相关竞赛的kitti数据集, FCN8s在训练过程中用tensorflow2.0自带的评估能达到91%精确率, deeplab v3+能达到97%的准确率。这篇文章适合入门级选手,在文章中不再讲述fcn的结构,直接百度就可以搜到。
文章使用的是tensorflow2.0框架,该框架集成了keras,在模型的训练方面极其简洁,不像tf1.x那么复杂,综合其他深度学习框架,发现这个是最适合新手使用的一种。
文章中用到的库函数,参数等均可在tensorflow2.0 api中查找到。
文章的代码在github可以获取,地址:https://github.com/fengshilin/tf2.0-FCN
文章的结构如下:
tensorflow模型的输入维度为[batch, h, w, c]分别表示批量,图片的长,宽,通道数。之所以要加上batch(一批同时训练多少个样本),是因为python运算中,把循环变为矩阵的运算,速度会快很多。具体可以参看吴恩达的这一节课。
在语义分割中,数据集的训练集包括影像与label,一般影像是三通道的影像,label形式比较多,有的是三通道影像,有的是单通道的预测值。
模型输入的image需要转成tensor,格式为float32;label需要转为0,1,2的格式,表示该像素值属于第几类。
由于模型学习后输出的结果是各个类别的概率,比如分三类,则输出的结果为[0.1,0.2,0.7],表示属于第一类的概率为0.1,第二类的概率为0.2, 第三类的概率为0.7,所以label需要将像素值转换为类别值0,1,2,使得模型在训练时将输出向0,1,2靠拢。
注:上述label的预处理与选择的loss函数有关,以下举了两个例子。
在kitti的数据集中,主要分两类,一类是背景,一类是道路,且我们选择了softmax_cross_entropy_with_logits作为损失函数,所以需要将label的像素值转为[0, 1]或者[1, 0],若当前像素为背景,则像素值取[1, 0],若当前像素值为道路,则取值[0,1]。
先生成图片和label的列表。
import tensorflow as tf
import cv2
import os
import scipy
import numpy as np
train_dir = os.path.join( "data", "train", "img")+"/" # os.path.join是做连接符,等同于data/train/img
train_label_dir = os.path.join( "data", "train", "label")+"/"
train_list_dir = os.listdir(train_dir) # 列出train-image目录下的图片名字, 形如["um_000001.png", ...]
train_list_dir.sort() # sort是排序,使得train_list与train_label_list一一对应
train_label_list_dir = os.listdir(train_label_dir) # 列出train-label目录下的图片名字,形如["um_road_000001.png", ...]
train_label_list_dir.sort()
assert len(train_list_dir)==len(train_label_list_dir), "训练的图片与标签数量不一致"
train_filenames = [train_dir + filename for filename in train_list_dir] # 生成图片路径,形如["data/train/img/um_000001.png"]
train_label_filenames = [train_label_dir +
filename for filename in train_label_list_dir] # 生成label路径,形如["data/train/img/um_road_000001.png"]
生成迭代器,以便生成tf中dataset类型的数据,直接作为模型的输入做训练,所以我们只要生成一个dataset就可以了,生成dataset有很多方法,这里我们用generator。
对数据做预处理,需要注意:
1.数据集的图片大小都不一致,需要resize到一致的大小(160,576), 最好长宽都是32的倍数,因为fcn会把长宽缩小到原图的1/32,然后再做上采样。
2.对影像做预处理是可选的,不一定按照代码中的来做。
3.对label的处理,先把等于背景色的点设置为True,再与False做concatenate, 得到[True, False]这样的一组结果。也就是说,如果某个点的像素值为背景色,则我们将该点的值设置为[True, False], 若不是背景色,则设置为[False, True],如果感兴趣可以自己打印出来看一下每一步的值。
def train_generator():
"""训练集生成器"""
# 对之前生成的路径列表用zip打包,生成形如[("data/train/img/um_000001.png", "data/train/img/um_road_000001.png"), ....]的列表,然后遍历列表,取出一一对应的图片与label。
for train_file_name, train_label_filename in zip(train_filenames, train_label_filenames):
image, label = handle_data(train_file_name, train_label_filename)
# 这里第一次取的image, label = "data/train/img/um_000001.png", "data/train/img/um_road_000001.png", 第二次往后迭代。用yield返回,这是python的generator的用法
yield image, label
def test_generator():
"""测试集生成器"""
for test_filename in test_filenames:
image = handle_data(test_filename)
yield image
def handle_data(train_filenames, train_label_filenames=None):
"""对数据做处理"""
image = scipy.misc.imresize(
scipy.misc.imread(train_filenames), image_shape) # 因为数据的size都不一样,所以需要统一resize到我们约定的size(160,576)
# 对影像的处理,去除阴影,这一步可以不做,只是效果会差一些。
image_yuv = cv2.cvtColor(image, cv2.COLOR_RGB2YUV)
image_yuv[:, :, 0] = cv2.equalizeHist(image_yuv[:, :, 0])
image = cv2.cvtColor(image_yuv, cv2.COLOR_YUV2RGB)
对label做处理
if train_label_filenames is not None:
gt_image = scipy.misc.imresize(
scipy.misc.imread(train_label_filenames), image_shape)
background_color = np.array([255, 0, 0])
gt_bg = np.all(gt_image == background_color, axis=2)
gt_bg = gt_bg.reshape(*gt_bg.shape, 1)
gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2)
return np.array(image), gt_image
else:
return np.array(image)
生成dataset, 这里用到tf.data.Dataset库的方法from_generator, 除了from_generator还有很多方法可以生成dataset。生成dataset后可以直接将该dataset放入模型训练。
在生成dataset后可以对dataset做映射,有map、shauffle、batch等。
from_generator的参数为
train_dataset = tf.data.Dataset.from_generator(
train_generator, (tf.float32, tf.float32), (tf.TensorShape([None, None, None]), tf.TensorShape([None, None, None])))
train_dataset = train_dataset.shuffle(buffer_size=len(train_filenames)) # 打乱数据的顺序,即不按照之前sort的顺序读取
train_dataset = train_dataset.batch(batch_size) # 设置批量,同时训练的数据量。tensorflow模型的输入维度为[batcn, h, w, c]
至此,数据的预处理已经完成,后续涉及模型的引用与微调,以及模型的保存与读取。