有3个结构相同但是weights不同的model组成一个list,models=[model1,model2,model3],还有一个中心模型fl_model,这四个模型的结构和超参数都相同。
需要进行这样一种操作:平均models里面三个模型的weights,把平均之后的weights"赋值"给fl_model的weights。
在tensorflow里可以直接用model.get_weights()和model.set_weights()来做,比较直观和方便。感觉pytorch里面稍微复杂一些。进行上述操作的代码如下:
worker_state_dict=[x.state_dict() for x in models]
weight_keys=list(worker_state_dict[0].keys())
fed_state_dict=collections.OrderedDict()
for key in weight_keys:
key_sum=0
for i in range(len(models)):
key_sum=key_sum+worker_state_dict[i][key]
fed_state_dict[key]=key_sum/len(models)
#### update fed weights to fl model
fl_model.load_state_dict(fed_state_dict)