websocket中转protobuf请求,从小程序接收protobuf请求转发到tcp服务,再把tcp服务的响应转发到小程序

package main

import (
    "flag"
    "net"
    "github.com/BurntSushi/toml"
    "github.com/gorilla/websocket"
    "log"
    "net/http"
    "runtime"
    "encoding/binary"
    "zonst/logging"
    "zonst/qipai/fsnotifyutil"
    "zonst/qipai/messages"
)

var listenAddr string
var tomlFile string

// var pid int

func init() {
    runtime.GOMAXPROCS(runtime.NumCPU())
    flag.StringVar(&listenAddr, "listenAddr", ":10000", "server listen addr")
    flag.StringVar(&tomlFile, "tomlFile", "docs/test.toml", "server TOML config file")
}

var upgrader = websocket.Upgrader{
    CheckOrigin: func(r *http.Request) bool { return true },
}

func (c *Context) echo(w http.ResponseWriter, r *http.Request) {
    conn, _ := upgrader.Upgrade(w, r, nil)
    //处理protobuf的tcp服务地址
    tcpConn, err := net.Dial("tcp", "xxx.xx.xx.xxx:10000")
    if err != nil {
        logging.Errorf("connect tcp err : %v \n", err)
        return
    }

    logging.Debugf("connect tcp server success %v \n", tcpConn)

    go c.ReadWebsocket(conn, tcpConn)
    go c.ReadTcp(conn, tcpConn)
}

func (c *Context) ReadWebsocket(conn *websocket.Conn, tcpConn net.Conn) {
    for {
        //mtype :TextMessage=1/BinaryMessage=2/CloseMessage=8/PingMessage=9/PongMessage=10
        mtype, msg, _ := conn.ReadMessage()
        logging.Debugf("mtype %v \n", mtype)
        logging.Debugf("msg %v \n", msg)
        switch mtype {
        case 2:
            logging.Debugf("二进制包长度: %v \n", binary.Size(msg))
            //protobuf的messageID
            protoMsgID := int32(msg[4])

            //去掉头
            protoMsgContent := msg[8:]
            logging.Debugf("msg content : %v \n", string(msg))
            // n, err := c.ProxyConn.Write(protoMsgContent)
            // logging.Debugln("=======================websocket write:", n, err)
            switch protoMsgID {
            case int32(messages.ProConnectRequest_ID):
                logging.Debugf("开始连接 : %v \n", protoMsgID)
                n, err := tcpConn.Write(msg)
                logging.Debugln("=======================websocket write:", n, err)
                // return
            case int32(messages.ProHeartBeatRequest_ID):
                logging.Debugf("心跳包 : %v \n", protoMsgID)
                n, err := tcpConn.Write(msg)
                logging.Debugln("=======================websocket write:", n, err)
                // n, err := c.ProxyConn.Write(protoMsgContent)
                // logging.Debugln("=======================websocket write:", n, err)
                // return
            case int32(messages.ProSocketCloseRequest_ID):
                logging.Debugf("关闭连接 : %v \n", protoMsgID)
                n, err := tcpConn.Write(msg)
                logging.Debugln("=======================websocket write:", n, err)
                // return
            case int32(messages.ProForceUserOfflineRequest_ID):
                logging.Debugf("强制下线 : %v \n", protoMsgID)
                // return
            default:
                logging.Debugf("转发 : %v \n", protoMsgID)
                _, err := tcpConn.Write(msg)
                if err != nil {
                    logging.Errorf("tcpConn.Write err : %v\n", err)
                }
                // logging.Debugln("=======================websocket write:", n, err)

                // return
            }
        default:
            conn.Close()
            return
        }
    }
}
func (c *Context) ReadTcp(conn *websocket.Conn, tcpConn net.Conn) {

    var (
        buf []byte
    )
    buf = make([]byte, 1024)
    for {
        length, err := tcpConn.Read(buf)
        // logging.Debugf("开始接收 tcp return content ================= tcpConn : %v \n", tcpConn)
        if err != nil {
            logging.Errorf("tcp conn err : %v \n", err)
            // tcpConn.Close()
            // conn.Close()
            break
        }

        logging.Debugf("接收到tcp的内容================== : %v \n", string(buf[:length]))
        conn.WriteMessage(2, buf[:length])
    }
}

func main() {
    flag.Parse()

    // Context上下文
    c := NewContext()
    // 解析TOML文件
    if _, err := toml.DecodeFile(tomlFile, &c.TOMLConfig); err != nil {
        logging.Errorf("TOML:DecodeFile: %v\n", err)
        return
    }
    logging.Debugf("Config: %#v\n", c.TOMLConfig)
    logging.Debugf("MetricConfig: %#v\n", c.TOMLConfig.MetricConfig)

    // 利用TOML配置文件进行初始化
    if err := c.Init(); err != nil {
        logging.Errorf("Context Init: %v\n", err)
        return
    }

    http.HandleFunc("/", c.echo)

    log.Fatal(http.ListenAndServe(listenAddr, nil))

}

你可能感兴趣的:(go)