如何编写联邦学习训练框架——Pytorch实现

联邦学习框架实现

联邦学习训练过程由服务器客户端两部分组成。
客户端将本地数据训练得到的模型上传服务器,服务器通过聚合客户端上传的服务器再次下发新一轮的模型,原理很简单,那么我们开始动手写代码。
如何编写联邦学习训练框架——Pytorch实现_第1张图片

1. 客户端部分:

客户端部分很简单,我们需要做的就是获取全局模型,利用本地数据进行训练,然后返回模型参数差。
这里我是推荐使用name_parameters进行参数更新提交,因为torch的训练模型中带有许多保存临时值的层,都提交没有意义。
我们首先建立一个字典pre_model={}来保存原先的模型数值。这个地方有两个作用:

  1. 保存先前的模型数值,用来计算新模型参数和旧模型参数之间的差,model的parameters减去new_model的parameters。
  2. 用来还原全局模型。

这里的2解释一下,因为是单机训练模型,我们没有真实地给客户端分配一个模型,因为如果要真的分配模型,那么就需要deepcopy一个model给每一个客户端,每一个模型大小都不小,开个十几个二十几个16G内存就被占满了,而且速度也没有明显变快。因此我们采取的方法是,客户端的服务器共享一个模型,客户端每次训练解释后恢复全局模型为训练前的参数值,服务器等到全部客户端结束一轮训练后,更新此全局模型。

代码编写部分:模型训练部分与集中式机器学习训练完全一致,就是训练完之后需要复位以及return diff。

class Client(object):

	client_id = 0

	def __init__(
		self,
		batch_size,
		lr,
		momentum,
		model_parameter,
		local_epochs,
		model,
		train_dataset,
	) -> None:

		Client.client_id += 1
		self.client_id = Client.client_id
		self.batch_size = batch_size
		self.lr = lr
		self.momentum = momentum
		self.model_parameter = model_parameter
		self.local_epochs = local_epochs
		self.local_model = model
		self.train_dataset = train_dataset
		self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)

	def local_train(self):

		# record the previous model parameters
		# 1. calculate diff
		# 2. restoring the global model
		pre_model = {}
		if self.model_parameter == "all":
			for name, param in self.local_model.state_dict().items():
				pre_model[name] = param.clone()
		else:
			for name, param in self.local_model.named_parameters():
				pre_model[name] = param.clone()

		optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.lr, momentum=self.momentum)
		epoch = self.local_epochs
		self.local_model.train()

		for _ in range(epoch):
			for _, batch in enumerate(self.train_loader):
				data, target = batch

				if torch.cuda.is_available():
					data = data.cuda()
					target = target.cuda()
			
				optimizer.zero_grad()
				output = self.local_model(data)
				loss = torch.nn.functional.cross_entropy(output, target)
				loss.backward()
				optimizer.step()
				
		print(f"{self.client_id} complete!")

		# record the differences between the local model and the global model
		diff = {}

		for name, param in pre_model.items():
			diff[name] = self.local_model.state_dict()[name] - param

		for name, param in pre_model.items():
			self.local_model.state_dict()[name] = param

		return diff

这段代码以后还有改写,使用的损失函数,优化器后期都会改写成可以修改的方式。

2. 服务器部分:

服务器部分包含了一个聚合部分和一个模型评估部分。聚合很简单,目前就写了fedavg,还是平均的聚合,clients_diff是一个元素为diff字典的列表,我们通过遍历此数组将每个客户端的参数差相加,最后乘以一定的权重加在最后的全局模型上,得到本轮迭代的结果。

