目前对于Tensorflow的模型参数文件,我们处理起来没有Pytorch的参数文件那样方便,
并且现在任务中有个需求,要在“某几个参数矩阵中,将特定行的参数复制到某些其他行”。
Pytorch的话就还好,因为毕竟是一群tensor被一个OrderDict包装起来的Python基本数据结构。
同样的事情,在Tensorflow中处理起来会比较麻烦,于是考虑实现这个工具类 CheckpointMonitor 来提高处理效率。
__init__(checkpoint_path)
为checkpoint路径list_variables()
展示当前checkpoint中的所有参数即shapelist_target_variables(pattern)
同list_variables
,展示筛选后的参数列表(图3)get_var_data(var_name)
获得模型文件中对应参数名的参数,格式为numpysave_model(path, method='tf)
模型文件存回Tensorflow或Pytorchmodify_var_name(old_name, new_name)
修改参数名modify_var_names(rename_func)
批量修改参数名modify_var_data(var_name, var_data)
修改参数的值import os
os.environ['CUDA_LAUNCH_BLOCKING'] = ""
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
from collections import OrderedDict
class CheckpointMonitor(object):
"""
# CPU mode
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = ""
os.environ['CUDA_VISIBLE_DEVICES'] = ""
"""
def __init__(self, checkpoint_path=None):
if checkpoint_path is None: # default path for testing
checkpoint_path = '/data/sharedata/model_files/model.ckpt-250042'
self.saver = None
self.graph = None
self.dump_path = './'
self.checkpoint_path = checkpoint_path
self.default_dump_name = 'my_modified_model'
self.var_name_list = []
self.var_shape_dict = OrderedDict()
self.var_data_dict = OrderedDict()
self.init_vars()
def reload(self, checkpoint_path=None):
self.__init__(checkpoint_path=checkpoint_path)
def init_vars(self, checkpoint_path=None):
if checkpoint_path is None:
checkpoint_path = self.checkpoint_path
self.var_shape_dict = OrderedDict(
self.list_variables(checkpoint_path))
self.var_name_list = list(self.var_shape_dict.keys())
for var_name in self.var_name_list:
# print(var_name)
var_data = self.get_var_data(var_name, checkpoint_path)
# dict(str, np.array)
self.var_data_dict.update({var_name: var_data})
def sort_var_dicts(self):
self.var_data_dict = OrderedDict(
[(var_name, self.var_data_dict[var_name])
for var_name in self.var_name_list])
self.var_shape_dict = OrderedDict(
[(var_name, self.var_shape_dict[var_name])
for var_name in self.var_name_list])
def list_variables(self, checkpoint_path=None):
# get all variables in form of tuple(name, shape) in checkpoint
if checkpoint_path is None:
checkpoint_path = self.checkpoint_path
# return a list of (var_name, shape)
return tf.contrib.framework.list_variables(checkpoint_path)
def list_target_variables(self, pattern, checkpoint_path=None):
if checkpoint_path is None:
if self.var_shape_dict.__len__() != 0:
# lazy loading
var_list = self.var_shape_dict.items()
return [(name, shape) for (name, shape)
in var_list if pattern in name]
else: # load for cold-booting
checkpoint_path = self.checkpoint_path
var_list = self.list_variables(checkpoint_path)
return [(name, shape) for (name, shape) in var_list if pattern in name]
def get_var_data(self, var_name, checkpoint_path=None):
# load variable from target checkpoint with the name as var_name
if checkpoint_path is None:
if self.var_data_dict.__len__() != 0:
# lazy loading
return self.var_data_dict.get(var_name)
checkpoint_path = self.checkpoint_path
# return the variable object (np.array)
return tf.contrib.framework.load_variable(checkpoint_path, var_name)
@staticmethod
def generate_rename_func(old_name_list, new_name_list):
def fn(var_name):
if var_name in old_name_list:
return new_name_list[old_name_list.index(var_name)]
return var_name
return fn
def modify_var_name(self, old_name, new_name, inplace=True):
var_index = self.var_name_list.index(old_name)
self.var_name_list[var_index] = new_name
self.var_data_dict[new_name] = self.var_data_dict[old_name]
self.var_shape_dict[new_name] = self.var_shape_dict[old_name]
del self.var_data_dict[old_name]
del self.var_shape_dict[old_name]
if inplace:
self.sort_var_dicts()
def modify_var_names(self, rename_func=None):
# modify var_names in batch, with a feed function `rename_func`
if rename_func is None:
rename_func = lambda _name: _name
with tf.Session() as sess:
for var_index, var_name in enumerate(self.var_name_list):
# get variable values, in form of np.array
new_name = rename_func(var_name)
if new_name != var_name:
self.modify_var_name(var_index, new_name, inplace=False)
print('Re-naming {} to {}.'.format(var_name, new_name))
self.sort_var_dicts()
def modify_var_data(self, var_name, var_data):
assert isinstance(var_data, np.ndarray)
if var_name not in self.var_name_list:
print("Invalid variable name:{}".format(var_name))
print("You can get avaliable variable names by calling list_variables()")
var_index = self.var_name_list.index(var_name)
self.var_shape_dict[var_name] = list(var_data.shape)
self.var_data_dict[var_name] = var_data
def generate_var_dict_for_torch(self, var_list=None):
if var_list is None:
var_list = self.var_data_dict.items()
torch_model_dict = OrderedDict()
for var_name, var_data in var_list:
var = torch.tensor(var_data)
torch_model_dict.update({var_name: var})
return torch_model_dict
def generate_var_list_for_saver(self, var_list=None):
if var_list is None:
var_list = self.var_data_dict.items()
saver_var_list = []
with tf.Session() as sess:
for var_name, var_data in var_list:
var = tf.Variable(var_data, name=var_name)
saver_var_list.append(var)
return saver_var_list
def save_model(self, new_checkpoint_path=None, model_name=None, method='pt'):
if new_checkpoint_path is None:
new_checkpoint_path = self.dump_path
if not os.path.exists(new_checkpoint_path):
os.makedirs(new_checkpoint_path)
if model_name is None:
model_name = self.default_dump_name
checkpoint_path = os.path.join(
new_checkpoint_path, model_name)
method_dict = {
'pt': self.save_model_as_pt,
'tf': self.save_model_as_tf,
'ckpt': self.save_model_as_tf,
'torch': self.save_model_as_pt,
'pytorch': self.save_model_as_pt,
'tensorflow': self.save_model_as_tf,
}
method_dict[method](checkpoint_path)
def save_model_as_pt(self, checkpoint_path):
import torch
var_dict = self.generate_var_dict_for_torch()
checkpoint = OrderedDict({'model': var_dict})
torch.save(checkpoint, checkpoint_path + '.pt')
print("Checkpoint saving finished !\n{}".format(
checkpoint_path + '.pt'))
def save_model_as_tf(self, checkpoint_path):
with tf.Session() as sess:
var_list = self.generate_var_list_for_saver()
# Construct the Saver
self.saver = tf.train.Saver(var_list=var_list)
# Necessary! Call the initializer at the beginning.
sess.run(tf.global_variables_initializer())
self.saver.save(sess, checkpoint_path)
print("Checkpoint saving finished !\n{}".format(
checkpoint_path))