[pysyft-006]联邦学习pysyft从入门到精通--使用protocol

import torch as th
import syft as sy


'''
Part 8 bis - Introduction to Protocols
http://localhost:8888/notebooks/git-home/github/PySyft/examples/tutorials/Part%2008%20bis%20-%20Introduction%20to%20Protocols.ipynb
'''

'''
本例演示protocol
一个protocol是若干个plan的集合,这些plan可以分布在不同的远程节点上。protocol提供更多的运算符号
'''

hook = sy.TorchHook(th)
hook.local_worker.is_client_worker = False


#定义三个plan
@sy.func2plan(args_shape=[(1,)])
def inc1(x):
        return x + 1

@sy.func2plan(args_shape=[(1,)])
def inc2(x):
    return x + 1

@sy.func2plan(args_shape=[(1,)])
def inc3(x):
    return x + 1
            
#把三个plan组装成一个protocol
protocol = sy.Protocol([("worker1", inc1), ("worker2", inc2), ("worker3", inc3)])

#创建虚拟节点
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
charlie = sy.VirtualWorker(hook, id="charlie")
workers = alice, bob, charlie

#把protocol部署到三个虚拟节点上
protocol.deploy(*workers)

#计算protocal
x = th.tensor([1.0])
ptr = protocol.run(x)
#打印结果
print(ptr.get())

#创建虚拟节点,把protocol发送到虚拟节点上
james = sy.VirtualWorker(hook, id="james")
protocol.send(james)

#在james节点上计算protocol
x = th.tensor([1.0]).send(james)
ptr = protocol.run(x)

#输出计算结果
ptr = ptr.get()
print(ptr)
ptr = ptr.get()
print(ptr)

 

你可能感兴趣的:([pysyft-006]联邦学习pysyft从入门到精通--使用protocol)