parse.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. // Copyright 2012 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package mockgen
  15. // This file contains the model construction by parsing source files.
  16. import (
  17. "fmt"
  18. "go/ast"
  19. "go/build"
  20. "go/parser"
  21. "go/token"
  22. "log"
  23. "path"
  24. "path/filepath"
  25. "strconv"
  26. "strings"
  27. "github.com/otokaze/mock/mockgen/model"
  28. )
  29. // ParseFile parse by a file
  30. // TODO: simplify error reporting
  31. func ParseFile(source string) (*model.Package, error) {
  32. srcDir, err := filepath.Abs(filepath.Dir(source))
  33. if err != nil {
  34. return nil, fmt.Errorf("failed getting source directory: %v", err)
  35. }
  36. var packageImport string
  37. if p, err := build.ImportDir(srcDir, 0); err == nil {
  38. packageImport = p.ImportPath
  39. } // TODO: should we fail if this returns an error?
  40. fs := token.NewFileSet()
  41. file, err := parser.ParseFile(fs, source, nil, 0)
  42. if err != nil {
  43. return nil, fmt.Errorf("failed parsing source file %v: %v", source, err)
  44. }
  45. p := &fileParser{
  46. fileSet: fs,
  47. imports: make(map[string]string),
  48. importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
  49. auxInterfaces: make(map[string]map[string]*ast.InterfaceType),
  50. srcDir: srcDir,
  51. }
  52. // Handle -imports.
  53. dotImports := make(map[string]bool)
  54. if imports != "" {
  55. for _, kv := range strings.Split(imports, ",") {
  56. eq := strings.Index(kv, "=")
  57. k, v := kv[:eq], kv[eq+1:]
  58. if k == "." {
  59. // TODO: Catch dupes?
  60. dotImports[v] = true
  61. } else {
  62. // TODO: Catch dupes?
  63. p.imports[k] = v
  64. }
  65. }
  66. }
  67. // Handle -aux_files.
  68. if err := p.parseAuxFiles(auxFiles); err != nil {
  69. return nil, err
  70. }
  71. p.addAuxInterfacesFromFile(packageImport, file) // this file
  72. pkg, err := p.parseFile(packageImport, file)
  73. if err != nil {
  74. return nil, err
  75. }
  76. pkg.DotImports = make([]string, 0, len(dotImports))
  77. for path := range dotImports {
  78. pkg.DotImports = append(pkg.DotImports, path)
  79. }
  80. return pkg, nil
  81. }
  82. type fileParser struct {
  83. fileSet *token.FileSet
  84. imports map[string]string // package name => import path
  85. importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
  86. auxFiles []*ast.File
  87. auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
  88. srcDir string
  89. }
  90. func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error {
  91. ps := p.fileSet.Position(pos)
  92. format = "%s:%d:%d: " + format
  93. args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...)
  94. return fmt.Errorf(format, args...)
  95. }
  96. func (p *fileParser) parseAuxFiles(auxFiles string) error {
  97. auxFiles = strings.TrimSpace(auxFiles)
  98. if auxFiles == "" {
  99. return nil
  100. }
  101. for _, kv := range strings.Split(auxFiles, ",") {
  102. parts := strings.SplitN(kv, "=", 2)
  103. if len(parts) != 2 {
  104. return fmt.Errorf("bad aux file spec: %v", kv)
  105. }
  106. pkg, fpath := parts[0], parts[1]
  107. file, err := parser.ParseFile(p.fileSet, fpath, nil, 0)
  108. if err != nil {
  109. return err
  110. }
  111. p.auxFiles = append(p.auxFiles, file)
  112. p.addAuxInterfacesFromFile(pkg, file)
  113. }
  114. return nil
  115. }
  116. func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
  117. if _, ok := p.auxInterfaces[pkg]; !ok {
  118. p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
  119. }
  120. for ni := range iterInterfaces(file) {
  121. p.auxInterfaces[pkg][ni.name.Name] = ni.it
  122. }
  123. }
  124. // parseFile loads all file imports and auxiliary files import into the
  125. // fileParser, parses all file interfaces and returns package model.
  126. func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
  127. allImports := importsOfFile(file)
  128. // Don't stomp imports provided by -imports. Those should take precedence.
  129. for pkg, path := range allImports {
  130. if _, ok := p.imports[pkg]; !ok {
  131. p.imports[pkg] = path
  132. }
  133. }
  134. // Add imports from auxiliary files, which might be needed for embedded interfaces.
  135. // Don't stomp any other imports.
  136. for _, f := range p.auxFiles {
  137. for pkg, path := range importsOfFile(f) {
  138. if _, ok := p.imports[pkg]; !ok {
  139. p.imports[pkg] = path
  140. }
  141. }
  142. }
  143. var is []*model.Interface
  144. for ni := range iterInterfaces(file) {
  145. i, err := p.parseInterface(ni.name.String(), importPath, ni.it)
  146. if err != nil {
  147. return nil, err
  148. }
  149. is = append(is, i)
  150. }
  151. return &model.Package{
  152. Name: file.Name.String(),
  153. SrcDir: p.srcDir,
  154. Interfaces: is,
  155. }, nil
  156. }
  157. // parsePackage loads package specified by path, parses it and populates
  158. // corresponding imports and importedInterfaces into the fileParser.
  159. func (p *fileParser) parsePackage(path string) error {
  160. var pkgs map[string]*ast.Package
  161. if imp, err := build.Import(path, p.srcDir, build.FindOnly); err != nil {
  162. return err
  163. } else if pkgs, err = parser.ParseDir(p.fileSet, imp.Dir, nil, 0); err != nil {
  164. return err
  165. }
  166. for _, pkg := range pkgs {
  167. file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
  168. if _, ok := p.importedInterfaces[path]; !ok {
  169. p.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
  170. }
  171. for ni := range iterInterfaces(file) {
  172. p.importedInterfaces[path][ni.name.Name] = ni.it
  173. }
  174. for pkgName, pkgPath := range importsOfFile(file) {
  175. if _, ok := p.imports[pkgName]; !ok {
  176. p.imports[pkgName] = pkgPath
  177. }
  178. }
  179. }
  180. return nil
  181. }
  182. func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
  183. intf := &model.Interface{Name: name}
  184. for _, field := range it.Methods.List {
  185. switch v := field.Type.(type) {
  186. case *ast.FuncType:
  187. if nn := len(field.Names); nn != 1 {
  188. return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn)
  189. }
  190. m := &model.Method{
  191. Name: field.Names[0].String(),
  192. }
  193. var err error
  194. m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v)
  195. if err != nil {
  196. return nil, err
  197. }
  198. intf.Methods = append(intf.Methods, m)
  199. case *ast.Ident:
  200. // Embedded interface in this package.
  201. ei := p.auxInterfaces[pkg][v.String()]
  202. if ei == nil {
  203. if ei = p.importedInterfaces[pkg][v.String()]; ei == nil {
  204. return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String())
  205. }
  206. }
  207. eintf, err := p.parseInterface(v.String(), pkg, ei)
  208. if err != nil {
  209. return nil, err
  210. }
  211. // Copy the methods.
  212. // TODO: apply shadowing rules.
  213. intf.Methods = append(intf.Methods, eintf.Methods...)
  214. case *ast.SelectorExpr:
  215. // Embedded interface in another package.
  216. fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
  217. epkg, ok := p.imports[fpkg]
  218. if !ok {
  219. return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg)
  220. }
  221. ei := p.auxInterfaces[fpkg][sel]
  222. if ei == nil {
  223. fpkg = epkg
  224. if _, ok = p.importedInterfaces[epkg]; !ok {
  225. if err := p.parsePackage(epkg); err != nil {
  226. return nil, p.errorf(v.Pos(), "could not parse package %s: %v", fpkg, err)
  227. }
  228. }
  229. if ei = p.importedInterfaces[epkg][sel]; ei == nil {
  230. return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", fpkg, sel)
  231. }
  232. }
  233. eintf, err := p.parseInterface(sel, fpkg, ei)
  234. if err != nil {
  235. return nil, err
  236. }
  237. // Copy the methods.
  238. // TODO: apply shadowing rules.
  239. intf.Methods = append(intf.Methods, eintf.Methods...)
  240. default:
  241. return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
  242. }
  243. }
  244. return intf, nil
  245. }
  246. func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) {
  247. if f.Params != nil {
  248. regParams := f.Params.List
  249. if isVariadic(f) {
  250. n := len(regParams)
  251. varParams := regParams[n-1:]
  252. regParams = regParams[:n-1]
  253. vp, err := p.parseFieldList(pkg, varParams)
  254. if err != nil {
  255. return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
  256. }
  257. variadic = vp[0]
  258. }
  259. in, err = p.parseFieldList(pkg, regParams)
  260. if err != nil {
  261. return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
  262. }
  263. }
  264. if f.Results != nil {
  265. out, err = p.parseFieldList(pkg, f.Results.List)
  266. if err != nil {
  267. return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
  268. }
  269. }
  270. return
  271. }
  272. func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) {
  273. nf := 0
  274. for _, f := range fields {
  275. nn := len(f.Names)
  276. if nn == 0 {
  277. nn = 1 // anonymous parameter
  278. }
  279. nf += nn
  280. }
  281. if nf == 0 {
  282. return nil, nil
  283. }
  284. ps := make([]*model.Parameter, nf)
  285. i := 0 // destination index
  286. for _, f := range fields {
  287. t, err := p.parseType(pkg, f.Type)
  288. if err != nil {
  289. return nil, err
  290. }
  291. if len(f.Names) == 0 {
  292. // anonymous arg
  293. ps[i] = &model.Parameter{Type: t}
  294. i++
  295. continue
  296. }
  297. for _, name := range f.Names {
  298. ps[i] = &model.Parameter{Name: name.Name, Type: t}
  299. i++
  300. }
  301. }
  302. return ps, nil
  303. }
  304. func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
  305. switch v := typ.(type) {
  306. case *ast.ArrayType:
  307. ln := -1
  308. if v.Len != nil {
  309. x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value)
  310. if err != nil {
  311. return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
  312. }
  313. ln = x
  314. }
  315. t, err := p.parseType(pkg, v.Elt)
  316. if err != nil {
  317. return nil, err
  318. }
  319. return &model.ArrayType{Len: ln, Type: t}, nil
  320. case *ast.ChanType:
  321. t, err := p.parseType(pkg, v.Value)
  322. if err != nil {
  323. return nil, err
  324. }
  325. var dir model.ChanDir
  326. if v.Dir == ast.SEND {
  327. dir = model.SendDir
  328. }
  329. if v.Dir == ast.RECV {
  330. dir = model.RecvDir
  331. }
  332. return &model.ChanType{Dir: dir, Type: t}, nil
  333. case *ast.Ellipsis:
  334. // assume we're parsing a variadic argument
  335. return p.parseType(pkg, v.Elt)
  336. case *ast.FuncType:
  337. in, variadic, out, err := p.parseFunc(pkg, v)
  338. if err != nil {
  339. return nil, err
  340. }
  341. return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
  342. case *ast.Ident:
  343. if v.IsExported() {
  344. // `pkg` may be an aliased imported pkg
  345. // if so, patch the import w/ the fully qualified import
  346. maybeImportedPkg, ok := p.imports[pkg]
  347. if ok {
  348. pkg = maybeImportedPkg
  349. }
  350. // assume type in this package
  351. return &model.NamedType{Package: pkg, Type: v.Name}, nil
  352. }
  353. return model.PredeclaredType(v.Name), nil
  354. case *ast.InterfaceType:
  355. if v.Methods != nil && len(v.Methods.List) > 0 {
  356. return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
  357. }
  358. return model.PredeclaredType("interface{}"), nil
  359. case *ast.MapType:
  360. key, err := p.parseType(pkg, v.Key)
  361. if err != nil {
  362. return nil, err
  363. }
  364. value, err := p.parseType(pkg, v.Value)
  365. if err != nil {
  366. return nil, err
  367. }
  368. return &model.MapType{Key: key, Value: value}, nil
  369. case *ast.SelectorExpr:
  370. pkgName := v.X.(*ast.Ident).String()
  371. pkg, ok := p.imports[pkgName]
  372. if !ok {
  373. return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
  374. }
  375. return &model.NamedType{Package: pkg, Type: v.Sel.String()}, nil
  376. case *ast.StarExpr:
  377. t, err := p.parseType(pkg, v.X)
  378. if err != nil {
  379. return nil, err
  380. }
  381. return &model.PointerType{Type: t}, nil
  382. case *ast.StructType:
  383. if v.Fields != nil && len(v.Fields.List) > 0 {
  384. return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types")
  385. }
  386. return model.PredeclaredType("struct{}"), nil
  387. }
  388. return nil, fmt.Errorf("don't know how to parse type %T", typ)
  389. }
  390. // importsOfFile returns a map of package name to import path
  391. // of the imports in file.
  392. func importsOfFile(file *ast.File) map[string]string {
  393. m := make(map[string]string)
  394. for _, is := range file.Imports {
  395. var pkgName string
  396. importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
  397. if is.Name != nil {
  398. // Named imports are always certain.
  399. if is.Name.Name == "_" {
  400. continue
  401. }
  402. pkgName = removeDot(is.Name.Name)
  403. } else {
  404. pkg, err := build.Import(importPath, "", 0)
  405. if err != nil {
  406. // Fallback to import path suffix. Note that this is uncertain.
  407. _, last := path.Split(importPath)
  408. // If the last path component has dots, the first dot-delimited
  409. // field is used as the name.
  410. pkgName = strings.SplitN(last, ".", 2)[0]
  411. } else {
  412. pkgName = pkg.Name
  413. }
  414. }
  415. if _, ok := m[pkgName]; ok {
  416. log.Fatalf("imported package collision: %q imported twice", pkgName)
  417. }
  418. m[pkgName] = importPath
  419. }
  420. return m
  421. }
  422. type namedInterface struct {
  423. name *ast.Ident
  424. it *ast.InterfaceType
  425. }
  426. // Create an iterator over all interfaces in file.
  427. func iterInterfaces(file *ast.File) <-chan namedInterface {
  428. ch := make(chan namedInterface)
  429. go func() {
  430. for _, decl := range file.Decls {
  431. gd, ok := decl.(*ast.GenDecl)
  432. if !ok || gd.Tok != token.TYPE {
  433. continue
  434. }
  435. for _, spec := range gd.Specs {
  436. ts, ok := spec.(*ast.TypeSpec)
  437. if !ok {
  438. continue
  439. }
  440. it, ok := ts.Type.(*ast.InterfaceType)
  441. if !ok {
  442. continue
  443. }
  444. ch <- namedInterface{ts.Name, it}
  445. }
  446. }
  447. close(ch)
  448. }()
  449. return ch
  450. }
  451. // isVariadic returns whether the function is variadic.
  452. func isVariadic(f *ast.FuncType) bool {
  453. nargs := len(f.Params.List)
  454. if nargs == 0 {
  455. return false
  456. }
  457. _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
  458. return ok
  459. }