class Server(Model):
	
	def __init__(
		self,
		model_name,
		batch_size,
		lamda,
		eval_dataset
	):
		super().__init__(model_name, eval_dataset)
		self.global_model = self.model
		self.lamda = lamda
		self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

	def model_aggregation(self, clients_diff):

		weight_accumulator = {}
		for name, params in clients_diff[0].items():
			weight_accumulator[name] = torch.zeros_like(params)

		for _, client_diff in enumerate(clients_diff):
			for name, params in client_diff.items():
				weight_accumulator[name].add_(params)

		for name, params in weight_accumulator.items():
			update_per_layer = params * self.lamda

			if params.type() != update_per_layer.type():
				params.add_(update_per_layer.to(torch.int64))
			else:
				params.add_(update_per_layer)

	def model_eval(self):
		self.global_model.eval()
		total_loss = 0.0
		correct = 0
		dataset_size = 0

		for batch_id, batch in enumerate(self.eval_loader):
			data, target = batch 
			dataset_size += data.size()[0]
			
			if torch.cuda.is_available():
				data = data.cuda()
				target = target.cuda()
				
			output = self.global_model(data)
			total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
			pred = output.data.max(1)[1]
			correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

		acc = 100.0 * (float(correct) / float(dataset_size))
		total_l = total_loss / dataset_size
		return acc, total_l

3. 主函数部分

主函数就是实例化一个服务器,和一群客户端。主函数部分里比较重要的就是数据集的划分,目前实现了两种方法,一种是平均划分,一种是dirichlet划分。
关于dirichlet划分可以详见:Dirichlet分布

if __name__ == '__main__':

    # load the configure file
    with open('./conf.json', 'r') as f:
        conf = json.load(f)

    # load dataset
    train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["dataset"])

    server = Server(batch_size=conf["batch_size"],
                    lamda=conf["lambda"],
                    model_name=conf["model_name"],
                    eval_dataset=eval_datasets)

    # total clients array
    clients = []

    if conf["data_distribution"] == 'iid':
        n_clients = conf["num_models"]
        data_len = len(train_datasets)
        subset_indices = distribution.split_iid(n_clients, data_len)
        for idx in subset_indices:
            subset_dataset = Subset(train_datasets, idx)
            clients.append(Client(batch_size=conf["batch_size"],
                                  lr=conf["lr"],
                                  momentum=conf["momentum"],
                                  model_parameter=conf["model_parameter"],
                                  local_epochs=conf["local_epochs"],
                                  model=server.global_model,
                                  train_dataset=subset_dataset))

    elif conf["data_distribution"] == 'dirichlet':
        n_clients = conf["num_models"]
        dirichlet_alpha = conf["dirichlet_alpha"]
        train_labels = train_datasets.targets
        # return an array: every client's index
        client_idcs = distribution.dirichlet_split_noniid(train_labels, alpha=dirichlet_alpha, n_clients=n_clients)

        for c, subset_indices in enumerate(client_idcs):
            subset_dataset = Subset(train_datasets, subset_indices)
            clients.append(Client(batch_size=conf["batch_size"],
                                  lr=conf["lr"],
                                  momentum=conf["momentum"],
                                  model_parameter=conf["model_parameter"],
                                  local_epochs=conf["local_epochs"],
                                  model=server.global_model,
                                  train_dataset=subset_dataset))

    accuracy = []
    losses = []

    for e in range(conf["global_epochs"]):

        # random choice k clients
        candidates = random.sample(clients, conf["k"])

        # clients_weight recode the diffs of every client
        clients_weight = []

        for _, c in enumerate(candidates):
            diff = c.local_train()
            clients_weight.append(diff)

        server.model_aggregation(clients_diff=clients_weight)

        acc, loss = server.model_eval()
        accuracy.append(acc)
        losses.append(loss)

        print(f"Epoch {e:d}, acc: {acc:f}, loss: {loss:f}\n")

4. 模型部分和数据集部分

这两个部分就直接使用了torch里自带的数据集和模型,如果想要使用自己的模型和数据集,就和平时pytorch里自己编写模型和数据集一样。

代码地址见联邦学习代码框架,如果对你有帮助的话,可不可以给个三连~

你可能感兴趣的:(联邦学习,pytorch,人工智能,python)