encoder.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. package gen
  2. import (
  3. "encoding"
  4. "encoding/json"
  5. "fmt"
  6. "reflect"
  7. "strconv"
  8. "strings"
  9. "github.com/mailru/easyjson"
  10. )
  11. func (g *Generator) getEncoderName(t reflect.Type) string {
  12. return g.functionName("encode", t)
  13. }
  14. var primitiveEncoders = map[reflect.Kind]string{
  15. reflect.String: "out.String(string(%v))",
  16. reflect.Bool: "out.Bool(bool(%v))",
  17. reflect.Int: "out.Int(int(%v))",
  18. reflect.Int8: "out.Int8(int8(%v))",
  19. reflect.Int16: "out.Int16(int16(%v))",
  20. reflect.Int32: "out.Int32(int32(%v))",
  21. reflect.Int64: "out.Int64(int64(%v))",
  22. reflect.Uint: "out.Uint(uint(%v))",
  23. reflect.Uint8: "out.Uint8(uint8(%v))",
  24. reflect.Uint16: "out.Uint16(uint16(%v))",
  25. reflect.Uint32: "out.Uint32(uint32(%v))",
  26. reflect.Uint64: "out.Uint64(uint64(%v))",
  27. reflect.Float32: "out.Float32(float32(%v))",
  28. reflect.Float64: "out.Float64(float64(%v))",
  29. }
  30. var primitiveStringEncoders = map[reflect.Kind]string{
  31. reflect.String: "out.String(string(%v))",
  32. reflect.Int: "out.IntStr(int(%v))",
  33. reflect.Int8: "out.Int8Str(int8(%v))",
  34. reflect.Int16: "out.Int16Str(int16(%v))",
  35. reflect.Int32: "out.Int32Str(int32(%v))",
  36. reflect.Int64: "out.Int64Str(int64(%v))",
  37. reflect.Uint: "out.UintStr(uint(%v))",
  38. reflect.Uint8: "out.Uint8Str(uint8(%v))",
  39. reflect.Uint16: "out.Uint16Str(uint16(%v))",
  40. reflect.Uint32: "out.Uint32Str(uint32(%v))",
  41. reflect.Uint64: "out.Uint64Str(uint64(%v))",
  42. reflect.Uintptr: "out.UintptrStr(uintptr(%v))",
  43. reflect.Float32: "out.Float32Str(float32(%v))",
  44. reflect.Float64: "out.Float64Str(float64(%v))",
  45. }
  46. // fieldTags contains parsed version of json struct field tags.
  47. type fieldTags struct {
  48. name string
  49. omit bool
  50. omitEmpty bool
  51. noOmitEmpty bool
  52. asString bool
  53. required bool
  54. }
  55. // parseFieldTags parses the json field tag into a structure.
  56. func parseFieldTags(f reflect.StructField) fieldTags {
  57. var ret fieldTags
  58. for i, s := range strings.Split(f.Tag.Get("json"), ",") {
  59. switch {
  60. case i == 0 && s == "-":
  61. ret.omit = true
  62. case i == 0:
  63. ret.name = s
  64. case s == "omitempty":
  65. ret.omitEmpty = true
  66. case s == "!omitempty":
  67. ret.noOmitEmpty = true
  68. case s == "string":
  69. ret.asString = true
  70. case s == "required":
  71. ret.required = true
  72. }
  73. }
  74. return ret
  75. }
  76. // genTypeEncoder generates code that encodes in of type t into the writer, but uses marshaler interface if implemented by t.
  77. func (g *Generator) genTypeEncoder(t reflect.Type, in string, tags fieldTags, indent int, assumeNonEmpty bool) error {
  78. ws := strings.Repeat(" ", indent)
  79. marshalerIface := reflect.TypeOf((*easyjson.Marshaler)(nil)).Elem()
  80. if reflect.PtrTo(t).Implements(marshalerIface) {
  81. fmt.Fprintln(g.out, ws+"("+in+").MarshalEasyJSON(out)")
  82. return nil
  83. }
  84. marshalerIface = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
  85. if reflect.PtrTo(t).Implements(marshalerIface) {
  86. fmt.Fprintln(g.out, ws+"out.Raw( ("+in+").MarshalJSON() )")
  87. return nil
  88. }
  89. marshalerIface = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
  90. if reflect.PtrTo(t).Implements(marshalerIface) {
  91. fmt.Fprintln(g.out, ws+"out.RawText( ("+in+").MarshalText() )")
  92. return nil
  93. }
  94. err := g.genTypeEncoderNoCheck(t, in, tags, indent, assumeNonEmpty)
  95. return err
  96. }
  97. // returns true of the type t implements one of the custom marshaler interfaces
  98. func hasCustomMarshaler(t reflect.Type) bool {
  99. t = reflect.PtrTo(t)
  100. return t.Implements(reflect.TypeOf((*easyjson.Marshaler)(nil)).Elem()) ||
  101. t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) ||
  102. t.Implements(reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem())
  103. }
  104. // genTypeEncoderNoCheck generates code that encodes in of type t into the writer.
  105. func (g *Generator) genTypeEncoderNoCheck(t reflect.Type, in string, tags fieldTags, indent int, assumeNonEmpty bool) error {
  106. ws := strings.Repeat(" ", indent)
  107. // Check whether type is primitive, needs to be done after interface check.
  108. if enc := primitiveStringEncoders[t.Kind()]; enc != "" && tags.asString {
  109. fmt.Fprintf(g.out, ws+enc+"\n", in)
  110. return nil
  111. } else if enc := primitiveEncoders[t.Kind()]; enc != "" {
  112. fmt.Fprintf(g.out, ws+enc+"\n", in)
  113. return nil
  114. }
  115. switch t.Kind() {
  116. case reflect.Slice:
  117. elem := t.Elem()
  118. iVar := g.uniqueVarName()
  119. vVar := g.uniqueVarName()
  120. if t.Elem().Kind() == reflect.Uint8 && elem.Name() == "uint8" {
  121. fmt.Fprintln(g.out, ws+"out.Base64Bytes("+in+")")
  122. } else {
  123. if !assumeNonEmpty {
  124. fmt.Fprintln(g.out, ws+"if "+in+" == nil && (out.Flags & jwriter.NilSliceAsEmpty) == 0 {")
  125. fmt.Fprintln(g.out, ws+` out.RawString("null")`)
  126. fmt.Fprintln(g.out, ws+"} else {")
  127. } else {
  128. fmt.Fprintln(g.out, ws+"{")
  129. }
  130. fmt.Fprintln(g.out, ws+" out.RawByte('[')")
  131. fmt.Fprintln(g.out, ws+" for "+iVar+", "+vVar+" := range "+in+" {")
  132. fmt.Fprintln(g.out, ws+" if "+iVar+" > 0 {")
  133. fmt.Fprintln(g.out, ws+" out.RawByte(',')")
  134. fmt.Fprintln(g.out, ws+" }")
  135. if err := g.genTypeEncoder(elem, vVar, tags, indent+2, false); err != nil {
  136. return err
  137. }
  138. fmt.Fprintln(g.out, ws+" }")
  139. fmt.Fprintln(g.out, ws+" out.RawByte(']')")
  140. fmt.Fprintln(g.out, ws+"}")
  141. }
  142. case reflect.Array:
  143. elem := t.Elem()
  144. iVar := g.uniqueVarName()
  145. if t.Elem().Kind() == reflect.Uint8 && elem.Name() == "uint8" {
  146. fmt.Fprintln(g.out, ws+"out.Base64Bytes("+in+"[:])")
  147. } else {
  148. fmt.Fprintln(g.out, ws+"out.RawByte('[')")
  149. fmt.Fprintln(g.out, ws+"for "+iVar+" := range "+in+" {")
  150. fmt.Fprintln(g.out, ws+" if "+iVar+" > 0 {")
  151. fmt.Fprintln(g.out, ws+" out.RawByte(',')")
  152. fmt.Fprintln(g.out, ws+" }")
  153. if err := g.genTypeEncoder(elem, "("+in+")["+iVar+"]", tags, indent+1, false); err != nil {
  154. return err
  155. }
  156. fmt.Fprintln(g.out, ws+"}")
  157. fmt.Fprintln(g.out, ws+"out.RawByte(']')")
  158. }
  159. case reflect.Struct:
  160. enc := g.getEncoderName(t)
  161. g.addType(t)
  162. fmt.Fprintln(g.out, ws+enc+"(out, "+in+")")
  163. case reflect.Ptr:
  164. if !assumeNonEmpty {
  165. fmt.Fprintln(g.out, ws+"if "+in+" == nil {")
  166. fmt.Fprintln(g.out, ws+` out.RawString("null")`)
  167. fmt.Fprintln(g.out, ws+"} else {")
  168. }
  169. if err := g.genTypeEncoder(t.Elem(), "*"+in, tags, indent+1, false); err != nil {
  170. return err
  171. }
  172. if !assumeNonEmpty {
  173. fmt.Fprintln(g.out, ws+"}")
  174. }
  175. case reflect.Map:
  176. key := t.Key()
  177. keyEnc, ok := primitiveStringEncoders[key.Kind()]
  178. if !ok && !hasCustomMarshaler(key) {
  179. return fmt.Errorf("map key type %v not supported: only string and integer keys and types implementing Marshaler interfaces are allowed", key)
  180. } // else assume the caller knows what they are doing and that the custom marshaler performs the translation from the key type to a string or integer
  181. tmpVar := g.uniqueVarName()
  182. if !assumeNonEmpty {
  183. fmt.Fprintln(g.out, ws+"if "+in+" == nil && (out.Flags & jwriter.NilMapAsEmpty) == 0 {")
  184. fmt.Fprintln(g.out, ws+" out.RawString(`null`)")
  185. fmt.Fprintln(g.out, ws+"} else {")
  186. } else {
  187. fmt.Fprintln(g.out, ws+"{")
  188. }
  189. fmt.Fprintln(g.out, ws+" out.RawByte('{')")
  190. fmt.Fprintln(g.out, ws+" "+tmpVar+"First := true")
  191. fmt.Fprintln(g.out, ws+" for "+tmpVar+"Name, "+tmpVar+"Value := range "+in+" {")
  192. fmt.Fprintln(g.out, ws+" if "+tmpVar+"First { "+tmpVar+"First = false } else { out.RawByte(',') }")
  193. if keyEnc != "" {
  194. fmt.Fprintln(g.out, ws+" "+fmt.Sprintf(keyEnc, tmpVar+"Name"))
  195. } else {
  196. if err := g.genTypeEncoder(key, tmpVar+"Name", tags, indent+2, false); err != nil {
  197. return err
  198. }
  199. }
  200. fmt.Fprintln(g.out, ws+" out.RawByte(':')")
  201. if err := g.genTypeEncoder(t.Elem(), tmpVar+"Value", tags, indent+2, false); err != nil {
  202. return err
  203. }
  204. fmt.Fprintln(g.out, ws+" }")
  205. fmt.Fprintln(g.out, ws+" out.RawByte('}')")
  206. fmt.Fprintln(g.out, ws+"}")
  207. case reflect.Interface:
  208. if t.NumMethod() != 0 {
  209. return fmt.Errorf("interface type %v not supported: only interface{} is allowed", t)
  210. }
  211. fmt.Fprintln(g.out, ws+"if m, ok := "+in+".(easyjson.Marshaler); ok {")
  212. fmt.Fprintln(g.out, ws+" m.MarshalEasyJSON(out)")
  213. fmt.Fprintln(g.out, ws+"} else if m, ok := "+in+".(json.Marshaler); ok {")
  214. fmt.Fprintln(g.out, ws+" out.Raw(m.MarshalJSON())")
  215. fmt.Fprintln(g.out, ws+"} else {")
  216. fmt.Fprintln(g.out, ws+" out.Raw(json.Marshal("+in+"))")
  217. fmt.Fprintln(g.out, ws+"}")
  218. default:
  219. return fmt.Errorf("don't know how to encode %v", t)
  220. }
  221. return nil
  222. }
  223. func (g *Generator) notEmptyCheck(t reflect.Type, v string) string {
  224. optionalIface := reflect.TypeOf((*easyjson.Optional)(nil)).Elem()
  225. if reflect.PtrTo(t).Implements(optionalIface) {
  226. return "(" + v + ").IsDefined()"
  227. }
  228. switch t.Kind() {
  229. case reflect.Slice, reflect.Map:
  230. return "len(" + v + ") != 0"
  231. case reflect.Interface, reflect.Ptr:
  232. return v + " != nil"
  233. case reflect.Bool:
  234. return v
  235. case reflect.String:
  236. return v + ` != ""`
  237. case reflect.Float32, reflect.Float64,
  238. reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
  239. reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  240. return v + " != 0"
  241. default:
  242. // note: Array types don't have a useful empty value
  243. return "true"
  244. }
  245. }
  246. func (g *Generator) genStructFieldEncoder(t reflect.Type, f reflect.StructField) error {
  247. jsonName := g.fieldNamer.GetJSONFieldName(t, f)
  248. tags := parseFieldTags(f)
  249. if tags.omit {
  250. return nil
  251. }
  252. noOmitEmpty := (!tags.omitEmpty && !g.omitEmpty) || tags.noOmitEmpty
  253. if noOmitEmpty {
  254. fmt.Fprintln(g.out, " {")
  255. } else {
  256. fmt.Fprintln(g.out, " if", g.notEmptyCheck(f.Type, "in."+f.Name), "{")
  257. }
  258. fmt.Fprintf(g.out, " const prefix string = %q\n", ","+strconv.Quote(jsonName)+":")
  259. fmt.Fprintln(g.out, " if first {")
  260. fmt.Fprintln(g.out, " first = false")
  261. fmt.Fprintln(g.out, " out.RawString(prefix[1:])")
  262. fmt.Fprintln(g.out, " } else {")
  263. fmt.Fprintln(g.out, " out.RawString(prefix)")
  264. fmt.Fprintln(g.out, " }")
  265. if err := g.genTypeEncoder(f.Type, "in."+f.Name, tags, 2, !noOmitEmpty); err != nil {
  266. return err
  267. }
  268. fmt.Fprintln(g.out, " }")
  269. return nil
  270. }
  271. func (g *Generator) genEncoder(t reflect.Type) error {
  272. switch t.Kind() {
  273. case reflect.Slice, reflect.Array, reflect.Map:
  274. return g.genSliceArrayMapEncoder(t)
  275. default:
  276. return g.genStructEncoder(t)
  277. }
  278. }
  279. func (g *Generator) genSliceArrayMapEncoder(t reflect.Type) error {
  280. switch t.Kind() {
  281. case reflect.Slice, reflect.Array, reflect.Map:
  282. default:
  283. return fmt.Errorf("cannot generate encoder/decoder for %v, not a slice/array/map type", t)
  284. }
  285. fname := g.getEncoderName(t)
  286. typ := g.getType(t)
  287. fmt.Fprintln(g.out, "func "+fname+"(out *jwriter.Writer, in "+typ+") {")
  288. err := g.genTypeEncoderNoCheck(t, "in", fieldTags{}, 1, false)
  289. if err != nil {
  290. return err
  291. }
  292. fmt.Fprintln(g.out, "}")
  293. return nil
  294. }
  295. func (g *Generator) genStructEncoder(t reflect.Type) error {
  296. if t.Kind() != reflect.Struct {
  297. return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct type", t)
  298. }
  299. fname := g.getEncoderName(t)
  300. typ := g.getType(t)
  301. fmt.Fprintln(g.out, "func "+fname+"(out *jwriter.Writer, in "+typ+") {")
  302. fmt.Fprintln(g.out, " out.RawByte('{')")
  303. fmt.Fprintln(g.out, " first := true")
  304. fmt.Fprintln(g.out, " _ = first")
  305. fs, err := getStructFields(t)
  306. if err != nil {
  307. return fmt.Errorf("cannot generate encoder for %v: %v", t, err)
  308. }
  309. for _, f := range fs {
  310. if err := g.genStructFieldEncoder(t, f); err != nil {
  311. return err
  312. }
  313. }
  314. fmt.Fprintln(g.out, " out.RawByte('}')")
  315. fmt.Fprintln(g.out, "}")
  316. return nil
  317. }
  318. func (g *Generator) genStructMarshaler(t reflect.Type) error {
  319. switch t.Kind() {
  320. case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct:
  321. default:
  322. return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct/slice/array/map type", t)
  323. }
  324. fname := g.getEncoderName(t)
  325. typ := g.getType(t)
  326. if !g.noStdMarshalers {
  327. fmt.Fprintln(g.out, "// MarshalJSON supports json.Marshaler interface")
  328. fmt.Fprintln(g.out, "func (v "+typ+") MarshalJSON() ([]byte, error) {")
  329. fmt.Fprintln(g.out, " w := jwriter.Writer{}")
  330. fmt.Fprintln(g.out, " "+fname+"(&w, v)")
  331. fmt.Fprintln(g.out, " return w.Buffer.BuildBytes(), w.Error")
  332. fmt.Fprintln(g.out, "}")
  333. }
  334. fmt.Fprintln(g.out, "// MarshalEasyJSON supports easyjson.Marshaler interface")
  335. fmt.Fprintln(g.out, "func (v "+typ+") MarshalEasyJSON(w *jwriter.Writer) {")
  336. fmt.Fprintln(g.out, " "+fname+"(w, v)")
  337. fmt.Fprintln(g.out, "}")
  338. return nil
  339. }