gencscode.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. package generator
  2. import (
  3. "fmt"
  4. "os"
  5. "path"
  6. "strings"
  7. "text/template"
  8. assets "go-common/app/tool/warden/generator/templates"
  9. "go-common/app/tool/warden/types"
  10. )
  11. // GenCSCodeOptions options
  12. type GenCSCodeOptions struct {
  13. PbPackage string
  14. RecvPackage string
  15. RecvName string
  16. }
  17. // CSValue ...
  18. type CSValue struct {
  19. options *GenCSCodeOptions
  20. Name string
  21. PbPackage string
  22. RecvName string
  23. RecvPackage string
  24. Imports map[string]struct{}
  25. ClientImports map[string]struct{}
  26. Methods []CSMethod
  27. }
  28. // CSMethod ...
  29. type CSMethod struct {
  30. Name string
  31. Comments []string
  32. ParamBlock string
  33. ReturnBlock string
  34. ParamPbBlock string
  35. }
  36. func (c *CSValue) render(spec *types.ServiceSpec) error {
  37. c.PbPackage = c.options.PbPackage
  38. c.Name = spec.Name
  39. c.RecvName = c.options.RecvName
  40. c.RecvPackage = c.options.RecvPackage
  41. c.Imports = map[string]struct{}{"context": struct{}{}}
  42. c.ClientImports = make(map[string]struct{})
  43. return c.renderMethods(spec.Methods)
  44. }
  45. func (c *CSValue) renderMethods(methods []*types.Method) error {
  46. for _, method := range methods {
  47. csMethod := CSMethod{
  48. Name: method.Name,
  49. Comments: method.Comments,
  50. ParamBlock: c.formatField(method.Parameters),
  51. ReturnBlock: c.formatField(method.Results),
  52. }
  53. c.Methods = append(c.Methods, csMethod)
  54. }
  55. return nil
  56. }
  57. func (c *CSValue) formatField(fields []*types.Field) string {
  58. var ss []string
  59. clientImps := make(map[string]struct{})
  60. for _, field := range fields {
  61. if field.Name == "" {
  62. ss = append(ss, field.Type.String())
  63. } else {
  64. ss = append(ss, fmt.Sprintf("%s %s", field.Name, field.Type))
  65. }
  66. importType(clientImps, field.Type)
  67. }
  68. for k := range clientImps {
  69. if _, ok := c.Imports[k]; !ok {
  70. c.ClientImports[k] = struct{}{}
  71. }
  72. }
  73. return strings.Join(ss, ", ")
  74. }
  75. func importType(m map[string]struct{}, t types.Typer) {
  76. if m == nil {
  77. panic("map is nil")
  78. }
  79. switch v := t.(type) {
  80. case *types.StructType:
  81. m[v.ImportPath] = struct{}{}
  82. for _, f := range v.Fields {
  83. importType(m, f.Type)
  84. }
  85. case *types.ArrayType:
  86. importType(m, v.EltType)
  87. case *types.InterfaceType:
  88. m[v.ImportPath] = struct{}{}
  89. }
  90. }
  91. func renderCSValue(spec *types.ServiceSpec, options *GenCSCodeOptions) (*CSValue, error) {
  92. value := &CSValue{
  93. options: options,
  94. }
  95. return value, value.render(spec)
  96. }
  97. // GenCSCode generator client, server code
  98. func GenCSCode(csdir string, spec *types.ServiceSpec, options *GenCSCodeOptions) error {
  99. value, err := renderCSValue(spec, options)
  100. if err != nil {
  101. return err
  102. }
  103. return genCode(value, "server", csdir)
  104. }
  105. func genCode(value *CSValue, name, dir string) error {
  106. if err := os.MkdirAll(dir, 0755); err != nil {
  107. return err
  108. }
  109. fp, err := os.OpenFile(path.Join(dir, fmt.Sprintf("%s.go", name)), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
  110. if err != nil {
  111. return err
  112. }
  113. defer fp.Close()
  114. templateName := fmt.Sprintf("%s.tmpl", name)
  115. t, err := template.New(name).Parse(string(assets.MustAsset(templateName)))
  116. if err != nil {
  117. return err
  118. }
  119. return t.Execute(fp, value)
  120. }