在Go中使用Failpoint注入故障

转载请声明出处哦~,本篇文章发布于luozhiyun的博客: https://www.luozhiyun.com/archives/595

最近在看 TiDB 源码的时候,发现里面用了 failpoint 来进行故障注入,感觉非常有意思,里面用到了代码生成,以及代码 AST 树解析替换等方式实现了故障注入,我也会尝试解析一下,学习如何解析 AST 树生成代码。

所以这篇文章主要来看看 failpoint 使用详解以及实现原理吧。

前言

failpoint 是用于测试时注入错误的工具,它是 FreeBSD Failpoints的 Golang 实现。通常我们为了提升系统的稳定性,会有各种各样的测试场景,但是有些场景非常难模拟,比如:微服务中某个服务出现随机延迟、某个服务不可用的场景;游戏开发中模拟玩家网络不稳定、掉帧、延迟过大等场景;

所以为了可以很方便的测试出这些问题就有了 failpoint,它极大的简化了我们的测试流程,帮助我们在各种场景中模拟出各种错误,以便我们调试出代码 bug。

对于 failpoint 来说主要有以下几个优势:

  • failpoint 相关代码不应该有任何额外开销;
  • 不能影响正常功能逻辑,不能对功能代码有任何侵入;
  • failpoint 代码必须是易读、易写并且能引入编译器检测;
  • 最终生成的代码必须具有可读性;
  • 生成代码中,功能逻辑代码的行号不能发生变化(便于调试);

使用

首先我们需要使用源码进行构建:

git clone https://github.com/pingcap/failpoint.git
cd failpoint
make
ls bin/failpoint-ctl

译出二进制 failpoint-ctl用于代码转换。

然后在代码里面可以使用 failpoint 来注入故障:

package main

import "github.com/pingcap/failpoint"
import "fmt"

func test() {
    failpoint.Inject("testValue", func(v failpoint.Value) {
        fmt.Println(v)
    })
}

func main(){
    for i:=0;i<100;i++{
        test()
    }
}

我们进入到 Inject 方法中可以看到:

func Inject(fpname string, fpbody interface{}) {}

failpoint 在没有启用的时候,它只是一个空的实现,并不会对我们业务逻辑的性能产生任何影响。当我们的服务代码被编译构建后,这块代码会被 inline 优化掉,这就是 failpoint 所实现的 zero cost 故障注入原理。

下面我们将上面的测试函数全部转换成可用的故障注入代码:

$ failpoint/bin/failpoint-ctl enable .

调用编译好的 failpoint-ctl将当前代码重写转换:

package main

import (
    "fmt"
    "github.com/pingcap/failpoint"
)

func test() {
    if v, _err_ := failpoint.Eval(_curpkg_("testValue")); _err_ == nil {
        fmt.Println(v)
    }
}

func main() {
    for i := 0; i < 100; i++ {
        test()
    }
}

下面我们对代码执行注入:

$ GO_FAILPOINTS='main/testValue=2*return("abc")' go run main.go binding__failpoint_binding__.go 
abc
abc

上面这个用例中 2 表示注入只会执行两次,return("abc")中的参数对应注入函数中获取到的 v 变量。

除此之外我们还可以设置生效的概率:

$ GO_FAILPOINTS='main/testValue=5%return("abc")' go run main.go binding__failpoint_binding__.go 
abc
abc
abc
abc

上面这个用例中 5%表示只有 5%生效返回 abc。

除了上面简单的例子以外,还可以用它来生成比较复杂的场景:

package main

import (
    "fmt"
    "github.com/pingcap/failpoint"
    "math/rand"
)

func main() {
    failpoint.Label("outer")
    for i := 0; i < 100; i++ {
    failpoint.Label("inner")
        for j := 1; j < 1000; j++ {
            switch rand.Intn(j) + i {
            case j / 5:
                failpoint.Break()
            case j / 7:
                failpoint.Continue("outer")
            case j / 9:
                failpoint.Fallthrough()
            case j / 10:
                failpoint.Goto("outer")
            default:
                failpoint.Inject("failpoint-name", func(val failpoint.Value) {
                    fmt.Println("unit-test", val.(int))
                    if val == j/11 {
                        failpoint.Break("inner")
                    } else {
                        failpoint.Goto("outer")
                    }
                })
            }
        }
    } 
}

在这个例子中,使用了 failpoint.Breakfailpoint.Gotofailpoint.Continuefailpoint.Label来实现代码的跳转,最后生成的代码:

