百度网盘提取码:lala
import tensorflow as tf
import random
import numpy as np
import os
import glob
# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# 图片数据的加载
def load_image(path, label):
image = tf.io.read_file(path)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [256, 256])
image = tf.cast(image, tf.float32)
image = image / 255
return image, label
# 数据的字典构建
def make_data_dict():
all_image_path = glob.glob(r'dataset\birds_train/*/*.jpg')
all_label_name = [image_path.split('\\')[2].split('.')[1] for image_path in all_image_path]
label_name = np.unique(all_label_name)
label_to_index = dict((name, i) for i, name in enumerate(label_name))
index_to_label = dict((v, k) for k, v in label_to_index.items())
return label_to_index, index_to_label
# 数据的构建
def make_data():
all_image_path = glob.glob(r'dataset\birds_train/*/*.jpg')
random.shuffle(all_image_path)
all_label_name = [image_path.split('\\')[2].split('.')[1] for image_path in all_image_path]
label_to_index, index_to_label = make_data_dict()
all_image_label = [label_to_index.get(name) for name in all_label_name]
dataset = tf.data.Dataset.from_tensor_slices((all_image_path, all_image_label)).map(load_image)
data_count = len(all_image_path)
test_count = int(data_count * 0.2)
train_count = data_count - test_count
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_dataset = dataset.skip(test_count)
test_dataset = dataset.take(test_count)
train_dataset = train_dataset.repeat().shuffle(300).batch(batch_size=8)
test_dataset = test_dataset.batch(batch_size=8)
return train_dataset, test_dataset, train_count, test_count
if __name__ == '__main__':
train_data, test_data, train_num, test_num = make_data()
print(train_data)
import tensorflow as tf
import os
# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# 模型的构建
def make_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(64, (3, 3), input_shape=(256, 256, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(256, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(256, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1024, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(200)
])
model.summary()
return model
if __name__ == '__main__':
make_model()
import tensorflow as tf
import os
from data_loader import make_data_dict, make_data
from model import make_model
# 环境变量的配置
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# 数据的加载
train_dataset, test_dataset, train_count, test_count = make_data()
# 模型的构建
model = make_model()
# 模型的相关配置
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['acc']
)
# 模型的训练
steps_per_epoch = train_count // 8
validation_steps = test_count // 8
history = model.fit(train_dataset, epochs=100, steps_per_epoch=steps_per_epoch, validation_data=test_dataset,
validation_steps=validation_steps, workers=6)
# 模型的保存
model.save(r'model_data/birds_model.h5')