各层模型结构:
Dog1原图片:
恢复彩色图片,迭代3次后的彩色图片:
迭代10次后的彩色图片:
经过多次迭代后才会有比较好的效果,才能更好地辨别上色后的图片与原图片的联系;
from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator
import os, numpy as np
from keras.preprocessing import image
# In[1]: 从硬盘中读取猫狗数据图片集并准备将图片进行灰度化
#把彩色图转化为灰度图,如果当前像素点为[r,g,b],那么对应的灰度点为0.299*r+0.587*g+0.114*b
def rgb2gray(rgb):
return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])
datagen_train = ImageDataGenerator(rescale = 1. / 255)
generator_train = datagen_train.flow_from_directory(
r'D:\Cadabra_tools002\course_data\cat-and-dog\training_set', # 训练集所在路径,子目录为类别
target_size=(64, 64), # 统一resize所有图片的大小
batch_size=8000, # 输入到fit函数中的批的大小
)
img_rows = 64
img_cols = 64
channels = 1
x_train,y = generator_train[0] # x_train 原图
#图片灰度化
x_train_gray = rgb2gray(x_train)
#将灰度图片增加一维,单通道图像,符合卷积网络的输入格式
x_train_gray = x_train_gray.reshape(x_train_gray.shape[0], img_rows, img_cols, channels)
# In[2]: 构造编码器
input_shape = (img_rows, img_cols, 1)
batch_size = 32
kernel_size = 3
#由于图片编码后需要保持图片物体与颜色信息,因此编码后的一维向量维度要变大
latent_dim = 256
layer_filters = [64, 128, 256]
inputs = Input(shape=input_shape, name = 'encoder_input')
x = inputs
for filters in layer_filters:
x = Conv2D(filters = filters, kernel_size = kernel_size, strides = 2,
activation = 'relu', padding = 'same')(x)
# 输入时格式为(32, 32, 3), 经过三层卷积层后输出为(4, 4, 256)
shape = K.int_shape(x)
x = Flatten()(x)
latent = Dense(latent_dim, name = 'latent_vector')(x)
encoder = Model(inputs, latent, name = 'encoder')
encoder.summary()
# In[3]: 构造解码器
latent_inputs = Input(shape=(latent_dim, ), name = 'decoder_input')
'''
将编码器输出的一维向量传入一个全连接网络层,输出的数据格式与上面shape变量相同,为[4, 4, 256]
'''
x = Dense(shape[1] * shape[2] * shape[3])(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)
'''
解码器对应编码器做反向操作,因此它将数据经过三个反卷积层,卷积层的输出维度与编码器恰好相反,分别为
256, 128, 64,每经过一个反卷积层,数据维度增加一倍,因此输入时数据维度为[4,4],经过三个反卷积层后
维度为[32,32]恰好与图片格式一致
'''
for filters in layer_filters[::-1]:
x = Conv2DTranspose(filters = filters, kernel_size = kernel_size,
strides = 2, activation = 'relu',
padding = 'same')(x)
outputs = Conv2DTranspose(filters = channels, kernel_size = kernel_size,
activation='relu', padding='same',
name = 'decoder_output')(x)
print(K.int_shape(outputs))
decoder = Model(latent_inputs, outputs, name = 'decoder')
decoder.summary()
# In[4]: 构造自动编解码器
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
autoencoder = Model(inputs, decoder(encoder(inputs)), name='autoencoder')
autoencoder.summary()
#如果经过5次循环训练后效果没有改进,那么就把学习率减少0.1的开方,通过调整学习率促使训练效果改进
lr_reducer = ReduceLROnPlateau(factor = np.sqrt(0.1), cooldown = 0, patience = 5,
verbose = 1, min_lr = 0.5e-6)
model_name = 'colorized_ae+model.{epoch:03d}.h5'
checkpoint = ModelCheckpoint(filepath = model_name, monitor = 'val_loss',verbose = 1)
autoencoder.compile(loss='mse', optimizer = 'adam')
callbacks = [lr_reducer, checkpoint]
autoencoder.fit(x_train_gray, x_train,
epochs = 10, #30,
batch_size = batch_size,
callbacks = callbacks)
# In[5]: 将灰度图和上色后的图片显示出来
x_decoded = autoencoder.predict(x_train_gray)
imgs = x_decoded[:100]
imgs = imgs.reshape((10, 10, img_rows, img_cols, channels))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure(dpi=200)
plt.axis('off')
plt.title('Colorized test images are: ')
plt.imshow(imgs, interpolation='none')
plt.show()
# In[6]: 读取一张图片,并且转换格式,以便作为神经网络的输入
img_path = os.path.join(r'D:\Cadabra_tools002\course_data\cat-and-dog\dog.1.jpg')
img = image.load_img(img_path, target_size=(64, 64))
plt.imshow(img)
img_tensor = image.img_to_array(img)
img_tensor = np.expand_dims(img_tensor, axis=0)
img_tensor /= 255.
print(img_tensor.shape)
#图片灰度化
img_gray = rgb2gray(img_tensor)
#将灰度图片增加一维,单通道图像,符合卷积网络的输入格式
img_gray = img_gray.reshape(img_gray.shape[0], img_rows, img_cols, 1)
# In[7]: 输入到训练好的网络中,得到恢复的彩色图片
x_decoded = autoencoder.predict(img_gray)
#把测试图片集中的前8张显示出来,看看解码器生成的图片是否与原图片足够相似
imgs = np.concatenate([img_gray, x_decoded])
imgs = imgs.reshape((2, 1, 64, 64))
imgs = np.vstack([np.hstack(i) for i in imgs])
plt.figure()
plt.axis('off')
plt.title('Input: 1st 1 rows, Decoded: last 1 rows')
plt.imshow(imgs, interpolation='none')
plt.show()
源码下载