元学习系列文章
此篇是对 MAML 源代码的解释,作者开源了论文代码,但是代码中注释很少,刚开始不容易理清思路,所以对代码中的关键部分进行了解释说明,核心是 construct_model()
函数,里面包含了 MAML 的训练过程,看代码实现能够更清楚地理解作者的思想。
maml.py
""" Code for the MAML algorithm and network definitions. """
from __future__ import print_function
import numpy as np
import sys
import tensorflow as tf
try:
import special_grads
except KeyError as e:
print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e,
file=sys.stderr)
from tensorflow.python.platform import flags
from utils import mse, xent, conv_block, normalize
FLAGS = flags.FLAGS
class MAML:
def __init__(self, dim_input=1, dim_output=1, test_num_updates=5):
""" must call construct_model() after initializing MAML! """
self.dim_input = dim_input
self.dim_output = dim_output
self.update_lr = FLAGS.update_lr
self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
self.classification = False
self.test_num_updates = test_num_updates
if FLAGS.datasource == 'sinusoid':
self.dim_hidden = [40, 40]
self.loss_func = mse
self.forward = self.forward_fc
self.construct_weights = self.construct_fc_weights
elif FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'miniimagenet':
self.loss_func = xent
self.classification = True
if FLAGS.conv:
self.dim_hidden = FLAGS.num_filters
self.forward = self.forward_conv
self.construct_weights = self.construct_conv_weights
else:
self.dim_hidden = [256, 128, 64, 64]
self.forward=self.forward_fc
self.construct_weights = self.construct_fc_weights
if FLAGS.datasource == 'miniimagenet':
self.channels = 3
else:
self.channels = 1
self.img_size = int(np.sqrt(self.dim_input/self.channels))
else:
raise ValueError('Unrecognized data source.')
# ************************* 模型训练图的构建过程,此函数是核心代码 *************************
def construct_model(self, input_tensors=None, prefix='metatrain_'):
# a: training data for inner gradient, b: test data for meta gradient
if input_tensors is None:
self.inputa = tf.placeholder(tf.float32)
self.inputb = tf.placeholder(tf.float32)
self.labela = tf.placeholder(tf.float32)
self.labelb = tf.placeholder(tf.float32)
else:
self.inputa = input_tensors['inputa']
self.inputb = input_tensors['inputb']
self.labela = input_tensors['labela']
self.labelb = input_tensors['labelb']
# 训练过程计算图
with tf.variable_scope('model', reuse=None) as training_scope:
# 如果不是第一次执行训练图, self 中则存在 self.weights 变量,那么所有的 tasks 都会共享这组 weights
if 'weights' in dir(self):
training_scope.reuse_variables()
weights = self.weights
else:
# Define the weights
# 第一次执行, weights 不在 dir(self) 中,则进行手动初始化
self.weights = weights = self.construct_weights()
# outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates
lossesa, outputas, lossesb, outputbs = [], [], [], []
accuraciesa, accuraciesb = [], []
num_updates = max(self.test_num_updates, FLAGS.num_updates)
outputbs = [[]]*num_updates
lossesb = [[]]*num_updates
accuraciesb = [[]]*num_updates
def task_metalearn(inp, reuse=True):
""" Perform gradient descent for one task in the meta-batch.
meta batch 个 task,并行执行 task_metalearn, 每个 task_metalearn 处理一个具体 task 的训练任务
"""
inputa, inputb, labela, labelb = inp
task_outputbs, task_lossesb = [], []
if self.classification:
task_accuraciesb = []
# inputa: [inner_batch, 1], task_outputa: [inner_batch, 1]
task_outputa = self.forward(inputa, weights, reuse=reuse) # only reuse on the first iter
task_lossa = self.loss_func(task_outputa, labela)
grads = tf.gradients(task_lossa, list(weights.values()))
if FLAGS.stop_grad:
grads = [tf.stop_gradient(grad) for grad in grads]
# w1: g1, w2: g2
gradients = dict(zip(weights.keys(), grads))
# w1: w1 - α*g1, w2: w2 - α*g2,
fast_weights = dict(zip(weights.keys(), [weights[key] - self.update_lr*gradients[key] for key in weights.keys()]))
# 使用更新后的 w, 在 inputb task 上再计算一次 meta
output = self.forward(inputb, fast_weights, reuse=True)
task_outputbs.append(output)
task_lossesb.append(self.loss_func(output, labelb))
# task 内部进行 num_updates 次更新,上面更新了一次,所以这里是 num_updates-1
for j in range(num_updates - 1):
loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela)
grads = tf.gradients(loss, list(fast_weights.values()))
if FLAGS.stop_grad:
grads = [tf.stop_gradient(grad) for grad in grads]
gradients = dict(zip(fast_weights.keys(), grads))
fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.update_lr*gradients[key] for key in fast_weights.keys()]))
output = self.forward(inputb, fast_weights, reuse=True)
task_outputbs.append(output)
task_lossesb.append(self.loss_func(output, labelb))
# inputa 是训练集,inputb 是 task 的测试集
# task_outputa 是第一次前向计算在 inputa 数据的输出,task_lossa 是基于 task_outputa 在参数 weight 上计算的 loss
# task_outputbs 是每次梯度更新后参数在 inputb 数据上的输出,task_lossesb 是基于每个 task_outputb 计算出的 loss
task_output = [task_outputa, task_outputbs, task_lossa, task_lossesb]
if self.classification:
task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1), tf.argmax(labela, 1))
for j in range(num_updates):
task_accuraciesb.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), tf.argmax(labelb, 1)))
task_output.extend([task_accuracya, task_accuraciesb])
return task_output
if FLAGS.norm is not 'None':
# to initialize the batch norm vars, might want to combine this, and not run idx 0 twice.
unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
# [task_outputa, task_outputbs, task_lossa, task_lossesb]
out_dtype = [tf.float32, [tf.float32]*num_updates, tf.float32, [tf.float32]*num_updates]
if self.classification:
out_dtype.extend([tf.float32, [tf.float32]*num_updates])
result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)
if self.classification:
outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result
else:
outputas, outputbs, lossesa, lossesb = result
## Performance & Optimization
## 汇总 loss 函数作为 meta 的训练节点
if 'train' in prefix:
# lossesa 是 meta_batch_size 个具体任务在 inputa 数据上的第一次前向的 loss,
self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
# lossesb[j] 是第 j 次更新时,meta_batch_size 个任务在 inputb 数据上的 loss
self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
# after the map_fn
self.outputas, self.outputbs = outputas, outputbs
if self.classification:
self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
self.total_accuracies2 = total_accuracies2 = [tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
# pretrain 使用 inputa 数据上的第一次 loss 和,pretrain 相当于迁移学习的预训练
self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1)
# metatrain, 使用 inputb 数据上最后一次前向计算出的平均 loss 作为优化目标
if FLAGS.metatrain_iterations > 0:
optimizer = tf.train.AdamOptimizer(self.meta_lr)
# metatrain_op 最小化目标是每个 task 最后一次前向计算出的 loss 的平均值
self.gvs = gvs = optimizer.compute_gradients(self.total_losses2[FLAGS.num_updates-1])
if FLAGS.datasource == 'miniimagenet':
gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]
self.metatrain_op = optimizer.apply_gradients(gvs)
else:
self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
if self.classification:
self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
self.metaval_total_accuracies2 = total_accuracies2 =[tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
## Summaries
# total_loss1 是各个任务在 meta-update 之前的平均 loss
tf.summary.scalar(prefix+'Pre-update loss', total_loss1)
if self.classification:
tf.summary.scalar(prefix+'Pre-update accuracy', total_accuracy1)
for j in range(num_updates):
tf.summary.scalar(prefix+'Post-update loss, step ' + str(j+1), total_losses2[j])
if self.classification:
tf.summary.scalar(prefix+'Post-update accuracy, step ' + str(j+1), total_accuracies2[j])
### Network construction functions (fc networks and conv networks)
# 构建网络模型全连接层的参数
def construct_fc_weights(self):
weights = {}
weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))
weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))
for i in range(1,len(self.dim_hidden)):
weights['w'+str(i+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[i-1], self.dim_hidden[i]], stddev=0.01))
weights['b'+str(i+1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))
weights['w'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))
weights['b'+str(len(self.dim_hidden)+1)] = tf.Variable(tf.zeros([self.dim_output]))
return weights
# 执行网络模型的前向计算过程
def forward_fc(self, inp, weights, reuse=False):
hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0')
for i in range(1,len(self.dim_hidden)):
hidden = normalize(tf.matmul(hidden, weights['w'+str(i+1)]) + weights['b'+str(i+1)], activation=tf.nn.relu, reuse=reuse, scope=str(i+1))
return tf.matmul(hidden, weights['w'+str(len(self.dim_hidden)+1)]) + weights['b'+str(len(self.dim_hidden)+1)]
# 构建卷积层的参数
def construct_conv_weights(self):
weights = {}
dtype = tf.float32
conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
k = 3
weights['conv1'] = tf.get_variable('conv1', [k, k, self.channels, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv2'] = tf.get_variable('conv2', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv3'] = tf.get_variable('conv3', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
weights['conv4'] = tf.get_variable('conv4', [k, k, self.dim_hidden, self.dim_hidden], initializer=conv_initializer, dtype=dtype)
weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
if FLAGS.datasource == 'miniimagenet':
# assumes max pooling
weights['w5'] = tf.get_variable('w5', [self.dim_hidden*5*5, self.dim_output], initializer=fc_initializer)
weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
else:
weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output]), name='w5')
weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
return weights
# 执行卷积层的前向计算
def forward_conv(self, inp, weights, reuse=False, scope=''):
# reuse is for the normalization parameters.
channels = self.channels
inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])
hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope+'0')
hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope+'1')
hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope+'2')
hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope+'3')
if FLAGS.datasource == 'miniimagenet':
# last hidden layer is 6x6x64-ish, reshape to a vector
hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])])
else:
hidden4 = tf.reduce_mean(hidden4, [1, 2])
return tf.matmul(hidden4, weights['w5']) + weights['b5']