对比原教程,这里使用context来处理子协程的泄露问题。
超时处理是 RPC 框架一个比较基本的能力,如果缺少超时处理机制,无论是服务端还是客户端都容易因为网络或其他错误导致挂死,资源耗尽,这些问题的出现大大地降低了服务的可用性。因此,我们需要在 RPC 框架中加入超时处理的能力。
纵观整个远程调用的过程,需要客户端处理超时的地方有:
与服务端建立连接,导致的超时
发送请求到服务端,写报文导致的超时
等待服务端处理时,等待处理导致的超时(比如服务端已挂死,迟迟不响应)
从服务端接收响应时,读报文导致的超时
需要服务端处理超时的地方有:
读取客户端请求报文时,读报文导致的超时
发送响应报文时,写报文导致的超时
调用映射服务的方法时,处理报文导致的超时
其RPC 在 3 个地方添加超时处理机制。分别是:
Client.Call()
整个过程导致的超时(不仅包含发送报文,还包括等待处理,接收报文所有阶段)Server.handleRequest
超时。为了实现简单,把一些超时时间设定放在Option结构体中。有两个超时时间,连接超时ConnectTimeout,服务端处理超时HandleTimeout。
type Option struct {
MagicNumber int // MagicNumber marks this's a geerpc request
CodecType codec.CodeType // client may choose different Codec to encode body
ConnectTimeout time.Duration //0 表示没有限制
HandleTimeout time.Duration
}
var DefaultOption = &Option{
MagicNumber: MagicNumber,
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10, //默认连接超时是10s
}
客户端连接时候是使用Dail方法,那我们就为 Dial 添加一层超时处理的外壳即可。
重点在dialTimeout函数。
net.Dial
替换为 net.DialTimeout
,如果连接创建超时,将返回错误。time.After()
信道先接收到消息,则说明 NewClient 执行超时,返回错误。需要讲下newClientFunc,为什么需要这个类型呢,大家应该明白就是直接是用NewClient函数就行的,为什么还要多此一举,要从函数参数中把该函数入参呢,为什么不直接在代码里把f(conn,opt)就写成NewClient(conn,opt)呢。
这是为了后面方便的,下一节会支持HTTP协议的,那就会有HTTP连接,而现在的是TCP连接,而HTTP连接对比TCP连接还有些操作的。这样我们把这个做成参数可以入参,这样我们就可以复用dialTimeout,我们把建立HTTP连接的函数传递给dialTimeout,不再需要为HTTP连接的再重新写dialTimeout。而这也方便了后面对该函数的测试。
type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error)
type clientResult struct {
client *Client
err error
}
func dialTimeout(f newClientFunc, network, address string, opt *Option) (client *Client, err error) {
//超时连接检测
conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
if err != nil {
return nil, err
}
//设置超时时间的情况
ch := make(chan clientResult)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func(ctx context.Context) {
client, err = f(conn, opt) //在这一节,f(conn, opt)就是NewClient(conn, opt)
select {
case <-ctx.Done():
return
default:
ch <- clientResult{client: client, err: err}
}
}(ctx)
if opt.ConnectTimeout == 0 {
result := <-ch
return result.client, result.err
}
select {
case <-time.After(opt.ConnectTimeout):
cancel() //超时通知子协程结束退出
return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
case result := <-ch:
return result.client, result.err
}
}
func Dail(network, address string, opts ...*Option) (client *Client, err error) {
opt, err := parseOptions(opts...)
if err != nil {
return nil, err
}
return dialTimeout(NewClient, network, address, opt)
}
Client.Call 的超时处理机制,使用 context 包实现,控制权交给用户,控制更为灵活。
这里的超时处理,不止是客户端发送给服务端所需的时间,还包括了客户端等待服务端发送回复的这段时间。即是这个超时时间是要等接收完服务端的回复信息。
代码case call := <-call.Done表示要等待收到服务端的回复信息,要是这时候先执行case <-ctx.Done(),那就表示超时了。
所以后面的测试Client.Call 超时要留意其超时。
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply any) error {
call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
select {
case <-ctx.Done():
client.removeCall(call.Seq)
return errors.New("rpc client: call failed: " + ctx.Err().Error())
case call := <-call.Done:
return call.Error
}
//之前的写法
// call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
// return call.Error
}
用户可以使用 context.WithTimeout
创建具备超时检测能力的 context 对象来控制。
var reply int
ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
client.Call(ctx, "My.Sum", args, &reply);
和客户端连接超时处理相似,也是使用context来控制子协程的退出。
开启一个新协程去执行call方法。通道called用来表示消息处理发送是否完毕。超时没有限制时,主协程就会阻塞在timeout==0的<-called中,等到发送完毕后,通道called有数据了,子协程结束,主协程也解除阻塞,退出。
超时有限制情况,假如超时了,就会执行在 case <-time.After(timeout)处调用cancel(),那子协程中的case <-ctx.Done()就会处理,发送超时处理的信息给客户端并退出子协程。
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
called := make(chan struct{})
go func(ctx context.Context) {
err := req.svc.call(req.mtype, req.argv, req.replyv)
select {
case <-ctx.Done():
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
server.sendResponse(cc, req.h, invalidRequest, sending)
default:
if err != nil {
fmt.Println("call err:", err)
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
} else {
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
}
called <- struct{}{}
}
}(ctx)
if timeout == 0 {
<-called
return
}
select {
case <-time.After(timeout):
cancel()
case <-called:
return
}
}
上一节的handleRequest中是没有timeout这个参数的,所以在使用该方法时候,需要加上timeout。需要修改下(Server).Serveconn方法。
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
//其余的没有改变
server.servCode(f(conn), &opt) //之前是server.servCode(f(conn))
}
//使用handleRequest的地方
func (server *Server) servCode(cc codec.Codec, opt *Option) {
//...................
for {
go server.handleRequest(cc, req, sending, &wg, opt.HandleTimeout)
}
}
第一个测试用例,用于测试连接超时。NewClient 函数耗时 3s,ConnectionTimeout 分别设置为 1s 和 0 两种场景。
这里newClientFunc类型就派上用场了,这里就可以设置该函数耗时,方便测试。
func TestClient_dialTimeout(t *testing.T) {
t.Parallel() //表示该测试将与(并且仅与)其他并行测试并行运行。
l, _ := net.Listen("tcp", "localhost:10000")
f := func(conn net.Conn, opt *Option) (*Client, error) {
conn.Close()
time.Sleep(time.Second * 2)
return nil, nil
}
//命令行执行 go test -run TestClient_dialTimeout/timeout 测试
t.Run("timeout", func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second})
_assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error")
})
//命令行执行 go test -run TestClient_dialTimeout/0 测试
t.Run("0", func(t *testing.T) {
_, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0})
_assert(err == nil, "0 means no limit")
})
}
func _assert(condition bool, msg string, v ...interface{}) {
if !condition {
panic(fmt.Sprintf("assertion failed: "+msg, v...))
}
}
第二个测试用例,用于测试处理超时。Bar.Timeout
耗时 2s。
场景一:客户端设置超时时间为 1s,服务端无限制(这个就是Client.Call超时的情况,需要注意)
场景二:服务端设置超时时间为1s,客户端无限制。
type Bar int
func (b *Bar) Timeout(argv int, reply *int) error {
time.Sleep(time.Second * 3) // 模拟3s的工作
return nil
}
func startServer(addr chan string) {
var b Bar
_ = Register(&b)
l, _ := net.Listen("tcp", "localhost:10000")
addr <- l.Addr().String()
Accept(l)
}
func TestClient_Call(t *testing.T) {
t.Parallel()
addrCh := make(chan string)
go startServer(addrCh)
addr := <-addrCh
time.Sleep(time.Second)
t.Run("client_timeout", func(t *testing.T) {
client, _ := Dail("tcp", addr)
ctx, _ := context.WithTimeout(context.Background(), time.Second*1)
var reply int
err := client.Call(ctx, "Bar.Timeout", 1, &reply)
_assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error")
})
t.Run("server_hander_timeout", func(t *testing.T) {
client, _ := Dail("tcp", addr, &Option{
HandleTimeout: time.Second,
})
var reply int
err := client.Call(context.Background(), "Bar.Timeout", 1, &reply)
_assert(err != nil && strings.Contains(err.Error(), "handle timeout"), "expect a timeout error")
})
}
完整代码:https://githubfast.com/liwook/Go-projects/tree/main/geerpc/4-timeout