当一个项目的代码不断增加,其中很多的内容诸如全局变量、提示语言等等都有必要放在一个独立的文件,方便变更。这个独立的文件有很多种,可以是init文件、conf文件、xml文件,为了通用性,我选择了xml文件作为自己的配置文件。对于《字符级的CNN文本分类器》一文中,我的xml文件是这样的:
xml version="1.0" encoding="UTF-8" ?>xml是一种非常开放的语言,所有的标签、属性命名都没有相关的规定,什么时候使用子标签,什么时候使用属性全凭作者的习惯,这里我给出一种判断的标准:对于数据的存储,尽量使用子标签,例如上述中hyperparameter(超参)属于我要存的数据,所以我用标签来存。但name和type属于这个数据的一些属性,我使用属性去存储。但其实这样的分类法有时候并不会很清晰。例如电影,一部电影有标题、简介、演员等内容,这些属于这部电影的属性,但从另外一个角度来说,这些又是属于我们要存起来的数据。所以有另外一种标准,我们可以把需要直接使用的内容看成数据,例如这个超参直接赋给程序,电影的标题、简介和演员直接输出到前端页面,而name和type属于对这些数据的描述,我们不会直接使用它们,所以作为属性。在上述文件中,大家还可以看到我在部分内容里面使用了格式字符串“%s%d%f”等,关于这点的意义之一,我在另一篇文章中提到,就是为了可以存储整条文本,方便语言包的制作。具体怎样指定相关的变量,后面会介绍。name="train_text_home">TrainText name="test_text_home">TestText name="checkpoint_home">CheckPoints name="summary_home">Summaries name="log_home">Logs name="model_file">ENCharCNNTextClassification_%s name="summary_file">ENCharCNNTextClassification_%s name="log_file">Logs_of_%s.log name="pre_train.py"> name="encode_success">One-Hot encoding done! Totally %d words have been skipped. name="start_train_file">Start to train from file: %s. name="done_train_file">Finish to train from file: %s. name="open_dir">Open directory: %s. name="char_cnn.py"> name="checkpoint_restore">Checkpoint: %s has been restored. name="checkpoint_restore_fail">No checkpoints being restored. name="display_steps">Total steps: %d, batch cost: %.4f, batch accuracy: %.2f%%, time to use: %d seconds. name="display_test">Accuracy: %.2f%%. name="main.py"> name="done_train">The train has been done! name="done_validation">The validation has been done! name="length0" type="int">1014 name="n_class" type="int">4 name="batch_size" type="int">128 name="learning_rate" type="float">0.01 name="decay_steps" type="int">1000 name="decay_rate" type="float">0.8 name="keep_prob" type="float">0.5 name="grad_clip" type="int">5
有了xml文件,就需要有代码去读,同样为了全局性考虑,我专门写了一个类来读取配置文件,做我程序的初始化工作:
import os import logging import xml.etree.ElementTree class ProjectInitializer: config = None work_dir = os.getcwd() @classmethod def init_my_project(cls, config_file='config.xml'): assert os.path.isfile(config_file) assert (os.path.splitext(config_file)[-1]) == '.xml' cls.config = xml.etree.ElementTree.parse(config_file).getroot() cls._check_dirs() cls._init_logging() @classmethod def get_dir_path(cls, home_name): path = cls.config.find("./directories/directory[@name=\'%s\']" % home_name).text if os.path.isabs(path): return path else: return os.path.join(cls.work_dir, path) @classmethod def get_file_path(cls, home_name, file_name, *args): real_file_name = cls.config.find("./files/file[@name=\'%s\']" % file_name).text if len(args) > 0: real_file_name = real_file_name % args return os.path.join(cls.get_dir_path(home_name), real_file_name) @classmethod def message_about(cls, module_file_path, event_name, *args): module_name = os.path.basename(module_file_path) mess_module = cls.config.find("./messages/module[@name=\'%s\']" % module_name) assert mess_module is not None message = mess_module.find("./message[@name=\'%s\']" % event_name).text if len(args) > 0: message = message % args return message @classmethod def option(cls, opt_name): tag = cls.config.find("./options/option[@name=\'%s\']" % opt_name) value = eval(tag.get('type'))(tag.text) return value @classmethod def hyper_para(cls, para_name): tag = cls.config.find("./hyperparameters/hyperparameter[@name=\'%s\']" % para_name) value = eval(tag.get('type'))(tag.text) return value @classmethod def _check_dirs(cls): for directory in cls.config.findall("./directories/directory"): path = cls.get_dir_path(directory.get('name')) if not os.path.isdir(path): os.mkdir(path) @classmethod def _init_logging(cls): formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s - %(message)s', '%Y %b %d %H:%M:%S') handlers = [ logging.FileHandler(cls.get_file_path('log_home', 'log_file', cls.option('train_name'))), logging.StreamHandler() ] logging.getLogger().setLevel(logging.INFO) for handler in handlers: handler.setFormatter(formatter) logging.getLogger().addHandler(handler)且看我逐行解析:
class ProjectInitializer: config = None work_dir = os.getcwd()类变量config是存储xml解析树的,在import类的时候,因为配置文件的路径不确定,所以先初始化为None,而work_dir则是要获取当前的工作目录作为全局变量,这样的好处是,无论后面我们是否改变了工作目录,我们配置文件指定的路径都可以跟这个路径组合成绝对路径,固定存储的位置。
xml获取相关标签的公式如下:
Syntax | Meaning |
---|---|
tag |
Selects all child elements with the given tag. For example, spam selects all child elements named spam , and spam/egg selects all grandchildren named egg in all children named spam . |
* |
Selects all child elements. For example, */egg selects all grandchildren named egg . |
. |
Selects the current node. This is mostly useful at the beginning of the path, to indicate that it’s a relative path. |
// |
Selects all subelements, on all levels beneath the current element. For example, .//egg selects all egg elements in the entire tree. |
.. |
Selects the parent element. Returns None if the path attempts to reach the ancestors of the start element (the element find was called on). |
[@attrib] |
Selects all elements that have the given attribute. |
[@attrib='value'] |
Selects all elements for which the given attribute has the given value. The value cannot contain quotes. |
[tag] |
Selects all elements that have a child named tag . Only immediate children are supported. |
[tag='text'] |
Selects all elements that have a child named tag whose complete text content, including descendants, equals the given text . |
[position] |
Selects all elements that are located at the given position. The position can be either an integer (1 is the first position), the expression last() (for the last position), or a position relative to the last position (e.g. last()-1 ). |
Predicates (expressions within square brackets) must be preceded by a tag name, an asterisk, or another predicate. position
predicates must be preceded by a tag name.
详细内容请参考xml.etree.ElementTree文档
@classmethod def init_my_project(cls, config_file='config.xml'): assert os.path.isfile(config_file) assert (os.path.splitext(config_file)[-1]) == '.xml' cls.config = xml.etree.ElementTree.parse(config_file).getroot() cls._check_dirs() cls._init_logging()第一个静态方法是用来初始化项目的,会在main函数的第一句被调用。从代码我们可以清楚看到,主要工作就是读取指定的xml文件,检查需要的目录是否存在,以及初始化logging的相关信息。
@classmethod def get_dir_path(cls, home_name): path = cls.config.find("./directories/directory[@name=\'%s\']" % home_name).text if os.path.isabs(path): return path else: return os.path.join(cls.work_dir, path) @classmethod def get_file_path(cls, home_name, file_name, *args): real_file_name = cls.config.find("./files/file[@name=\'%s\']" % file_name).text if len(args) > 0: real_file_name = real_file_name % args return os.path.join(cls.get_dir_path(home_name), real_file_name)这两个方法是用来获取相关的目录和文件路径的,为了便于在win和Linux之间迁移,我配置文件一般不会存路径分割符,而像代码中的使用os.path中的join方法来连接,这样就保证了可迁移性。同时在第二个方法中我们可以看到,函数使用了Python的一个可变参数的特性*args,这个特性允许使用着动态输入若干个参数,例如参数文件中的log文件我们会允许输入一个字符串来区分不同的日志,这样调用者获取log文件路径的时候,只需要把这个字符串按顺序放在函数里面,代码便会替换相应的内容,生成真正的文件名,再组合目录的路径成为绝对路径返回。由于这个参数的个数是可变的,也满足了字符串中替代符个数不定的需要。
@classmethod def message_about(cls, module_file_path, event_name, *args): module_name = os.path.basename(module_file_path) mess_module = cls.config.find("./messages/module[@name=\'%s\']" % module_name) assert mess_module is not None message = mess_module.find("./message[@name=\'%s\']" % event_name).text if len(args) > 0: message = message % args return message这个方法是用来获取提示信息的,由于提示信息非常多,而且一般每个模块有固定的提示信息,很少是各个模块共用的,所以在提示信息保存的xml里面,我增加了一个父标签
(MyInit.message_about(__file__, 'done_train')__file__这个变量存储了当前模块的文件名,这样在不同的模块里面,我们便可以统一使用这样的代码来进行调用,减少了错误的发生:
MyInit.message_about(__file__, 'start_train_file', file_path)上面是带参数的调用。
@classmethod def option(cls, opt_name): tag = cls.config.find("./options/option[@name=\'%s\']" % opt_name) value = eval(tag.get('type'))(tag.text) return value @classmethod def hyper_para(cls, para_name): tag = cls.config.find("./hyperparameters/hyperparameter[@name=\'%s\']" % para_name) value = eval(tag.get('type'))(tag.text) return value这两个都是获取全局变量的方法,其中option和hyperparameter并没有太严格的区分,我凭我个人感觉认为后者对神经网络的定义相关性更大。方法中值得一提的是,tag.text存的都是str变量,这样我们在使用的时候可能会面临变量类型的报错,所以我特意增加了一个type属性,然后使用Python的特殊方法eval执行,例如是int的变量,eval('int)('4')的效果就等同于int('4'),这样返回的全局变量变可以通过xml文件动态设置属性了。
@classmethod def _check_dirs(cls): for directory in cls.config.findall("./directories/directory"): path = cls.get_dir_path(directory.get('name')) if not os.path.isdir(path): os.mkdir(path)该方法是检查相关的目录是否存在,没有就创建,代码比较简单。
@classmethod def _init_logging(cls): formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s - %(message)s', '%Y %b %d %H:%M:%S') handlers = [ logging.FileHandler(cls.get_file_path('log_home', 'log_file', cls.option('train_name'))), logging.StreamHandler() ] logging.getLogger().setLevel(logging.INFO) for handler in handlers: handler.setFormatter(formatter) logging.getLogger().addHandler(handler)最后一个方法是logging的初始化,由于本文主要讲的是xml读取,这部分代码只是简单介绍一下。formatter是用来定义格式的,handlers里面存了需要输出的地方,本代码里面是logfile和屏幕,把它设置到getLogger()获取的root logger里面,就可以对所有的log执行了。
上述类就是用来初始化我的代码的,我做人工智能的编程基本上都是照搬这两个文件,然后修改一下xml配置文件的内容,非常方便。