golang工程——gRpc 拦截器及原理

oauth2认证与拦截器

类似java spring中的拦截器。gRpc也有拦截器的说法,拦截器可作用于客户端请求,服务端请求。对请求进行拦截,进行业务上的一些封装校验等,类似一个中间件的作用

拦截器类型

  • 一元请求拦截器
  • 流式请求拦截器
  • 链式拦截器(一个个调用对应处理函数)

使用场景:

拦截器可以从元数据获取一些认证进行进行校验。

服务端拦截器

拦截器定义

interceptor.go

package server

import (
    "context"
    "errors"
    "fmt"
    "google.golang.org/grpc"
    "google.golang.org/grpc/metadata"
    "strings"
)

// UnaryInterceptor 一元请求拦截器
func UnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {

    fmt.Println("server UnaryInterceptor:", info)

    if err := oauth2Valid(ctx); err != nil {
        return nil, err
    }
    return handler(ctx, req)
}

// StreamInterceptor 流式拦截器
func StreamInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    fmt.Println("server StreamInterceptor")
    fmt.Println(info)

    if err := oauth2Valid(ss.Context()); err != nil {
        return err
    }

    return handler(srv, ss)
}

// oauth2认证,从上下文获取请求元数据
func oauth2Valid(ctx context.Context) error {
    md, ok := metadata.FromIncomingContext(ctx)

    if !ok {
        return errors.New("元数据获取失败, 身份认证失败")
    }
    authorization := md["authorization"]
    if !valid(authorization) {
        return errors.New("令牌校验不通过, 身份认证失败")
    }
    return nil
}

func valid(authorization []string) bool {
    if len(authorization) < 1 {
        return false
    }
    token := strings.TrimPrefix(authorization[0], "Bearer ")
    return token == fetchToken()
}

func fetchToken() string {
    return "some-secret-token"
}

拦截器配置

package main

import (
    "flag"
    "fmt"
    "google.golang.org/grpc"
    "grpc/echo"
    "grpc/echo-server-practice/server"
    "log"
    "net"
)

var (
    port = flag.Int("port", 50053, "port")
)

func getOptions() (opts []grpc.ServerOption) {
    opts = make([]grpc.ServerOption, 0)
    opts = append(opts, server.GetMTlsOpt())
    // 附加一个拦截器,还有链式拦截器 ChainInterceptor
    opts = append(opts, grpc.UnaryInterceptor(server.UnaryInterceptor))
    opts = append(opts, grpc.StreamInterceptor(server.StreamInterceptor))
    return opts
}

func main() {
    flag.Parse()

    lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))

    if err != nil {
        log.Fatal(err)
    }

    // grpc server
    s := grpc.NewServer(getOptions()...)
    ......
}
客户端拦截器

客户端拦截器中简单实现利用oauth2做认证

package client

import (
    "fmt"
    "golang.org/x/net/context"
    "google.golang.org/grpc"
    "grpc/echo-client/client"
)



// UnaryInterceptor 客户端一元请求拦截器

func UnaryInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    fmt.Println("server UnaryInterceptor: ", req)
    // 其实就proto生成的一元请求里的invoke差不多
    return invoker(ctx, method, req, reply, cc, opts...)
}
// StreamInterceptor 客户端流式拦截器
func StreamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    fmt.Println("client streamInterceptor")

    // 和流式请求里的NewStream也是一样的
    return streamer(ctx, desc, cc, method, opts...)
}

如果客户端启动没配置,可以在拦截器中添加。

客户端 main.go

package main

import (
    "flag"
    "google.golang.org/grpc"
    "grpc/echo"
    "grpc/echo-client-practice/client"
    "log"
)

var (
    addr = flag.String("host", "localhost:50053", "")
)


func getDiaOption() []grpc.DialOption {
    dialOptions := make([]grpc.DialOption, 0)
    dialOptions = append(dialOptions, client.GetMTlsOpt())

    dialOptions = append(dialOptions, grpc.WithUnaryInterceptor(client.UnaryInterceptor))
    dialOptions = append(dialOptions, grpc.WithStreamInterceptor(client.StreamInterceptor))
    dialOptions = append(dialOptions, client.GetAuth(client.FetchToken()))

    return dialOptions
}

func main() {
    flag.Parse()

    conn, err := grpc.Dial(*addr, getDiaOption()...)

    if err != nil {
        log.Fatal(err)
    }

    defer conn.Close()
    # 下面是伪代码
    c := echo.NewYourClient(conn)

    c.CallYourRpc(your_request)
}

链式拦截器执行核心原理
func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
 n := len(interceptors)
 if n > 1 {
  lastI := n - 1
  return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
   var (
    chainHandler grpc.UnaryInvoker
    curI         int
   )

   chainHandler = func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
    if curI == lastI {
     return invoker(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentOpts...)
    }
    curI++
    err := interceptors[curI](currentCtx, currentMethod, currentReq, currentRepl, currentConn, chainHandler, currentOpts...)
    curI--
    return err
   }

   return interceptors[0](ctx, method, req, reply, cc, chainHandler, opts...)
  }
 }
    ...
}

当拦截器数量大于 1 时,从 interceptors[1] 开始递归,每一个递归的拦截器 interceptors[i] 会不断地执行,最后才真正的去执行 handler 方法。

你可能感兴趣的:(golang,后端)