一个gorm插件:实现了orm层的可观测性
注:可优先关注 main 函数
package main
import (
"context"
"fmt"
"os"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/jaeger"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/plugin/opentelemetry/logging/logrus"
"gorm.io/plugin/opentelemetry/tracing"
)
func ConfigureOpentelemetry(ctx context.Context) func() {
switch {
case os.Getenv("OTEL_EXPORTER_JAEGER_ENDPOINT") != "":
return configureJaeger(ctx)
default:
return configureStdout(ctx)
}
}
func configureJaeger(ctx context.Context) func() {
provider := sdktrace.NewTracerProvider()
otel.SetTracerProvider(provider)
exp, err := jaeger.New(jaeger.WithCollectorEndpoint())
if err != nil {
panic(err)
}
bsp := sdktrace.NewBatchSpanProcessor(exp)
provider.RegisterSpanProcessor(bsp)
return func() {
if err := provider.Shutdown(ctx); err != nil {
panic(err)
}
}
}
func configureStdout(ctx context.Context) func() {
provider := sdktrace.NewTracerProvider()
otel.SetTracerProvider(provider)
exp, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
if err != nil {
panic(err)
}
bsp := sdktrace.NewBatchSpanProcessor(exp)
provider.RegisterSpanProcessor(bsp)
return func() {
if err := provider.Shutdown(ctx); err != nil {
panic(err)
}
}
}
func PrintTraceID(ctx context.Context) {
fmt.Println("trace:", TraceURL(trace.SpanFromContext(ctx)))
}
func TraceURL(span trace.Span) string {
switch {
case os.Getenv("OTEL_EXPORTER_JAEGER_ENDPOINT") != "":
return fmt.Sprintf("http://localhost:16686/trace/%s", span.SpanContext().TraceID())
default:
return fmt.Sprintf("http://localhost:16686/trace/%s", span.SpanContext().TraceID())
}
}
func main() {
ctx := context.Background()
shutdown := ConfigureOpentelemetry(ctx)
defer shutdown()
logger := logger.New(
logrus.NewWriter(),// <--<--
logger.Config{
SlowThreshold: time.Millisecond,
LogLevel: logger.Warn,
Colorful: false,
},
)
db, err := gorm.Open(mysql.Open("root:rootpassword@tcp(127.0.0.1:3306)/test?charset=utf8mb4&parseTime=True&loc=Local"), &gorm.Config{Logger: logger})
if err != nil {
panic(err)
}
tx := db.Begin()
defer tx.Rollback()
if err := db.Use(tracing.NewPlugin()); err != nil {// <--<-- tracing.NewPlugin()
panic(err)
}
tracer := otel.Tracer("gorm.io/plugin/opentelemetry")
ctx, span := tracer.Start(ctx, "root")
defer span.End()
if os.Getenv("OTEL_EXPORTER_JAEGER_ENDPOINT") != "" {
PrintTraceID(ctx)
}
ids := []int64{1, 2, 3}
var users []int64
_ = tx.Debug().WithContext(ctx).Table("users").Where("id IN ?", ids).Find(&users)
fmt.Println(users)
err = tx.Commit().Error
if err != nil {
fmt.Println("commit err:", err)
}
}
tracing.NewPlugin是实现orm层链路监控的一个gorm插件,具体实现是利用了gorm的callback。
// gorm-opentelemetry/tracing/tracing.go
func (p otelPlugin) Initialize(db *gorm.DB) (err error) {
if !p.excludeMetrics {
if db, ok := db.ConnPool.(*sql.DB); ok {
metrics.ReportDBStatsMetrics(db)
}
}
cb := db.Callback()
hooks := []struct {
callback gormRegister
hook gormHookFunc // type gormHookFunc func(tx *gorm.DB)
name string
}{
{cb.Create().Before("gorm:create"), p.before("gorm.Create"), "before:create"},
{cb.Create().After("gorm:create"), p.after(), "after:create"},
{cb.Query().Before("gorm:query"), p.before("gorm.Query"), "before:select"},
{cb.Query().After("gorm:query"), p.after(), "after:select"},
{cb.Delete().Before("gorm:delete"), p.before("gorm.Delete"), "before:delete"},
{cb.Delete().After("gorm:delete"), p.after(), "after:delete"},
{cb.Update().Before("gorm:update"), p.before("gorm.Update"), "before:update"},
{cb.Update().After("gorm:update"), p.after(), "after:update"},
{cb.Row().Before("gorm:row"), p.before("gorm.Row"), "before:row"},
{cb.Row().After("gorm:row"), p.after(), "after:row"},
{cb.Raw().Before("gorm:raw"), p.before("gorm.Raw"), "before:raw"},
{cb.Raw().After("gorm:raw"), p.after(), "after:raw"},
}
var firstErr error
for _, h := range hooks {
if err := h.callback.Register("otel:"+h.name, h.hook); err != nil && firstErr == nil {
firstErr = fmt.Errorf("callback register %s failed: %w", h.name, err)
}
}
return firstErr
}