generator.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. package gen
  2. import (
  3. "bytes"
  4. "fmt"
  5. "hash/fnv"
  6. "io"
  7. "path"
  8. "reflect"
  9. "sort"
  10. "strconv"
  11. "strings"
  12. "unicode"
  13. )
  14. const pkgWriter = "github.com/mailru/easyjson/jwriter"
  15. const pkgLexer = "github.com/mailru/easyjson/jlexer"
  16. const pkgEasyJSON = "github.com/mailru/easyjson"
  17. // FieldNamer defines a policy for generating names for struct fields.
  18. type FieldNamer interface {
  19. GetJSONFieldName(t reflect.Type, f reflect.StructField) string
  20. }
  21. // Generator generates the requested marshaler/unmarshalers.
  22. type Generator struct {
  23. out *bytes.Buffer
  24. pkgName string
  25. pkgPath string
  26. buildTags string
  27. hashString string
  28. varCounter int
  29. noStdMarshalers bool
  30. omitEmpty bool
  31. disallowUnknownFields bool
  32. fieldNamer FieldNamer
  33. // package path to local alias map for tracking imports
  34. imports map[string]string
  35. // types that marshalers were requested for by user
  36. marshalers map[reflect.Type]bool
  37. // types that encoders were already generated for
  38. typesSeen map[reflect.Type]bool
  39. // types that encoders were requested for (e.g. by encoders of other types)
  40. typesUnseen []reflect.Type
  41. // function name to relevant type maps to track names of de-/encoders in
  42. // case of a name clash or unnamed structs
  43. functionNames map[string]reflect.Type
  44. }
  45. // NewGenerator initializes and returns a Generator.
  46. func NewGenerator(filename string) *Generator {
  47. ret := &Generator{
  48. imports: map[string]string{
  49. pkgWriter: "jwriter",
  50. pkgLexer: "jlexer",
  51. pkgEasyJSON: "easyjson",
  52. "encoding/json": "json",
  53. },
  54. fieldNamer: DefaultFieldNamer{},
  55. marshalers: make(map[reflect.Type]bool),
  56. typesSeen: make(map[reflect.Type]bool),
  57. functionNames: make(map[string]reflect.Type),
  58. }
  59. // Use a file-unique prefix on all auxiliary funcs to avoid
  60. // name clashes.
  61. hash := fnv.New32()
  62. hash.Write([]byte(filename))
  63. ret.hashString = fmt.Sprintf("%x", hash.Sum32())
  64. return ret
  65. }
  66. // SetPkg sets the name and path of output package.
  67. func (g *Generator) SetPkg(name, path string) {
  68. g.pkgName = name
  69. g.pkgPath = path
  70. }
  71. // SetBuildTags sets build tags for the output file.
  72. func (g *Generator) SetBuildTags(tags string) {
  73. g.buildTags = tags
  74. }
  75. // SetFieldNamer sets field naming strategy.
  76. func (g *Generator) SetFieldNamer(n FieldNamer) {
  77. g.fieldNamer = n
  78. }
  79. // UseSnakeCase sets snake_case field naming strategy.
  80. func (g *Generator) UseSnakeCase() {
  81. g.fieldNamer = SnakeCaseFieldNamer{}
  82. }
  83. // UseLowerCamelCase sets lowerCamelCase field naming strategy.
  84. func (g *Generator) UseLowerCamelCase() {
  85. g.fieldNamer = LowerCamelCaseFieldNamer{}
  86. }
  87. // NoStdMarshalers instructs not to generate standard MarshalJSON/UnmarshalJSON
  88. // methods (only the custom interface).
  89. func (g *Generator) NoStdMarshalers() {
  90. g.noStdMarshalers = true
  91. }
  92. // DisallowUnknownFields instructs not to skip unknown fields in json and return error.
  93. func (g *Generator) DisallowUnknownFields() {
  94. g.disallowUnknownFields = true
  95. }
  96. // OmitEmpty triggers `json=",omitempty"` behaviour by default.
  97. func (g *Generator) OmitEmpty() {
  98. g.omitEmpty = true
  99. }
  100. // addTypes requests to generate encoding/decoding funcs for the given type.
  101. func (g *Generator) addType(t reflect.Type) {
  102. if g.typesSeen[t] {
  103. return
  104. }
  105. for _, t1 := range g.typesUnseen {
  106. if t1 == t {
  107. return
  108. }
  109. }
  110. g.typesUnseen = append(g.typesUnseen, t)
  111. }
  112. // Add requests to generate marshaler/unmarshalers and encoding/decoding
  113. // funcs for the type of given object.
  114. func (g *Generator) Add(obj interface{}) {
  115. t := reflect.TypeOf(obj)
  116. if t.Kind() == reflect.Ptr {
  117. t = t.Elem()
  118. }
  119. g.addType(t)
  120. g.marshalers[t] = true
  121. }
  122. // printHeader prints package declaration and imports.
  123. func (g *Generator) printHeader() {
  124. if g.buildTags != "" {
  125. fmt.Println("// +build ", g.buildTags)
  126. fmt.Println()
  127. }
  128. fmt.Println("// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT.")
  129. fmt.Println()
  130. fmt.Println("package ", g.pkgName)
  131. fmt.Println()
  132. byAlias := map[string]string{}
  133. var aliases []string
  134. for path, alias := range g.imports {
  135. aliases = append(aliases, alias)
  136. byAlias[alias] = path
  137. }
  138. sort.Strings(aliases)
  139. fmt.Println("import (")
  140. for _, alias := range aliases {
  141. fmt.Printf(" %s %q\n", alias, byAlias[alias])
  142. }
  143. fmt.Println(")")
  144. fmt.Println("")
  145. fmt.Println("// suppress unused package warning")
  146. fmt.Println("var (")
  147. fmt.Println(" _ *json.RawMessage")
  148. fmt.Println(" _ *jlexer.Lexer")
  149. fmt.Println(" _ *jwriter.Writer")
  150. fmt.Println(" _ easyjson.Marshaler")
  151. fmt.Println(")")
  152. fmt.Println()
  153. }
  154. // Run runs the generator and outputs generated code to out.
  155. func (g *Generator) Run(out io.Writer) error {
  156. g.out = &bytes.Buffer{}
  157. for len(g.typesUnseen) > 0 {
  158. t := g.typesUnseen[len(g.typesUnseen)-1]
  159. g.typesUnseen = g.typesUnseen[:len(g.typesUnseen)-1]
  160. g.typesSeen[t] = true
  161. if err := g.genDecoder(t); err != nil {
  162. return err
  163. }
  164. if err := g.genEncoder(t); err != nil {
  165. return err
  166. }
  167. if !g.marshalers[t] {
  168. continue
  169. }
  170. if err := g.genStructMarshaler(t); err != nil {
  171. return err
  172. }
  173. if err := g.genStructUnmarshaler(t); err != nil {
  174. return err
  175. }
  176. }
  177. g.printHeader()
  178. _, err := out.Write(g.out.Bytes())
  179. return err
  180. }
  181. // fixes vendored paths
  182. func fixPkgPathVendoring(pkgPath string) string {
  183. const vendor = "/vendor/"
  184. if i := strings.LastIndex(pkgPath, vendor); i != -1 {
  185. return pkgPath[i+len(vendor):]
  186. }
  187. return pkgPath
  188. }
  189. func fixAliasName(alias string) string {
  190. alias = strings.Replace(
  191. strings.Replace(alias, ".", "_", -1),
  192. "-",
  193. "_",
  194. -1,
  195. )
  196. if alias[0] == 'v' { // to void conflicting with var names, say v1
  197. alias = "_" + alias
  198. }
  199. return alias
  200. }
  201. // pkgAlias creates and returns and import alias for a given package.
  202. func (g *Generator) pkgAlias(pkgPath string) string {
  203. pkgPath = fixPkgPathVendoring(pkgPath)
  204. if alias := g.imports[pkgPath]; alias != "" {
  205. return alias
  206. }
  207. for i := 0; ; i++ {
  208. alias := fixAliasName(path.Base(pkgPath))
  209. if i > 0 {
  210. alias += fmt.Sprint(i)
  211. }
  212. exists := false
  213. for _, v := range g.imports {
  214. if v == alias {
  215. exists = true
  216. break
  217. }
  218. }
  219. if !exists {
  220. g.imports[pkgPath] = alias
  221. return alias
  222. }
  223. }
  224. }
  225. // getType return the textual type name of given type that can be used in generated code.
  226. func (g *Generator) getType(t reflect.Type) string {
  227. if t.Name() == "" {
  228. switch t.Kind() {
  229. case reflect.Ptr:
  230. return "*" + g.getType(t.Elem())
  231. case reflect.Slice:
  232. return "[]" + g.getType(t.Elem())
  233. case reflect.Array:
  234. return "[" + strconv.Itoa(t.Len()) + "]" + g.getType(t.Elem())
  235. case reflect.Map:
  236. return "map[" + g.getType(t.Key()) + "]" + g.getType(t.Elem())
  237. }
  238. }
  239. if t.Name() == "" || t.PkgPath() == "" {
  240. if t.Kind() == reflect.Struct {
  241. // the fields of an anonymous struct can have named types,
  242. // and t.String() will not be sufficient because it does not
  243. // remove the package name when it matches g.pkgPath.
  244. // so we convert by hand
  245. nf := t.NumField()
  246. lines := make([]string, 0, nf)
  247. for i := 0; i < nf; i++ {
  248. f := t.Field(i)
  249. var line string
  250. if !f.Anonymous {
  251. line = f.Name + " "
  252. } // else the field is anonymous (an embedded type)
  253. line += g.getType(f.Type)
  254. t := f.Tag
  255. if t != "" {
  256. line += " " + escapeTag(t)
  257. }
  258. lines = append(lines, line)
  259. }
  260. return strings.Join([]string{"struct { ", strings.Join(lines, "; "), " }"}, "")
  261. }
  262. return t.String()
  263. } else if t.PkgPath() == g.pkgPath {
  264. return t.Name()
  265. }
  266. return g.pkgAlias(t.PkgPath()) + "." + t.Name()
  267. }
  268. // escape a struct field tag string back to source code
  269. func escapeTag(tag reflect.StructTag) string {
  270. t := string(tag)
  271. if strings.ContainsRune(t, '`') {
  272. // there are ` in the string; we can't use ` to enclose the string
  273. return strconv.Quote(t)
  274. }
  275. return "`" + t + "`"
  276. }
  277. // uniqueVarName returns a file-unique name that can be used for generated variables.
  278. func (g *Generator) uniqueVarName() string {
  279. g.varCounter++
  280. return fmt.Sprint("v", g.varCounter)
  281. }
  282. // safeName escapes unsafe characters in pkg/type name and returns a string that can be used
  283. // in encoder/decoder names for the type.
  284. func (g *Generator) safeName(t reflect.Type) string {
  285. name := t.PkgPath()
  286. if t.Name() == "" {
  287. name += "anonymous"
  288. } else {
  289. name += "." + t.Name()
  290. }
  291. parts := []string{}
  292. part := []rune{}
  293. for _, c := range name {
  294. if unicode.IsLetter(c) || unicode.IsDigit(c) {
  295. part = append(part, c)
  296. } else if len(part) > 0 {
  297. parts = append(parts, string(part))
  298. part = []rune{}
  299. }
  300. }
  301. return joinFunctionNameParts(false, parts...)
  302. }
  303. // functionName returns a function name for a given type with a given prefix. If a function
  304. // with this prefix already exists for a type, it is returned.
  305. //
  306. // Method is used to track encoder/decoder names for the type.
  307. func (g *Generator) functionName(prefix string, t reflect.Type) string {
  308. prefix = joinFunctionNameParts(true, "easyjson", g.hashString, prefix)
  309. name := joinFunctionNameParts(true, prefix, g.safeName(t))
  310. // Most of the names will be unique, try a shortcut first.
  311. if e, ok := g.functionNames[name]; !ok || e == t {
  312. g.functionNames[name] = t
  313. return name
  314. }
  315. // Search if the function already exists.
  316. for name1, t1 := range g.functionNames {
  317. if t1 == t && strings.HasPrefix(name1, prefix) {
  318. return name1
  319. }
  320. }
  321. // Create a new name in the case of a clash.
  322. for i := 1; ; i++ {
  323. nm := fmt.Sprint(name, i)
  324. if _, ok := g.functionNames[nm]; ok {
  325. continue
  326. }
  327. g.functionNames[nm] = t
  328. return nm
  329. }
  330. }
  331. // DefaultFieldsNamer implements trivial naming policy equivalent to encoding/json.
  332. type DefaultFieldNamer struct{}
  333. func (DefaultFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
  334. jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
  335. if jsonName != "" {
  336. return jsonName
  337. } else {
  338. return f.Name
  339. }
  340. }
  341. // LowerCamelCaseFieldNamer
  342. type LowerCamelCaseFieldNamer struct{}
  343. func isLower(b byte) bool {
  344. return b <= 122 && b >= 97
  345. }
  346. func isUpper(b byte) bool {
  347. return b >= 65 && b <= 90
  348. }
  349. // convert HTTPRestClient to httpRestClient
  350. func lowerFirst(s string) string {
  351. if s == "" {
  352. return ""
  353. }
  354. str := ""
  355. strlen := len(s)
  356. /**
  357. Loop each char
  358. If is uppercase:
  359. If is first char, LOWER it
  360. If the following char is lower, LEAVE it
  361. If the following char is upper OR numeric, LOWER it
  362. If is the end of string, LEAVE it
  363. Else lowercase
  364. */
  365. foundLower := false
  366. for i := range s {
  367. ch := s[i]
  368. if isUpper(ch) {
  369. if i == 0 {
  370. str += string(ch + 32)
  371. } else if !foundLower { // Currently just a stream of capitals, eg JSONRESTS[erver]
  372. if strlen > (i+1) && isLower(s[i+1]) {
  373. // Next char is lower, keep this a capital
  374. str += string(ch)
  375. } else {
  376. // Either at end of string or next char is capital
  377. str += string(ch + 32)
  378. }
  379. } else {
  380. str += string(ch)
  381. }
  382. } else {
  383. foundLower = true
  384. str += string(ch)
  385. }
  386. }
  387. return str
  388. }
  389. func (LowerCamelCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
  390. jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
  391. if jsonName != "" {
  392. return jsonName
  393. } else {
  394. return lowerFirst(f.Name)
  395. }
  396. }
  397. // SnakeCaseFieldNamer implements CamelCase to snake_case conversion for fields names.
  398. type SnakeCaseFieldNamer struct{}
  399. func camelToSnake(name string) string {
  400. var ret bytes.Buffer
  401. multipleUpper := false
  402. var lastUpper rune
  403. var beforeUpper rune
  404. for _, c := range name {
  405. // Non-lowercase character after uppercase is considered to be uppercase too.
  406. isUpper := (unicode.IsUpper(c) || (lastUpper != 0 && !unicode.IsLower(c)))
  407. if lastUpper != 0 {
  408. // Output a delimiter if last character was either the first uppercase character
  409. // in a row, or the last one in a row (e.g. 'S' in "HTTPServer").
  410. // Do not output a delimiter at the beginning of the name.
  411. firstInRow := !multipleUpper
  412. lastInRow := !isUpper
  413. if ret.Len() > 0 && (firstInRow || lastInRow) && beforeUpper != '_' {
  414. ret.WriteByte('_')
  415. }
  416. ret.WriteRune(unicode.ToLower(lastUpper))
  417. }
  418. // Buffer uppercase char, do not output it yet as a delimiter may be required if the
  419. // next character is lowercase.
  420. if isUpper {
  421. multipleUpper = (lastUpper != 0)
  422. lastUpper = c
  423. continue
  424. }
  425. ret.WriteRune(c)
  426. lastUpper = 0
  427. beforeUpper = c
  428. multipleUpper = false
  429. }
  430. if lastUpper != 0 {
  431. ret.WriteRune(unicode.ToLower(lastUpper))
  432. }
  433. return string(ret.Bytes())
  434. }
  435. func (SnakeCaseFieldNamer) GetJSONFieldName(t reflect.Type, f reflect.StructField) string {
  436. jsonName := strings.Split(f.Tag.Get("json"), ",")[0]
  437. if jsonName != "" {
  438. return jsonName
  439. }
  440. return camelToSnake(f.Name)
  441. }
  442. func joinFunctionNameParts(keepFirst bool, parts ...string) string {
  443. buf := bytes.NewBufferString("")
  444. for i, part := range parts {
  445. if i == 0 && keepFirst {
  446. buf.WriteString(part)
  447. } else {
  448. if len(part) > 0 {
  449. buf.WriteString(strings.ToUpper(string(part[0])))
  450. }
  451. if len(part) > 1 {
  452. buf.WriteString(part[1:])
  453. }
  454. }
  455. }
  456. return buf.String()
  457. }