func main() {
outer:
    for i := 0; i < 100; i++ {
    inner:
        for j := 1; j < 1000; j++ {
            switch rand.Intn(j) + i {
            case j / 5:
                break
            case j / 7:
                continue outer
            case j / 9:
                fallthrough
            case j / 10:
                goto outer
            default:
                if val, _err_ := failpoint.Eval(_curpkg_("failpoint-name")); _err_ == nil {
                    fmt.Println("unit-test", val.(int))
                    if val == j/11 {
                        break inner
                    } else {
                        goto outer
                    }
                }
            }
        }
    }
}

可以看到我们上面的 failpoint 代码都转化为了 Go 语言中的跳转关键字。

在测试完毕之后,最后我们通过 disable 可以将代码还原:

$ failpoint/bin/failpoint-ctl disable .

其他的使用方式可以查看官方文档:

https://github.com/pingcap/failpoint

实现原理

代码注入

举例说明

在使用 failpoint 的时候会通过它提供的一系列 Marker 函数来构建我们的故障埋点:

func Inject(fpname string, fpblock func(val Value)) {}
func InjectContext(fpname string, ctx context.Context, fpblock func(val Value)) {}
func Break(label ...string) {}
func Goto(label string) {}
func Continue(label ...string) {}
func Fallthrough() {}
func Return(results ...interface{}) {}
func Label(label string) {}

然后经 failpoint-ctl转换,构建 AST 替换 marker stmt, 转换成最终的注入函数代码,如下所示:

package main

import (
    "fmt"
    "github.com/pingcap/failpoint"
)

func test() {
    failpoint.Inject("testPanic", func(val failpoint.Value){
        fmt.Println(val)
    })
}

func main() {
    for i := 0; i < 100; i++ {
        test()
    }
}

转换后:

package main

import (
    "fmt"
    "github.com/pingcap/failpoint"
)

func test() {
    if val, _err_ := failpoint.Eval(_curpkg_("testPanic")); _err_ == nil {
        fmt.Println(val)
    }
}

func main() {
    for i := 0; i < 100; i++ {
        test()
    }
}

failpoint-ctl转换除了将代码内容进行替换以外,还生成了一个binding__failpoint_binding__.go文件,里面有一个 _curpkg_ 函数用来获取当前的包名:

package main

import "reflect"

type __failpointBindingType struct {pkgpath string}
var __failpointBindingCache = &__failpointBindingType{}

func init() {
    __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath()
}
func _curpkg_(name string) string {
    return  __failpointBindingCache.pkgpath + "/" + name
}

获取代码 AST 树

我们在调用failpoint-ctl进行代码转换的时候,会通过 Rewriter 对代码进行重写。Rewriter 是一个工具结构体,主要就是通过遍历代码 AST 树,检测 Marker 函数并完成函数的替换重写。

type Rewriter struct {
    rewriteDir    string // 重写路径
    currentPath   string // 文件路径
    currentFile   *ast.File // 文件 AST 树
    currsetFset   *token.FileSet // FileSet
    failpointName string // import 中 failpoint 的导入重命名
    rewritten     bool // 是否重写完毕

    output io.Writer // 重定向输出
}

failpoint-ctl执行的时候会调用到 RewriteFile 方法进行代码的重写:

func (r *Rewriter) RewriteFile(path string) (err error) {
    defer func() {
        if e := recover(); e != nil {
            err = fmt.Errorf("%s %v\n%s", r.currentPath, e, debug.Stack())
        }
    }()
    fset := token.NewFileSet()
    // 获取go文件AST树
    file, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
    if err != nil {
        return err
    }
    if len(file.Decls) < 1 {
        return nil
    }
    // 文件路径
    r.currentPath = path
    // 文件AST树
    r.currentFile = file
    // 文件 FileSet
    r.currsetFset = fset
    // 标记是否重写完毕
    r.rewritten = false
    // 获取 failpoint import 包
    var failpointImport *ast.ImportSpec
    for _, imp := range file.Imports {
        if strings.Trim(imp.Path.Value, "`\"") == packagePath {
            failpointImport = imp
            break
        }
    }
    if failpointImport == nil {
        panic("import path should be check before rewrite")
    }
    if failpointImport.Name != nil {
        r.failpointName = failpointImport.Name.Name
    } else {
        r.failpointName = packageName
    }
    // 遍历文件中的顶级声明:如type、函数、import、全局常量等
    for _, decl := range file.Decls {
        fn, ok := decl.(*ast.FuncDecl)
        if !ok {
            continue
        }
        // 遍历函数声明节点,将failpoint相关函数进行替换
        if err := r.rewriteFuncDecl(fn); err != nil {
            return err
        }
    }

    if !r.rewritten {
        return nil
    }

    if r.output != nil {
        return format.Node(r.output, fset, file)
    }
    // 生成 binding__failpoint_binding__ 代码
    found, err := isBindingFileExists(path)
    if err != nil {
        return err
    }
    // binding__failpoint_binding__.go文件不存在,那么重新生成一个
    if !found {
        err := writeBindingFile(path, file.Name.Name)
        if err != nil {
            return err
        }
    }
    // 将原文件改名,如:将main.go改名为main.go__failpoint_stash__
    // 用来做作为还原使用
    targetPath := path + failpointStashFileSuffix
    if err := os.Rename(path, targetPath); err != nil {
        return err
    }

    newFile, err := os.OpenFile(path, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, os.ModePerm)
    if err != nil {
        return err
    }
    defer newFile.Close()
    //将构造好的ast树重新生成代码文件
    return format.Node(newFile, fset, file)
}

