decoder.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. package gen
  2. import (
  3. "encoding"
  4. "encoding/json"
  5. "fmt"
  6. "reflect"
  7. "strings"
  8. "unicode"
  9. "github.com/mailru/easyjson"
  10. )
  11. // Target this byte size for initial slice allocation to reduce garbage collection.
  12. const minSliceBytes = 64
  13. func (g *Generator) getDecoderName(t reflect.Type) string {
  14. return g.functionName("decode", t)
  15. }
  16. var primitiveDecoders = map[reflect.Kind]string{
  17. reflect.String: "in.String()",
  18. reflect.Bool: "in.Bool()",
  19. reflect.Int: "in.Int()",
  20. reflect.Int8: "in.Int8()",
  21. reflect.Int16: "in.Int16()",
  22. reflect.Int32: "in.Int32()",
  23. reflect.Int64: "in.Int64()",
  24. reflect.Uint: "in.Uint()",
  25. reflect.Uint8: "in.Uint8()",
  26. reflect.Uint16: "in.Uint16()",
  27. reflect.Uint32: "in.Uint32()",
  28. reflect.Uint64: "in.Uint64()",
  29. reflect.Float32: "in.Float32()",
  30. reflect.Float64: "in.Float64()",
  31. }
  32. var primitiveStringDecoders = map[reflect.Kind]string{
  33. reflect.String: "in.String()",
  34. reflect.Int: "in.IntStr()",
  35. reflect.Int8: "in.Int8Str()",
  36. reflect.Int16: "in.Int16Str()",
  37. reflect.Int32: "in.Int32Str()",
  38. reflect.Int64: "in.Int64Str()",
  39. reflect.Uint: "in.UintStr()",
  40. reflect.Uint8: "in.Uint8Str()",
  41. reflect.Uint16: "in.Uint16Str()",
  42. reflect.Uint32: "in.Uint32Str()",
  43. reflect.Uint64: "in.Uint64Str()",
  44. reflect.Uintptr: "in.UintptrStr()",
  45. reflect.Float32: "in.Float32Str()",
  46. reflect.Float64: "in.Float64Str()",
  47. }
  48. var customDecoders = map[string]string{
  49. "json.Number": "in.JsonNumber()",
  50. }
  51. // genTypeDecoder generates decoding code for the type t, but uses unmarshaler interface if implemented by t.
  52. func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, indent int) error {
  53. ws := strings.Repeat(" ", indent)
  54. unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()
  55. if reflect.PtrTo(t).Implements(unmarshalerIface) {
  56. fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)")
  57. return nil
  58. }
  59. unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
  60. if reflect.PtrTo(t).Implements(unmarshalerIface) {
  61. fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {")
  62. fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )")
  63. fmt.Fprintln(g.out, ws+"}")
  64. return nil
  65. }
  66. unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
  67. if reflect.PtrTo(t).Implements(unmarshalerIface) {
  68. fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {")
  69. fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )")
  70. fmt.Fprintln(g.out, ws+"}")
  71. return nil
  72. }
  73. err := g.genTypeDecoderNoCheck(t, out, tags, indent)
  74. return err
  75. }
  76. // returns true of the type t implements one of the custom unmarshaler interfaces
  77. func hasCustomUnmarshaler(t reflect.Type) bool {
  78. t = reflect.PtrTo(t)
  79. return t.Implements(reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()) ||
  80. t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) ||
  81. t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem())
  82. }
  83. // genTypeDecoderNoCheck generates decoding code for the type t.
  84. func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags fieldTags, indent int) error {
  85. ws := strings.Repeat(" ", indent)
  86. // Check whether type is primitive, needs to be done after interface check.
  87. if dec := customDecoders[t.String()]; dec != "" {
  88. fmt.Fprintln(g.out, ws+out+" = "+dec)
  89. return nil
  90. } else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString {
  91. fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
  92. return nil
  93. } else if dec := primitiveDecoders[t.Kind()]; dec != "" {
  94. fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
  95. return nil
  96. }
  97. switch t.Kind() {
  98. case reflect.Slice:
  99. tmpVar := g.uniqueVarName()
  100. elem := t.Elem()
  101. if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" {
  102. fmt.Fprintln(g.out, ws+"if in.IsNull() {")
  103. fmt.Fprintln(g.out, ws+" in.Skip()")
  104. fmt.Fprintln(g.out, ws+" "+out+" = nil")
  105. fmt.Fprintln(g.out, ws+"} else {")
  106. fmt.Fprintln(g.out, ws+" "+out+" = in.Bytes()")
  107. fmt.Fprintln(g.out, ws+"}")
  108. } else {
  109. capacity := minSliceBytes / elem.Size()
  110. if capacity == 0 {
  111. capacity = 1
  112. }
  113. fmt.Fprintln(g.out, ws+"if in.IsNull() {")
  114. fmt.Fprintln(g.out, ws+" in.Skip()")
  115. fmt.Fprintln(g.out, ws+" "+out+" = nil")
  116. fmt.Fprintln(g.out, ws+"} else {")
  117. fmt.Fprintln(g.out, ws+" in.Delim('[')")
  118. fmt.Fprintln(g.out, ws+" if "+out+" == nil {")
  119. fmt.Fprintln(g.out, ws+" if !in.IsDelim(']') {")
  120. fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+", 0, "+fmt.Sprint(capacity)+")")
  121. fmt.Fprintln(g.out, ws+" } else {")
  122. fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"{}")
  123. fmt.Fprintln(g.out, ws+" }")
  124. fmt.Fprintln(g.out, ws+" } else { ")
  125. fmt.Fprintln(g.out, ws+" "+out+" = ("+out+")[:0]")
  126. fmt.Fprintln(g.out, ws+" }")
  127. fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {")
  128. fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem))
  129. if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil {
  130. return err
  131. }
  132. fmt.Fprintln(g.out, ws+" "+out+" = append("+out+", "+tmpVar+")")
  133. fmt.Fprintln(g.out, ws+" in.WantComma()")
  134. fmt.Fprintln(g.out, ws+" }")
  135. fmt.Fprintln(g.out, ws+" in.Delim(']')")
  136. fmt.Fprintln(g.out, ws+"}")
  137. }
  138. case reflect.Array:
  139. iterVar := g.uniqueVarName()
  140. elem := t.Elem()
  141. if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" {
  142. fmt.Fprintln(g.out, ws+"if in.IsNull() {")
  143. fmt.Fprintln(g.out, ws+" in.Skip()")
  144. fmt.Fprintln(g.out, ws+"} else {")
  145. fmt.Fprintln(g.out, ws+" copy("+out+"[:], in.Bytes())")
  146. fmt.Fprintln(g.out, ws+"}")
  147. } else {
  148. length := t.Len()
  149. fmt.Fprintln(g.out, ws+"if in.IsNull() {")
  150. fmt.Fprintln(g.out, ws+" in.Skip()")
  151. fmt.Fprintln(g.out, ws+"} else {")
  152. fmt.Fprintln(g.out, ws+" in.Delim('[')")
  153. fmt.Fprintln(g.out, ws+" "+iterVar+" := 0")
  154. fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {")
  155. fmt.Fprintln(g.out, ws+" if "+iterVar+" < "+fmt.Sprint(length)+" {")
  156. if err := g.genTypeDecoder(elem, "("+out+")["+iterVar+"]", tags, indent+3); err != nil {
  157. return err
  158. }
  159. fmt.Fprintln(g.out, ws+" "+iterVar+"++")
  160. fmt.Fprintln(g.out, ws+" } else {")
  161. fmt.Fprintln(g.out, ws+" in.SkipRecursive()")
  162. fmt.Fprintln(g.out, ws+" }")
  163. fmt.Fprintln(g.out, ws+" in.WantComma()")
  164. fmt.Fprintln(g.out, ws+" }")
  165. fmt.Fprintln(g.out, ws+" in.Delim(']')")
  166. fmt.Fprintln(g.out, ws+"}")
  167. }
  168. case reflect.Struct:
  169. dec := g.getDecoderName(t)
  170. g.addType(t)
  171. fmt.Fprintln(g.out, ws+dec+"(in, &"+out+")")
  172. case reflect.Ptr:
  173. fmt.Fprintln(g.out, ws+"if in.IsNull() {")
  174. fmt.Fprintln(g.out, ws+" in.Skip()")
  175. fmt.Fprintln(g.out, ws+" "+out+" = nil")
  176. fmt.Fprintln(g.out, ws+"} else {")
  177. fmt.Fprintln(g.out, ws+" if "+out+" == nil {")
  178. fmt.Fprintln(g.out, ws+" "+out+" = new("+g.getType(t.Elem())+")")
  179. fmt.Fprintln(g.out, ws+" }")
  180. if err := g.genTypeDecoder(t.Elem(), "*"+out, tags, indent+1); err != nil {
  181. return err
  182. }
  183. fmt.Fprintln(g.out, ws+"}")
  184. case reflect.Map:
  185. key := t.Key()
  186. keyDec, ok := primitiveStringDecoders[key.Kind()]
  187. if !ok && !hasCustomUnmarshaler(key) {
  188. return fmt.Errorf("map type %v not supported: only string and integer keys and types implementing json.Unmarshaler are allowed", key)
  189. } // else assume the caller knows what they are doing and that the custom unmarshaler performs the translation from string or integer keys to the key type
  190. elem := t.Elem()
  191. tmpVar := g.uniqueVarName()
  192. fmt.Fprintln(g.out, ws+"if in.IsNull() {")
  193. fmt.Fprintln(g.out, ws+" in.Skip()")
  194. fmt.Fprintln(g.out, ws+"} else {")
  195. fmt.Fprintln(g.out, ws+" in.Delim('{')")
  196. fmt.Fprintln(g.out, ws+" if !in.IsDelim('}') {")
  197. fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+")")
  198. fmt.Fprintln(g.out, ws+" } else {")
  199. fmt.Fprintln(g.out, ws+" "+out+" = nil")
  200. fmt.Fprintln(g.out, ws+" }")
  201. fmt.Fprintln(g.out, ws+" for !in.IsDelim('}') {")
  202. if keyDec != "" {
  203. fmt.Fprintln(g.out, ws+" key := "+g.getType(key)+"("+keyDec+")")
  204. } else {
  205. fmt.Fprintln(g.out, ws+" var key "+g.getType(key))
  206. if err := g.genTypeDecoder(key, "key", tags, indent+2); err != nil {
  207. return err
  208. }
  209. }
  210. fmt.Fprintln(g.out, ws+" in.WantColon()")
  211. fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem))
  212. if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil {
  213. return err
  214. }
  215. fmt.Fprintln(g.out, ws+" ("+out+")[key] = "+tmpVar)
  216. fmt.Fprintln(g.out, ws+" in.WantComma()")
  217. fmt.Fprintln(g.out, ws+" }")
  218. fmt.Fprintln(g.out, ws+" in.Delim('}')")
  219. fmt.Fprintln(g.out, ws+"}")
  220. case reflect.Interface:
  221. if t.NumMethod() != 0 {
  222. return fmt.Errorf("interface type %v not supported: only interface{} is allowed", t)
  223. }
  224. fmt.Fprintln(g.out, ws+"if m, ok := "+out+".(easyjson.Unmarshaler); ok {")
  225. fmt.Fprintln(g.out, ws+"m.UnmarshalEasyJSON(in)")
  226. fmt.Fprintln(g.out, ws+"} else if m, ok := "+out+".(json.Unmarshaler); ok {")
  227. fmt.Fprintln(g.out, ws+"_ = m.UnmarshalJSON(in.Raw())")
  228. fmt.Fprintln(g.out, ws+"} else {")
  229. fmt.Fprintln(g.out, ws+" "+out+" = in.Interface()")
  230. fmt.Fprintln(g.out, ws+"}")
  231. default:
  232. return fmt.Errorf("don't know how to decode %v", t)
  233. }
  234. return nil
  235. }
  236. func (g *Generator) genStructFieldDecoder(t reflect.Type, f reflect.StructField) error {
  237. jsonName := g.fieldNamer.GetJSONFieldName(t, f)
  238. tags := parseFieldTags(f)
  239. if tags.omit {
  240. return nil
  241. }
  242. fmt.Fprintf(g.out, " case %q:\n", jsonName)
  243. if err := g.genTypeDecoder(f.Type, "out."+f.Name, tags, 3); err != nil {
  244. return err
  245. }
  246. if tags.required {
  247. fmt.Fprintf(g.out, "%sSet = true\n", f.Name)
  248. }
  249. return nil
  250. }
  251. func (g *Generator) genRequiredFieldSet(t reflect.Type, f reflect.StructField) {
  252. tags := parseFieldTags(f)
  253. if !tags.required {
  254. return
  255. }
  256. fmt.Fprintf(g.out, "var %sSet bool\n", f.Name)
  257. }
  258. func (g *Generator) genRequiredFieldCheck(t reflect.Type, f reflect.StructField) {
  259. jsonName := g.fieldNamer.GetJSONFieldName(t, f)
  260. tags := parseFieldTags(f)
  261. if !tags.required {
  262. return
  263. }
  264. g.imports["fmt"] = "fmt"
  265. fmt.Fprintf(g.out, "if !%sSet {\n", f.Name)
  266. fmt.Fprintf(g.out, " in.AddError(fmt.Errorf(\"key '%s' is required\"))\n", jsonName)
  267. fmt.Fprintf(g.out, "}\n")
  268. }
  269. func mergeStructFields(fields1, fields2 []reflect.StructField) (fields []reflect.StructField) {
  270. used := map[string]bool{}
  271. for _, f := range fields2 {
  272. used[f.Name] = true
  273. fields = append(fields, f)
  274. }
  275. for _, f := range fields1 {
  276. if !used[f.Name] {
  277. fields = append(fields, f)
  278. }
  279. }
  280. return
  281. }
  282. func getStructFields(t reflect.Type) ([]reflect.StructField, error) {
  283. if t.Kind() != reflect.Struct {
  284. return nil, fmt.Errorf("got %v; expected a struct", t)
  285. }
  286. var efields []reflect.StructField
  287. for i := 0; i < t.NumField(); i++ {
  288. f := t.Field(i)
  289. if !f.Anonymous {
  290. continue
  291. }
  292. t1 := f.Type
  293. if t1.Kind() == reflect.Ptr {
  294. t1 = t1.Elem()
  295. }
  296. fs, err := getStructFields(t1)
  297. if err != nil {
  298. return nil, fmt.Errorf("error processing embedded field: %v", err)
  299. }
  300. efields = mergeStructFields(efields, fs)
  301. }
  302. var fields []reflect.StructField
  303. for i := 0; i < t.NumField(); i++ {
  304. f := t.Field(i)
  305. if f.Anonymous {
  306. continue
  307. }
  308. c := []rune(f.Name)[0]
  309. if unicode.IsUpper(c) {
  310. fields = append(fields, f)
  311. }
  312. }
  313. return mergeStructFields(efields, fields), nil
  314. }
  315. func (g *Generator) genDecoder(t reflect.Type) error {
  316. switch t.Kind() {
  317. case reflect.Slice, reflect.Array, reflect.Map:
  318. return g.genSliceArrayDecoder(t)
  319. default:
  320. return g.genStructDecoder(t)
  321. }
  322. }
  323. func (g *Generator) genSliceArrayDecoder(t reflect.Type) error {
  324. switch t.Kind() {
  325. case reflect.Slice, reflect.Array, reflect.Map:
  326. default:
  327. return fmt.Errorf("cannot generate encoder/decoder for %v, not a slice/array/map type", t)
  328. }
  329. fname := g.getDecoderName(t)
  330. typ := g.getType(t)
  331. fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {")
  332. fmt.Fprintln(g.out, " isTopLevel := in.IsStart()")
  333. err := g.genTypeDecoderNoCheck(t, "*out", fieldTags{}, 1)
  334. if err != nil {
  335. return err
  336. }
  337. fmt.Fprintln(g.out, " if isTopLevel {")
  338. fmt.Fprintln(g.out, " in.Consumed()")
  339. fmt.Fprintln(g.out, " }")
  340. fmt.Fprintln(g.out, "}")
  341. return nil
  342. }
  343. func (g *Generator) genStructDecoder(t reflect.Type) error {
  344. if t.Kind() != reflect.Struct {
  345. return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct type", t)
  346. }
  347. fname := g.getDecoderName(t)
  348. typ := g.getType(t)
  349. fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {")
  350. fmt.Fprintln(g.out, " isTopLevel := in.IsStart()")
  351. fmt.Fprintln(g.out, " if in.IsNull() {")
  352. fmt.Fprintln(g.out, " if isTopLevel {")
  353. fmt.Fprintln(g.out, " in.Consumed()")
  354. fmt.Fprintln(g.out, " }")
  355. fmt.Fprintln(g.out, " in.Skip()")
  356. fmt.Fprintln(g.out, " return")
  357. fmt.Fprintln(g.out, " }")
  358. // Init embedded pointer fields.
  359. for i := 0; i < t.NumField(); i++ {
  360. f := t.Field(i)
  361. if !f.Anonymous || f.Type.Kind() != reflect.Ptr {
  362. continue
  363. }
  364. fmt.Fprintln(g.out, " out."+f.Name+" = new("+g.getType(f.Type.Elem())+")")
  365. }
  366. fs, err := getStructFields(t)
  367. if err != nil {
  368. return fmt.Errorf("cannot generate decoder for %v: %v", t, err)
  369. }
  370. for _, f := range fs {
  371. g.genRequiredFieldSet(t, f)
  372. }
  373. fmt.Fprintln(g.out, " in.Delim('{')")
  374. fmt.Fprintln(g.out, " for !in.IsDelim('}') {")
  375. fmt.Fprintln(g.out, " key := in.UnsafeString()")
  376. fmt.Fprintln(g.out, " in.WantColon()")
  377. fmt.Fprintln(g.out, " if in.IsNull() {")
  378. fmt.Fprintln(g.out, " in.Skip()")
  379. fmt.Fprintln(g.out, " in.WantComma()")
  380. fmt.Fprintln(g.out, " continue")
  381. fmt.Fprintln(g.out, " }")
  382. fmt.Fprintln(g.out, " switch key {")
  383. for _, f := range fs {
  384. if err := g.genStructFieldDecoder(t, f); err != nil {
  385. return err
  386. }
  387. }
  388. fmt.Fprintln(g.out, " default:")
  389. if g.disallowUnknownFields {
  390. fmt.Fprintln(g.out, ` in.AddError(&jlexer.LexerError{
  391. Offset: in.GetPos(),
  392. Reason: "unknown field",
  393. Data: key,
  394. })`)
  395. } else {
  396. fmt.Fprintln(g.out, " in.SkipRecursive()")
  397. }
  398. fmt.Fprintln(g.out, " }")
  399. fmt.Fprintln(g.out, " in.WantComma()")
  400. fmt.Fprintln(g.out, " }")
  401. fmt.Fprintln(g.out, " in.Delim('}')")
  402. fmt.Fprintln(g.out, " if isTopLevel {")
  403. fmt.Fprintln(g.out, " in.Consumed()")
  404. fmt.Fprintln(g.out, " }")
  405. for _, f := range fs {
  406. g.genRequiredFieldCheck(t, f)
  407. }
  408. fmt.Fprintln(g.out, "}")
  409. return nil
  410. }
  411. func (g *Generator) genStructUnmarshaler(t reflect.Type) error {
  412. switch t.Kind() {
  413. case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct:
  414. default:
  415. return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct/slice/array/map type", t)
  416. }
  417. fname := g.getDecoderName(t)
  418. typ := g.getType(t)
  419. if !g.noStdMarshalers {
  420. fmt.Fprintln(g.out, "// UnmarshalJSON supports json.Unmarshaler interface")
  421. fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalJSON(data []byte) error {")
  422. fmt.Fprintln(g.out, " r := jlexer.Lexer{Data: data}")
  423. fmt.Fprintln(g.out, " "+fname+"(&r, v)")
  424. fmt.Fprintln(g.out, " return r.Error()")
  425. fmt.Fprintln(g.out, "}")
  426. }
  427. fmt.Fprintln(g.out, "// UnmarshalEasyJSON supports easyjson.Unmarshaler interface")
  428. fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalEasyJSON(l *jlexer.Lexer) {")
  429. fmt.Fprintln(g.out, " "+fname+"(l, v)")
  430. fmt.Fprintln(g.out, "}")
  431. return nil
  432. }