pytorch优化器传入两个网络参数

pytorch优化器传入两个网络的参数

方法一

使用字典传入,还可以指定不同学习率。

optimizer = torch.optim.Adam([
				{'params': model_one.parameters()},
				{'params': model_two.parameters(), 'lr': 1e-4}
				], lr)

方法二

由于params定义如下:

  • params (iterable) – iterable of parameters to optimize or dicts defining parameter groups

所以我们可以通过itertools.chain将两个网络参数连接。

import itertools
......
optimizer = torch.optim.Adam(itertools.chain(model_one.parameters(), model_two.parameters()), lr)

你可能感兴趣的:(pytorch,深度学习,python)