【Keras】基于SegNet和U-Net的遥感图像语义分割(二)

训练U_net模型:unet_train.py

代码:
https://github.com/fuyou123/Segmentation_Unet
【Keras】基于SegNet和U-Net的遥感图像语义分割(二)_第1张图片

U-Net网络训练

1. args = args_parse()

def args_parse():
    # construct the argument parse and parse the arguments
    ap = argparse.ArgumentParser()
    ap.add_argument("-d", "--data", default="./unet_train/road/",
                    help="training data's path")
    ap.add_argument("-m", "--model", default="Trained_Unet_Model.h5",
                    help="path to output model")
    ap.add_argument("-p", "--plot", type=str, default="plot.png",
                    help="path to output accuracy/loss plot")
    args = vars(ap.parse_args()) 
    return args
注意:对于命令行参数的设定,参考Argparse的使用。由于我的调试过程是在windows环境下,所以路径设为了默认值,便于操作。

2. model = unet() #模型初始化

调试可能出现的问题:ValueError: Negative dimension size caused by subtracting 2 from 1 for 'max_pooling2d_2/Max...
解决:https://blog.csdn.net/weixin_43723625/article/details/104918997
def unet():
    inputs = Input((3, img_w, img_h))

    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)
    conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)
    conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)
    conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)
    conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)
    conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)

    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)
    conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)
    conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)
    conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)
    conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)

    conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)
    #conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)

    model = Model(inputs=inputs, outputs=conv10)
    model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

【Keras】基于SegNet和U-Net的遥感图像语义分割(二)_第2张图片

3. get_train_val()

def get_train_val(val_rate = 0.25):
    train_url = []    
    train_set = []    # 训练数据集
    val_set  = []     # 测试训练集
    # 将待训练的图片路径存入列表
    for pic in os.listdir(filepath + 'src'):
        train_url.append(pic)
    random.shuffle(train_url)   # 将序列的所有元素随机排序
    total_num = len(train_url)
    # 将数据集分为2部分,3/4用于训练,1/4用于检验
    val_num = int(val_rate * total_num)
    for i in range(len(train_url)):
        if i < val_num:
            val_set.append(train_url[i]) 
        else:
            train_set.append(train_url[i])
    return train_set,val_set

【Keras】基于SegNet和U-Net的遥感图像语义分割(二)_第3张图片
由于在训练过程中,为了检验训练中的效果,一般是从原数据集中取出一小部分数据用于检验训练的效果。而不是等全部训练结束后,再检验训练模型。

4. model.fit_generator()

# 使用 Python 生成器逐批生成的数据,按批次训练模型
 H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1,  
                    validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1)  

参数:
(1) generator=generateData(BS,train_set) # BS=16
【Keras】基于SegNet和U-Net的遥感图像语义分割(二)_第4张图片
定义训练时,选取数据的方式是,将数据集中每16个为一组进行训练
(2) steps_per_epoch=train_numb//BS # 整数,表示一次epoch需要训练的组数
(3) epochs=EPOCHS # 迭代次数

(4) validation_data=generateValidData(BS,val_set) # BS=16 , 定义验证集的生成器
(5) validation_steps=valid_numb//BS # 整数,表示一次epoch需要训练的组数

结果:
(调试过程,可以将数据集减少些,同时减少迭代的次数)
在这里插入图片描述
【Keras】基于SegNet和U-Net的遥感图像语义分割(二)_第5张图片

你可能感兴趣的:(图像语义分割)