Golang grpc server响应请求

朋友偶然问了一句,golang的grpc在接到请求后是如何调到对应的实现函数的?

当时对着代码讲了一通。后来想想觉得这是个好问题,写下来记录一下。

注册:

func main() {
    listen, err := net.Listen("tcp", ":2008")
    if err != nil {
        fmt.Println("net.Listen tcp :2008 err", err)
        return
    }
    s := grpc.NewServer()
    hServer := xxx.Server{}
        xxxxxx.RegisterXXXXXXServer(s, &hServer)
    s.Serve(listen)
}

golang起grpc server的代码很简单,一个proto接口的实现 :xxx.Server{},然后register server,listen就可以了

func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
   ht := reflect.TypeOf(sd.HandlerType).Elem()
   st := reflect.TypeOf(ss)
   if !st.Implements(ht) {
      grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
   }
   s.register(sd, ss)
}

这是grpc server register的入口,首先对服务接口定义ServiceDesc和接口实现ss取type,判断ss是否实现了ServiceDesc中的接口

type service struct {
   server interface{} // the server for service methods
   md     map[string]*MethodDesc
   sd     map[string]*StreamDesc
   mdata  interface{}
}

func (s *Server) register(sd *ServiceDesc, ss interface{}) {
   s.mu.Lock()
   defer s.mu.Unlock()
   s.printf("RegisterService(%q)", sd.ServiceName)
   if s.serve {
      grpclog.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
   }
   if _, ok := s.m[sd.ServiceName]; ok {
      grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
   }
   srv := &service{
      server: ss,
      md:     make(map[string]*MethodDesc),
      sd:     make(map[string]*StreamDesc),
      mdata:  sd.Metadata,
   }
   for i := range sd.Methods {
      d := &sd.Methods[i]
      srv.md[d.MethodName] = d
   }
   for i := range sd.Streams {
      d := &sd.Streams[i]
      srv.sd[d.StreamName] = d
   }
   s.m[sd.ServiceName] = srv
}

service 结构中的server就是实现了接口定义的实体,就是我们的响应服务。
md 保存了响应func的映射"MethodName -- func"
sd 保存了rpc流服务的映射"StreamName -- func"
mdata 保存了proto的位置

响应:

func (s *Server) Serve(lis net.Listener) error {
    // ......
    ls := &listenSocket{Listener: lis}
    s.lis[ls] = true
    // ......
    var tempDelay time.Duration // how long to sleep on accept failure

    for {
        rawConn, err := lis.Accept()
        if err != nil {
            // ......
        }
        tempDelay = 0
        // Start a new goroutine to deal with rawConn so we don't stall this Accept
        // loop goroutine.
        //
        // Make sure we account for the goroutine so GracefulStop doesn't nil out
        // s.conns before this conn can be added.
        s.serveWG.Add(1)
        go func() {
            s.handleRawConn(rawConn)
            s.serveWG.Done()
        }()
    }
}
func (s *Server) handleRawConn(rawConn net.Conn) {
    rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
    conn, authInfo, err := s.useTransportAuthenticator(rawConn)
    if err != nil {
        // ......
        return
    }

    s.mu.Lock()
    if s.conns == nil {
        s.mu.Unlock()
        conn.Close()
        return
    }
    s.mu.Unlock()

    // Finish handshaking (HTTP2)
    st := s.newHTTP2Transport(conn, authInfo)
    if st == nil {
        return
    }

    rawConn.SetDeadline(time.Time{})
    if !s.addConn(st) {
        return
    }
    go func() {
        s.serveStreams(st)
        s.removeConn(st)
    }()
}

servers监听端口listenSocket,Accept到请求后起goroutine 处理。然后把 建立http2链接。这里的ServerTransport是指 所有gRPC服务器端传输的通用接口实现。
然后起goroutine 继续serveStreams

func (s *Server) serveStreams(st transport.ServerTransport) {
    defer st.Close()
    var wg sync.WaitGroup
    st.HandleStreams(func(stream *transport.Stream) {
        wg.Add(1)
        go func() {
            defer wg.Done()
            s.handleStream(st, stream, s.traceInfo(st, stream))
        }()
    }, func(ctx context.Context, method string) context.Context {
        if !EnableTracing {
            return ctx
        }
        tr := trace.New("grpc.Recv."+methodFamily(method), method)
        return trace.NewContext(ctx, tr)
    })
    wg.Wait()
}

func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
    defer close(t.readerDone)
    for {
        frame, err := t.framer.fr.ReadFrame()
        atomic.StoreUint32(&t.activity, 1)
        if err != nil {
            if se, ok := err.(http2.StreamError); ok {
                warningf("transport: http2Server.HandleStreams encountered http2.StreamError: %v", se)
                t.mu.Lock()
                s := t.activeStreams[se.StreamID]
                t.mu.Unlock()
                if s != nil {
                    t.closeStream(s, true, se.Code, false)
                } else {
                    t.controlBuf.put(&cleanupStream{
                        streamID: se.StreamID,
                        rst:      true,
                        rstCode:  se.Code,
                        onWrite:  func() {},
                    })
                }
                continue
            }
            if err == io.EOF || err == io.ErrUnexpectedEOF {
                t.Close()
                return
            }
            warningf("transport: http2Server.HandleStreams failed to read frame: %v", err)
            t.Close()
            return
        }
        switch frame := frame.(type) {
        case *http2.MetaHeadersFrame:
            if t.operateHeaders(frame, handle, traceCtx) {
                t.Close()
                break
            }
        case *http2.DataFrame:
            t.handleData(frame)
        case *http2.RSTStreamFrame:
            t.handleRSTStream(frame)
        case *http2.SettingsFrame:
            t.handleSettings(frame)
        case *http2.PingFrame:
            t.handlePing(frame)
        case *http2.WindowUpdateFrame:
            t.handleWindowUpdate(frame)
        case *http2.GoAwayFrame:
            // TODO: Handle GoAway from the client appropriately.
        default:
            errorf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame)
        }
    }
}

