123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477 |
- // Copyright 2017 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package astutil
- import (
- "fmt"
- "go/ast"
- "reflect"
- "sort"
- )
- // An ApplyFunc is invoked by Apply for each node n, even if n is nil,
- // before and/or after the node's children, using a Cursor describing
- // the current node and providing operations on it.
- //
- // The return value of ApplyFunc controls the syntax tree traversal.
- // See Apply for details.
- type ApplyFunc func(*Cursor) bool
- // Apply traverses a syntax tree recursively, starting with root,
- // and calling pre and post for each node as described below.
- // Apply returns the syntax tree, possibly modified.
- //
- // If pre is not nil, it is called for each node before the node's
- // children are traversed (pre-order). If pre returns false, no
- // children are traversed, and post is not called for that node.
- //
- // If post is not nil, and a prior call of pre didn't return false,
- // post is called for each node after its children are traversed
- // (post-order). If post returns false, traversal is terminated and
- // Apply returns immediately.
- //
- // Only fields that refer to AST nodes are considered children;
- // i.e., token.Pos, Scopes, Objects, and fields of basic types
- // (strings, etc.) are ignored.
- //
- // Children are traversed in the order in which they appear in the
- // respective node's struct definition. A package's files are
- // traversed in the filenames' alphabetical order.
- //
- func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
- parent := &struct{ ast.Node }{root}
- defer func() {
- if r := recover(); r != nil && r != abort {
- panic(r)
- }
- result = parent.Node
- }()
- a := &application{pre: pre, post: post}
- a.apply(parent, "Node", nil, root)
- return
- }
- var abort = new(int) // singleton, to signal termination of Apply
- // A Cursor describes a node encountered during Apply.
- // Information about the node and its parent is available
- // from the Node, Parent, Name, and Index methods.
- //
- // If p is a variable of type and value of the current parent node
- // c.Parent(), and f is the field identifier with name c.Name(),
- // the following invariants hold:
- //
- // p.f == c.Node() if c.Index() < 0
- // p.f[c.Index()] == c.Node() if c.Index() >= 0
- //
- // The methods Replace, Delete, InsertBefore, and InsertAfter
- // can be used to change the AST without disrupting Apply.
- type Cursor struct {
- parent ast.Node
- name string
- iter *iterator // valid if non-nil
- node ast.Node
- }
- // Node returns the current Node.
- func (c *Cursor) Node() ast.Node { return c.node }
- // Parent returns the parent of the current Node.
- func (c *Cursor) Parent() ast.Node { return c.parent }
- // Name returns the name of the parent Node field that contains the current Node.
- // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
- // the filename for the current Node.
- func (c *Cursor) Name() string { return c.name }
- // Index reports the index >= 0 of the current Node in the slice of Nodes that
- // contains it, or a value < 0 if the current Node is not part of a slice.
- // The index of the current node changes if InsertBefore is called while
- // processing the current node.
- func (c *Cursor) Index() int {
- if c.iter != nil {
- return c.iter.index
- }
- return -1
- }
- // field returns the current node's parent field value.
- func (c *Cursor) field() reflect.Value {
- return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
- }
- // Replace replaces the current Node with n.
- // The replacement node is not walked by Apply.
- func (c *Cursor) Replace(n ast.Node) {
- if _, ok := c.node.(*ast.File); ok {
- file, ok := n.(*ast.File)
- if !ok {
- panic("attempt to replace *ast.File with non-*ast.File")
- }
- c.parent.(*ast.Package).Files[c.name] = file
- return
- }
- v := c.field()
- if i := c.Index(); i >= 0 {
- v = v.Index(i)
- }
- v.Set(reflect.ValueOf(n))
- }
- // Delete deletes the current Node from its containing slice.
- // If the current Node is not part of a slice, Delete panics.
- // As a special case, if the current node is a package file,
- // Delete removes it from the package's Files map.
- func (c *Cursor) Delete() {
- if _, ok := c.node.(*ast.File); ok {
- delete(c.parent.(*ast.Package).Files, c.name)
- return
- }
- i := c.Index()
- if i < 0 {
- panic("Delete node not contained in slice")
- }
- v := c.field()
- l := v.Len()
- reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
- v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
- v.SetLen(l - 1)
- c.iter.step--
- }
- // InsertAfter inserts n after the current Node in its containing slice.
- // If the current Node is not part of a slice, InsertAfter panics.
- // Apply does not walk n.
- func (c *Cursor) InsertAfter(n ast.Node) {
- i := c.Index()
- if i < 0 {
- panic("InsertAfter node not contained in slice")
- }
- v := c.field()
- v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
- l := v.Len()
- reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
- v.Index(i + 1).Set(reflect.ValueOf(n))
- c.iter.step++
- }
- // InsertBefore inserts n before the current Node in its containing slice.
- // If the current Node is not part of a slice, InsertBefore panics.
- // Apply will not walk n.
- func (c *Cursor) InsertBefore(n ast.Node) {
- i := c.Index()
- if i < 0 {
- panic("InsertBefore node not contained in slice")
- }
- v := c.field()
- v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
- l := v.Len()
- reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
- v.Index(i).Set(reflect.ValueOf(n))
- c.iter.index++
- }
- // application carries all the shared data so we can pass it around cheaply.
- type application struct {
- pre, post ApplyFunc
- cursor Cursor
- iter iterator
- }
- func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
- // convert typed nil into untyped nil
- if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
- n = nil
- }
- // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
- saved := a.cursor
- a.cursor.parent = parent
- a.cursor.name = name
- a.cursor.iter = iter
- a.cursor.node = n
- if a.pre != nil && !a.pre(&a.cursor) {
- a.cursor = saved
- return
- }
- // walk children
- // (the order of the cases matches the order of the corresponding node types in go/ast)
- switch n := n.(type) {
- case nil:
- // nothing to do
- // Comments and fields
- case *ast.Comment:
- // nothing to do
- case *ast.CommentGroup:
- if n != nil {
- a.applyList(n, "List")
- }
- case *ast.Field:
- a.apply(n, "Doc", nil, n.Doc)
- a.applyList(n, "Names")
- a.apply(n, "Type", nil, n.Type)
- a.apply(n, "Tag", nil, n.Tag)
- a.apply(n, "Comment", nil, n.Comment)
- case *ast.FieldList:
- a.applyList(n, "List")
- // Expressions
- case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
- // nothing to do
- case *ast.Ellipsis:
- a.apply(n, "Elt", nil, n.Elt)
- case *ast.FuncLit:
- a.apply(n, "Type", nil, n.Type)
- a.apply(n, "Body", nil, n.Body)
- case *ast.CompositeLit:
- a.apply(n, "Type", nil, n.Type)
- a.applyList(n, "Elts")
- case *ast.ParenExpr:
- a.apply(n, "X", nil, n.X)
- case *ast.SelectorExpr:
- a.apply(n, "X", nil, n.X)
- a.apply(n, "Sel", nil, n.Sel)
- case *ast.IndexExpr:
- a.apply(n, "X", nil, n.X)
- a.apply(n, "Index", nil, n.Index)
- case *ast.SliceExpr:
- a.apply(n, "X", nil, n.X)
- a.apply(n, "Low", nil, n.Low)
- a.apply(n, "High", nil, n.High)
- a.apply(n, "Max", nil, n.Max)
- case *ast.TypeAssertExpr:
- a.apply(n, "X", nil, n.X)
- a.apply(n, "Type", nil, n.Type)
- case *ast.CallExpr:
- a.apply(n, "Fun", nil, n.Fun)
- a.applyList(n, "Args")
- case *ast.StarExpr:
- a.apply(n, "X", nil, n.X)
- case *ast.UnaryExpr:
- a.apply(n, "X", nil, n.X)
- case *ast.BinaryExpr:
- a.apply(n, "X", nil, n.X)
- a.apply(n, "Y", nil, n.Y)
- case *ast.KeyValueExpr:
- a.apply(n, "Key", nil, n.Key)
- a.apply(n, "Value", nil, n.Value)
- // Types
- case *ast.ArrayType:
- a.apply(n, "Len", nil, n.Len)
- a.apply(n, "Elt", nil, n.Elt)
- case *ast.StructType:
- a.apply(n, "Fields", nil, n.Fields)
- case *ast.FuncType:
- a.apply(n, "Params", nil, n.Params)
- a.apply(n, "Results", nil, n.Results)
- case *ast.InterfaceType:
- a.apply(n, "Methods", nil, n.Methods)
- case *ast.MapType:
- a.apply(n, "Key", nil, n.Key)
- a.apply(n, "Value", nil, n.Value)
- case *ast.ChanType:
- a.apply(n, "Value", nil, n.Value)
- // Statements
- case *ast.BadStmt:
- // nothing to do
- case *ast.DeclStmt:
- a.apply(n, "Decl", nil, n.Decl)
- case *ast.EmptyStmt:
- // nothing to do
- case *ast.LabeledStmt:
- a.apply(n, "Label", nil, n.Label)
- a.apply(n, "Stmt", nil, n.Stmt)
- case *ast.ExprStmt:
- a.apply(n, "X", nil, n.X)
- case *ast.SendStmt:
- a.apply(n, "Chan", nil, n.Chan)
- a.apply(n, "Value", nil, n.Value)
- case *ast.IncDecStmt:
- a.apply(n, "X", nil, n.X)
- case *ast.AssignStmt:
- a.applyList(n, "Lhs")
- a.applyList(n, "Rhs")
- case *ast.GoStmt:
- a.apply(n, "Call", nil, n.Call)
- case *ast.DeferStmt:
- a.apply(n, "Call", nil, n.Call)
- case *ast.ReturnStmt:
- a.applyList(n, "Results")
- case *ast.BranchStmt:
- a.apply(n, "Label", nil, n.Label)
- case *ast.BlockStmt:
- a.applyList(n, "List")
- case *ast.IfStmt:
- a.apply(n, "Init", nil, n.Init)
- a.apply(n, "Cond", nil, n.Cond)
- a.apply(n, "Body", nil, n.Body)
- a.apply(n, "Else", nil, n.Else)
- case *ast.CaseClause:
- a.applyList(n, "List")
- a.applyList(n, "Body")
- case *ast.SwitchStmt:
- a.apply(n, "Init", nil, n.Init)
- a.apply(n, "Tag", nil, n.Tag)
- a.apply(n, "Body", nil, n.Body)
- case *ast.TypeSwitchStmt:
- a.apply(n, "Init", nil, n.Init)
- a.apply(n, "Assign", nil, n.Assign)
- a.apply(n, "Body", nil, n.Body)
- case *ast.CommClause:
- a.apply(n, "Comm", nil, n.Comm)
- a.applyList(n, "Body")
- case *ast.SelectStmt:
- a.apply(n, "Body", nil, n.Body)
- case *ast.ForStmt:
- a.apply(n, "Init", nil, n.Init)
- a.apply(n, "Cond", nil, n.Cond)
- a.apply(n, "Post", nil, n.Post)
- a.apply(n, "Body", nil, n.Body)
- case *ast.RangeStmt:
- a.apply(n, "Key", nil, n.Key)
- a.apply(n, "Value", nil, n.Value)
- a.apply(n, "X", nil, n.X)
- a.apply(n, "Body", nil, n.Body)
- // Declarations
- case *ast.ImportSpec:
- a.apply(n, "Doc", nil, n.Doc)
- a.apply(n, "Name", nil, n.Name)
- a.apply(n, "Path", nil, n.Path)
- a.apply(n, "Comment", nil, n.Comment)
- case *ast.ValueSpec:
- a.apply(n, "Doc", nil, n.Doc)
- a.applyList(n, "Names")
- a.apply(n, "Type", nil, n.Type)
- a.applyList(n, "Values")
- a.apply(n, "Comment", nil, n.Comment)
- case *ast.TypeSpec:
- a.apply(n, "Doc", nil, n.Doc)
- a.apply(n, "Name", nil, n.Name)
- a.apply(n, "Type", nil, n.Type)
- a.apply(n, "Comment", nil, n.Comment)
- case *ast.BadDecl:
- // nothing to do
- case *ast.GenDecl:
- a.apply(n, "Doc", nil, n.Doc)
- a.applyList(n, "Specs")
- case *ast.FuncDecl:
- a.apply(n, "Doc", nil, n.Doc)
- a.apply(n, "Recv", nil, n.Recv)
- a.apply(n, "Name", nil, n.Name)
- a.apply(n, "Type", nil, n.Type)
- a.apply(n, "Body", nil, n.Body)
- // Files and packages
- case *ast.File:
- a.apply(n, "Doc", nil, n.Doc)
- a.apply(n, "Name", nil, n.Name)
- a.applyList(n, "Decls")
- // Don't walk n.Comments; they have either been walked already if
- // they are Doc comments, or they can be easily walked explicitly.
- case *ast.Package:
- // collect and sort names for reproducible behavior
- var names []string
- for name := range n.Files {
- names = append(names, name)
- }
- sort.Strings(names)
- for _, name := range names {
- a.apply(n, name, nil, n.Files[name])
- }
- default:
- panic(fmt.Sprintf("Apply: unexpected node type %T", n))
- }
- if a.post != nil && !a.post(&a.cursor) {
- panic(abort)
- }
- a.cursor = saved
- }
- // An iterator controls iteration over a slice of nodes.
- type iterator struct {
- index, step int
- }
- func (a *application) applyList(parent ast.Node, name string) {
- // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
- saved := a.iter
- a.iter.index = 0
- for {
- // must reload parent.name each time, since cursor modifications might change it
- v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
- if a.iter.index >= v.Len() {
- break
- }
- // element x may be nil in a bad AST - be cautious
- var x ast.Node
- if e := v.Index(a.iter.index); e.IsValid() {
- x = e.Interface().(ast.Node)
- }
- a.iter.step = 1
- a.apply(parent, name, &a.iter, x)
- a.iter.index += a.iter.step
- }
- a.iter = saved
- }
|