123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515 |
- package gen
- import (
- "encoding"
- "encoding/json"
- "fmt"
- "reflect"
- "strings"
- "unicode"
- "github.com/mailru/easyjson"
- )
- // Target this byte size for initial slice allocation to reduce garbage collection.
- const minSliceBytes = 64
- func (g *Generator) getDecoderName(t reflect.Type) string {
- return g.functionName("decode", t)
- }
- var primitiveDecoders = map[reflect.Kind]string{
- reflect.String: "in.String()",
- reflect.Bool: "in.Bool()",
- reflect.Int: "in.Int()",
- reflect.Int8: "in.Int8()",
- reflect.Int16: "in.Int16()",
- reflect.Int32: "in.Int32()",
- reflect.Int64: "in.Int64()",
- reflect.Uint: "in.Uint()",
- reflect.Uint8: "in.Uint8()",
- reflect.Uint16: "in.Uint16()",
- reflect.Uint32: "in.Uint32()",
- reflect.Uint64: "in.Uint64()",
- reflect.Float32: "in.Float32()",
- reflect.Float64: "in.Float64()",
- }
- var primitiveStringDecoders = map[reflect.Kind]string{
- reflect.String: "in.String()",
- reflect.Int: "in.IntStr()",
- reflect.Int8: "in.Int8Str()",
- reflect.Int16: "in.Int16Str()",
- reflect.Int32: "in.Int32Str()",
- reflect.Int64: "in.Int64Str()",
- reflect.Uint: "in.UintStr()",
- reflect.Uint8: "in.Uint8Str()",
- reflect.Uint16: "in.Uint16Str()",
- reflect.Uint32: "in.Uint32Str()",
- reflect.Uint64: "in.Uint64Str()",
- reflect.Uintptr: "in.UintptrStr()",
- reflect.Float32: "in.Float32Str()",
- reflect.Float64: "in.Float64Str()",
- }
- var customDecoders = map[string]string{
- "json.Number": "in.JsonNumber()",
- }
- // genTypeDecoder generates decoding code for the type t, but uses unmarshaler interface if implemented by t.
- func (g *Generator) genTypeDecoder(t reflect.Type, out string, tags fieldTags, indent int) error {
- ws := strings.Repeat(" ", indent)
- unmarshalerIface := reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()
- if reflect.PtrTo(t).Implements(unmarshalerIface) {
- fmt.Fprintln(g.out, ws+"("+out+").UnmarshalEasyJSON(in)")
- return nil
- }
- unmarshalerIface = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
- if reflect.PtrTo(t).Implements(unmarshalerIface) {
- fmt.Fprintln(g.out, ws+"if data := in.Raw(); in.Ok() {")
- fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalJSON(data) )")
- fmt.Fprintln(g.out, ws+"}")
- return nil
- }
- unmarshalerIface = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
- if reflect.PtrTo(t).Implements(unmarshalerIface) {
- fmt.Fprintln(g.out, ws+"if data := in.UnsafeBytes(); in.Ok() {")
- fmt.Fprintln(g.out, ws+" in.AddError( ("+out+").UnmarshalText(data) )")
- fmt.Fprintln(g.out, ws+"}")
- return nil
- }
- err := g.genTypeDecoderNoCheck(t, out, tags, indent)
- return err
- }
- // returns true of the type t implements one of the custom unmarshaler interfaces
- func hasCustomUnmarshaler(t reflect.Type) bool {
- t = reflect.PtrTo(t)
- return t.Implements(reflect.TypeOf((*easyjson.Unmarshaler)(nil)).Elem()) ||
- t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) ||
- t.Implements(reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem())
- }
- // genTypeDecoderNoCheck generates decoding code for the type t.
- func (g *Generator) genTypeDecoderNoCheck(t reflect.Type, out string, tags fieldTags, indent int) error {
- ws := strings.Repeat(" ", indent)
- // Check whether type is primitive, needs to be done after interface check.
- if dec := customDecoders[t.String()]; dec != "" {
- fmt.Fprintln(g.out, ws+out+" = "+dec)
- return nil
- } else if dec := primitiveStringDecoders[t.Kind()]; dec != "" && tags.asString {
- fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
- return nil
- } else if dec := primitiveDecoders[t.Kind()]; dec != "" {
- fmt.Fprintln(g.out, ws+out+" = "+g.getType(t)+"("+dec+")")
- return nil
- }
- switch t.Kind() {
- case reflect.Slice:
- tmpVar := g.uniqueVarName()
- elem := t.Elem()
- if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" {
- fmt.Fprintln(g.out, ws+"if in.IsNull() {")
- fmt.Fprintln(g.out, ws+" in.Skip()")
- fmt.Fprintln(g.out, ws+" "+out+" = nil")
- fmt.Fprintln(g.out, ws+"} else {")
- fmt.Fprintln(g.out, ws+" "+out+" = in.Bytes()")
- fmt.Fprintln(g.out, ws+"}")
- } else {
- capacity := minSliceBytes / elem.Size()
- if capacity == 0 {
- capacity = 1
- }
- fmt.Fprintln(g.out, ws+"if in.IsNull() {")
- fmt.Fprintln(g.out, ws+" in.Skip()")
- fmt.Fprintln(g.out, ws+" "+out+" = nil")
- fmt.Fprintln(g.out, ws+"} else {")
- fmt.Fprintln(g.out, ws+" in.Delim('[')")
- fmt.Fprintln(g.out, ws+" if "+out+" == nil {")
- fmt.Fprintln(g.out, ws+" if !in.IsDelim(']') {")
- fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+", 0, "+fmt.Sprint(capacity)+")")
- fmt.Fprintln(g.out, ws+" } else {")
- fmt.Fprintln(g.out, ws+" "+out+" = "+g.getType(t)+"{}")
- fmt.Fprintln(g.out, ws+" }")
- fmt.Fprintln(g.out, ws+" } else { ")
- fmt.Fprintln(g.out, ws+" "+out+" = ("+out+")[:0]")
- fmt.Fprintln(g.out, ws+" }")
- fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {")
- fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem))
- if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil {
- return err
- }
- fmt.Fprintln(g.out, ws+" "+out+" = append("+out+", "+tmpVar+")")
- fmt.Fprintln(g.out, ws+" in.WantComma()")
- fmt.Fprintln(g.out, ws+" }")
- fmt.Fprintln(g.out, ws+" in.Delim(']')")
- fmt.Fprintln(g.out, ws+"}")
- }
- case reflect.Array:
- iterVar := g.uniqueVarName()
- elem := t.Elem()
- if elem.Kind() == reflect.Uint8 && elem.Name() == "uint8" {
- fmt.Fprintln(g.out, ws+"if in.IsNull() {")
- fmt.Fprintln(g.out, ws+" in.Skip()")
- fmt.Fprintln(g.out, ws+"} else {")
- fmt.Fprintln(g.out, ws+" copy("+out+"[:], in.Bytes())")
- fmt.Fprintln(g.out, ws+"}")
- } else {
- length := t.Len()
- fmt.Fprintln(g.out, ws+"if in.IsNull() {")
- fmt.Fprintln(g.out, ws+" in.Skip()")
- fmt.Fprintln(g.out, ws+"} else {")
- fmt.Fprintln(g.out, ws+" in.Delim('[')")
- fmt.Fprintln(g.out, ws+" "+iterVar+" := 0")
- fmt.Fprintln(g.out, ws+" for !in.IsDelim(']') {")
- fmt.Fprintln(g.out, ws+" if "+iterVar+" < "+fmt.Sprint(length)+" {")
- if err := g.genTypeDecoder(elem, "("+out+")["+iterVar+"]", tags, indent+3); err != nil {
- return err
- }
- fmt.Fprintln(g.out, ws+" "+iterVar+"++")
- fmt.Fprintln(g.out, ws+" } else {")
- fmt.Fprintln(g.out, ws+" in.SkipRecursive()")
- fmt.Fprintln(g.out, ws+" }")
- fmt.Fprintln(g.out, ws+" in.WantComma()")
- fmt.Fprintln(g.out, ws+" }")
- fmt.Fprintln(g.out, ws+" in.Delim(']')")
- fmt.Fprintln(g.out, ws+"}")
- }
- case reflect.Struct:
- dec := g.getDecoderName(t)
- g.addType(t)
- fmt.Fprintln(g.out, ws+dec+"(in, &"+out+")")
- case reflect.Ptr:
- fmt.Fprintln(g.out, ws+"if in.IsNull() {")
- fmt.Fprintln(g.out, ws+" in.Skip()")
- fmt.Fprintln(g.out, ws+" "+out+" = nil")
- fmt.Fprintln(g.out, ws+"} else {")
- fmt.Fprintln(g.out, ws+" if "+out+" == nil {")
- fmt.Fprintln(g.out, ws+" "+out+" = new("+g.getType(t.Elem())+")")
- fmt.Fprintln(g.out, ws+" }")
- if err := g.genTypeDecoder(t.Elem(), "*"+out, tags, indent+1); err != nil {
- return err
- }
- fmt.Fprintln(g.out, ws+"}")
- case reflect.Map:
- key := t.Key()
- keyDec, ok := primitiveStringDecoders[key.Kind()]
- if !ok && !hasCustomUnmarshaler(key) {
- return fmt.Errorf("map type %v not supported: only string and integer keys and types implementing json.Unmarshaler are allowed", key)
- } // else assume the caller knows what they are doing and that the custom unmarshaler performs the translation from string or integer keys to the key type
- elem := t.Elem()
- tmpVar := g.uniqueVarName()
- fmt.Fprintln(g.out, ws+"if in.IsNull() {")
- fmt.Fprintln(g.out, ws+" in.Skip()")
- fmt.Fprintln(g.out, ws+"} else {")
- fmt.Fprintln(g.out, ws+" in.Delim('{')")
- fmt.Fprintln(g.out, ws+" if !in.IsDelim('}') {")
- fmt.Fprintln(g.out, ws+" "+out+" = make("+g.getType(t)+")")
- fmt.Fprintln(g.out, ws+" } else {")
- fmt.Fprintln(g.out, ws+" "+out+" = nil")
- fmt.Fprintln(g.out, ws+" }")
- fmt.Fprintln(g.out, ws+" for !in.IsDelim('}') {")
- if keyDec != "" {
- fmt.Fprintln(g.out, ws+" key := "+g.getType(key)+"("+keyDec+")")
- } else {
- fmt.Fprintln(g.out, ws+" var key "+g.getType(key))
- if err := g.genTypeDecoder(key, "key", tags, indent+2); err != nil {
- return err
- }
- }
- fmt.Fprintln(g.out, ws+" in.WantColon()")
- fmt.Fprintln(g.out, ws+" var "+tmpVar+" "+g.getType(elem))
- if err := g.genTypeDecoder(elem, tmpVar, tags, indent+2); err != nil {
- return err
- }
- fmt.Fprintln(g.out, ws+" ("+out+")[key] = "+tmpVar)
- fmt.Fprintln(g.out, ws+" in.WantComma()")
- fmt.Fprintln(g.out, ws+" }")
- fmt.Fprintln(g.out, ws+" in.Delim('}')")
- fmt.Fprintln(g.out, ws+"}")
- case reflect.Interface:
- if t.NumMethod() != 0 {
- return fmt.Errorf("interface type %v not supported: only interface{} is allowed", t)
- }
- fmt.Fprintln(g.out, ws+"if m, ok := "+out+".(easyjson.Unmarshaler); ok {")
- fmt.Fprintln(g.out, ws+"m.UnmarshalEasyJSON(in)")
- fmt.Fprintln(g.out, ws+"} else if m, ok := "+out+".(json.Unmarshaler); ok {")
- fmt.Fprintln(g.out, ws+"_ = m.UnmarshalJSON(in.Raw())")
- fmt.Fprintln(g.out, ws+"} else {")
- fmt.Fprintln(g.out, ws+" "+out+" = in.Interface()")
- fmt.Fprintln(g.out, ws+"}")
- default:
- return fmt.Errorf("don't know how to decode %v", t)
- }
- return nil
- }
- func (g *Generator) genStructFieldDecoder(t reflect.Type, f reflect.StructField) error {
- jsonName := g.fieldNamer.GetJSONFieldName(t, f)
- tags := parseFieldTags(f)
- if tags.omit {
- return nil
- }
- fmt.Fprintf(g.out, " case %q:\n", jsonName)
- if err := g.genTypeDecoder(f.Type, "out."+f.Name, tags, 3); err != nil {
- return err
- }
- if tags.required {
- fmt.Fprintf(g.out, "%sSet = true\n", f.Name)
- }
- return nil
- }
- func (g *Generator) genRequiredFieldSet(t reflect.Type, f reflect.StructField) {
- tags := parseFieldTags(f)
- if !tags.required {
- return
- }
- fmt.Fprintf(g.out, "var %sSet bool\n", f.Name)
- }
- func (g *Generator) genRequiredFieldCheck(t reflect.Type, f reflect.StructField) {
- jsonName := g.fieldNamer.GetJSONFieldName(t, f)
- tags := parseFieldTags(f)
- if !tags.required {
- return
- }
- g.imports["fmt"] = "fmt"
- fmt.Fprintf(g.out, "if !%sSet {\n", f.Name)
- fmt.Fprintf(g.out, " in.AddError(fmt.Errorf(\"key '%s' is required\"))\n", jsonName)
- fmt.Fprintf(g.out, "}\n")
- }
- func mergeStructFields(fields1, fields2 []reflect.StructField) (fields []reflect.StructField) {
- used := map[string]bool{}
- for _, f := range fields2 {
- used[f.Name] = true
- fields = append(fields, f)
- }
- for _, f := range fields1 {
- if !used[f.Name] {
- fields = append(fields, f)
- }
- }
- return
- }
- func getStructFields(t reflect.Type) ([]reflect.StructField, error) {
- if t.Kind() != reflect.Struct {
- return nil, fmt.Errorf("got %v; expected a struct", t)
- }
- var efields []reflect.StructField
- for i := 0; i < t.NumField(); i++ {
- f := t.Field(i)
- if !f.Anonymous {
- continue
- }
- t1 := f.Type
- if t1.Kind() == reflect.Ptr {
- t1 = t1.Elem()
- }
- fs, err := getStructFields(t1)
- if err != nil {
- return nil, fmt.Errorf("error processing embedded field: %v", err)
- }
- efields = mergeStructFields(efields, fs)
- }
- var fields []reflect.StructField
- for i := 0; i < t.NumField(); i++ {
- f := t.Field(i)
- if f.Anonymous {
- continue
- }
- c := []rune(f.Name)[0]
- if unicode.IsUpper(c) {
- fields = append(fields, f)
- }
- }
- return mergeStructFields(efields, fields), nil
- }
- func (g *Generator) genDecoder(t reflect.Type) error {
- switch t.Kind() {
- case reflect.Slice, reflect.Array, reflect.Map:
- return g.genSliceArrayDecoder(t)
- default:
- return g.genStructDecoder(t)
- }
- }
- func (g *Generator) genSliceArrayDecoder(t reflect.Type) error {
- switch t.Kind() {
- case reflect.Slice, reflect.Array, reflect.Map:
- default:
- return fmt.Errorf("cannot generate encoder/decoder for %v, not a slice/array/map type", t)
- }
- fname := g.getDecoderName(t)
- typ := g.getType(t)
- fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {")
- fmt.Fprintln(g.out, " isTopLevel := in.IsStart()")
- err := g.genTypeDecoderNoCheck(t, "*out", fieldTags{}, 1)
- if err != nil {
- return err
- }
- fmt.Fprintln(g.out, " if isTopLevel {")
- fmt.Fprintln(g.out, " in.Consumed()")
- fmt.Fprintln(g.out, " }")
- fmt.Fprintln(g.out, "}")
- return nil
- }
- func (g *Generator) genStructDecoder(t reflect.Type) error {
- if t.Kind() != reflect.Struct {
- return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct type", t)
- }
- fname := g.getDecoderName(t)
- typ := g.getType(t)
- fmt.Fprintln(g.out, "func "+fname+"(in *jlexer.Lexer, out *"+typ+") {")
- fmt.Fprintln(g.out, " isTopLevel := in.IsStart()")
- fmt.Fprintln(g.out, " if in.IsNull() {")
- fmt.Fprintln(g.out, " if isTopLevel {")
- fmt.Fprintln(g.out, " in.Consumed()")
- fmt.Fprintln(g.out, " }")
- fmt.Fprintln(g.out, " in.Skip()")
- fmt.Fprintln(g.out, " return")
- fmt.Fprintln(g.out, " }")
- // Init embedded pointer fields.
- for i := 0; i < t.NumField(); i++ {
- f := t.Field(i)
- if !f.Anonymous || f.Type.Kind() != reflect.Ptr {
- continue
- }
- fmt.Fprintln(g.out, " out."+f.Name+" = new("+g.getType(f.Type.Elem())+")")
- }
- fs, err := getStructFields(t)
- if err != nil {
- return fmt.Errorf("cannot generate decoder for %v: %v", t, err)
- }
- for _, f := range fs {
- g.genRequiredFieldSet(t, f)
- }
- fmt.Fprintln(g.out, " in.Delim('{')")
- fmt.Fprintln(g.out, " for !in.IsDelim('}') {")
- fmt.Fprintln(g.out, " key := in.UnsafeString()")
- fmt.Fprintln(g.out, " in.WantColon()")
- fmt.Fprintln(g.out, " if in.IsNull() {")
- fmt.Fprintln(g.out, " in.Skip()")
- fmt.Fprintln(g.out, " in.WantComma()")
- fmt.Fprintln(g.out, " continue")
- fmt.Fprintln(g.out, " }")
- fmt.Fprintln(g.out, " switch key {")
- for _, f := range fs {
- if err := g.genStructFieldDecoder(t, f); err != nil {
- return err
- }
- }
- fmt.Fprintln(g.out, " default:")
- if g.disallowUnknownFields {
- fmt.Fprintln(g.out, ` in.AddError(&jlexer.LexerError{
- Offset: in.GetPos(),
- Reason: "unknown field",
- Data: key,
- })`)
- } else {
- fmt.Fprintln(g.out, " in.SkipRecursive()")
- }
- fmt.Fprintln(g.out, " }")
- fmt.Fprintln(g.out, " in.WantComma()")
- fmt.Fprintln(g.out, " }")
- fmt.Fprintln(g.out, " in.Delim('}')")
- fmt.Fprintln(g.out, " if isTopLevel {")
- fmt.Fprintln(g.out, " in.Consumed()")
- fmt.Fprintln(g.out, " }")
- for _, f := range fs {
- g.genRequiredFieldCheck(t, f)
- }
- fmt.Fprintln(g.out, "}")
- return nil
- }
- func (g *Generator) genStructUnmarshaler(t reflect.Type) error {
- switch t.Kind() {
- case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct:
- default:
- return fmt.Errorf("cannot generate encoder/decoder for %v, not a struct/slice/array/map type", t)
- }
- fname := g.getDecoderName(t)
- typ := g.getType(t)
- if !g.noStdMarshalers {
- fmt.Fprintln(g.out, "// UnmarshalJSON supports json.Unmarshaler interface")
- fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalJSON(data []byte) error {")
- fmt.Fprintln(g.out, " r := jlexer.Lexer{Data: data}")
- fmt.Fprintln(g.out, " "+fname+"(&r, v)")
- fmt.Fprintln(g.out, " return r.Error()")
- fmt.Fprintln(g.out, "}")
- }
- fmt.Fprintln(g.out, "// UnmarshalEasyJSON supports easyjson.Unmarshaler interface")
- fmt.Fprintln(g.out, "func (v *"+typ+") UnmarshalEasyJSON(l *jlexer.Lexer) {")
- fmt.Fprintln(g.out, " "+fname+"(l, v)")
- fmt.Fprintln(g.out, "}")
- return nil
- }
|