类似java spring中的拦截器。gRpc也有拦截器的说法,拦截器可作用于客户端请求,服务端请求。对请求进行拦截,进行业务上的一些封装校验等,类似一个中间件的作用
拦截器类型
使用场景:
拦截器可以从元数据获取一些认证进行进行校验。
拦截器定义
interceptor.go
package server
import (
"context"
"errors"
"fmt"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"strings"
)
// UnaryInterceptor 一元请求拦截器
func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
fmt.Println("server UnaryInterceptor:", info)
if err := oauth2Valid(ctx); err != nil {
return nil, err
}
return handler(ctx, req)
}
// StreamInterceptor 流式拦截器
func StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
fmt.Println("server StreamInterceptor")
fmt.Println(info)
if err := oauth2Valid(ss.Context()); err != nil {
return err
}
return handler(srv, ss)
}
// oauth2认证,从上下文获取请求元数据
func oauth2Valid(ctx context.Context) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return errors.New("元数据获取失败, 身份认证失败")
}
authorization := md["authorization"]
if !valid(authorization) {
return errors.New("令牌校验不通过, 身份认证失败")
}
return nil
}
func valid(authorization []string) bool {
if len(authorization) < 1 {
return false
}
token := strings.TrimPrefix(authorization[0], "Bearer ")
return token == fetchToken()
}
func fetchToken() string {
return "some-secret-token"
}
拦截器配置
package main
import (
"flag"
"fmt"
"google.golang.org/grpc"
"grpc/echo"
"grpc/echo-server-practice/server"
"log"
"net"
)
var (
port = flag.Int("port", 50053, "port")
)
func getOptions() (opts []grpc.ServerOption) {
opts = make([]grpc.ServerOption, 0)
opts = append(opts, server.GetMTlsOpt())
// 附加一个拦截器,还有链式拦截器 ChainInterceptor
opts = append(opts, grpc.UnaryInterceptor(server.UnaryInterceptor))
opts = append(opts, grpc.StreamInterceptor(server.StreamInterceptor))
return opts
}
func main() {
flag.Parse()
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
if err != nil {
log.Fatal(err)
}
// grpc server
s := grpc.NewServer(getOptions()...)
......
}
客户端拦截器中简单实现利用oauth2做认证
package client
import (
"fmt"
"golang.org/x/net/context"
"google.golang.org/grpc"
"grpc/echo-client/client"
)
// UnaryInterceptor 客户端一元请求拦截器
func UnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
fmt.Println("server UnaryInterceptor: ", req)
// 其实就proto生成的一元请求里的invoke差不多
return invoker(ctx, method, req, reply, cc, opts...)
}
// StreamInterceptor 客户端流式拦截器
func StreamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
fmt.Println("client streamInterceptor")
// 和流式请求里的NewStream也是一样的
return streamer(ctx, desc, cc, method, opts...)
}
如果客户端启动没配置,可以在拦截器中添加。
客户端 main.go
package main
import (
"flag"
"google.golang.org/grpc"
"grpc/echo"
"grpc/echo-client-practice/client"
"log"
)
var (
addr = flag.String("host", "localhost:50053", "")
)
func getDiaOption() []grpc.DialOption {
dialOptions := make([]grpc.DialOption, 0)
dialOptions = append(dialOptions, client.GetMTlsOpt())
dialOptions = append(dialOptions, grpc.WithUnaryInterceptor(client.UnaryInterceptor))
dialOptions = append(dialOptions, grpc.WithStreamInterceptor(client.StreamInterceptor))
dialOptions = append(dialOptions, client.GetAuth(client.FetchToken()))
return dialOptions
}
func main() {
flag.Parse()
conn, err := grpc.Dial(*addr, getDiaOption()...)
if err != nil {
log.Fatal(err)
}
defer conn.Close()
# 下面是伪代码
c := echo.NewYourClient(conn)
c.CallYourRpc(your_request)
}
func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
n := len(interceptors)
if n > 1 {
lastI := n - 1
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
var (
chainHandler grpc.UnaryInvoker
curI int
)
chainHandler = func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
if curI == lastI {
return invoker(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentOpts...)
}
curI++
err := interceptors[curI](currentCtx, currentMethod, currentReq, currentRepl, currentConn, chainHandler, currentOpts...)
curI--
return err
}
return interceptors[0](ctx, method, req, reply, cc, chainHandler, opts...)
}
}
...
}
当拦截器数量大于 1 时,从 interceptors[1]
开始递归,每一个递归的拦截器 interceptors[i]
会不断地执行,最后才真正的去执行 handler
方法。