FATE是微众银行开发的联邦学习平台,是全球首个工业级的联邦学习开源框架,在github上拥有近4000stars,可谓是相当有名气的,该平台为联邦学习提供了完整的生态和社区支持,为联邦学习初学者提供了很好的环境,否则利用python从零开发,那将会是一件非常痛苦的事情。本篇博客内容涉及《联邦学习实战》第十章内容,使用的fate版本为1.6.0,fate的安装已经在这篇博客中介绍,有需要的朋友可以点击查阅。下面就让我们开始吧。
随着算法的提升,大数据和硬件算力的发展,人工智能在视觉领域出现爆发性的增长,以目标检测为例,主要步骤如下:
但是传统的深度学习容易受到以下因素影响:
本案例对分散在各地的摄像头数据,通过联邦学习,构建一个联邦分布式训练网络,摄像头数据无需上传,便可以协同训练目标检测模型,这样一方面用户的隐私数据不会被泄露,另一方面,充分利用参与方的训练数据,提升机器学习视觉模型的识别效果。
当前常见的计算机视觉任务可以归纳为图像分类、语义分割、目标检测、实例分割,区别如下图所示。
本案例场景为典型的目标检测任务。本节简单回顾目标检测任务的算法步骤。
边界线: 描述目标位置,是一个矩形框,由左上角坐标 ( x 1 , y 1 ) (x_1,y_1) (x1,y1)和右下角坐标 ( x 2 , y 2 ) (x_2,y_2) (x2,y2)共同决定。
锚框: YOLO系列算法定义锚框来提取候选区域,锚框以每个像素为中心,生成多个大小宽高比不同的边界框集合。如下图所示
交并比: 当多个边界框覆盖了图像中物体,如果该物体的真实边界框已知,那么需要一个衡量预测边界框好坏的指标,在目标检测领域,使用交互比(IOU)衡量。
假设有两个边界框A和B,则A和B的IOU为二者的相交面积和相并面积的比值。
I O U ( A , B ) = A ∩ B A ∪ B IOU(A,B)=\frac{A\cap B}{A\cup B} IOU(A,B)=A∪BA∩B
基于候选区域的目标检测算法包括R-CNN、Fast R-CNN、Faster R-CNN等,这类算法在求解目标检测任务时,分为两个阶段:第一阶段先产生所有可能的目标候选框,第二阶段再对所有候选框做分类与回归。因此这类算法也被称为二阶段算法。
R-CNN:先对图像提取大约2000个候选区域,然后将候选区域输入到CNN网络中,提取每个候选框的特征数据,每个候选框的特征数据与其类别一起构成一个样本,训练多个支持向量机对目标分类,其中每个支持向量机用来判断样本是否属于同一个类别,利用每个候选框的特征数据与其边界框一起构成一个样本,用来训练线性回归模型,并预测真实的边界框。
Fast R-CNN:R-CNN的瓶颈在于,候选区域大量重叠,导致单独提取特征出现大量重复计算,所以Fast R-CNN先将图片输入CNN中,得到特征图,在特征图上进行候选区选取工作,并用softmax代替支持向量机,加快训练速度。由于每个候选区域大小不同,得到的特征向量长度不一,所以使用ROI池化将不同大小的输入转变为固定的大小长度。
Faster R-CNN:虽然Fast R-CNN相比R-CNN有了很大的提升,但是候选区域的提取与目标检测仍然是两个独立过程,因此,Faster R-CNN在此基础上,提出了候选区域网络(RPN),将候选区域的提取与目标检测作为同一个网络进行端到端的训练。
仅仅使用一个卷积神经网络直接预测不同目标的分类与位置,不需要预先选取候选区域,因此在效果上,基于区域的算法要比单阶段算法准确度高,但速度慢,相反,单阶段算法速度快,但准确性低,典型的单阶段算法包括SSD,YOLO系列。
以YOLO为例,不需要先找出所有的候选框,而是直接将图片输入到模型中,最后直接得到边界框的位置及物体的标签信息,并且它将边界框定位与目标分类都看成回归问题。这样做到端到端的处理,以Pascal VOC数据集为例,处理步骤如下:
对模型提供方和数据提供方来说,安全威胁是当前最为头疼和亟待解决的问题。安全威胁主要来自数据层面:
因此,急需一种新的模型训练方法:数据保证不离开本地,并且模型性能不能受到影响。这两点都非常适合联邦学习。
对于一个横向联邦学习实现的目标检测模型的工作流程,以本案为例,基本设置如下:
基于联邦学习的目标检测视觉模型对集中式模型的优势:
书中实现方法有基于Flask-SocketIO的python实现,也有基于FATE实现,这里主要介绍python实现过程。
Flask-SocketIO作为服务端和客户端之间的通信框架,可以轻松实现服务端和客户端的双向通信。
首先安装SocketIO库,只需在命令行中输入:
$ pip install flask-socketio
from flask import Flask, render_template
from flask_socketio import SocketIO
app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)
if __name__=='__main__':
socketio.run(app)
socketio.run()是服务器启动的接口,通过封装app.run()实现。这段代码没有任何功能,为了能够相应用户请求,需要定义必要的函数。如下创建一个“my event”事件,代码如下:
from flask import Flask, render_template
from flask_socketio import SocketIO
app = Flask(__name__)
app.config['SECRET_KEY'] = 'secret!'
socketio = SocketIO(app)
@socketio.on('my event')
def test_message(message):
emit('my response', {'data':message['data']})
if __name__=='__main__':
socketio.run(app)
事件创建后,服务端等待客户发送“my event”请求,此外,socketIO是双向通信,所以服务端还能向客户端发送请求,用emit和send(命名事件用前者,未命名用后者)。
from socketIO_client import SocketIO
def test_response(data):
print(data)
sio = SocketIO('localhost', 5000, None)
sio.on("my_response", test_response)
sio.emit("my event")
sio.wait()
先用socketIO创建一个客户端,构造函数需要提供端口号和服务器IP,然后利用on连接事件“my_response”,以及处理函数“test_response”,发送“my event”事件,等待服务端事件响应。
服务端主体如下:
构建一个服务端类,在类结构的构造函数中,定义部分变量如下:
class FLServer(object):
def __init__(self, task_config_filename, host, port):
self.task_config = load_json(task_config_filename)
self.ready_client_sids = set()
self.app = Flask(__name__)
self.socketio = SocketIO(self.app, ping_timeout=3600000,
ping_interval=3600000,
max_http_buffer_size=int(1e32))
self.host = host
self.port = port
self.model_id = str(uuid.uuid4())
self.aggregator = Aggregator(self.task_config, self.logger)
...
self.register_handles()
相对于第3章的服务端设计,本章的服务端更为复杂,主要增加了socket通信的信息,一些字段解析如下:
构造函数之后是register_handles函数,用于事件注册,即响应客户端的请求。
def register_handles(self):
# single-threaded async, no need to lock
@self.socketio.on('connect')
def handle_connect():
print(request.sid, "connected")
@self.socketio.on('reconnect')
def handle_reconnect():
print(request.sid, "reconnected")
@self.socketio.on('disconnect')
def handle_disconnect():
print(request.sid, "disconnected")
if request.sid in self.ready_client_sids:
self.ready_client_sids.remove(request.sid)
@self.socketio.on('client_wake_up')
def handle_wake_up():
print("client wake_up: ", request.sid)
emit('init')
@self.socketio.on('client_ready')
def handle_client_ready():
print("client ready for training", request.sid)
self.ready_client_sids.add(request.sid)
if len(self.ready_client_sids) >= self.MIN_NUM_WORKERS and self.current_round == -1:
print("start to federated learning.....")
self.check_client_resource()
elif len(self.ready_client_sids) < self.MIN_NUM_WORKERS:
print("not enough client worker running.....")
else:
print("current_round is not equal to -1, please restart server.")
...
服务端创建完毕等待客户端发送信号,接收到客户端信号后,将它们全放置在候选列表ready_client_sids
中,每一轮训练会随机挑选部分客户端参与下一轮的迭代。
client_sids_selected = random.sample(list(self.ready_client_sids), self.NUM_CLIENTS_CONTACTED_PER_ROUND)
服务端另一个主要功能是进行模型聚合,如下是FedAvg的实现,我们将每轮上传的客户端模型参数放置到model_weights
中,选择本地样本数量占全体样本数量的比例作为模型参数的权重,求取新的全局模型参数值。
def update_weights(self, client_weights, client_sizes):
total_size = np.sum(client_sizes)
new_weights = [np.zeros(param.shape) for param in client_weights[0]]
for c in range(len(client_weights)):
for i in range(len(new_weights)):
new_weights[i] += (client_weights[c][i] * client_sizes[c]
/ total_size)
self.current_weights = new_weights
构造函数主体如下:
class FederatedClient(object):
MAX_DATASET_SIZE_KEPT = 6000
def __init__(self, server_host, server_port, task_config_filename,
gpu, ignore_load):
os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpu
self.task_config = load_json(task_config_filename)
# self.data_path = self.task_config['data_path']
print(self.task_config)
self.ignore_load = ignore_load
self.local_model = None
self.dataset = None
...
在联邦学习中,客户端与服务端是双向通信的,因此需要客户端注册相应的事件函数,用于响应服务端发送事件请求处理。
def register_handles(self):
########## Socket IO messaging ##########
def on_connect():
print('connect')
def on_disconnect():
print('disconnect')
def on_reconnect():
print('reconnect')
def on_request_update(*args):
...
self.sio.on('connect', on_connect)
self.sio.on('disconnect', on_disconnect)
self.sio.on('reconnect', on_reconnect)
self.sio.on('init', self.on_init)
self.sio.on('request_update', on_request_update)
self.sio.on('stop_and_eval', on_stop_and_eval)
self.sio.on('check_client_resource', on_check_client_resource)
on是一个接口函数,参数是事件名称和对应的响应函数。
客户端创建完毕后,等待服务端下发初始化命令,服务端会下发初始的全局模型和配置信息给客户端,客户端初始化主要是将本地模型替换全局模型,同时利用配置信息读取本地训练数据集。
def on_init(self, request):
print('on init')
self.local_model = LocalModel(self.task_config)
print("local model initialized done.")
# ready to be dispatched for training
self.sio.emit('client_ready')
客户端另一个重要环节是本地训练,通常情况和本地训练没有太大区别,这里不再赘述,感兴趣的朋友参考官方代码。
本章最后部分对两个模型在联邦学习中的性能进行了测试,分别测试了它们在不同数量客户参与方(C)以及不同本地训练迭代次数(E)配置下的性能对比,可以看到,参与方越多,其迭代收敛也越快(这是书中原话,但笔者认为并不绝对)。
下图是两个模型在损失值上的对比,可以得出:
本章内容涉及CV领域的目标检测内容,还是比较好理解的,只不过在运行代码的过程中,由于官方代码不全,导致运行不起来,实属遗憾,有时间一定斟酌一下,找到遗漏的的文件。然后FATE的实现文中并没有介绍,但是给了github链接,感兴趣的朋友可以复现一下,我也尽量能够出期FATE进行联邦目标检测实例的博客。接下来的第11章,FL在物联网的应用,应该还是理论居多,就让我们继续吧!
https://blog.csdn.net/tinyzhao/article/details/53729006
https://blog.csdn.net/tinyzhao/article/details/53742626
https://github.com/FederatedAI/Practicing-Federated-Learning/tree/main/chapter10_Computer_Vision