RPC(Remote Procedure Call,远程过程调用)是一种计算机通信协议,允许调用不同进程空间的程序。RPC 的客户端和服务器可以在一台机器上,也可以在不同的机器上。程序员使用时,就像调用本地程序一样,无需关注内部的实现细节。
不同的应用程序之间的通信方式有很多,比如浏览器和服务器之间广泛使用的基于 HTTP 协议的 Restful API。与 RPC 相比,Restful API 有相对统一的标准,因而更通用,兼容性更好,支持不同的语言。HTTP 协议是基于文本的,一般具备更好的可读性。但是缺点也很明显:
需要确定采用的传输协议是什么?如果这个两个应用程序位于不同的机器,那么一般会选择 TCP 协议或者 HTTP 协议;那如果两个应用程序位于相同的机器,也可以选择 Unix Socket 协议。
还需要确定报文的编码格式,比如采用最常用的 JSON 或者 XML,那如果报文比较大,还可能会选择 protobuf 等其他的编码方式,甚至编码之后,再进行压缩。接收端获取报文则需要相反的过程,先解压再解码。
如果服务端的实例很多,客户端并不关心这些实例的地址和部署位置,只关心自己能否获取到期待的结果,那就引出了注册中心(registry)和负载均衡(load balance)的问题。(即客户端和服务端互相不感知对方的存在,服务端启动时将自己注册到注册中心,客户端调用时,从注册中心获取到所有可用的实例,选择一个来调用。)
Go 语言广泛地应用于云计算和微服务,成熟的 RPC 框架和微服务框架汗牛充栋。grpc、rpcx、go-micro 等都是非常成熟的框架。一般而言,RPC 是微服务框架的一个子集,微服务框架可以自己实现 RPC 部分,当然,也可以选择不同的 RPC 框架作为通信基座。
上述成熟的框架代码量都比较庞大,而且通常和第三方库,例如 protobuf、etcd、zookeeper 等有比较深的耦合,难以直观地窥视框架的本质。
因此,从零实现 Go 语言官方的标准库 net/rpc,并在此基础上,新增协议交换(protocol exchange)、注册中心(registry)、服务发现(service discovery)、负载均衡(load balance)、超时处理(timeout processing)等特性。有助于理解 RPC 框架在设计时需要考虑什么。
使用 encoding/gob 实现消息的编解码(序列化与反序列化)。
gob(Go binary)是Goland包自带的一个数据结构序列化的编码/解码工具。编码使用Encoder,解码使用Decoder。一种典型的应用场景就是RPC(remote procedure calls)。
代码中有这样的字眼:var _ Codec = (*GobCodec)(nil)
,作者给出的解释如下:
类似的方法还有:
// 验证httpGetter是否实现了PeerGetter接口
var _ PeerGetter = &httpGetter{}
包括gin
框架的源码:
type IRouter interface{ ... }
...
...
type RouterGroup struct { ... }
...
var _ IRouter = &RouterGroup{}
一个典型的 RPC 调用如下(参考go语言rpc/grpc介绍):
err = client.Call("Arith.Multiply", args, &reply)
客户端发送的请求包括服务名 Arith
,方法名 Multiply
,参数 args
三个,服务端的响应包括错误 error
,返回值 reply
2 个。
抽象出数据结构 Header:
package codec
import "io"
type Header struct {
// ServiceMethod 是服务名和方法名
ServiceMethod string
// Seq 是请求的序号,也可以认为是某个请求的 ID,用来区分不同的请求。
Seq uint64
Error string
}
// Codec 抽象出对消息体进行编解码的接口
type Codec interface {
io.Closer
ReadHeader(*Header) error
ReadBody(any) error
Write(*Header, any) error
}
type NewCodecFunc func(closer io.ReadWriteCloser) Codec
const (
GobType = "application/gob"
JsonType = "application/json"
)
var NewCodecFuncMap map[string]NewCodecFunc
func init() {
NewCodecFuncMap = make(map[string]NewCodecFunc)
NewCodecFuncMap[GobType] = NewGobCodec
}
再实现编解码接口:
package codec
import (
"bufio"
"encoding/gob"
"io"
"log"
)
type GobCodec struct {
// conn 是由构建函数传入,通常是通过 TCP 或者 Unix 建立 socket 时得到的链接实例
conn io.ReadWriteCloser
// buf 是为了防止阻塞而创建的带缓冲的 Writer,一般这么做能提升性能。
buf *bufio.Writer
dec *gob.Decoder
enc *gob.Encoder
}
// 确保GobCodec实现了Codec
var _ Codec = (*GobCodec)(nil)
// NewGobCodec 是GobCodec的构造函数
func NewGobCodec(conn io.ReadWriteCloser) Codec {
buf := bufio.NewWriter(conn)
return &GobCodec{
conn: conn,
buf: buf,
// dec 从conn解码
dec: gob.NewDecoder(conn),
// enc 编码到buf
enc: gob.NewEncoder(buf),
}
}
func (c *GobCodec) ReadHeader(h *Header) error {
return c.dec.Decode(h)
}
func (c *GobCodec) ReadBody(body interface{}) error {
return c.dec.Decode(body)
}
func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
defer func() {
_ = c.buf.Flush()
if err != nil {
_ = c.Close()
}
}()
if err := c.enc.Encode(h); err != nil {
log.Println("rpc codec: gob error encoding header:", err)
return err
}
if err := c.enc.Encode(body); err != nil {
log.Println("rpc codec: gob error encoding body:", err)
return err
}
return nil
}
func (c *GobCodec) Close() error {
return c.conn.Close()
}
客户端与服务端的通信需要协商一些内容,为了提升性能,一般在报文的最开始会规划固定的字节,来协商相关的信息。比如第1个字节用来表示序列化方式,第2个字节表示压缩方式,第3-6字节表示 header 的长度,7-10 字节表示 body 的长度。
服务端首先使用 JSON 解码 Option,然后通过 Option 的 CodeType 解码剩余的内容。即报文将以这样的形式发送:
| Option{MagicNumber: xxx, CodecType: xxx} | Header{ServiceMethod ...} | Body interface{} |
| <------ 固定 JSON 编码 ------> | <------- 编码方式由 CodeType 决定 ------->|
在一次连接中,Option 固定在报文的最开始,Header 和 Body 可以有多个,即报文可能是这样的。
| Option | Header1 | Body1 | Header2 | Body2 | ...
package GenRpc
import (
"github.com/Generlazy/GenGrpc/GenRpc/codec"
"log"
"net"
)
const MagicNumber = 0x3bef5c
type Option struct {
// MagicNumber标记这是一个GenRpc请求
MagicNumber int
// CodecType body编码方式
CodecType string
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
}
type Server struct{}
// NewServer returns a new Server.
func NewServer() *Server {
return &Server{}
}
// DefaultServer 是一个默认的 Server 实例,主要为了用户使用方便。
var DefaultServer = NewServer()
// Accept 接受请求
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error:", err)
return
}
// 异步服务request
go server.ServeConn(conn)
}
}
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
启动服务:
lis, _ := net.Listen("tcp", ":9999")
geerpc.Accept(lis)
实现ServerConn(conn):
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() { _ = conn.Close() }()
var opt Option
// 将 magicNumber 和 Content-type 解码到opt中
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: options error: ", err)
return
}
// 判断 magicNumber是否正确
if opt.MagicNumber != MagicNumber {
log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
return
}
// 获取解码器gob/json的构造函数
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
return
}
// 获取具体构造器gob/json
codecObj := f(conn)
// 开启服务
server.serveCodec(codecObj)
}
func (server *Server) serveCodec(cc codec.Codec) {
// 确保发送完整的响应
sending := new(sync.Mutex)
// wait until all request are handled
wg := new(sync.WaitGroup)
for {
// 一直读取请求(上文将连接对象传入到了gob中)
req, err := server.readRequest(cc)
if err != nil {
if req == nil {
break // it's not possible to recover, so close the connection
}
req.h.Error = err.Error()
// 返回错误响应
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
// 异步处理正确响应
go server.handleRequest(cc, req, sending, wg)
}
wg.Wait()
_ = cc.Close()
}
serveCodec 的过程非常简单。主要包含三个阶段:
在一次连接中,允许接收多个请求,即多个 request header 和 request body,因此这里使用了 for 无限制地等待请求的到来,直到发生错误(例如连接被关闭,接收到的报文有问题等),这里需要注意的点有三个:
// request 请求上下文
type request struct {
// h 请求头
h *codec.Header
// argv 请求参数
argv reflect.Value
// respv 响应参数
respv reflect.Value
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
// 读取请求头
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.argv = reflect.New(reflect.TypeOf(""))
// 读取请求体
if err = cc.ReadBody(req.argv.Interface()); err != nil {
log.Println("rpc server: read argv err:", err)
}
return req, nil
}
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
// 将头信息解码到h返回
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error:", err)
}
return nil, err
}
return &h, nil
}
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
// 将h和body写入到conn中
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response error:", err)
}
}
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
defer wg.Done()
// 输出请求参数
log.Println(req.h, req.argv.Elem())
req.respv = reflect.ValueOf(fmt.Sprintf("geerpc resp %d", req.h.Seq))
server.sendResponse(cc, req.h, req.respv.Interface(), sending)
}
package main
import (
"encoding/json"
"fmt"
"github.com/Generlazy/GenGrpc/GenRpc"
"github.com/Generlazy/GenGrpc/GenRpc/codec"
"log"
"net"
"time"
)
func startServer(addr chan string) {
// 监听tcp:8080
l, err := net.Listen("tcp", ":8080")
if err != nil {
log.Fatal("network error:", err)
}
log.Println("start rpc server on", l.Addr())
addr <- l.Addr().String()
GenRpc.Accept(l)
}
func main() {
addr := make(chan string)
// 开启服务
go startServer(addr)
// 客户端
conn, _ := net.Dial("tcp", <-addr)
defer func() { _ = conn.Close() }()
time.Sleep(time.Second)
// 发送Option 协商好的格式
_ = json.NewEncoder(conn).Encode(GenRpc.DefaultOption)
// 获取gob编码器
cc := codec.NewGobCodec(conn)
// send request & receive response
// 一个conn连接,请求响应了10次
for i := 0; i < 5; i++ {
h := &codec.Header{
// 调用Foo.Sum
ServiceMethod: "Foo.Sum",
// 序列号为 index
Seq: uint64(i),
}
_ = cc.Write(h, fmt.Sprintf("geerpc req %d", h.Seq))
_ = cc.ReadHeader(h)
var reply string
_ = cc.ReadBody(&reply)
log.Println("reply:", reply)
}
}
对 net/rpc 而言,一个函数需要能够被远程调用,需要满足如下五个条件:
func (t *T) MethodName(argType T1, replyType *T2) error
// Call represents an active RPC.
type Call struct {
Seq uint64
ServiceMethod string // format "."
Args interface{} // arguments to the function
Reply interface{} // reply from the function
Error error // if error occurs, it will be set
Done chan *Call // Strobes when call is complete.
}
func (call *Call) done() {
call.Done <- call
}
为了支持异步调用,Call 结构体中添加了一个字段 Done,Done 的类型是 chan *Call,当调用结束时,会调用 call.done() 通知调用方。
// Client Rpc客户端
type Client struct {
// cc 编解码器
cc codec.Codec
// opt 自定义协议选项
opt *Option
// sending 是一个互斥锁,和服务端类似,为了保证请求的有序发送,即防止出现多个请求报文混淆。
sending sync.Mutex
// header 是每个请求的消息头,header 只有在请求发送时才需要,而请求发送是互斥的,因此每个客户端只需要一个,声明在 Client 结构体中可以复用。
header codec.Header
mu sync.Mutex
// seq 用于给发送的请求编号,每个请求拥有唯一编号。
seq uint64
// pending 存储未处理完的请求,键是编号,值是 Call 实例。
pending map[uint64]*Call
// closing 标记客户端是否关闭,通过调用 Close 设置
closing bool
// shutdown 置为 true 一般是有错误发生。
shutdown bool
}
// 验证Client是否实现了io.Closer 便于在编译阶段就报错
var _ io.Closer = (*Client)(nil)
var ErrShutdown = errors.New("connection is shut down")
// Close the connection
func (client *Client) Close() error {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing {
return ErrShutdown
}
client.closing = true
return client.cc.Close()
}
// IsAvailable return true if the client does work
func (client *Client) IsAvailable() bool {
client.mu.Lock()
defer client.mu.Unlock()
return !client.shutdown && !client.closing
}
// registerCall:将参数 call 添加到 client.pending 中,并更新 client.seq。
func (client *Client) registerCall(call *Call) (uint64, error) {
client.mu.Lock()
defer client.mu.Unlock()
if client.closing || client.shutdown {
return 0, ErrShutdown
}
call.Seq = client.seq
client.pending[call.Seq] = call
client.seq++
return call.Seq, nil
}
// removeCall:根据 seq,从 client.pending 中移除对应的 call,并返回。
func (client *Client) removeCall(seq uint64) *Call {
client.mu.Lock()
defer client.mu.Unlock()
call := client.pending[seq]
delete(client.pending, seq)
return call
}
// terminateCalls:服务端或客户端发生错误时调用,将 shutdown 设置为 true,且将错误信息通知所有 pending 状态的 call。
// terminateCalls:服务端或客户端发生错误时调用,将 shutdown 设置为 true,且将错误信息通知所有 pending 状态的 call。
func (client *Client) terminateCalls(err error) {
// 先锁发送
client.sending.Lock()
defer client.sending.Unlock()
// 再锁client
client.mu.Lock()
defer client.mu.Unlock()
client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
}
}
对一个客户端端来说,接收响应、发送请求是最重要的 2 个功能。
接收到的响应有三种情况:
// receive 接收功能
func (client *Client) receive() {
var err error
for err == nil {
var h codec.Header
// 从conn解码header到h
if err = client.cc.ReadHeader(&h); err != nil {
break
}
// 根据header取出一个调用
call := client.removeCall(h.Seq)
// call的情况
switch {
case call == nil:
// it usually means that Write partially failed
// and call was already removed.
err = client.cc.ReadBody(nil)
case h.Error != "":
call.Error = fmt.Errorf(h.Error)
err = client.cc.ReadBody(nil)
call.done()
default:
err = client.cc.ReadBody(call.Reply)
if err != nil {
call.Error = errors.New("reading body " + err.Error())
}
call.done()
}
}
// error occurs, so terminateCalls pending calls
client.terminateCalls(err)
}
创建 Client 实例时,首先需要完成一开始的协议交换,即发送 Option 信息给服务端。协商好消息的编解码方式之后,再创建一个子协程调用 receive() 接收响应。
func NewClient(conn net.Conn, opt *Option) (*Client, error) {
// 获取编解码器的初始化函数
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
err := fmt.Errorf("invalid codec type %s", opt.CodecType)
log.Println("rpc client: codec error:", err)
return nil, err
}
// send options with server
// 将option 按照规定json序列化 并传输给conn
if err := json.NewEncoder(conn).Encode(opt); err != nil {
log.Println("rpc client: options error: ", err)
_ = conn.Close()
return nil, err
}
return newClientCodec(f(conn), opt), nil
}
func newClientCodec(cc codec.Codec, opt *Option) *Client {
client := &Client{
seq: 1, // seq starts with 1, 0 means invalid call
cc: cc,
opt: opt,
pending: make(map[uint64]*Call),
}
go client.receive()
return client
}
实现 Dial 函数,便于用户传入服务端地址,创建 Client 实例。为了简化用户调用,通过 …*Option 将 Option 实现为可选参数(选项模式)。
func parseOptions(opts ...*Option) (*Option, error) {
// 没有传入选项返回默认值
if len(opts) == 0 || opts[0] == nil {
return DefaultOption, nil
}
if len(opts) != 1 {
return nil, errors.New("number of options is more than 1")
}
opt := opts[0]
opt.MagicNumber = DefaultOption.MagicNumber
if opt.CodecType == "" {
opt.CodecType = DefaultOption.CodecType
}
return opt, nil
}
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if client == nil {
_ = conn.Close()
}
}()
return NewClient(conn, opt)
}
实现发送请求的能力:
func (client *Client) send(call *Call) {
// make sure that the client will send a complete request
client.sending.Lock()
defer client.sending.Unlock()
// register this call.
seq, err := client.registerCall(call)
if err != nil {
// 用于 receive 判断call的情况
call.Error = err
call.done()
return
}
// prepare request header
client.header.ServiceMethod = call.ServiceMethod
client.header.Seq = seq
client.header.Error = ""
// encode and send the request
if err := client.cc.Write(&client.header, call.Args); err != nil {
call := client.removeCall(seq)
// call may be nil, it usually means that Write partially failed,
// client has received the response and handled
if call != nil {
call.Error = err
call.done()
}
}
}
Go 和 Call 是客户端暴露给用户的两个 RPC 服务调用接口,Go 是一个异步接口,返回 call 实例。
Call 是对 Go 的封装,阻塞 call.Done,等待响应返回,是一个同步接口。
// Go invokes the function asynchronously.
// It returns the Call structure representing the invocation.
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
if done == nil {
done = make(chan *Call, 10)
} else if cap(done) == 0 {
log.Panic("rpc client: done channel is unbuffered")
}
call := &Call{
ServiceMethod: serviceMethod,
Args: args,
Reply: reply,
Done: done,
}
client.send(call)
return call
}
// Call invokes the named function, waits for it to complete,
// and returns its error status.
func (client *Client) Call(serviceMethod string, args, reply interface{}) error {
call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
return call.Error
}
handleRequest
方法只是打印序列号:fmt.Sprintf("geerpc resp %d", req.h.Seq)
,并没有实现根据ServiceMethod去寻找对应服务的功能。假设客户端发过来一个请求,包含 ServiceMethod 和 Argv:
{
"ServiceMethod": "T.MethodName"
"Argv":"0101110101..." // 序列化之后的字节流
}
通过 “T.MethodName” 可以确定调用的是类型 T 的 MethodName,通过反射能够非常容易地获取某个结构体的所有方法,并且能够通过方法,获取到该方法所有的参数类型与返回值。
func main() {
var wg sync.WaitGroup
typ := reflect.TypeOf(&wg)
for i := 0; i < typ.NumMethod(); i++ {
method := typ.Method(i)
argv := make([]string, 0, method.Type.NumIn())
returns := make([]string, 0, method.Type.NumOut())
// j 从 1 开始,第 0 个入参是 wg 自己。
for j := 1; j < method.Type.NumIn(); j++ {
argv = append(argv, method.Type.In(j).Name())
}
for j := 0; j < method.Type.NumOut(); j++ {
returns = append(returns, method.Type.Out(j).Name())
}
log.Printf("func (w *%s) %s(%s) %s",
typ.Elem().Name(),
method.Name,
strings.Join(argv, ","),
strings.Join(returns, ","))
}
}
func (w *WaitGroup) Add(int)
func (w *WaitGroup) Done()
func (w *WaitGroup) Wait()
定义结构体 methodType:,实现了 2 个方法 newArgv 和 newReplyv,用于创建对应类型的实例。
type methodType struct {
// method:方法本身
method reflect.Method
// ArgType:第一个参数的类型(请求参数)
ArgType reflect.Type
// ReplyType:第二个参数的类型(响应参数)
ReplyType reflect.Type
// numCalls:后续统计方法调用次数时会用到
numCalls uint64
}
// NumCalls 返回调用Method的次数
func (m *methodType) NumCalls() uint64 {
return atomic.LoadUint64(&m.numCalls)
}
func (m *methodType) newArgv() reflect.Value {
var argv reflect.Value
// arg may be a pointer type, or a value type
if m.ArgType.Kind() == reflect.Ptr {
// 如果是指针,需要调用Elem()方法,相等于*ptr获取值
argv = reflect.New(m.ArgType.Elem())
} else {
argv = reflect.New(m.ArgType).Elem()
}
return argv
}
func (m *methodType) newReplyv() reflect.Value {
// reply must be a pointer type
replyv := reflect.New(m.ReplyType.Elem())
switch m.ReplyType.Elem().Kind() {
case reflect.Map:
replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
case reflect.Slice:
replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
}
return replyv
}
定义结构体 service:
type service struct {
// name 即映射的结构体的名称 比如 T,比如 WaitGroup
name string
// typ 是结构体的类型
typ reflect.Type
// rcvr 即结构体的实例本身,保留 rcvr 是因为在调用时需要 rcvr 作为第 0 个参数
rcvr reflect.Value
// method 是 map 类型,存储映射的结构体的所有符合条件的方法。
method map[string]*methodType
}
func newService(rcvr interface{}) *service {
s := new(service)
s.rcvr = reflect.ValueOf(rcvr)
s.name = reflect.Indirect(s.rcvr).Type().Name()
s.typ = reflect.TypeOf(rcvr)
if !ast.IsExported(s.name) {
log.Fatalf("rpc server: %s is not a valid service name", s.name)
}
s.registerMethods()
return s
}
func (s *service) registerMethods() {
s.method = make(map[string]*methodType)
for i := 0; i < s.typ.NumMethod(); i++ {
method := s.typ.Method(i)
mType := method.Type
if mType.NumIn() != 3 || mType.NumOut() != 1 {
continue
}
if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
continue
}
argType, replyType := mType.In(1), mType.In(2)
if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
continue
}
s.method[method.Name] = &methodType{
method: method,
ArgType: argType,
ReplyType: replyType,
}
log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
}
}
func isExportedOrBuiltinType(t reflect.Type) bool {
return ast.IsExported(t.Name()) || t.PkgPath() == ""
}
registerMethods 过滤出了符合条件的方法:
两个导出或内置类型的入参(反射时为 3 个,第 0 个是自身,类似于 python 的 self,java 中的 this)
返回值有且只有 1 个,类型为 error。
还需要实现 call 方法,即能够通过反射值调用方法。
func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
atomic.AddUint64(&m.numCalls, 1)
f := m.method.Func
returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
if errInter := returnValues[0].Interface(); errInter != nil {
return errInter.(error)
}
return nil
}
通过反射结构体已经映射为服务,但请求的处理过程还没有完成。从接收到请求到回复还差以下几个步骤:第一步,根据入参类型,将请求的 body 反序列化;第二步,调用 service.call,完成方法调用;第三步,将 reply 序列化为字节流,构造响应报文,返回。
需要为 Server 实现一个方法 Register:
// Server represents an RPC Server.
type Server struct {
serviceMap sync.Map
}
// Register publishes in the server the set of methods of the
func (server *Server) Register(rcvr interface{}) error {
s := newService(rcvr)
if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup {
return errors.New("rpc: service already defined: " + s.name)
}
return nil
}
// Register publishes the receiver's methods in the DefaultServer.
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
配套实现 findService 方法,即通过 ServiceMethod 从 serviceMap 中找到对应的 service:
func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod)
return
}
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc server: can't find service " + serviceName)
return
}
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc server: can't find method " + methodName)
}
return
}
findService 的实现看似比较繁琐,但是逻辑还是非常清晰的。因为 ServiceMethod 的构成是 “Service.Method”,因此先将其分割成 2 部分,第一部分是 Service 的名称,第二部分即方法名。现在 serviceMap 中找到对应的 service 实例,再从 service 实例的 method 中,找到对应的 methodType。
补全 readRequest 方法:
// request 请求上下文
type request struct {
// h 请求头
h *codec.Header
// argv 请求参数
argv reflect.Value
// respv 响应参数
respv reflect.Value
mtype *methodType
svc *service
}
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
req := &request{h: h}
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mtype.newArgv()
req.respv = req.mtype.newReplyv()
// make sure that argvi is a pointer, ReadBody need a pointer as parameter
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Ptr {
argvi = req.argv.Addr().Interface()
}
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read body err:", err)
return req, err
}
return req, nil
}
readRequest 方法中最重要的部分,即通过 newArgv() 和 newReplyv() 两个方法创建出两个入参实例,然后通过 cc.ReadBody() 将请求报文反序列化为第一个入参 argv,在这里同样需要注意 argv 可能是值类型,也可能是指针类型,所以处理方式有点差异。
补全 handleRequest 方法:
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
defer wg.Done()
err := req.svc.call(req.mtype, req.argv, req.respv)
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, struct {}{}, sending)
return
}
server.sendResponse(cc, req.h, req.respv.Interface(), sending)
}
相对于 readRequest,handleRequest 的实现非常简单,通过 req.svc.call 完成方法调用,将 replyv 传递给 sendResponse 完成序列化即可。
超时处理是 RPC 框架一个比较基本的能力,如果缺少超时处理机制,无论是服务端还是客户端都容易因为网络或其他错误导致挂死,资源耗尽,这些问题的出现大大地降低了服务的可用性。因此,我们需要在 RPC 框架中加入超时处理的能力。
纵观整个远程调用的过程,需要客户端处理超时的地方有:
需要服务端处理超时的地方有:
ConnectTimeout 默认值为 10s,HandleTimeout 默认值为 0,即不设限。
type Option struct {
// MagicNumber标记这是一个GenRpc请求
MagicNumber int
// CodecType body编码方式
CodecType string
ConnectTimeout time.Duration // 0 means no limit
HandleTimeout time.Duration
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10,
}
客户端连接超时,只需要为 Dial 添加一层超时处理的外壳即可:
type clientResult struct {
client *Client
err error
}
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)
func dialTimeout(f newClientFunc, network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
if err != nil {
return nil, err
}
// close the connection if client is nil
defer func() {
if err != nil {
_ = conn.Close()
}
}()
ch := make(chan clientResult)
go func() {
client, err := f(conn, opt)
ch <- clientResult{client: client, err: err}
}()
if opt.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
// 阻塞在这,直到某一个case有响应
select {
case <-time.After(opt.ConnectTimeout):
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}
// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(NewClient, network, address, opts...)
}
Client.Call 的超时处理机制,使用 context 包实现,控制权交给用户,控制更为灵活。
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return errors.New("rpc client: call failed: " + ctx.Err().Error())
case call := <-call.Done:
return call.Error
}
}
可以这样使用:
ctx, _ := context.WithTimeout(context.Background(), time.Second)
var reply int
err := client.Call(ctx, "Foo.Sum", &Args{1, 2}, &reply)
...
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
called := make(chan struct{})
sent := make(chan struct{})
go func() {
err := req.svc.call(req.mtype, req.argv, req.respv)
called <- struct{}{}
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, struct {
}{}, sending)
sent <- struct{}{}
return
}
server.sendResponse(cc, req.h, req.respv.Interface(), sending)
sent <- struct{}{}
}()
if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
server.sendResponse(cc, req.h, struct {
}{}, sending)
case <-called:
<-sent
}
}
这里需要确保 sendResponse 仅调用一次,因此将整个过程拆分为 called 和 sent 两个阶段,在这段代码中只会发生如下两种情况:
RPC 的消息格式与标准的 HTTP 协议并不兼容,在这种情况下,就需要一个协议的转换过程。HTTP 协议的 CONNECT 方法恰好提供了这个能力,CONNECT 一般用于代理服务。
假设浏览器与服务器之间的 HTTPS 通信都是加密的,浏览器通过代理服务器发起 HTTPS 请求时,由于请求的站点地址和端口号都是加密保存在 HTTPS 请求报文头中的,代理服务器如何知道往哪里发送请求呢?为了解决这个问题,浏览器通过 HTTP 明文形式向代理服务器发送一个 CONNECT 请求告诉代理服务器目标地址和端口,代理服务器接收到这个请求后,会在对应端口与目标站点建立一个 TCP 连接,连接建立成功后返回 HTTP 200 状态码告诉浏览器与该站点的加密通道已经完成。接下来代理服务器仅需透传浏览器和服务器之间的加密数据包即可,代理服务器无需解析 HTTPS 报文。
举一个简单例子:
CONNECT geektutu.com:443 HTTP/1.0
HTTP/1.0 200 Connection Established
事实上,这个过程其实是通过代理服务器将 HTTP 协议转换为 HTTPS 协议的过程。对 RPC 服务端来,需要做的是将 HTTP 协议转换为 RPC 协议,对客户端来说,需要新增通过 HTTP CONNECT 请求创建连接的逻辑。
那通信过程应该是这样的:
CONNECT 10.0.0.1:9999/_genrpc_ HTTP/1.0
HTTP/1.0 200 Connected to Gen RPC
const (
connected = "200 Connected to Gen RPC"
defaultRPCPath = "/_genprc_"
defaultDebugPath = "/debug/genrpc"
)
// ServeHTTP implements a http.Handler that answers RPC requests.
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = io.WriteString(w, "405 must CONNECT\n")
return
}
// 将http请求劫持 获取连接
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
_, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func (server *Server) HandleHTTP() {
http.Handle(defaultRPCPath, server)
}
// HandleHTTP is a convenient approach for default server to register HTTP handlers
func HandleHTTP() {
DefaultServer.HandleHTTP()
}
Hijack()可以将HTTP对应的TCP连接取出,连接在Hijack()之后,HTTP的相关操作就会受到影响,调用方需要负责去关闭连接:
type Hijacker interface {
Hijack() (net.Conn, *bufio.ReadWriter, error)
}
func handle1(w http.ResponseWriter, r *http.Request) {
hj, _ := w.(http.Hijacker)
conn, buf, _ := hj.Hijack()
defer conn.Close()
buf.WriteString("hello world")
buf.Flush()
}
func handle2(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "hello world")
}
上面两个handle方法有什么区别呢?很简单,同样是http请求,返回的结果一个遵循http协议,一个不遵循。
➜ ~ curl -i http://localhost:9090/handle1
hello world%
➜ ~ curl -i http://localhost:9090/handle2
HTTP/1.1 200 OK
Date: Thu, 14 Jun 2018 07:51:31 GMT
Content-Length: 11
Content-Type: text/plain; charset=utf-8
hello world%
分别是以上两者的返回,可以看到,hijack之后的返回,虽然body是相同的,但是完全没有遵循http协议。
http包的源码:
func (c *conn) serve(ctx context.Context) {
...
serverHandler{c.server}.ServeHTTP(w, w.req)
w.cancelCtx()
if c.hijacked() {
return
}
w.finishRequest()
...
}
这是net/http包中的方法,也是http路由的核心方法。调用ServeHTTP方法,如果被hijack(劫持)了就直接return了,而一般的http请求会经过后边的finishRequest方法,加入headers等并关闭连接。
Hijack方法,一般在在创建连接阶段使用HTTP连接,后续自己完全处理connection。符合这样的使用场景的并不多,基于HTTP协议的rpc算一个,从HTTP升级到WebSocket也算一个。
go中自带的rpc可以直接复用http server处理请求的那一套流程去创建连接,连接创建完毕后再使用Hijack方法拿到连接。客户端通过向服务端发送method为connect的请求创建连接,创建成功后即可开始rpc调用。
websocket中的应用:websocket在创建连接的阶段与http使用相同的协议,而在后边的数据传输的过程中使用了他自己的协议,符合了Hijack的用途。通过serveWebSocket方法将HTTP协议升级到Websocket协议。
// ServeHTTP implements the http.Handler interface for a WebSocket
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.serveWebSocket(w, req)
}
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
}
// The server should abort the WebSocket connection if it finds
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
s.Handler(conn)
}
服务端已经能够接受 CONNECT 请求,并返回了 200 状态码 HTTP/1.0 200 Connected to Gee RPC,客户端要做的,发起 CONNECT 请求,检查返回状态码即可成功建立连接。
// NewHTTPClient new a Client instance via HTTP as transport protocol
func NewHTTPClient(conn net.Conn, opt *Option) (*Client, error) {
_, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath))
// Require successful HTTP response
// before switching to RPC protocol.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err == nil && resp.Status == connected {
return NewClient(conn, opt)
}
if err == nil {
err = errors.New("unexpected HTTP response: " + resp.Status)
}
return nil, err
}
// DialHTTP connects to an HTTP RPC server at the specified network address
// listening on the default HTTP RPC path.
func DialHTTP(network, address string, opts ...*Option) (*Client, error) {
return dialTimeout(NewHTTPClient, network, address, opts...)
}
通过 HTTP CONNECT 请求建立连接之后,后续的通信过程就交给 NewClient 了。
为了简化调用,提供了一个统一入口 XDial
// XDial calls different functions to connect to a RPC server
// according the first parameter rpcAddr.
// rpcAddr is a general format (protocol@addr) to represent a rpc server
// eg, [email protected]:7001, [email protected]:9999, unix@/tmp/geerpc.sock
func XDial(rpcAddr string, opts ...*Option) (*Client, error) {
parts := strings.Split(rpcAddr, "@")
if len(parts) != 2 {
return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr)
}
protocol, addr := parts[0], parts[1]
switch protocol {
case "http":
return DialHTTP("tcp", addr, opts...)
default:
// tcp, unix or other transport protocol
return Dial(protocol, addr, opts...)
}
}
支持 HTTP 协议的好处在于,RPC 服务仅仅使用了监听端口的 /_genrpc 路径,在其他路径上我们可以提供诸如日志、统计等更为丰富的功能。接下来我们在 /debug/genrpc 上展示服务的调用统计视图。
package GenRpc
import (
"fmt"
"html/template"
"net/http"
)
const debugText = `
GenRPC Services
{{range .}}
Service {{.Name}}
Method Calls
{{range $name, $mtype := .Method}}
{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error
{{$mtype.NumCalls}}
{{end}}
{{end}}
`
var debug = template.Must(template.New("RPC debug").Parse(debugText))
type debugHTTP struct {
*Server
}
type debugService struct {
Name string
Method map[string]*methodType
}
// Runs at /debug/geerpc
func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// Build a sorted version of the data.
var services []debugService
server.serviceMap.Range(func(namei, svci interface{}) bool {
svc := svci.(*service)
services = append(services, debugService{
Name: namei.(string),
Method: svc.method,
})
return true
})
err := debug.Execute(w, services)
if err != nil {
_, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error())
}
}
将 debugHTTP 实例绑定到地址 /debug/genrpc:
func (server *Server) HandleHTTP() {
http.Handle(defaultRPCPath, server)
http.Handle(defaultDebugPath, debugHTTP{server})
log.Println("rpc server debug path:", defaultDebugPath)
}
定义服务:
type Foo int
type Args struct{ Num1, Num2 int }
func (f Foo) Sum(args Args, reply *int) error {
*reply = args.Num1 + args.Num2
return nil
}
func startServer(addrCh chan string) {
var foo Foo
l, _ := net.Listen("tcp", ":9999")
_ = geerpc.Register(&foo)
geerpc.HandleHTTP()
addrCh <- l.Addr().String()
_ = http.Serve(l, nil)
}
客户端:
func call(addrCh chan string) {
client, _ := geerpc.DialHTTP("tcp", <-addrCh)
defer func() { _ = client.Close() }()
time.Sleep(time.Second)
// send request & receive response
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
args := &Args{Num1: i, Num2: i * i}
var reply int
if err := client.Call(context.Background(), "Foo.Sum", args, &reply); err != nil {
log.Fatal("call Foo.Sum error:", err)
}
log.Printf("%d + %d = %d", args.Num1, args.Num2, reply)
}(i)
}
wg.Wait()
}
func main() {
log.SetFlags(0)
ch := make(chan string)
go call(ch)
startServer(ch)
}
假设有多个服务实例,每个实例提供相同的功能,为了提高整个系统的吞吐量,每个实例部署在不同的机器上。客户端可以选择任意一个实例进行调用,获取想要的结果。那如何选择呢?取决了负载均衡的策略。
负载均衡的前提是有多个服务实例,首先实现一个最基础的服务发现模块 Discovery:
定义 2 个类型:
package xclient
import (
"errors"
"math"
"math/rand"
"sync"
"time"
)
type SelectMode int
const (
RandomSelect SelectMode = iota // select randomly
RoundRobinSelect // select using Robbin algorithm
)
type Discovery interface {
Refresh() error // refresh from remote registry
Update(servers []string) error
Get(mode SelectMode) (string, error)
GetAll() ([]string, error)
}
紧接着,实现一个不需要注册中心,服务列表由手工维护的服务发现的结构体:MultiServersDiscovery:
// MultiServersDiscovery is a discovery for multi servers without a registry center
// user provides the server addresses explicitly instead
type MultiServersDiscovery struct {
// r 是一个产生随机数的实例,初始化时使用时间戳设定随机数种子,避免每次产生相同的随机数序列。
r *rand.Rand // generate random number
mu sync.RWMutex // protect following
servers []string
// index 记录 Round Robin 算法已经轮询到的位置,为了避免每次从 0 开始,初始化时随机设定一个值。
index int // record the selected position for robin algorithm
}
// NewMultiServerDiscovery creates a MultiServersDiscovery instance
func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery {
d := &MultiServersDiscovery{
servers: servers,
r: rand.New(rand.NewSource(time.Now().UnixNano())),
}
d.index = d.r.Intn(math.MaxInt32 - 1)
return d
}
实现 Discovery 接口:
var _ Discovery = (*MultiServersDiscovery)(nil)
// Refresh doesn't make sense for MultiServersDiscovery, so ignore it
func (d *MultiServersDiscovery) Refresh() error {
return nil
}
// Update the servers of discovery dynamically if needed
func (d *MultiServersDiscovery) Update(servers []string) error {
d.mu.Lock()
defer d.mu.Unlock()
d.servers = servers
return nil
}
// Get a server according to mode
func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) {
d.mu.Lock()
defer d.mu.Unlock()
n := len(d.servers)
if n == 0 {
return "", errors.New("rpc discovery: no available servers")
}
switch mode {
case RandomSelect:
return d.servers[d.r.Intn(n)], nil
case RoundRobinSelect:
s := d.servers[d.index%n] // servers could be updated, so mode n to ensure safety
d.index = (d.index + 1) % n
return s, nil
default:
return "", errors.New("rpc discovery: not supported select mode")
}
}
// returns all servers in discovery
func (d *MultiServersDiscovery) GetAll() ([]string, error) {
d.mu.RLock()
defer d.mu.RUnlock()
// return a copy of d.servers
servers := make([]string, len(d.servers), len(d.servers))
copy(servers, d.servers)
return servers, nil
}
向用户暴露一个支持负载均衡的客户端 XClient。
type XClient struct {
d Discovery
mode SelectMode
opt *GenRpc.Option
mu sync.Mutex // protect following
clients map[string]*GenRpc.Client
}
var _ io.Closer = (*XClient)(nil)
func NewXClient(d Discovery, mode SelectMode, opt *Option) *XClient {
return &XClient{d: d, mode: mode, opt: opt, clients: make(map[string]*Client)}
}
func (xc *XClient) Close() error {
xc.mu.Lock()
defer xc.mu.Unlock()
for key, client := range xc.clients {
// I have no idea how to deal with error, just ignore it.
_ = client.Close()
delete(xc.clients, key)
}
return nil
}
XClient 的构造函数需要传入三个参数,服务发现实例 Discovery、负载均衡模式 SelectMode 以及协议选项 Option。为了尽量地复用已经创建好的 Socket 连接,使用 clients 保存创建成功的 Client 实例,并提供 Close 方法在结束后,关闭已经建立的连接。
实现客户端最基本的功能 Call:
func (xc *XClient) dial(rpcAddr string) (*Client, error) {
xc.mu.Lock()
defer xc.mu.Unlock()
client, ok := xc.clients[rpcAddr]
if ok && !client.IsAvailable() {
_ = client.Close()
delete(xc.clients, rpcAddr)
client = nil
}
if client == nil {
var err error
client, err = XDial(rpcAddr, xc.opt)
if err != nil {
return nil, err
}
xc.clients[rpcAddr] = client
}
return client, nil
}
func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error {
client, err := xc.dial(rpcAddr)
if err != nil {
return err
}
return client.Call(ctx, serviceMethod, args, reply)
}
// Call invokes the named function, waits for it to complete,
// and returns its error status.
// xc will choose a proper server.
func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
rpcAddr, err := xc.d.Get(xc.mode)
if err != nil {
return err
}
return xc.call(rpcAddr, ctx, serviceMethod, args, reply)
}
我们将复用 Client 的能力封装在方法 dial 中,dial 的处理逻辑如下:
另外,我们为 XClient 添加一个常用功能:Broadcast。
// Broadcast invokes the named function for every server registered in discovery
func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error {
servers, err := xc.d.GetAll()
if err != nil {
return err
}
var wg sync.WaitGroup
var mu sync.Mutex // protect e and replyDone
var e error
replyDone := reply == nil // if reply is nil, don't need to set value
ctx, cancel := context.WithCancel(ctx)
for _, rpcAddr := range servers {
wg.Add(1)
go func(rpcAddr string) {
defer wg.Done()
var clonedReply interface{}
if reply != nil {
clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()
}
err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply)
mu.Lock()
if err != nil && e == nil {
e = err
cancel() // if any call failed, cancel unfinished calls
}
if err == nil && !replyDone {
reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem())
replyDone = true
}
mu.Unlock()
}(rpcAddr)
}
wg.Wait()
return e
}
如果没有注册中心,客户端需要硬编码服务端的地址,而且没有机制保证服务端是否处于可用状态。当然注册中心的功能还有很多,比如配置的动态同步、通知机制等。比较常用的注册中心有 etcd(推荐)、zookeeper、consul,一般比较出名的微服务或者 RPC 框架,这些主流的注册中心都是支持的。
首先定义 GeeRegistry 结构体,默认超时时间设置为 5 min,也就是说,任何注册的服务超过 5 min,即视为不可用状态。
// GeeRegistry is a simple register center, provide following functions.
// add a server and receive heartbeat to keep it alive.
// returns all alive servers and delete dead servers sync simultaneously.
type GeeRegistry struct {
timeout time.Duration
mu sync.Mutex // protect following
servers map[string]*ServerItem
}
type ServerItem struct {
Addr string
start time.Time
}
const (
defaultPath = "/_geerpc_/registry"
defaultTimeout = time.Minute * 5
)
// New create a registry instance with timeout setting
func New(timeout time.Duration) *GeeRegistry {
return &GeeRegistry{
servers: make(map[string]*ServerItem),
timeout: timeout,
}
}
var DefaultGeeRegister = New(defaultTimeout)
为 GeeRegistry 实现添加服务实例和返回服务列表的方法。
func (r *GeeRegistry) putServer(addr string) {
r.mu.Lock()
defer r.mu.Unlock()
s := r.servers[addr]
if s == nil {
r.servers[addr] = &ServerItem{Addr: addr, start: time.Now()}
} else {
s.start = time.Now() // if exists, update start time to keep alive
}
}
func (r *GeeRegistry) aliveServers() []string {
r.mu.Lock()
defer r.mu.Unlock()
var alive []string
for addr, s := range r.servers {
if r.timeout == 0 || s.start.Add(r.timeout).After(time.Now()) {
alive = append(alive, addr)
} else {
delete(r.servers, addr)
}
}
sort.Strings(alive)
return alive
}
为了实现上的简单,GenRegistry 采用 HTTP 协议提供服务,且所有的有用信息都承载在 HTTP Header 中。
Get:返回所有可用的服务列表,通过自定义字段 X-Geerpc-Servers 承载。
Post:添加服务实例或发送心跳,通过自定义字段 X-Geerpc-Server 承载。
// Runs at /_geerpc_/registry
func (r *GeeRegistry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case "GET":
// keep it simple, server is in req.Header
w.Header().Set("X-Geerpc-Servers", strings.Join(r.aliveServers(), ","))
case "POST":
// keep it simple, server is in req.Header
addr := req.Header.Get("X-Geerpc-Server")
if addr == "" {
w.WriteHeader(http.StatusInternalServerError)
return
}
r.putServer(addr)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
}
// HandleHTTP registers an HTTP handler for GeeRegistry messages on registryPath
func (r *GeeRegistry) HandleHTTP(registryPath string) {
http.Handle(registryPath, r)
log.Println("rpc registry path:", registryPath)
}
func HandleHTTP() {
DefaultGeeRegister.HandleHTTP(defaultPath)
}
另外,提供 Heartbeat 方法,便于服务启动时定时向注册中心发送心跳,默认周期比注册中心设置的过期时间少 1 min。
// Heartbeat send a heartbeat message every once in a while
// it's a helper function for a server to register or send heartbeat
func Heartbeat(registry, addr string, duration time.Duration) {
if duration == 0 {
// make sure there is enough time to send heart beat
// before it's removed from registry
duration = defaultTimeout - time.Duration(1)*time.Minute
}
var err error
err = sendHeartbeat(registry, addr)
go func() {
// 定时发送心跳报文
t := time.NewTicker(duration)
for err == nil {
<-t.C
err = sendHeartbeat(registry, addr)
}
}()
}
func sendHeartbeat(registry, addr string) error {
log.Println(addr, "send heart beat to registry", registry)
httpClient := &http.Client{}
req, _ := http.NewRequest("POST", registry, nil)
req.Header.Set("X-Geerpc-Server", addr)
if _, err := httpClient.Do(req); err != nil {
log.Println("rpc server: heart beat err:", err)
return err
}
return nil
}
在 xclient 中对应实现 Discovery:
package xclient
type GeeRegistryDiscovery struct {
*MultiServersDiscovery
registry string
timeout time.Duration
lastUpdate time.Time
}
const defaultUpdateTimeout = time.Second * 10
func NewGeeRegistryDiscovery(registerAddr string, timeout time.Duration) *GeeRegistryDiscovery {
if timeout == 0 {
timeout = defaultUpdateTimeout
}
d := &GeeRegistryDiscovery{
MultiServersDiscovery: NewMultiServerDiscovery(make([]string, 0)),
registry: registerAddr,
timeout: timeout,
}
return d
}
实现 Update 和 Refresh 方法,超时重新获取的逻辑在 Refresh 中实现:
func (d *GeeRegistryDiscovery) Update(servers []string) error {
d.mu.Lock()
defer d.mu.Unlock()
d.servers = servers
d.lastUpdate = time.Now()
return nil
}
func (d *GeeRegistryDiscovery) Refresh() error {
d.mu.Lock()
defer d.mu.Unlock()
if d.lastUpdate.Add(d.timeout).After(time.Now()) {
return nil
}
log.Println("rpc registry: refresh servers from registry", d.registry)
resp, err := http.Get(d.registry)
if err != nil {
log.Println("rpc registry refresh err:", err)
return err
}
servers := strings.Split(resp.Header.Get("X-Geerpc-Servers"), ",")
d.servers = make([]string, 0, len(servers))
for _, server := range servers {
if strings.TrimSpace(server) != "" {
d.servers = append(d.servers, strings.TrimSpace(server))
}
}
d.lastUpdate = time.Now()
return nil
}
Get 和 GetAll 与 MultiServersDiscovery 相似,唯一的不同在于,GeeRegistryDiscovery 需要先调用 Refresh 确保服务列表没有过期。
func (d *GeeRegistryDiscovery) Get(mode SelectMode) (string, error) {
if err := d.Refresh(); err != nil {
return "", err
}
return d.MultiServersDiscovery.Get(mode)
}
func (d *GeeRegistryDiscovery) GetAll() ([]string, error) {
if err := d.Refresh(); err != nil {
return nil, err
}
return d.MultiServersDiscovery.GetAll()
}
GO语言著名websocket框架:gorilla/websocket,下载方式:go get github.com/gorilla/websocket
gorilla定义了Upgrade方法将HTTP升级为WebSocket,
...
h, ok := w.(http.Hijacker)
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
...