123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517 |
- package main
- import (
- "bytes"
- "flag"
- "go/ast"
- "io/ioutil"
- "log"
- "os"
- "path/filepath"
- "regexp"
- "strconv"
- "strings"
- "text/template"
- "go-common/app/tool/cache/common"
- )
- var (
- encode = flag.String("encode", "", "encode type: json/pb/raw/gob/gzip")
- mcType = flag.String("type", "", "type: get/set/del/replace/only_add")
- key = flag.String("key", "", "key name method")
- expire = flag.String("expire", "", "expire time code")
- batchSize = flag.Int("batch", 0, "batch size")
- batchErr = flag.String("batch_err", "break", "batch err to contine or break")
- maxGroup = flag.Int("max_group", 0, "max group size")
- mcValidTypes = []string{"set", "replace", "del", "get", "only_add"}
- mcValidPrefix = []string{"set", "replace", "del", "get", "cache", "add"}
- optionNamesMap = map[string]bool{"batch": true, "max_group": true, "encode": true, "type": true, "key": true, "expire": true, "batch_err": true}
- simpleTypes = []string{"int", "int8", "int16", "int32", "int64", "float32", "float64", "uint", "uint8", "uint16", "uint32", "uint64", "bool", "string", "[]byte"}
- lenTypes = []string{"[]", "map"}
- )
- const (
- _interfaceName = "_mc"
- _multiTpl = 1
- _singleTpl = 2
- _noneTpl = 3
- _typeGet = "get"
- _typeSet = "set"
- _typeDel = "del"
- _typeReplace = "replace"
- _typeAdd = "only_add"
- )
- func resetFlag() {
- *encode = ""
- *mcType = ""
- *batchSize = 0
- *maxGroup = 0
- *batchErr = "break"
- }
- // options options
- type options struct {
- name string
- keyType string
- ValueType string
- template int
- SimpleValue bool
- // int float 类型
- GetSimpleValue bool
- // string, []byte类型
- GetDirectValue bool
- ConvertValue2Bytes string
- ConvertBytes2Value string
- GoValue bool
- ImportPackage string
- importPackages []string
- Args string
- PkgName string
- ExtraArgsType string
- ExtraArgs string
- MCType string
- KeyMethod string
- ExpireCode string
- Encode string
- UseMemcached bool
- InitValue bool
- OriginValueType string
- UseStrConv bool
- Comment string
- GroupSize int
- MaxGroup int
- EnableBatch bool
- BatchErrBreak bool
- LenType bool
- PointType bool
- }
- func parse(s *common.Source) (opts []*options) {
- f := s.F
- fset := s.Fset
- src := s.Src
- c := f.Scope.Lookup(_interfaceName)
- if (c == nil) || (c.Kind != ast.Typ) {
- log.Fatalln("无法找到缓存声明")
- }
- lines := strings.Split(src, "\n")
- lists := c.Decl.(*ast.TypeSpec).Type.(*ast.InterfaceType).Methods.List
- for _, list := range lists {
- opt := options{Args: s.GetDef(_interfaceName), UseMemcached: true, importPackages: s.Packages(list)}
- opt.name = list.Names[0].Name
- opt.KeyMethod = "key" + opt.name
- opt.ExpireCode = "d.mc" + opt.name + "Expire"
- // get comment
- line := fset.Position(list.Pos()).Line - 3
- if len(lines)-1 >= line {
- comment := lines[line]
- opt.Comment = common.RegexpReplace(`\s+//(?P<name>.+)`, comment, "$name")
- opt.Comment = strings.TrimSpace(opt.Comment)
- }
- // get options
- line = fset.Position(list.Pos()).Line - 2
- comment := lines[line]
- os.Args = []string{os.Args[0]}
- if regexp.MustCompile(`\s+//\s*mc:.+`).Match([]byte(comment)) {
- args := strings.Split(common.RegexpReplace(`//\s*mc:(?P<arg>.+)`, comment, "$arg"), " ")
- for _, arg := range args {
- arg = strings.TrimSpace(arg)
- if arg != "" {
- // validate option name
- argName := common.RegexpReplace(`-(?P<name>[\w_-]+)=.+`, arg, "$name")
- if !optionNamesMap[argName] {
- log.Fatalf("选项:%s 不存在 请检查拼写\n", argName)
- }
- os.Args = append(os.Args, arg)
- }
- }
- }
- resetFlag()
- flag.Parse()
- if *mcType != "" {
- opt.MCType = *mcType
- }
- if *key != "" {
- opt.KeyMethod = *key
- }
- if *expire != "" {
- opt.ExpireCode = *expire
- }
- opt.EnableBatch = (*batchSize != 0) && (*maxGroup != 0)
- opt.BatchErrBreak = *batchErr == "break"
- opt.GroupSize = *batchSize
- opt.MaxGroup = *maxGroup
- // get type from prefix
- if opt.MCType == "" {
- for _, t := range mcValidPrefix {
- if strings.HasPrefix(strings.ToLower(opt.name), t) {
- if t == "add" {
- t = _typeSet
- }
- opt.MCType = t
- break
- }
- }
- if opt.MCType == "" {
- log.Fatalln(opt.name + "请指定方法类型(type=get/set/del...)")
- }
- }
- if opt.MCType == "cache" {
- opt.MCType = _typeGet
- }
- params := list.Type.(*ast.FuncType).Params.List
- if len(params) == 0 {
- log.Fatalln(opt.name + "参数不足")
- }
- if s.ExprString(params[0].Type) != "context.Context" {
- log.Fatalln(opt.name + "第一个参数必须为context")
- }
- for _, param := range params {
- if len(param.Names) > 1 {
- log.Fatalln(opt.name + "不支持省略类型")
- }
- }
- // get template
- if len(params) == 1 {
- opt.template = _noneTpl
- } else if (len(params) == 2) && (opt.MCType == _typeSet || opt.MCType == _typeAdd || opt.MCType == _typeReplace) {
- if _, ok := params[1].Type.(*ast.MapType); ok {
- opt.template = _multiTpl
- } else {
- opt.template = _noneTpl
- }
- } else {
- if _, ok := params[1].Type.(*ast.ArrayType); ok {
- opt.template = _multiTpl
- } else {
- opt.template = _singleTpl
- }
- }
- // extra args
- if len(params) > 2 {
- args := []string{""}
- allArgs := []string{""}
- var pos = 2
- if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) {
- pos = 3
- }
- for _, pa := range params[pos:] {
- paType := s.ExprString(pa.Type)
- if len(pa.Names) == 0 {
- args = append(args, paType)
- allArgs = append(allArgs, paType)
- continue
- }
- var names []string
- for _, name := range pa.Names {
- names = append(names, name.Name)
- }
- allArgs = append(allArgs, strings.Join(names, ",")+" "+paType)
- args = append(args, strings.Join(names, ","))
- }
- if len(args) > 1 {
- opt.ExtraArgs = strings.Join(args, ",")
- opt.ExtraArgsType = strings.Join(allArgs, ",")
- }
- }
- // get k v from results
- results := list.Type.(*ast.FuncType).Results.List
- if s.ExprString(results[len(results)-1].Type) != "error" {
- log.Fatalln("最后返回值参数需为error")
- }
- for _, res := range results {
- if len(res.Names) > 1 {
- log.Fatalln(opt.name + "返回值不支持省略类型")
- }
- }
- if opt.MCType == _typeGet {
- if len(results) != 2 {
- log.Fatalln("参数个数不对")
- }
- }
- // get key type and value type
- if (opt.MCType == _typeAdd) || (opt.MCType == _typeSet) || (opt.MCType == _typeReplace) {
- if opt.template == _multiTpl {
- p, ok := params[1].Type.(*ast.MapType)
- if !ok {
- log.Fatalf("%s: 参数类型错误 批量设置数据时类型需为map类型\n", opt.name)
- }
- opt.keyType = s.ExprString(p.Key)
- opt.ValueType = s.ExprString(p.Value)
- } else if opt.template == _singleTpl {
- opt.keyType = s.ExprString(params[1].Type)
- opt.ValueType = s.ExprString(params[2].Type)
- } else {
- opt.ValueType = s.ExprString(params[1].Type)
- }
- }
- if opt.MCType == _typeGet {
- if opt.template == _multiTpl {
- if p, ok := results[0].Type.(*ast.MapType); ok {
- opt.keyType = s.ExprString(p.Key)
- opt.ValueType = s.ExprString(p.Value)
- } else {
- log.Fatalf("%s: 返回值类型错误 批量获取数据时返回值需为map类型\n", opt.name)
- }
- } else if opt.template == _singleTpl {
- opt.keyType = s.ExprString(params[1].Type)
- opt.ValueType = s.ExprString(results[0].Type)
- } else {
- opt.ValueType = s.ExprString(results[0].Type)
- }
- }
- if opt.MCType == _typeDel {
- if opt.template == _multiTpl {
- p, ok := params[1].Type.(*ast.ArrayType)
- if !ok {
- log.Fatalf("%s: 类型错误 参数需为[]类型\n", opt.name)
- }
- opt.keyType = s.ExprString(p.Elt)
- } else if opt.template == _singleTpl {
- opt.keyType = s.ExprString(params[1].Type)
- }
- }
- for _, t := range simpleTypes {
- if t == opt.ValueType {
- opt.SimpleValue = true
- opt.GetSimpleValue = true
- opt.ConvertValue2Bytes = convertValue2Bytes(t)
- opt.ConvertBytes2Value = convertBytes2Value(t)
- break
- }
- }
- if opt.ValueType == "string" {
- opt.LenType = true
- } else {
- for _, t := range lenTypes {
- if strings.HasPrefix(opt.ValueType, t) {
- opt.LenType = true
- break
- }
- }
- }
- if opt.SimpleValue && (opt.ValueType == "[]byte" || opt.ValueType == "string") {
- opt.GetSimpleValue = false
- opt.GetDirectValue = true
- }
- if opt.MCType == _typeGet && opt.template == _multiTpl {
- opt.UseMemcached = false
- }
- if strings.HasPrefix(opt.ValueType, "*") {
- opt.InitValue = true
- opt.PointType = true
- opt.OriginValueType = strings.Replace(opt.ValueType, "*", "", 1)
- } else {
- opt.OriginValueType = opt.ValueType
- }
- if *encode != "" {
- var flags []string
- for _, f := range strings.Split(*encode, "|") {
- switch f {
- case "gob":
- flags = append(flags, "memcache.FlagGOB")
- case "json":
- flags = append(flags, "memcache.FlagJSON")
- case "raw":
- flags = append(flags, "memcache.FlagRAW")
- case "pb":
- flags = append(flags, "memcache.FlagProtobuf")
- case "gzip":
- flags = append(flags, "memcache.FlagGzip")
- default:
- log.Fatalf("%s: encode类型无效\n", opt.name)
- }
- }
- opt.Encode = strings.Join(flags, " | ")
- } else {
- if opt.SimpleValue {
- opt.Encode = "memcache.FlagRAW"
- } else {
- opt.Encode = "memcache.FlagJSON"
- }
- }
- opt.Check()
- opts = append(opts, &opt)
- }
- return
- }
- func (option *options) Check() {
- var valid bool
- for _, x := range mcValidTypes {
- if x == option.MCType {
- valid = true
- break
- }
- }
- if !valid {
- log.Fatalf("%s: 类型错误 不支持%s类型\n", option.name, option.MCType)
- }
- if (option.MCType != _typeDel) && !option.SimpleValue && !strings.Contains(option.ValueType, "*") && !strings.Contains(option.ValueType, "[]") && !strings.Contains(option.ValueType, "map") {
- log.Fatalf("%s: 值类型只能为基本类型/slice/map/指针类型\n", option.name)
- }
- }
- func genHeader(opts []*options) (src string) {
- option := options{PkgName: os.Getenv("GOPACKAGE"), UseMemcached: false}
- var packages []string
- packagesMap := map[string]bool{`"context"`: true}
- for _, opt := range opts {
- if len(opt.importPackages) > 0 {
- for _, pkg := range opt.importPackages {
- if !packagesMap[pkg] {
- packages = append(packages, pkg)
- packagesMap[pkg] = true
- }
- }
- }
- if opt.Args != "" {
- option.Args = opt.Args
- }
- if opt.UseMemcached {
- option.UseMemcached = true
- }
- if opt.SimpleValue && !opt.GetDirectValue {
- option.UseStrConv = true
- }
- if opt.EnableBatch {
- option.EnableBatch = true
- }
- }
- option.ImportPackage = strings.Join(packages, "\n")
- src = _headerTemplate
- t := template.Must(template.New("header").Parse(src))
- var buffer bytes.Buffer
- err := t.Execute(&buffer, option)
- if err != nil {
- log.Fatalf("execute template: %s", err)
- }
- // Format the output.
- src = strings.Replace(buffer.String(), "\t", "", -1)
- src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
- src = strings.Replace(src, "NEWLINE", "", -1)
- src = strings.Replace(src, "ARGS", option.Args, -1)
- return
- }
- func genBody(opts []*options) (res string) {
- for _, option := range opts {
- var src string
- if option.template == _multiTpl {
- switch option.MCType {
- case _typeGet:
- src = _multiGetTemplate
- case _typeSet:
- src = _multiSetTemplate
- case _typeReplace:
- src = _multiReplaceTemplate
- case _typeDel:
- src = _multiDelTemplate
- case _typeAdd:
- src = _multiAddTemplate
- }
- } else if option.template == _singleTpl {
- switch option.MCType {
- case _typeGet:
- src = _singleGetTemplate
- case _typeSet:
- src = _singleSetTemplate
- case _typeReplace:
- src = _singleReplaceTemplate
- case _typeDel:
- src = _singleDelTemplate
- case _typeAdd:
- src = _singleAddTemplate
- }
- } else {
- switch option.MCType {
- case _typeGet:
- src = _noneGetTemplate
- case _typeSet:
- src = _noneSetTemplate
- case _typeReplace:
- src = _noneReplaceTemplate
- case _typeDel:
- src = _noneDelTemplate
- case _typeAdd:
- src = _noneAddTemplate
- }
- }
- src = strings.Replace(src, "KEY", option.keyType, -1)
- src = strings.Replace(src, "NAME", option.name, -1)
- src = strings.Replace(src, "VALUE", option.ValueType, -1)
- src = strings.Replace(src, "GROUPSIZE", strconv.Itoa(option.GroupSize), -1)
- src = strings.Replace(src, "MAXGROUP", strconv.Itoa(option.MaxGroup), -1)
- t := template.Must(template.New("cache").Parse(src))
- var buffer bytes.Buffer
- err := t.Execute(&buffer, option)
- if err != nil {
- log.Fatalf("execute template: %s", err)
- }
- // Format the output.
- src = strings.Replace(buffer.String(), "\t", "", -1)
- src = regexp.MustCompile("\n+").ReplaceAllString(src, "\n")
- res = res + "\n" + src
- }
- return
- }
- func main() {
- log.SetFlags(0)
- defer func() {
- if err := recover(); err != nil {
- log.Fatalf("程序解析失败, err: %+v 请企业微信联系 @wangxu01", err)
- }
- }()
- options := parse(common.NewSource(common.SourceText()))
- header := genHeader(options)
- body := genBody(options)
- code := common.FormatCode(header + "\n" + body)
- // Write to file.
- dir := filepath.Dir(".")
- outputName := filepath.Join(dir, "mc.cache.go")
- err := ioutil.WriteFile(outputName, []byte(code), 0644)
- if err != nil {
- log.Fatalf("写入文件失败: %s", err)
- }
- log.Println("mc.cache.go: 生成成功")
- }
- func convertValue2Bytes(t string) string {
- switch t {
- case "int", "int8", "int16", "int32", "int64":
- return "[]byte(strconv.FormatInt(int64(val), 10))"
- case "uint", "uint8", "uint16", "uint32", "uint64":
- return "[]byte(strconv.FormatUInt(val, 10))"
- case "bool":
- return "[]byte(strconv.FormatBool(val))"
- case "float32":
- return "[]byte(strconv.FormatFloat(val, 'E', -1, 32))"
- case "float64":
- return "[]byte(strconv.FormatFloat(val, 'E', -1, 64))"
- case "string":
- return "[]byte(val)"
- case "[]byte":
- return "val"
- }
- return ""
- }
- func convertBytes2Value(t string) string {
- switch t {
- case "int", "int8", "int16", "int32", "int64":
- return "strconv.ParseInt(v, 10, 64)"
- case "uint", "uint8", "uint16", "uint32", "uint64":
- return "strconv.ParseUInt(v, 10, 64)"
- case "bool":
- return "strconv.ParseBool(v)"
- case "float32":
- return "float32(strconv.ParseFloat(v, 32))"
- case "float64":
- return "strconv.ParseFloat(v, 64)"
- }
- return ""
- }
|