pytorch实现多个模型的weights平均和修改weights

文章目录

    • 1. 操作说明
    • 2.代码

1. 操作说明

有3个结构相同但是weights不同的model组成一个list,models=[model1,model2,model3],还有一个中心模型fl_model,这四个模型的结构和超参数都相同。

需要进行这样一种操作:平均models里面三个模型的weights,把平均之后的weights"赋值"给fl_model的weights。

2.代码

在tensorflow里可以直接用model.get_weights()和model.set_weights()来做,比较直观和方便。感觉pytorch里面稍微复杂一些。进行上述操作的代码如下:

worker_state_dict=[x.state_dict() for x in models]
weight_keys=list(worker_state_dict[0].keys())
fed_state_dict=collections.OrderedDict()
for key in weight_keys:
    key_sum=0
    for i in range(len(models)):
        key_sum=key_sum+worker_state_dict[i][key]
    fed_state_dict[key]=key_sum/len(models)
#### update fed weights to fl model
fl_model.load_state_dict(fed_state_dict)

你可能感兴趣的:(#)