Gin框架学习(三)

Gin框架基础

  • 学习思维导图
    • 提高篇
      • 自定义中间件
        • Logger日志
        • JWT的认证中间件(鉴权)
        • CORS跨域
      • 登录中间件
  • 后记

学习思维导图

Gin框架学习(三)_第1张图片

提高篇

自定义中间件

Logger日志

package main

import (
    "github.com/gin-gonic/gin"
)
func hello(c *gin.Context) {
    c.JSON(200, "ok")
}
func main() {
    r := gin.New()
    r.Use(gin.Logger())   //日志信息:启动服务请求curl http://localhost:8080/hello
    r.GET("/hello", hello)  
    r.Run()
}

JWT的认证中间件(鉴权)

// 定义一个JWTAuth的中间件
func JWTAuth() gin.HandlerFunc {
	return func(c *gin.Context) {
		// 通过http header中的token解析来认证
		token := c.Request.Header.Get("token")
		if token == "" {
			c.JSON(http.StatusOK, gin.H{
				"status": -1,
				"msg":    "请求未携带token,无权限访问",
				"data":   nil,
			})
			c.Abort()
			return
		}
		log.Print("get token: ", token)
		// 初始化一个JWT对象实例,并根据结构体方法来解析token
		j := NewJWT()
		// 解析token中包含的相关信息(有效载荷)
		claims, err := j.ParserToken(token)
		if err != nil {
			// token过期
			if err == TokenExpired {
				c.JSON(http.StatusOK, gin.H{
					"status": -1,
					"msg":    "token授权已过期,请重新申请授权",
					"data":   nil,
				})
				c.Abort()
				return
			}
			// 其他错误
			c.JSON(http.StatusOK, gin.H{
				"status": -1,
				"msg":    err.Error(),
				"data":   nil,
			})
			c.Abort()
			return
		}
		// 将解析后的有效载荷claims重新写入gin.Context引用对象中
		c.Set("claims", claims)
	}
}

CORS跨域

func CORSMiddleware() gin.HandlerFunc {
    return func(c *gin.Context) {
        c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
        c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
        c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
        c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
        if c.Request.Method == "OPTIONS" {
            c.AbortWithStatus(204)
            return
        }
        c.Next()
    }
}

登录中间件

main.go

package main

import (
    "fmt"
    "log"
    "net/http"

    "github.com/gin-gonic/gin"
    "github.com/zhuge20100104/gin_session/gsession"
)

func main() {
    r := gin.Default()
    mgrObj, err := gsession.CreateSessionMgr(gsession.Redis, "localhost:6379")
    if err != nil {
        log.Fatalf("Create manager obj failed, err: %v\n", err)
        return
    }
    sm := gsession.SessionMiddleware(mgrObj, gsession.Options{
        Path:     "/",
        Domain:   "127.0.0.1",
        MaxAge:   120,
        Secure:   false,
        HttpOnly: true,
    })
    r.Use(sm)
    r.GET("/incr", func(c *gin.Context) {
        session := c.MustGet("session").(gsession.Session)
        fmt.Printf("%#v\n", session)
        var count int
        v, err := session.Get("count")
        if err != nil {
            log.Printf("get count from session failed, err: %v\n", err)
            count = 0
        } else {
            count = v.(int)
            count++
        }
        session.Set("count", count)
        session.Save()
        c.String(http.StatusOK, "count:%v", count)
    })
    r.Run()
}

session.go

package gsession

import (
    "fmt"
    "log"

    "github.com/gin-gonic/gin"
)

type SessionMgrType string

const (
    // SessionID在cookie里面的名字
    SessionCookieName = "session_id"
    // Session对象在Context里面的名字
    SessionContextName                = "session"
    Memory             SessionMgrType = "memory"
    Redis              SessionMgrType = "redis"
)

// Session 接口
type Session interface {
    // 获取Session对象的ID
    ID() string
    // 加载redis数据到 session data
    Load() error
    // 获取key对应的value值
    Get(string) (interface{}, error)
    // 设置key对应的value值
    Set(string, interface{})
    // 删除key对应的value值
    Del(string)
    // 落盘数据到redis
    Save()
    // 设置Redis数据过期时间,内存版本无效
    SetExpired(int)
}

