golang[ssa & callgraph] 获取调用图实战

最近在拆分一个旧服务,需要从几十万行代码中,按业务功能拆分出对应代码,并部署新服务;然而,面对这种巨型服务,代码调用错综复杂,纯人力拆分需要耗费很多时间;基于此,这里借助golang自带callgraph调用图能力,帮我们找到需要拆出的代码;

package main

import (
	"fmt"
	"io/ioutil"
	"path/filepath"
	"sort"
	"strings"

	"github.com/pkg/errors"
	"golang.org/x/tools/go/packages"
	"golang.org/x/tools/go/ssa/ssautil"
	"golang.org/x/tools/go/callgraph"
	"golang.org/x/tools/go/pointer"
)

// getProjectUsedCall 获取项目使用中的调用方法
func getProjectUsedCall(projectPath string) ([]string, error) {
	projectModule, err := parseProjectModule(projectPath)
	if err != nil {
		return nil, errors.Wrap(err, "parseProjectModule fail")
	}
	log.Debugf("projectModule: %+v", projectModule)

	callMap, err := parseProjectCallMap(projectPath)
	if err != nil {
		return nil, errors.Wrap(err, "parseProjectCallMap fail")
	}
	log.Debugf("callMap: %+v", callMap)

	srcCall := fmt.Sprintf("%v.main", projectModule)
	isDeleteEdgeFunc := func(caller, callee string) bool {
		// 非本项目调用
		if !strings.Contains(caller, projectModule) || !strings.Contains(callee, projectModule) {
			return true
		}
		// 非初始化调用
		if isInitCall(caller) || isInitCall(callee) {
			return true
		}
		// 非自我调用
		if caller == callee {
			return true
		}
		return false
	}

	// 过滤不需要的边
	for caller, callees := range callMap {
		for callee := range callees {
			if isDeleteEdgeFunc(caller, callee) {
				delete(callees, callee)
			}
		}
		if len(callees) == 0 {
			delete(callMap, caller)
		}
	}

	// 广度搜索图
	for {
		srcCallees := callMap[srcCall]
		srcSize := len(srcCallees)

		for srcCallee := range srcCallees {
			for nextCallee := range callMap[srcCallee] {
				callMap[srcCall][nextCallee] = true
			}
		}

		if srcSize == len(callMap[srcCall]) {
			break
		}
	}

	// 调用源涉及到的所有方法
	var callees []string
	for c := range callMap[srcCall] {
		callees = append(callees, c)
	}
	sort.Strings(callees)
	return callees, nil
}

// parseProjectCallMap 解析项目调用图
func parseProjectCallMap(projectPath string) (map[string]map[string]bool, error) {
	projectModule, err := parseProjectModule(projectPath)
	if err != nil {
		return nil, errors.Wrap(err, "parseProjectModule fail")
	}
	log.Debugf("projectModule: %+v", projectModule)

	result, err := analyzeProject(projectPath)
	if err != nil {
		return nil, errors.Wrap(err, "analyzeProject fail")
	}
	log.Debugf("analyzeProject: %+v", result)

	// 遍历调用链路
	var callMap = make(map[string]map[string]bool)
	visitFunc := func(edge *callgraph.Edge) error {
		if edge == nil {
			return nil
		}
		// 解析调用者和被调用者
		caller, callee, err := parseCallEdge(edge)
		if err != nil {
			return errors.Wrap(err, "parseCallEdge fail")
		}
		// 记录调用关系
		if callMap[caller] == nil {
			callMap[caller] = make(map[string]bool)
		}
		callMap[caller][callee] = true
		return nil
	}
	err = callgraph.GraphVisitEdges(result.CallGraph, visitFunc)
	if err != nil {
		return nil, errors.Wrap(err, "GraphVisitEdges fail")
	}
	return callMap, nil
}

func parseProjectModule(projectPath string) (string, error) {
	modFilename := filepath.Join(projectPath, "go.mod")
	content, err := ioutil.ReadFile(modFilename)
	if err != nil {
		return "", errors.Wrap(err, "ioutil.ReadFile fail")
	}
	lines := strings.Split(string(content), "\n")
	module := strings.TrimPrefix(lines[0], "module ")
	module = strings.TrimSpace(module)
	return module, nil
}

func analyzeProject(projectPath string) (*pointer.Result, error) {
	// 生成Go Packages
	pkgs, err := packages.Load(&packages.Config{
		Mode: packages.LoadAllSyntax,
		Dir:  projectPath,
	})
	if err != nil {
		return nil, errors.Wrap(err, "packages.Load fail")
	}
	log.Debugf("pkgs: %+v", pkgs)

	// 生成ssa 构建编译
	prog, ssaPkgs := ssautil.AllPackages(pkgs, 0)
	prog.Build()
	log.Debugf("ssaPkgs: %+v", ssaPkgs)
	// 使用pointer生成调用链路
	return pointer.Analyze(&pointer.Config{
		Mains:          ssaPkgs,
		BuildCallGraph: true,
	})
}

func parseCallEdge(edge *callgraph.Edge) (string, string, error) {
	const callArrow = "-->"
	edgeStr := fmt.Sprintf("%+v", edge)
	strArray := strings.Split(edgeStr, callArrow)
	if len(strArray) != 2 {
		return "", "", fmt.Errorf("invalid format: %v", edgeStr)
	}
	callerNodeStr, calleeNodeStr := strArray[0], strArray[1]
	caller, callee := getCallRoute(callerNodeStr), getCallRoute(calleeNodeStr)
	return caller, callee, nil
}

func getCallRoute(nodeStr string) string {
	nodeStr = strings.TrimSpace(nodeStr)
	if strings.Contains(nodeStr, ":") {
		nodeStr = nodeStr[strings.Index(nodeStr, ":")+1:]
	}
	nodeStr = strings.ReplaceAll(nodeStr, "*", "")
	nodeStr = strings.ReplaceAll(nodeStr, "(", "")
	nodeStr = strings.ReplaceAll(nodeStr, ")", "")
	nodeStr = strings.ReplaceAll(nodeStr, "<", "")
	nodeStr = strings.ReplaceAll(nodeStr, ">", "")
	if strings.Contains(nodeStr, "$") {
		nodeStr = nodeStr[:strings.Index(nodeStr, "$")]
	}
	if strings.Contains(nodeStr, "#") {
		nodeStr = nodeStr[:strings.Index(nodeStr, "#")]
	}
	return strings.TrimSpace(nodeStr)
}

func isInitCall(call string) bool {
	return strings.HasSuffix(call, ".init")
}

你可能感兴趣的:(golang,数据库,开发语言,调用图)