这个方法首先会调用 Go 提供的parser.ParseFile方法获取文件的 AST 树, AST 树是使用树状结构表示源代码的语法结构,树的每一个节点就代表源代码中的一个结构。然后遍历这颗 AST 树的顶级声明 Decls切片,相当于是从树的顶部往下遍历,是一个深度优先的遍历。

遍历完成之后通过校验 binding__failpoint_binding__ 文件,以及将源文件备份的操作之后调用format.Node将整个文件重写。

代码AST树遍历获取 Rewriter 执行节点替换

func (r *Rewriter) rewriteStmts(stmts []ast.Stmt) error {
    // 遍历函数体节点
    for i, block := range stmts {
        switch v := block.(type) {
        case *ast.DeclStmt:
            ... 
        // 包含单独的表达式语句
        case *ast.ExprStmt:
            call, ok := v.X.(*ast.CallExpr)
            if !ok {
                break
            }
            switch expr := call.Fun.(type) {
            // 函数定义
            case *ast.FuncLit:
                // 递归遍历函数
                err := r.rewriteFuncLit(expr)
                if err != nil {
                    return err
                }
            // 选择结构,类似于a.b的结构
            case *ast.SelectorExpr:
                // 获取函数调用的包名
                packageName, ok := expr.X.(*ast.Ident)
                // 包名是否等于 failpoint 包名
                if !ok || packageName.Name != r.failpointName {
                    break
                }
                // 通过 Marker 名获取 failpoint 的 Rewriter
                exprRewriter, found := exprRewriters[expr.Sel.Name]
                if !found {
                    break
                }
                // 对函数进行重写
                rewritten, stmt, err := exprRewriter(r, call)
                if err != nil {
                    return err
                }
                if !rewritten {
                    continue
                }
                // 获取重新生成好的if节点
                if ifStmt, ok := stmt.(*ast.IfStmt); ok {
                    err := r.rewriteIfStmt(ifStmt)
                    if err != nil {
                        return err
                    }
                }
                // 节点替换为重新生成好的if节点
                stmts[i] = stmt
                r.rewritten = true
            }

        case *ast.AssignStmt:
            ... 
        case *ast.GoStmt:
            ...
        case *ast.DeferStmt:
            ...
        case *ast.ReturnStmt: 
        ... 
        default:
            fmt.Printf("unsupported statement: %T in %s\n", v, r.pos(v.Pos()))
        }
    } 
    return nil
}

这里会依次遍历所有函数,直到找到 failpoint Marker 声明的地方,然后会通过 Marker名称在 exprRewriters 中获取到对应的 Rewriter:

var exprRewriters = map[string]exprRewriter{
    "Inject":        (*Rewriter).rewriteInject,
    "InjectContext": (*Rewriter).rewriteInjectContext,
    "Break":         (*Rewriter).rewriteBreak,
    "Continue":      (*Rewriter).rewriteContinue,
    "Label":         (*Rewriter).rewriteLabel,
    "Goto":          (*Rewriter).rewriteGoto,
    "Fallthrough":   (*Rewriter).rewriteFallthrough,
    "Return":        (*Rewriter).rewriteReturn,
}

Rewriter 重写

我们上面的例子使用的是 failpoint.Inject,所以这里使用 rewriteInject 进行讲解。

通过这个方法最终会将:

    failpoint.Inject("testPanic", func(val failpoint.Value){
        fmt.Println(val)
    })

转变成:

    if val, _err_ := failpoint.Eval(_curpkg_("testPanic")); _err_ == nil {
        fmt.Println(val)
    }

下面看看是如何构造 AST 树:

