msg_generate.go 10 KB


  1. //+build ignore
  2. // msg_generate.go is meant to run with go generate. It will use
  3. // go/{importer,types} to track down all the RR struct types. Then for each type
  4. // it will generate pack/unpack methods based on the struct tags. The generated source is
  5. // written to zmsg.go, and is meant to be checked into git.
  6. package main
  7. import (
  8. "bytes"
  9. "fmt"
  10. "go/format"
  11. "go/importer"
  12. "go/types"
  13. "log"
  14. "os"
  15. "strings"
  16. )
  17. var packageHdr = `
  18. // Code generated by "go run msg_generate.go"; DO NOT EDIT.
  19. package dns
  20. `
  21. // getTypeStruct will take a type and the package scope, and return the
  22. // (innermost) struct if the type is considered a RR type (currently defined as
  23. // those structs beginning with a RR_Header, could be redefined as implementing
  24. // the RR interface). The bool return value indicates if embedded structs were
  25. // resolved.
  26. func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
  27. st, ok := t.Underlying().(*types.Struct)
  28. if !ok {
  29. return nil, false
  30. }
  31. if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
  32. return st, false
  33. }
  34. if st.Field(0).Anonymous() {
  35. st, _ := getTypeStruct(st.Field(0).Type(), scope)
  36. return st, true
  37. }
  38. return nil, false
  39. }
  40. func main() {
  41. // Import and type-check the package
  42. pkg, err := importer.Default().Import("github.com/miekg/dns")
  43. fatalIfErr(err)
  44. scope := pkg.Scope()
  45. // Collect actual types (*X)
  46. var namedTypes []string
  47. for _, name := range scope.Names() {
  48. o := scope.Lookup(name)
  49. if o == nil || !o.Exported() {
  50. continue
  51. }
  52. if st, _ := getTypeStruct(o.Type(), scope); st == nil {
  53. continue
  54. }
  55. if name == "PrivateRR" {
  56. continue
  57. }
  58. // Check if corresponding TypeX exists
  59. if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
  60. log.Fatalf("Constant Type%s does not exist.", o.Name())
  61. }
  62. namedTypes = append(namedTypes, o.Name())
  63. }
  64. b := &bytes.Buffer{}
  65. b.WriteString(packageHdr)
  66. fmt.Fprint(b, "// pack*() functions\n\n")
  67. for _, name := range namedTypes {
  68. o := scope.Lookup(name)
  69. st, _ := getTypeStruct(o.Type(), scope)
  70. fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
  71. fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
  72. if err != nil {
  73. return off, err
  74. }
  75. headerEnd := off
  76. `)
  77. for i := 1; i < st.NumFields(); i++ {
  78. o := func(s string) {
  79. fmt.Fprintf(b, s, st.Field(i).Name())
  80. fmt.Fprint(b, `if err != nil {
  81. return off, err
  82. }
  83. `)
  84. }
  85. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  86. switch st.Tag(i) {
  87. case `dns:"-"`: // ignored
  88. case `dns:"txt"`:
  89. o("off, err = packStringTxt(rr.%s, msg, off)\n")
  90. case `dns:"opt"`:
  91. o("off, err = packDataOpt(rr.%s, msg, off)\n")
  92. case `dns:"nsec"`:
  93. o("off, err = packDataNsec(rr.%s, msg, off)\n")
  94. case `dns:"domain-name"`:
  95. o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n")
  96. default:
  97. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  98. }
  99. continue
  100. }
  101. switch {
  102. case st.Tag(i) == `dns:"-"`: // ignored
  103. case st.Tag(i) == `dns:"cdomain-name"`:
  104. o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
  105. case st.Tag(i) == `dns:"domain-name"`:
  106. o("off, err = PackDomainName(rr.%s, msg, off, compression, false)\n")
  107. case st.Tag(i) == `dns:"a"`:
  108. o("off, err = packDataA(rr.%s, msg, off)\n")
  109. case st.Tag(i) == `dns:"aaaa"`:
  110. o("off, err = packDataAAAA(rr.%s, msg, off)\n")
  111. case st.Tag(i) == `dns:"uint48"`:
  112. o("off, err = packUint48(rr.%s, msg, off)\n")
  113. case st.Tag(i) == `dns:"txt"`:
  114. o("off, err = packString(rr.%s, msg, off)\n")
  115. case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
  116. fallthrough
  117. case st.Tag(i) == `dns:"base32"`:
  118. o("off, err = packStringBase32(rr.%s, msg, off)\n")
  119. case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
  120. fallthrough
  121. case st.Tag(i) == `dns:"base64"`:
  122. o("off, err = packStringBase64(rr.%s, msg, off)\n")
  123. case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`):
  124. // directly write instead of using o() so we get the error check in the correct place
  125. field := st.Field(i).Name()
  126. fmt.Fprintf(b, `// Only pack salt if value is not "-", i.e. empty
  127. if rr.%s != "-" {
  128. off, err = packStringHex(rr.%s, msg, off)
  129. if err != nil {
  130. return off, err
  131. }
  132. }
  133. `, field, field)
  134. continue
  135. case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
  136. fallthrough
  137. case st.Tag(i) == `dns:"hex"`:
  138. o("off, err = packStringHex(rr.%s, msg, off)\n")
  139. case st.Tag(i) == `dns:"octet"`:
  140. o("off, err = packStringOctet(rr.%s, msg, off)\n")
  141. case st.Tag(i) == "":
  142. switch st.Field(i).Type().(*types.Basic).Kind() {
  143. case types.Uint8:
  144. o("off, err = packUint8(rr.%s, msg, off)\n")
  145. case types.Uint16:
  146. o("off, err = packUint16(rr.%s, msg, off)\n")
  147. case types.Uint32:
  148. o("off, err = packUint32(rr.%s, msg, off)\n")
  149. case types.Uint64:
  150. o("off, err = packUint64(rr.%s, msg, off)\n")
  151. case types.String:
  152. o("off, err = packString(rr.%s, msg, off)\n")
  153. default:
  154. log.Fatalln(name, st.Field(i).Name())
  155. }
  156. default:
  157. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  158. }
  159. }
  160. // We have packed everything, only now we know the rdlength of this RR
  161. fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)")
  162. fmt.Fprintln(b, "return off, nil }\n")
  163. }
  164. fmt.Fprint(b, "// unpack*() functions\n\n")
  165. for _, name := range namedTypes {
  166. o := scope.Lookup(name)
  167. st, _ := getTypeStruct(o.Type(), scope)
  168. fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
  169. fmt.Fprintf(b, "rr := new(%s)\n", name)
  170. fmt.Fprint(b, "rr.Hdr = h\n")
  171. fmt.Fprint(b, `if noRdata(h) {
  172. return rr, off, nil
  173. }
  174. var err error
  175. rdStart := off
  176. _ = rdStart
  177. `)
  178. for i := 1; i < st.NumFields(); i++ {
  179. o := func(s string) {
  180. fmt.Fprintf(b, s, st.Field(i).Name())
  181. fmt.Fprint(b, `if err != nil {
  182. return rr, off, err
  183. }
  184. `)
  185. }
  186. // size-* are special, because they reference a struct member we should use for the length.
  187. if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
  188. structMember := structMember(st.Tag(i))
  189. structTag := structTag(st.Tag(i))
  190. switch structTag {
  191. case "hex":
  192. fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  193. case "base32":
  194. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  195. case "base64":
  196. fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
  197. default:
  198. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  199. }
  200. fmt.Fprint(b, `if err != nil {
  201. return rr, off, err
  202. }
  203. `)
  204. continue
  205. }
  206. if _, ok := st.Field(i).Type().(*types.Slice); ok {
  207. switch st.Tag(i) {
  208. case `dns:"-"`: // ignored
  209. case `dns:"txt"`:
  210. o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
  211. case `dns:"opt"`:
  212. o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
  213. case `dns:"nsec"`:
  214. o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
  215. case `dns:"domain-name"`:
  216. o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  217. default:
  218. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  219. }
  220. continue
  221. }
  222. switch st.Tag(i) {
  223. case `dns:"-"`: // ignored
  224. case `dns:"cdomain-name"`:
  225. fallthrough
  226. case `dns:"domain-name"`:
  227. o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
  228. case `dns:"a"`:
  229. o("rr.%s, off, err = unpackDataA(msg, off)\n")
  230. case `dns:"aaaa"`:
  231. o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
  232. case `dns:"uint48"`:
  233. o("rr.%s, off, err = unpackUint48(msg, off)\n")
  234. case `dns:"txt"`:
  235. o("rr.%s, off, err = unpackString(msg, off)\n")
  236. case `dns:"base32"`:
  237. o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  238. case `dns:"base64"`:
  239. o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  240. case `dns:"hex"`:
  241. o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
  242. case `dns:"octet"`:
  243. o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
  244. case "":
  245. switch st.Field(i).Type().(*types.Basic).Kind() {
  246. case types.Uint8:
  247. o("rr.%s, off, err = unpackUint8(msg, off)\n")
  248. case types.Uint16:
  249. o("rr.%s, off, err = unpackUint16(msg, off)\n")
  250. case types.Uint32:
  251. o("rr.%s, off, err = unpackUint32(msg, off)\n")
  252. case types.Uint64:
  253. o("rr.%s, off, err = unpackUint64(msg, off)\n")
  254. case types.String:
  255. o("rr.%s, off, err = unpackString(msg, off)\n")
  256. default:
  257. log.Fatalln(name, st.Field(i).Name())
  258. }
  259. default:
  260. log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
  261. }
  262. // If we've hit len(msg) we return without error.
  263. if i < st.NumFields()-1 {
  264. fmt.Fprintf(b, `if off == len(msg) {
  265. return rr, off, nil
  266. }
  267. `)
  268. }
  269. }
  270. fmt.Fprintf(b, "return rr, off, err }\n\n")
  271. }
  272. // Generate typeToUnpack map
  273. fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
  274. for _, name := range namedTypes {
  275. if name == "RFC3597" {
  276. continue
  277. }
  278. fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
  279. }
  280. fmt.Fprintln(b, "}\n")
  281. // gofmt
  282. res, err := format.Source(b.Bytes())
  283. if err != nil {
  284. b.WriteTo(os.Stderr)
  285. log.Fatal(err)
  286. }
  287. // write result
  288. f, err := os.Create("zmsg.go")
  289. fatalIfErr(err)
  290. defer f.Close()
  291. f.Write(res)
  292. }
  293. // structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
  294. func structMember(s string) string {
  295. fields := strings.Split(s, ":")
  296. if len(fields) == 0 {
  297. return ""
  298. }
  299. f := fields[len(fields)-1]
  300. // f should have a closing "
  301. if len(f) > 1 {
  302. return f[:len(f)-1]
  303. }
  304. return f
  305. }
  306. // structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
  307. func structTag(s string) string {
  308. fields := strings.Split(s, ":")
  309. if len(fields) < 2 {
  310. return ""
  311. }
  312. return fields[1][len("\"size-"):]
  313. }
  314. func fatalIfErr(err error) {
  315. if err != nil {
  316. log.Fatal(err)
  317. }
  318. }