流量预测之联邦学习

时间序列预测算法之联邦学习

介绍

设三个节点,其中一个中心节点,两个子节点,子节点利用LSTM模型训练,保证每个epoch完跟中心节点进行交互,完成参数融合

  1. 子节点部分代码
class LSTM(nn.Module):
    def __init__(self, input_size=2, hidden_size=4, output_size=1, num_layer=1):
        super(LSTM, self).__init__()
        self.layer1 = nn.LSTM(input_size, hidden_size, num_layer)
        self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x, _ = self.layer1(x)
        x = torch.relu(x)
        s, b, h = x.size()
        x = x.view(s * b, h)
        x = self.layer2(x)
        x = x.view(s, b, -1)
        return x

# 二、模型构建
model = LSTM(look_back, 4, 1, 2)
# print(model)
loss_fun = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
# s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# # 定义中心节点的地址端口号
# host = '127.0.0.1'
# port = 9999
# #建立链接
# s.connect((host, port))
# 三、开始训练
losses = list()
steps = list()
for epoch in range(1, EPOCH + 1):
    log("\033[1;31;40m第\033[1;31;40m%s\033[1;31;40m轮开始训练!\033[1;31;40m" % str(epoch))
    # 第一个网络
    for t in range(10):
        loss_t = list()
        # 前向传播
        out = model(var_x)
        loss = loss_fun(out, var_y)
        loss_t.append(loss.item())
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(sum(loss_t)/len(loss_t))
    steps.append(epoch)
    plt.plot(steps, losses, "o-")
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.draw()
    plt.pause(0.1)
    log("建立连接并上传......")
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    # 定义中心节点的地址端口号
    host = '127.0.0.1'
    port = 9999
    # 建立链接
    s.connect((host, port))
    # 序列化
    data = {}
    data['num'] = epoch
    data['model'] = model.state_dict()
    keys = model.state_dict().keys()
    data = pickle.dumps(data)
    print(s.send(data))
    # 等待待收
    log("等待接收......")
    try:
        s.settimeout(30000)
        data = s.recv(1024 * 100)
        # print(data)
        data = pickle.loads(data)
        print(data['num'], epoch)
        if data['num'] == epoch:
            global_state_dict = data['model']
        else:
            global_state_dict = model.state_dict()
    except Exception as e:
        print(e)
        # s.sendto(data, (host, port))
        log("没有在规定时间收到正确的包, 利用本地参数更新")
        global_state_dict = model.state_dict()

    # print(global_state_dict)
    # 重新加载全局参数
    model.load_state_dict(global_state_dict)
    s.close()
log("训练完毕,关闭连接")
s.close()

2.中心节点部分代码

def socket_udp_server():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)  # SOCK_STREAM指类型是UDP
    host = '127.0.0.1'  # 监听指定的ip,host=''即监听所有的ip
    port = 9999
    # 绑定端口
    s.bind((host, port))
    # 开始监听:
    s.listen(5)  # param : 等待连接的最大数量
    print('waiting for connecting')

    res, addrs = [], []
    cnt = 1
    while True:
        log("第%d轮开始接收并计时" % cnt)
        try:
            s.settimeout(30000)
            start = time.time()
            # 接收操作
            sock, addr = s.accept() #accept会等待并发返回一个客户端的连接
            print(sock)
            data = sock.recv(1024*100)  # 接收来自客户端的数据,最大(1k),阻塞式等待
            print('Received from %s:%s' % addr)
            # print('Received data:', data)

            tmp = pickle.loads(data)
            print(tmp['num'], cnt)
            if tmp['num'] == cnt:
                addrs.append(sock)
                res.append(tmp['model'])
            # print(res)
            recv_time = time.time() - start
            print(len(res))
            if len(res) >= 2 or recv_time > 2000000:
                log("第%d轮接收完毕 接收来自%d个节点的参数" % (cnt, len(res)))
                # 处理操作
                log("开始融合处理操作......")
                # time.sleep(5)
                # res = str(sum(res))
                for m, n in zip(res[0].values(), res[1].values()):
                    if len(m.size()) == 1:
                        m1(m, n)
                    elif len(m.size()) == 2:
                        m2(m, n)
                # print(res[0])
                # res = pickle.dumps(res[0])
                data = {}
                data['num'] = cnt
                data['model'] = res[0]
                # 下发操作
                log('第%d轮融合完毕,下发......' % cnt)
                data = pickle.dumps(data)
                # print(data)
                for sock in (addrs):
                    sock.send(data)
                    sock.close()
                    # s.sendto(b'%s' % res.encode('utf-8'), addr)
                # else:
                #     res = '处理完毕,关闭连接'
                #     for addr in (addrs):
                #         s.sendto(b'%s' % res.encode('utf-8'), addr)
                #     break
                res, addrs = [], []
                cnt += 1
                if cnt > Epoch:
                    log('处理完毕,关闭连接')
                    break
            else:
                continue
        except:
            log("超时!!!")
            cnt += 1
    s.close()
  • 完整代码见FL节点流量预测

你可能感兴趣的:(流量预测之联邦学习,python,算法,机器学习,深度学习,socket)