import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np import tensorflow as tf import struct import glob import os from PIL import Image import time __sony__ = 0 __huawei__ = 1 __blackberry__ = 2 __stage_raw2raw__ = 0 __stage_raw2rgb__ = 1 __stage_overall__ = 2 train_prefix = '0' valid_prefix = '1' test_prefix = '2' # ============ CONFIGURATION ============ USE_GPU = False if USE_GPU: os.environ['CUDA_VISIBLE_DEVICES'] = '2' # change this to switch between datasets source_id = __sony__ # switch between training stages training_stage = __stage_raw2rgb__ # patch size should be set on running patch_size = (512, 512) #patch_size = (2840, 4248) # switch between training and validation current_prefix = train_prefix # model saving settings max_epoch = 2000 save_epoch_delay = 1 model_dir = './result_raw2raw/' out_dir = './output_raw2raw/' log_dir = './log_raw2raw/' learn_rate = 1e-2 # ============ CONFIGURATION ============ if source_id == __blackberry__: WHITE_LEVEL = 1023 BLACK_LEVEL = 64 HEIGHT = 3024 WIDTH = 4032 elif source_id == __sony__: WHITE_LEVEL = 16383 BLACK_LEVEL = 512 HEIGHT = 2848 WIDTH = 4256 elif source_id == __huawei__: WHITE_LEVEL = 1023 BLACK_LEVEL = 64 HEIGHT = 2976 WIDTH = 3968 if USE_GPU: data_dir = '../see_in_the_dark/dataset/Sony_small/' else: data_dir = 'D:/data/Sony_small/' # !!!!!! DO NOT TOUCH THIS SETTING !!!!!! fixed_size = (128, 128) num_of_denoise_filter = 3 standard_brightness = 0.1 # !!!!!! DO NOT TOUCH THIS SETTING !!!!!! def has_nan_in_tensor(x): return np.sum(x != x) > 0 def raw_from_file(path): if source_id == __sony__: data = rawpy.imread(path) raw = data.raw_image_visible.astype(np.float32) raw = raw.reshape(2848, 4256) # convert from RGBG into standard GRGB format: # cut the strips of left and right borders h, w = raw.shape[0], raw.shape[1] return np.reshape(raw[:, 1:w-1], [h, w-2, 1]) elif source_id == __huawei__: data = rawpy.imread(path) raw = data.raw_image_visible.astype(np.float32) raw = raw.reshape(2976, 3968) # convert from BGRG into standard GRGB format: # cut the strips of top and bottom borders h, w = raw.shape[0], raw.shape[1] return np.reshape(raw[1:h-1, :], [h-2, w, 1]) elif source_id == __blackberry__: data = open(path, 'rb').read() data = struct.unpack('H'*int(len(data)/2), data) raw = np.float32(data) raw = raw.reshape(3024, 4032) h, w = raw.shape[0], raw.shape[1] return np.reshape(raw, [h, w, 1]) else: assert False def rgb_from_file(path): if source_id == __sony__: raw = rawpy.imread(path) rgb = np.float32( raw.postprocess( use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16 ) ) / 65535.0 return rgb[:, 1:-1, :] elif source_id == __huawei__: raw = rawpy.imread(path) rgb = np.float32( raw.postprocess( use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16 ) ) / 65535.0 return rgb[1:-1, :, :] else: raise NameError('file type [%d] does not support rawpy!' % source_id) def black_level_correction(bayer): with tf.name_scope('black_level_corr'): r = 1.0/(WHITE_LEVEL-BLACK_LEVEL) return tf.nn.relu((bayer - BLACK_LEVEL)*r) def bound(bayer): return tf.minimum(tf.maximum(bayer, 0), 1) def bayer_to_rgb(bayer): with tf.name_scope('bayer2rgb'): filters = np.array([ [0.0, 1.0, 0.0, 0.0], # R [0.5, 0.0, 0.0, 0.5], # (G1+G2)/2 [0.0, 0.0, 1.0, 0.0], # B ]).reshape([1, 3, 2, 2]).transpose([2, 3, 0, 1]) return tf.nn.conv2d( bayer, filters, strides=(1, 2, 2, 1), padding='VALID', name='bayer_converter' ) def demosaic(rgb): with tf.name_scope('demosaic'): return tf.image.resize_bilinear(rgb, patch_size) def color_correction(rgb, color_matrix): with tf.name_scope('color_corr'): filters = tf.reshape(color_matrix, [1, 1, 3, 3]) return tf.nn.conv2d(rgb, filters, (1, 1, 1, 1), 'SAME', name='output') def min_max_normalize(rgb): _min = tf.reduce_min(rgb) _max = tf.reduce_max(rgb) return (rgb - _min + 1e-8)/(_max - _min + 1e-8) def gaussian_norm(rgb): _mean = tf.reduce_mean(rgb) _vari = tf.sqrt(tf.reduce_mean(tf.square(rgb-_mean))) return (rgb-_mean)/_vari # not supported on SNPE, so do it on cpu of mobile phone # in case of negative value, normalize it before power operation def gamma_correction(rgb, gamma): with tf.name_scope('gamma_corr'): return tf.pow(min_max_normalize(rgb), gamma) def lrelu(x): return tf.maximum(x*0.2, x) def network_raw2raw(inputs): with tf.name_scope('raw2raw'): net = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope='g_conv1') net = slim.conv2d(net, 32, [3, 3], rate=2, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope='g_conv2') net = slim.conv2d(net, 32, [3, 3], rate=4, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope='g_conv3') net = slim.conv2d(net, 32, [3, 3], rate=8, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope='g_conv4') net = slim.conv2d(net, 32, [3, 3], rate=16, activation_fn=lrelu, weights_initializer=tf.initializers.constant, scope='g_conv5') net = slim.conv2d(net, 1, [1, 1], rate=1, activation_fn=None, scope='g_conv_last') return net def show(rgb, title): im = Image.fromarray(np.uint8(rgb * 255)) im.show(title) def save(rgb, path): im = Image.fromarray(np.uint8(rgb * 255)) im.save(path) def concat(ims): return np.concatenate(ims, axis=1) def get_color_matrix_and_gamma(bayer): with tf.name_scope('isp_param_gen'): with tf.name_scope('common_extractor'): channels = tf.layers.conv2d(bayer, 3, kernel_size=3, strides=2, padding='valid') activations = tf.nn.tanh(channels) channels = tf.layers.conv2d(activations, 5, kernel_size=3, strides=2, padding='valid') activations = tf.nn.relu(channels) with tf.name_scope('color_matrix'): channels_cm = tf.layers.conv2d(activations, 7, kernel_size=3, strides=2, padding='valid') activations_cm = tf.nn.tanh(channels_cm) channels_cm = tf.layers.conv2d(activations_cm, 5, kernel_size=3, strides=2, padding='valid') channels_flat_cm = tf.reshape( channels_cm, [-1, channels_cm.shape[1]*channels_cm.shape[2]*channels_cm.shape[3]]) color_matrix = tf.reshape(tf.layers.dense(channels_flat_cm, 9), [3, 3]) with tf.name_scope('gamma'): channels_gamma = tf.layers.conv2d(activations, 7, kernel_size=3, strides=2, padding='valid') activations_gama = tf.nn.tanh(channels_gamma) channels_gamma = tf.layers.conv2d(activations_gama, 5, kernel_size=3, strides=2, padding='valid') channels_flat_gamma = tf.reshape( channels_gamma, [-1, channels_gamma.shape[1] * channels_gamma.shape[2] * channels_gamma.shape[3]]) gamma = tf.reshape(tf.maximum(tf.layers.dense(channels_flat_gamma, 1), 1e-3), [1]) return color_matrix, gamma def build_isp_process_flow(bayer, color_matrix, gamma): with tf.name_scope('isp_flow'): return gamma_correction( color_correction( demosaic( bayer ), color_matrix ), gamma ) # in form of NHWC def color_normalize(rgb): return rgb/tf.expand_dims(tf.maximum(tf.reduce_sum(rgb, axis=3), 1e-7), axis=-1) def color_loss(rgb_out, rgb_gt): return tf.reduce_mean(tf.abs(color_normalize(rgb_out) - color_normalize(rgb_gt))) # load images from files gt_files = glob.glob(data_dir + '/long/' + current_prefix + '*.ARW') in_files = [None]*len(gt_files) train_ids = [None] * len(gt_files) gt_raws = [None] * len(train_ids) gt_rgbs = [None] * len(train_ids) in_raws = [None] * len(train_ids) # Reorganize the raw files according to their training id for i in range(len(gt_files)): if USE_GPU: train_ids[i] = gt_files[i].split('/')[-1][1:5] else: train_ids[i] = gt_files[i].split('\\')[-1][1:5] # for input files, multiple ones may relate to single ground truth file in_files[i] = glob.glob(data_dir + '/short/' + current_prefix + train_ids[i] + '*.ARW') in_raws[i] = [None]*len(in_files[i]) def get_gt_file_by_train_id(tid): return gt_files[tid] def get_in_file_by_train_id_file_id(tid, fid): return in_files[tid][fid] def get_patch_pair_raw_raw(raw_in, raw_gt): h, w = raw_in.shape[0], raw_in.shape[1] y, x = np.random.randint(0, h - patch_size[0]), np.random.randint(0, w - patch_size[1]) return ( np.expand_dims(raw_in[y:y + patch_size[0], x:x + patch_size[1], :], axis=0), np.expand_dims(raw_gt[y:y + patch_size[0], x:x + patch_size[1], :], axis=0) ) def get_patch_pair_raw_rgb(raw, rgb): h, w = raw.shape[0], raw.shape[1] y, x = np.random.randint(0, h - patch_size[0]), np.random.randint(0, w - patch_size[1]) return ( np.expand_dims(raw[y:y + patch_size[0], x:x + patch_size[1], :], axis=0), np.expand_dims(rgb[y:y + patch_size[0], x:x + patch_size[1], :], axis=0) ) def get_rand_patch_from_file_raw2rgb(): while True: seq = np.random.permutation(len(train_ids)) for ind in seq: if gt_rgbs[ind] is None: # resource not found in cache, load it from disk gt_file = get_gt_file_by_train_id(ind) gt_rgb = rgb_from_file(gt_file) fid = np.random.randint(0, len(in_files[ind])) if in_raws[ind][fid] is None: in_file = get_in_file_by_train_id_file_id(ind, fid) in_raw = raw_from_file(in_file) # cache them when using GPU on linux server since memory is sufficient if USE_GPU: gt_rgbs[ind] = gt_rgb in_raws[ind][fid] = in_raw yield get_patch_pair_raw_rgb(in_raw, gt_rgb) def get_rand_patch_from_file_raw2raw(): while True: seq = np.random.permutation(len(train_ids)) for ind in seq: if gt_raws[ind] is None: # resource not found in cache, load it from disk gt_file = get_gt_file_by_train_id(ind) gt_raw = raw_from_file(gt_file) fid = np.random.randint(0, len(in_files[ind])) if in_raws[ind][fid] is None: in_file = get_in_file_by_train_id_file_id(ind, fid) in_raw = raw_from_file(in_file) # cache them when using GPU on linux server since memory is sufficient if USE_GPU: in_raws[ind][fid] = in_raw gt_raws[ind] = gt_raw yield get_patch_pair_raw_rgb(in_raw, gt_raw) # basic nodes t_bayer_in = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1], name='input') t_bayer_gt = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1]) t_bayer_std = black_level_correction(t_bayer_in) t_bayer_gt_std = black_level_correction(t_bayer_gt) t_bayer_boosted = network_raw2raw(tf.minimum(300*t_bayer_std, 1.0)) t_half_rgb = bayer_to_rgb(t_bayer_std) t_half_rgb_boosted = bayer_to_rgb(bound(t_bayer_boosted)) t_half_rgb_gt = bayer_to_rgb(t_bayer_gt_std) t_half_rgb_resized = tf.image.resize_bilinear(t_half_rgb, fixed_size) t_rgb_gt = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3]) # ISP nodes t_color_matrix, t_gamma = get_color_matrix_and_gamma(t_half_rgb_resized) # training raw2raw alone # t_err_raw = tf.reduce_mean(tf.abs(t_half_rgb_gt - t_half_rgb_boosted)) t_err_raw = tf.reduce_mean(tf.abs(gaussian_norm(t_half_rgb_boosted) - gaussian_norm(t_half_rgb_gt))) # training raw2rgb alone t_half_rgb_freeze = tf.stop_gradient(t_half_rgb_boosted) t_rgb_freeze = build_isp_process_flow(t_half_rgb_freeze, t_color_matrix, t_gamma) # t_err_rgb = tf.reduce_mean(tf.abs(t_rgb_gt - t_rgb_freeze)) t_err_rgb = color_loss(t_rgb_freeze, t_rgb_gt) + tf.abs(t_gamma[0] - 1.0/2.5) # t_err_rgb = color_loss(t_rgb_freeze, t_rgb_gt) # training overall model t_rgb_final = build_isp_process_flow(t_half_rgb_boosted, t_color_matrix, t_gamma) # t_err_overall = tf.reduce_mean(tf.abs(t_rgb_gt - t_rgb_final)) t_err_overall = color_loss(t_rgb_final, t_rgb_gt) def clean_no_grad_vars(vs, gs): vs_clear = [] gs_clear = [] for i in range(len(gs)): if gs[i] is not None: vs_clear.append(vs[i]) gs_clear.append(gs[i]) return vs_clear, gs_clear def make_var_grad_pairs(vs, gs): return [(gs[i], vs[i]) for i in range(len(vs))] def train(): print('Staged training begins...') t_opt = tf.train.GradientDescentOptimizer(learning_rate=learn_rate) sess = tf.Session() t_minimizer_raw2raw = t_opt.minimize(t_err_raw) t_minimizer_raw2rgb = t_opt.minimize(t_err_rgb) t_minimizer_overall = t_opt.minimize(t_err_overall) # include = ['g_conv1', 'g_conv2', 'g_conv3', 'g_conv4', 'g_conv5', 'g_conv_last'] # variables_to_restore = slim.get_variables_to_restore(include=include) # saver = tf.train.Saver(variables_to_restore) saver = tf.train.Saver(tf.global_variables()) sess.run(tf.global_variables_initializer()) # logger if not os.path.exists(log_dir): os.mkdir(log_dir) logger = tf.summary.FileWriter(log_dir, graph=sess.graph) t_sum_raw = tf.summary.scalar('raw2raw_loss', t_err_raw) t_sum_rgb = tf.summary.scalar('raw2rgb_loss', t_err_rgb) t_sum_all = tf.summary.scalar('overall_loss', t_err_overall) if not os.path.exists(os.path.join(model_dir, 'checkpoint')): if not os.path.exists(model_dir): os.mkdir(model_dir) else: print('Restoring model...') model_name_prefix = 'model_checkpoint_path: "' with open(os.path.join(model_dir + 'checkpoint')) as ckpt: latest_id = ckpt.readline()[len(model_name_prefix):-2] saver.restore(sess, os.path.join(model_dir, latest_id)) # bind saver to the full graph instead of a sub-graph saver = tf.train.Saver(tf.global_variables()) # first stage: raw to raw training if training_stage == __stage_raw2raw__: print('Stage I: train to map input raw into ground truth raw') patches = get_rand_patch_from_file_raw2raw() counter = 0 t_start = time.clock() for raw_in, raw_gt in patches: _, err_raw2raw, sum_raw = sess.run( [t_minimizer_raw2raw, t_err_raw, t_sum_raw], feed_dict={ t_bayer_in: raw_in, t_bayer_gt: raw_gt } ) logger.add_summary(sum_raw, counter) epoch = int(counter / len(train_ids)) print('Epoch# %d Counter# %d Loss= %.7f' % (epoch, counter, err_raw2raw)) counter += 1 if counter % 100 is 0: t_stop = time.clock() print('Speed: %.6f' % ((t_stop - t_start) / 100)) t_start = t_stop if counter > max_epoch * len(train_ids): saver.save(sess, model_dir + '/' + str(epoch)) print('Training done.') break elif counter % (len(train_ids) * save_epoch_delay) is 0: saver.save(sess, model_dir + '/' + str(epoch)) print('Model saved.') # second stage: raw to rgb training if training_stage == __stage_raw2rgb__: print('Stage II: train to map generated raw into ground truth rgb') # gradient clip # t_vs = tf.trainable_variables() # t_gs = tf.gradients(t_err_rgb, t_vs) # t_vs, t_gs = clean_no_grad_vars(t_vs, t_gs) # t_var_grad_pairs = make_var_grad_pairs(t_vs, t_gs) # t_minimizer_raw2rgb = t_opt.apply_gradients(t_var_grad_pairs) patches = get_rand_patch_from_file_raw2rgb() counter = 0 t_start = time.clock() for raw_in, rgb_gt in patches: _, err_raw2rgb, sum_rgb, gamma = sess.run( [t_minimizer_raw2rgb, t_err_rgb, t_sum_rgb, t_gamma], feed_dict={ t_bayer_in: raw_in, t_rgb_gt: rgb_gt } ) # _, err_raw2rgb, grads, sum_rgb, gamma = sess.run( # [t_minimizer_raw2rgb, t_err_rgb, t_gs, t_sum_rgb, t_gamma], # feed_dict={ # t_bayer_in: raw_in, # t_rgb_gt: rgb_gt # } # ) logger.add_summary(sum_rgb, counter) epoch = int(counter / len(train_ids)) print('Epoch# %d Counter# %d Loss= %.7f Gamma=%.6f' % (epoch, counter, err_raw2rgb, 1.0 / gamma)) # Gradient check # for i in range(len(grads)): # if has_nan_in_tensor(grads[i]): # print('Nan value found in gradient: %s!' % t_gs[i].name) counter += 1 if counter % 100 is 0: t_stop = time.clock() print('Speed: %.6f' % ((t_stop - t_start) / 100)) t_start = t_stop if counter > max_epoch * len(train_ids): saver.save(sess, model_dir + '/' + str(epoch)) print('Training done.') elif counter % (len(train_ids) * save_epoch_delay) is 0: saver.save(sess, model_dir + '/' + str(epoch)) print('Model saved.') # second stage: overall training if training_stage == __stage_overall__: print('Stage III: train to map input raw into ground truth rgb') patches = get_rand_patch_from_file_raw2rgb() counter = 0 t_start = time.clock() for raw_in, rgb_gt in patches: _, err_overall, sum_all = sess.run( [t_minimizer_overall, t_err_overall, t_sum_all], feed_dict={ t_bayer_in: raw_in, t_rgb_gt: rgb_gt } ) logger.add_summary(sum_all, counter) epoch = int(counter / len(train_ids)) print('Epoch# %d Counter# %d Loss= %.7f' % (epoch, counter, err_overall)) counter += 1 if counter % 100 is 0: t_stop = time.clock() print('Speed: %.6f' % ((t_stop - t_start) / 100)) t_start = t_stop if counter > max_epoch * len(train_ids): saver.save(sess, model_dir + '/' + str(epoch)) print('Training done.') elif counter % (len(train_ids) * save_epoch_delay) is 0: saver.save(sess, model_dir + '/' + str(epoch)) print('Model saved.') # finalization logger.close() sess.close() def test_half_rgb(): print('Testing Half RGB reconstruction...') sess = tf.Session() t_vars = tf.global_variables() # var_names = [] # for v in t_vars: # var_names.append(v.name) # print(v.name) saver = tf.train.Saver(t_vars) if not os.path.exists(model_dir): assert 'path not found!' model_name_prefix = 'model_checkpoint_path: "' with open(os.path.join(model_dir, 'checkpoint')) as ckpt: latest_id = ckpt.readline()[len(model_name_prefix):-2] saver.restore(sess, os.path.join(model_dir, latest_id)) print('Model loaded.') if not os.path.exists(out_dir): os.mkdir(out_dir) patches = get_rand_patch_from_file_raw2raw() counter = 0 for raw_in, raw_gt in patches: half_rgb_boosted, half_rgb_gt = sess.run( [t_half_rgb_boosted, t_half_rgb_gt], feed_dict={ t_bayer_in: raw_in, t_bayer_gt: raw_gt } ) im_cmp = concat((half_rgb_boosted[0], half_rgb_gt[0])) # show(im_cmp, str(counter)) save(im_cmp, (out_dir + '/HALF_%04d.jpg') % counter) counter += 1 if counter >= 20: break if __name__ == '__main__': # test_half_rgb() train()
1.先说tf.train.Saver()的坑,这个比较严重,其损失是不可挽回的!!!
由于经常需要迁移学习,需要执行图融合的操作,于是,需要先加载一部分子图然后创建另一部分子图,训练完后保存整个模型。
问题是:直接采用tf.train.Saver()的话,等效于saver = tf.train.Saver(tf.global_variables())
在加载子图的时候会报错:因为在子图的checkpoint文件中找不到新创建的子图中的算子,因此需要特别指定要回复的算子,而不是采用tf.global_variables()。
于是将tf.global_variables()这个替换掉,方案有两种:
1.直接利用name的prefix进行变量过滤,即对tf.global_variables()得到的变量列表中的部分变量根据其v.name进行剔除,剩下的就是需要加载的变量。
2.采用tf.contrib.slim直接获取要加载的变量列表,然而这里出现了一个坑:
slim.get_variables_to_restore(include=include) 中 include 是一个name list,采用正则进行名字匹配,原理是:if v.name.startswith('VAR_NAME_PREFIX'): ADD_TO_LIST(ret)
于是当你的include list中有conv2d这个变量名称前缀时,所有的conv2d_xxx都会被自动添加到列表中,而且,SLIM很傻逼的不进行查重检查!!!于是你得到的var_list中将会出现重复的
变量,导致加载模型时报错:at least two of variables have the same name : conv2d_1/bias !!!
填坑完毕!
创建saver一定要指定要加载的变量列表,不然不知不觉的可能导致辛辛苦苦训练好的变量(参数)最终没有保存,永远的在结束训练时的内存中消亡了~~~~~