我们用CNN卷集神经网络实现卫星数据分类,数据有两种,lake 和 airplane
导入包
import tensorflow as tf
import numpy as np
import pathlib
import matplotlib.pyplot as plt
查看tensorflow的版本
print('Tensorflow version: {}'.format(tf.__version__))
Tensorflow version: 2.3
获取文件路径
data_dir = './2_class'
用pathlib构建路径对像
data_root = pathlib.Path(data_dir)
对目录进行迭代查看文件路径及对象
for item in data_root.iterdir():
print(item)
/2_class/airplane
/2_class/lake
使用glob方法及正则表达式提取目录里面所有文件
all_image_path = list(data_root.glob('*/*'))
数据的数量
len(all_image_path)
1400
通过切片查看前2个文件
all_image_path[:2]
[PosixPath('/2_class/airplane/airplane_446.jpg'),
PosixPath('/2_class/airplane/airplane_550.jpg')]
通过切片查看最后3个文件
all_image_path[-3:]
[PosixPath('/2_class/lake/lake_360.jpg'),
PosixPath('/2_class/lake/lake_375.jpg'),
PosixPath('/2_class/lake/lake_092.jpg')]
把文件目录转化为str形式
all_image_path = [str(path) for path in all_image_path]
显示其中的一些数据
all_image_path[10:12]
['/2_class/airplane/airplane_344.jpg',
'/2_class/airplane/airplane_599.jpg']
由 PosixPath(’/2_class/lake/lake_360.jpg’)变成 ‘/2_class/airplane/airplane_599.jpg’
乱序图片
import random
random.shuffle(all_image_path)
图片数量
image_count = len(all_image_path)
image_count
提取目标值
label_names = sorted(item.name for item in data_root.glob('*/'))
label_names
['airplane', 'lake']
目标值airplane’和lake转换成训练时的目标值 0和1字典
label_to_index = dict((name,index) for index,name in enumerate(label_names))
label_to_index
{'airplane': 0, 'lake': 1}
pathlib.Path('/2_class/airplane/airplane_240.jpg').parent.name
'airplane'
把所有数据的标签转换
all_image_labels = [label_to_index[pathlib.Path(p).parent.name] for p in all_image_path]
all_image_labels[:5]
[0, 0, 0, 1, 1]
把标签0和1反转换为airplane’和lake
index_to_label = dict((v, k) for k, v in label_to_index.items())
index_to_label
{0: 'airplane', 1: 'lake'}
显示图片和label
import IPython.display as display
for n in range(3): #随机显示3张图片
image_index = random.choice(range(len(all_image_path)))
display.display(display.Image(all_image_path[image_index]))
print(index_to_label[all_image_labels[image_index]])
print()
imag_path = all_image_paths[0]
imag_path
'dataset/2_class/lake/lake_700.jpg'
用 tf.io中的方法读取第一张图像
img_raw = tf.io.read_file(imag_path)
img_raw
图像解码
img_tensor = tf.image.decode_image(img_raw)
img_tensor.shape
TensorShape([256, 256, 3])
图形数据类型
img_tensor.dtype
tf.uint8
自定义一个函数把前面几个方法一并完成,从读取到解码到大小转化再到类型转化最后归一化
def load_preprocess_image(img_path):
img_raw = tf.io.read_file(img_path)
img_tensor = tf.image.decode_jpeg(img_raw, channels=3)
img_tensor = tf.image.resize(img_tensor, (256, 256))
img_tensor = tf.cast(img_tensor, tf.float32)
img = img_tensor/255
return img
测试第500张
image_path = all_image_path[500]
plt.imshow(load_preprocess_image(image_path))
构建image图像的dataset
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_dataset = path_ds.map(load_preprocess_image)
构建label的dataset,label不需要map自定义的函数
label_dataset = tf.data.Dataset.from_tensor_slices(all_image_labels)
查看image_dataset的形状
image_dataset
label_dataset
把image_dataset和label_dataset zip到一起(也不可以不用)
dataset = tf.data.Dataset.zip((image_dataset, label_dataset))
dataset
设计训练数据和测试数据的数量
test_count = int(image_count*0.2)
train_count = image_count - test_count
test_count, train_count
(280, 1120)
创建训练集和测试集
train_dataset = dataset.skip(test_count)
test_dataset = dataset.take(test_count)
设置batch_size
BACH_SIZE = 32
设置训练数据输入的状况
train_dataset = train_dataset.repeat().shuffle(buffer_size=train_count).batch(BACH_SIZE)
设置测试数据测试的状况
test_dataset = test_dataset.batch(BACH_SIZE)
查看训练数据和测试数据的状况
test_dataset
train_dataset
建立模型
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (3, 3),input_shape=(256, 256, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(1024, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(1024, activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
模型概述
model.summary()
模型编译
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['acc'])
设置批次
steps_per_epoch = train_count//BACH_SIZE
validation_step = test_count//BACH_SIZE
模型训练
history = model.fit(train_dataset, epochs=30, steps_per_epoch=steps_per_epoch, validation_data=test_dataset, validation_steps=validation_step)
模型训练结果
plt.plot(history.epoch, history.history.get('loss'), label='loss')
plt.plot(history.epoch, history.history.get('val_loss'), label='val_loss')
plt.legend()
plt.plot(history.epoch, history.history.get('acc'), label='loss')
plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
plt.legend()
参考文献:
https://study.163.com/course/introduction/1004573006.htm