通过阅读gin框架的源码来探究gin框架路由与中间件的秘密。
gin框架使用的是定制版本的httprouter,其路由的原理是大量使用公共前缀的树结构,它基本上是一个紧凑的Trie tree 或者只是(Radix Tree)。具有公共前缀的节点也共享一个公共父节点。
基数树,由成为PAT位树,是一种更节省空间的前缀树。对于基数树的每个节点,如果该节点是唯一的子树的话,就和父节点合并。
Radix Tree 可以被认为是一颗简洁版的前缀树。我们注册路由的过程就是构造前缀树的过程,具有公共前缀的节点也共享一个公共父节点。假设我们现在注册有以下路由信息:
r := gin.Default()
r.GET("/", func1)
r.GET("/search/", func2)
r.GET("/support/", func3)
r.GET("/blog/", func4)
r.GET("/blog/:post/", func5)
r.GET("/about-us/", func6)
r.GET("/about-us/team/", func7)
r.GET("/contact/", func8)
那么我们会得到一个GET方法对应的路由树,具体结构如下:
Priority Path Handle
9 \ *<1>
3 ├s nil
2 |├earch\ *<2>
1 |└upport\ *<3>
2 ├blog\ *<4>
1 | └:post nil
1 | └\ *<5>
2 ├about-us\ *<6>
1 | └team\ *<7>
1 └contact\ *<8>
上面最右边那一列每个*<数字>表示Handle处理函数的内存地址(一个指针)。从根节点遍历到叶子节点我们就能得到完整的路由表。
例如:blog/:post其中:post只是实际文章名称的占位符(参数)。与hash-maps不同,这种树结构还允许我们使用像:post参数这种动态部分,因为我们实际上是根据路由模式进行匹配,而不仅仅是比较哈希值。
由于URL路径具有层次结构,并且只使用优先的一组字符(字节值),所以很可能有许多常见的前缀。这使得我们可以很容易的将路由简化为更小的问题。此外, 路由器为每种请求方法管理一棵单独的树。 一方面,它比在每个节点中都保存一个method-> handle map 更加节省空间,它还使我们甚至可以在开始在前缀树中查找之前大大减少路由问题。
为了获得更好的伸缩性,每个树级别上的子节点都按 Priority (优先级)排序,其中优先级(最左列)就是在子节点(子节点、子子节点等等)注册的句柄的数量。这样做有两个好处:
├------------
├---------
├-----
├----
├--
├--
└-
路由树是由一个个节点构成的,gin框架路由树的节点由 node 结构体表示,它有以下字段:
// tree.go
type node struct {
// 节点路径,比如上面的s,earch,和upport
path string
// 和children字段对应, 保存的是分裂的分支的第一个字符
// 例如search和support, 那么s节点的indices对应的"eu"
// 代表有两个分支, 分支的首字母分别是e和u
indices string
// 儿子节点
children []*node
// 处理函数链条(切片)
handlers HandlersChain
// 优先级,子节点、子子节点等注册的handler数量
priority uint32
// 节点类型,包括static, root, param, catchAll
// static: 静态节点(默认),比如上面的s,earch等节点
// root: 树的根节点
// catchAll: 有*匹配的节点
// param: 参数节点
nType nodeType
// 路径上最大参数个数
maxParams uint8
// 节点是否是参数节点,比如上面的:post
wildChild bool
// 完整路径
fullPath string
}
在gin的路由中,每一个 HTTP Method (GET、POST、PUT、DELETE…)都对应了一颗 radix tree , 我们注册路由的时候会调用下面的 addRoute 函数
// gin.go
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
// liwenzhou.com...
// 获取请求方法对应的树
root := engine.trees.get(method)
if root == nil {
// 如果没有就创建一个
root = new(node)
root.fullPath = "/"
engine.trees = append(engine.trees, methodTree{method: method, root: root})
}
root.addRoute(path, handlers)
}
从上面的代码中我们可以看到在注册路由的时候都是先根据请求方法获取对应的树,也就是gin框架会为每一个请求方法创建一棵对应的树。只不过需要注意到一个细节是gin框架中保存请求方法对应树关系并不是使用的map而是使用的切片,engine.trees的类型是methodTrees,其定义如下:
type methodTree struct {
method string
root *node
}
type methodTrees []methodTree // slice
而获取请求方法对应树的get方法定义如下:
func (trees methodTrees) get(method string) *node {
for _, tree := range trees {
if tree.method == method {
return tree.root
}
}
return nil
}
GPT解答:
1、性能考虑:在某些情况下,遍历小型切片的性能可能优于在映射中查找键。尤其是在请求方法数量相对较少的情况下(比如HTTP有GET, POST, PUT, DELETE等方法),线性搜索切片可能比在哈希表(也就是map)中查找更快;
2、内存占用:相较于map,切片通常会使用更少的内存。如果我们只需要存储少量的元素,那么切片可能会是更经济的选择;
3、简单性:使用切片可能会让代码更加简洁易读。在这个函数中,我们可以直接使用一个for循环来查找匹配的元素。如果我们使用map,我们就需要处理map返回的第二个值,这个值用来表示键是否存在于map中;
4、保持顺序:切片保持了元素的插入顺序,这在某些应用中可能是必需的。与此相反,map在Go语言中是无序的。
顺着这个思路,我们可以看一下gin框架中engine的初始化方法中,确实对tress字段做了一次内存申请:
func New() *Engine {
debugPrintWARNINGNew()
engine := &Engine{
RouterGroup: RouterGroup{
Handlers: nil,
basePath: "/",
root: true,
},
// liwenzhou.com ...
// 初始化容量为9的切片(HTTP1.1请求方法共9种)
trees: make(methodTrees, 0, 9),
// liwenzhou.com...
}
engine.RouterGroup.engine = engine
engine.pool.New = func() interface{} {
return engine.allocateContext()
}
return engine
}
注册路由的逻辑主要有addRoute函数和insertChild方法。
// tree.go
// addRoute 将具有给定句柄的节点添加到路径中。
// 不是并发安全的
func (n *node) addRoute(path string, handlers HandlersChain) {
fullPath := path
n.priority++
numParams := countParams(path) // 数一下参数个数
// 空树就直接插入当前节点
if len(n.path) == 0 && len(n.children) == 0 {
n.insertChild(numParams, path, fullPath, handlers)
n.nType = root
return
}
parentFullPathIndex := 0
walk:
for {
// 更新当前节点的最大参数个数
if numParams > n.maxParams {
n.maxParams = numParams
}
// 找到最长的通用前缀
// 这也意味着公共前缀不包含“:”"或“*” /
// 因为现有键不能包含这些字符。
i := longestCommonPrefix(path, n.path)
// 分裂边缘(此处分裂的是当前树节点)
// 例如一开始path是search,新加入support,s是他们通用的最长前缀部分
// 那么会将s拿出来作为parent节点,增加earch和upport作为child节点
if i < len(n.path) {
child := node{
path: n.path[i:], // 公共前缀后的部分作为子节点
wildChild: n.wildChild,
indices: n.indices,
children: n.children,
handlers: n.handlers,
priority: n.priority - 1, //子节点优先级-1
fullPath: n.fullPath,
}
// Update maxParams (max of all children)
for _, v := range child.children {
if v.maxParams > child.maxParams {
child.maxParams = v.maxParams
}
}
n.children = []*node{&child}
// []byte for proper unicode char conversion, see #65
n.indices = string([]byte{n.path[i]})
n.path = path[:i]
n.handlers = nil
n.wildChild = false
n.fullPath = fullPath[:parentFullPathIndex+i]
}
// 将新来的节点插入新的parent节点作为子节点
if i < len(path) {
path = path[i:]
if n.wildChild { // 如果是参数节点
parentFullPathIndex += len(n.path)
n = n.children[0]
n.priority++
// Update maxParams of the child node
if numParams > n.maxParams {
n.maxParams = numParams
}
numParams--
// 检查通配符是否匹配
if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
// 检查更长的通配符, 例如 :name and :names
if len(n.path) >= len(path) || path[len(n.path)] == '/' {
continue walk
}
}
pathSeg := path
if n.nType != catchAll {
pathSeg = strings.SplitN(path, "/", 2)[0]
}
prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
panic("'" + pathSeg +
"' in new path '" + fullPath +
"' conflicts with existing wildcard '" + n.path +
"' in existing prefix '" + prefix +
"'")
}
// 取path首字母,用来与indices做比较
c := path[0]
// 处理参数后加斜线情况
if n.nType == param && c == '/' && len(n.children) == 1 {
parentFullPathIndex += len(n.path)
n = n.children[0]
n.priority++
continue walk
}
// 检查路path下一个字节的子节点是否存在
// 比如s的子节点现在是earch和upport,indices为eu
// 如果新加一个路由为super,那么就是和upport有匹配的部分u,将继续分列现在的upport节点
for i, max := 0, len(n.indices); i < max; i++ {
if c == n.indices[i] {
parentFullPathIndex += len(n.path)
i = n.incrementChildPrio(i)
n = n.children[i]
continue walk
}
}
// 否则就插入
if c != ':' && c != '*' {
// []byte for proper unicode char conversion, see #65
// 注意这里是直接拼接第一个字符到n.indices
n.indices += string([]byte{c})
child := &node{
maxParams: numParams,
fullPath: fullPath,
}
// 追加子节点
n.children = append(n.children, child)
n.incrementChildPrio(len(n.indices) - 1)
n = child
}
n.insertChild(numParams, path, fullPath, handlers)
return
}
// 已经注册过的节点
if n.handlers != nil {
panic("handlers are already registered for path '" + fullPath + "'")
}
n.handlers = handlers
return
}
}
其实上面的代码很好理解,大家可以参照动画尝试将以下情形代入上面的代码逻辑,体味整个路由树构造的详细过程:
第一次注册路由,例如注册search
继续注册一条没有公共前缀的路由,例如blog
注册一条与先前注册的路由有公共前缀的路由,例如support
// tree.go
func (n *node) insertChild(numParams uint8, path string, fullPath string, handlers HandlersChain) {
// 找到所有的参数
for numParams > 0 {
// 查找前缀直到第一个通配符
wildcard, i, valid := findWildcard(path)
if i < 0 { // 没有发现通配符
break
}
// 通配符的名称必须包含':' 和 '*'
if !valid {
panic("only one wildcard per path segment is allowed, has: '" +
wildcard + "' in path '" + fullPath + "'")
}
// 检查通配符是否有名称
if len(wildcard) < 2 {
panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
}
// 检查这个节点是否有已经存在的子节点
// 如果我们在这里插入通配符,这些子节点将无法访问
if len(n.children) > 0 {
panic("wildcard segment '" + wildcard +
"' conflicts with existing children in path '" + fullPath + "'")
}
if wildcard[0] == ':' { // param
if i > 0 {
// 在当前通配符之前插入前缀
n.path = path[:i]
path = path[i:]
}
n.wildChild = true
child := &node{
nType: param,
path: wildcard,
maxParams: numParams,
fullPath: fullPath,
}
n.children = []*node{child}
n = child
n.priority++
numParams--
// 如果路径没有以通配符结束
// 那么将有另一个以'/'开始的非通配符子路径。
if len(wildcard) < len(path) {
path = path[len(wildcard):]
child := &node{
maxParams: numParams,
priority: 1,
fullPath: fullPath,
}
n.children = []*node{child}
n = child // 继续下一轮循环
continue
}
// 否则我们就完成了。将处理函数插入新叶子中
n.handlers = handlers
return
}
// catchAll
if i+len(wildcard) != len(path) || numParams > 1 {
panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
}
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
}
// currently fixed width 1 for '/'
i--
if path[i] != '/' {
panic("no / before catch-all in path '" + fullPath + "'")
}
n.path = path[:i]
// 第一个节点:路径为空的catchAll节点
child := &node{
wildChild: true,
nType: catchAll,
maxParams: 1,
fullPath: fullPath,
}
// 更新父节点的maxParams
if n.maxParams < 1 {
n.maxParams = 1
}
n.children = []*node{child}
n.indices = string('/')
n = child
n.priority++
// 第二个节点:保存变量的节点
child = &node{
path: path[i:],
nType: catchAll,
maxParams: 1,
handlers: handlers,
priority: 1,
fullPath: fullPath,
}
n.children = []*node{child}
return
}
// 如果没有找到通配符,只需插入路径和句柄
n.path = path
n.handlers = handlers
n.fullPath = fullPath
}
insertChild 函数是根据path本身进行分割,将/分开的部分分别作为节点保存,形成一棵树结构。参数匹配中的:和*的区别是,前者是匹配一个字段而后者是匹配后面所有的路径。
我们先来看gin框架处理请求的入口函数ServeHTTP:
// gin.go
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// 这里使用了对象池
c := engine.pool.Get().(*Context)
// 这里有一个细节就是Get对象后做初始化
c.writermem.reset(w)
c.Request = req
c.reset()
engine.handleHTTPRequest(c) // 我们要找的处理HTTP请求的函数
engine.pool.Put(c) // 处理完请求后将对象放回池子
}
函数很长,这里省略了部分代码,只保留相关逻辑代码:
// gin.go
func (engine *Engine) handleHTTPRequest(c *Context) {
// liwenzhou.com...
// 根据请求方法找到对应的路由树
t := engine.trees
for i, tl := 0, len(t); i < tl; i++ {
if t[i].method != httpMethod {
continue
}
root := t[i].root
// 在路由树中根据path查找
value := root.getValue(rPath, c.Params, unescape)
if value.handlers != nil {
c.handlers = value.handlers
c.Params = value.params
c.fullPath = value.fullPath
c.Next() // 执行函数链条
c.writermem.WriteHeaderNow()
return
}
// liwenzhou.com...
c.handlers = engine.allNoRoute
serveError(c, http.StatusNotFound, default404Body)
}
路由匹配是由节点的 getValue 方法实现的。getValue 根据给定的路径(键)返回 nodeValue 值,保存注册的处理函数和匹配到的路径参数数据。如果找不到任何处理函数,则会尝试TS(尾随斜杠重定向)。代码虽然很长,但还算比较工整。
大家可以借助注释看一下路由查找及参数匹配的逻辑。
// tree.go
type nodeValue struct {
handlers HandlersChain
params Params // []Param
tsr bool
fullPath string
}
// liwenzhou.com...
func (n *node) getValue(path string, po Params, unescape bool) (value nodeValue) {
value.params = po
walk: // Outer loop for walking the tree
for {
prefix := n.path
if path == prefix {
// 我们应该已经到达包含处理函数的节点。
// 检查该节点是否注册有处理函数
if value.handlers = n.handlers; value.handlers != nil {
value.fullPath = n.fullPath
return
}
if path == "/" && n.wildChild && n.nType != root {
value.tsr = true
return
}
// 没有找到处理函数 检查这个路径末尾+/ 是否存在注册函数
indices := n.indices
for i, max := 0, len(indices); i < max; i++ {
if indices[i] == '/' {
n = n.children[i]
value.tsr = (len(n.path) == 1 && n.handlers != nil) ||
(n.nType == catchAll && n.children[0].handlers != nil)
return
}
}
return
}
if len(path) > len(prefix) && path[:len(prefix)] == prefix {
path = path[len(prefix):]
// 如果该节点没有通配符(param或catchAll)子节点
// 我们可以继续查找下一个子节点
if !n.wildChild {
c := path[0]
indices := n.indices
for i, max := 0, len(indices); i < max; i++ {
if c == indices[i] {
n = n.children[i] // 遍历树
continue walk
}
}
// 没找到
// 如果存在一个相同的URL但没有末尾/的叶子节点
// 我们可以建议重定向到那里
value.tsr = path == "/" && n.handlers != nil
return
}
// 根据节点类型处理通配符子节点
n = n.children[0]
switch n.nType {
case param:
// find param end (either '/' or path end)
end := 0
for end < len(path) && path[end] != '/' {
end++
}
// 保存通配符的值
if cap(value.params) < int(n.maxParams) {
value.params = make(Params, 0, n.maxParams)
}
i := len(value.params)
value.params = value.params[:i+1] // 在预先分配的容量内扩展slice
value.params[i].Key = n.path[1:]
val := path[:end]
if unescape {
var err error
if value.params[i].Value, err = url.QueryUnescape(val); err != nil {
value.params[i].Value = val // fallback, in case of error
}
} else {
value.params[i].Value = val
}
// 继续向下查询
if end < len(path) {
if len(n.children) > 0 {
path = path[end:]
n = n.children[0]
continue walk
}
// ... but we can't
value.tsr = len(path) == end+1
return
}
if value.handlers = n.handlers; value.handlers != nil {
value.fullPath = n.fullPath
return
}
if len(n.children) == 1 {
// 没有找到处理函数. 检查此路径末尾加/的路由是否存在注册函数
// 用于 TSR 推荐
n = n.children[0]
value.tsr = n.path == "/" && n.handlers != nil
}
return
case catchAll:
// 保存通配符的值
if cap(value.params) < int(n.maxParams) {
value.params = make(Params, 0, n.maxParams)
}
i := len(value.params)
value.params = value.params[:i+1] // 在预先分配的容量内扩展slice
value.params[i].Key = n.path[2:]
if unescape {
var err error
if value.params[i].Value, err = url.QueryUnescape(path); err != nil {
value.params[i].Value = path // fallback, in case of error
}
} else {
value.params[i].Value = path
}
value.handlers = n.handlers
value.fullPath = n.fullPath
return
default:
panic("invalid node type")
}
}
// 找不到,如果存在一个在当前路径最后添加/的路由
// 我们会建议重定向到那里
value.tsr = (path == "/") ||
(len(prefix) == len(path)+1 && prefix[len(path)] == '/' &&
path == prefix[:len(prefix)-1] && n.handlers != nil)
return
}
}
gin框架涉及中间件相关有4个常用的方法,它们分别是c.Next()、c.Abort()、c.Set()、c.Get()。
gin框架中的中间件设计很巧妙,我们可以首先从我们最常用的r := gin.Default()的Default函数开始看,它内部构造一个新的engine之后就通过Use()函数注册了Logger中间件和Recovery中间件:
func Default() *Engine {
debugPrintWARNINGDefault()
engine := New()
engine.Use(Logger(), Recovery()) // 默认注册的两个中间件
return engine
}
继续往下查看一下Use()函数的代码:
func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...) // 实际上还是调用的RouterGroup的Use函数
engine.rebuild404Handlers()
engine.rebuild405Handlers()
return engine
}
从下方的代码可以看出,注册中间件其实就是将中间件函数追加到group.Handlers中:
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes {
group.Handlers = append(group.Handlers, middleware...)
return group.returnObj()
}
而我们注册路由时会将对应路由的函数和之前的中间件函数结合到一起:
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers HandlersChain) IRoutes {
absolutePath := group.calculateAbsolutePath(relativePath)
handlers = group.combineHandlers(handlers) // 将处理请求的函数与中间件函数结合
group.engine.addRoute(httpMethod, absolutePath, handlers)
return group.returnObj()
}
其中结合操作的函数内容如下,注意观察这里是如何实现拼接两个切片得到一个新切片的。
const abortIndex int8 = math.MaxInt8 / 2
func (group *RouterGroup) combineHandlers(handlers HandlersChain) HandlersChain {
finalSize := len(group.Handlers) + len(handlers)
if finalSize >= int(abortIndex) { // 这里有一个最大限制
panic("too many handlers")
}
mergedHandlers := make(HandlersChain, finalSize)
copy(mergedHandlers, group.Handlers)
copy(mergedHandlers[len(group.Handlers):], handlers)
return mergedHandlers
}
也就是说,我们会将一个路由的中间件函数和处理函数结合到一起组成一条处理函数链条HandlersChain,而它本质上就是一个由HandlerFunc组成的切片:
type HandlersChain []HandlerFunc
我们在上面路由匹配的时候见过如下逻辑:
value := root.getValue(rPath, c.Params, unescape)
if value.handlers != nil {
c.handlers = value.handlers
c.Params = value.params
c.fullPath = value.fullPath
c.Next() // 执行函数链条
c.writermem.WriteHeaderNow()
return
}
其中c.Next()就是很关键的一步,它的代码很简单:
func (c *Context) Next() {
c.index++
for c.index < int8(len(c.handlers)) {
c.handlers[c.index](c)
c.index++
}
}
从上面的代码可以看到,这里通过索引遍历HandlersChain链条,从而实现依次调用该路由的每一个函数(中间件或处理请求的函数)。
我们可以在中间件函数中通过再次调用c。Next()实现嵌套调用(funcl中调用func2;func2中调用func3),
或者通过调用c.Abort()中断整个调用链条,从当前函数返回。
func (c *Context) Abort() {
c.index = abortIndex // 直接将索引置为最大限制值,从而退出循环
}
c.Set()和c.Get()这两个方法多用于在多个函数之间通过c传递数据的,比如我们可以在认证中间件中获取当前请求的相关信息(userID等)通过c.Set()存入c,然后在后续处理业务逻辑的函数中通过c.Get()来获取当前请求的用户。c就像是一根绳子,将该次请求相关的所有的函数都串起来了
Go语言中的 database/sql 包提供了保证SQL或类SQL数据库的泛用接口,并不提供具体的数据库驱动。使用 database/sql 包时必须注入至少一个数据库驱动。
我们常用的数据库基本上都有完整的第三方实现:例如:MySQL驱动https://github.com/go-sql-driver/mysql
go get -u github.com/go-sql-driver/mysql
func Open(driverName, dataSourceName string)(*DB, error){}
Open打开一个dirverName指定的数据库,dataSourceName指定数据源,一般至少包括数据库文件名和其它连接必要的信息。
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
func main() {
// DSN:Data Source Name
dsn := "user:password@tcp(127.0.0.1:3306)/dbname"
db, err := sql.Open("mysql", dsn)
if err != nil {
panic(err)
}
defer db.Close() // 注意这行代码要写在上面err判断的下面
}
Open函数可能只是验证其参数格式是否正确,实际上并不创建与数据库的连接。如果要检查数据源的名称是否真实有效,应该调用panic方法。
返回的DB对象可以安全地被多个goroutine并发使用,并且维护其自己的空闲连接池。因此,open函数应该仅仅被调用一次,很少需要关闭这个DB对象。
接下来,我们定义一个全局的变量 db,用来保存数据库连接对象。将上面的示例代码拆分出一个独立的initDB函数,只需要在程序启动时调用一次该函数完成全局变量db的初始化,其他函数就可以直接使用全局变量db 了。
// 定义一个全局对象db
var db *sql.DB
// 定义一个初始化数据库的函数
func initDB() (err error) {
// DSN:Data Source Name
dsn := "user:password@tcp(127.0.0.1:3306)/sql_test?charset=utf8mb4&parseTime=True"
// 不会校验账号密码是否正确
// 注意!!!这里不要使用:=,我们是给全局变量赋值,然后在main函数中使用全局变量db
db, err = sql.Open("mysql", dsn)
if err != nil {
return err
}
// 尝试与数据库建立连接(校验dsn是否正确)
err = db.Ping()
if err != nil {
return err
}
return nil
}
func main() {
err := initDB() // 调用输出化数据库的函数
if err != nil {
fmt.Printf("init db failed,err:%v\n", err)
return
}
}
其中sql.DB是表示连接的数据库对象(结构体实例),它保存了连接数据库相关的所有信息。内部维护着一个具有零到多个底层连接的连接池,它可以安全地被多个goroutine同时使用。
func (db *DB) SetMaxOpenConns(n int)
设置与数据库建立连接的最大数目。如果n大于0且小于最大闲置连接数,会将最大闲置连接数减小到匹配最大开启连接数的限制。如果n<=0,不会限制最大开启连接数,默认为0(无限制)。
func(db *DB)SetMaxIdleConns(n int)
设置连接池中的最大限制连接数。如果n大于最大开启连接数,则新的最大闲置连接数会减小到匹配最大开启连接数的限制。如果n<=0,不会保留闲置连接。
我们先在MySQL中创建一个名为sql_test的数据库
CREATE DATABASE sql_test;
进入该数据库:
use sql_test;
执行以下命令创建一张用于测试的数据表:
CREATE TABLE `user` (
`id` BIGINT(20) NOT NULL AUTO_INCREMENT,
`name` VARCHAR(20) DEFAULT '',
`age` INT(11) DEFAULT '0',
PRIMARY KEY(`id`)
)ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4;
单行查询 db.QueryRow() 执行一次查询,并期望返回最多一行结果(即Row)。 QueryRow 总是返回非nil 的值,直到返回值的Scan方法被调用时,才会返回被延迟的错误。(如:未找到结果)
func (db *DB)QueryRow(query sting, args ...interface{}) *Row
单行查询db.QueryRow()执行一次查询,并期望返回最多一行结果(即Row)。QueryRow总是返回非nil的值,直到返回值的Scan方法被调用时,才会返回被延迟的错误。(如:未找到结果)
package main
import (
"database/sql"
"fmt"
_"github.com/go-sql-driver/mysql" // init()初始化驱动
"time"
)
// 创建一个全局的DB实例
var DB *sql.DB
type user struct{
id int
name string
age int
}
func initMySQL()(err error){
//Data source Name
dsn := "root:DENGni1314@tcp(127.0.0.1:3306)/test"
//去初始化全局的DB对象而不是声明一个DB变量
DB, err = sql.Open("mysql", dsn)
if err != nil {
panic(err)
}
//做完错误检查之后,确保DB不为nil
//尝试与数据库建立连接(校验dsn是否正确)
err = DB.Ping()
if err != nil {
fmt.Println("connect to database error,err=", err)
return
}
//数值需要业务具体的情况来定
DB.SetConnMaxLifetime(time.Second*10)
DB.SetMaxOpenConns(200)
DB.SetMaxIdleConns(30)
return
}
func queryRowDemo(){
sqlStr := "select id, name, age from user where id=?"
var u user
//非常重要,确保QueryRow之后调用Scan方法,否则持有的数据库链接不会被释放
err := DB.QueryRow(sqlStr,1).Scan(&u.id, &u.name, &u.age)
if err != nil {
fmt.Println("Scan failed, err=", err)
return
}
fmt.Printf("id:%d, name:%s, age:%d\n", u.id, u.name, u.age)
}
func main(){
if err := initMySQL();err != nil{
fmt.Printf("init db failed,err:%v\n",err)
return
}
queryRowDemo()
defer DB.Close()
}
多行查询db.Query()执行一次查询,返回多行结果(即Rows),一般用于执行select命令。参数args表示query中的占位参数。
func (db *DB) Query(query string, args ...interface{}) (*Rows, error)
func QueryDemo(){
sqlStr := "select id, name, age from user where id > ?"
rows, err := DB.Query(sqlStr, 0)
if err != nil {
fmt.Println("Query failed,err=", err)
return
}
//非常重要: 关闭rows释放持有的数据库链接
defer rows.Close()
//循环读取结果集中的数据
for rows.Next(){
var u user
err := rows.Scan(&u.id, &u.name, &u.age)
if err != nil {
fmt.Println("Scan failed,err=", err)
return
}
fmt.Printf("id:%d, name:%s, age:%d\n", u.id, u.name, u.age)
}
}
插入、更新和删除操作都使用Exec方法。
func (db *DB) Exec(query string, args ...interface{}) (Result, error)
Exec执行一次命令(包括查询、删除、更新、插入等),返回的Result是对已执行的SQL命令的总结。参数args表示query中的占位参数。
具体插入数据示例代码如下:
// 插入数据
func insertRowDemo() {
sqlStr := "insert into user(name, age) values (?,?)"
ret, err := db.Exec(sqlStr, "王五", 38)
if err != nil {
fmt.Printf("insert failed, err:%v\n", err)
return
}
theID, err := ret.LastInsertId() // 新插入数据的id
if err != nil {
fmt.Printf("get lastinsert ID failed, err:%v\n", err)
return
}
fmt.Printf("insert success, the id is %d.\n", theID)
}
// 更新数据
func updateRowDemo() {
sqlStr := "update user set age=? where id = ?"
ret, err := db.Exec(sqlStr, 39, 3)
if err != nil {
fmt.Printf("update failed, err:%v\n", err)
return
}
n, err := ret.RowsAffected() // 操作影响的行数
if err != nil {
fmt.Printf("get RowsAffected failed, err:%v\n", err)
return
}
fmt.Printf("update success, affected rows:%d\n", n)
}
// 删除数据
func deleteRowDemo() {
sqlStr := "delete from user where id = ?"
ret, err := db.Exec(sqlStr, 3)
if err != nil {
fmt.Printf("delete failed, err:%v\n", err)
return
}
n, err := ret.RowsAffected() // 操作影响的行数
if err != nil {
fmt.Printf("get RowsAffected failed, err:%v\n", err)
return
}
fmt.Printf("delete success, affected rows:%d\n", n)
}
普通SQL语句执行过程:
客户端对SQL语句进行占位符替换得到完整的SQL语句。
客户端发送完整SQL语句到MySQL服务端
MySQL服务端执行完整的SQL语句并将结果返回给客户端。
预处理执行过程:
把SQL语句分成两部分,命令部分与数据部分。
先把命令部分发送给MySQL服务端,MySQL服务端进行SQL预处理。
然后把数据部分发送给MySQL服务端,MySQL服务端对SQL语句进行占位符替换。
MySQL服务端执行完整的SQL语句并将结果返回给客户端。
优化MySQL服务器重复执行SQL的方法,可以提升服务器性能,提前让服务器编译,一次编译多次执行,节省后续编译的成本。
避免SQL注入问题。
database/sql 中通过Preper 方法实现了预处理
func (db *DB) Prepare(query string) (*Stmt, error)
Prepare方法会先将sql语句发送给MySQL服务端,返回一个准备好的状态用于之后的查询和命令。返回值可以同时执行多个查询和命令。
查询操作的预处理示例代码如下:
// 预处理查询示例
func prepareQueryDemo() {
sqlStr := "select id, name, age from user where id > ?"
stmt, err := db.Prepare(sqlStr)
if err != nil {
fmt.Printf("prepare failed, err:%v\n", err)
return
}
defer stmt.Close()
rows, err := stmt.Query(0)
if err != nil {
fmt.Printf("query failed, err:%v\n", err)
return
}
defer rows.Close()
// 循环读取结果集中的数据
for rows.Next() {
var u user
err := rows.Scan(&u.id, &u.name, &u.age)
if err != nil {
fmt.Printf("scan failed, err:%v\n", err)
return
}
fmt.Printf("id:%d name:%s age:%d\n", u.id, u.name, u.age)
}
}
插入、更新和删除操作的预处理十分类似,这里以插入操作的预处理为例:
// 预处理插入示例
func prepareInsertDemo() {
sqlStr := "insert into user(name, age) values (?,?)"
stmt, err := db.Prepare(sqlStr)
if err != nil {
fmt.Printf("prepare failed, err:%v\n", err)
return
}
defer stmt.Close()
_, err = stmt.Exec("小王子", 18)
if err != nil {
fmt.Printf("insert failed, err:%v\n", err)
return
}
_, err = stmt.Exec("沙河娜扎", 18)
if err != nil {
fmt.Printf("insert failed, err:%v\n", err)
return
}
fmt.Println("insert success.")
}
我们任何时候都不应该自己拼接SQL语句!因为总会有人想要搞破坏
这里我们演示一个自行拼接SQL语句的示例,编写一个根据name字段查询user表的函数如下:
// sql注入示例
func sqlInjectDemo(name string) {
sqlStr := fmt.Sprintf("select id, name, age from user where name='%s'", name)
fmt.Printf("SQL:%s\n", sqlStr)
var u user
err := db.QueryRow(sqlStr).Scan(&u.id, &u.name, &u.age)
if err != nil {
fmt.Printf("exec failed, err:%v\n", err)
return
}
fmt.Printf("user:%#v\n", u)
}
此时以下输入字符串都可以引发SQL注入问题:
sqlInjectDemo("xxx' or 1=1#")
sqlInjectDemo("xxx' union select * from user #")
sqlInjectDemo("xxx' and (select count(*) from user) <10 #")
事务:一个最小的不可再分的工作单元;通常一个事务对应一个完整的业务(例如银行账户转账业务,该业务就是一个最小的工作单元),同时这个完整的业务需要执行多次的DML(insert,update,delete)语句共同联合完成。A转账给B,这里面就需要执行两次update操作
在MySQL中只有使用了 Innodb数据库引擎的数据库或表才支持事务。事务处理可以用来维护数据库的完整性,保证成批的SQL语句要么全部执行,要么全部不执行。
通常事务必须满足4个条件:原子性 (Atomicity,或称不可分割性)、一致性(Consistency)、隔离性(Isolation,又称为独立性)、持久性(Durability)。
Go语言中使用以下三个方法实现MySQL中的事务操作。
开始事务
func (db *DB) Begin() (*Tx, error)
提交事务
func (tx *Tx) Commit() error
回滚事务
func (tx *Tx) Rollback() error
下面的代码演示了一个简单的事务操作,该事物操作能够确保两次更新操作要么同时成功要么同时失败,不会存在中间状态。
// 事务操作示例
func transactionDemo() {
tx, err := db.Begin() // 开启事务
if err != nil {
if tx != nil {
tx.Rollback() // 回滚
}
fmt.Printf("begin trans failed, err:%v\n", err)
return
}
sqlStr1 := "Update user set age=30 where id=?"
ret1, err := tx.Exec(sqlStr1, 2)
if err != nil {
tx.Rollback() // 回滚
fmt.Printf("exec sql1 failed, err:%v\n", err)
return
}
affRow1, err := ret1.RowsAffected()
if err != nil {
tx.Rollback() // 回滚
fmt.Printf("exec ret1.RowsAffected() failed, err:%v\n", err)
return
}
sqlStr2 := "Update user set age=40 where id=?"
ret2, err := tx.Exec(sqlStr2, 3)
if err != nil {
tx.Rollback() // 回滚
fmt.Printf("exec sql2 failed, err:%v\n", err)
return
}
affRow2, err := ret2.RowsAffected()
if err != nil {
tx.Rollback() // 回滚
fmt.Printf("exec ret1.RowsAffected() failed, err:%v\n", err)
return
}
fmt.Println(affRow1, affRow2)
if affRow1 == 1 && affRow2 == 1 {
fmt.Println("事务提交啦...")
tx.Commit() // 提交事务
} else {
tx.Rollback()
fmt.Println("事务回滚啦...")
}
fmt.Println("exec trans success!")
}
结合net/http和database/sql实现一个使用MySQL存储用户信息的注册及登陆的简易web程序。
package main
import (
"database/sql"
"fmt"
"net/http"
)
var DB *sql.DB
func initDB() (err error) {
dsn := "root:DENGni1314@tcp(127.0.0.1:3306)/test"
DB, err = sql.Open("mysql", dsn)
if err != nil {
fmt.Println("sql Open failed,err=", err)
return
}
err = DB.Ping()
if err != nil {
fmt.Println("Ping failed,err=", err)
return
}
return nil
}
func check(account, password string) bool {
sqlStr := "select * from login where account = ? and password = ?"
rows, err := DB.Query(sqlStr, account, password)
if err != nil {
fmt.Println("err=", err)
return false
}
defer rows.Close()
if rows.Next() {
return true
}else{
return false
}
}
func Login(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `
登入系统
登入
`)
err := r.ParseForm()
if err != nil {
fmt.Println("err=", err)
return
}
account := r.Form.Get("account")
password := r.Form.Get("password")
flag := check(account, password)
if flag {
fmt.Fprintln(w, "yes")
} else {
fmt.Fprintln(w, "no")
}
}
func main() {
err := initDB()
if err != nil {
fmt.Println("init DB failed,err=", err)
return
}
http.HandleFunc("/", Login)
err = http.ListenAndServe(":9090", nil)
if err != nil {
fmt.Println("err=", err)
return
}
}
var db *sqlx.DB
func initDB() (err error) {
dsn := "user:password@tcp(127.0.0.1:3306)/sql_test?charset=utf8mb4&parseTime=True"
// 也可以使用MustConnect连接不成功就panic
db, err = sqlx.Connect("mysql", dsn)
if err != nil {
fmt.Printf("connect DB failed, err:%v\n", err)
return
}
db.SetMaxOpenConns(20)
db.SetMaxIdleConns(10)
return
}
查询单行数据示例代码如下:
查询单行数据示例代码如下:
// 查询单条数据示例
func queryRowDemo() {
sqlStr := "select id, name, age from user where id=?"
var u user
err := db.Get(&u, sqlStr, 1)
if err != nil {
fmt.Printf("get failed, err:%v\n", err)
return
}
fmt.Printf("id:%d name:%s age:%d\n", u.ID, u.Name, u.Age)
}
查询多行数据示例代码如下:
// 查询多条数据示例
func queryMultiRowDemo() {
sqlStr := "select id, name, age from user where id > ?"
var users []user
err := db.Select(&users, sqlStr, 0)
if err != nil {
fmt.Printf("query failed, err:%v\n", err)
return
}
fmt.Printf("users:%#v\n", users)
}
sqlx中的exec方法与原生sql中的exec使用基本一致:
// 插入数据
func insertRowDemo() {
sqlStr := "insert into user(name, age) values (?,?)"
ret, err := db.Exec(sqlStr, "沙河小王子", 19)
if err != nil {
fmt.Printf("insert failed, err:%v\n", err)
return
}
theID, err := ret.LastInsertId() // 新插入数据的id
if err != nil {
fmt.Printf("get lastinsert ID failed, err:%v\n", err)
return
}
fmt.Printf("insert success, the id is %d.\n", theID)
}
// 更新数据
func updateRowDemo() {
sqlStr := "update user set age=? where id = ?"
ret, err := db.Exec(sqlStr, 39, 6)
if err != nil {
fmt.Printf("update failed, err:%v\n", err)
return
}
n, err := ret.RowsAffected() // 操作影响的行数
if err != nil {
fmt.Printf("get RowsAffected failed, err:%v\n", err)
return
}
fmt.Printf("update success, affected rows:%d\n", n)
}
// 删除数据
func deleteRowDemo() {
sqlStr := "delete from user where id = ?"
ret, err := db.Exec(sqlStr, 6)
if err != nil {
fmt.Printf("delete failed, err:%v\n", err)
return
}
n, err := ret.RowsAffected() // 操作影响的行数
if err != nil {
fmt.Printf("get RowsAffected failed, err:%v\n", err)
return
}
fmt.Printf("delete success, affected rows:%d\n", n)
}
DB.NamedExec 方法用来绑定SQL语句与结构体或map中的同名字段。
func insertUserDemo()(err error){
sqlStr := "INSERT INTO user (name,age) VALUES (:name,:age)"
_, err = db.NamedExec(sqlStr,
map[string]interface{}{
"name": "天意",
"age": 18,
})
return
}
与DB.NamedExec同理,这里是支持查询。
func namedQuery(){
sqlStr := "SELECT * FROM user WHERE name=:name"
// 使用map做命名查询
rows, err := db.NamedQuery(sqlStr, map[string]interface{}{"name": "七米"})
if err != nil {
fmt.Printf("db.NamedQuery failed, err:%v\n", err)
return
}
defer rows.Close()
for rows.Next(){
var u user
err := rows.StructScan(&u)
if err != nil {
fmt.Printf("scan failed, err:%v\n", err)
continue
}
fmt.Printf("user:%#v\n", u)
}
u := user{
Name: "七米",
}
// 使用结构体命名查询,根据结构体字段的 db tag进行映射
rows, err = db.NamedQuery(sqlStr, u)
if err != nil {
fmt.Printf("db.NamedQuery failed, err:%v\n", err)
return
}
defer rows.Close()
for rows.Next(){
var u user
err := rows.StructScan(&u)
if err != nil {
fmt.Printf("scan failed, err:%v\n", err)
continue
}
fmt.Printf("user:%#v\n", u)
}
}
对于事务操作,我们可以使用sqlx中提供的db.Beginx()和tx.Exec()方法。示例代码如下:
func transactionDemo2()(err error) {
tx, err := db.Beginx() // 开启事务
if err != nil {
fmt.Printf("begin trans failed, err:%v\n", err)
return err
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p) // re-throw panic after Rollback
} else if err != nil {
fmt.Println("rollback")
tx.Rollback() // err is non-nil; don't change it
} else {
err = tx.Commit() // err is nil; if Commit returns error update err
fmt.Println("commit")
}
}()
sqlStr1 := "Update user set age=20 where id=?"
rs, err := tx.Exec(sqlStr1, 1)
if err!= nil{
return err
}
n, err := rs.RowsAffected()
if err != nil {
return err
}
if n != 1 {
return errors.New("exec sqlStr1 failed")
}
sqlStr2 := "Update user set age=50 where i=?"
rs, err = tx.Exec(sqlStr2, 5)
if err!=nil{
return err
}
n, err = rs.RowsAffected()
if err != nil {
return err
}
if n != 1 {
return errors.New("exec sqlStr1 failed")
}
return err
}
sqlx.In是sqlx提供的一个非常方便的函数。
为了方便演示插入数据操作,这里创建一个user表,表结构如下:
CREATE TABLE `user` (
`id` BIGINT(20) NOT NULL AUTO_INCREMENT,
`name` VARCHAR(20) DEFAULT '',
`age` INT(11) DEFAULT '0',
PRIMARY KEY(`id`)
)ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4;
查询占位符?在内部称为bindvars(查询占位符),它非常重要。你应该始终使用它们向数据库发送值,因为它们可以防止SQL注入攻击。database/sql不尝试对查询文本进行任何验证;它与编码的参数一起按原样发送到服务器。除非驱动程序实现一个特殊的接口,否则在执行之前,查询是在服务器上准备的。因此bindvars是特定于数据库的:
MySQL中使用?
PostgreSQL使用枚举的$1、$2等bindvar语法
SQLite中?和$1的语法都支持
Oracle中使用:name的语法
bindvars的一个常见误解是,它们用来在sql语句中插入值。它们其实仅用于参数化,不允许更改SQL语句的结构。例如,使用bindvars尝试参数化列或表名将不起作用:
// ?不能用来插入表名(做SQL语句中表名的占位符)
db.Query("SELECT * FROM ?", "mytable")
// ?也不能用来插入列名(做SQL语句中列名的占位符)
db.Query("SELECT ?, ? FROM people", "name", "location")
// BatchInsertUsers 自行构造批量插入的语句
func BatchInsertUsers(users []*User) error {
// 存放 (?, ?) 的slice
valueStrings := make([]string, 0, len(users))
// 存放values的slice
valueArgs := make([]interface{}, 0, len(users) * 2)
// 遍历users准备相关数据
for _, u := range users {
// 此处占位符要与插入值的个数对应
valueStrings = append(valueStrings, "(?, ?)")
valueArgs = append(valueArgs, u.Name)
valueArgs = append(valueArgs, u.Age)
}
// 自行拼接要执行的具体语句
stmt := fmt.Sprintf("INSERT INTO user (name, age) VALUES %s",
strings.Join(valueStrings, ","))
_, err := DB.Exec(stmt, valueArgs...)
return err
}
// BatchInsertUsers2 使用sqlx.In帮我们拼接语句和参数, 注意传入的参数是[]interface{}
func BatchInsertUsers2(users []interface{}) error {
query, args, _ := sqlx.In(
"INSERT INTO user (name, age) VALUES (?), (?), (?)",
users..., // 如果arg实现了 driver.Valuer, sqlx.In 会通过调用 Value()来展开它
)
fmt.Println(query) // 查看生成的querystring
fmt.Println(args) // 查看生成的args
_, err := DB.Exec(query, args...)
return err
}、
// BatchInsertUsers3 使用NamedExec实现批量插入
func BatchInsertUsers3(users []*User) error {
_, err := DB.NamedExec("INSERT INTO user (name, age) VALUES (:name, :age)", users)
return err
}
把上面三种方法综合起来试一下:
func main() {
err := initDB()
if err != nil {
panic(err)
}
defer DB.Close()
u1 := User{Name: "七米", Age: 18}
u2 := User{Name: "q1mi", Age: 28}
u3 := User{Name: "小王子", Age: 38}
// 方法1
users := []*User{&u1, &u2, &u3}
err = BatchInsertUsers(users)
if err != nil {
fmt.Printf("BatchInsertUsers failed, err:%v\n", err)
}
// 方法2
users2 := []interface{}{u1, u2, u3}
err = BatchInsertUsers2(users2)
if err != nil {
fmt.Printf("BatchInsertUsers2 failed, err:%v\n", err)
}
// 方法3
users3 := []*User{&u1, &u2, &u3}
err = BatchInsertUsers3(users3)
if err != nil {
fmt.Printf("BatchInsertUsers3 failed, err:%v\n", err)
}
}
关于sqlx.In这里再补充一个用法,在sqlx查询语句中实现In查询和FIND_IN_SET函数。即实现SELECT * FROM user WHERE id in (3, 2, 1);和SELECT * FROM user WHERE id in (3, 2, 1) ORDER BY FIND_IN_SET(id, ‘3,2,1’);。
查询id在给定id集合中的数据。
查询id在给定id集合的数据并维持给定id集合的顺序。
// QueryAndOrderByIDs 按照指定id查询并维护顺序
func QueryAndOrderByIDs(ids []int)(users []User, err error){
// 动态填充id
strIDs := make([]string, 0, len(ids))
for _, id := range ids {
strIDs = append(strIDs, fmt.Sprintf("%d", id))
}
query, args, err := sqlx.In("SELECT name, age FROM user WHERE id IN (?) ORDER BY FIND_IN_SET(id, ?)", ids, strings.Join(strIDs, ","))
if err != nil {
return
}
// sqlx.In 返回带 `?` bindvar的查询语句, 我们使用Rebind()重新绑定它
query = DB.Rebind(query)
err = DB.Select(&users, query, args...)
return
}
当然,在这个例子里面你也可以先使用IN查询,然后通过代码按给定的ids对查询结果进行排序。
Redis是一个开源的内存数据库,Redis提供了多种不同类型的数据结构,很多业务场景下的问题都可以很自然地映射到这些数据结构上。除此之外,通过复制、持久化和客户端分片等特性,我们可以很方便地将Redis扩展成一个能够包含数百GB数据、每秒处理上百万次请求的系统。
Redis支持诸如字符串(string)、哈希(hashe)、列表(list)、集合(set)、带范围查询的排序集合(sorted set)、bitmap、hyperloglog、带半径查询的地理空间索引(geospatial index)和流(stream)等数据结构。
读者可以选择在本机安装 redis 或使用云数据库,这里直接使用Docker启动一个 redis 环境,方便学习使用。
使用下面的命令启动一个名为 redis507 的 5.0.7 版本的 redis server环境。
docker run --name redis507 -p 6379:6379 -d redis:5.0.7
注意:此处的版本、容器名和端口号可以根据自己需要设置。
启动一个 redis-cli 连接上面的 redis server。
docker run -it --network host --rm redis:5.0.7 redis-cli
Go 社区中目前有很多成熟的 redis client 库,比如[https://github.com/gomodule/redigo 和https://github.com/go-redis/redis,读者可以自行选择适合自己的库。本书使用 go-redis 这个库来操作 Redis 数据库。
可以在import中加入以下包
"github.com/go-redis/redis"
go-redis 库中使用 redis.NewClient 函数连接 Redis 服务器。
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
Password: "", // 密码
DB: 0, // 数据库
PoolSize: 20, // 连接池大小
})
除此之外,还可以使用 redis.ParseURL 函数从表示数据源的字符串中解析得到 Redis 服务器的配置信息。
opt, err := redis.ParseURL("redis://:@localhost:6379/" )
if err != nil {
panic(err)
}
rdb := redis.NewClient(opt)
如果使用的是 TLS 连接方式,则需要使用 tls.Config 配置。
rdb := redis.NewClient(&redis.Options{
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
// Certificates: []tls.Certificate{cert},
// ServerName: "your.domain.com",
},
})
使用下面的命令连接到由 Redis Sentinel 管理的 Redis 服务器。
rdb := redis.NewFailoverClient(&redis.FailoverOptions{
MasterName: "master-name",
SentinelAddrs: []string{":9126", ":9127", ":9128"},
})
使用下面的命令连接到 Redis Cluster,go-redis 支持按延迟或随机路由命令。
rdb := redis.NewClusterClient(&redis.ClusterOptions{
Addrs: []string{":7000", ":7001", ":7002", ":7003", ":7004", ":7005"},
// 若要根据延迟或随机路由命令,请启用以下命令之一
// RouteByLatency: true,
// RouteRandomly: true,
})
下面的示例代码演示了 go-redis 的基本使用。
// doCommand go-redis基本使用示例
func doCommand() {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
// 执行命令获取结果
val, err := rdb.Get(ctx, "key").Result()
fmt.Println(val, err)
// 先获取到命令对象
cmder := rdb.Get(ctx, "key")
fmt.Println(cmder.Val()) // 获取值
fmt.Println(cmder.Err()) // 获取错误
// 直接执行命令获取错误
err = rdb.Set(ctx, "key", 10, time.Hour).Err()
// 直接执行命令获取值
value := rdb.Get(ctx, "key").Val()
fmt.Println(value)
}
go-redis 还提供了一个执行任意命令或自定义命令的 Do 方法,特别是一些 go-redis 库暂时不支持的命令都可以使用该方法执行。具体使用方法如下。
// doDemo rdb.Do 方法使用示例
func doDemo() {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
// 直接执行命令获取错误
err := rdb.Do(ctx, "set", "key", 10, "EX", 3600).Err()
fmt.Println(err)
// 执行命令获取结果
val, err := rdb.Do(ctx, "get", "key").Result()
fmt.Println(val, err)
}
go-redis 库提供了一个 redis.Nil 错误来表示 Key 不存在的错误。因此在使用 go-redis 时需要注意对返回错误的判断。在某些场景下我们应该区别处理 redis.Nil 和其他不为 nil 的错误。
// getValueFromRedis redis.Nil判断
func getValueFromRedis(key, defaultValue string) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
val, err := rdb.Get(ctx, key).Result()
if err != nil {
// 如果返回的错误是key不存在
if errors.Is(err, redis.Nil) {
return defaultValue, nil
}
// 出其他错了
return "", err
}
return val, nil
}
下面的示例代码演示了如何使用 go-redis 库操作 zset。
// zsetDemo 操作zset示例
func zsetDemo() {
// key
zsetKey := "language_rank"
// value
languages := []*redis.Z{
{Score: 90.0, Member: "Golang"},
{Score: 98.0, Member: "Java"},
{Score: 95.0, Member: "Python"},
{Score: 97.0, Member: "JavaScript"},
{Score: 99.0, Member: "C/C++"},
}
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
// ZADD
err := rdb.ZAdd(ctx, zsetKey, languages...).Err()
if err != nil {
fmt.Printf("zadd failed, err:%v\n", err)
return
}
fmt.Println("zadd success")
// 把Golang的分数加10
newScore, err := rdb.ZIncrBy(ctx, zsetKey, 10.0, "Golang").Result()
if err != nil {
fmt.Printf("zincrby failed, err:%v\n", err)
return
}
fmt.Printf("Golang's score is %f now.\n", newScore)
// 取分数最高的3个
ret := rdb.ZRevRangeWithScores(ctx, zsetKey, 0, 2).Val()
for _, z := range ret {
fmt.Println(z.Member, z.Score)
}
// 取95~100分的
op := &redis.ZRangeBy{
Min: "95",
Max: "100",
}
ret, err = rdb.ZRangeByScoreWithScores(ctx, zsetKey, op).Result()
if err != nil {
fmt.Printf("zrangebyscore failed, err:%v\n", err)
return
}
for _, z := range ret {
fmt.Println(z.Member, z.Score)
}
}
你可以使用KEYS prefix:* 命令按前缀获取所有 key。
vals, err := rdb.Keys(ctx, "prefix*").Result()
但是如果需要扫描数百万的 key ,那速度就会比较慢。这种场景下你可以使用Scan 命令来遍历所有符合要求的 key。
// scanKeysDemo1 按前缀查找所有key示例
func scanKeysDemo1() {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
var cursor uint64
for {
var keys []string
var err error
// 按前缀扫描key
keys, cursor, err = rdb.Scan(ctx, cursor, "prefix:*", 0).Result()
if err != nil {
panic(err)
}
for _, key := range keys {
fmt.Println("key", key)
}
if cursor == 0 { // no more keys
break
}
}
}
Go-redis 允许将上面的代码简化为如下示例。
// scanKeysDemo2 按前缀扫描key示例
func scanKeysDemo2() {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
// 按前缀扫描key
iter := rdb.Scan(ctx, 0, "prefix:*", 0).Iterator()
for iter.Next(ctx) {
fmt.Println("keys", iter.Val())
}
if err := iter.Err(); err != nil {
panic(err)
}
}
例如,我们可以写出一个将所有匹配指定模式的 key 删除的示例。
// delKeysByMatch 按match格式扫描所有key并删除
func delKeysByMatch(match string, timeout time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
iter := rdb.Scan(ctx, 0, match, 0).Iterator()
for iter.Next(ctx) {
err := rdb.Del(ctx, iter.Val()).Err()
if err != nil {
panic(err)
}
}
if err := iter.Err(); err != nil {
panic(err)
}
}
此外,对于 Redis 中的 set、hash、zset 数据类型,go-redis 也支持类似的遍历方法。
iter := rdb.SScan(ctx, "set-key", 0, "prefix:*", 0).Iterator()
iter := rdb.HScan(ctx, "hash-key", 0, "prefix:*", 0).Iterator()
iter := rdb.ZScan(ctx, "sorted-hash-key", 0, "prefix:*", 0).Iterator()
Pipeline 主要是一种网络优化。它本质上意味着客户端缓冲一堆命令并一次性将它们发送到服务器。这些命令不能保
证在事务中执行。这样做的好处是节省了每个命令的网络往返时间(RTT)。
pipe := rdb.Pipeline()
incr := pipe.Incr("pipeline_counter")
pipe.Expire("pipeline_counter", time.Hour)
_, err := pipe.Exec()
fmt.Println(incr,Val(), err)
上面的代码相当于将以下两个命令一次发给redis serveri端执行,与不使用Pipeline相比能减少一次RTT。
INCR pipeline_counter
EXPIRE pipeline_counts 3600
也可以使用 Pipelined:
var incr *redis.IntCmd
_, err := rdb.Pipelined(func(pipe redis.Pipeliner) error{
incr = pipe.Incr("pipelined_counter")
pipe.Expire("pipielined_counter", time.Hour)
return nil
})
fmt.Println(incr.Val(), err)
在某些场景下,当我们有多条命令要执行时,就可以考虑使用pipeline 来优化。
Redis 是单线程执行命令的,因此单个命令始终是原子的,但是来自不同客户端的两个给定命令可以依次执行,例如在它们之间交替执行。但是,Multi/exec能够确保在multi/exec两个语句之间的命令之间没有其他客户端正在执行命令。
在这种场景我们需要使用 TxPipeline 或 TxPipelined 方法将 pipeline 命令使用 MULTI 和EXEC包裹起来。
上面代码相当于在一个RTT下执行了下面的redis命令:
还有一个与上文类似的TxPipelined方法,使用方法如下:
下面的代码片段演示了 Watch 方法搭配 TxPipelined 的使用示例。
// watchDemo 在key值不变的情况下将其值+1
func watchDemo(ctx context.Context, key string) error {
return rdb.Watch(ctx, func(tx *redis.Tx) error {
n, err := tx.Get(ctx, key).Int()
if err != nil && err != redis.Nil {
return err
}
// 假设操作耗时5秒
// 5秒内我们通过其他的客户端修改key,当前事务就会失败
time.Sleep(5 * time.Second)
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, key, n+1, time.Hour)
return nil
})
return err
}, key)
}
我们通常搭配 WATCH命令来执行事务操作。从使用WATCH命令监视某个 key 开始,直到执行EXEC命令的这段时间里,如果有其他用户抢先对被监视的 key 进行了替换、更新、删除等操作,那么当用户尝试执行EXEC的时候,事务将失败并返回一个错误,用户可以根据这个错误选择重试事务或者放弃事务。
// watchDemo 在key值不变的情况下将其值+1
func watchDemo(ctx context.Context, key string) error {
return rdb.Watch(ctx, func(tx *redis.Tx) error {
n, err := tx.Get(ctx, key).Int()
if err != nil && err != redis.Nil {
return err
}
// 假设操作耗时5秒
// 5秒内我们通过其他的客户端修改key,当前事务就会失败
time.Sleep(5 * time.Second)
_, err = tx.TxPipelined(ctx, func(pipe redis.Pipeliner) error {
pipe.Set(ctx, key, n+1, time.Hour)
return nil
})
return err
}, key)
}
将上面的函数执行并打印其返回值,如果我们在程序运行后的5秒内修改了被 watch 的 key 的值,那么该事务操作失败,返回redis: transaction failed错误。
最后我们来看一个 go-redis 官方文档中使用 GET 、SET和WATCH命令实现一个 INCR 命令的完整示例。
const routineCount = 100
increment := func(key string) error {
txf := func(tx *redis.Tx) error {
// 获得当前值或零值
n, err := tx.Get(key).Int()
if err != nil && err != redis.Nil {
return err
}
// 实际操作(乐观锁定中的本地操作)
n++
// 仅在监视的Key保持不变的情况下运行
_, err = tx.Pipelined(func(pipe redis.Pipeliner) error {
// pipe 处理错误情况
pipe.Set(key, n, 0)
return nil
})
return err
}
for retries := routineCount; retries > 0; retries-- {
err := rdb.Watch(txf, key)
if err != redis.TxFailedErr {
return err
}
// 乐观锁丢失
}
return errors.New("increment reached maximum number of retries")
}
var wg sync.WaitGroup
wg.Add(routineCount)
for i := 0; i < routineCount; i++ {
go func() {
defer wg.Done()
if err := increment("counter3"); err != nil {
fmt.Println("increment error:", err)
}
}()
}
wg.Wait()
n, err := rdb.Get("counter3").Int()
fmt.Println("ended with", n, err)
本文先介绍了Go语言原生的日志库的使用,然后详细介绍了非常流行的Uber开源的zap日志库,同时介绍了如何搭配Lumberjack实现日志的切割和归档。
在介绍Uber-go的zap包之前,让我们先看看Go语言提供的基本日志功能。Go语言提供的默认日志包是https://golang.org/pkg/log/。
实现一个Go语言中的日志记录器非常简单——创建一个新的日志文件,然后设置它为日志的输出位置。
设置Logger
我们可以像下面的代码一样设置日志记录器
func SetupLogger() {
logFileLocation, _ := os.OpenFile("/Users/q1mi/test.log", os.O_CREATE|os.O_APPEND|os.O_RDWR, 0744)
log.SetOutput(logFileLocation)
}
使用Logger
让我们来写一些虚拟的代码来使用这个日志记录器。
在当前的示例中,我们将建立一个到URL的HTTP连接,并将状态代码/错误记录到日志文件中。
func simpleHttpGet(url string) {
resp, err := http.Get(url)
if err != nil {
log.Printf("Error fetching url %s : %s", url, err.Error())
} else {
log.Printf("Status Code for %s : %s", url, resp.Status)
resp.Body.Close()
}
}
Logger的运行
现在让我们执行上面的代码并查看日志记录器的运行情况。
func main() {
SetupLogger()
simpleHttpGet(“www.google.com”)
simpleHttpGet(“http://www.google.com”)
}
优势
它最大的优点是使用非常简单。我们可以设置任何io.Writer作为日志记录输出并向其发送要写入的日志。
劣势
Zap是非常快的、结构化的,分日志级别的Go日志库。
运行下面的命令安装zap
go get -u go.uber.org/zap
Logger
var logger *zap.Logger
func main() {
InitLogger()
defer logger.Sync()
simpleHttpGet("www.google.com")
simpleHttpGet("http://www.google.com")
}
func InitLogger() {
logger, _ = zap.NewProduction()
}
func simpleHttpGet(url string) {
resp, err := http.Get(url)
if err != nil {
logger.Error(
"Error fetching url..",
zap.String("url", url),
zap.Error(err))
} else {
logger.Info("Success..",
zap.String("statusCode", resp.Status),
zap.String("url", url))
resp.Body.Close()
}
}
在上面的代码中,我们首先创建了一个Logger,然后使用Info/ Error等Logger方法记录消息。
日志记录器方法的语法是这样的:
func (log *Logger) MethodXXX(msg string, fields ...Field)
其中MethodXXX是一个可变参数函数,可以是Info / Error/ Debug / Panic等。每个方法都接受一个消息字符串和任意数量的zapcore.Field场参数。
每个zapcore.Field其实就是一组键值对参数。
我们执行上面的代码会得到如下输出结果:
{"level":"error","ts":1572159218.912792,"caller":"zap_demo/temp.go:25","msg":"Error fetching url..","url":"www.sogo.com","error":"Get www.sogo.com: unsupported protocol scheme \"\"","stacktrace":"main.simpleHttpGet\n\t/Users/q1mi/zap_demo/temp.go:25\nmain.main\n\t/Users/q1mi/zap_demo/temp.go:14\nruntime.main\n\t/usr/local/go/src/runtime/proc.go:203"}
{"level":"info","ts":1572159219.1227388,"caller":"zap_demo/temp.go:30","msg":"Success..","statusCode":"200 OK","url":"http://www.sogo.com"}
Sugared Logger
var sugarLogger *zap.SugaredLogger
func main() {
InitLogger()
defer sugarLogger.Sync()
simpleHttpGet("www.google.com")
simpleHttpGet("http://www.google.com")
}
func InitLogger() {
logger, _ := zap.NewProduction()
sugarLogger = logger.Sugar()
}
func simpleHttpGet(url string) {
sugarLogger.Debugf("Trying to hit GET request for %s", url)
resp, err := http.Get(url)
if err != nil {
sugarLogger.Errorf("Error fetching URL %s : Error = %s", url, err)
} else {
sugarLogger.Infof("Success! statusCode = %s for URL %s", resp.Status, url)
resp.Body.Close()
}
}
当你执行上面的代码会得到如下输出:
{"level":"error","ts":1572159149.923002,"caller":"logic/temp2.go:27","msg":"Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme \"\"","stacktrace":"main.simpleHttpGet\n\t/Users/q1mi/zap_demo/logic/temp2.go:27\nmain.main\n\t/Users/q1mi/zap_demo/logic/temp2.go:14\nruntime.main\n\t/usr/local/go/src/runtime/proc.go:203"}
{"level":"info","ts":1572159150.192585,"caller":"logic/temp2.go:29","msg":"Success! statusCode = 200 OK for URL http://www.sogo.com"}
你应该注意到的了,到目前为止这两个logger都打印输出JSON结构格式。
在本博客的后面部分,我们将更详细地讨论SugaredLogger,并了解如何进一步配置它。
将日志写入文件而不是终端
我们要做的第一个更改是把日志写入文件,而不是打印到应用程序控制台。
func New(core zapcore.Core, options ...Option) *Logger
zapcore.Core需要三个配置——Encoder,WriteSyncer,LogLevel。
func InitLogger() {
writeSyncer := getLogWriter()
encoder := getEncoder()
core := zapcore.NewCore(encoder, writeSyncer, zapcore.DebugLevel)
logger := zap.New(core)
sugarLogger = logger.Sugar()
}
func getEncoder() zapcore.Encoder {
return zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig())
}
func getLogWriter() zapcore.WriteSyncer {
file, _ := os.Create("./test.log")
return zapcore.AddSync(file)
}
当使用这些修改过的logger配置调用上述部分的main()函数时,以下输出将打印在文件——test.log中。
{"level":"debug","ts":1572160754.994731,"msg":"Trying to hit GET request for www.sogo.com"}
{"level":"error","ts":1572160754.994982,"msg":"Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme \"\""}
{"level":"debug","ts":1572160754.994996,"msg":"Trying to hit GET request for http://www.sogo.com"}
{"level":"info","ts":1572160757.3755069,"msg":"Success! statusCode = 200 OK for URL http://www.sogo.com"}
将JSON Encoder更改为普通的Log Encoder
现在,我们希望将编码器从JSON Encoder更改为普通Encoder。为此,我们需要将NewJSONEncoder()更改为NewConsoleEncoder()。
return zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig())
当使用这些修改过的logger配置调用上述部分的main()函数时,以下输出将打印在文件——test.log中。
1.572161051846623e+09 debug Trying to hit GET request for www.sogo.com
1.572161051846828e+09 error Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme ""
1.5721610518468401e+09 debug Trying to hit GET request for http://www.sogo.com
1.572161052068744e+09 info Success! statusCode = 200 OK for URL http://www.sogo.com
更改时间编码并添加调用者详细信息
鉴于我们对配置所做的更改,有下面两个问题:
func getEncoder() zapcore.Encoder {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
return zapcore.NewConsoleEncoder(encoderConfig)
}
接下来,我们将修改zap logger代码,添加将调用函数信息记录到日志中的功能。为此,我们将在zap.New(…)函数中添加一个Option。
logger := zap.New(core, zap.AddCaller())
当使用这些修改过的logger配置调用上述部分的main()函数时,以下输出将打印在文件——test.log中。
2019-10-27T15:33:29.855+0800 DEBUG logic/temp2.go:47 Trying to hit GET request for www.sogo.com
2019-10-27T15:33:29.855+0800 ERROR logic/temp2.go:50 Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme ""
2019-10-27T15:33:29.856+0800 DEBUG logic/temp2.go:47 Trying to hit GET request for http://www.sogo.com
2019-10-27T15:33:30.125+0800 INFO logic/temp2.go:52 Success! statusCode = 200 OK for URL http://www.sogo.com
AddCallerSkip
当我们不是直接使用初始化好的logger实例记录日志,而是将其包装成一个函数等,此时记录日志的函数调用链会增加,想要获得准确的调用信息就需要通过AddCallerSkip函数来跳过。
将日志输出到多个位置
我们可以将日志同时输出到文件和终端。
func getLogWriter() zapcore.WriteSyncer {
file, _ := os.Create("./test.log")
// 利用io.MultiWriter支持文件和终端两个输出目标
ws := io.MultiWriter(file, os.Stdout)
return zapcore.AddSync(ws)
}
将err日志单独输出到文件
有时候我们除了将全量日志输出到xx.log文件中之外,还希望将ERROR级别的日志单独输出到一个名为xx.err.log的日志文件中。我们可以通过以下方式实现。
func InitLogger() {
encoder := getEncoder()
// test.log记录全量日志
logF, _ := os.Create("./test.log")
c1 := zapcore.NewCore(encoder, zapcore.AddSync(logF), zapcore.DebugLevel)
// test.err.log记录ERROR级别的日志
errF, _ := os.Create("./test.err.log")
c2 := zapcore.NewCore(encoder, zapcore.AddSync(errF), zap.ErrorLevel)
// 使用NewTee将c1和c2合并到core
core := zapcore.NewTee(c1, c2)
logger = zap.New(core, zap.AddCaller())
}
这个日志程序中唯一缺少的就是日志切割归档功能。
官方的说法是为了添加日志切割归档功能,我们将使用第三方库Lumberjack来实现。
目前只支持按文件大小切割,原因是按时间切割效率低且不能保证日志数据不被破坏。详情戳https://github.com/natefinch/lumberjack/issues/54。
想按日期切割可以使用github.com/lestrrat-go/file-rotatelogs这个库,虽然目前不维护了,但也够用了。
// 使用file-rotatelogs按天切割日志
import rotatelogs "github.com/lestrrat-go/file-rotatelogs"
l, _ := rotatelogs.New(
filename+".%Y%m%d%H%M",
rotatelogs.WithMaxAge(30*24*time.Hour), // 最长保存30天
rotatelogs.WithRotationTime(time.Hour*24), // 24小时切割一次
)
zapcore.AddSync(l)
安装
执行下面的命令安装 Lumberjack v2 版本。
go get gopkg.in/natefinch/lumberjack.v2
zap logger中加入Lumberjack
zap logger中加入Lumberjack
func getLogWriter2() zapcore.WriteSyncer{
lumberJackLogger := &lumberjack.Logger{
Filename: "./test.log",
MaxSize: 10,
MaxBackups: 5,
MaxAge: 30,
Compress: false,
}
return zapcore.AddSync(lumberJackLogger)
}
测试所有功能
最终,使用Zap/Lumberjack logger的完整示例代码如下:
package main
import (
"net/http"
"gopkg.in/natefinch/lumberjack.v2"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
var sugarLogger *zap.SugaredLogger
func main() {
InitLogger()
defer sugarLogger.Sync()
simpleHttpGet("www.sogo.com")
simpleHttpGet("http://www.sogo.com")
}
func InitLogger() {
writeSyncer := getLogWriter()
encoder := getEncoder()
core := zapcore.NewCore(encoder, writeSyncer, zapcore.DebugLevel)
logger := zap.New(core, zap.AddCaller())
sugarLogger = logger.Sugar()
}
func getEncoder() zapcore.Encoder {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
return zapcore.NewConsoleEncoder(encoderConfig)
}
func getLogWriter() zapcore.WriteSyncer {
lumberJackLogger := &lumberjack.Logger{
Filename: "./test.log",
MaxSize: 1,
MaxBackups: 5,
MaxAge: 30,
Compress: false,
}
return zapcore.AddSync(lumberJackLogger)
}
func simpleHttpGet(url string) {
sugarLogger.Debugf("Trying to hit GET request for %s", url)
resp, err := http.Get(url)
if err != nil {
sugarLogger.Errorf("Error fetching URL %s : Error = %s", url, err)
} else {
sugarLogger.Infof("Success! statusCode = %s for URL %s", resp.Status, url)
resp.Body.Close()
}
}
执行上述代码,下面的内容会输出到文件——test.log中。
2019-10-27T15:50:32.944+0800 DEBUG logic/temp2.go:48 Trying to hit GET request for www.sogo.com
2019-10-27T15:50:32.944+0800 ERROR logic/temp2.go:51 Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme ""
2019-10-27T15:50:32.944+0800 DEBUG logic/temp2.go:48 Trying to hit GET request for http://www.sogo.com
2019-10-27T15:50:33.165+0800 INFO logic/temp2.go:53 Success! statusCode = 200 OK for URL http://www.sogo.com
同时,可以在main函数中循环记录日志,测试日志文件是否会自动切割和归档(日志文件每1MB会切割并且在当前目录下最多保存5个备份)。
至此,我们总结了如何将Zap日志程序集成到Go应用程序项目中
本文介绍了在基于gin框架开发的项目中如何配置并使用zap来接收并记录gin框架默认的日志和如何配置日志归档。
我们在基于gin框架开发项目时通常都会选择使用专业的日志库来记录项目中的日志,go语言常用的日志库有zap、logrus等。网上也有很多类似的教程,我之前也翻译过一篇《在Go语言项目中使用Zap日志库》。
但是我们该如何在日志中记录gin框架本身输出的那些日志呢?
首先我们来看一个最简单的gin项目:
func main(){
r := gin.Default()
r.GET("./hello", func (c *gin.Context){
c.String("Hello liwenzhou.com!")
})
r.RUN()
}
接下来我们看一下gin.Default()的源码:
func Default() *Engine {
debugPrintWARNINGDefault()
engine := New()
engine.Use(Logger(), Recovery())
return engine
}
也就是我们在使用gin.Default()的同时是用到了gin框架内的两个默认中间件Logger()和Recovery()。
其中Logger()是把gin框架本身的日志输出到标准输出(我们本地开发调试时在终端输出的那些日志就是它的功劳),而Recovery()是在程序出现panic的时候恢复现场并写入500响应的。
我们可以模仿Logger()和Recovery()的实现,使用我们的日志库来接收gin框架默认输出的日志。
这里以zap为例,我们实现两个中间件如下:
package main
import (
"net"
"net/http"
"net/http/httputil"
"os"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/natefinch/lumberjack"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
var lg *zap.Logger
var sugarLogger *zap.SugaredLogger
func GinLogger(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
c.Next()
cost := time.Since(start)
logger.Info(path,
zap.Int("status", c.Writer.Status()),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("ip", c.ClientIP()),
zap.String("user-agent", c.Request.UserAgent()),
zap.String("errors", c.Errors.ByType(gin.ErrorTypePrivate).String()),
zap.Duration("cost",cost),
)
}
}
func GinRecovery(logger *zap.Logger, stack bool)gin.HandlerFunc{
return func (c *gin.Context){
defer func(){
if err := recover(); err!=nil{
var brokenPipe bool
if ne, ok := err.(*net.OpError);ok{
if se, ok := ne.Err.(*os.SyscallError);ok{
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer"){
brokenPipe = true
}
}
}
httpRequest, _ := httputil.DumpRequest(c.Request, false)
if brokenPipe {
logger.Error(c.Request.URL.Path,
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
c.Error(err.(error))
c.Abort()
return
}
if stack{
logger.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
zap.String("stack", string(debug.Stack())),
)
} else{
logger.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("requets", string(httpRequest)),
)
}
c.AbortWithStatus(http.StatusInternalServerError)
}
}()
c.Next()
}
}
func InitLogger(){
writeSyncer := getLogWriter()
encoder := getEncoder()
core := zapcore.NewCore(encoder, writeSyncer, zapcore.DebugLevel)
lg = zap.New(core, zap.AddCaller())
sugarLogger = lg.Sugar()
}
func getEncoder() zapcore.Encoder{
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.TimeKey = "time"
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
encoderConfig.EncodeDuration = zapcore.SecondsDurationEncoder
encoderConfig.EncodeCaller = zapcore.ShortCallerEncoder
return zapcore.NewJSONEncoder(encoderConfig)
}
func getLogWriter() zapcore.WriteSyncer{
lumberJackLogger := &lumberjack.Logger{
Filename: "./test.DEBUG.log",
MaxSize: 10,
MaxBackups: 5,
MaxAge: 30,
Compress: false,
}
return zapcore.AddSync(lumberJackLogger)
}
func main(){
InitLogger()
r := gin.New()
r.Use(GinLogger(lg), GinRecovery(lg, true))
r.GET("./hello", func(c *gin.Context){
c.JSON(http.StatusOK,nil)
})
r.Run(":8080")
}
如果不想自己实现,可以使用github上有别人封装好的https://github.com/gin-contrib/zap。
这样我们就可以在gin框架中使用我们上面定义好的两个中间件来代替gin框架默认的Logger()和Recovery()了。
r := gin.New()
r.Use(GinLogger(), GinRecovery())
最后我们再加入我们项目中常用的日志切割,完整版的logger.go代码如下:
go get github.com/spf13/viper
Viper是适用于Go应用程序(包括Twelve-Factor App) 的完整配置解决方案。它被设计用于在应用程序中工作,并且可以处理所有类型的配置需求和格式。它支持以下特性:
在构建现代应用程序时,你无须担心配置文件格式;你想要专注于构建出色的软件。Viper的出现就是为了在这方面帮助你的。
Viper能够为你执行下列操作:
1、查找、加载和反序列化JSON TOML YAML HCL INI envfile和java properties格式的配置文件
2、提供一种机制为你的不同配置选项设置默认值
3、 提供一种机制来通过命令行参数覆盖指定选项的值
4、提供别名系统,以便在不破坏现有代码的情况下轻松重命名参数
5、当用户提供了与默认值相同的命令行或配置文件时,可以很容易地分辨出它们之间的区别
Viper会按照下面的优先级。每个项目的优先级都高于它下面的项目:
一个好的配置系统应该支持默认值。键不需要默认值,但如果没有通过配置文件、环境变量、远程配置或命令行标志(flag)设置键,则默认值非常有用。
例如:
viper.SetDefault("ContentDir", "content")
viper.SetDefault("LayoutDir", "layouts")
viper.SetDefault("Taxonomies", map[string]string{"tag": "tags", "category": "categories"})
Viper需要最少知道在哪里查找配置文件的配置。Viper支持JSON TOML YAML HCL envfile和java properties 格式的配置文件。Viper可以搜索多个路径,但目前单个Viper实例只支持单个配置文件。Viper不默认任何配置搜索路径,将默认决策留给应用程序。
下面是一个如何使用Viper搜索和读取配置文件的示例。不需要任何特定的路径,但是至少应该提供一个配置文件预期出现的路径。
viper.SetConfigFile("./config.yaml") // 指定配置文件路径
viper.SetConfigName("config") // 配置文件名称(无扩展名)
viper.SetConfigType("yaml") // 如果配置文件的名称中没有扩展名,则需要配置此项
viper.AddConfigPath("/etc/appname/") // 查找配置文件所在的路径
viper.AddConfigPath("$HOME/.appname") // 多次调用以添加多个搜索路径
viper.AddConfigPath(".") // 还可以在工作目录中查找配置
err := viper.ReadInConfig() // 查找并读取配置文件
if err != nil { // 处理读取配置文件的错误
panic(fmt.Errorf("Fatal error config file: %s \n", err))
}
在加载配置文件出错时,你可以像下面这样处理找不到配置文件的特定情况:
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
// 配置文件未找到错误;如果需要可以忽略
} else {
// 配置文件被找到,但产生了另外的错误
}
}
// 配置文件找到并成功解析
注意[自1.6起]: 你也可以有不带扩展名的文件,并以编程方式指定其格式。对于位于用户$HOME目录中的配置文件没有任何扩展名,如.bashrc
这里补充两个问题供读者解答并自行验证
当你使用如下方式读取配置时,viper会从./conf目录下查找任何以config为文件名的配置文件,如果同时存在./conf/config.json和./conf/config.yaml两个配置文件的话,viper会从哪个配置文件加载配置呢?
viper.SetConfigName("config")
viper.AddConfigPath("./conf")
在上面两个语句下搭配使用viper.SetConfigType(“yaml”)指定配置文件类型可以实现预期的效果吗?
从配置文件中读取配置文件是有用的,但是有时你想要存储在运行时所做的所有修改。为此,可以使用下面一组命令,每个命令都有自己的用途:
viper.WriteConfig() // 将当前配置写入“viper.AddConfigPath()”和“viper.SetConfigName”设置的预定义路径
viper.SafeWriteConfig()
viper.WriteConfigAs("/path/to/my/.config")
viper.SafeWriteConfigAs("/path/to/my/.config") // 因为该配置文件写入过,所以会报错
viper.SafeWriteConfigAs("/path/to/my/.other_config")
Viper支持在运行时实时读取配置文件的功能。
需要重新启动服务器以使配置生效的日子已经一去不复返了,viper驱动的应用程序可以在运行时读取配置文件的更新,而不会错过任何消息。
只需告诉viper实例watchConfig。可选地,你可以为Viper提供一个回调函数,以便在每次发生更改时运行。
确保在调用WatchConfig()之前添加了所有的配置路径。
viper.WatchConfig()
viper.OnConfigChange(func(e fsnotify.Event) {
// 配置文件发生变更之后会调用的回调函数
fmt.Println("Config file changed:", e.Name)
})
Viper预先定义了许多配置源,如文件、环境变量、标志和远程K/V存储,但你不受其约束。你还可以实现自己所需的配置源并将其提供给viper。
viper.SetConfigType("yaml") // 或者 viper.SetConfigType("YAML")
// 任何需要将此配置添加到程序中的方法。
var yamlExample = []byte(`
Hacker: true
name: steve
hobbies:
- skateboarding
- snowboarding
- go
clothing:
jacket: leather
trousers: denim
age: 35
eyes : brown
beard: true
`)
viper.ReadConfig(bytes.NewBuffer(yamlExample))
viper.Get("name") // 这里会得到 "steve"
这些可能来自命令行标志,也可能来自你自己的应用程序逻辑。
viper.Set("Verbose", true)
viper.Set("LogFile", LogFile)
package main
import (
"fmt"
"github.com/spf13/viper"
"time"
)
func main(){
var runtime_viper = viper.New()
//K/V
//etcd
runtime_viper.AddRemoteProvider("etcd", "http://127.0.0.1:4001", "/config/hugo.json")
runtime_viper.SetConfigType("json")
if err := runtime_viper.ReadRemoteConfig();err != nil{
//如果没有读取到配置文件-处理逻辑
}else{
//如果读取到配置文件,但发生了别的错误-处理逻辑
}
//Consul
//{
// "port": 8080,
// "hostname": "liwenzhou.com"
//
//}
viper.AddRemoteProvider("consul", "localhost:8500", "MY_CONSUL_KEY")
viper.SetConfigType("json")
if err := viper.ReadRemoteConfig(); err != nil{
}else {
}
fmt.Println(viper.Get("port"))
fmt.Println(viper.Get("hostname"))
//Firestore
viper.AddRemoteProvider("firestore", "google-cloud-project-id", "collection/document")
viper.SetConfigType("json")
if err := viper.ReadRemoteConfig();err != nil{
}else{
}
//runtime_viper.Unmarshal(&runtime_conf)
go func(){
for{
time.Sleep(time.Second * 5)
err := runtime_viper.WatchRemoteConfig()
if err != nil {
//log.Errorf("unable to read remote config: %v", err)
continue
}
//将新配置反序列化到我们运行时的配置结构体中
//runtime_viper.Unmarshal(&runtime_conf)
}
}()
}
在Viper中,有几种方法可以根据值的类型获取值。存在以下功能和方法:
需要认识到的一件重要事情是,每一个Get方法在找不到值的时候都会返回零值。为了检查给定的键是否存在,提供了 IsSet() 方法
例如:
viper.GetString("logfile")
if viper.GetBool("verbose"){
fmt.Println("verbose enabled")
}
访问器方法也接收深度嵌套键的格式化路径。例如,如果加载下面的JSON文件:
{
"host": {
"address": "localhost",
"port": 5799
},
"datastore": {
"metric": {
"host": "127.0.0.1",
"port": 3099
},
"warehouse": {
"host": "198.0.0.1",
"port": 2112
}
}
}
Viper可以通过传入.分隔的路径来访问嵌套字段:
GetString("datastore.metric.host") // (返回 "127.0.0.1")
这遵守上面建立的优先规则;搜索路径将遍历其余配置注册表,直到找到为止。(译注:因为Viper支持从多种配置来源,例如磁盘上的配置文件>命令行标志位>环境变量>远程Key/Value存储>默认值,我们在查找一个配置的时候如果在当前配置源中没找到,就会继续从后续的配置源查找,直到找到为止。)
例如,在给定此配置文件的情况下,datastore.metric.host和datastore.metric.port均已定义(并且可以被覆盖)。如果另外在默认值中定义了datastore.metric.protocol,Viper也会找到它。
然而,如果datastore.metric被直接赋值覆盖(被flag,环境变量,set()方法等等…),那么datastore.metric的所有子键都将变为未定义状态,它们被高优先级配置级别“遮蔽”(shadowed)了。
最后,如果存在与分隔的键路径匹配的键,则返回其值。例如:
{
"datastore.metric.host": "0.0.0.0",
"host": {
"address": "localhost",
"port": 5799
},
"datastore": {
"metric": {
"host": "127.0.0.1",
"port": 3099
},
"warehouse": {
"host": "198.0.0.1",
"port": 2112
}
}
}
GetString("datastore.metric.host") // 返回 "0.0.0.0"
从Viper中提取子树。
例如,viper实例现在代表了以下配置:
app:
cache1:
max-items: 100
item-size: 64
cache2:
max-items: 200
item-size: 80
执行后:
subv := viper.Sub("app.cache1")
subv现在就代表:
max-items: 100
item-size: 64
假设我们现在有这么一个函数:
func NewCache(cfg *Viper) *Cache {...}
它基于subv格式的配置信息创建缓存。现在,可以轻松地分别创建这两个缓存,如下所示:
cfg1 := viper.Sub("app.cache1")
cache1 := NewCache(cfg1)
cfg2 := viper.Sub("app.cache2")
cache2 := NewCache(cfg2)
type config struct {
Port int
Name string
PathMap string `mapstructure:"path_map"`
}
var C config
err := viper.Unmarshal(&C)
if err != nil {
t.Fatalf("unable to decode into struct, %v", err)
}
如果你想要解析那些键本身就包含.(默认的键分隔符)的配置,你需要修改分隔符:
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
v.SetDefault("chart::values", map[string]interface{}{
"ingress": map[string]interface{}{
"annotations": map[string]interface{}{
"traefik.frontend.rule.type": "PathPrefix",
"traefik.ingress.kubernetes.io/ssl-redirect": "true",
},
},
})
type config struct {
Chart struct{
Values map[string]interface{}
}
}
var C config
v.Unmarshal(&C)
Viper还支持解析到嵌入的结构体:
/*
Example config:
module:
enabled: true
token: 89h3f98hbwf987h3f98wenf89ehf
*/
type config struct {
Module struct {
Enabled bool
moduleConfig `mapstructure:",squash"`
}
}
// moduleConfig could be in a module specific package
type moduleConfig struct {
Token string
}
var C config
err := viper.Unmarshal(&C)
if err != nil {
t.Fatalf("unable to decode into struct, %v", err)
}
Viper在后台使用github.com/mitchellh/mapstructure来解析值,其默认情况下使用mapstructuretag。
注意 当我们需要将viper读取的配置反序列到我们定义的结构体变量中时,一定要使用mapstructuretag哦!
你可能需要将viper中保存的所有设置序列化到一个字符串中,而不是将它们写入到一个文件中。你可以将自己喜欢的格式的序列化器与AllSettings()返回的配置一起使用。
import (
yaml "gopkg.in/yaml.v2"
)
func yamlStringSettings() stirng{
c := viper.AllSettings()
bs, err := yaml.Marshal(c)
if err != nil {
}
return string(bs)
}
Viper是开箱即用的。你不需要配置或初始化即可开始使用Viper。由于大多数应用程序都希望使用单个中央存储库管理它们的配置信息,所以viper包提供了这个功能。它类似于单例模式。
在上面的所有示例中,它们都以其单例风格的方法演示了如何使用viper。
你还可以在应用程序中创建许多不同的viper实例。每个都有自己独特的一组配置和值。每个人都可以从不同的配置文件,key value存储区等读取数据。每个都可以从不同的配置文件、键值存储等中读取。viper包支持的所有功能都被镜像为viper实例的方法。
x := viper.New()
y := viper.New()
x.SetDefault("ContentDir", "content")
y.SetDefault("ContentDir", "foobar")
//...
当使用多个viper实例时,由用户来管理不同的viper实例。
这里用一个demo演示如何在gin框架搭建的web项目中使用viper,使用viper加载配置文件中的信息,并在代码中直接使用viper.GetXXX()方法获取对应的配置值。
package main
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
)
func main(){
viper.SetConfigFile("./conf/config.yaml")
if err := viper.ReadInConfig();err != nil{
panic(fmt.Errorf("%s\n", err))
}
viper.WatchConfig()
r := gin.Default()
r.GET("/version", func(c *gin.Context) {
c.String(http.StatusOK, viper.GetString("version"))
})
if err := r.Run(
fmt.Sprintf(":%d", viper.GetInt("port"))); err != nil {
panic(err)
}
}
除了上面的用法外,我们还可以在项目中定义与配置文件对应的结构体,viper加载完配置信息后使用结构体变量保存配置信息。
package main
import (
"fmt"
"net/http"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
)
type Config struct {
Port int `mapstructure:"port"`
Version string `mapstructure:"version"`
}
var Conf = new(Config)
func main() {
viper.SetConfigFile("./conf/config.yaml") // 指定配置文件路径
err := viper.ReadInConfig() // 读取配置信息
if err != nil { // 读取配置信息失败
panic(fmt.Errorf("Fatal error config file: %s \n", err))
}
// 将读取的配置信息保存至全局变量Conf
if err := viper.Unmarshal(Conf); err != nil {
panic(fmt.Errorf("unmarshal conf failed, err:%s \n", err))
}
// 监控配置文件变化
viper.WatchConfig()
// 注意!!!配置文件发生变化后要同步到全局变量Conf
viper.OnConfigChange(func(in fsnotify.Event) {
fmt.Println("夭寿啦~配置文件被人修改啦...")
if err := viper.Unmarshal(Conf); err != nil {
panic(fmt.Errorf("unmarshal conf failed, err:%s \n", err))
}
})
r := gin.Default()
// 访问/version的返回值会随配置文件的变化而变化
r.GET("/version", func(c *gin.Context) {
c.String(http.StatusOK, Conf.Version)
})
if err := r.Run(fmt.Sprintf(":%d", Conf.Port)); err != nil {
panic(err)
}
}
https://github.com/spf13/viper/blob/master/README.md
优雅的关机就是服务端关机命令发出后不是立即关机,而是等待当前还在处理的请求全部处理完成后再退出程序,是一种对客户端友好的关机方式。而执行ctr-c 关闭服务端时,会强制结束进程导致正在访问的请求出现问题。
Go 1.8版本之后, http.Server 内置的 Shutdown() 方法就支持优雅地关机,具体示例如下:
package main
import (
"context"
"github.com/gin-gonic/gin"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
func main(){
r := gin.Default()
r.GET("/", func(c *gin.Context){
time.Sleep(5 * time.Second)
c.String(http.StatusOK, "Welcone gin Server")
})
srv := &http.Server{
Addr: ":8080",
Handler: r,
}
go func(){
if err := srv.ListenAndServe();err != nil && err != http.ErrServerClosed{
log.Fatalf("Listen: %s", err)
}
}()
quit := make(chan os.Signal, 1)//创建一个接收信号的通道
//kill 默认会发生 syscall.SIGTERM信号
//kill -2 发送 syscall.SIGINT信号, 我们常用的ctr-c就是出发这个信号
//kill -9 发送 syscall.SIGKILL信号,但是不能被捕获,所以不需要添加
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)//此处不会阻塞
<- quit//阻塞在此,当接受到上述两种信号时才会往下执行
log.Println("Shutdown server...")
//创建一个5秒超时的Context
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
log.Fatalf("Server shutdown:", err)
}
log.Printf("Server exiting")
}
优雅关机实现了,那么该如何实现优雅重启呢?
我们可以使用 fvbock/endless 来替换默认的 ListenAndServe启动服务来实现, 示例代码如下:
package main
import (
"github.com/gin-gonic/gin"
"log"
"github.com/fvbock/endless"
"net/http"
"time"
)
func main(){
r := gin.Default()
r.GET("/", func(c *gin.Context){
time.Sleep(5 * time.Second)
c.String(http.StatusOK, "hello girl")
})
if err := endless.ListenAndServe(":8080", r); err!=nil{
log.Fatalf("listen: %s\n", err)
}
log.Println("Server exiting")
}
无论是优雅关机还是优雅重启归根结底都是通过监听特定系统信号,然后执行一定的逻辑处理保障当前系统正在处理的请求被正常处理后再关闭当前进程。使用优雅关机还是使用优雅重启以及怎么实现,这就需要根据项目实际情况来决定了。
MVC模式代表Model-View-Controller (模型-视图-控制器) 模式。这种模式用于应用程序的分层开发
Model(模型) - 模型代表一个存取数据的对象或JAVA POJO。它也可以带有逻辑,在数据变化时更新控制器。
View(视图) - 视图代表模型包含的数据的可视化。
Controller(控制器) - 控制器作用于模型和视图上。它控制数据流向模型对象,并在数据变化时更新视图。它使视图与模型分离开。
如果你只是简单的想要获取命令行参数,可以像下面的代码示例一样使用os.Args来获取命令行参数。
package main
import (
"fmt"
"os"
)
//os.Args demo
func main() {
//os.Args是一个[]string
if len(os.Args) > 0 {
for index, arg := range os.Args {
fmt.Printf("args[%d]=%v\n", index, arg)
}
}
}
将上面的代码执行go build -o "args_demo"编译之后,执行:
$ ./args_demo a b c d
args[0]=./args_demo
args[1]=a
args[2]=b
args[3]=c
args[4]=d
os.Args是一个存储命令行参数的字符串切片,它的第一个元素是执行文件的名称。
https://studygolang.com/pkgdoc
定义命令行flag参数
name := flag.String("name", "张三”, "姓名")
age := flag.Int("age", 18, "年龄")
married := flag.Bool("marrried", false, "婚否")
delay := flag.Duration("d", 0, "时间间隔")
需要注意的是,此时name、age、married、delat均为对应类型的指针。
基本格式如下: flag.TypeVar(Type指针, flag名, 默认值, 帮助信息) 例如我们要定义姓名、年龄、婚否三个命令行参数,我们可以按如下方式定义:
var name string
var age int
var married bool
var delay time.Duration
flag.StringVar(&name, "name", "张三", "姓名")
flag.IntVar(&age, "age", 18, "年龄")
flag.BoolVar(&married, "married", false, "婚否")
flag.DurationVar(&delay, "d", 0, "时间间隔")
定义
package main
import (
"flag"
"fmt"
"time"
)
func main() {
//定义命令行参数方式1
var name string
var age int
var married bool
var delay time.Duration
flag.StringVar(&name, "name", "张三", "姓名")
flag.IntVar(&age, "age", 18, "年龄")
flag.BoolVar(&married, "married", false, "婚否")
flag.DurationVar(&delay, "d", 0, "延迟的时间间隔")
//解析命令行参数
flag.Parse()
fmt.Println(name, age, married, delay)
//返回命令行参数后的其它参数
fmt.Println(flag.Args())
//返回命令行参数后的其他参数个数
fmt.Print(flag.NArg())
//返回使用的命令行从参数个数
fmt.Println(flag.NFlag())
}
使用
命令行参数使用提示:
$ ./flag_demo -help
Usage of ./flag_demo:
-age int
年龄 (default 18)
-d duration
时间间隔
-married
婚否
-name string
姓名 (default "张三")
正常使用命令行flag参数:
$ ./flag_demo -name 沙河娜扎 --age 28 -married=false -d=1h30m
沙河娜扎 28 false 1h30m0s
[]
0
4
使用非flag命令行参数:
$ ./flag_demo a b c
张三 18 false 0s
[a b c]
3
0
分布式ID的特点:
不仅仅是用于用户1D,实际互联网中有很多场景需要能够生成类似MySQL自增1D这样不断增大,同时又不会重复的id。以支持业务中的高并发场景。
比较典型的场景有;电商促销时短时间内会有大量的订单涌入到系统,比如每秒10w+;明星出轨时微博短时问内会产生大量的相关微博转发和评论消息。在这些业务场景下将数据插入数据库之前,我们需要给这些订单和消息先分配一个唯一1D,然后再保存到数据库中。对这个id的要求是希望其中能带有一些时间信息,这样即使我们后端的系统对消息进行了分库分表,也能够以时间顺序对这些消息进行排
序。
单机版
package main
import (
"fmt"
"github.com/bwmarrin/snowflake"
"time"
)
var node *snowflake.Node
func Init(startTime string, machineID int64) (err error) {
//指定时间因子:从哪一年开始
var st time.Time
st, err = time.Parse("2006-01-02", startTime)
if err != nil {
return
}
snowflake.Epoch = st.UnixNano() / 1000000 //需要毫秒值
node, err = snowflake.NewNode(machineID)
return
}
func GenID() int64 {
return node.Generate().Int64() //.String()
}
func main() {
if err := Init("2020-07-01", 1); err != nil {
fmt.Printf("init failed, err:%v\n", err)
return
}
id := GenID()
fmt.Println(id)
}
https://github.com/sony/sonyflake
sonyflake是Sony公司的一个开源项目,基本思路和snowflake差不多,不过位分配上稍有不同:
这里的时间只用了39个bit, 但时间的单位变成了10ms,所以理论上比41位表示的时间还要久(174年)。
sequence ID 和之前的定义一致,Machine 工D 其实就是节点id。sonyflake 库有以下配置参数:
type Settings struct{
StartTime time.Time
MachineID func() (uint16, error)
CheckMachineID func(uint16) bool
}
package main
import (
"fmt"
"github.com/sony/sonyflake"
"time"
)
var (
sonyFlake *sonyflake.Sonyflake
sonyMachineID uint16
)
func getMachineID() (uint16, error) {
return sonyMachineID, nil
}
// 传入当前的机器ID
func Init(startTime string, machineId uint16) (err error) {
sonyMachineID = machineId
var st time.Time
st, err = time.Parse("2006-01-02", startTime)
if err != nil {
return
}
settings := sonyflake.Settings{
StartTime: st,
MachineID: getMachineID,
}
sonyFlake = sonyflake.NewSonyflake(settings)
return
}
func GenID() (id uint64, err error) {
if sonyFlake == nil {
err = fmt.Errorf("sony flake not inited")
return
}
id, err = sonyFlake.NextID()
return
}
func main() {
if err := Init("2020-07-01", 1); err != nil {
fmt.Printf("Init failed, err:%v", err)
return
}
id, _ := GenID()
fmt.Println(id)
}
在web开发中一个不可避免的环节就是对请求参数进行校验,通常我们会在代码中定义与请求参数相对应的模型(结构体),借助模型绑定快捷地解析请求中的参数,例如 gin 框架中的Bind和ShouldBind系列方法。本文就以 gin 框架的请求参数校验为例,介绍一些validator库的实用技巧。
gin框架使用github.com/go-playground/validator进行参数校验,目前已经支持github.com/go-playground/validator/v10了,我们需要在定义结构体时使用 binding tag标识相关校验规则,可以查看validator文档查看支持的所有 tag。
首先来看gin框架内置使用validator做参数校验的基本示例。
package main
import (
"net/http"
"github.com/gin-gonic/gin"
)
type SignUpParam struct {
Age uint8 `json:"age" binding:"gte=1,lte=130"`
Name string `json:"name" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
RePassword string `json:"re_password" binding:"required,eqfield=Password"`
}
func main() {
r := gin.Default()
r.POST("/signup", func(c *gin.Context) {
var u SignUpParam
if err := c.ShouldBind(&u); err != nil {
c.JSON(http.StatusOK, gin.H{
"msg": err.Error(),
})
return
}
// 保存入库等业务逻辑代码...
c.JSON(http.StatusOK, "success")
})
_ = r.Run(":8999")
}
我们使用curl发送一个POST请求测试下:
curl -H "Content-type: application/json" -X POST -d '{"name":"q1mi","age":18,"email":"123.com"}' http://127.0.0.1:8999/signup
输出结果:
{"msg":"Key: 'SignUpParam.Email' Error:Field validation for 'Email' failed on the 'email' tag\nKey: 'SignUpParam.Password' Error:Field validation for 'Password' failed on the 'required' tag\nKey: 'SignUpParam.RePassword' Error:Field validation for 'RePassword' failed on the 'required' tag"}
从最终的输出结果可以看到 validator 的检验生效了,但是错误提示的字段不是特别友好,我们可能需要将它翻译成中文。
validator库本身是支持国际化的,借助相应的语言包可以实现校验错误提示信息的自动翻译。下面的示例代码演示了如何将错误提示信息翻译成中文,翻译成其他语言的方法类似。
package main
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/locales/en"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
enTranslations "github.com/go-playground/validator/v10/translations/en"
zhTranslations "github.com/go-playground/validator/v10/translations/zh"
)
// 定义一个全局翻译器T
var trans ut.Translator
// InitTrans 初始化翻译器
func InitTrans(locale string) (err error) {
// 修改gin框架中的Validator引擎属性,实现自定制
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
zhT := zh.New() // 中文翻译器
enT := en.New() // 英文翻译器
// 第一个参数是备用(fallback)的语言环境
// 后面的参数是应该支持的语言环境(支持多个)
// uni := ut.New(zhT, zhT) 也是可以的
uni := ut.New(enT, zhT, enT)
// locale 通常取决于 http 请求头的 'Accept-Language'
var ok bool
// 也可以使用 uni.FindTranslator(...) 传入多个locale进行查找
trans, ok = uni.GetTranslator(locale)
if !ok {
return fmt.Errorf("uni.GetTranslator(%s) failed", locale)
}
// 注册翻译器
switch locale {
case "en":
err = enTranslations.RegisterDefaultTranslations(v, trans)
case "zh":
err = zhTranslations.RegisterDefaultTranslations(v, trans)
default:
err = enTranslations.RegisterDefaultTranslations(v, trans)
}
return
}
return
}
type SignUpParam struct {
Age uint8 `json:"age" binding:"gte=1,lte=130"`
Name string `json:"name" binding:"required"`
Email string `json:"email" binding:"required,email"`
Password string `json:"password" binding:"required"`
RePassword string `json:"re_password" binding:"required,eqfield=Password"`
}
func main() {
if err := InitTrans("zh"); err != nil {
fmt.Printf("init trans failed, err:%v\n", err)
return
}
r := gin.Default()
r.POST("/signup", func(c *gin.Context) {
var u SignUpParam
if err := c.ShouldBind(&u); err != nil {
// 获取validator.ValidationErrors类型的errors
errs, ok := err.(validator.ValidationErrors)
if !ok {
// 非validator.ValidationErrors类型错误直接返回
c.JSON(http.StatusOK, gin.H{
"msg": err.Error(),
})
return
}
// validator.ValidationErrors类型错误则进行翻译
c.JSON(http.StatusOK, gin.H{
"msg":errs.Translate(trans),
})
return
}
// 保存入库等具体业务逻辑代码...
c.JSON(http.StatusOK, "success")
})
_ = r.Run(":8999")
}
上面的错误提示看起来是可以了,但是还是差点意思,首先是错误提示中的字段并不是请求中使用的字段,例如:RePassword是我们后端定义的结构体中的字段名,而请求中使用的是re_password字段。如何是错误提示中的字段使用自定义的名称,例如jsontag指定的值呢?
只需要在初始化翻译器的时候像下面一样添加一个获取json tag的自定义方法即可。
// InitTrans 初始化翻译器
func InitTrans(locale string) (err error) {
// 修改gin框架中的Validator引擎属性,实现自定制
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
// 注册一个获取json tag的自定义方法
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
zhT := zh.New() // 中文翻译器
enT := en.New() // 英文翻译器
// 第一个参数是备用(fallback)的语言环境
// 后面的参数是应该支持的语言环境(支持多个)
// uni := ut.New(zhT, zhT) 也是可以的
uni := ut.New(enT, zhT, enT)
// ... liwenzhou.com ...
}
再尝试发请求,看一下效果:
{"msg":{"SignUpParam.email":"email必须是一个有效的邮箱","SignUpParam.password":"password为必填字段","SignUpParam.re_password":"re_password为必填字段"}}
可以看到现在错误提示信息中使用的就是我们结构体中jsontag设置的名称了。
但是还是有点瑕疵,那就是最终的错误提示信息中心还是有我们后端定义的结构体名称——SignUpParam,这个名称其实是不需要随错误提示返回给前端的,前端并不需要这个值。我们需要想办法把它去掉。
这里参考https://github.com/go-playground/validator/issues/633#issuecomment-654382345提供的方法,定义一个去掉结构体名称前缀的自定义方法:
func removeTopStruct(fields map[string]string) map[string]string {
res := map[string]string{}
for field, err := range fields {
res[field[strings.Index(field, ".")+1:]] = err
}
return res
}
我们在代码中使用上述函数将翻译后的errors做一下处理即可:
if err := c.ShouldBind(&u); err != nil {
// 获取validator.ValidationErrors类型的errors
errs, ok := err.(validator.ValidationErrors)
if !ok {
// 非validator.ValidationErrors类型错误直接返回
c.JSON(http.StatusOK, gin.H{
"msg": err.Error(),
})
return
}
// validator.ValidationErrors类型错误则进行翻译
// 并使用removeTopStruct函数去除字段名中的结构体名称标识
c.JSON(http.StatusOK, gin.H{
"msg": removeTopStruct(errs.Translate(trans)),
})
return
}
看一下最终的效果:
{"msg":{"email":"email必须是一个有效的邮箱","password":"password为必填字段","re_password":"re_password为必填字段"}}
这一次看起来就比较符合我们预期的标准了。
上面的校验还是有点小问题,就是当涉及到一些复杂的校验规则,比如re_password字段需要与password字段的值相等这样的校验规则,我们的自定义错误提示字段名称方法就不能很好解决错误提示信息中的其他字段名称了。
curl -H "Content-type: application/json" -X POST -d '{"name":"q1mi","age":18,"email":"123.com","password":"123","re_password":"321"}' http://127.0.0.1:8999/signup
最后输出的错误提示信息如下:
{"msg":{"email":"email必须是一个有效的邮箱","re_password":"re_password必须等于Password"}}
可以看到re_password字段的提示信息中还是出现了Password这个结构体字段名称。这有点小小的遗憾,毕竟自定义字段名称的方法不能影响被当成param传入的值。
此时如果想要追求更好的提示效果,将上面的Password字段也改为和json tag一致的名称,就需要我们自定义结构体校验的方法。
例如,我们为SignUpParam自定义一个校验方法如下:
// SignUpParamStructLevelValidation 自定义SignUpParam结构体校验函数
func SignUpParamStructLevelValidation(sl validator.StructLevel) {
su := sl.Current().Interface().(SignUpParam)
if su.Password != su.RePassword {
// 输出错误提示信息,最后一个参数就是传递的param
sl.ReportError(su.RePassword, "re_password", "RePassword", "eqfield", "password")
}
}
然后在初始化校验器的函数中注册该自定义校验方法即可:
func InitTrans(locale string) (err error) {
// 修改gin框架中的Validator引擎属性,实现自定制
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
// ... liwenzhou.com ...
// 为SignUpParam注册自定义校验方法
v.RegisterStructValidation(SignUpParamStructLevelValidation, SignUpParam{})
zhT := zh.New() // 中文翻译器
enT := en.New() // 英文翻译器
// ... liwenzhou.com ...
}
最终再请求一次,看一下效果:
{"msg":{"email":"email必须是一个有效的邮箱","re_password":"re_password必须等于password"}}
这一次re_password字段的错误提示信息就符合我们预期了。
由于本篇博客示例代码较多,我已经把文中示例代码上传到我的github仓库——https://github.com/Q1mi/validator_demo,大家可以查看完整的示例代码。
本文总结的gin框架中validator的使用技巧同样也适用于直接使用validator库,区别仅仅在于我们配置的是gin框架中的校验器还是由validator.New()创建的校验器。同时使用validator库确实能够在一定程度上减少我们的编码量,但是它不太可能完美解决我们所有需求,所以你需要找到两者之间的平衡点。
HTTP是一个无状态的协议,一次请求结束后,下次再发送服务器就不知道这个请求是谁发的了(同一个IP不代表同一个用户),在web应用中,用户的认证和鉴权是非常重要的一环,时间中有多种可用方案。
IT是JSON Web Token 的缩写,是为了在网络应用环境问传递声明而执行的一种基于SON的开放标准((RFC 7519)。JWT 本身没有定义任何技术实现,它只是定义了一种基于 Token 的会话管理的规则,涵盖 Token 需要包含的标准内容和 Token 的生成过程,特别适用于分布式站点的单点登录(SSO)场景。。
JWT全称JSON Web Token是一种跨域认证解决方案,属于一个开放的标准,它规定了一种Token实现方式,目前多用于前后端分离项目和OAuth2.0业务场景下。
前面讲的Token,都是Access Token,也就是访问资源接口时所需要的Token,还有另外一种
Token,Refresh Token,通常情况下,Refresh Token的有效期会比较长,而Access Token的有效期 比较短,当Access Token由于过期而失效时,使用Refresh Token就可以获取到新的Access Token, 如果Refresh Token也失效了,用户就只能重新登录了。
go get github.com/golang-jwt/jwt/v4
本文将使用这个库来实现我们生成JWT和解析JWT的功能。
#### 使用
###### 默认Claim
如果我们直接使用JWT中默认的字段,没有其他定制化的需求则可以直接使用这个包中的和方法快速生成和解析token。
````go
// 用于签名的字符串
var mySigningKey = []byte("liwenzhou.com")
// GenRegisteredClaims 使用默认声明创建jwt
func GenRegisteredClaims() (string, error) {
// 创建 Claims
claims := &jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 24)), // 过期时间
Issuer: "qimi", // 签发人
}
// 生成token对象
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// 生成签名字符串
return token.SignedString(mySigningKey)
}
// ParseRegisteredClaims 解析jwt
func ValidateRegisteredClaims(tokenString string) bool {
// 解析token
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return mySigningKey, nil
})
if err != nil { // 解析token失败
return false
}
return token.Valid
}
我们需要定制自己的需求来决定JWT中保存哪些数据,比如我们规定在JWT中要存储username信息,那么我们就定义一个MyClaims结构体如下:
// CustomClaims 自定义声明类型 并内嵌jwt.RegisteredClaims
// jwt包自带的jwt.RegisteredClaims只包含了官方字段
// 假设我们这里需要额外记录一个username字段,所以要自定义结构体
// 如果想要保存更多信息,都可以添加到这个结构体中
type CustomClaims struct {
// 可根据需要自行添加字段
Username string `json:"username"`
jwt.RegisteredClaims // 内嵌标准的声明
}
然后我们定义JWT的过期时间,这里以24小时为例:
const TokenExpireDuration = time.Hour * 24
接下来还需要定义一个用于签名的字符串:
// CustomSecret 用于加盐的字符串
var CustomSecret = []byte("夏天夏天悄悄过去")
我们可以根据自己的业务需要封装一个生成 token 的函数。
// GenToken 生成JWT
func GenToken(username string) (string, error) {
// 创建一个我们自己的声明
claims := CustomClaims{
username, // 自定义字段
jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(TokenExpireDuration)),
Issuer: "my-project", // 签发人
},
}
// 使用指定的签名方法创建签名对象
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// 使用指定的secret签名并获得完整的编码后的字符串token
return token.SignedString(CustomSecret)
}
根据给定的 JWT 字符串,解析出数据。
// ParseToken 解析JWT
func ParseToken(tokenString string) (*CustomClaims, error) {
// 解析token
// 如果是自定义Claim结构体则需要使用 ParseWithClaims 方法
token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (i interface{}, err error) {
// 直接使用标准的Claim则可以直接使用Parse方法
//token, err := jwt.Parse(tokenString, func(token *jwt.Token) (i interface{}, err error) {
return CustomSecret, nil
})
if err != nil {
return nil, err
}
// 对token对象中的Claim进行类型断言
if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid { // 校验token
return claims, nil
}
return nil, errors.New("invalid token")
}
首先我们注册一条路由/auth,对外提供获取Token的渠道:
r.POST("/auth", authHandler)
我们的authHandler定义如下
func authHandler(c *gin.Context) {
// 用户发送用户名和密码过来
var user UserInfo
err := c.ShouldBind(&user)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": 2001,
"msg": "无效的参数",
})
return
}
// 校验用户名和密码是否正确
if user.Username == "q1mi" && user.Password == "q1mi123" {
// 生成Token
tokenString, _ := GenToken(user.Username)
c.JSON(http.StatusOK, gin.H{
"code": 2000,
"msg": "success",
"data": gin.H{"token": tokenString},
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 2002,
"msg": "鉴权失败",
})
return
}
用户通过上面的接口获取Token之后,后续就会携带着Token再来请求我们的其他接口,这个时候就需要对这些请求的Token进行校验操作了,很显然我们应该实现一个检验Token的中间件,具体实现如下:
// JWTAuthMiddleware 基于JWT的认证中间件
func JWTAuthMiddleware() func(c *gin.Context) {
return func(c *gin.Context) {
// 客户端携带Token有三种方式 1.放在请求头 2.放在请求体 3.放在URI
// 这里假设Token放在Header的Authorization中,并使用Bearer开头
// 这里的具体实现方式要依据你的实际业务情况决定
authHeader := c.Request.Header.Get("Authorization")
if authHeader == "" {
c.JSON(http.StatusOK, gin.H{
"code": 2003,
"msg": "请求头中auth为空",
})
c.Abort()
return
}
// 按空格分割
parts := strings.SplitN(authHeader, " ", 2)
if !(len(parts) == 2 && parts[0] == "Bearer") {
c.JSON(http.StatusOK, gin.H{
"code": 2004,
"msg": "请求头中auth格式有误",
})
c.Abort()
return
}
// parts[1]是获取到的tokenString,我们使用之前定义好的解析JWT的函数来解析它
mc, err := ParseToken(parts[1])
if err != nil {
c.JSON(http.StatusOK, gin.H{
"code": 2005,
"msg": "无效的Token",
})
c.Abort()
return
}
// 将当前请求的username信息保存到请求的上下文c上
c.Set("username", mc.Username)
c.Next() // 后续的处理函数可以用过c.Get("username")来获取当前请求的用户信息
}
}
注册一个/home路由,发个请求验证一下吧。
r.GET("/home", JWTAuthMiddleware(), homeHandler)
func homeHandler(c *gin.Context) {
username := c.MustGet("username").(string)
c.JSON(http.StatusOK, gin.H{
"code": 2000,
"msg": "success",
"data": gin.H{"username": username},
})
}
如果不想自己实现上述功能,你也可以使用Github上别人封装好的包,比如https://github.com/appleboy/gin-jwt。
在某些业务场景下,我们可能还需要使用refresh token。
前面讲的Token,都是Access Token,也就是访问资源接口时所需要的Token,还有另外一种
Token,Refresh Token,通常情况下,Refresh Token的有效期会比较长,而Access Token的有效期 比较短,当Access Token由于过期而失效时,使用Refresh Token就可以获取到新的Access Token, 如果Refresh Token也失效了,用户就只能重新登录了。
Token:接口获取新Access Token的拦截器。
func GenToken1(userID int64) (accessToken, refreshToken string, err error){
//创建一个我们自己的声明
c := MyClaims{
userID,
"username",
jwt.StandardClaims{
ExpiresAt: time.Now().Add(TokenExpireDuration).Unix(),//过期时间
Issuer: "bluebell",//签发人
},
}
//加密并获得完整的编码后的字符串token
accessToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256,c).SignedString(mySercet)
// refresh token 不需要存任何自定义数据
refreshToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256,jwt.StandardClaims{
ExpiresAt: time.Now().Add(time.Second * 30).Unix(), //过期时间
Issuer: "bluebell", //签发人
}).SignedString(mySercet)
return
}
func ParseToken1(tokenString string)(claims *MyClaims, err error){
//解析token
var token *jwt.Token
//需要主动申请内存
claims = new(MyClaims)
token,err = jwt.ParseWithClaims(tokenString, claims, keyFunc)
if err != nil {
return
}
if !token.Valid{//校验token
err = errors.New("invalid token")
}
return
}
func RefreshToken(aToken, rToken string)(newAToken, newRToken string, err error){
//refresh token 无效直接返回
if _, err = jwt.Parse(rToken, keyFunc);err != nil{
return
}
//从旧的access Token中解析出claims数据
var claims MyClaims
_, err = jwt.ParseWithClaims(aToken,&claims,keyFunc)
v,_ := err.(*jwt.ValidationError)
//当access token过期错误,并且refresh token 没有过期就创建一个新的access token
if v.Errors == jwt.ValidationErrorExpired{
return GenToken1(claims.UserID)
}
return
}
这里可以参考 RFC 6749 OAuth2.0中关于refresh token的介绍https://datatracker.ietf.org/doc/html/rfc6749#section-1.5
今天我们要介绍一个神器——Air能够实时监听项目的代码文件,在代码发生变更之后自动重新编译并执行,大大提高gin框架项目的开发效率。
之前使用Python编写Web项目的时候,常见的Flask或Django框架都是支持实时加载的,你修改了项目代码之后,程序能够自动重新加载并执行(live-reload),这在日常的开发阶段是十分方便的。
在使用Go语言的gin框架在本地做开发调试的时候,经常需要在变更代码之后频繁的按下Ctrl+C停止程序并重新编译再执行,这样就不是很方便。
怎样才能在基于gin框架开发时实现实时加载功能呢?像这种烦恼肯定不会只是你一个人的烦恼,所以我报着肯定有现成轮子的心态开始了全网大搜索。果不其然就在Github上找到了一个工具:Air。它支持以下特性:
完整的air_example.conf示例配置如下,可以根据自己的需要修改。
# [Air](https://github.com/cosmtrek/air) TOML 格式的配置文件
# 工作目录
# 使用 . 或绝对路径,请注意 `tmp_dir` 目录必须在 `root` 目录下
root = "."
tmp_dir = "tmp"
[build]
# 只需要写你平常编译使用的shell命令。你也可以使用 `make`
# Windows平台示例: cmd = "go build -o tmp\main.exe ."
cmd = "go build -o ./tmp/main ."
# 由`cmd`命令得到的二进制文件名
# Windows平台示例:bin = "tmp\main.exe"
bin = "tmp/main"
# 自定义执行程序的命令,可以添加额外的编译标识例如添加 GIN_MODE=release
# Windows平台示例:full_bin = "tmp\main.exe"
full_bin = "APP_ENV=dev APP_USER=air ./tmp/main"
# 监听以下文件扩展名的文件.
include_ext = ["go", "tpl", "tmpl", "html"]
# 忽略这些文件扩展名或目录
exclude_dir = ["assets", "tmp", "vendor", "frontend/node_modules"]
# 监听以下指定目录的文件
include_dir = []
# 排除以下文件
exclude_file = []
# 如果文件更改过于频繁,则没有必要在每次更改时都触发构建。可以设置触发构建的延迟时间
delay = 1000 # ms
# 发生构建错误时,停止运行旧的二进制文件。
stop_on_error = true
# air的日志文件名,该日志文件放置在你的`tmp_dir`中
log = "air_errors.log"
[log]
# 显示日志时间
time = true
[color]
# 自定义每个部分显示的颜色。如果找不到颜色,使用原始的应用程序日志。
main = "magenta"
watcher = "cyan"
build = "yellow"
runner = "green"
[misc]
# 退出时删除tmp目录
clean_on_exit = true
package models
import (
"fmt"
"testing"
"unsafe"
)
type s1 struct{
a int8
c int8
b string
}
type s2 struct{
a int8
b string
c int8
}
func TestString(t *testing.T){
v1 := s1{
a:1,
b:"七米",
c:2,
}
v2 := s2{
a:1,
b:"七米",
c:2,
}
fmt.Println(unsafe.Sizeof(v1), unsafe.Sizeof(v2))
}
有时候一份清晰明了的接口文档能够极大地提高前后端双方的沟通效率和开发效率。本文将介绍如何使用swagger生成接口文档。
Swagger本质上是一种用于描述使用JSON表示的RESTful API的接口描述语言。Swagger与一组开源软件工具一起使用,以设计、构建、记录和使用RESTful Web服务。Swagger包括自动文档,代码生成和测试用例生成。
在前后端分离的项目开发过程中,如果后端同学能够提供一份清晰明了的接口文档,那么就能极大地提高大家的沟通效率和开发效率。可是编写接口文档历来都是令人头痛的,而且后续接口文档的维护也十分耗费精力。
最好是有一种方案能够既满足我们输出文档的需要又能随代码的变更自动更新,而Swagger正是那种能帮我们解决接口文档问题的工具。
这里以gin框架为例,使用gin-swagger库以使用Swagger 2.0自动生成RESTful API文档。
想要使用gin-swagger为你的代码自动生成接口文档,一般需要下面三个步骤:
在程序入口main函数上以注释的方式写下项目相关介绍信息。
package main
// @title 这里写标题
// @version 1.0
// @description 这里写描述信息
// @termsOfService http://swagger.io/terms/
// @contact.name 这里写联系人信息
// @contact.url http://www.swagger.io/support
// @contact.email [email protected]
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @host 这里写接口服务的host
// @BasePath 这里写base path
func main() {
r := gin.New()
// liwenzhou.com ...
r.Run()
}
在你代码中处理请求的接口函数(通常位于controller层)按如下方式写上注释:
// GetPostListHandler2 升级版帖子列表接口
// @Summary 升级版帖子列表接口
// @Description 可按社区按时间或分数排序查询帖子列表接口
// @Tags 帖子相关接口
// @Accept application/json
// @Produce application/json
// @Param Authorization header string false "Bearer 用户令牌"
// @Param object query models.ParamPostList false "查询参数"
// @Security ApiKeyAuth
// @Success 200 {object} _ResponsePostList
// @Router /posts2 [get]
func GetPostListHandler2(c *gin.Context) {
// GET请求参数(query string):/api/v1/posts2?page=1&size=10&order=time
// 初始化结构体时指定初始参数
p := &models.ParamPostList{
Page: 1,
Size: 10,
Order: models.OrderTime,
}
if err := c.ShouldBindQuery(p); err != nil {
zap.L().Error("GetPostListHandler2 with invalid params", zap.Error(err))
ResponseError(c, CodeInvalidParam)
return
}
data, err := logic.GetPostListNew(p)
// 获取数据
if err != nil {
zap.L().Error("logic.GetPostList() failed", zap.Error(err))
ResponseError(c, CodeServerBusy)
return
}
ResponseSuccess(c, data)
// 返回响应
}
上面注释中参数类型使用了object,models.ParamPostList具体定义如下:
// bluebell/models/params.go
// ParamPostList 获取帖子列表query string参数
type ParamPostList struct {
CommunityID int64 `json:"community_id" form:"community_id"` // 可以为空
Page int64 `json:"page" form:"page" example:"1"` // 页码
Size int64 `json:"size" form:"size" example:"10"` // 每页数据量
Order string `json:"order" form:"order" example:"score"` // 排序依据
}
响应数据类型也使用的object,我个人习惯在controller层专门定义一个docs_models.go文件来存储文档中使用的响应数据model。
// bluebell/controller/docs_models.go
// _ResponsePostList 帖子列表接口响应数据
type _ResponsePostList struct {
Code ResCode `json:"code"` // 业务响应状态码
Message string `json:"message"` // 提示信息
Data []*models.ApiPostDetail `json:"data"` // 数据
}
编写完注释后,使用以下命令安装swag工具:
go get -u github.com/swaggo/swag/cmd/swag
在项目根目录执行以下命令,使用swag工具生成接口文档数据。
swag init
执行完上述命令后,如果你写的注释格式没问题,此时你的项目根目录下会多出一个docs文件夹。
./docs
├── docs.go
├── swagger.json
└── swagger.yaml
import (
// liwenzhou.com ...
_ "bluebell/docs" // 千万不要忘了导入把你上一步生成的docs
gs "github.com/swaggo/gin-swagger"
"github.com/swaggo/gin-swagger/swaggerFiles"
"github.com/gin-gonic/gin"
)
注册swagger api相关路由
r.GET("/swagger/*any", gs.WrapHandler(swaggerFiles.Handler))
把你的项目程序运行起来,打开浏览器访问http://localhost:8080/swagger/index.html就能看到Swagger 2.0 Api文档了。
package controllers
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"testing"
)
func TestCreatePostHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.Default()
url := "/api/v1/post"
r.POST(url, CreatePostHandler)
body := `{
"community_id": 1,
"title": "test",
"content": "just a test"
}`
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewReader([]byte(body)))
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
// 方法1 :判断响应内容中是不是包含指定的字符串
//assert.Contains(t, w.Body.String(),"用户没有登录")
// 方法2 :将相应的内容反序列化到res,然后判断各字段与其是否一致
res := new(ResponseData)
if err := json.Unmarshal(w.Body.Bytes(), res); err !=nil{
t.Fatalf("json.Unmarshal w.Body faild , err:%v\n",err)
}
assert.Equal(t, res.Code, CodeNeedLogin)
}
package mysql
import (
"testing"
"web_app/models"
"web_app/settings"
)
// 因为全局的 db 一开始是一个空指针,所以需要 init 才可以正常运行
func init(){
dbcfg := &settings.MySQLConfig{
Host: "127.0.0.1",
User: "root",
Password: "DENGni1314",
DB: "bluebell",
Port: 3306,
MaxOpenConns: 10,
MaxIdleConns: 10,
}
err := Init(dbcfg)
if err != nil{
panic(err)
}
}
func TestCreatePost(t *testing.T){
post := &models.Post{
ID: 100000,
AuthorID: 123,
CommunityID: 1,
Title: "test",
Content: "just a test",
}
err := CreatePost(post)
if err != nil {
t.Fatalf("CreatePost insert record into mysql failed, err:%v\n",err)
}
t.Logf("CreatePost insert record into mysql success")
}
在项目正式上线之前,我们通常需要通过压测来评估当前系统能够支撑的请求量、排查可能存在的隐藏bug,同时了解了程序的实际处理能力能够帮我们更好的匹配项目的实际需求,节约资源成本。
ab全称Apache Bench,是Apache自带的性能测试工具。使用这个工具,只须指定同时连接数、请求数以及URL,即可测试网站或网站程序的性能。
通过ab发送请求模拟多个访问者同时对某一URL地址进行访问,可以得到每秒传送字节数、每秒处理请求数、每请求处理时间等统计数据。
例如测试某个GET请求接口:
ab -n 10000 -c 100 -t 10 "http://127.0.0.1:8080/api/v1/posts?size=10"
测试POST请求接口:
ab -n 10000 -c 100 -t 10 -p post.json -T "application/json" "http://127.0.0.1:8080/api/v1/post"
wrk是一款开源的HTTP性能测试工具,它和上面提到的ab同属于HTTP性能测试工具,它比ab功能更加强大,可以通过编写lua脚本来支持更加复杂的测试场景。
使用示例
wrk -t8 -c100 -d30s --latency http://127.0.0.1:8080/api/v1/posts?size=10
Running 30s test @ http://127.0.0.1:8080/api/v1/posts?size=10
8 threads and 100 connections
Thread Stats Avg Stdev Max +/- Stdev
Latency 14.55ms 2.02ms 31.59ms 76.70%
Req/Sec 828.16 85.69 0.97k 60.46%
Latency Distribution
50% 14.44ms
75% 15.76ms
90% 16.63ms
99% 21.07ms
198091 requests in 30.05s, 29.66MB read
Requests/sec: 6592.29
Transfer/sec: 0.99MB
==========================BENCHMARK==========================
URL: http://127.0.0.1:8080/api/v1/posts?size=10
Used Connections: 100
Used Threads: 8
Total number of calls: 10000
===========================TIMINGS===========================
Total time passed: 2.74s
Avg time per request: 27.11ms
Requests per second: 3644.53
Median time per request: 26.88ms
99th percentile time: 39.16ms
Slowest time for request: 45.00ms
=============================DATA=============================
Total response body sizes: 340000
Avg response body per request: 34.00 Byte
Transfer rate per second: 123914.11 Byte/s (0.12 MByte/s)
==========================RESPONSES==========================
20X Responses: 10000 (100.00%)
30X Responses: 0 (0.00%)
40X Responses: 0 (0.00%)
50X Responses: 0 (0.00%)
Errors: 0 (0.00%)
限流又称为流量控制(流控),通常是指限制到达系统的并发请求数。
我们生活中也会经常遇到限流的场景,比如:某景区限制每日进入景区的游客数量为8万人;沙河地铁站早高峰通过站外排队逐一放行的方式限制同一时间进入车站的旅客数量等。
限流虽然会影响部分用户的使用体验,但是却能在一定程度上报障系统的稳定性,不至于崩溃(大家都没了用户体验)。
而互联网上类似需要限流的业务场景也有很多,比如电商系统的秒杀、微博上突发热点新闻、双十一购物节、12306抢票等等。这些场景下的用户请求量通常会激增,远远超过平时正常的请求量,此时如果不加任何限制很容易就会将后端服务打垮,影响服务的稳定性。
此外,一些厂商公开的API服务通常也会限制用户的请求次数,比如百度地图开放平台等会根据用户的付费情况来限制用户的请求数等。
漏桶法限流很好理解,假设我们有一个水桶按固定的速率向下方滴落一滴水,无论有多少请求,请求的速率有多大,都按照固定的速率流出,对应到系统中就是按照固定的速率处理请求。
漏桶法的关键点在于漏桶始终按照固定的速率运行,但是它并不能很好的处理有大量突发请求的场景,毕竟在某些场景下我们可能需要提高系统的处理效率,而不是一味的按照固定速率处理请求。
关于漏桶的实现,uber团队有一个开源的github.com/uber-go/ratelimit库。 这个库的使用方法比较简单,Take() 方法会返回漏桶下一次滴水的时间。
import (
"fmt"
"time"
"go.uber.org/ratelimit"
)
func main() {
rl := ratelimit.New(100) // per second
prev := time.Now()
for i := 0; i < 10; i++ {
now := rl.Take()
fmt.Println(i, now.Sub(prev))
prev = now
}
// Output:
// 0 0
// 1 10ms
// 2 10ms
// 3 10ms
// 4 10ms
// 5 10ms
// 6 10ms
// 7 10ms
// 8 10ms
// 9 10ms
}
它的源码实现也比较简单,这里大致说一下关键的地方,有兴趣的同学可以自己去看一下完整的源码。
限制器是一个接口类型,其要求实现一个Take()方法:
type Limiter interface {
// Take方法应该阻塞已确保满足 RPS
Take() time.Time
}
实现限制器接口的结构体定义如下,这里可以重点留意下maxSlack字段,它在后面的Take()方法中的处理。
type limiter struct {
sync.Mutex // 锁
last time.Time // 上一次的时刻
sleepFor time.Duration // 需要等待的时间
perRequest time.Duration // 每次的时间间隔
maxSlack time.Duration // 最大的富余量
clock Clock // 时钟
}
limiter结构体实现Limiter接口的Take()方法内容如下:
// Take 会阻塞确保两次请求之间的时间走完
// Take 调用平均数为 time.Second/rate.
func (t *limiter) Take() time.Time {
t.Lock()
defer t.Unlock()
now := t.clock.Now()
// 如果是第一次请求就直接放行
if t.last.IsZero() {
t.last = now
return t.last
}
// sleepFor 根据 perRequest 和上一次请求的时刻计算应该sleep的时间
// 由于每次请求间隔的时间可能会超过perRequest, 所以这个数字可能为负数,并在多个请求之间累加
t.sleepFor += t.perRequest - now.Sub(t.last)
// 我们不应该让sleepFor负的太多,因为这意味着一个服务在短时间内慢了很多随后会得到更高的RPS。
if t.sleepFor < t.maxSlack {
t.sleepFor = t.maxSlack
}
// 如果 sleepFor 是正值那么就 sleep
if t.sleepFor > 0 {
t.clock.Sleep(t.sleepFor)
t.last = now.Add(t.sleepFor)
t.sleepFor = 0
} else {
t.last = now
}
return t.last
}
上面的代码根据记录每次请求的间隔时间和上一次请求的时刻来计算当次请求需要阻塞的时间——sleepFor,这里需要留意的是sleepFor的值可能为负,在经过间隔时间长的两次访问之后会导致随后大量的请求被放行,所以代码中针对这个场景有专门的优化处理。创建限制器的New()函数中会为maxSlack设置初始值,也可以通过WithoutSlack这个Option取消这个默认值。
func New(rate int, opts ...Option) Limiter {
l := &limiter{
perRequest: time.Second / time.Duration(rate),
maxSlack: -10 * time.Second / time.Duration(rate),
}
for _, opt := range opts {
opt(l)
}
if l.clock == nil {
l.clock = clock.New()
}
return l
}
令牌桶其实和漏桶的原理类似,令牌桶按固定的速率往桶里放入令牌,并且只要能从桶里取出令牌就能通过,令牌桶支持突发流量的快速处理。
对于从桶里取不到令牌的场景,我们可以选择等待也可以直接拒绝并返回。
对于令牌桶的Go语言实现,大家可以参照github.com/juju/ratelimit库。这个库支持多种令牌桶模式,并且使用起来也比较简单。
创建令牌桶的方法:
// 创建指定填充速率和容量大小的令牌桶
func NewBucket(fillInterval time.Duration, capacity int64) *Bucket
// 创建指定填充速率、容量大小和每次填充的令牌数的令牌桶
func NewBucketWithQuantum(fillInterval time.Duration, capacity, quantum int64) *Bucket
// 创建填充速度为指定速率和容量大小的令牌桶
// NewBucketWithRate(0.1, 200) 表示每秒填充20个令牌
func NewBucketWithRate(rate float64, capacity int64) *Bucket
取令牌方法:
// 取token(非阻塞)
func (tb *Bucket) Take(count int64) time.Duration
func (tb *Bucket) TakeAvailable(count int64) int64
// 最多等maxWait时间取token
func (tb *Bucket) TakeMaxDuration(count int64, maxWait time.Duration) (time.Duration, bool)
// 取token(阻塞)
func (tb *Bucket) Wait(count int64)
func (tb *Bucket) WaitMaxDuration(count int64, maxWait time.Duration) bool
虽说是令牌桶,但是我们没有必要真的去生成令牌放到桶里,我们只需要每次来取令牌的时候计算一下,当前是否有足够的令牌就可以了,具体的计算方式可以总结为下面的公式:
当前令牌数 = 上一次剩余的令牌数 + (本次取令牌的时刻-上一次取令牌的时刻)/放置令牌的时间间隔 * 每次放置的令牌数
github.com/juju/ratelimit这个库中关于令牌数计算的源代码如下
package middlewares
import (
"github.com/gin-gonic/gin"
"github.com/juju/ratelimit"
"net/http"
"time"
)
func RateLimitMiddleware(fillInterval time.Duration, cap int64) func(c *gin.Context) {
bucket := ratelimit.NewBucket(fillInterval, cap)
return func(c *gin.Context) {
// 如果取不到令牌就中断本次请求返回 rate limit...
if bucket.TakeAvailable(1) < 1 {
c.String(http.StatusOK, "rate limit...")
c.Abort()
return
}
// 取到令牌就放行
c.Next()
}
}
自己实现
package main
import (
"github.com/gin-gonic/gin"
ratelimit2 "github.com/juju/ratelimit"
ratelimit1 "go.uber.org/ratelimit"
"net/http"
"time"
)
func pingHandler(c *gin.Context) {
c.String(http.StatusOK, "pong")
}
func heiHandler(c *gin.Context) {
c.String(http.StatusOK, "hapong")
}
// 基于漏桶的限流中间件1
func rateLimit1() func(ctx *gin.Context) {
// 生成一个限流器
r1 := ratelimit1.New(100)
return func(c *gin.Context) {
if r1.Take().Sub(time.Now()) > 0 { // 取水滴
c.String(http.StatusOK, "rate limit ...")
c.Abort()
return
}
c.Next()
}
}
// 基于令牌桶的限流中间件2
func rateLimit2(fillInterval time.Duration, cap int64) func(ctx *gin.Context) {
r2 := ratelimit2.NewBucket(fillInterval, cap)
return func(c *gin.Context) {
//r2.Take() // 这一次可以欠账
//r2.TakeAvailable(1) // 有令牌的时候才可以取出令牌
if r2.TakeAvailable(1) == 1 { // 此次没有取出令牌
c.Next()
return
}
c.String(http.StatusOK, "rate limit ...")
c.Abort()
return
}
}
func main() {
r := gin.Default()
r.GET("/ping", rateLimit1(), pingHandler)
r.GET("/hei", rateLimit2(2*time.Second, 1), heiHandler)
r.Run()
}
在计算机性能调试领域里,profiling 是指对应用程序的画像,画像就是应用程序使用 CPU 和内存的情况。 Go语言是一个对性能特别看重的语言,因此语言中自带了 profiling 的库,这篇文章就要讲解怎么在 golang 中做 profiling。
在Go语言项目中的性能优化主要有以下几个方面:
Go语言内置了获取程序的运行数据的工具,包括以下两个标准库:
pprof开启后,每隔一段时间(10ms)就会收集下当前的堆栈信息,获取各个函数占用的CPU以及内存资源;最后通过对这些采样数据进行分析,形成一个性能分析报告。
注意,我们只应该在性能测试的时候才在代码中引入pprof。
如果你的应用程序是运行一段时间就结束退出类型。那么最好的办法是在应用退出的时候把 profiling 的报告保存到文件中,进行分析。对于这种情况,可以使用runtime/pprof库。 首先在代码中导入runtime/pprof工具:
import "runtime/pprof"
如果你的应用程序是一直运行的,比如 web 应用,那么可以使用net/http/pprof库,它能够在提供 HTTP 服务进行分析。
如果使用了默认的http.DefaultServeMux(通常是代码直接使用 http.ListenAndServe(“0.0.0.0:8000”, nil)),只需要在你的web server端代码中按如下方式导入net/http/pprof
首先我们来写一段有问题的代码:
// runtime_pprof/main.go
package main
import (
"flag"
"fmt"
"os"
"runtime/pprof"
"time"
)
// 一段有问题的代码
func logicCode() {
var c chan int
for {
select {
case v := <-c:
fmt.Printf("recv from chan, value:%v\n", v)
default:
}
}
}
func main() {
var isCPUPprof bool
var isMemPprof bool
flag.BoolVar(&isCPUPprof, "cpu", false, "turn cpu pprof on")
flag.BoolVar(&isMemPprof, "mem", false, "turn mem pprof on")
flag.Parse()
if isCPUPprof {
file, err := os.Create("./cpu.pprof")
if err != nil {
fmt.Printf("create cpu pprof failed, err:%v\n", err)
return
}
pprof.StartCPUProfile(file)
defer pprof.StopCPUProfile()
}
for i := 0; i < 8; i++ {
go logicCode()
}
time.Sleep(20 * time.Second)
if isMemPprof {
file, err := os.Create("./mem.pprof")
if err != nil {
fmt.Printf("create mem pprof failed, err:%v\n", err)
return
}
pprof.WriteHeapProfile(file)
file.Close()
}
}
推荐使用https://github.com/wg/wrk 或 https://github.com/adjust/go-wrk
Go 语言支持跨平台交叉编译,也就是说我们可以在 Windows 或 Mac 平台下编写代码,并且将代码编译成能够在 Linux amd64 服务器上运行的程序。
对于简单的项目,通常我们只需要将编译后的二进制文件拷贝到服务器上,然后设置为后台守护进程运行即可。
编译可以通过以下命令或编写 makefile 来操作。
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o ./bin/web_app
下面假设我们将本地编译好的 bluebell 二进制文件、配置文件和静态文件等上传到服务器的/data/app/bluebell目录下。
补充一点,如果嫌弃编译后的二进制文件太大,可以在编译的时候加上-ldflags "-s -w"参数去掉符号表和调试信息,一般能减小20%的大小。
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -ldflags "-s -w" -o ./bin/web_app
如果还是嫌大的话可以继续使用 upxu 工具对二进制可执行文件进行压缩。
我们在/etc/supervisord.d目录下创建一个名为bluebell.conf的配置文件,具体内容如下。
[program:bluebell] ;程序名称
user=root ;执行程序的用户
command=/data/app/bluebell/bin/bluebell /data/app/bluebell/conf/config.yaml ;执行的命令
directory=/data/app/bluebell/ ;命令执行的目录
stopsignal=TERM ;重启时发送的信号
autostart=true
autorestart=true ;是否自动重启
stdout_logfile=/var/log/bluebell-stdout.log ;标准输出日志位置
stderr_logfile=/var/log/bluebell-stderr.log ;标准错误日志位置
在需要静态文件分离、需要配置多个域名及证书、需要自建负载均衡层等稍复杂的场景下,我们一般需要搭配第三方的web服务器(Nginx、Apache)来部署我们的程序。
Nginx 是一个免费的、开源的、高性能的 HTTP 和反向代理服务,主要负责负载一些访问量比较大的站点。Nginx 可以作为一个独立的 Web 服务,也可以用来给 Apache 或是其他的 Web 服务做反向代理。相比于 Apache,Nginx 可以处理更多的并发连接,而且每个连接的内存占用的非常小。
我们推荐使用 nginx 作为反向代理来部署我们的程序,按下面的内容修改 nginx 的配置文件。
worker_processes 1;
events {
worker_connections 1024;
}
http {
include mime.types;
default_type application/octet-stream;
sendfile on;
keepalive_timeout 65;
server {
listen 80;
server_name localhost;
access_log /var/log/bluebell-access.log;
error_log /var/log/bluebell-error.log;
location / {
proxy_pass http://127.0.0.1:8081;
proxy_redirect off;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
}
}
当然我们还可以使用 nginx 的 upstream 配置来添加多个服务器地址实现负载均衡。
worker_processes 1;
events {
worker_connections 1024;
}
http {
include mime.types;
default_type application/octet-stream;
sendfile on;
keepalive_timeout 65;
upstream backend {
server 127.0.0.1:8084;
# 这里需要填真实可用的地址,默认轮询
#server backend1.example.com;
#server backend2.example.com;
}
server {
listen 80;
server_name localhost;
access_log /var/log/bluebell-access.log;
error_log /var/log/bluebell-error.log;
location / {
proxy_pass http://backend/;
proxy_redirect off;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
}
}
worker_processes 1;
events {
worker_connections 1024;
}
http {
include mime.types;
default_type application/octet-stream;
sendfile on;
keepalive_timeout 65;
server {
listen 80;
server_name bluebell;
access_log /var/log/bluebell-access.log;
error_log /var/log/bluebell-error.log;
# 静态文件请求
location ~ .*\.(gif|jpg|jpeg|png|js|css|eot|ttf|woff|svg|otf)$ {
access_log off;
expires 1d;
root /data/app/bluebell;
}
# index.html页面请求
# 因为是单页面应用这里使用 try_files 处理一下,避免刷新页面时出现404的问题
location / {
root /data/app/bluebell/templates;
index index.html;
try_files $uri $uri/ /index.html;
}
# API请求
location /api {
proxy_pass http://127.0.0.1:8084;
proxy_redirect off;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
}
}
}
这里使用github.com/gin-contrib/cors库来支持跨域请求。
最简单的允许跨域的配置是使用cors.Default(),它默认允许所有跨域请求。
func main() {
router := gin.Default()
// same as
// config := cors.DefaultConfig()
// config.AllowAllOrigins = true
// router.Use(cors.New(config))
router.Use(cors.Default())
router.Run()
}
此外,还可以使用cors.Config自定义具体的跨域请求相关配置项:
package main
import (
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
func main() {
router := gin.Default()
// CORS for https://foo.com and https://github.com origins, allowing:
// - PUT and PATCH methods
// - Origin header
// - Credentials share
// - Preflight requests cached for 12 hours
router.Use(cors.New(cors.Config{
AllowOrigins: []string{"https://foo.com"},
AllowMethods: []string{"PUT", "PATCH"},
AllowHeaders: []string{"Origin"},
ExposeHeaders: []string{"Content-Length"},
AllowCredentials: true,
AllowOriginFunc: func(origin string) bool {
return origin == "https://github.com"
},
MaxAge: 12 * time.Hour,
}))
router.Run()
}
容器部署方案可参照我之前的博客:使用Docker和Docker Compose部署Go Web应用。