module() takes at most 2 arguments (3 given)

出错代码

from torch.optim import optimizer

class ScaffoldOptimizer(optimizer):
    def __init__(self, params, lr):
        super(ScaffoldOptimizer, self).__init__(params, lr)
        self.lr = lr
        self.params = params

    def step(self, server_controls, client_controls):
        for k, v in self.params:
            # w = w - lr * (w.grad + c - ci)
            v.data = v.data - self.lr * (v.grad.data + server_controls[k] - client_controls[k])

解决

应该继承Optimizer:

from torch.optim import Optimizer

你可能感兴趣的:(PyTorch,pytorch,FedAvg,联邦学习)