Keras-VGG16预训练模型-Cats and Dogs数据集

from keras.applications import VGG16
import numpy as np
import os
from keras.preprocessing.image import ImageDataGenerator
from keras import Sequential
from keras import layers
from keras import optimizers
import matplotlib.pyplot as plt

conv_base = VGG16(
    weights = "imagenet",
    include_top = False,
    input_shape = (150,150,3)
)
###############模型搭建
model = Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256,activation="relu"))
model.add(layers.Dense(1,activation="sigmoid"))
################设置卷积基不可训练
conv_base.trainable = False
#################或者model.layers[0].trainable = False
#########print(model.summary())
base_dir = "D:/cats_and_dogs_small"
train_dir = "D:/cats_and_dogs_small/train"
validation_dir = "D:/cats_and_dogs_small/validation"
test_dir = "D:/cats_and_dogs_small/test"

train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
)

validation_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(150,150),
    batch_size=20,
    class_mode="binary",
)

validation_generator = validation_datagen.flow_from_directory(
    validation_dir,
    target_size=(150,150),
    batch_size=20,
    class_mode="binary"

)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150,150),
    batch_size=20,
    class_mode="binary"
)

model.compile(
    optimizer = optimizers.RMSprop(2e-5),
    loss="binary_crossentropy",
    metrics=["acc"]
)

history = model.fit_generator(
    train_generator,
    steps_per_epoch=100,
    epochs=10,
    validation_data=validation_generator,
    validation_steps=50
)

acc = history.history["acc"]
val_acc = history.history["val_acc"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]

epoches = range(1,len(acc)+1)

plt.plot(epoches,acc,"bo",label="Training acc")
plt.plot(epoches,val_acc,"b",label="Validation acc")
plt.title("Trainging and validation acc")
plt.legend()
plt.show()

你可能感兴趣的:(Keras-VGG16预训练模型-Cats and Dogs数据集)