实体命名识别详解(十四)

        # Generic functions that add training op and initialize session
        self.add_train_op(self.config.lr_method, self.lr, self.loss,
                self.config.clip)
        self.initialize_session() # now self.sess is defined and vars are init

接下来是俩通用函数,可以看到,一个是训练用train,一个初始化session

    def add_train_op(self, lr_method, lr, loss, clip=-1):
        """Defines self.train_op that performs an update on a batch

        Args:
            lr_method: (string) sgd method, for example "adam"
            lr: (tf.placeholder) tf.float32, learning rate
            loss: (tensor) tf.float32 loss to minimize
            clip: (python float) clipping of gradient. If < 0, no clipping

        """
        _lr_m = lr_method.lower() # lower to make sure

        with tf.variable_scope("train_step"):
            if _lr_m == 'adam': # sgd method
                optimizer = tf.train.AdamOptimizer(lr)
            elif _lr_m == 'adagrad':
                optimizer = tf.train.AdagradOptimizer(lr)
            elif _lr_m == 'sgd':
                optimizer = tf.train.GradientDescentOptimizer(lr)
            elif _lr_m == 'rmsprop':
                optimizer = tf.train.RMSPropOptimizer(lr)
            else:
                raise NotImplementedError("Unknown method {}".format(_lr_m))

            if clip > 0: # gradient clipping if clip is positive
                grads, vs     = zip(*optimizer.compute_gradients(loss))
                grads, gnorm  = tf.clip_by_global_norm(grads, clip)
                self.train_op = optimizer.apply_gradients(zip(grads, vs))
            else:
                self.train_op = optimizer.minimize(loss)

这里传入的参数是self、优化函数(lr_method)、学习率(learning_rate)、损失函数(loss)、clip(目前不知道是干什么用的)
先看函数体介绍,,定义self.train_op,这应该是一个变量什么的⑧,对每一个batch进行更新操作,传入四个参数:
lr_method学习方法,通常我们用Adam优化器。
lr学习率。
loss将损失值降到最低(loss to minimize)
clip进行梯度裁剪,如果小于0的话不进行梯度裁剪,稳重我们将clip设置默认值为-1,也就是说不进行裁剪。
看函数体:
首先使用lower()函数,将传入的lr_method全部转换为小写字母,然后在train_step命名域中进行优化器的选择,并将学习率lr传入其中。最后执行优化函数optimizer.minimize(loss)。

    def initialize_session(self):
        """Defines self.sess and initialize the variables"""
        self.logger.info("Initializing tf session")
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver()

初始化session函数,这几个通用操作无需多言了吧。
OK,总的来说,model.build中的这些个初始化操作就初步分析完了。。。心好累P

你可能感兴趣的:(实体命名识别详解(十四))