123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494 |
- // Copyright 2012 Google Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package mockgen
- // This file contains the model construction by parsing source files.
- import (
- "fmt"
- "go/ast"
- "go/build"
- "go/parser"
- "go/token"
- "log"
- "path"
- "path/filepath"
- "strconv"
- "strings"
- "github.com/otokaze/mock/mockgen/model"
- )
- // ParseFile parse by a file
- // TODO: simplify error reporting
- func ParseFile(source string) (*model.Package, error) {
- srcDir, err := filepath.Abs(filepath.Dir(source))
- if err != nil {
- return nil, fmt.Errorf("failed getting source directory: %v", err)
- }
- var packageImport string
- if p, err := build.ImportDir(srcDir, 0); err == nil {
- packageImport = p.ImportPath
- } // TODO: should we fail if this returns an error?
- fs := token.NewFileSet()
- file, err := parser.ParseFile(fs, source, nil, 0)
- if err != nil {
- return nil, fmt.Errorf("failed parsing source file %v: %v", source, err)
- }
- p := &fileParser{
- fileSet: fs,
- imports: make(map[string]string),
- importedInterfaces: make(map[string]map[string]*ast.InterfaceType),
- auxInterfaces: make(map[string]map[string]*ast.InterfaceType),
- srcDir: srcDir,
- }
- // Handle -imports.
- dotImports := make(map[string]bool)
- if imports != "" {
- for _, kv := range strings.Split(imports, ",") {
- eq := strings.Index(kv, "=")
- k, v := kv[:eq], kv[eq+1:]
- if k == "." {
- // TODO: Catch dupes?
- dotImports[v] = true
- } else {
- // TODO: Catch dupes?
- p.imports[k] = v
- }
- }
- }
- // Handle -aux_files.
- if err := p.parseAuxFiles(auxFiles); err != nil {
- return nil, err
- }
- p.addAuxInterfacesFromFile(packageImport, file) // this file
- pkg, err := p.parseFile(packageImport, file)
- if err != nil {
- return nil, err
- }
- pkg.DotImports = make([]string, 0, len(dotImports))
- for path := range dotImports {
- pkg.DotImports = append(pkg.DotImports, path)
- }
- return pkg, nil
- }
- type fileParser struct {
- fileSet *token.FileSet
- imports map[string]string // package name => import path
- importedInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
- auxFiles []*ast.File
- auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
- srcDir string
- }
- func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error {
- ps := p.fileSet.Position(pos)
- format = "%s:%d:%d: " + format
- args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...)
- return fmt.Errorf(format, args...)
- }
- func (p *fileParser) parseAuxFiles(auxFiles string) error {
- auxFiles = strings.TrimSpace(auxFiles)
- if auxFiles == "" {
- return nil
- }
- for _, kv := range strings.Split(auxFiles, ",") {
- parts := strings.SplitN(kv, "=", 2)
- if len(parts) != 2 {
- return fmt.Errorf("bad aux file spec: %v", kv)
- }
- pkg, fpath := parts[0], parts[1]
- file, err := parser.ParseFile(p.fileSet, fpath, nil, 0)
- if err != nil {
- return err
- }
- p.auxFiles = append(p.auxFiles, file)
- p.addAuxInterfacesFromFile(pkg, file)
- }
- return nil
- }
- func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
- if _, ok := p.auxInterfaces[pkg]; !ok {
- p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
- }
- for ni := range iterInterfaces(file) {
- p.auxInterfaces[pkg][ni.name.Name] = ni.it
- }
- }
- // parseFile loads all file imports and auxiliary files import into the
- // fileParser, parses all file interfaces and returns package model.
- func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Package, error) {
- allImports := importsOfFile(file)
- // Don't stomp imports provided by -imports. Those should take precedence.
- for pkg, path := range allImports {
- if _, ok := p.imports[pkg]; !ok {
- p.imports[pkg] = path
- }
- }
- // Add imports from auxiliary files, which might be needed for embedded interfaces.
- // Don't stomp any other imports.
- for _, f := range p.auxFiles {
- for pkg, path := range importsOfFile(f) {
- if _, ok := p.imports[pkg]; !ok {
- p.imports[pkg] = path
- }
- }
- }
- var is []*model.Interface
- for ni := range iterInterfaces(file) {
- i, err := p.parseInterface(ni.name.String(), importPath, ni.it)
- if err != nil {
- return nil, err
- }
- is = append(is, i)
- }
- return &model.Package{
- Name: file.Name.String(),
- SrcDir: p.srcDir,
- Interfaces: is,
- }, nil
- }
- // parsePackage loads package specified by path, parses it and populates
- // corresponding imports and importedInterfaces into the fileParser.
- func (p *fileParser) parsePackage(path string) error {
- var pkgs map[string]*ast.Package
- if imp, err := build.Import(path, p.srcDir, build.FindOnly); err != nil {
- return err
- } else if pkgs, err = parser.ParseDir(p.fileSet, imp.Dir, nil, 0); err != nil {
- return err
- }
- for _, pkg := range pkgs {
- file := ast.MergePackageFiles(pkg, ast.FilterFuncDuplicates|ast.FilterUnassociatedComments|ast.FilterImportDuplicates)
- if _, ok := p.importedInterfaces[path]; !ok {
- p.importedInterfaces[path] = make(map[string]*ast.InterfaceType)
- }
- for ni := range iterInterfaces(file) {
- p.importedInterfaces[path][ni.name.Name] = ni.it
- }
- for pkgName, pkgPath := range importsOfFile(file) {
- if _, ok := p.imports[pkgName]; !ok {
- p.imports[pkgName] = pkgPath
- }
- }
- }
- return nil
- }
- func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
- intf := &model.Interface{Name: name}
- for _, field := range it.Methods.List {
- switch v := field.Type.(type) {
- case *ast.FuncType:
- if nn := len(field.Names); nn != 1 {
- return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn)
- }
- m := &model.Method{
- Name: field.Names[0].String(),
- }
- var err error
- m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v)
- if err != nil {
- return nil, err
- }
- intf.Methods = append(intf.Methods, m)
- case *ast.Ident:
- // Embedded interface in this package.
- ei := p.auxInterfaces[pkg][v.String()]
- if ei == nil {
- if ei = p.importedInterfaces[pkg][v.String()]; ei == nil {
- return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String())
- }
- }
- eintf, err := p.parseInterface(v.String(), pkg, ei)
- if err != nil {
- return nil, err
- }
- // Copy the methods.
- // TODO: apply shadowing rules.
- intf.Methods = append(intf.Methods, eintf.Methods...)
- case *ast.SelectorExpr:
- // Embedded interface in another package.
- fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
- epkg, ok := p.imports[fpkg]
- if !ok {
- return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg)
- }
- ei := p.auxInterfaces[fpkg][sel]
- if ei == nil {
- fpkg = epkg
- if _, ok = p.importedInterfaces[epkg]; !ok {
- if err := p.parsePackage(epkg); err != nil {
- return nil, p.errorf(v.Pos(), "could not parse package %s: %v", fpkg, err)
- }
- }
- if ei = p.importedInterfaces[epkg][sel]; ei == nil {
- return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", fpkg, sel)
- }
- }
- eintf, err := p.parseInterface(sel, fpkg, ei)
- if err != nil {
- return nil, err
- }
- // Copy the methods.
- // TODO: apply shadowing rules.
- intf.Methods = append(intf.Methods, eintf.Methods...)
- default:
- return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
- }
- }
- return intf, nil
- }
- func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) {
- if f.Params != nil {
- regParams := f.Params.List
- if isVariadic(f) {
- n := len(regParams)
- varParams := regParams[n-1:]
- regParams = regParams[:n-1]
- vp, err := p.parseFieldList(pkg, varParams)
- if err != nil {
- return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
- }
- variadic = vp[0]
- }
- in, err = p.parseFieldList(pkg, regParams)
- if err != nil {
- return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
- }
- }
- if f.Results != nil {
- out, err = p.parseFieldList(pkg, f.Results.List)
- if err != nil {
- return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
- }
- }
- return
- }
- func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) {
- nf := 0
- for _, f := range fields {
- nn := len(f.Names)
- if nn == 0 {
- nn = 1 // anonymous parameter
- }
- nf += nn
- }
- if nf == 0 {
- return nil, nil
- }
- ps := make([]*model.Parameter, nf)
- i := 0 // destination index
- for _, f := range fields {
- t, err := p.parseType(pkg, f.Type)
- if err != nil {
- return nil, err
- }
- if len(f.Names) == 0 {
- // anonymous arg
- ps[i] = &model.Parameter{Type: t}
- i++
- continue
- }
- for _, name := range f.Names {
- ps[i] = &model.Parameter{Name: name.Name, Type: t}
- i++
- }
- }
- return ps, nil
- }
- func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
- switch v := typ.(type) {
- case *ast.ArrayType:
- ln := -1
- if v.Len != nil {
- x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value)
- if err != nil {
- return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
- }
- ln = x
- }
- t, err := p.parseType(pkg, v.Elt)
- if err != nil {
- return nil, err
- }
- return &model.ArrayType{Len: ln, Type: t}, nil
- case *ast.ChanType:
- t, err := p.parseType(pkg, v.Value)
- if err != nil {
- return nil, err
- }
- var dir model.ChanDir
- if v.Dir == ast.SEND {
- dir = model.SendDir
- }
- if v.Dir == ast.RECV {
- dir = model.RecvDir
- }
- return &model.ChanType{Dir: dir, Type: t}, nil
- case *ast.Ellipsis:
- // assume we're parsing a variadic argument
- return p.parseType(pkg, v.Elt)
- case *ast.FuncType:
- in, variadic, out, err := p.parseFunc(pkg, v)
- if err != nil {
- return nil, err
- }
- return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
- case *ast.Ident:
- if v.IsExported() {
- // `pkg` may be an aliased imported pkg
- // if so, patch the import w/ the fully qualified import
- maybeImportedPkg, ok := p.imports[pkg]
- if ok {
- pkg = maybeImportedPkg
- }
- // assume type in this package
- return &model.NamedType{Package: pkg, Type: v.Name}, nil
- }
- return model.PredeclaredType(v.Name), nil
- case *ast.InterfaceType:
- if v.Methods != nil && len(v.Methods.List) > 0 {
- return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
- }
- return model.PredeclaredType("interface{}"), nil
- case *ast.MapType:
- key, err := p.parseType(pkg, v.Key)
- if err != nil {
- return nil, err
- }
- value, err := p.parseType(pkg, v.Value)
- if err != nil {
- return nil, err
- }
- return &model.MapType{Key: key, Value: value}, nil
- case *ast.SelectorExpr:
- pkgName := v.X.(*ast.Ident).String()
- pkg, ok := p.imports[pkgName]
- if !ok {
- return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
- }
- return &model.NamedType{Package: pkg, Type: v.Sel.String()}, nil
- case *ast.StarExpr:
- t, err := p.parseType(pkg, v.X)
- if err != nil {
- return nil, err
- }
- return &model.PointerType{Type: t}, nil
- case *ast.StructType:
- if v.Fields != nil && len(v.Fields.List) > 0 {
- return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types")
- }
- return model.PredeclaredType("struct{}"), nil
- }
- return nil, fmt.Errorf("don't know how to parse type %T", typ)
- }
- // importsOfFile returns a map of package name to import path
- // of the imports in file.
- func importsOfFile(file *ast.File) map[string]string {
- m := make(map[string]string)
- for _, is := range file.Imports {
- var pkgName string
- importPath := is.Path.Value[1 : len(is.Path.Value)-1] // remove quotes
- if is.Name != nil {
- // Named imports are always certain.
- if is.Name.Name == "_" {
- continue
- }
- pkgName = removeDot(is.Name.Name)
- } else {
- pkg, err := build.Import(importPath, "", 0)
- if err != nil {
- // Fallback to import path suffix. Note that this is uncertain.
- _, last := path.Split(importPath)
- // If the last path component has dots, the first dot-delimited
- // field is used as the name.
- pkgName = strings.SplitN(last, ".", 2)[0]
- } else {
- pkgName = pkg.Name
- }
- }
- if _, ok := m[pkgName]; ok {
- log.Fatalf("imported package collision: %q imported twice", pkgName)
- }
- m[pkgName] = importPath
- }
- return m
- }
- type namedInterface struct {
- name *ast.Ident
- it *ast.InterfaceType
- }
- // Create an iterator over all interfaces in file.
- func iterInterfaces(file *ast.File) <-chan namedInterface {
- ch := make(chan namedInterface)
- go func() {
- for _, decl := range file.Decls {
- gd, ok := decl.(*ast.GenDecl)
- if !ok || gd.Tok != token.TYPE {
- continue
- }
- for _, spec := range gd.Specs {
- ts, ok := spec.(*ast.TypeSpec)
- if !ok {
- continue
- }
- it, ok := ts.Type.(*ast.InterfaceType)
- if !ok {
- continue
- }
- ch <- namedInterface{ts.Name, it}
- }
- }
- close(ch)
- }()
- return ch
- }
- // isVariadic returns whether the function is variadic.
- func isVariadic(f *ast.FuncType) bool {
- nargs := len(f.Params.List)
- if nargs == 0 {
- return false
- }
- _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
- return ok
- }
|