main.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. package main
  2. import (
  3. "bytes"
  4. "flag"
  5. "go/ast"
  6. "io/ioutil"
  7. "log"
  8. "os"
  9. "path/filepath"
  10. "regexp"
  11. "strconv"
  12. "strings"
  13. "text/template"
  14. "go-common/app/tool/cache/common"
  15. )
  16. var (
  17. encode = flag.String("encode", "", "encode type: json/pb/raw/gob/gzip")
  18. mcType = flag.String("type", "", "type: get/set/del/replace/only_add")
  19. key = flag.String("key", "", "key name method")
  20. expire = flag.String("expire", "", "expire time code")
  21. batchSize = flag.Int("batch", 0, "batch size")
  22. batchErr = flag.String("batch_err", "break", "batch err to contine or break")
  23. maxGroup = flag.Int("max_group", 0, "max group size")
  24. mcValidTypes = []string{"set", "replace", "del", "get", "only_add"}
  25. mcValidPrefix = []string{"set", "replace", "del", "get", "cache", "add"}
  26. optionNamesMap = map[string]bool{"batch": true, "max_group": true, "encode": true, "type": true, "key": true, "expire": true, "batch_err": true}
  27. simpleTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64", "bool", "string", "[]byte"}
  28. lenTypes = []string{"[]", "map"}
  29. )
  30. const (
  31. _interfaceName = "_mc"
  32. _multiTpl = 1
  33. _singleTpl = 2
  34. _noneTpl = 3
  35. _typeGet = "get"
  36. _typeSet = "set"
  37. _typeDel = "del"
  38. _typeReplace = "replace"
  39. _typeAdd = "only_add"
  40. )
  41. func resetFlag() {
  42. *encode = ""
  43. *mcType = ""
  44. *batchSize = 0
  45. *maxGroup = 0
  46. *batchErr = "break"
  47. }
  48. // options options
  49. type options struct {
  50. name string
  51. keyType string
  52. ValueType string
  53. template int
  54. SimpleValue bool
  55. // int float 类型
  56. GetSimpleValue bool
  57. // string, []byte类型
  58. GetDirectValue bool
  59. ConvertValue2Bytes string
  60. ConvertBytes2Value string
  61. GoValue bool
  62. ImportPackage string
  63. importPackages []string
  64. Args string
  65. PkgName string
  66. ExtraArgsType string
  67. ExtraArgs string
  68. MCType string
  69. KeyMethod string
  70. ExpireCode string
  71. Encode string
  72. UseMemcached bool
  73. InitValue bool
  74. OriginValueType string
  75. UseStrConv bool
  76. Comment string
  77. GroupSize int
  78. MaxGroup int
  79. EnableBatch bool
  80. BatchErrBreak bool
  81. LenType bool
  82. PointType bool
  83. }
  84. func parse(s *common.Source) (opts []*options) {
  85. f := s.F
  86. fset := s.Fset
  87. src := s.Src
  88. c := f.Scope.Lookup(_interfaceName)
  89. if (c == nil) || (c.Kind != ast.Typ) {
  90. log.Fatalln("无法找到缓存声明")
  91. }
  92. lines := strings.Split(src, "\n")
  93. lists := c.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List
  94. for _, list := range lists {
  95. opt := options{Args: s.GetDef(_interfaceName), UseMemcached: true, importPackages: s.Packages(list)}
  96. opt.name = list.Names[0].Name
  97. opt.KeyMethod = "key" + opt.name
  98. opt.ExpireCode = "d.mc" + opt.name + "Expire"
  99. // get comment
  100. line := fset.Position(list.Pos()).Line - 3
  101. if len(lines)-1 >= line {
  102. comment := lines[line]
  103. opt.Comment = common.RegexpReplace(`\s+//(?P<name>.+)`, comment, "$name")
  104. opt.Comment = strings.TrimSpace(opt.Comment)
  105. }
  106. // get options
  107. line = fset.Position(list.Pos()).Line - 2
  108. comment := lines[line]
  109. os.Args = []string{os.Args[0]}
  110. if regexp.MustCompile(`\s+//\s*mc:.+`).Match([]byte(comment)) {
  111. args := strings.Split(common.RegexpReplace(`//\s*mc:(?P<arg>.+)`, comment, "$arg"), " ")
  112. for _, arg := range args {
  113. arg = strings.TrimSpace(arg)
  114. if arg != "" {
  115. // validate option name
  116. argName := common.RegexpReplace(`-(?P<name>[\w_-]+)=.+`, arg, "$name")
  117. if !optionNamesMap[argName] {
  118. log.Fatalf("选项:%s 不存在 请检查拼写\n", argName)
  119. }
  120. os.Args = append(os.Args, arg)
  121. }
  122. }
  123. }
  124. resetFlag()
  125. flag.Parse()
  126. if *mcType != "" {
  127. opt.MCType = *mcType
  128. }
  129. if *key != "" {
  130. opt.KeyMethod = *key
  131. }
  132. if *expire != "" {
  133. opt.ExpireCode = *expire
  134. }
  135. opt.EnableBatch = (*batchSize != 0) && (*maxGroup != 0)
  136. opt.BatchErrBreak = *batchErr == "break"
  137. opt.GroupSize = *batchSize
  138. opt.MaxGroup = *maxGroup
  139. // get type from prefix
  140. if opt.MCType == "" {
  141. for _, t := range mcValidPrefix {
  142. if strings.HasPrefix(strings.ToLower(opt.name), t) {
  143. if t == "add" {
  144. t = _typeSet
  145. }
  146. opt.MCType = t
  147. break
  148. }
  149. }
  150. if opt.MCType == "" {
  151. log.Fatalln(opt.name + "请指定方法类型(type=get/set/del...)")
  152. }
  153. }
  154. if opt.MCType == "cache" {
  155. opt.MCType = _typeGet
  156. }
  157. params := list.Type.(*ast.FuncType).Params.List
  158. if len(params) == 0 {
  159. log.Fatalln(opt.name + "参数不足")
  160. }
  161. if s.ExprString(params[0].Type) != "context.Context" {
  162. log.Fatalln(opt.name + "第一个参数必须为context")
  163. }
  164. for _, param := range params {
  165. if len(param.Names) > 1 {
  166. log.Fatalln(opt.name + "不支持省略类型")
  167. }
  168. }
  169. // get template
  170. if len(params) == 1 {
  171. opt.template = _noneTpl
  172. } else if (len(params) == 2) && (opt.MCType == _typeSet || opt.MCType == _typeAdd || opt.MCType == _typeReplace) {
  173. if _, ok := params[1].Type.(*ast.MapType); ok {
  174. opt.template = _multiTpl
  175. } else {
  176. opt.template = _noneTpl
  177. }
  178. } else {
  179. if _, ok := params[1].Type.(*ast.ArrayType); ok {
  180. opt.template = _multiTpl
  181. } else {
  182. opt.template = _singleTpl
  183. }
  184. }
  185. // extra args
  186. if len(params) > 2 {
  187. args := []string{""}
  188. allArgs := []string{""}
  189. var pos = 2
  190. if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) {
  191. pos = 3
  192. }
  193. for _, pa := range params[pos:] {
  194. paType := s.ExprString(pa.Type)
  195. if len(pa.Names) == 0 {
  196. args = append(args, paType)
  197. allArgs = append(allArgs, paType)
  198. continue
  199. }
  200. var names []string
  201. for _, name := range pa.Names {
  202. names = append(names, name.Name)
  203. }
  204. allArgs = append(allArgs, strings.Join(names, ",")+" "+paType)
  205. args = append(args, strings.Join(names, ","))
  206. }
  207. if len(args) > 1 {
  208. opt.ExtraArgs = strings.Join(args, ",")
  209. opt.ExtraArgsType = strings.Join(allArgs, ",")
  210. }
  211. }
  212. // get k v from results
  213. results := list.Type.(*ast.FuncType).Results.List
  214. if s.ExprString(results[len(results)-1].Type) != "error" {
  215. log.Fatalln("最后返回值参数需为error")
  216. }
  217. for _, res := range results {
  218. if len(res.Names) > 1 {
  219. log.Fatalln(opt.name + "返回值不支持省略类型")
  220. }
  221. }
  222. if opt.MCType == _typeGet {
  223. if len(results) != 2 {
  224. log.Fatalln("参数个数不对")
  225. }
  226. }
  227. // get key type and value type
  228. if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) {
  229. if opt.template == _multiTpl {
  230. p, ok := params[1].Type.(*ast.MapType)
  231. if !ok {
  232. log.Fatalf("%s: 参数类型错误 批量设置数据时类型需为map类型\n", opt.name)
  233. }
  234. opt.keyType = s.ExprString(p.Key)
  235. opt.ValueType = s.ExprString(p.Value)
  236. } else if opt.template == _singleTpl {
  237. opt.keyType = s.ExprString(params[1].Type)
  238. opt.ValueType = s.ExprString(params[2].Type)
  239. } else {
  240. opt.ValueType = s.ExprString(params[1].Type)
  241. }
  242. }
  243. if opt.MCType == _typeGet {
  244. if opt.template == _multiTpl {
  245. if p, ok := results[0].Type.(*ast.MapType); ok {
  246. opt.keyType = s.ExprString(p.Key)
  247. opt.ValueType = s.ExprString(p.Value)
  248. } else {
  249. log.Fatalf("%s: 返回值类型错误 批量获取数据时返回值需为map类型\n", opt.name)
  250. }
  251. } else if opt.template == _singleTpl {
  252. opt.keyType = s.ExprString(params[1].Type)
  253. opt.ValueType = s.ExprString(results[0].Type)
  254. } else {
  255. opt.ValueType = s.ExprString(results[0].Type)
  256. }
  257. }
  258. if opt.MCType == _typeDel {
  259. if opt.template == _multiTpl {
  260. p, ok := params[1].Type.(*ast.ArrayType)
  261. if !ok {
  262. log.Fatalf("%s: 类型错误 参数需为[]类型\n", opt.name)
  263. }
  264. opt.keyType = s.ExprString(p.Elt)
  265. } else if opt.template == _singleTpl {
  266. opt.keyType = s.ExprString(params[1].Type)
  267. }
  268. }
  269. for _, t := range simpleTypes {
  270. if t == opt.ValueType {
  271. opt.SimpleValue = true
  272. opt.GetSimpleValue = true
  273. opt.ConvertValue2Bytes = convertValue2Bytes(t)
  274. opt.ConvertBytes2Value = convertBytes2Value(t)
  275. break
  276. }
  277. }
  278. if opt.ValueType == "string" {
  279. opt.LenType = true
  280. } else {
  281. for _, t := range lenTypes {
  282. if strings.HasPrefix(opt.ValueType, t) {
  283. opt.LenType = true
  284. break
  285. }
  286. }
  287. }
  288. if opt.SimpleValue && (opt.ValueType == "[]byte" || opt.ValueType == "string") {
  289. opt.GetSimpleValue = false
  290. opt.GetDirectValue = true
  291. }
  292. if opt.MCType == _typeGet && opt.template == _multiTpl {
  293. opt.UseMemcached = false
  294. }
  295. if strings.HasPrefix(opt.ValueType, "*") {
  296. opt.InitValue = true
  297. opt.PointType = true
  298. opt.OriginValueType = strings.Replace(opt.ValueType, "*", "", 1)
  299. } else {
  300. opt.OriginValueType = opt.ValueType
  301. }
  302. if *encode != "" {
  303. var flags []string
  304. for _, f := range strings.Split(*encode, "|") {
  305. switch f {
  306. case "gob":
  307. flags = append(flags, "memcache.FlagGOB")
  308. case "json":
  309. flags = append(flags, "memcache.FlagJSON")
  310. case "raw":
  311. flags = append(flags, "memcache.FlagRAW")
  312. case "pb":
  313. flags = append(flags, "memcache.FlagProtobuf")
  314. case "gzip":
  315. flags = append(flags, "memcache.FlagGzip")
  316. default:
  317. log.Fatalf("%s: encode类型无效\n", opt.name)
  318. }
  319. }
  320. opt.Encode = strings.Join(flags, " | ")
  321. } else {
  322. if opt.SimpleValue {
  323. opt.Encode = "memcache.FlagRAW"
  324. } else {
  325. opt.Encode = "memcache.FlagJSON"
  326. }
  327. }
  328. opt.Check()
  329. opts = append(opts, &opt)
  330. }
  331. return
  332. }
  333. func (option *options) Check() {
  334. var valid bool
  335. for _, x := range mcValidTypes {
  336. if x == option.MCType {
  337. valid = true
  338. break
  339. }
  340. }
  341. if !valid {
  342. log.Fatalf("%s: 类型错误 不支持%s类型\n", option.name, option.MCType)
  343. }
  344. if (option.MCType != _typeDel) && !option.SimpleValue && !strings.Contains(option.ValueType, "*") && !strings.Contains(option.ValueType, "[]") && !strings.Contains(option.ValueType, "map") {
  345. log.Fatalf("%s: 值类型只能为基本类型/slice/map/指针类型\n", option.name)
  346. }
  347. }
  348. func genHeader(opts []*options) (src string) {
  349. option := options{PkgName: os.Getenv("GOPACKAGE"), UseMemcached: false}
  350. var packages []string
  351. packagesMap := map[string]bool{`"context"`: true}
  352. for _, opt := range opts {
  353. if len(opt.importPackages) > 0 {
  354. for _, pkg := range opt.importPackages {
  355. if !packagesMap[pkg] {
  356. packages = append(packages, pkg)
  357. packagesMap[pkg] = true
  358. }
  359. }
  360. }
  361. if opt.Args != "" {
  362. option.Args = opt.Args
  363. }
  364. if opt.UseMemcached {
  365. option.UseMemcached = true
  366. }
  367. if opt.SimpleValue && !opt.GetDirectValue {
  368. option.UseStrConv = true
  369. }
  370. if opt.EnableBatch {
  371. option.EnableBatch = true
  372. }
  373. }
  374. option.ImportPackage = strings.Join(packages, "\n")
  375. src = _headerTemplate
  376. t := template.Must(template.New("header").Parse(src))
  377. var buffer bytes.Buffer
  378. err := t.Execute(&buffer, option)
  379. if err != nil {
  380. log.Fatalf("execute template: %s", err)
  381. }
  382. // Format the output.
  383. src = strings.Replace(buffer.String(), "\t", "", -1)
  384. src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
  385. src = strings.Replace(src, "NEWLINE", "", -1)
  386. src = strings.Replace(src, "ARGS", option.Args, -1)
  387. return
  388. }
  389. func genBody(opts []*options) (res string) {
  390. for _, option := range opts {
  391. var src string
  392. if option.template == _multiTpl {
  393. switch option.MCType {
  394. case _typeGet:
  395. src = _multiGetTemplate
  396. case _typeSet:
  397. src = _multiSetTemplate
  398. case _typeReplace:
  399. src = _multiReplaceTemplate
  400. case _typeDel:
  401. src = _multiDelTemplate
  402. case _typeAdd:
  403. src = _multiAddTemplate
  404. }
  405. } else if option.template == _singleTpl {
  406. switch option.MCType {
  407. case _typeGet:
  408. src = _singleGetTemplate
  409. case _typeSet:
  410. src = _singleSetTemplate
  411. case _typeReplace:
  412. src = _singleReplaceTemplate
  413. case _typeDel:
  414. src = _singleDelTemplate
  415. case _typeAdd:
  416. src = _singleAddTemplate
  417. }
  418. } else {
  419. switch option.MCType {
  420. case _typeGet:
  421. src = _noneGetTemplate
  422. case _typeSet:
  423. src = _noneSetTemplate
  424. case _typeReplace:
  425. src = _noneReplaceTemplate
  426. case _typeDel:
  427. src = _noneDelTemplate
  428. case _typeAdd:
  429. src = _noneAddTemplate
  430. }
  431. }
  432. src = strings.Replace(src, "KEY", option.keyType, -1)
  433. src = strings.Replace(src, "NAME", option.name, -1)
  434. src = strings.Replace(src, "VALUE", option.ValueType, -1)
  435. src = strings.Replace(src, "GROUPSIZE", strconv.Itoa(option.GroupSize), -1)
  436. src = strings.Replace(src, "MAXGROUP", strconv.Itoa(option.MaxGroup), -1)
  437. t := template.Must(template.New("cache").Parse(src))
  438. var buffer bytes.Buffer
  439. err := t.Execute(&buffer, option)
  440. if err != nil {
  441. log.Fatalf("execute template: %s", err)
  442. }
  443. // Format the output.
  444. src = strings.Replace(buffer.String(), "\t", "", -1)
  445. src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
  446. res = res + "\n" + src
  447. }
  448. return
  449. }
  450. func main() {
  451. log.SetFlags(0)
  452. defer func() {
  453. if err := recover(); err != nil {
  454. log.Fatalf("程序解析失败, err: %+v 请企业微信联系 @wangxu01", err)
  455. }
  456. }()
  457. options := parse(common.NewSource(common.SourceText()))
  458. header := genHeader(options)
  459. body := genBody(options)
  460. code := common.FormatCode(header + "\n" + body)
  461. // Write to file.
  462. dir := filepath.Dir(".")
  463. outputName := filepath.Join(dir, "mc.cache.go")
  464. err := ioutil.WriteFile(outputName, []byte(code), 0644)
  465. if err != nil {
  466. log.Fatalf("写入文件失败: %s", err)
  467. }
  468. log.Println("mc.cache.go: 生成成功")
  469. }
  470. func convertValue2Bytes(t string) string {
  471. switch t {
  472. case "int", "int8", "int16", "int32", "int64":
  473. return "[]byte(strconv.FormatInt(int64(val), 10))"
  474. case "uint", "uint8", "uint16", "uint32", "uint64":
  475. return "[]byte(strconv.FormatUInt(val, 10))"
  476. case "bool":
  477. return "[]byte(strconv.FormatBool(val))"
  478. case "float32":
  479. return "[]byte(strconv.FormatFloat(val, 'E', -1, 32))"
  480. case "float64":
  481. return "[]byte(strconv.FormatFloat(val, 'E', -1, 64))"
  482. case "string":
  483. return "[]byte(val)"
  484. case "[]byte":
  485. return "val"
  486. }
  487. return ""
  488. }
  489. func convertBytes2Value(t string) string {
  490. switch t {
  491. case "int", "int8", "int16", "int32", "int64":
  492. return "strconv.ParseInt(v, 10, 64)"
  493. case "uint", "uint8", "uint16", "uint32", "uint64":
  494. return "strconv.ParseUInt(v, 10, 64)"
  495. case "bool":
  496. return "strconv.ParseBool(v)"
  497. case "float32":
  498. return "float32(strconv.ParseFloat(v, 32))"
  499. case "float64":
  500. return "strconv.ParseFloat(v, 64)"
  501. }
  502. return ""
  503. }