go 原生rpc源码解析之服务端

导读

rpc协议是内部服务之间通信的较为常用的协议,它具有效率高、跨语言、易扩展等优点。go语言本身其实提供了rpc协议的实现,由于它功能比较简单,与thrift、grpc等相比,并不被广泛了解与使用。一叶知秋,通过刨开它的内部实现,也能窥探到实现一个rpc协议的套路一二。

例子

var (
    newServer                 *Server
    serverAddr, newServerAddr string
    httpServerAddr            string
    once, newOnce, httpOnce   sync.Once
)
const (
    newHttpPath = "/foo"
)
//Args receiver
type Args struct {
    A, B int
}
//reply receiver
type Reply struct {
    C int
}
//Arith receiver
type Arith int
// Some of Arith's methods have value args, some have pointer args. That's deliberate.
func (t *Arith) Add(args Args, reply *Reply) error {
    reply.C = args.A + args.B
    return nil
}
func (t *Arith) Mul(args *Args, reply *Reply) error {
    reply.C = args.A * args.B
    return nil
}
func (t *Arith) Div(args Args, reply *Reply) error {
    if args.B == 0 {
        return errors.New("divide by zero")
    }
    reply.C = args.A / args.B
    return nil
}
func (t *Arith) String(args *Args, reply *string) error {
    *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
    return nil
}
func (t *Arith) Scan(args string, reply *Reply) (err error) {
    _, err = fmt.Sscan(args, &reply.C)
    return
}
func (t *Arith) Error(args *Args, reply *Reply) error {
    panic("ERROR")
}
func (t *Arith) SleepMilli(args *Args, reply *Reply) error {
    time.Sleep(time.Duration(args.A) * time.Millisecond)
    return nil
}
//hidden receiver
type hidden int
func (t *hidden) Exported(args Args, reply *Reply) error {
    reply.C = args.A + args.B
    return nil
}

type Embed struct {
    hidden
}
//BuiltinTypes receiver
type BuiltinTypes struct{}
func (BuiltinTypes) Map(args *Args, reply *map[int]int) error {
    (*reply)[args.A] = args.B
    return nil
}
func (BuiltinTypes) Slice(args *Args, reply *[]int) error {
    *reply = append(*reply, args.A, args.B)
    return nil
}
func (BuiltinTypes) Array(args *Args, reply *[2]int) error {
    (*reply)[0] = args.A
    (*reply)[1] = args.B
    return nil
}
func listenTCP() (net.Listener, string) {
    l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
    if e != nil {
        log.Fatalf("net.Listen tcp :0: %v", e)
    }
    return l, l.Addr().String()
}
func startServer() {
       //注册Arith receiver
    Register(new(Arith))
       //注册Embed receiver
    Register(new(Embed))
       //注册Arith receiver,并起个名net.rpc.Arith
    RegisterName("net.rpc.Arith", new(Arith))
      //注册自带类型receiver
    Register(BuiltinTypes{})

    var l net.Listener
       //开启tcp服务
    l, serverAddr = listenTCP()
    log.Println("Test RPC server listening on", serverAddr)
       //开始接收连接
    go Accept(l)
       //通过http请求协议,发起rpc请求
    HandleHTTP()
       //启动http服务
    httpOnce.Do(startHttpServer)
}
//上面是默认服务,启动一个新的服务
func startNewServer() {
    newServer = NewServer()
    newServer.Register(new(Arith))
    newServer.Register(new(Embed))
    newServer.RegisterName("net.rpc.Arith", new(Arith))
    newServer.RegisterName("newServer.Arith", new(Arith))

    var l net.Listener
    l, newServerAddr = listenTCP()
    log.Println("NewServer test RPC server listening on", newServerAddr)
    go newServer.Accept(l)

    newServer.HandleHTTP(newHttpPath, "/bar")
    httpOnce.Do(startHttpServer)
}
func startHttpServer() {
    server := httptest.NewServer(nil)
    httpServerAddr = server.Listener.Addr().String()
    log.Println("Test HTTP RPC server listening on", httpServerAddr)
}

原理

主要从startServer里面讲起,先看Register的实现

// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) 
}
// NewServer returns a new Server.
func NewServer() *Server {
    return &Server{}
}

// DefaultServer is the default instance of *Server.
var DefaultServer = NewServer()

func (server *Server) register(rcvr interface{}, name string, useName bool) error {
//new 一个service,service结构体,可以看成是对一个receiver的封装
// type service struct {
//  name   string                 // name of service
//  rcvr   reflect.Value          // receiver of methods for the service
//  typ    reflect.Type           // type of the receiver
//  method map[string]*methodType // registered methods
//}
    s := new(service)
//返回接收者的类型
    s.typ = reflect.TypeOf(rcvr)
//返回接收者的value
    s.rcvr = reflect.ValueOf(rcvr)
//返回接收者实际类型名称,比如例子中的Args接收者,sname就是args
    sname := reflect.Indirect(s.rcvr).Type().Name()
//如果参数中指定名称,将sname赋值给指定的类型,这里可以把通过反射获取sname放到这个判断的后面,如果指定了name,就不用了做了
    if useName {
        sname = name
    }
//不能是空
    if sname == "" {
        s := "rpc.Register: no service name for type " + s.typ.String()
        log.Print(s)
        return errors.New(s)
    }
//如果不useName,receiver需要是exported,首字母大写
    if !isExported(sname) && !useName {
        s := "rpc.Register: type " + sname + " is not exported"
        log.Print(s)
        return errors.New(s)
    }
    s.name = sname

    // Install the methods
    s.method = suitableMethods(s.typ, true)
  //如果没有方法,看对应的指针类型有没有具体方法,并给client提示,例子中的Args如果没有方法,看*Args有没有,go里面这是两个不同的类型
    if len(s.method) == 0 {
        str := ""

        // To help the user, see if a pointer receiver would work.
        method := suitableMethods(reflect.PtrTo(s.typ), false)
        if len(method) != 0 {
            str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
        } else {
            str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
        }
        log.Print(str)
        return errors.New(str)
    }
//store到map中
    if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
        return errors.New("rpc: service already defined: " + sname)
    }
    return nil
}

Register函数中suitableMethods方法

// suitableMethods returns suitable Rpc methods of typ, it will report
// error using log if reportErr is true.
func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
//对一个方法的封装
//type methodType struct {
//  sync.Mutex // protects counters
//  method     reflect.Method
//  ArgType    reflect.Type
//  ReplyType  reflect.Type
//  numCalls   uint
//}
    methods := make(map[string]*methodType)
//遍历receiver的方法
    for m := 0; m < typ.NumMethod(); m++ {
        method := typ.Method(m)
        mtype := method.Type
        mname := method.Name
        // Method must be exported.
              //方法必须是exported
        if method.PkgPath != "" {
            continue
        }
        // Method needs three ins: receiver, *args, *reply.
               //入参必须是三个
        if mtype.NumIn() != 3 {
            if reportErr {
                log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
            }
            continue
        }
        // First arg need not be a pointer.
//第一个参数必须是指针
        argType := mtype.In(1)
//需是exported类型,或者是内部类型
        if !isExportedOrBuiltinType(argType) {
            if reportErr {
                log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
            }
            continue
        }
        // 第二个参数必须是指针
        replyType := mtype.In(2)
        if replyType.Kind() != reflect.Ptr {
            if reportErr {
                log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
            }
            continue
        }
        // 返回类型需是exported,或者是内部类型,比如int
        if !isExportedOrBuiltinType(replyType) {
            if reportErr {
                log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
            }
            continue
        }
        // 只能有一个返回值
        if mtype.NumOut() != 1 {
            if reportErr {
                log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
            }
            continue
        }
        // 返回类型必须是error类型
        if returnType := mtype.Out(0); returnType != typeOfError {
            if reportErr {
                log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
            }
            continue
        }
//将方法存到map
        methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
    }
    return methods
}

开启服务,结合例子来看