func (r *Rewriter) rewriteInject(call *ast.CallExpr) (bool, ast.Stmt, error) {
    //判断函数failpoint.Inject调用是否合法
    if len(call.Args) != 2 {
        return false, nil, fmt.Errorf("failpoint.Inject: expect 2 arguments but got %v in %s", len(call.Args), r.pos(call.Pos()))
    } 
    // 获取第一个参数 “testPanic”
    fpname, ok := call.Args[0].(ast.Expr)
    if !ok {
        return false, nil, fmt.Errorf("failpoint.Inject: first argument expect a valid expression in %s", r.pos(call.Pos()))
    }

    // 获取第二个参数 func(val failpoint.Value){}
    ident, ok := call.Args[1].(*ast.Ident)
    // 判断第二个参数是否为空
    isNilFunc := ok && ident.Name == "nil"

    // 校验第二个参数是函数的情况,因为第二个函数参数可以为空
    // failpoint.Inject("failpoint-name", func(){...})
    // failpoint.Inject("failpoint-name", func(val failpoint.Value){...})
    fpbody, isFuncLit := call.Args[1].(*ast.FuncLit)
    if !isNilFunc && !isFuncLit {
        return false, nil, fmt.Errorf("failpoint.Inject: second argument expect closure in %s", r.pos(call.Pos()))
    }

    // 第二个参数是函数的情况
    if isFuncLit {
        if len(fpbody.Type.Params.List) > 1 {
            return false, nil, fmt.Errorf("failpoint.Inject: closure signature illegal in %s", r.pos(call.Pos()))
        }

        if len(fpbody.Type.Params.List) == 1 && len(fpbody.Type.Params.List[0].Names) > 1 {
            return false, nil, fmt.Errorf("failpoint.Inject: closure signature illegal in %s", r.pos(call.Pos()))
        }
    }
    //构建替换函数:_curpkg_("testPanic")
    fpnameExtendCall := &ast.CallExpr{
        Fun:  ast.NewIdent(extendPkgName),
        Args: []ast.Expr{fpname},
    }
    //构建函数 failpoint.Eval
    checkCall := &ast.CallExpr{
        Fun: &ast.SelectorExpr{
            X:   &ast.Ident{NamePos: call.Pos(), Name: r.failpointName},
            Sel: ast.NewIdent(evalFunction),
        },
        Args: []ast.Expr{fpnameExtendCall},
    }
    if isNilFunc || len(fpbody.Body.List) < 1 {
        return true, &ast.ExprStmt{X: checkCall}, nil
    }
    // 构建if代码块
    ifBody := &ast.BlockStmt{
        Lbrace: call.Pos(),
        List:   fpbody.Body.List,
        Rbrace: call.End(),
    }

    // 校验failpoint中的闭包函数是否是包含参数的
    // func(val failpoint.Value) {...}
    // func() {...}
    var argName *ast.Ident
    if len(fpbody.Type.Params.List) > 0 {
        arg := fpbody.Type.Params.List[0]
        selector, ok := arg.Type.(*ast.SelectorExpr)
        if !ok || selector.Sel.Name != "Value" || selector.X.(*ast.Ident).Name != r.failpointName {
            return false, nil, fmt.Errorf("failpoint.Inject: invalid signature in %s", r.pos(call.Pos()))
        }
        argName = arg.Names[0]
    } else {
        argName = ast.NewIdent("_")
    }
    // 构建 failpoint.Eval 的返回值
    err := ast.NewIdent("_err_")
    init := &ast.AssignStmt{
        Lhs: []ast.Expr{argName, err},
        Rhs: []ast.Expr{checkCall},
        Tok: token.DEFINE,
    }
    // 构建 if 的判断条件,也就是 _err_ == nil
    cond := &ast.BinaryExpr{
        X:  err,
        Op: token.EQL,
        Y:  ast.NewIdent("nil"),
    }
    // 构建完整 if 代码块
    stmt := &ast.IfStmt{
        If:   call.Pos(),
        Init: init,
        Cond: cond,
        Body: ifBody,
    }
    return true, stmt, nil
}

上面的注释应该很详细了,可以跟着注释看代码。

failpoint 执行

构建故障方案

比如说我们这个故障有5%的概率会被触发,那我们可以这么做:

$ GO_FAILPOINTS='main/testValue=5%return("abc")' go run main.go binding__failpoint_binding__.go

上面声明的 GO_FAILPOINTS 变量里面的内容会在初始化的时候被读取到,然后注册好对应的机制,在执行的时候根据注册的机制进行故障控制。

