在前一篇文章【深度域适配】一、DANN与梯度反转层(GRL)详解中,我们主要讲解了DANN的网络架构与梯度反转层(GRL)的基本原理,接下来这篇文章中我们将主要复现DANN论文Unsupervised Domain Adaptation by Backpropagation中MNIST和MNIST-M数据集的迁移训练实验。
该项目的github地址为:DANN-MNIST
为了利用DANN实现MNIST和MNIST-M数据集的迁移训练,我们首先需要获取到MNIST和MNIST-M数据集。其中MNIST数据集很容易获取,官网下载链接为:MNSIT。需要下载的文件如下图所示蓝色的4个文件。
同时MNSIT数据集的加载,tensorflow框架已经给出相关的读取接口,因此我们不需要自行编写,读取MNIST数据集的代码如下:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(os.path.abspath('./dataset/mnist'), one_hot=True)
# Process MNIST
mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)
MNIST-M数据集由MNIST数字与BSDS500数据集中的随机色块混合而成。那么要像生成MNIST-M数据集,请首先下载BSDS500数据集。BSDS500数据集的官方下载地址为:BSDS500。以下是BSDS500数据集官方网址相关截图,点击下图中蓝框的连接即可下载数据。
下载好BSDS500数据集后,我们必须根据MNIST和BSDS500数据集来生成MNIST-M数据集,生成数据集的脚本create_mnistm.py如下:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tarfile
import os
import pickle as pkl
import numpy as np
import skimage
import skimage.io
import skimage.transform
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('./dataset/mnist')
BST_PATH = os.path.abspath('./dataset/BSR_bsds500.tgz')
rand = np.random.RandomState(42)
f = tarfile.open(BST_PATH)
train_files = []
for name in f.getnames():
if name.startswith('BSR/BSDS500/data/images/train/'):
train_files.append(name)
print('Loading BSR training images')
background_data = []
for name in train_files:
try:
fp = f.extractfile(name)
bg_img = skimage.io.imread(fp)
background_data.append(bg_img)
except:
continue
def compose_image(digit, background):
"""Difference-blend a digit and a random patch from a background image."""
w, h, _ = background.shape
dw, dh, _ = digit.shape
x = np.random.randint(0, w - dw)
y = np.random.randint(0, h - dh)
bg = background[x:x+dw, y:y+dh]
return np.abs(bg - digit).astype(np.uint8)
def mnist_to_img(x):
"""Binarize MNIST digit and convert to RGB."""
x = (x > 0).astype(np.float32)
d = x.reshape([28, 28, 1]) * 255
return np.concatenate([d, d, d], 2)
def create_mnistm(X):
"""
Give an array of MNIST digits, blend random background patches to
build the MNIST-M dataset as described in
http://jmlr.org/papers/volume17/15-239/15-239.pdf
"""
X_ = np.zeros([X.shape[0], 28, 28, 3], np.uint8)
for i in range(X.shape[0]):
if i % 1000 == 0:
print('Processing example', i)
bg_img = rand.choice(background_data)
d = mnist_to_img(X[i])
d = compose_image(d, bg_img)
X_[i] = d
return X_
print('Building train set...')
train = create_mnistm(mnist.train.images)
print('Building test set...')
test = create_mnistm(mnist.test.images)
print('Building validation set...')
valid = create_mnistm(mnist.validation.images)
# Save dataset as pickle
mnistm_dir = os.path.abspath("./dataset/mnistm")
if not os.path.exists(mnistm_dir):
os.mkdir(mnistm_dir)
with open(os.path.join(mnistm_dir,'mnistm_data.pkl'), 'wb') as f:
pkl.dump({ 'train': train, 'test': test, 'valid': valid }, f, pkl.HIGHEST_PROTOCOL)
由于整个DANN-MNIST网络的训练过程中涉及到很多超参数,因此为了整个项目的编程方便,我们利用面向对象的思想将所有的超参数放置到一个类中,即参数配置类config。这个参数配置类config的代码如下:
# -*- coding: utf-8 -*-
# @Time : 2020/2/15 15:05
# @Author : Dai PuWei
# @Email : [email protected]
# @File : config.py
# @Software: PyCharm
import os
import cv2
import numpy as np
class config(object):
__defualt_dict__ = {
"pre_model_path":None,
"checkpoints_dir":os.path.abspath("./checkpoints"),
"logs_dir":os.path.abspath("./logs"),
"config_dir":os.path.abspath("./config"),
"dataset_dir": os.path.abspath("./dataset"),
#"dataset_dir": os.path.abspath("/input0"),
"result_dir": os.path.abspath("./result"),
"image_input_shape":(28,28,3),
"image_size":28,
"init_learning_rate": 1e-2,
"momentum_rate": 0.9,
"batch_size":64,
"epoch":500,
}
def __init__(self,**kwargs):
"""
这是参数配置类的初始化函数
:param kwargs: 参数字典
"""
# 初始化相关配置参数
self.__dict__.update(self. __defualt_dict__)
# 根据相关传入参数进行参数更新
self.__dict__.update(kwargs)
if not os.path.exists(self.checkpoints_dir):
os.mkdir(self.checkpoints_dir)
if not os.path.exists(self.logs_dir):
os.mkdir(self.logs_dir)
if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir)
def set(self,**kwargs):
"""
这是参数配置的设置函数
:param kwargs: 参数字典
:return:
"""
# 根据相关传入参数进行参数更新
self.__dict__.update(kwargs)
def save_config(self,time):
"""
这是保存参数配置类的函数
:param time: 时间点字符串
:return:
"""
# 更新相关目录
self.checkpoints_dir = os.path.join(self.checkpoints_dir,time)
self.logs_dir = os.path.join(self.logs_dir,time)
self.config_dir = os.path.join(self.config_dir,time)
self.result_dir = os.path.join(self.result_dir,time)
if not os.path.exists(self.config_dir):
os.mkdir(self.config_dir)
if not os.path.exists(self.checkpoints_dir):
os.mkdir(self.checkpoints_dir)
if not os.path.exists(self.logs_dir):
os.mkdir(self.logs_dir)
if not os.path.exists(self.result_dir):
os.mkdir(self.result_dir)
config_txt_path = os.path.join(self.config_dir,"config.txt")
with open(config_txt_path,'a') as f:
for key,value in self.__dict__.items():
if key in ["checkpoints_dir","logs_dir","config_dir"]:
value = os.path.join(value,time)
s = key+": "+value+"\n"
f.write(s)
在DANN中比较重要的模块就是梯度反转层(Gradient Reversal Layer, GRL)的实现。GRL的tf1.0代码实现如下:
# -*- coding: utf-8 -*-
# @Time : 2020/2/14 20:59
# @Author : Dai PuWei
# @Email : [email protected]
# @File : GRL.py
# @Software: PyCharm
import tensorflow as tf
from tensorflow.python.framework import ops
class GradientReversalLayer(object):
def __init__(self):
self.num_calls = 0
def __call__(self, x, l=1.0):
grad_name = "FlipGradient%d" % self.num_calls
@ops.RegisterGradient(grad_name)
def _flip_gradients(op, grad):
return [tf.negative(grad) * l]
g = tf.get_default_graph()
with g.gradient_override_map({"Identity": grad_name}):
y = tf.identity(x)
self.num_calls += 1
return y
GRL = GradientReversalLayer()
在上述代码中@ops.RegisterGradient(grad_name)修饰 _flip_gradients(op, grad)函数,即自定义该层的梯度取反。同时gradient_override_map函数主要用于解决使用自己定义的函数方式来求梯度的问题,gradient_override_map函数的参数值为一个字典。即字典中value表示使用该值表示的函数代替key表示的函数进行梯度运算。
DANN论文Unsupervised Domain Adaptation by Backpropagation中给出MNIST和MNIST-M数据集的迁移训练实验的网络,网络架构图如下图所示。
接下来,我们将利用tensorflow1.14.0来搭建整个DANN-MNIST网络,并在使用面向对象思想进行编程。DANN-MNIST类代码如下:
# -*- coding: utf-8 -*-
# @Time : 2020/2/14 20:27
# @Author : Dai PuWei
# @Email : [email protected]
# @File : MNIST2MNIST_M.py
# @Software: PyCharm
import os
import cv2
import datetime
import numpy as np
import tensorflow as tf
from tensorflow import keras as K
from tensorflow.train import MomentumOptimizer
from utils.utils import plot_loss
from utils.utils import plot_accuracy
from utils.utils import AverageMeter
from utils.utils import make_summary
from model.GRL import GRL
from utils.utils import grl_lambda_schedule
from utils.utils import learning_rate_schedule
class MNIST2MNIST_M_DANN(object):
def __init__(self,config):
"""
这是MNINST与MNIST_M域适配网络的初始化函数
:param config: 参数配置类
"""
# 初始化参数类
self.cfg = config
# 定义相关占位符
self.grl_lambd = tf.placeholder(tf.float32, []) # GRL层参数
self.learning_rate = tf.placeholder(tf.float32, []) # 学习率
self.source_image_labels = tf.placeholder(tf.float32, shape=(None, 10))
self.domain_labels = tf.placeholder(tf.float32, shape=(None, 2))
# 搭建深度域适配网络
self.build_DANN()
# 定义损失
self.image_cls_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.source_image_labels,
logits=self.image_cls))
self.domain_cls_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.domain_labels,
logits=self.domain_cls))
self.loss = self.image_cls_loss+self.domain_cls_loss
# 定义精度
correct_label_pred = tf.equal(tf.argmax(self.source_image_labels, 1), tf.argmax(self.image_cls, 1))
self.acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32))
# 定义模型保存类与加载类
self.saver_save = tf.train.Saver(max_to_keep=100) # 设置最大保存检测点个数为周期数
# 定义学习率
self.global_step = tf.Variable(tf.constant(0),trainable=False)
#self.process = self.global_step / self.cfg.epoch
# 初始化优化器
#self.optimizer = MomentumOptimizer(self.learning_rate, momentum=self.cfg.momentum_rate)
self.optimizer = MomentumOptimizer(self.learning_rate, momentum=self.cfg.momentum_rate)
#var_list = [v.name() for v in tf.trainable_variables()]
self.train_op = self.optimizer.minimize(self.loss,global_step=self.global_step)
def featur_extractor(self,image_input,name):
"""
这是特征提取子网络的构建函数
:param image_input: 图像输入张量
:param name: 输出特征名称
:return:
"""
x = K.layers.Conv2D(filters=32,kernel_size=5,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.1),
bias_initializer = K.initializers.Constant(value=0.1), activation='relu')(image_input)
x = K.layers.MaxPool2D(pool_size=(2,2),strides=2)(x)
x = K.layers.Conv2D(filters=48, kernel_size=5, kernel_initializer=K.initializers.TruncatedNormal(stddev=0.1),
bias_initializer = K.initializers.Constant(value=0.1), activation='relu')(x)
x = K.layers.MaxPool2D(pool_size=(2, 2),strides=2,name=name)(x)
return x
def build_image_classify_model(self,image_classify_feature):
"""
这是搭建图像分类器模型的函数
:param image_classify_feature: 图像分类特征张量
:return:
"""
# 搭建图像分类器
x = K.layers.Lambda(lambda x:x,name="image_classify_feature")(image_classify_feature)
x = K.layers.Flatten()(x)
x = K.layers.Dense(100,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.1),
bias_initializer = K.initializers.Constant(value=0.1), activation='relu')(x)
#x = K.layers.Dropout(0.5)(x)
x = K.layers.Dense(10,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.1),
bias_initializer = K.initializers.Constant(value=0.1), activation='softmax',
name = "image_classify_pred")(x)
return x
def build_domain_classify_model(self,domain_classify_feature):
"""
这是搭建域分类器的函数
:param domain_classify_feature: 域分类特征张量
:return:
"""
# 搭建域分类器
x = GRL(domain_classify_feature,self.grl_lambd)
x = K.layers.Flatten()(x)
x = K.layers.Dense(100,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.01),
bias_initializer = K.initializers.Constant(value=0.1), activation='relu')(x)
#x = K.layers.Dropout(0.5)(x)
x = K.layers.Dense(2,kernel_initializer=K.initializers.TruncatedNormal(stddev=0.01),
bias_initializer = K.initializers.Constant(value=0.1), activation='softmax'
,name="domain_classify_pred")(x)
return x
def build_DANN(self):
"""
这是搭建域适配网络的函数
:return:
"""
# 定义源域、目标域的图像输入和DANN模型图像输入
self.source_image_input = K.layers.Input(shape=self.cfg.image_input_shape,name="source_image_input")
self.target_image_input = K.layers.Input(shape=self.cfg.image_input_shape,name="target_image_input")
self.image_input = K.layers.Concatenate(axis=0,name="image_input")([self.source_image_input,self.target_image_input])
self.image_input = (self.image_input - self.cfg.pixel_mean) / 255.0
# 域分类器与图像分类器的共享特征
share_feature = self.featur_extractor(self.image_input,"image_feature")
# 均等划分共享特征为源域数据特征与目标域数据特征
source_feature,target_feature = \
K.layers.Lambda(tf.split, arguments={'axis': 0, 'num_or_size_splits': 2})(share_feature)
source_feature = K.layers.Lambda(lambda x:x,name="source_feature")(source_feature)
# 获取图像分类结果和域分类结果张量
self.image_cls = self.build_image_classify_model(source_feature)
self.domain_cls = self.build_domain_classify_model(share_feature)
def eval_on_val_dataset(self,sess,val_datagen,val_batch_num,ep):
"""
这是评估模型在验证集上的性能的函数
:param val_datagen: 验证集数据集生成器
:param val_batch_num: 验证集数据集批量个数
"""
epoch_loss_avg = AverageMeter()
epoch_image_cls_loss_avg = AverageMeter()
epoch_domain_cls_loss_avg = AverageMeter()
epoch_accuracy = AverageMeter()
for i in np.arange(1, val_batch_num + 1):
# 获取小批量数据集及其图像标签与域标签
batch_mnist_m_image_data, batch_mnist_m_labels = val_datagen.__next__()#val_datagen.next_batch()
batch_domain_labels = np.tile([0., 1.], [self.cfg.batch_size * 2, 1])
#batch_mnist_m_image_data = (batch_mnist_m_image_data - self.cfg.val_image_mean) /255.0
#batch_mnist_m_domain_labels = np.ones((self.cfg.batch_size,1))
# 在验证阶段只利用目标域数据及其标签进行测试
#batch_domain_labels = np.concatenate((batch_mnist_m_domain_labels, batch_mnist_m_domain_labels), axis=0)
# 计算模型在验证集上相关指标的值
val_loss, val_image_cls_loss, val_domain_cls_loss, val_acc = \
sess.run([self.loss, self.image_cls_loss, self.domain_cls_loss, self.acc],
feed_dict={self.source_image_input: batch_mnist_m_image_data,
self.target_image_input: batch_mnist_m_image_data,
self.source_image_labels: batch_mnist_m_labels,
self.domain_labels: batch_domain_labels})
# 更新损失与精度的平均值
epoch_loss_avg.update(val_loss, 1)
epoch_image_cls_loss_avg.update(val_image_cls_loss, 1)
epoch_domain_cls_loss_avg.update(val_domain_cls_loss, 1)
epoch_accuracy.update(val_acc, 1)
self.writer.add_summary(make_summary('val/val_loss', epoch_loss_avg.average),global_step=ep)
self.writer.add_summary(make_summary('val/val_image_cls_loss', epoch_image_cls_loss_avg.average),global_step=ep)
self.writer.add_summary(make_summary('val/val_domain_cls_loss', epoch_domain_cls_loss_avg.average),global_step=ep)
self.writer.add_summary(make_summary('accuracy/val_accuracy', epoch_accuracy.average),global_step=ep)
#self.writer1.add_summary(make_summary('val/val_loss', epoch_loss_avg.average),global_step=ep)
#self.writer1.add_summary(make_summary('val/val_image_cls_loss', epoch_image_cls_loss_avg.average),global_step=ep)
#self.writer1.add_summary(make_summary('val/val_domain_cls_loss', epoch_domain_cls_loss_avg.average),global_step=ep)
#self.writer1.add_summary(make_summary('accuracy/val_accuracy', epoch_accuracy.average),global_step=ep)
return epoch_loss_avg.average,epoch_image_cls_loss_avg.average,\
epoch_domain_cls_loss_avg.average,epoch_accuracy.average
def train(self,train_source_datagen,train_target_datagen,val_datagen,pixel_mean,interval,
train_iter_num,val_iter_num,pre_model_path=None):
"""
这是DANN的训练函数
:param train_source_datagen: 源域训练数据集生成器
:param train_target_datagen: 目标域训练数据集生成器
:param val_datagen: 验证数据集生成器
:param interval: 验证间隔
:param train_iter_num: 每个epoch的训练次数
:param val_iter_num: 每次验证过程的验证次数
:param pre_model_path: 预训练模型地址,与训练模型为ckpt文件,注意文件路径只需到.ckpt即可。
"""
# 初始化相关文件目录路径
time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
checkpoint_dir = os.path.join(self.cfg.checkpoints_dir,time)
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
log_dir = os.path.join(self.cfg.logs_dir, time)
if not os.path.exists(log_dir):
os.mkdir(log_dir)
result_dir = os.path.join(self.cfg.result_dir, time)
if not os.path.exists(result_dir):
os.mkdir(result_dir)
self.cfg.save_config(time)
# 初始化训练损失和精度数组
train_loss_results = [] # 保存训练loss值
train_image_cls_loss_results = [] # 保存训练图像分类loss值
train_domain_cls_loss_results = [] # 保存训练域分类loss值
train_accuracy_results = [] # 保存训练accuracy值
# 初始化验证损失和精度数组,验证最大精度
val_ep = []
val_loss_results = [] # 保存验证loss值
val_image_cls_loss_results = [] # 保存验证图像分类loss值
val_domain_cls_loss_results = [] # 保存验证域分类loss值
val_accuracy_results = [] # 保存验证accuracy值
val_acc_max = 0 # 最大验证精度
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
# 加载预训练模型
if pre_model_path is not None: # pre_model_path的地址写到.ckpt
saver_restore = tf.train.import_meta_graph(pre_model_path+".meta")
saver_restore.restore(sess,pre_model_path)
print("restore model from : %s" % (pre_model_path))
self.merged = tf.summary.merge_all()
self.writer = tf.summary.FileWriter(log_dir, sess.graph)
#self.writer1 = tf.summary.FileWriter(os.path.join("./tf_dir"), sess.graph)
print('\n----------- start to train -----------\n')
total_global_step = self.cfg.epoch * train_iter_num
for ep in np.arange(self.cfg.epoch):
# 初始化每次迭代的训练损失与精度平均指标类
epoch_loss_avg = AverageMeter()
epoch_image_cls_loss_avg = AverageMeter()
epoch_domain_cls_loss_avg = AverageMeter()
epoch_accuracy = AverageMeter()
# 初始化精度条
progbar = K.utils.Progbar(train_iter_num)
print('Epoch {}/{}'.format(ep+1, self.cfg.epoch))
batch_domain_labels = np.vstack([np.tile([1., 0.], [self.cfg.batch_size // 2, 1]),
np.tile([0., 1.], [self.cfg.batch_size // 2, 1])])
for i in np.arange(1,train_iter_num+1):
# 获取小批量数据集及其图像标签与域标签
batch_mnist_image_data, batch_mnist_labels = train_source_datagen.__next__()#train_source_datagen.next_batch()
batch_mnist_m_image_data, batch_mnist_m_labels = train_target_datagen.__next__()#train_target_datagen.next_batch()
"""
print(np.shape(batch_mnist_image_data))
print(np.shape(batch_mnist_labels))
print(np.shape(batch_mnist_domain_labels))
print(np.shape(batch_mnist_m_image_data))
print(np.shape(batch_mnist_m_labels))
print(np.shape(batch_mnist_m_domain_labels))
"""
# 计算学习率和GRL层的参数lambda
global_step = (ep-1)*train_iter_num + i
process = global_step * 1.0 / total_global_step
leanring_rate = learning_rate_schedule(process,self.cfg.init_learning_rate)
grl_lambda = grl_lambda_schedule(process)
# 前向传播,计算损失及其梯度
op,train_loss,train_image_cls_loss,train_domain_cls_loss,train_acc = \
sess.run([self.train_op,self.loss,self.image_cls_loss,self.domain_cls_loss,self.acc],
feed_dict={self.source_image_input:batch_mnist_image_data,
self.target_image_input:batch_mnist_m_image_data,
self.source_image_labels:batch_mnist_labels,
self.domain_labels:batch_domain_labels,
self.learning_rate:leanring_rate,
self.grl_lambd:grl_lambda})
self.writer.add_summary(make_summary('learning_rate', leanring_rate),global_step=global_step)
#self.writer1.add_summary(make_summary('learning_rate', leanring_rate), global_step=global_step)
# 更新训练损失与训练精度
epoch_loss_avg.update(train_loss,1)
epoch_image_cls_loss_avg.update(train_image_cls_loss,1)
epoch_domain_cls_loss_avg.update(train_domain_cls_loss,1)
epoch_accuracy.update(train_acc,1)
# 更新进度条
progbar.update(i, [('train_image_cls_loss', train_image_cls_loss),
('train_domain_cls_loss', train_domain_cls_loss),
('train_loss', train_loss),
("train_acc",train_acc)])
# 保存相关损失与精度值,可用于可视化
train_loss_results.append(epoch_loss_avg.average)
train_image_cls_loss_results.append(epoch_image_cls_loss_avg.average)
train_domain_cls_loss_results.append(epoch_domain_cls_loss_avg.average)
train_accuracy_results.append(epoch_accuracy.average)
self.writer.add_summary(make_summary('train/train_loss', epoch_loss_avg.average),global_step=ep+1)
self.writer.add_summary(make_summary('train/train_image_cls_loss', epoch_image_cls_loss_avg.average),
global_step=ep+1)
self.writer.add_summary(make_summary('train/train_domain_cls_loss', epoch_domain_cls_loss_avg.average),
global_step=ep+1)
self.writer.add_summary(make_summary('accuracy/train_accuracy', epoch_accuracy.average),global_step=ep+1)
#self.writer1.add_summary(make_summary('train/train_loss', epoch_loss_avg.average),global_step=ep+1)
#self.writer1.add_summary(make_summary('train/train_image_cls_loss', epoch_image_cls_loss_avg.average),
# global_step=ep+1)
#self.writer1.add_summary(make_summary('train/train_domain_cls_loss', epoch_domain_cls_loss_avg.average),
# global_step=ep+1)
#self.writer1.add_summary(make_summary('accuracy/train_accuracy', epoch_accuracy.average),global_step=ep+1)
if (ep+1) % interval == 0:
# 评估模型在验证集上的性能
val_ep.append(ep)
val_loss, val_image_cls_loss,val_domain_cls_loss, \
val_accuracy = self.eval_on_val_dataset(sess,val_datagen,val_iter_num,ep+1)
val_loss_results.append(val_loss)
val_image_cls_loss_results.append(val_image_cls_loss)
val_domain_cls_loss_results.append(val_domain_cls_loss)
val_accuracy_results.append(val_accuracy)
str = "Epoch{:03d}_val_image_cls_loss{:.3f}_val_domain_cls_loss{:.3f}_val_loss{:.3f}" \
"_val_accuracy{:.3%}".format(ep+1,val_image_cls_loss,val_domain_cls_loss,val_loss,val_accuracy)
print(str)
if val_accuracy > val_acc_max: # 验证精度达到当前最大,保存模型
val_acc_max = val_accuracy
self.saver_save.save(sess,os.path.join(checkpoint_dir,str+".ckpt"))
# 保存训练与验证结果
path = os.path.join(result_dir, "train_loss.jpg")
plot_loss(np.arange(1,len(train_loss_results)+1), [np.array(train_loss_results),
np.array(train_image_cls_loss_results),np.array(train_domain_cls_loss_results)],
path, "train")
path = os.path.join(result_dir, "val_loss.jpg")
plot_loss(np.array(val_ep)+1, [np.array(val_loss_results),
np.array(val_image_cls_loss_results),np.array(val_domain_cls_loss_results)],
path, "val")
train_acc = np.array(train_accuracy_results)[np.array(val_ep)]
path = os.path.join(result_dir, "accuracy.jpg")
plot_accuracy(np.array(val_ep)+1, [train_acc, val_accuracy_results], path)
# 保存最终的模型
model_path = os.path.join(checkpoint_dir,"trained_model.ckpt")
self.saver_save.save(sess,model_path)
print("Train model finshed. The model is saved in : ", model_path)
print('\n----------- end to train -----------\n')
def test_image(self,image_path,model_path):
"""
这是测试一张图像的函数
:param image_path: 图像路径
:param model_path: 模型路径
:return:
"""
# 读取图像数据,并进行数组维度扩充
image = cv2.imread(image_path)
image = np.expand_dims(image,axis=0)
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
# 加载预训练模型
saver_restore = tf.train.import_meta_graph(model_path+".meta")
saver_restore.restore(sess, model_path)
# 进行测试
img_cls_pred = sess.run([self.image_cls],feed_dict={self.source_image_input: image})
pred_label = np.argmax(img_cls_pred[0])+1
print("%s is %d" %(image_path,pred_label))
def test_batch_images(self, image_paths, model_path):
"""
这是测试一张图像的函数
:param image_paths: 图像路径数组
:param model_path: 模型路径
:return:
"""
# 批量读取图像数据
images = np.array([cv2.imread(image_path) for image_path in image_paths])
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
# 加载预训练模型
saver_restore = tf.train.import_meta_graph(model_path+".meta")
saver_restore.restore(sess, model_path)
# 进行测试
img_cls_pred = sess.run([self.image_cls], feed_dict={self.source_image_input: images})
pred_label = np.argmax(img_cls_pred,axis=0) + 1
for i,image_path in enumerate(image_paths):
print("%s is %d" % (image_path, pred_label[i]))
在训练过程中,需要各种小工具函数来辅助训练过程。例如学习率 μ p \mu_p μp、GRL参数 λ p \lambda_p λp是根据迭代进程变化,数据集生成器的定义域各种结果绘制函数。工具脚本utilis.py如下:
# -*- coding: utf-8 -*-
# @Time : 2020/2/15 16:10
# @Author : Dai PuWei
# @Email : [email protected]
# @File : utils.py
# @Software: PyCharm
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.core.framework import summary_pb2
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.average = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.average = self.sum / float(self.count)
def make_summary(name, val):
return summary_pb2.Summary(value=[summary_pb2.Summary.Value(tag=name, simple_value=val)])
def plot_accuracy(x,y,path):
"""
这是绘制精度的函数
:param x: x坐标数组
:param y: y坐标数组
:param path: 结果保存地址
:param mode: 模式,“train”代表训练损失,“val”为验证损失
"""
lengend_array = ["train_acc", "val_acc"]
train_accuracy,val_accuracy = y
plt.plot(x, train_accuracy, 'r-')
plt.plot(x, val_accuracy, 'b--')
plt.grid(True)
plt.xlim(0, x[-1]+2)
#plt.xticks(x)
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.legend(lengend_array,loc="best")
plt.savefig(path)
plt.close()
def plot_loss(x,y,path,mode="train"):
"""
这是绘制损失的函数
:param x: x坐标数组
:param y: y坐标数组
:param path: 结果保存地址
:param mode: 模式,“train”代表训练损失,“val”为验证损失
"""
if mode == "train":
lengend_array = ["train_loss","train_image_cls_loss","train_domain_cls_loss"]
else:
lengend_array = ["val_loss", "val_image_cls_loss", "val_domain_cls_loss"]
loss_results,image_cls_loss_results,domain_cls_loss_results = y
loss_results_min = np.max([np.min(loss_results) - 0.1,0])
image_cls_loss_results_min = np.max([np.min(image_cls_loss_results) - 0.1,0])
domain_cls_loss_results_min =np.max([np.min(domain_cls_loss_results) - 0.1,0])
y_min = np.min([loss_results_min,image_cls_loss_results_min,domain_cls_loss_results_min])
plt.plot(x, loss_results, 'r-')
plt.plot(x, image_cls_loss_results, 'b--')
plt.plot(x, domain_cls_loss_results, 'g-.')
plt.grid(True)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.xlim(0,x[-1]+2)
plt.ylim(ymin=y_min)
#plt.xticks(x)
plt.legend(lengend_array,loc="best")
plt.savefig(path)
plt.close()
def learning_rate_schedule(process,init_learning_rate = 0.01,alpha = 10.0 , beta = 0.75):
"""
这个学习率的变换函数
:param process: 训练进程比率,值在0-1之间
:param init_learning_rate: 初始学习率,默认为0.01
:param alpha: 参数alpha,默认为10
:param beta: 参数beta,默认为0.75
"""
return init_learning_rate /(1.0 + alpha * process)**beta
def grl_lambda_schedule(process,gamma=10.0):
"""
这是GRL的参数lambda的变换函数
:param process: 训练进程比率,值在0-1之间
:param gamma: 参数gamma,默认为10
"""
return 2.0 / (1.0+np.exp(-gamma*process)) - 1.0
最后是训练DANN的脚本train.py,代码如下:
# -*- coding: utf-8 -*-
# @Time : 2020/2/15 16:36
# @Author : Dai PuWei
# @Email : [email protected]
# @File : train.py
# @Software: PyCharm
import os
import numpy as np
import pickle as pkl
from config.config import config
from tensorflow import keras as K
from model.MNIST2MNIST_M import MNIST2MNIST_M_DANN
from datagenerator.DataGenerator import DataGenerator
from tensorflow.examples.tutorials.mnist import input_data
def shuffle_aligned_list(data):
"""Shuffle arrays in a list by shuffling each array identically."""
num = data[0].shape[0]
p = np.random.permutation(num)
return [d[p] for d in data]
def batch_generator(data, batch_size, shuffle=True):
"""Generate batches of data.
Given a list of array-like objects, generate batches of a given
size by yielding a list of array-like objects corresponding to the
same slice of each input.
"""
if shuffle:
data = shuffle_aligned_list(data)
batch_count = 0
while True:
if batch_count * batch_size + batch_size >= len(data[0]):
batch_count = 0
if shuffle:
data = shuffle_aligned_list(data)
start = batch_count * batch_size
end = start + batch_size
batch_count += 1
yield [d[start:end] for d in data]
def run_main():
"""
这是主函数
"""
# 初始化参数配置类
cfg = config()
mnist = input_data.read_data_sets(os.path.abspath('./dataset/mnist'), one_hot=True)
# Process MNIST
mnist_train = (mnist.train.images > 0).reshape(55000, 28, 28, 1).astype(np.uint8) * 255
mnist_train = np.concatenate([mnist_train, mnist_train, mnist_train], 3)
mnist_test = (mnist.test.images > 0).reshape(10000, 28, 28, 1).astype(np.uint8) * 255
mnist_test = np.concatenate([mnist_test, mnist_test, mnist_test], 3)
# Load MNIST-M
mnistm = pkl.load(open(os.path.abspath('./dataset/mnistm/mnistm_data.pkl'), 'rb'))
mnistm_train = mnistm['train']
mnistm_test = mnistm['test']
mnistm_valid = mnistm['valid']
# Compute pixel mean for normalizing data
pixel_mean = np.vstack([mnist_train, mnistm_train]).mean((0, 1, 2))
cfg.set(pixel_mean = pixel_mean)
# 构造数据生成器
train_source_datagen = batch_generator([mnist_train,mnist.train.labels],cfg.batch_size // 2)
train_target_datagen = batch_generator([mnistm_train,mnist.train.labels],cfg.batch_size // 2)
val_datagen = batch_generator([mnistm_test,mnist.test.labels],cfg.batch_size)
"""
train_source_datagen = DataGenerator(os.path.join(cfg.dataset_dir, 'mnist'),int(cfg.batch_size/2),
cfg.image_size,source_flag=True,mode="train")
train_target_datagen = DataGenerator(os.path.join(cfg.dataset_dir, 'mnistM'),int(cfg.batch_size/2),
cfg.image_size,source_flag=False,mode="train")
val_datagen = DataGenerator(os.path.join(cfg.dataset_dir, 'mnistM'),cfg.batch_size,
cfg.image_size,source_flag=False,mode="val")
"""
# 初始化每个epoch的训练次数和每次验证过程的验证次数
train_source_batch_num = int(len(mnist_train) // (cfg.batch_size // 2))
train_target_batch_num = int(len(mnistm_train) // (cfg.batch_size // 2))
train_iter_num = int(np.max([train_source_batch_num,train_target_batch_num]))
val_iter_num = int(len(mnistm_test) / cfg.batch_size)
# 初始化相关参数
interval = 2 # 验证间隔
"""
train_num = cfg.train_dataset_size # 训练集样本数
val_num = cfg.val_dataset_size # 验证集样本数
"""
train_num = len(mnist_train) + len(mnistm_train)# 训练集样本数
val_num = len(mnistm_test) # 验证集样本数
print("train on %d training samples with batch_size %d ,validation on %d val samples"
% (train_num, cfg.batch_size, val_num))
# 初始化DANN,并进行训练
dann = MNIST2MNIST_M_DANN(cfg)
#pre_model_path = os.path.abspath("./pre_model/trained_model.ckpt")
pre_model_path = None
dann.train(train_source_datagen,train_target_datagen,val_datagen,pixel_mean,
interval,train_iter_num,val_iter_num,pre_model_path)
if __name__ == '__main__':
run_main()
下面是训练过程中的相关tensorboard的相关指标在训练过程中的走势图。首先是训练误差的走势图,主要包括训练域分类误差、训练图像分类误差和训练总误差。
接下来是验证误差的走势图,主要包括验证域分类误差、验证图像分类误差和验证练总误差。
然后是训练过程中学习率的走势图
最后是精度走势图,主要包括训练精度和测试精度。其中训练精度是在源域数据集即MNIST数据集上的统计结果,验证精度是在目标域数据集即MNIST-M数据集上的统计结果。从图中可以看出,DANN在训练MNIST-M数据集时没有使用对应的标签,MNSIT-M数据集上的精度最终收敛到75.4%,效果相比于81.49%还有一定距离,但鉴于没有使用任何数据增强和dropout,这个结果可以接受。