common.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package common
  2. import (
  3. "fmt"
  4. "go/ast"
  5. "go/format"
  6. "go/parser"
  7. "go/token"
  8. "io/ioutil"
  9. "log"
  10. "os"
  11. "regexp"
  12. "strings"
  13. )
  14. // Source source
  15. type Source struct {
  16. Fset *token.FileSet
  17. Src string
  18. F *ast.File
  19. }
  20. // NewSource new source
  21. func NewSource(src string) *Source {
  22. s := &Source{
  23. Fset: token.NewFileSet(),
  24. Src: src,
  25. }
  26. f, err := parser.ParseFile(s.Fset, "", src, 0)
  27. if err != nil {
  28. log.Fatal("无法解析源文件")
  29. }
  30. s.F = f
  31. return s
  32. }
  33. // ExprString expr string
  34. func (s *Source) ExprString(typ ast.Expr) string {
  35. fset := s.Fset
  36. s1 := fset.Position(typ.Pos()).Offset
  37. s2 := fset.Position(typ.End()).Offset
  38. return s.Src[s1:s2]
  39. }
  40. // pkgPath package path
  41. func (s *Source) pkgPath(name string) (res string) {
  42. for _, im := range s.F.Imports {
  43. if im.Name != nil && im.Name.Name == name {
  44. return im.Path.Value
  45. }
  46. }
  47. for _, im := range s.F.Imports {
  48. if strings.HasSuffix(im.Path.Value, name+"\"") {
  49. return im.Path.Value
  50. }
  51. }
  52. return
  53. }
  54. // GetDef get define code
  55. func (s *Source) GetDef(name string) string {
  56. c := s.F.Scope.Lookup(name).Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType)
  57. s1 := s.Fset.Position(c.Pos()).Offset
  58. s2 := s.Fset.Position(c.End()).Offset
  59. line := s.Fset.Position(c.Pos()).Line
  60. lines := []string{strings.Split(s.Src, "\n")[line-1]}
  61. for _, l := range strings.Split(s.Src[s1:s2], "\n")[1:] {
  62. lines = append(lines, "\t"+l)
  63. }
  64. return strings.Join(lines, "\n")
  65. }
  66. // RegexpReplace replace regexp
  67. func RegexpReplace(reg, src, temp string) string {
  68. result := []byte{}
  69. pattern := regexp.MustCompile(reg)
  70. for _, submatches := range pattern.FindAllStringSubmatchIndex(src, -1) {
  71. result = pattern.ExpandString(result, temp, src, submatches)
  72. }
  73. return string(result)
  74. }
  75. // formatPackage format package
  76. func formatPackage(name, path string) (res string) {
  77. if path != "" {
  78. if strings.HasSuffix(path, name+"\"") {
  79. res = path
  80. return
  81. }
  82. res = fmt.Sprintf("%s %s", name, path)
  83. }
  84. return
  85. }
  86. // SourceText get source file text
  87. func SourceText() string {
  88. file := os.Getenv("GOFILE")
  89. data, err := ioutil.ReadFile(file)
  90. if err != nil {
  91. log.Fatal("can't open file", file)
  92. }
  93. return string(data)
  94. }
  95. // FormatCode format code
  96. func FormatCode(source string) string {
  97. src, err := format.Source([]byte(source))
  98. if err != nil {
  99. // Should never happen, but can arise when developing this code.
  100. // The user can compile the output to see the error.
  101. log.Printf("warning: 输出文件不合法: %s", err)
  102. log.Printf("warning: 详细错误请编译查看")
  103. return source
  104. }
  105. return string(src)
  106. }
  107. // Packages get import packages
  108. func (s *Source) Packages(f *ast.Field) (res []string) {
  109. fs := f.Type.(*ast.FuncType).Params.List
  110. fs = append(fs, f.Type.(*ast.FuncType).Results.List...)
  111. var types []string
  112. resMap := make(map[string]bool)
  113. for _, field := range fs {
  114. if p, ok := field.Type.(*ast.MapType); ok {
  115. types = append(types, s.ExprString(p.Key))
  116. types = append(types, s.ExprString(p.Value))
  117. } else if p, ok := field.Type.(*ast.ArrayType); ok {
  118. types = append(types, s.ExprString(p.Elt))
  119. } else {
  120. types = append(types, s.ExprString(field.Type))
  121. }
  122. }
  123. for _, t := range types {
  124. name := RegexpReplace(`(?P<pkg>\w+)\.\w+`, t, "$pkg")
  125. if name == "" {
  126. continue
  127. }
  128. pkg := formatPackage(name, s.pkgPath(name))
  129. if !resMap[pkg] {
  130. resMap[pkg] = true
  131. }
  132. }
  133. for pkg := range resMap {
  134. res = append(res, pkg)
  135. }
  136. return
  137. }