// SessionMgr Session管理器对象
type SessionMgr interface {
    // 初始化Redis数据库连接
    Init(addr string, options ...string) error
    // 通过SessionID获取已经初始化的Session对象
    GetSession(string) (Session, error)
    // 创建一个新的Session对象
    CreateSession() Session
    // 使用SessionID清空一个Session对象
    Clear(string)
}

// Options Cookie对应的相关选项
type Options struct {
    Path   string
    Domain string
    // Cookie中的SessionID存活时间
    // MaxAge=0 means no 'Max-Age' attribute specified.
    // MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
    // MaxAge>0 means Max-Age attribute present and given in seconds.
    MaxAge   int
    Secure   bool
    HttpOnly bool
}

func CreateSessionMgr(name SessionMgrType, addr string, options ...string) (sm SessionMgr, err error) {
    switch name {
    case Memory:
        sm = NewMemSessionMgr()
    case Redis:
        sm = NewRedisSessionMgr()
    default:
        err = fmt.Errorf("unsupported %v\n", name)
        return
    }
    err = sm.Init(addr, options...)
    return
}

func SessionMiddleware(sm SessionMgr, options Options) gin.HandlerFunc {
    return func(c *gin.Context) {
        var session Session
        // 尝试从cookie获取session ID
        sessionID, err := c.Cookie(SessionCookieName)
        if err != nil {
            log.Printf("get session_id from cookie failed, err:%v\n", err)
            session = sm.CreateSession()
            sessionID = session.ID()
        } else {
            log.Printf("SessionId: %v\n", sessionID)
            session, err = sm.GetSession(sessionID)
            if err != nil {
                log.Printf("Get session by %s failed, err: %v\n", sessionID, err)
                session = sm.CreateSession()
                sessionID = session.ID()
            }
        }

        session.SetExpired(options.MaxAge)
        c.Set(SessionContextName, session)
        c.SetCookie(SessionCookieName, sessionID, options.MaxAge, options.Path, options.Domain, options.Secure, options.HttpOnly)
        defer sm.Clear(sessionID)
        c.Next()
    }
}

memory.go

package gsession

import (
    "fmt"
    "sync"
    uuid "github.com/satori/go.uuid"
)
// memSession 内存对应的Session对象
type memSession struct {
    // 全局唯一标识的session id对象
    id string
    // session数据
    data map[string]interface{}
    // session过期时间
    expired int
    // 读写锁,支持多线程
    rwLock sync.RWMutex
}
func NewMemSession(id string) *memSession {
    return &memSession{
        id:   id,
        data: make(map[string]interface{}, 8),
    }
}
func (m *memSession) ID() string {
    return m.id
}
func (m *memSession) Load() (err error) {
    return
}
func (m *memSession) Get(key string) (value interface{}, err error) {
    m.rwLock.RLock()
    defer m.rwLock.RUnlock()
    value, ok := m.data[key]
    if !ok {
        err = fmt.Errorf("Invalid key")
        return
    }
    return
}
func (m *memSession) Set(key string, value interface{}) {
    m.rwLock.Lock()
    defer m.rwLock.Unlock()
    m.data[key] = value
}
func (m *memSession) Del(key string) {
    m.rwLock.Lock()
    defer m.rwLock.Unlock()
    delete(m.data, key)
}
func (m *memSession) Save() {
    return
}
func (m *memSession) SetExpired(expired int) {
    m.expired = expired
}
// MemSessionMgr 内存Session管理器
type MemSessionMgr struct {
    session map[string]Session
    rwLock  sync.RWMutex
}
// NewMemSessionMgr MemSessionMgr类构造函数
func NewMemSessionMgr() *MemSessionMgr {
    return &MemSessionMgr{
        session: make(map[string]Session, 1024),
    }
}
func (m *MemSessionMgr) Init(addr string, options ...string) (err error) {
    return
}
// GetSession get the session by session id
func (m *MemSessionMgr) GetSession(sessionID string) (sd Session, err error) {
    m.rwLock.RLock()
    defer m.rwLock.RUnlock()
    sd, ok := m.session[sessionID]
    if !ok {
        err = fmt.Errorf("Invalid session id")
        return
    }
    return
}
func (m *MemSessionMgr) CreateSession() (sd Session) {
    sessionID := uuid.NewV4().String()
    sd = NewMemSession(sessionID)
    m.session[sd.ID()] = sd
    return
}
func (m *MemSessionMgr) Clear(sessionID string) {
    m.rwLock.Lock()
    defer m.rwLock.Unlock()
    delete(m.session, sessionID)
}

