上一篇文章的主要任务是记录了论文中超网络实现CL的思想,这篇文章主要任务是分析作者的代码,具体分析论文中超网络的三个部分的实现以及对比实验的算法实现。
超网络的代码分析和对比实验大致按照在论文中出现的先后顺序给出。
Toy Example
的工程目录为:https://github.com/chrhenning/hypercl/tree/master/toy_example
连续学习的任务数量:3个。分别是 1.一次函数,2.二次函数,3.三次函数。
代码路径:hypercl/toy_example/
定义的CL场景: 论文Continual learning with hypernetworks 中的场景1,即连续学习时提供任务 id 的场景。
在三个区间分别定义三个类型的数据,代码如下:
# 自定义数据的映射关系,输入输出之间的三种关系分别对应三个任务。
map_funcs = [lambda x : (x+3.), # 一次函数
lambda x : 2. * np.power(x, 2) - 1, # 二次函数
lambda x : np.power(x-3., 3)] # 三次函数
num_tasks = len(map_funcs) # 3个任务
x_domains = [[-4,-2], [-1,1], [2,4]] # 取相同长度自变量范围,保证三个任务的数据量相同。
std = .05 # 标志差设置为非0,以保证数据有有一定的波动性。
自定义的数据的可视化如下(三个颜色分别对应三个任务的数据):
$ python3 train.py --no_cuda --beta=0.005 --emb_size=2 --n_iter=4001 --lr_hyper=1e-2 --data_random_seed=42
参数说明如下(所有的参数在 /hypercl/toy_example/train_utils.py
中定义):
no_cuda
:不是使用显卡进行训练,数据、主模型、超网络模型、识别任务id模型(mnet, hnet, rnet )就不用加载到显卡内存中;beta
: β \beta β 用于调整对之前学习任务的正则化强度,就是论文中 L o s s t o t a l Loss_{total} Losstotal 计算中的 β \beta β , β \beta β 设置的值越大对之前任务学习到的权重越重视(约束越强),越不容易改变过去的权重;emb_size
:程序中使用 te_dim
表示,指的是任务嵌入向量(task embeddings)的维度;n_iter
:定义每个任务需要迭代的次数;lr_hyper
:超网络的学习率;data_random_seed
:随机生成训练数据时,设置的随机种子;每个任务各自迭代4000次之后,超网络连续学习结果的可视化如下:
图片解读:三个颜色表示连续学习的三个任务,细点表示是实际的函数值,粗点表示是预测的函数值。下图中三个任务的细点所在的线条和粗点所在的线条重合完好,表示接续学习的结果很好,没有发生明显的遗忘现象。
使用 MSE 表示预测与真实值之间的差距,MSE 越小越好。
使用 Current MSE 与 Immediate MSE 之间的差值表示遗忘程度,差值越小越好,负数比正数好。
主网络、超网络、任务id识别网络,三个网络分别在 main_model.py
,hyper_model.py
,task_recognition_model.py
三个文件中被定义,被train_utils.py
中的 _generate_networks()
函数使用 。
主网络mnet
: 网络结构由 train_utils.py
中的 --main_arch
定义参数,在main_model.py
中代码实现。
使用的MLP全连接层,默认的隐藏层是default='10,10'
。即网络结构是 1 10 10 1
的全连接层(一维数据的输入是1
维,输出也是1
维的)。因此可以计算出主网络的weight_shape = [[10, 1], [10], [10, 10], [10], [1, 10], [1]]
,一共有141个权重(含偏置)。
作者的网络的forward 过程是通过遍历weight_shape
中的内容实现的,关键的两行代码如下:
hidden = self._spec_norm(F.linear(hidden, W, bias=b))
# 相乘之后归一化hidden = self._a_fun(hidden)
# 默认使用relu
激活函数部分代码如下:
# main_model.py
# weights 中存放的就是超网络所有的权重。
hidden = x
for l in range(0, len(weights), step_size):
W = weights[l]
if self.has_bias:
b = weights[l+1]
else:
b = None
hidden = self._spec_norm(F.linear(hidden, W, bias=b))
# Only for hidden layers.
if l / step_size + 1 < num_layers:
if self._dropout_rate != -1:
hidden = self._dropout(hidden)
if self._a_fun is not None:
hidden = self._a_fun(hidden)
return hidden
训练过程中的日志如下(在main_model.py
中实现):
Creating an MLP with 141 weights. # 主网络的权重数
141 的可以通过weight_shape
计算得到:141 = 10x1 + 10 + 10x10 + 10 + 1x10 + 1
主网络的 init_weights
的权重和偏置的初始化方式通过在hypercl/utils/torch_utils.py
文件中的init_params
函数来定义,分别是
nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
# 默认的权重初始化方式nn.init.uniform_(bias, -bound, bound)
# 默认的偏置初始化方式超网路hnet
: 网络结构由 train_utils.py
中的 --hnet_arch
定义,默认是default='10,10'
。在hyper_model.py
中实现。
在本次的回归任务中,定义任务嵌入向量 te_dim
的维度是2,即超网络的输入维度是2,所以超网络(前半部分)的结构是: 2 10 10
,最后一层的输出是 10
个数值;
超网络(前半部分)的权重维度如下:
self._hidden_dims = [[10, 2], [10], [10, 10], [10]]
超网络需要输出(主网络的权重)数据的shape是:self.target_shapes = [[10, 1], [10], [10, 10], [10], [1, 10], [1]]
(解释:只有一个数字的[*]
表示偏置维度),所以超网络的输出需要输出的数据数量是 10x1 + 10 + 10x10 + 10 + 1x10 + 1
,这 141 个数据都要通过结构为2 10 10
的超网络输出,作者这里没有采用10
之后直接接141
的全连接方式输出,而是通过10
之后分别接10 10 100 10 10 1
的方式实现。
main_shapes = [[10, 1], [10], [10, 10], [10], [1, 10], [1]] # 主网络的结构就是超网络需要生成的目标
self.target_shapes
分析:
10 10 100 10 10 1
,最后一个10 与各个相连需要的权重(含偏置)个数是 (10
x10+10
)+(10
x10+10
)+(100
x10+100
)+(10
x10+10
)+(10
x10+10
)+(1
x10+1
) = 15512 10 10
)中间需要的权重个数:(10x2+10)+(10x10+10) = 140超网络后半部分结构的权重维度(在hyper_model.py
中定义)为:
self._out_dims = [[10, 10], [10], [10, 10], [10], [100, 10], [100], [10, 10], [10], [10, 10], [10], [1, 10], [1]]
最终的超网络的 Θ \Theta Θ 权重维度由这两部分(超网络的前半部分 和 生成主网络所需的各个部分)组成:
self._theta_shapes = self._hidden_dims + self._out_dims
此外,超网络还需要任务嵌入向量(固定的数值)作为输入,在该回归任务中,连续学习的任务的个数是 3 个,每个任务需要使用一个长度(te_dim
)是 2 的任务嵌入向量,因此任务嵌入向量一共需要 6 个数值(这 6 个值也是需要确定的参数)。
超网络网络结构所需的 Θ \Theta Θ 权重个数( 140 + 1551) + 任务嵌入向量(6个数值),一共需要的权重个数为 1691 + 6 = 1697
任务嵌入向量的生成代码如下(在hyper_model.py
中定义),可以看出所有任务嵌入向量(self._task_embs
)中的元素(data=torch.Tensor(te_dim)
)是随机生成的,最后存放在self._task_embs
中,确定了之后就不能被改变了。
# 任务嵌入向量的生成代码如下:
# Task embeddings.
if no_te_embs:
self._task_embs = None
else:
self._task_embs = nn.ParameterList()
for _ in range(num_tasks):
self._task_embs.append(nn.Parameter(data=torch.Tensor(te_dim),requires_grad=True))
torch.nn.init.normal_(self._task_embs[-1], mean=0., std=1.)
# te_dim 的值是 2
layers
:超网络自己的网络层定义为,layers = [10, 10];
te_dim
:超网络的任务嵌入向量(task embedding)的维度 2 ;
训练过程中的日志如下(与我计算出的结果1697
141
一致,说明我的理解是对的):
超网络的权重数 1697 个 :
Constructed hypernetwork with 1697 parameters (1691 network weights + 6 task embedding weights). # 超网络的权重数 1697
超网络的输出 141 个:
The hypernetwork has a total of 141 outputs. # 超网络的输出 141
Train the network using the task-specific loss plus a regularizer that should weaken catastrophic forgetting.
识别任务id的网络rnet
:
上面的1.1.6
大概理清楚了网络结构的定义、大小,这小节的主要目的是弄清三者之间是如何相互运作的。
根据 1.1.3
中命令中参数use_proximal_alg
默认设置为 True
, 表示 Proximal algorithm(近似算法). In this case, the optimal weight change is searched for via optimization rather than the actual weights. Note, in this case the options “use_sgd_change” and “backprop_dt” have no effect.
当use_proximal_alg
为 True
时: 使用近似算法进行优化,即不直接使用超网路计算超网络的权重,而是先通过下面的 loss \text{loss} loss 公式计算出 Δ θ \Delta\theta Δθ,之后计算出的 Δ θ + θ \Delta\theta + \theta Δθ+θ 就是超网络的权重。总 loss \text{loss} loss 的计算公式:
loss = task_loss ( θ + Δ θ ) + α ∥ Δ θ ∥ 2 + β ∗ ∑ j < task_id ∥ h ( c j , θ ) − h ( c j , θ + Δ θ ) ∥ 2 \text{loss} = \text{task\_loss}(\theta + \Delta\theta) + \alpha \lVert \Delta\theta \rVert^2 + \beta * \sum_{j < \text{task\_id}} \lVert h(c_j, \theta) - h(c_j, \theta + \Delta\theta) \rVert^2 loss=task_loss(θ+Δθ)+α∥Δθ∥2+β∗j<task_id∑∥h(cj,θ)−h(cj,θ+Δθ)∥2
θ + Δ θ \theta+\Delta\theta θ+Δθ :用于超网络加载的权重
对应的代码为(train.py
中):
loss = loss_task + config.alpha * l2_reg + config.beta * cl_reg
当use_proximal_alg
为 False
时: 使用 task-specific loss 加正则项,总loss的计算公式:
loss = task_loss + β ∗ regularizer \text{loss} = \text{task\_loss} + \beta * \text{regularizer} loss=task_loss+β∗regularizer
论文链接:使用超网络实现继续学习_论文:https://arxiv.org/abs/1906.00695
78
代码链接:使用超网络实现继续学习_代码:https://github.com/chrhenning/hypercl