ast.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. package expr
  2. import (
  3. "bytes"
  4. "fmt"
  5. "strconv"
  6. "strings"
  7. "text/scanner"
  8. "unicode"
  9. )
  10. const (
  11. TokenEOF = -(iota + 1)
  12. TokenIdent
  13. TokenInt
  14. TokenFloat
  15. TokenString
  16. TokenOperator
  17. )
  18. type lexer struct {
  19. scan scanner.Scanner
  20. token rune
  21. text string
  22. }
  23. func (lex *lexer) getToken() rune {
  24. return lex.token
  25. }
  26. func (lex *lexer) getText() string {
  27. return lex.text
  28. }
  29. func (lex *lexer) next() {
  30. token := lex.scan.Scan()
  31. text := lex.scan.TokenText()
  32. switch token {
  33. case scanner.EOF:
  34. lex.token = TokenEOF
  35. lex.text = text
  36. case scanner.Ident:
  37. lex.token = TokenIdent
  38. lex.text = text
  39. case scanner.Int:
  40. lex.token = TokenInt
  41. lex.text = text
  42. case scanner.Float:
  43. lex.token = TokenFloat
  44. lex.text = text
  45. case scanner.RawString:
  46. fallthrough
  47. case scanner.String:
  48. lex.token = TokenString
  49. if len(text) >= 2 && ((text[0] == '`' && text[len(text)-1] == '`') || (text[0] == '"' && text[len(text)-1] == '"')) {
  50. lex.text = text[1 : len(text)-1]
  51. } else {
  52. msg := fmt.Sprintf("got illegal string:%s", text)
  53. panic(lexPanic(msg))
  54. }
  55. case '+', '-', '*', '/', '%', '~':
  56. lex.token = TokenOperator
  57. lex.text = text
  58. case '&', '|', '=':
  59. var buffer bytes.Buffer
  60. lex.token = TokenOperator
  61. buffer.WriteRune(token)
  62. next := lex.scan.Peek()
  63. if next == token {
  64. buffer.WriteRune(next)
  65. lex.scan.Scan()
  66. }
  67. lex.text = buffer.String()
  68. case '>', '<', '!':
  69. var buffer bytes.Buffer
  70. lex.token = TokenOperator
  71. buffer.WriteRune(token)
  72. next := lex.scan.Peek()
  73. if next == '=' {
  74. buffer.WriteRune(next)
  75. lex.scan.Scan()
  76. }
  77. lex.text = buffer.String()
  78. default:
  79. if token >= 0 {
  80. lex.token = token
  81. lex.text = text
  82. } else {
  83. msg := fmt.Sprintf("got unknown token:%d, text:%s", lex.token, lex.text)
  84. panic(lexPanic(msg))
  85. }
  86. }
  87. }
  88. type lexPanic string
  89. // describe returns a string describing the current token, for use in errors.
  90. func (lex *lexer) describe() string {
  91. switch lex.token {
  92. case TokenEOF:
  93. return "end of file"
  94. case TokenIdent:
  95. return fmt.Sprintf("identifier:%s", lex.getText())
  96. case TokenInt, TokenFloat:
  97. return fmt.Sprintf("number:%s", lex.getText())
  98. case TokenString:
  99. return fmt.Sprintf("string:%s", lex.getText())
  100. }
  101. return fmt.Sprintf("token:%d", rune(lex.getToken())) // any other rune
  102. }
  103. func precedence(token rune, text string) int {
  104. if token == TokenOperator {
  105. switch text {
  106. case "~", "!":
  107. return 9
  108. case "*", "/", "%":
  109. return 8
  110. case "+", "-":
  111. return 7
  112. case ">", ">=", "<", "<=":
  113. return 6
  114. case "!=", "==", "=":
  115. return 5
  116. case "&":
  117. return 4
  118. case "|":
  119. return 3
  120. case "&&":
  121. return 2
  122. case "||":
  123. return 1
  124. default:
  125. msg := fmt.Sprintf("unknown operator:%s", text)
  126. panic(lexPanic(msg))
  127. }
  128. }
  129. return 0
  130. }
  131. // ---- parser ----
  132. type ExpressionParser struct {
  133. expression Expr
  134. variable map[string]struct{}
  135. }
  136. func NewExpressionParser() *ExpressionParser {
  137. return &ExpressionParser{
  138. expression: nil,
  139. variable: make(map[string]struct{}),
  140. }
  141. }
  142. // Parse parses the input string as an arithmetic expression.
  143. //
  144. // expr = num a literal number, e.g., 3.14159
  145. // | id a variable name, e.g., x
  146. // | id '(' expr ',' ... ')' a function call
  147. // | '-' expr a unary operator ( + - ! )
  148. // | expr '+' expr a binary operator ( + - * / && & || | == )
  149. //
  150. func (parser *ExpressionParser) Parse(input string) (err error) {
  151. defer func() {
  152. switch x := recover().(type) {
  153. case nil:
  154. // no panic
  155. case lexPanic:
  156. err = fmt.Errorf("%s", x)
  157. default:
  158. // unexpected panic: resume state of panic.
  159. panic(x)
  160. }
  161. }()
  162. lex := new(lexer)
  163. lex.scan.Init(strings.NewReader(input))
  164. lex.scan.Mode = scanner.ScanIdents | scanner.ScanInts | scanner.ScanFloats | scanner.ScanStrings | scanner.ScanRawStrings
  165. lex.scan.IsIdentRune = parser.isIdentRune
  166. lex.next() // initial lookahead
  167. parser.expression = nil
  168. parser.variable = make(map[string]struct{})
  169. e := parser.parseExpr(lex)
  170. if lex.token != scanner.EOF {
  171. return fmt.Errorf("unexpected %s", lex.describe())
  172. }
  173. parser.expression = e
  174. return nil
  175. }
  176. func (parser *ExpressionParser) GetExpr() Expr {
  177. return parser.expression
  178. }
  179. func (parser *ExpressionParser) GetVariable() []string {
  180. variable := make([]string, 0, len(parser.variable))
  181. for v := range parser.variable {
  182. if v != "true" && v != "false" {
  183. variable = append(variable, v)
  184. }
  185. }
  186. return variable
  187. }
  188. func (parser *ExpressionParser) isIdentRune(ch rune, i int) bool {
  189. return ch == '$' || ch == '_' || ch == '?' || unicode.IsLetter(ch) || unicode.IsDigit(ch) && i > 0
  190. }
  191. func (parser *ExpressionParser) parseExpr(lex *lexer) Expr {
  192. return parser.parseBinary(lex, 1)
  193. }
  194. // binary = unary ('+' binary)*
  195. // parseBinary stops when it encounters an
  196. // operator of lower precedence than prec1.
  197. func (parser *ExpressionParser) parseBinary(lex *lexer, prec1 int) Expr {
  198. lhs := parser.parseUnary(lex)
  199. for prec := precedence(lex.getToken(), lex.getText()); prec >= prec1; prec-- {
  200. for precedence(lex.getToken(), lex.getText()) == prec {
  201. op := lex.getText()
  202. lex.next() // consume operator
  203. rhs := parser.parseBinary(lex, prec+1)
  204. lhs = binary{op, lhs, rhs}
  205. }
  206. }
  207. return lhs
  208. }
  209. // unary = '+' expr | primary
  210. func (parser *ExpressionParser) parseUnary(lex *lexer) Expr {
  211. if lex.getToken() == TokenOperator {
  212. op := lex.getText()
  213. if op == "+" || op == "-" || op == "~" || op == "!" {
  214. lex.next()
  215. return unary{op, parser.parseUnary(lex)}
  216. } else {
  217. msg := fmt.Sprintf("unary got unknown operator:%s", lex.getText())
  218. panic(lexPanic(msg))
  219. }
  220. }
  221. return parser.parsePrimary(lex)
  222. }
  223. // primary = id
  224. // | id '(' expr ',' ... ',' expr ')'
  225. // | num
  226. // | '(' expr ')'
  227. func (parser *ExpressionParser) parsePrimary(lex *lexer) Expr {
  228. switch lex.token {
  229. case TokenIdent:
  230. id := lex.getText()
  231. lex.next()
  232. if lex.token != '(' {
  233. parser.variable[id] = struct{}{}
  234. return Var(id)
  235. }
  236. lex.next() // consume '('
  237. var args []Expr
  238. if lex.token != ')' {
  239. for {
  240. args = append(args, parser.parseExpr(lex))
  241. if lex.token != ',' {
  242. break
  243. }
  244. lex.next() // consume ','
  245. }
  246. if lex.token != ')' {
  247. msg := fmt.Sprintf("got %q, want ')'", lex.token)
  248. panic(lexPanic(msg))
  249. }
  250. }
  251. lex.next() // consume ')'
  252. return call{id, args}
  253. case TokenFloat:
  254. f, err := strconv.ParseFloat(lex.getText(), 64)
  255. if err != nil {
  256. panic(lexPanic(err.Error()))
  257. }
  258. lex.next()
  259. return literal{value: f}
  260. case TokenInt:
  261. i, err := strconv.ParseInt(lex.getText(), 10, 64)
  262. if err != nil {
  263. panic(lexPanic(err.Error()))
  264. }
  265. lex.next()
  266. return literal{value: i}
  267. case TokenString:
  268. s := lex.getText()
  269. lex.next()
  270. return literal{value: s}
  271. case '(':
  272. lex.next() // consume '('
  273. e := parser.parseExpr(lex)
  274. if lex.token != ')' {
  275. msg := fmt.Sprintf("got %s, want ')'", lex.describe())
  276. panic(lexPanic(msg))
  277. }
  278. lex.next() // consume ')'
  279. return e
  280. }
  281. msg := fmt.Sprintf("unexpected %s", lex.describe())
  282. panic(lexPanic(msg))
  283. }