我的联邦学习相关笔记Github
TFF平台还是挺难用,光是那些API就很难用熟练。所以在借鉴这位老哥代码的基础上改出来这份代码。
将mnist数据集处理成IID类型
import random
import numpy as np
from termcolor import colored
from tensorflow.keras.datasets import mnist
import tensorflow as tf
from mnist_model import mnist_model
from utils import *
def generate_clients_data( num_expamples_list_in_clients:list , num_clients = 10, IsIID=True, batch_size=100 ,tt_rate = 0.3):
(x_train, y_train), (x_test, y_test)= mnist.load_data()
x_train , y_train = shuffle_dataset(x_train , y_train)
x_test , y_test = shuffle_dataset(x_test , y_test)
x_train = x_train.astype('float32').reshape(-1,28*28)/255.0
x_test = x_test.astype('float32').reshape(-1,28*28)/255.0
y_test = tf.one_hot(y_test , depth=10 , on_value=None , off_value = None)
y_train = tf.one_hot(y_train , depth=10 , on_value=None , off_value = None)
if len(num_expamples_list_in_clients) == 1:
num_expamples_list_in_clients *= num_clients
# dataset for server test
# get server datasets
client_dataset_test_size = int(sum([x*tt_rate for x in num_expamples_list_in_clients]))
dataset_server_size = int(client_dataset_test_size*0.3)
server_test_x = x_test[ client_dataset_test_size : int(client_dataset_test_size+dataset_server_size) ]
server_test_y = y_test[ client_dataset_test_size : int(client_dataset_test_size+dataset_server_size) ]
server_dataset = tf.data.Dataset.from_tensor_slices((server_test_x , server_test_y )).batch(batch_size)
x_test = x_test[:client_dataset_test_size]
y_test = y_test[:client_dataset_test_size]
if (IsIID == True):
print(colored('---------- IID = True ----------', 'green'))
# get train dataset for client
train_data_list = []
start_train = 0
for size in num_expamples_list_in_clients:
client_i_train_dataset = list( zip( x_train, y_train ))[start_train : size+start_train]
train_data_list.append( preprocess_client_data( client_i_train_dataset ) )
start_train += size
# get test dataset for client
test_data_list = []
start_test = 0
for test_size in [x*tt_rate for x in num_expamples_list_in_clients]:
client_i_test_dataset = list(zip( x_test , y_test ))[start_test:int(test_size)+start_test]
test_data_list.append( preprocess_client_data( client_i_test_dataset ) )
start_test += int(test_size)
#for test server
return train_data_list, test_data_list , server_dataset
else:
''' creates x non_IID clients'''
print(colored('---------- IID = False ----------', 'green'))
#create unique label list and shuffle
unique_labels = np.unique(np.array(y_train))
# random.shuffle(unique_labels)
unique_labels = sorted(unique_labels)
train_class = [None]*num_clients
# classifar examples by unique label
for (item , num_examples_client) in zip(unique_labels , num_expamples_list_in_clients ):
train_class[item] = [(image, label) for (image, label) in zip(x_train, y_train) if label == item][:num_examples_client]
clients_dataset_list = []
for dataset in train_class:
clients_dataset_list.append( preprocess_client_data(dataset) )
# dataset for server test
return clients_dataset_list , server_dataset
这里用的是TFF中的mnist模型,如果要改成自己的模型,记得把代码中要用的功能实现一下。
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow_federated.python.simulation.models import mnist
#创建mnist神经网络模型
def mnist_model(comp_model = False):
return mnist.create_keras_model(compile_model=comp_model)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
# import numpy as np
import tensorflow as tf
from termcolor import colored
import random
from tensorflow_federated.python.simulation.models import mnist
from utils import *
'''
客户端类
'''
class client(object):
def __init__(self ,
local_dataset:dict,
client_name = 0 ,
local_model = mnist.create_keras_model(compile_model=False)
):
self.local_dataset = local_dataset
self.client_name = client_name
self.local_model = local_model
self.dataset_size = get_datasize(self.local_dataset['train']) +get_datasize(self.local_dataset['test'])
self.val_acc_list = []
self.val_loss_list = []
self.local_model_size_list = [ get_model_size(self.local_model) ]
def set_client_name(self , name):
self.client_name = name
def set_model_weights(self , model : tf.keras.Model):
self.local_model.set_weights(model.get_weights())
def set_local_dataset(self , dataset):
self.local_dataset = dataset
def client_train(self ,
client_epochs = 10 ,
model_loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True) ,
model_optimizer=tf.keras.optimizers.SGD(learning_rate=0.1) ,
model_metrics = ['accuracy']
):
self.local_model.compile(
optimizer=model_optimizer ,
loss=model_loss ,
metrics=model_metrics)
client_train_history = self.local_model.fit(self.local_dataset['train']
, epochs = client_epochs
, validation_data=self.local_dataset['test']
, validation_freq= client_epochs
, verbose=0
, workers= 4
, use_multiprocessing=True
)
self.val_acc_list.append(client_train_history.history['val_accuracy'][0])
self.val_loss_list.append(client_train_history.history['val_loss'][0])
self.local_model_size_list.append(get_model_size(self.local_model))
def get_local_info(self):
return {'client_name': self.client_name ,
'local_dataset_size': self.dataset_size ,
'client_model_size_history': self.local_model_size_list ,
'client_val_acc_history':self.val_acc_list ,
'client_val_loss_history': self.val_loss_list ,
'current_local_model_size': self.local_model_size_list[-1] ,
'current_local_acc': self.val_acc_list[-1] ,
'current_local_loss': self.val_loss_list[-1]
}
import tensorflow as tf
# import numpy as np
from tensorflow_federated.python.simulation.models import mnist
from client import client
from typing import List
from utils import *
type_client_list = List[client]
class server(object):
def __init__(self ,
server_name = 0 ,
test_dataset = None ,
server_model = mnist.create_keras_model(compile_model=False) ):
self.server_name = server_name
self.test_dataset = test_dataset
self.server_model = server_model
self.ave_acc_list = []
self.ave_loss_list = []
def get_server_info(self):
return {
'server name' : self.server_name ,
'server dataset size' : get_datasize(self.test_dataset) ,
'server model size' : get_model_size(self.server_model) ,
'server acc history' : self.ave_acc_list ,
'server loss history' : self.ave_loss_list
}
#calculate server model by clients list
def calculate_server_model( self, client_list : type_client_list):
# get sum of datasets size in clients
sum_client_datasets = 0
for client in client_list:
sum_client_datasets +=client.dataset_size
# get client impact factor to server model
rate_client_dataset_size = []
for client in client_list:
rate_client_dataset_size.append( client.local_model_size_list[-1] /sum_client_datasets )
#calculate server model
clients_modelweight_list = []
for (client , factor) in zip(client_list,rate_client_dataset_size):
client_union_weights = []
client_weights = client.local_model.get_weights()
num_client_layers = len(client_weights)
for i in range(num_client_layers):
client_union_weights.append(factor*client_weights[i])
clients_modelweight_list.append(client_union_weights)
# union server model
metrix = []
for weights in zip(*clients_modelweight_list):
weights_sum = tf.reduce_sum(weights, axis =0)
metrix.append(weights_sum)
self.server_model.set_weights(metrix)
#broadcast server model to clients list
def broadcast_server_model( self, client_list : type_client_list ):
for client in client_list:
client.set_model_weights(self.server_model)
# return client_list
# server test
def server_model_test(self ):
server_loss , server_acc =self.server_model.evaluate(self.test_dataset , verbose=1 , workers=4 , use_multiprocessing=True )
self.ave_loss_list.append(server_loss)
self.ave_acc_list.append(server_acc)
import numpy as np
import tensorflow as tf
def preprocess_client_data(data, bs=100):
x, y = zip(*data)
return tf.data.Dataset.from_tensor_slices( ( list(x) , list(y) ) ).batch(bs)
'''
shuffle dataset
'''
def shuffle_dataset(datas , labels):
shuffle_ix = np.random.permutation(np.arange(len(datas)))
return datas[shuffle_ix] , labels[shuffle_ix]
'''
获取客户端模型大小
'''
def get_model_size(model):
para_num = sum([np.prod(w.shape) for w in model.get_weights()])
# para_size: 参数个数 * 每个4字节(float32) / 1024 / 1024,单位为 MB
para_size = para_num * 4 / 1024 / 1024
return para_size
'''
获取客户端数据集大小
'''
def get_datasize(dataset):
dataset_size = 0
for batch in dataset:
dataset_size += len(batch)
return dataset_size
def summary_acc_loss(logdir , name ,loss:list ,acc:list ):
summary_writer = tf.summary.create_file_writer(logdir)
for rnd in range(len(loss)):
tf.summary.scalar('{}_acc'.format(name) , acc[rnd] , step=rnd)
tf.summary.scalar('{}_loss'.format(name) , loss[rnd] , step=rnd)
summary_writer.flush()
from mnist import generate_clients_data
from client import client
from server import server
from mnist_model import mnist_model
from typing import List
import matplotlib.pyplot as plt
from utils import *
import numpy as np
def draw(epoch_sumloss , epoch_acc):
x=[i for i in range(len(epoch_sumloss))]
#左纵坐标
fig , ax1 = plt.subplots()
color = 'red'
ax1.set_xlabel('epoch')
ax1.set_ylabel('loss' , color=color)
ax1.plot(x , epoch_sumloss , color=color)
ax1.tick_params(axis='y', labelcolor= color)
ax2=ax1.twinx()
color1='blue'
ax2.set_ylabel('acc',color=color1)
ax2.plot(x , epoch_acc , color=color1)
ax2.tick_params(axis='y' , labelcolor=color1)
fig.tight_layout()
plt.show()
def FAVG_init( ):
# get IID data list for client list
num_expamples_list_in_clients = [1000 , 2000 , 1500 , 500 , 3000 , 1000 , 1000,2000,1500 , 2000]
client_train_data_list , client_test_data_list , server_test_dataset = generate_clients_data( num_expamples_list_in_clients ,
num_clients = 10,
IsIID=True,
batch_size=100,
tt_rate = 0.3)
# experiment model
#model = mnist_model(comp_model=False)
# set dataset for server
server_0 = server(test_dataset=server_test_dataset , server_model=mnist_model(comp_model=False) )
# client set list
clients_list = []
# set dataset for clients
client_name_list = list('client_{}'.format(i) for i in range(10))
for i in range(len(num_expamples_list_in_clients)):
client_data_dict = {'train': client_train_data_list[i] , 'test' : client_test_data_list[i] }
clients_list.append( client(
local_dataset=client_data_dict,
client_name = client_name_list[i] ,
local_model= mnist_model(comp_model=False)
) )
return server_0 , clients_list
# train
def FAVG_train(server:server , clients_list: List[client] , server_round : int , client_enpochs:int):
for i in range(server_round):
for client_ in clients_list:
client_.client_train(client_epochs=client_enpochs)
# print(client_.get_local_info() , '\n')
server.calculate_server_model(clients_list)
server.broadcast_server_model(clients_list)
server.server_model_test()
return server , clients_list
if __name__ == "__main__":
server_0 , clients_list = FAVG_init()
server_0 , clients_list = FAVG_train(server_0 , clients_list , 200 , 10)
log_dir = "/tmp/logs/scalars/FAVG/"
summary_acc_loss(logdir=log_dir , name=server_0 , loss=server_0.ave_loss_list , acc=server_0.ave_acc_list)
# draw(server_0.ave_loss_list , server_0.ave_acc_list)
np.savez('server_acc' , server_0.ave_acc_list)
np.savez('server_loss' , server_0.ave_loss_list)
for _client in clients_list:
np.savez('{}_acc'.format(_client.client_name) , _client.val_acc_list )
np.savez('{}_loss'.format(_client.client_name) , _client.val_loss_list )