redis.go

package gsession
import (
    "bytes"
    "encoding/gob"
    "fmt"
    "log"
    "strconv"
    "sync"
    "time"
    "github.com/go-redis/redis"
    uuid "github.com/satori/go.uuid"
)
// redisSession redis session对象
type redisSession struct {
    // redis session id 对象
    id string
    // session 数据对象
    data map[string]interface{}
    // session 数据是否有更新
    modifyFlag bool
    // 过期时间
    expired int
    rwLock  sync.RWMutex
    client  *redis.Client
}
func NewRedisSession(id string, client *redis.Client) (session Session) {
    session = &redisSession{
        id:     id,
        data:   make(map[string]interface{}, 8),
        client: client,
    }
    return
}
func (r *redisSession) ID() string {
    return r.id
}
func (r *redisSession) Load() (err error) {
    data, err := r.client.Get(r.id).Bytes()
    if err != nil {
        log.Printf("get session data from redis by %s failed, err: %v\n", r.id, err)
        return
    }
    dec := gob.NewDecoder(bytes.NewBuffer(data))
    err = dec.Decode(&r.data)
    if err != nil {
        log.Printf("gob decode session data failed, err: %v\n", err)
        return
    }
    return
}
func (r *redisSession) Get(key string) (value interface{}, err error) {
    r.rwLock.RLock()
    defer r.rwLock.RUnlock()
    value, ok := r.data[key]
    if !ok {
        err = fmt.Errorf("invalid key")
        return
    }
    return
}
func (r *redisSession) Set(key string, value interface{}) {
    r.rwLock.Lock()
    defer r.rwLock.Unlock()
    r.data[key] = value
    r.modifyFlag = true
}
func (r *redisSession) Del(key string) {
    r.rwLock.Lock()
    defer r.rwLock.Unlock()
    delete(r.data, key)
    r.modifyFlag = true
}
func (r *redisSession) SetExpired(expired int) {
    r.expired = expired
}
func (r *redisSession) Save() {
    r.rwLock.Lock()
    defer r.rwLock.Unlock()
    if !r.modifyFlag {
        return
    }
    buf := new(bytes.Buffer)
    enc := gob.NewEncoder(buf)
    err := enc.Encode(r.data)
    if err != nil {
        log.Fatalf("gob encode r.data failed, err: %v\n", err)
        return
    }
    r.client.Set(r.id, buf.Bytes(), time.Second*time.Duration(r.expired))
    log.Printf("set data %v to redis.\n", buf.Bytes())
    r.modifyFlag = false
}
// redisSessionMgr redis Session管理器对象
type redisSessionMgr struct {
    session map[string]Session
    rwLock  sync.RWMutex
    client  *redis.Client
}
// NewRedisSessionMgr Redis SessionMgr类构造函数
func NewRedisSessionMgr() *redisSessionMgr {
    return &redisSessionMgr{
        session: make(map[string]Session, 1024),
    }
}
func (r *redisSessionMgr) Init(addr string, options ...string) (err error) {
    var (
        password string
        db       int
    )
    if len(options) == 1 {
        password = options[0]
    }
    if len(options) == 2 {
        password = options[0]
        db, err = strconv.Atoi(options[1])
        if err != nil {
            log.Fatalln("invalid redis DB param")
        }
    }
    r.client = redis.NewClient(&redis.Options{
        Addr:     addr,
        Password: password,
        DB:       db,
    })
    _, err = r.client.Ping().Result()
    if err != nil {
        return
    }
    return nil
}
func (r *redisSessionMgr) GetSession(sessionID string) (sd Session, err error) {
    sd = NewRedisSession(sessionID, r.client)
    err = sd.Load()
    if err != nil {
        return
    }
    r.rwLock.RLock()
    r.session[sessionID] = sd
    r.rwLock.RUnlock()
    return
}
func (r *redisSessionMgr) CreateSession() (sd Session) {
    sessionID := uuid.NewV4().String()
    sd = NewRedisSession(sessionID, r.client)
    r.session[sd.ID()] = sd
    return
}
func (r *redisSessionMgr) Clear(sessionID string) {
    r.rwLock.Lock()
    defer r.rwLock.Unlock()
    delete(r.session, sessionID)
}

后记

喜欢的话可以三连,后续继续更新其他内容,帮忙推一推,感谢观看!

你可能感兴趣的:(gin框架,web开发,golang,开发语言,后端)