func init() {
    failpoints.reg = make(map[string]*Failpoint)
    // 获取 GO_FAILPOINTS 变量
    if s := os.Getenv("GO_FAILPOINTS"); len(s) > 0 { 
        // 多个值使用;进行分割
        for _, fp := range strings.Split(s, ";") {
            fpTerms := strings.Split(fp, "=")
            if len(fpTerms) != 2 {
                fmt.Printf("bad failpoint %q\n", fp)
                os.Exit(1)
            }
            // 注册注入方案
            err := Enable(fpTerms[0], fpTerms[1])
            if err != nil {
                fmt.Printf("bad failpoint %s\n", err)
                os.Exit(1)
            }
        }
    }
    if s := os.Getenv("GO_FAILPOINTS_HTTP"); len(s) > 0 {
        if err := serve(s); err != nil {
            fmt.Println(err)
            os.Exit(1)
        }
    }
}

Enable 最后会调用到 Failpoints 结构体的 Enable 方法中,我们先来看看 Failpoints结构体:

type Failpoints struct {
    mu  sync.RWMutex  //并发控制
    reg map[string]*Failpoint //故障方案表
}

Failpoint struct {
    mu       sync.RWMutex  //并发控制
    t        *terms
    waitChan chan struct{} // 用来做暂停
}

Enable 会将 main/testValue=5%return("abc")解析成 key-value 的形式存放到 reg 这个 map 中,value 会被解析成为 Failpoint 结构体。

Failpoint 结构体中的故障控制方案主要存放在 term 结构体中:

type term struct {
    desc string //方案描述,这里是 5%return("abc")

    mods mod // 方案类型,是故障概率控制还是故障次数控制,这里是 5%
    act  actFunc // 故障行为,这里是 return
    val  interface{} // 注入故障的值,这里是 abc

    parent *terms
    fp     *Failpoint
}

我们在上面使用了 return 来执行故障,除此之外还有6个:

  • off: Take no action (does not trigger failpoint code)
  • return: Trigger failpoint with specified argument
  • sleep: Sleep the specified number of milliseconds
  • panic: Panic
  • break: Execute gdb and break into debugger
  • print: Print failpoint path for inject variable
  • pause: Pause will pause until the failpoint is disabled

整个 Filpoint 的层级关系如下:

下面我们看一下 Enable:

func (fp *Failpoint) Enable(inTerms string) error {
    t, err := newTerms(inTerms, fp)
    if err != nil {
        return err
    }
    fp.mu.Lock()
    fp.t = t
    fp.waitChan = make(chan struct{})
    fp.mu.Unlock()
    return nil
}

Enable 主要是调用 newTerms 构建 terms 结构体:

func newTerms(desc string, fp *Failpoint) (*terms, error) {
    // 解析传入的策略
    chain, err := parse(desc, fp)
    if err != nil {
        return nil, err
    }
    t := &terms{chain: chain, desc: desc}
    for _, c := range chain {
        c.parent = t
    }
    return t, nil
}

通过parse解析传入的策略,构建 terms 返回。

故障执行

我们在运行故障代码的时候会执行 failpoint.Eval,然后根据是否返回 err 来判断是否会执行故障函数。

Eval 函数会调用到 Failpoints 的 Eval 方法:

func (fps *Failpoints) Eval(failpath string) (Value, error) {
    fps.mu.RLock()
    // 获取注册的 Failpoint
    fp, found := fps.reg[failpath]
    fps.mu.RUnlock()
    if !found {
        return nil, errors.Wrapf(ErrNotExist, "error on %s", failpath)
    }
    // 执行方案判断
    val, err := fp.Eval()
    if err != nil {
        return nil, errors.Wrapf(err, "error on %s", failpath)
    }
    return val, nil
}

Eval 方法里面调用到的 reg map 是我们上面提到的 init 函数中注册好的方案,获取到 Failpoint 会调用它的 Eval 方法:

Eval 方法会调用到 terms 的 eval 方法遍历 chain []*term字段,获取其中设置的方案调用 allow 方法校验是否通过,通过则调用 do 方法执行对应的行为。

总结

在上面的介绍中首先学习了如何使用 Failpoint 服务于我们的代码,然后学习了 Failpoint 是如何通过代码注入的方式来实现故障注入。其中包含了 Go 的 AST 树遍历修改,以及代码生成,也为我们自己平时在写代码的时候提供了一种思路,通过这种方式代码生成的方式来提供一些额外的功能。

Reference

https://github.com/pingcap/failpoint

https://pingcap.com/zh/blog/golang-failpoint

https://www.modb.pro/db/79460

扫码_搜索联合传播样式-白色版 1