客户端:
import torch
import fedml
from fedml import FedMLRunner
from fedml.data.MNIST.data_loader import download_mnist, load_partition_data_mnist
def load_data(args):
download_mnist(args.data_cache_dir)
fedml.logging.info("load_data. dataset_name = %s" % args.dataset)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_mnist(
args,
args.batch_size,
train_path=args.data_cache_dir+"/MNIST/train",
test_path=args.data_cache_dir+"/MNIST/test",
)
args.client_num_in_total=client_num
dataset=[
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
]
return dataset,class_num
class LogisticRegression(torch.nn.Module):
def __init__(self,input_dim,output_dim):
super(LogisticRegression,self).__init__()
self.linear=torch.nn.Linear(input_dim,output_dim)
def forward(self,x):
outputs=torch.sigmoid(self.linear(x))
return outputs
if __name__=="__main__":
args=fedml.init()
device=fedml.device.get_device(args)
dataset,output_dim=load_data(args)
model=LogisticRegression(28*28,output_dim)
fedml_runner=FedMLRunner(args,device,dataset,model)
fedml_runner.run()
服务器:
import fedml
import torch
from fedml import FedMLRunner
from fedml.cross_silo import Server
from fedml.data.MNIST.data_loader import download_mnist, load_partition_data_mnist
def load_data(args):
download_mnist(args.data_cache_dir)
fedml.logging.info("load_data. dataset_name = %s" % args.dataset)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dcit,
class_num,
)=load_partition_data_mnist(
args,
args.batch_size,
train_path=args.data_cache_dir+"/MNIST/train",
test_path=args.data_cache_dir+"/MNIST/train",
)
args.client_num_in_total=client_num
dataset=[
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dcit,
class_num,
]
return dataset,class_num
class LogisticRegression(torch.nn.Module):
def __init__(self,input_dim,output_dim):
super(LogisticRegression,self).__init__()
self.linear=torch.nn.Linear(input_dim,output_dim)
def forward(self,x):
outputs=torch.sigmoid(self.linear(x))
return outputs
if __name__=="__main__":
args=fedml.init()
device=fedml.device.get_device(args)
dataset,output_dim=load_data(args)
model=LogisticRegression(28*28,output_dim)
fedml_runner=FedMLRunner(args,device,dataset,model)
fedml_runner.run()
torch_client
import torch
import fedml
from fedml import FedMLRunner
from fedml.data.MNIST.data_loader import download_mnist, load_partition_data_mnist
def load_data(args):
download_mnist(args.data_cache_dir)
fedml.logging.info("load_data. dataset_name = %s" % args.dataset)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_mnist(
args,
args.batch_size,
train_path=args.data_cache_dir + "/MNIST/train",
test_path=args.data_cache_dir + "/MNIST/test",
)
args.client_num_in_total=client_num
dataset=[
train_data_num,
test_data_num,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
]
return dataset,class_num
class LogisticRegression(torch.nn.Module):
def __init__(self,input_dim,output_dim):
super(LogisticRegression,self).__init__()
self.linear=torch.nn.Linear(input_dim,output_dim)
def forward(self,x):
outputs=torch.sigmoid(self.linear(x))
return outputs
if __name__=="__main__":
args=fedml.init()
device=fedml.device.grt_device(args)
dataset,output_dim=load_data(args)
model=LogisticRegression(28*28,output_dim)
fedml_runner=FedMLRunner(args,device,dataset,model)
fedml_runner.run()
torch_sever
import torch
import fedml
from fedml import FedMLRunner
from fedml.data.MNIST.data_loader import download_mnist, load_partition_data_mnist
def load_data(args):
download_mnist(args.data_cache_dir)
fedml.logging.info("load_data. dataset_name = %s" % args.dataset)
(
client_num,
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
)=load_partition_data_mnist(
args,
args.batch_size,
train_path=args.data_cache_dir + "/MNIST/train",
test_path=args.data_cache_dir + "/MNIST/test",
)
args.client_num_in_total=client_num
dataset=[
train_data_num,
test_data_num,
train_data_global,
test_data_global,
train_data_local_num_dict,
train_data_local_dict,
test_data_local_dict,
class_num,
]
return dataset,class_num
class LogisticRegression(torch.nn.Module):
def __init__(self,input_dim,output_dim):
super(LogisticRegression,self).__init__()
self.linear=torch.nn.Linear(input_dim,output_dim)
def forward(self,x):
outputs=torch.sigmoid(self.linear(x))
return outputs
if __name__=="__main__":
args=fedml.init()
device=fedml.device.get_device(args)
dataset,output_dim=load_data(args)
model=LogisticRegression(28*28,output_dim)
fedml_runner=FedMLRunner(args,device,dataset,model)
fedml_runner.run()