123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422 |
- package dsn
- import (
- "encoding"
- "net/url"
- "reflect"
- "runtime"
- "strconv"
- "strings"
- )
- const (
- _tagID = "dsn"
- _queryPrefix = "query."
- )
- // InvalidBindError describes an invalid argument passed to DecodeQuery.
- // (The argument to DecodeQuery must be a non-nil pointer.)
- type InvalidBindError struct {
- Type reflect.Type
- }
- func (e *InvalidBindError) Error() string {
- if e.Type == nil {
- return "Bind(nil)"
- }
- if e.Type.Kind() != reflect.Ptr {
- return "Bind(non-pointer " + e.Type.String() + ")"
- }
- return "Bind(nil " + e.Type.String() + ")"
- }
- // BindTypeError describes a query value that was
- // not appropriate for a value of a specific Go type.
- type BindTypeError struct {
- Value string
- Type reflect.Type
- }
- func (e *BindTypeError) Error() string {
- return "cannot decode " + e.Value + " into Go value of type " + e.Type.String()
- }
- type assignFunc func(v reflect.Value, to tagOpt) error
- func stringsAssignFunc(val string) assignFunc {
- return func(v reflect.Value, to tagOpt) error {
- if v.Kind() != reflect.String || !v.CanSet() {
- return &BindTypeError{Value: "string", Type: v.Type()}
- }
- if val == "" {
- v.SetString(to.Default)
- } else {
- v.SetString(val)
- }
- return nil
- }
- }
- // bindQuery parses url.Values and stores the result in the value pointed to by v.
- // if v is nil or not a pointer, bindQuery returns an InvalidDecodeError
- func bindQuery(query url.Values, v interface{}, assignFuncs map[string]assignFunc) (url.Values, error) {
- if assignFuncs == nil {
- assignFuncs = make(map[string]assignFunc)
- }
- d := decodeState{
- data: query,
- used: make(map[string]bool),
- assignFuncs: assignFuncs,
- }
- err := d.decode(v)
- ret := d.unused()
- return ret, err
- }
- type tagOpt struct {
- Name string
- Default string
- }
- func parseTag(tag string) tagOpt {
- vs := strings.SplitN(tag, ",", 2)
- if len(vs) == 2 {
- return tagOpt{Name: vs[0], Default: vs[1]}
- }
- return tagOpt{Name: vs[0]}
- }
- type decodeState struct {
- data url.Values
- used map[string]bool
- assignFuncs map[string]assignFunc
- }
- func (d *decodeState) unused() url.Values {
- ret := make(url.Values)
- for k, v := range d.data {
- if !d.used[k] {
- ret[k] = v
- }
- }
- return ret
- }
- func (d *decodeState) decode(v interface{}) (err error) {
- defer func() {
- if r := recover(); r != nil {
- if _, ok := r.(runtime.Error); ok {
- panic(r)
- }
- err = r.(error)
- }
- }()
- rv := reflect.ValueOf(v)
- if rv.Kind() != reflect.Ptr || rv.IsNil() {
- return &InvalidBindError{reflect.TypeOf(v)}
- }
- return d.root(rv)
- }
- func (d *decodeState) root(v reflect.Value) error {
- var tu encoding.TextUnmarshaler
- tu, v = d.indirect(v)
- if tu != nil {
- return tu.UnmarshalText([]byte(d.data.Encode()))
- }
- // TODO support map, slice as root
- if v.Kind() != reflect.Struct {
- return &BindTypeError{Value: d.data.Encode(), Type: v.Type()}
- }
- tv := v.Type()
- for i := 0; i < tv.NumField(); i++ {
- fv := v.Field(i)
- field := tv.Field(i)
- to := parseTag(field.Tag.Get(_tagID))
- if to.Name == "-" {
- continue
- }
- if af, ok := d.assignFuncs[to.Name]; ok {
- if err := af(fv, tagOpt{}); err != nil {
- return err
- }
- continue
- }
- if !strings.HasPrefix(to.Name, _queryPrefix) {
- continue
- }
- to.Name = to.Name[len(_queryPrefix):]
- if err := d.value(fv, "", to); err != nil {
- return err
- }
- }
- return nil
- }
- func combinekey(prefix string, to tagOpt) string {
- key := to.Name
- if prefix != "" {
- key = prefix + "." + key
- }
- return key
- }
- func (d *decodeState) value(v reflect.Value, prefix string, to tagOpt) (err error) {
- key := combinekey(prefix, to)
- d.used[key] = true
- var tu encoding.TextUnmarshaler
- tu, v = d.indirect(v)
- if tu != nil {
- if val, ok := d.data[key]; ok {
- return tu.UnmarshalText([]byte(val[0]))
- }
- if to.Default != "" {
- return tu.UnmarshalText([]byte(to.Default))
- }
- return
- }
- switch v.Kind() {
- case reflect.Bool:
- err = d.valueBool(v, prefix, to)
- case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
- err = d.valueInt64(v, prefix, to)
- case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
- err = d.valueUint64(v, prefix, to)
- case reflect.Float32, reflect.Float64:
- err = d.valueFloat64(v, prefix, to)
- case reflect.String:
- err = d.valueString(v, prefix, to)
- case reflect.Slice:
- err = d.valueSlice(v, prefix, to)
- case reflect.Struct:
- err = d.valueStruct(v, prefix, to)
- case reflect.Ptr:
- if !d.hasKey(combinekey(prefix, to)) {
- break
- }
- if !v.CanSet() {
- break
- }
- nv := reflect.New(v.Type().Elem())
- v.Set(nv)
- err = d.value(nv, prefix, to)
- }
- return
- }
- func (d *decodeState) hasKey(key string) bool {
- for k := range d.data {
- if strings.HasPrefix(k, key+".") || k == key {
- return true
- }
- }
- return false
- }
- func (d *decodeState) valueBool(v reflect.Value, prefix string, to tagOpt) error {
- key := combinekey(prefix, to)
- val := d.data.Get(key)
- if val == "" {
- if to.Default == "" {
- return nil
- }
- val = to.Default
- }
- return d.setBool(v, val)
- }
- func (d *decodeState) setBool(v reflect.Value, val string) error {
- bval, err := strconv.ParseBool(val)
- if err != nil {
- return &BindTypeError{Value: val, Type: v.Type()}
- }
- v.SetBool(bval)
- return nil
- }
- func (d *decodeState) valueInt64(v reflect.Value, prefix string, to tagOpt) error {
- key := combinekey(prefix, to)
- val := d.data.Get(key)
- if val == "" {
- if to.Default == "" {
- return nil
- }
- val = to.Default
- }
- return d.setInt64(v, val)
- }
- func (d *decodeState) setInt64(v reflect.Value, val string) error {
- ival, err := strconv.ParseInt(val, 10, 64)
- if err != nil {
- return &BindTypeError{Value: val, Type: v.Type()}
- }
- v.SetInt(ival)
- return nil
- }
- func (d *decodeState) valueUint64(v reflect.Value, prefix string, to tagOpt) error {
- key := combinekey(prefix, to)
- val := d.data.Get(key)
- if val == "" {
- if to.Default == "" {
- return nil
- }
- val = to.Default
- }
- return d.setUint64(v, val)
- }
- func (d *decodeState) setUint64(v reflect.Value, val string) error {
- uival, err := strconv.ParseUint(val, 10, 64)
- if err != nil {
- return &BindTypeError{Value: val, Type: v.Type()}
- }
- v.SetUint(uival)
- return nil
- }
- func (d *decodeState) valueFloat64(v reflect.Value, prefix string, to tagOpt) error {
- key := combinekey(prefix, to)
- val := d.data.Get(key)
- if val == "" {
- if to.Default == "" {
- return nil
- }
- val = to.Default
- }
- return d.setFloat64(v, val)
- }
- func (d *decodeState) setFloat64(v reflect.Value, val string) error {
- fval, err := strconv.ParseFloat(val, 64)
- if err != nil {
- return &BindTypeError{Value: val, Type: v.Type()}
- }
- v.SetFloat(fval)
- return nil
- }
- func (d *decodeState) valueString(v reflect.Value, prefix string, to tagOpt) error {
- key := combinekey(prefix, to)
- val := d.data.Get(key)
- if val == "" {
- if to.Default == "" {
- return nil
- }
- val = to.Default
- }
- return d.setString(v, val)
- }
- func (d *decodeState) setString(v reflect.Value, val string) error {
- v.SetString(val)
- return nil
- }
- func (d *decodeState) valueSlice(v reflect.Value, prefix string, to tagOpt) error {
- key := combinekey(prefix, to)
- strs, ok := d.data[key]
- if !ok {
- strs = strings.Split(to.Default, ",")
- }
- if len(strs) == 0 {
- return nil
- }
- et := v.Type().Elem()
- var setFunc func(reflect.Value, string) error
- switch et.Kind() {
- case reflect.Bool:
- setFunc = d.setBool
- case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
- setFunc = d.setInt64
- case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
- setFunc = d.setUint64
- case reflect.Float32, reflect.Float64:
- setFunc = d.setFloat64
- case reflect.String:
- setFunc = d.setString
- default:
- return &BindTypeError{Type: et, Value: strs[0]}
- }
- vals := reflect.MakeSlice(v.Type(), len(strs), len(strs))
- for i, str := range strs {
- if err := setFunc(vals.Index(i), str); err != nil {
- return err
- }
- }
- if v.CanSet() {
- v.Set(vals)
- }
- return nil
- }
- func (d *decodeState) valueStruct(v reflect.Value, prefix string, to tagOpt) error {
- tv := v.Type()
- for i := 0; i < tv.NumField(); i++ {
- fv := v.Field(i)
- field := tv.Field(i)
- fto := parseTag(field.Tag.Get(_tagID))
- if fto.Name == "-" {
- continue
- }
- if af, ok := d.assignFuncs[fto.Name]; ok {
- if err := af(fv, tagOpt{}); err != nil {
- return err
- }
- continue
- }
- if !strings.HasPrefix(fto.Name, _queryPrefix) {
- continue
- }
- fto.Name = fto.Name[len(_queryPrefix):]
- if err := d.value(fv, to.Name, fto); err != nil {
- return err
- }
- }
- return nil
- }
- func (d *decodeState) indirect(v reflect.Value) (encoding.TextUnmarshaler, reflect.Value) {
- v0 := v
- haveAddr := false
- if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
- haveAddr = true
- v = v.Addr()
- }
- for {
- if v.Kind() == reflect.Interface && !v.IsNil() {
- e := v.Elem()
- if e.Kind() == reflect.Ptr && !e.IsNil() && e.Elem().Kind() == reflect.Ptr {
- haveAddr = false
- v = e
- continue
- }
- }
- if v.Kind() != reflect.Ptr {
- break
- }
- if v.Elem().Kind() != reflect.Ptr && v.CanSet() {
- break
- }
- if v.IsNil() {
- v.Set(reflect.New(v.Type().Elem()))
- }
- if v.Type().NumMethod() > 0 {
- if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
- return u, reflect.Value{}
- }
- }
- if haveAddr {
- v = v0
- haveAddr = false
- } else {
- v = v.Elem()
- }
- }
- return nil, v
- }
|