rewrite.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. // Copyright 2017 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package astutil
  5. import (
  6. "fmt"
  7. "go/ast"
  8. "reflect"
  9. "sort"
  10. )
  11. // An ApplyFunc is invoked by Apply for each node n, even if n is nil,
  12. // before and/or after the node's children, using a Cursor describing
  13. // the current node and providing operations on it.
  14. //
  15. // The return value of ApplyFunc controls the syntax tree traversal.
  16. // See Apply for details.
  17. type ApplyFunc func(*Cursor) bool
  18. // Apply traverses a syntax tree recursively, starting with root,
  19. // and calling pre and post for each node as described below.
  20. // Apply returns the syntax tree, possibly modified.
  21. //
  22. // If pre is not nil, it is called for each node before the node's
  23. // children are traversed (pre-order). If pre returns false, no
  24. // children are traversed, and post is not called for that node.
  25. //
  26. // If post is not nil, and a prior call of pre didn't return false,
  27. // post is called for each node after its children are traversed
  28. // (post-order). If post returns false, traversal is terminated and
  29. // Apply returns immediately.
  30. //
  31. // Only fields that refer to AST nodes are considered children;
  32. // i.e., token.Pos, Scopes, Objects, and fields of basic types
  33. // (strings, etc.) are ignored.
  34. //
  35. // Children are traversed in the order in which they appear in the
  36. // respective node's struct definition. A package's files are
  37. // traversed in the filenames' alphabetical order.
  38. //
  39. func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
  40. parent := &struct{ ast.Node }{root}
  41. defer func() {
  42. if r := recover(); r != nil && r != abort {
  43. panic(r)
  44. }
  45. result = parent.Node
  46. }()
  47. a := &application{pre: pre, post: post}
  48. a.apply(parent, "Node", nil, root)
  49. return
  50. }
  51. var abort = new(int) // singleton, to signal termination of Apply
  52. // A Cursor describes a node encountered during Apply.
  53. // Information about the node and its parent is available
  54. // from the Node, Parent, Name, and Index methods.
  55. //
  56. // If p is a variable of type and value of the current parent node
  57. // c.Parent(), and f is the field identifier with name c.Name(),
  58. // the following invariants hold:
  59. //
  60. // p.f == c.Node() if c.Index() < 0
  61. // p.f[c.Index()] == c.Node() if c.Index() >= 0
  62. //
  63. // The methods Replace, Delete, InsertBefore, and InsertAfter
  64. // can be used to change the AST without disrupting Apply.
  65. type Cursor struct {
  66. parent ast.Node
  67. name string
  68. iter *iterator // valid if non-nil
  69. node ast.Node
  70. }
  71. // Node returns the current Node.
  72. func (c *Cursor) Node() ast.Node { return c.node }
  73. // Parent returns the parent of the current Node.
  74. func (c *Cursor) Parent() ast.Node { return c.parent }
  75. // Name returns the name of the parent Node field that contains the current Node.
  76. // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
  77. // the filename for the current Node.
  78. func (c *Cursor) Name() string { return c.name }
  79. // Index reports the index >= 0 of the current Node in the slice of Nodes that
  80. // contains it, or a value < 0 if the current Node is not part of a slice.
  81. // The index of the current node changes if InsertBefore is called while
  82. // processing the current node.
  83. func (c *Cursor) Index() int {
  84. if c.iter != nil {
  85. return c.iter.index
  86. }
  87. return -1
  88. }
  89. // field returns the current node's parent field value.
  90. func (c *Cursor) field() reflect.Value {
  91. return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
  92. }
  93. // Replace replaces the current Node with n.
  94. // The replacement node is not walked by Apply.
  95. func (c *Cursor) Replace(n ast.Node) {
  96. if _, ok := c.node.(*ast.File); ok {
  97. file, ok := n.(*ast.File)
  98. if !ok {
  99. panic("attempt to replace *ast.File with non-*ast.File")
  100. }
  101. c.parent.(*ast.Package).Files[c.name] = file
  102. return
  103. }
  104. v := c.field()
  105. if i := c.Index(); i >= 0 {
  106. v = v.Index(i)
  107. }
  108. v.Set(reflect.ValueOf(n))
  109. }
  110. // Delete deletes the current Node from its containing slice.
  111. // If the current Node is not part of a slice, Delete panics.
  112. // As a special case, if the current node is a package file,
  113. // Delete removes it from the package's Files map.
  114. func (c *Cursor) Delete() {
  115. if _, ok := c.node.(*ast.File); ok {
  116. delete(c.parent.(*ast.Package).Files, c.name)
  117. return
  118. }
  119. i := c.Index()
  120. if i < 0 {
  121. panic("Delete node not contained in slice")
  122. }
  123. v := c.field()
  124. l := v.Len()
  125. reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
  126. v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
  127. v.SetLen(l - 1)
  128. c.iter.step--
  129. }
  130. // InsertAfter inserts n after the current Node in its containing slice.
  131. // If the current Node is not part of a slice, InsertAfter panics.
  132. // Apply does not walk n.
  133. func (c *Cursor) InsertAfter(n ast.Node) {
  134. i := c.Index()
  135. if i < 0 {
  136. panic("InsertAfter node not contained in slice")
  137. }
  138. v := c.field()
  139. v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
  140. l := v.Len()
  141. reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
  142. v.Index(i + 1).Set(reflect.ValueOf(n))
  143. c.iter.step++
  144. }
  145. // InsertBefore inserts n before the current Node in its containing slice.
  146. // If the current Node is not part of a slice, InsertBefore panics.
  147. // Apply will not walk n.
  148. func (c *Cursor) InsertBefore(n ast.Node) {
  149. i := c.Index()
  150. if i < 0 {
  151. panic("InsertBefore node not contained in slice")
  152. }
  153. v := c.field()
  154. v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
  155. l := v.Len()
  156. reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
  157. v.Index(i).Set(reflect.ValueOf(n))
  158. c.iter.index++
  159. }
  160. // application carries all the shared data so we can pass it around cheaply.
  161. type application struct {
  162. pre, post ApplyFunc
  163. cursor Cursor
  164. iter iterator
  165. }
  166. func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
  167. // convert typed nil into untyped nil
  168. if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
  169. n = nil
  170. }
  171. // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
  172. saved := a.cursor
  173. a.cursor.parent = parent
  174. a.cursor.name = name
  175. a.cursor.iter = iter
  176. a.cursor.node = n
  177. if a.pre != nil && !a.pre(&a.cursor) {
  178. a.cursor = saved
  179. return
  180. }
  181. // walk children
  182. // (the order of the cases matches the order of the corresponding node types in go/ast)
  183. switch n := n.(type) {
  184. case nil:
  185. // nothing to do
  186. // Comments and fields
  187. case *ast.Comment:
  188. // nothing to do
  189. case *ast.CommentGroup:
  190. if n != nil {
  191. a.applyList(n, "List")
  192. }
  193. case *ast.Field:
  194. a.apply(n, "Doc", nil, n.Doc)
  195. a.applyList(n, "Names")
  196. a.apply(n, "Type", nil, n.Type)
  197. a.apply(n, "Tag", nil, n.Tag)
  198. a.apply(n, "Comment", nil, n.Comment)
  199. case *ast.FieldList:
  200. a.applyList(n, "List")
  201. // Expressions
  202. case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
  203. // nothing to do
  204. case *ast.Ellipsis:
  205. a.apply(n, "Elt", nil, n.Elt)
  206. case *ast.FuncLit:
  207. a.apply(n, "Type", nil, n.Type)
  208. a.apply(n, "Body", nil, n.Body)
  209. case *ast.CompositeLit:
  210. a.apply(n, "Type", nil, n.Type)
  211. a.applyList(n, "Elts")
  212. case *ast.ParenExpr:
  213. a.apply(n, "X", nil, n.X)
  214. case *ast.SelectorExpr:
  215. a.apply(n, "X", nil, n.X)
  216. a.apply(n, "Sel", nil, n.Sel)
  217. case *ast.IndexExpr:
  218. a.apply(n, "X", nil, n.X)
  219. a.apply(n, "Index", nil, n.Index)
  220. case *ast.SliceExpr:
  221. a.apply(n, "X", nil, n.X)
  222. a.apply(n, "Low", nil, n.Low)
  223. a.apply(n, "High", nil, n.High)
  224. a.apply(n, "Max", nil, n.Max)
  225. case *ast.TypeAssertExpr:
  226. a.apply(n, "X", nil, n.X)
  227. a.apply(n, "Type", nil, n.Type)
  228. case *ast.CallExpr:
  229. a.apply(n, "Fun", nil, n.Fun)
  230. a.applyList(n, "Args")
  231. case *ast.StarExpr:
  232. a.apply(n, "X", nil, n.X)
  233. case *ast.UnaryExpr:
  234. a.apply(n, "X", nil, n.X)
  235. case *ast.BinaryExpr:
  236. a.apply(n, "X", nil, n.X)
  237. a.apply(n, "Y", nil, n.Y)
  238. case *ast.KeyValueExpr:
  239. a.apply(n, "Key", nil, n.Key)
  240. a.apply(n, "Value", nil, n.Value)
  241. // Types
  242. case *ast.ArrayType:
  243. a.apply(n, "Len", nil, n.Len)
  244. a.apply(n, "Elt", nil, n.Elt)
  245. case *ast.StructType:
  246. a.apply(n, "Fields", nil, n.Fields)
  247. case *ast.FuncType:
  248. a.apply(n, "Params", nil, n.Params)
  249. a.apply(n, "Results", nil, n.Results)
  250. case *ast.InterfaceType:
  251. a.apply(n, "Methods", nil, n.Methods)
  252. case *ast.MapType:
  253. a.apply(n, "Key", nil, n.Key)
  254. a.apply(n, "Value", nil, n.Value)
  255. case *ast.ChanType:
  256. a.apply(n, "Value", nil, n.Value)
  257. // Statements
  258. case *ast.BadStmt:
  259. // nothing to do
  260. case *ast.DeclStmt:
  261. a.apply(n, "Decl", nil, n.Decl)
  262. case *ast.EmptyStmt:
  263. // nothing to do
  264. case *ast.LabeledStmt:
  265. a.apply(n, "Label", nil, n.Label)
  266. a.apply(n, "Stmt", nil, n.Stmt)
  267. case *ast.ExprStmt:
  268. a.apply(n, "X", nil, n.X)
  269. case *ast.SendStmt:
  270. a.apply(n, "Chan", nil, n.Chan)
  271. a.apply(n, "Value", nil, n.Value)
  272. case *ast.IncDecStmt:
  273. a.apply(n, "X", nil, n.X)
  274. case *ast.AssignStmt:
  275. a.applyList(n, "Lhs")
  276. a.applyList(n, "Rhs")
  277. case *ast.GoStmt:
  278. a.apply(n, "Call", nil, n.Call)
  279. case *ast.DeferStmt:
  280. a.apply(n, "Call", nil, n.Call)
  281. case *ast.ReturnStmt:
  282. a.applyList(n, "Results")
  283. case *ast.BranchStmt:
  284. a.apply(n, "Label", nil, n.Label)
  285. case *ast.BlockStmt:
  286. a.applyList(n, "List")
  287. case *ast.IfStmt:
  288. a.apply(n, "Init", nil, n.Init)
  289. a.apply(n, "Cond", nil, n.Cond)
  290. a.apply(n, "Body", nil, n.Body)
  291. a.apply(n, "Else", nil, n.Else)
  292. case *ast.CaseClause:
  293. a.applyList(n, "List")
  294. a.applyList(n, "Body")
  295. case *ast.SwitchStmt:
  296. a.apply(n, "Init", nil, n.Init)
  297. a.apply(n, "Tag", nil, n.Tag)
  298. a.apply(n, "Body", nil, n.Body)
  299. case *ast.TypeSwitchStmt:
  300. a.apply(n, "Init", nil, n.Init)
  301. a.apply(n, "Assign", nil, n.Assign)
  302. a.apply(n, "Body", nil, n.Body)
  303. case *ast.CommClause:
  304. a.apply(n, "Comm", nil, n.Comm)
  305. a.applyList(n, "Body")
  306. case *ast.SelectStmt:
  307. a.apply(n, "Body", nil, n.Body)
  308. case *ast.ForStmt:
  309. a.apply(n, "Init", nil, n.Init)
  310. a.apply(n, "Cond", nil, n.Cond)
  311. a.apply(n, "Post", nil, n.Post)
  312. a.apply(n, "Body", nil, n.Body)
  313. case *ast.RangeStmt:
  314. a.apply(n, "Key", nil, n.Key)
  315. a.apply(n, "Value", nil, n.Value)
  316. a.apply(n, "X", nil, n.X)
  317. a.apply(n, "Body", nil, n.Body)
  318. // Declarations
  319. case *ast.ImportSpec:
  320. a.apply(n, "Doc", nil, n.Doc)
  321. a.apply(n, "Name", nil, n.Name)
  322. a.apply(n, "Path", nil, n.Path)
  323. a.apply(n, "Comment", nil, n.Comment)
  324. case *ast.ValueSpec:
  325. a.apply(n, "Doc", nil, n.Doc)
  326. a.applyList(n, "Names")
  327. a.apply(n, "Type", nil, n.Type)
  328. a.applyList(n, "Values")
  329. a.apply(n, "Comment", nil, n.Comment)
  330. case *ast.TypeSpec:
  331. a.apply(n, "Doc", nil, n.Doc)
  332. a.apply(n, "Name", nil, n.Name)
  333. a.apply(n, "Type", nil, n.Type)
  334. a.apply(n, "Comment", nil, n.Comment)
  335. case *ast.BadDecl:
  336. // nothing to do
  337. case *ast.GenDecl:
  338. a.apply(n, "Doc", nil, n.Doc)
  339. a.applyList(n, "Specs")
  340. case *ast.FuncDecl:
  341. a.apply(n, "Doc", nil, n.Doc)
  342. a.apply(n, "Recv", nil, n.Recv)
  343. a.apply(n, "Name", nil, n.Name)
  344. a.apply(n, "Type", nil, n.Type)
  345. a.apply(n, "Body", nil, n.Body)
  346. // Files and packages
  347. case *ast.File:
  348. a.apply(n, "Doc", nil, n.Doc)
  349. a.apply(n, "Name", nil, n.Name)
  350. a.applyList(n, "Decls")
  351. // Don't walk n.Comments; they have either been walked already if
  352. // they are Doc comments, or they can be easily walked explicitly.
  353. case *ast.Package:
  354. // collect and sort names for reproducible behavior
  355. var names []string
  356. for name := range n.Files {
  357. names = append(names, name)
  358. }
  359. sort.Strings(names)
  360. for _, name := range names {
  361. a.apply(n, name, nil, n.Files[name])
  362. }
  363. default:
  364. panic(fmt.Sprintf("Apply: unexpected node type %T", n))
  365. }
  366. if a.post != nil && !a.post(&a.cursor) {
  367. panic(abort)
  368. }
  369. a.cursor = saved
  370. }
  371. // An iterator controls iteration over a slice of nodes.
  372. type iterator struct {
  373. index, step int
  374. }
  375. func (a *application) applyList(parent ast.Node, name string) {
  376. // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
  377. saved := a.iter
  378. a.iter.index = 0
  379. for {
  380. // must reload parent.name each time, since cursor modifications might change it
  381. v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
  382. if a.iter.index >= v.Len() {
  383. break
  384. }
  385. // element x may be nil in a bad AST - be cautious
  386. var x ast.Node
  387. if e := v.Index(a.iter.index); e.IsValid() {
  388. x = e.Interface().(ast.Node)
  389. }
  390. a.iter.step = 1
  391. a.apply(parent, name, &a.iter, x)
  392. a.iter.index += a.iter.step
  393. }
  394. a.iter = saved
  395. }