leaktest.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Package leaktest provides tools to detect leaked goroutines in tests.
  5. // To use it, call "defer leaktest.Check(t)()" at the beginning of each
  6. // test that may use goroutines.
  7. // copied out of the cockroachdb source tree with slight modifications to be
  8. // more re-useable
  9. package leaktest
  10. import (
  11. "context"
  12. "fmt"
  13. "runtime"
  14. "sort"
  15. "strconv"
  16. "strings"
  17. "time"
  18. )
  19. type goroutine struct {
  20. id uint64
  21. stack string
  22. }
  23. type goroutineByID []*goroutine
  24. func (g goroutineByID) Len() int { return len(g) }
  25. func (g goroutineByID) Less(i, j int) bool { return g[i].id < g[j].id }
  26. func (g goroutineByID) Swap(i, j int) { g[i], g[j] = g[j], g[i] }
  27. func interestingGoroutine(g string) (*goroutine, error) {
  28. sl := strings.SplitN(g, "\n", 2)
  29. if len(sl) != 2 {
  30. return nil, fmt.Errorf("error parsing stack: %q", g)
  31. }
  32. stack := strings.TrimSpace(sl[1])
  33. if strings.HasPrefix(stack, "testing.RunTests") {
  34. return nil, nil
  35. }
  36. if stack == "" ||
  37. // Below are the stacks ignored by the upstream leaktest code.
  38. strings.Contains(stack, "testing.Main(") ||
  39. strings.Contains(stack, "testing.(*T).Run(") ||
  40. strings.Contains(stack, "runtime.goexit") ||
  41. strings.Contains(stack, "created by runtime.gc") ||
  42. strings.Contains(stack, "interestingGoroutines") ||
  43. strings.Contains(stack, "runtime.MHeap_Scavenger") ||
  44. strings.Contains(stack, "signal.signal_recv") ||
  45. strings.Contains(stack, "sigterm.handler") ||
  46. strings.Contains(stack, "runtime_mcall") ||
  47. strings.Contains(stack, "goroutine in C code") {
  48. return nil, nil
  49. }
  50. // Parse the goroutine's ID from the header line.
  51. h := strings.SplitN(sl[0], " ", 3)
  52. if len(h) < 3 {
  53. return nil, fmt.Errorf("error parsing stack header: %q", sl[0])
  54. }
  55. id, err := strconv.ParseUint(h[1], 10, 64)
  56. if err != nil {
  57. return nil, fmt.Errorf("error parsing goroutine id: %s", err)
  58. }
  59. return &goroutine{id: id, stack: strings.TrimSpace(g)}, nil
  60. }
  61. // interestingGoroutines returns all goroutines we care about for the purpose
  62. // of leak checking. It excludes testing or runtime ones.
  63. func interestingGoroutines(t ErrorReporter) []*goroutine {
  64. buf := make([]byte, 2<<20)
  65. buf = buf[:runtime.Stack(buf, true)]
  66. var gs []*goroutine
  67. for _, g := range strings.Split(string(buf), "\n\n") {
  68. gr, err := interestingGoroutine(g)
  69. if err != nil {
  70. t.Errorf("leaktest: %s", err)
  71. continue
  72. } else if gr == nil {
  73. continue
  74. }
  75. gs = append(gs, gr)
  76. }
  77. sort.Sort(goroutineByID(gs))
  78. return gs
  79. }
  80. // ErrorReporter is a tiny subset of a testing.TB to make testing not such a
  81. // massive pain
  82. type ErrorReporter interface {
  83. Errorf(format string, args ...interface{})
  84. }
  85. // Check snapshots the currently-running goroutines and returns a
  86. // function to be run at the end of tests to see whether any
  87. // goroutines leaked, waiting up to 5 seconds in error conditions
  88. func Check(t ErrorReporter) func() {
  89. return CheckTimeout(t, 5*time.Second)
  90. }
  91. // CheckTimeout is the same as Check, but with a configurable timeout
  92. func CheckTimeout(t ErrorReporter, dur time.Duration) func() {
  93. ctx, cancel := context.WithCancel(context.Background())
  94. fn := CheckContext(ctx, t)
  95. return func() {
  96. timer := time.AfterFunc(dur, cancel)
  97. fn()
  98. // Remember to clean up the timer and context
  99. timer.Stop()
  100. cancel()
  101. }
  102. }
  103. // CheckContext is the same as Check, but uses a context.Context for
  104. // cancellation and timeout control
  105. func CheckContext(ctx context.Context, t ErrorReporter) func() {
  106. orig := map[uint64]bool{}
  107. for _, g := range interestingGoroutines(t) {
  108. orig[g.id] = true
  109. }
  110. return func() {
  111. var leaked []string
  112. for {
  113. select {
  114. case <-ctx.Done():
  115. t.Errorf("leaktest: timed out checking goroutines")
  116. default:
  117. leaked = make([]string, 0)
  118. for _, g := range interestingGoroutines(t) {
  119. if !orig[g.id] {
  120. leaked = append(leaked, g.stack)
  121. }
  122. }
  123. if len(leaked) == 0 {
  124. return
  125. }
  126. // don't spin needlessly
  127. time.Sleep(time.Millisecond * 50)
  128. continue
  129. }
  130. break
  131. }
  132. for _, g := range leaked {
  133. t.Errorf("leaktest: leaked goroutine: %v", g)
  134. }
  135. }
  136. }