st.HandleStreams中解析了 ServerTransport中的frame,traceCtx将trace附加到ctx并返回新上下文,调用s.handleStream处理请求(s.handleStream中是真正调用服务响应函数地方)

func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
    sm := stream.Method()
    if sm != "" && sm[0] == '/' {
        sm = sm[1:]
    }
    pos := strings.LastIndex(sm, "/")
    if pos == -1 {
        if trInfo != nil {
            trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
            trInfo.tr.SetError()
        }
        errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
        if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
            if trInfo != nil {
                trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
                trInfo.tr.SetError()
            }
            grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err)
        }
        if trInfo != nil {
            trInfo.tr.Finish()
        }
        return
    }
    service := sm[:pos]
    method := sm[pos+1:]

    srv, knownService := s.m[service]
    if knownService {
        if md, ok := srv.md[method]; ok {
            s.processUnaryRPC(t, stream, srv, md, trInfo)
            return
        }
        if sd, ok := srv.sd[method]; ok {
            s.processStreamingRPC(t, stream, srv, sd, trInfo)
            return
        }
    }
    // Unknown service, or known server unknown method.
    if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
        s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
        return
    }
    var errDesc string
    if !knownService {
        errDesc = fmt.Sprintf("unknown service %v", service)
    } else {
        errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
    }
    if trInfo != nil {
        trInfo.tr.LazyPrintf("%s", errDesc)
        trInfo.tr.SetError()
    }
    if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
        if trInfo != nil {
            trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
            trInfo.tr.SetError()
        }
        grpclog.Warningf("grpc: Server.handleStream failed to write status: %v", err)
    }
    if trInfo != nil {
        trInfo.tr.Finish()
    }
}

server.handleStream中根据ServerTransport中带来的server name, method name在最开始注册时记录的map中找到对应的handle func 执行processUnaryRPC(如果是流服务 那么会执行processStreamingRPC)。

func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) {
    //...
      这里有一大段代码,都是在执行数据的comp/decomp操作
    //..
    ctx := NewContextWithServerTransportStream(stream.Context(), stream)
    reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)
    if appErr != nil {
        appStatus, ok := status.FromError(appErr)
        if !ok {
            // Convert appErr if it is not a grpc status error.
            appErr = status.Error(codes.Unknown, appErr.Error())
            appStatus, _ = status.FromError(appErr)
        }
        if trInfo != nil {
            trInfo.tr.LazyLog(stringer(appStatus.Message()), true)
            trInfo.tr.SetError()
        }
        if e := t.WriteStatus(stream, appStatus); e != nil {
            grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e)
        }
        if binlog != nil {
            if h, _ := stream.Header(); h.Len() > 0 {
                // Only log serverHeader if there was header. Otherwise it can
                // be trailer only.
                binlog.Log(&binarylog.ServerHeader{
                    Header: h,
                })
            }
            binlog.Log(&binarylog.ServerTrailer{
                Trailer: stream.Trailer(),
                Err:     appErr,
            })
        }
        return appErr
    }
    if trInfo != nil {
        trInfo.tr.LazyLog(stringer("OK"), false)
    }
    opts := &transport.Options{Last: true}

    if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
        if err == io.EOF {
            // The entire stream is done (for unary RPC only).
            return err
        }
        if s, ok := status.FromError(err); ok {
            if e := t.WriteStatus(stream, s); e != nil {
                grpclog.Warningf("grpc: Server.processUnaryRPC failed to write status: %v", e)
            }
        } else {
            switch st := err.(type) {
            case transport.ConnectionError:
                // Nothing to do here.
            default:
                panic(fmt.Sprintf("grpc: Unexpected error (%T) from sendResponse: %v", st, st))
            }
        }
        if binlog != nil {
            h, _ := stream.Header()
            binlog.Log(&binarylog.ServerHeader{
                Header: h,
            })
            binlog.Log(&binarylog.ServerTrailer{
                Trailer: stream.Trailer(),
                Err:     appErr,
            })
        }
        return err
    }
    if binlog != nil {
        h, _ := stream.Header()
        binlog.Log(&binarylog.ServerHeader{
            Header: h,
        })
        binlog.Log(&binarylog.ServerMessage{
            Message: reply,
        })
    }
    if channelz.IsOn() {
        t.IncrMsgSent()
    }
    if trInfo != nil {
        trInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true)
    }
    // TODO: Should we be logging if writing status failed here, like above?
    // Should the logging be in WriteStatus?  Should we ignore the WriteStatus
    // error or allow the stats handler to see it?
    err = t.WriteStatus(stream, status.New(codes.OK, ""))
    if binlog != nil {
        binlog.Log(&binarylog.ServerTrailer{
            Trailer: stream.Trailer(),
            Err:     appErr,
        })
    }
    return err
}

上面省略了一大段构造context,decomp的代码。
经过一系列context的构造,decomp 终于到了调请求对应实现方法的地方了:reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt),返回的reply就是给client的data。最下面WriteStatus codes.OK就是client _stub函数返回的status.ok()了。

本来想画个图的,但是最近时间有点紧,等有空了再补吧,就当又温习一遍

你可能感兴趣的:(Golang grpc server响应请求)