tensorflow 固定部分参数训练,只训练部分参数

    def var_filter(var_list, last_layers = [0]):
        filter_keywords = ['fine_tune', 'layer_11', 'layer_10', 'layer_9', 'layer_8']
        for var in var_list:
            for layer in last_layers:
                kw = filter_keywords[layer]
                if kw in var.name:
                    yield var
                    break
            else:
                continue
                
    def set_optimizer(self, n):
        train_vars = list(var_filter(tf.trainable_variables(), last_layers = range(n)))
        self.train_op = self.optim.minimize(self.loss, global_step=self.global_step, var_list = train_vars)

你可能感兴趣的:(TensorFlow)