导入库:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Dropout
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import shutil
import tqdm
# 导入数据集
CONTENT_DIR = '/G:'
TRAIN_DIR = CONTENT_DIR + '/train'
VALID_DIR = CONTENT_DIR + '/valid'
if not os.path.exists(CONTENT_DIR):
import zipfile
with zipfile.ZipFile('E:/6.10dogcat/train.zip', 'r') as zipf:
zipf.extractall(CONTENT_DIR)
img_filenames = os.listdir(TRAIN_DIR)
dog_filenames = [fn for fn in img_filenames if fn.startswith('dog')]
cat_filenames = [fn for fn in img_filenames if fn.startswith('cat')]
dataset_filenames = train_test_split(
dog_filenames, cat_filenames, test_size=0.1, shuffle=True, random_state=42
)
make_dirs = [d + a for a in ['/dog', '/cat'] for d in [TRAIN_DIR, VALID_DIR]]
for dir, fns in zip(make_dirs, dataset_filenames):
os.makedirs(dir, exist_ok=True)
for fn in tqdm.tqdm(fns):
shutil.move(os.path.join(TRAIN_DIR, fn), dir)
print('elements in {}: {}'.format(dir, len(os.listdir(dir))))
# 数据预处理
train_generator = ImageDataGenerator(rescale=1./255)
valid_generator = ImageDataGenerator(rescale=1./255)
# 数据加载器
train_data = train_generator.flow_from_directory(
directory=TRAIN_DIR,
target_size=(IMAGE_SHAPE, IMAGE_SHAPE),
batch_size=BATCH_SIZE,
class_mode='binary',
shuffle=False
)
valid_data = valid_generator.flow_from_directory(
directory=VALID_DIR,
target_size=(IMAGE_SHAPE, IMAGE_SHAPE),
batch_size=BATCH_SIZE,
class_mode='binary',
shuffle=False
)
resnet_model = tf.keras.applications.resnet.ResNet50(
include_top=False,
weights='imagenet',
input_shape=(IMAGE_SHAPE, IMAGE_SHAPE, 3),
pooling='avg'
)
valid_bottleneck = resnet_model.predict_generator(
valid_data, valid_data.n // BATCH_SIZE, verbose=1
)
# 加载模型
model = tf.keras.models.Sequential([
Dense(units=256, activation='relu', input_shape=resnet_model.output_shape[1:]),
Dropout(0.5),
Dense(units=128, activation='relu'),
Dropout(0.5),
Dense(units=2, activation='softmax')
])
# 配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# 训练
EPOCHS = 30
history = model.fit(
x=train_bottleneck,
y=train_data.labels[:len(train_bottleneck)],
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_data=(valid_bottleneck, valid_data.labels[:len(valid_bottleneck)])
)
# 画出损失图和准确率图
def show_graphs(history):
plt.figure(figsize=(12, 8))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='train')
plt.plot(history.history['val_accuracy'], label='valid')
plt.legend(loc='lower right')
plt.title('Accuracy')
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='valid')
plt.legend(loc='upper left')
plt.title('Loss ')
plt.show()
show_graphs(history)
IMAGE_SHAPE = 224
example_data = train_generator.flow_from_directory(
directory=TRAIN_DIR,
target_size=(IMAGE_SHAPE, IMAGE_SHAPE),
batch_size=BATCH_SIZE,
class_mode='binary',
shuffle=True
)
example_x, example_y = example_data.next()
example_classes = list(example_data.class_indices.keys())
example_y_classes = [example_classes[int(i)] for i in example_y]
resnet_native = tf.keras.applications.resnet.ResNet50(
include_top=True,
weights='imagenet',
input_shape=(IMAGE_SHAPE, IMAGE_SHAPE, 3),
pooling='avg'
)
resnet_native.compile()
example_pred = resnet_native.predict(
example_x
)
labels_path = tf.keras.utils.get_file(
'ImageNetLabels.txt',
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
)
imagenet_labels = np.array(open(labels_path).read().splitlines())
result = imagenet_labels[np.argmax(example_pred, axis=1)]
NUM_ROWS = 4
NUM_COLS = 4
NUM_IMAGES = NUM_COLS * NUM_ROWS
# 进行预测
plt.figure(figsize=(2*NUM_COLS, 2*NUM_ROWS))
for i in range(NUM_IMAGES):
plt.subplot(NUM_ROWS, NUM_COLS, i + 1)
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(example_x[i], cmap=plt.cm.binary)
plt.xlabel('{} \n({})'.format(result[i], example_y_classes[i]))
plt.tight_layout()
plt.show()