func listenTCP() (net.Listener, string) {
    l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
    if e != nil {
        log.Fatalf("net.Listen tcp :0: %v", e)
    }
    return l, l.Addr().String()
}
// go statement.
func (server *Server) Accept(lis net.Listener) {
    for {
        conn, err := lis.Accept()
        if err != nil {
            log.Print("rpc.Serve: accept:", err.Error())
            return
        }
//起个协程处理连接
        go server.ServeConn(conn)
    }
}

重点看server.ServeConn(conn)

func (server *Server) ServeConn(conn io.ReadWriteCloser) {
    buf := bufio.NewWriter(conn)
//对解析请求的一个封装
    srv := &gobServerCodec{
        rwc:    conn,
        dec:    gob.NewDecoder(conn),
        enc:    gob.NewEncoder(buf),
        encBuf: buf,
    }
    server.ServeCodec(srv)
}

func (server *Server) ServeCodec(codec ServerCodec) {
//一个互斥锁,并发安全
    sending := new(sync.Mutex)
//等待请求处理完,函数才能退出,防止请求进来,但是意外终止,并没有处理请求
    wg := new(sync.WaitGroup)
    for {
//解析request
        service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
        if err != nil {
            if debugLog && err != io.EOF {
                log.Println("rpc:", err)
            }
//如果请求头中解析错误,目前keepReading一直是true,应该是为了后面扩展用的
            if !keepReading {
                break
            }
            // send a response if we actually managed to read a header.
                      //如果请求不是nil
            if req != nil {
//response
                server.sendResponse(sending, req, invalidRequest, codec, err.Error())
 //将req结构体挂到free链上,复用
                server.freeRequest(req)
            }
//继续处理请求
            continue
        }
        wg.Add(1)
//如果一切正常,开始调用对应receiver的函数
        go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
    }
    // We've seen that there are no more requests.
    // Wait for responses to be sent before closing codec.
// 阻塞
    wg.Wait()
// 关闭连接
    codec.Close()
}

再看call方法

func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
    if wg != nil {
        defer wg.Done()
    }
    mtype.Lock()
//记录调用次数
    mtype.numCalls++
    mtype.Unlock()
    function := mtype.method.Func
    // Invoke the method, providing a new value for the reply.
//发起调用
    returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv})
    // The return value for the method is an error.
    errInter := returnValues[0].Interface()
    errmsg := ""
    if errInter != nil {
        errmsg = errInter.(error).Error()
    }
//发送response
    server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
//挂到free链
    server.freeRequest(req)
}

以上是注册receiver,client发起请求过来主要的处理过程。其实也可以通过http协议发送rpc请求。
例子中的HandleHTTP

const (
    // Defaults used by HandleHTTP
    DefaultRPCPath   = "/_goRPC_"
    DefaultDebugPath = "/debug/rpc"
)
func HandleHTTP() {
    DefaultServer.HandleHTTP(DefaultRPCPath, DefaultDebugPath)
}
func (server *Server) HandleHTTP(rpcPath, debugPath string) {
    http.Handle(rpcPath, server)
    http.Handle(debugPath, debugHTTP{server})
}
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
//如果不是连接
    if req.Method != "CONNECT" {
        w.Header().Set("Content-Type", "text/plain; charset=utf-8")
        w.WriteHeader(http.StatusMethodNotAllowed)
        io.WriteString(w, "405 must CONNECT\n")
        return
    }
//返回连接
    conn, _, err := w.(http.Hijacker).Hijack()
    if err != nil {
        log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
        return
    }
//response
    io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
//处理连接,与上文的逻辑一样
    server.ServeConn(conn)
}
func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    // Build a sorted version of the data.
    var services serviceArray
    server.serviceMap.Range(func(snamei, svci interface{}) bool {
        svc := svci.(*service)
        ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))}
        for mname, method := range svc.method {
            ds.Method = append(ds.Method, debugMethod{method, mname})
        }
        sort.Sort(ds.Method)
        services = append(services, ds)
        return true
    })
    sort.Sort(services)
//通过template输出services的信息到client
    err := debug.Execute(w, services)
    if err != nil {
        fmt.Fprintln(w, "rpc: error executing template:", err.Error())
    }
}

结束

本文只是做了go rpc服务端的讲解,等写了client部分才算完整。

你可能感兴趣的:(go 原生rpc源码解析之服务端)