ast.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. package expr
  2. import (
  3. "bytes"
  4. "fmt"
  5. "strconv"
  6. "strings"
  7. "text/scanner"
  8. "unicode"
  9. )
  10. const (
  11. TokenEOF = -(iota + 1)
  12. TokenIdent
  13. TokenInt
  14. TokenFloat
  15. TokenOperator
  16. )
  17. type lexer struct {
  18. scan scanner.Scanner
  19. token rune
  20. text string
  21. }
  22. func (lex *lexer) getToken() rune {
  23. return lex.token
  24. }
  25. func (lex *lexer) getText() string {
  26. return lex.text
  27. }
  28. func (lex *lexer) next() {
  29. token := lex.scan.Scan()
  30. text := lex.scan.TokenText()
  31. switch token {
  32. case scanner.EOF:
  33. lex.token = TokenEOF
  34. lex.text = text
  35. case scanner.Ident:
  36. lex.token = TokenIdent
  37. lex.text = text
  38. case scanner.Int:
  39. lex.token = TokenInt
  40. lex.text = text
  41. case scanner.Float:
  42. lex.token = TokenFloat
  43. lex.text = text
  44. case '+', '-', '*', '/', '%', '~':
  45. lex.token = TokenOperator
  46. lex.text = text
  47. case '&', '|', '=':
  48. var buffer bytes.Buffer
  49. lex.token = TokenOperator
  50. buffer.WriteRune(token)
  51. next := lex.scan.Peek()
  52. if next == token {
  53. buffer.WriteRune(next)
  54. lex.scan.Scan()
  55. }
  56. lex.text = buffer.String()
  57. case '>', '<', '!':
  58. var buffer bytes.Buffer
  59. lex.token = TokenOperator
  60. buffer.WriteRune(token)
  61. next := lex.scan.Peek()
  62. if next == '=' {
  63. buffer.WriteRune(next)
  64. lex.scan.Scan()
  65. }
  66. lex.text = buffer.String()
  67. default:
  68. if token >= 0 {
  69. lex.token = token
  70. lex.text = text
  71. } else {
  72. msg := fmt.Sprintf("got unknown token:%q, text:%s", lex.token, lex.text)
  73. panic(lexPanic(msg))
  74. }
  75. }
  76. //fmt.Printf("token:%d, text:%s\n", lex.token, lex.text)
  77. }
  78. type lexPanic string
  79. // describe returns a string describing the current token, for use in errors.
  80. func (lex *lexer) describe() string {
  81. switch lex.token {
  82. case TokenEOF:
  83. return "end of file"
  84. case TokenIdent:
  85. return fmt.Sprintf("identifier %s", lex.getText())
  86. case TokenInt, TokenFloat:
  87. return fmt.Sprintf("number %s", lex.getText())
  88. }
  89. return fmt.Sprintf("%q", rune(lex.getToken())) // any other rune
  90. }
  91. func precedence(token rune, text string) int {
  92. if token == TokenOperator {
  93. switch text {
  94. case "~", "!":
  95. return 9
  96. case "*", "/", "%":
  97. return 8
  98. case "+", "-":
  99. return 7
  100. case ">", ">=", "<", "<=":
  101. return 6
  102. case "!=", "==", "=":
  103. return 5
  104. case "&":
  105. return 4
  106. case "|":
  107. return 3
  108. case "&&":
  109. return 2
  110. case "||":
  111. return 1
  112. default:
  113. msg := fmt.Sprintf("unknown operator:%s", text)
  114. panic(lexPanic(msg))
  115. }
  116. }
  117. return 0
  118. }
  119. // ---- parser ----
  120. type ExpressionParser struct {
  121. expression Expr
  122. variable map[string]struct{}
  123. }
  124. func NewExpressionParser() *ExpressionParser {
  125. return &ExpressionParser{
  126. expression: nil,
  127. variable: make(map[string]struct{}),
  128. }
  129. }
  130. // Parse parses the input string as an arithmetic expression.
  131. //
  132. // expr = num a literal number, e.g., 3.14159
  133. // | id a variable name, e.g., x
  134. // | id '(' expr ',' ... ')' a function call
  135. // | '-' expr a unary operator ( + - ! )
  136. // | expr '+' expr a binary operator ( + - * / && & || | == )
  137. //
  138. func (parser *ExpressionParser) Parse(input string) (err error) {
  139. defer func() {
  140. switch x := recover().(type) {
  141. case nil:
  142. // no panic
  143. case lexPanic:
  144. err = fmt.Errorf("%s", x)
  145. default:
  146. // unexpected panic: resume state of panic.
  147. panic(x)
  148. }
  149. }()
  150. lex := new(lexer)
  151. lex.scan.Init(strings.NewReader(input))
  152. lex.scan.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats
  153. lex.scan.IsIdentRune = parser.isIdentRune
  154. lex.next() // initial lookahead
  155. parser.expression = nil
  156. parser.variable = make(map[string]struct{})
  157. e := parser.parseExpr(lex)
  158. if lex.token != scanner.EOF {
  159. return fmt.Errorf("unexpected %s", lex.describe())
  160. }
  161. parser.expression = e
  162. return nil
  163. }
  164. func (parser *ExpressionParser) GetExpr() Expr {
  165. return parser.expression
  166. }
  167. func (parser *ExpressionParser) GetVariable() []string {
  168. variable := make([]string, 0, len(parser.variable))
  169. for v := range parser.variable {
  170. if v != "true" && v != "false" {
  171. variable = append(variable, v)
  172. }
  173. }
  174. return variable
  175. }
  176. func (parser *ExpressionParser) isIdentRune(ch rune, i int) bool {
  177. return ch == '$' || ch == '_' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && i > 0
  178. }
  179. func (parser *ExpressionParser) parseExpr(lex *lexer) Expr {
  180. return parser.parseBinary(lex, 1)
  181. }
  182. // binary = unary ('+' binary)*
  183. // parseBinary stops when it encounters an
  184. // operator of lower precedence than prec1.
  185. func (parser *ExpressionParser) parseBinary(lex *lexer, prec1 int) Expr {
  186. lhs := parser.parseUnary(lex)
  187. for prec := precedence(lex.getToken(), lex.getText()); prec >= prec1; prec-- {
  188. for precedence(lex.getToken(), lex.getText()) == prec {
  189. op := lex.getText()
  190. lex.next() // consume operator
  191. rhs := parser.parseBinary(lex, prec+1)
  192. lhs = binary{op, lhs, rhs}
  193. }
  194. }
  195. return lhs
  196. }
  197. // unary = '+' expr | primary
  198. func (parser *ExpressionParser) parseUnary(lex *lexer) Expr {
  199. if lex.getToken() == TokenOperator {
  200. op := lex.getText()
  201. if op == "+" || op == "-" || op == "~" || op == "!" {
  202. lex.next()
  203. return unary{op, parser.parseUnary(lex)}
  204. } else {
  205. msg := fmt.Sprintf("unary got unknown operator:%s", lex.getText())
  206. panic(lexPanic(msg))
  207. }
  208. }
  209. return parser.parsePrimary(lex)
  210. }
  211. // primary = id
  212. // | id '(' expr ',' ... ',' expr ')'
  213. // | num
  214. // | '(' expr ')'
  215. func (parser *ExpressionParser) parsePrimary(lex *lexer) Expr {
  216. switch lex.token {
  217. case TokenIdent:
  218. id := lex.getText()
  219. lex.next()
  220. if lex.token != '(' {
  221. parser.variable[id] = struct{}{}
  222. return Var(id)
  223. }
  224. lex.next() // consume '('
  225. var args []Expr
  226. if lex.token != ')' {
  227. for {
  228. args = append(args, parser.parseExpr(lex))
  229. if lex.token != ',' {
  230. break
  231. }
  232. lex.next() // consume ','
  233. }
  234. if lex.token != ')' {
  235. msg := fmt.Sprintf("got %q, want ')'", lex.token)
  236. panic(lexPanic(msg))
  237. }
  238. }
  239. lex.next() // consume ')'
  240. return call{id, args}
  241. case TokenFloat:
  242. f, err := strconv.ParseFloat(lex.getText(), 64)
  243. if err != nil {
  244. panic(lexPanic(err.Error()))
  245. }
  246. lex.next() // consume number
  247. return literal{value: f}
  248. case TokenInt:
  249. i, err := strconv.ParseInt(lex.getText(), 10, 64)
  250. if err != nil {
  251. panic(lexPanic(err.Error()))
  252. }
  253. lex.next() // consume number
  254. return literal{value: i}
  255. case '(':
  256. lex.next() // consume '('
  257. e := parser.parseExpr(lex)
  258. if lex.token != ')' {
  259. msg := fmt.Sprintf("got %s, want ')'", lex.describe())
  260. panic(lexPanic(msg))
  261. }
  262. lex.next() // consume ')'
  263. return e
  264. }
  265. msg := fmt.Sprintf("unexpected %s", lex.describe())
  266. panic(lexPanic(msg))
  267. }