query.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. package dsn
  2. import (
  3. "encoding"
  4. "net/url"
  5. "reflect"
  6. "runtime"
  7. "strconv"
  8. "strings"
  9. )
  10. const (
  11. _tagID = "dsn"
  12. _queryPrefix = "query."
  13. )
  14. // InvalidBindError describes an invalid argument passed to DecodeQuery.
  15. // (The argument to DecodeQuery must be a non-nil pointer.)
  16. type InvalidBindError struct {
  17. Type reflect.Type
  18. }
  19. func (e *InvalidBindError) Error() string {
  20. if e.Type == nil {
  21. return "Bind(nil)"
  22. }
  23. if e.Type.Kind() != reflect.Ptr {
  24. return "Bind(non-pointer " + e.Type.String() + ")"
  25. }
  26. return "Bind(nil " + e.Type.String() + ")"
  27. }
  28. // BindTypeError describes a query value that was
  29. // not appropriate for a value of a specific Go type.
  30. type BindTypeError struct {
  31. Value string
  32. Type reflect.Type
  33. }
  34. func (e *BindTypeError) Error() string {
  35. return "cannot decode " + e.Value + " into Go value of type " + e.Type.String()
  36. }
  37. type assignFunc func(v reflect.Value, to tagOpt) error
  38. func stringsAssignFunc(val string) assignFunc {
  39. return func(v reflect.Value, to tagOpt) error {
  40. if v.Kind() != reflect.String || !v.CanSet() {
  41. return &BindTypeError{Value: "string", Type: v.Type()}
  42. }
  43. if val == "" {
  44. v.SetString(to.Default)
  45. } else {
  46. v.SetString(val)
  47. }
  48. return nil
  49. }
  50. }
  51. // bindQuery parses url.Values and stores the result in the value pointed to by v.
  52. // if v is nil or not a pointer, bindQuery returns an InvalidDecodeError
  53. func bindQuery(query url.Values, v interface{}, assignFuncs map[string]assignFunc) (url.Values, error) {
  54. if assignFuncs == nil {
  55. assignFuncs = make(map[string]assignFunc)
  56. }
  57. d := decodeState{
  58. data: query,
  59. used: make(map[string]bool),
  60. assignFuncs: assignFuncs,
  61. }
  62. err := d.decode(v)
  63. ret := d.unused()
  64. return ret, err
  65. }
  66. type tagOpt struct {
  67. Name string
  68. Default string
  69. }
  70. func parseTag(tag string) tagOpt {
  71. vs := strings.SplitN(tag, ",", 2)
  72. if len(vs) == 2 {
  73. return tagOpt{Name: vs[0], Default: vs[1]}
  74. }
  75. return tagOpt{Name: vs[0]}
  76. }
  77. type decodeState struct {
  78. data url.Values
  79. used map[string]bool
  80. assignFuncs map[string]assignFunc
  81. }
  82. func (d *decodeState) unused() url.Values {
  83. ret := make(url.Values)
  84. for k, v := range d.data {
  85. if !d.used[k] {
  86. ret[k] = v
  87. }
  88. }
  89. return ret
  90. }
  91. func (d *decodeState) decode(v interface{}) (err error) {
  92. defer func() {
  93. if r := recover(); r != nil {
  94. if _, ok := r.(runtime.Error); ok {
  95. panic(r)
  96. }
  97. err = r.(error)
  98. }
  99. }()
  100. rv := reflect.ValueOf(v)
  101. if rv.Kind() != reflect.Ptr || rv.IsNil() {
  102. return &InvalidBindError{reflect.TypeOf(v)}
  103. }
  104. return d.root(rv)
  105. }
  106. func (d *decodeState) root(v reflect.Value) error {
  107. var tu encoding.TextUnmarshaler
  108. tu, v = d.indirect(v)
  109. if tu != nil {
  110. return tu.UnmarshalText([]byte(d.data.Encode()))
  111. }
  112. // TODO support map, slice as root
  113. if v.Kind() != reflect.Struct {
  114. return &BindTypeError{Value: d.data.Encode(), Type: v.Type()}
  115. }
  116. tv := v.Type()
  117. for i := 0; i < tv.NumField(); i++ {
  118. fv := v.Field(i)
  119. field := tv.Field(i)
  120. to := parseTag(field.Tag.Get(_tagID))
  121. if to.Name == "-" {
  122. continue
  123. }
  124. if af, ok := d.assignFuncs[to.Name]; ok {
  125. if err := af(fv, tagOpt{}); err != nil {
  126. return err
  127. }
  128. continue
  129. }
  130. if !strings.HasPrefix(to.Name, _queryPrefix) {
  131. continue
  132. }
  133. to.Name = to.Name[len(_queryPrefix):]
  134. if err := d.value(fv, "", to); err != nil {
  135. return err
  136. }
  137. }
  138. return nil
  139. }
  140. func combinekey(prefix string, to tagOpt) string {
  141. key := to.Name
  142. if prefix != "" {
  143. key = prefix + "." + key
  144. }
  145. return key
  146. }
  147. func (d *decodeState) value(v reflect.Value, prefix string, to tagOpt) (err error) {
  148. key := combinekey(prefix, to)
  149. d.used[key] = true
  150. var tu encoding.TextUnmarshaler
  151. tu, v = d.indirect(v)
  152. if tu != nil {
  153. if val, ok := d.data[key]; ok {
  154. return tu.UnmarshalText([]byte(val[0]))
  155. }
  156. if to.Default != "" {
  157. return tu.UnmarshalText([]byte(to.Default))
  158. }
  159. return
  160. }
  161. switch v.Kind() {
  162. case reflect.Bool:
  163. err = d.valueBool(v, prefix, to)
  164. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
  165. err = d.valueInt64(v, prefix, to)
  166. case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
  167. err = d.valueUint64(v, prefix, to)
  168. case reflect.Float32, reflect.Float64:
  169. err = d.valueFloat64(v, prefix, to)
  170. case reflect.String:
  171. err = d.valueString(v, prefix, to)
  172. case reflect.Slice:
  173. err = d.valueSlice(v, prefix, to)
  174. case reflect.Struct:
  175. err = d.valueStruct(v, prefix, to)
  176. case reflect.Ptr:
  177. if !d.hasKey(combinekey(prefix, to)) {
  178. break
  179. }
  180. if !v.CanSet() {
  181. break
  182. }
  183. nv := reflect.New(v.Type().Elem())
  184. v.Set(nv)
  185. err = d.value(nv, prefix, to)
  186. }
  187. return
  188. }
  189. func (d *decodeState) hasKey(key string) bool {
  190. for k := range d.data {
  191. if strings.HasPrefix(k, key+".") || k == key {
  192. return true
  193. }
  194. }
  195. return false
  196. }
  197. func (d *decodeState) valueBool(v reflect.Value, prefix string, to tagOpt) error {
  198. key := combinekey(prefix, to)
  199. val := d.data.Get(key)
  200. if val == "" {
  201. if to.Default == "" {
  202. return nil
  203. }
  204. val = to.Default
  205. }
  206. return d.setBool(v, val)
  207. }
  208. func (d *decodeState) setBool(v reflect.Value, val string) error {
  209. bval, err := strconv.ParseBool(val)
  210. if err != nil {
  211. return &BindTypeError{Value: val, Type: v.Type()}
  212. }
  213. v.SetBool(bval)
  214. return nil
  215. }
  216. func (d *decodeState) valueInt64(v reflect.Value, prefix string, to tagOpt) error {
  217. key := combinekey(prefix, to)
  218. val := d.data.Get(key)
  219. if val == "" {
  220. if to.Default == "" {
  221. return nil
  222. }
  223. val = to.Default
  224. }
  225. return d.setInt64(v, val)
  226. }
  227. func (d *decodeState) setInt64(v reflect.Value, val string) error {
  228. ival, err := strconv.ParseInt(val, 10, 64)
  229. if err != nil {
  230. return &BindTypeError{Value: val, Type: v.Type()}
  231. }
  232. v.SetInt(ival)
  233. return nil
  234. }
  235. func (d *decodeState) valueUint64(v reflect.Value, prefix string, to tagOpt) error {
  236. key := combinekey(prefix, to)
  237. val := d.data.Get(key)
  238. if val == "" {
  239. if to.Default == "" {
  240. return nil
  241. }
  242. val = to.Default
  243. }
  244. return d.setUint64(v, val)
  245. }
  246. func (d *decodeState) setUint64(v reflect.Value, val string) error {
  247. uival, err := strconv.ParseUint(val, 10, 64)
  248. if err != nil {
  249. return &BindTypeError{Value: val, Type: v.Type()}
  250. }
  251. v.SetUint(uival)
  252. return nil
  253. }
  254. func (d *decodeState) valueFloat64(v reflect.Value, prefix string, to tagOpt) error {
  255. key := combinekey(prefix, to)
  256. val := d.data.Get(key)
  257. if val == "" {
  258. if to.Default == "" {
  259. return nil
  260. }
  261. val = to.Default
  262. }
  263. return d.setFloat64(v, val)
  264. }
  265. func (d *decodeState) setFloat64(v reflect.Value, val string) error {
  266. fval, err := strconv.ParseFloat(val, 64)
  267. if err != nil {
  268. return &BindTypeError{Value: val, Type: v.Type()}
  269. }
  270. v.SetFloat(fval)
  271. return nil
  272. }
  273. func (d *decodeState) valueString(v reflect.Value, prefix string, to tagOpt) error {
  274. key := combinekey(prefix, to)
  275. val := d.data.Get(key)
  276. if val == "" {
  277. if to.Default == "" {
  278. return nil
  279. }
  280. val = to.Default
  281. }
  282. return d.setString(v, val)
  283. }
  284. func (d *decodeState) setString(v reflect.Value, val string) error {
  285. v.SetString(val)
  286. return nil
  287. }
  288. func (d *decodeState) valueSlice(v reflect.Value, prefix string, to tagOpt) error {
  289. key := combinekey(prefix, to)
  290. strs, ok := d.data[key]
  291. if !ok {
  292. strs = strings.Split(to.Default, ",")
  293. }
  294. if len(strs) == 0 {
  295. return nil
  296. }
  297. et := v.Type().Elem()
  298. var setFunc func(reflect.Value, string) error
  299. switch et.Kind() {
  300. case reflect.Bool:
  301. setFunc = d.setBool
  302. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
  303. setFunc = d.setInt64
  304. case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
  305. setFunc = d.setUint64
  306. case reflect.Float32, reflect.Float64:
  307. setFunc = d.setFloat64
  308. case reflect.String:
  309. setFunc = d.setString
  310. default:
  311. return &BindTypeError{Type: et, Value: strs[0]}
  312. }
  313. vals := reflect.MakeSlice(v.Type(), len(strs), len(strs))
  314. for i, str := range strs {
  315. if err := setFunc(vals.Index(i), str); err != nil {
  316. return err
  317. }
  318. }
  319. if v.CanSet() {
  320. v.Set(vals)
  321. }
  322. return nil
  323. }
  324. func (d *decodeState) valueStruct(v reflect.Value, prefix string, to tagOpt) error {
  325. tv := v.Type()
  326. for i := 0; i < tv.NumField(); i++ {
  327. fv := v.Field(i)
  328. field := tv.Field(i)
  329. fto := parseTag(field.Tag.Get(_tagID))
  330. if fto.Name == "-" {
  331. continue
  332. }
  333. if af, ok := d.assignFuncs[fto.Name]; ok {
  334. if err := af(fv, tagOpt{}); err != nil {
  335. return err
  336. }
  337. continue
  338. }
  339. if !strings.HasPrefix(fto.Name, _queryPrefix) {
  340. continue
  341. }
  342. fto.Name = fto.Name[len(_queryPrefix):]
  343. if err := d.value(fv, to.Name, fto); err != nil {
  344. return err
  345. }
  346. }
  347. return nil
  348. }
  349. func (d *decodeState) indirect(v reflect.Value) (encoding.TextUnmarshaler, reflect.Value) {
  350. v0 := v
  351. haveAddr := false
  352. if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() {
  353. haveAddr = true
  354. v = v.Addr()
  355. }
  356. for {
  357. if v.Kind() == reflect.Interface && !v.IsNil() {
  358. e := v.Elem()
  359. if e.Kind() == reflect.Ptr && !e.IsNil() && e.Elem().Kind() == reflect.Ptr {
  360. haveAddr = false
  361. v = e
  362. continue
  363. }
  364. }
  365. if v.Kind() != reflect.Ptr {
  366. break
  367. }
  368. if v.Elem().Kind() != reflect.Ptr && v.CanSet() {
  369. break
  370. }
  371. if v.IsNil() {
  372. v.Set(reflect.New(v.Type().Elem()))
  373. }
  374. if v.Type().NumMethod() > 0 {
  375. if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
  376. return u, reflect.Value{}
  377. }
  378. }
  379. if haveAddr {
  380. v = v0
  381. haveAddr = false
  382. } else {
  383. v = v.Elem()
  384. }
  385. }
  386. return nil, v
  387. }