generator.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. // Copyright 2018 Twitch Interactive, Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License"). You may not
  4. // use this file except in compliance with the License. A copy of the License is
  5. // located at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // or in the "license" file accompanying this file. This file is distributed on
  10. // an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
  11. // express or implied. See the License for the specific language governing
  12. // permissions and limitations under the License.
  13. package main
  14. import (
  15. "bufio"
  16. "bytes"
  17. "compress/gzip"
  18. "fmt"
  19. "go/parser"
  20. "go/printer"
  21. "go/token"
  22. "path"
  23. "strconv"
  24. "strings"
  25. "go-common/app/tool/liverpc/protoc-gen-liverpc/gen"
  26. "go-common/app/tool/liverpc/protoc-gen-liverpc/gen/stringutils"
  27. "go-common/app/tool/liverpc/protoc-gen-liverpc/gen/typemap"
  28. "github.com/golang/protobuf/proto"
  29. "github.com/golang/protobuf/protoc-gen-go/descriptor"
  30. plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
  31. "github.com/pkg/errors"
  32. )
  33. type liverpc struct {
  34. filesHandled int
  35. reg *typemap.Registry
  36. // Map to record whether we've built each package
  37. pkgs map[string]string
  38. pkgNamesInUse map[string]bool
  39. importPrefix string // String to prefix to imported package file names.
  40. importMap map[string]string // Mapping from .proto file name to import path.
  41. // Package naming:
  42. genPkgName string // Name of the package that we're generating
  43. fileToGoPackageName map[*descriptor.FileDescriptorProto]string
  44. // List of files that were inputs to the generator. We need to hold this in
  45. // the struct so we can write a header for the file that lists its inputs.
  46. genFiles []*descriptor.FileDescriptorProto
  47. // Output buffer that holds the bytes we want to write out for a single file.
  48. // Gets reset after working on a file.
  49. output *bytes.Buffer
  50. }
  51. func liveRPCGenerator() *liverpc {
  52. t := &liverpc{
  53. pkgs: make(map[string]string),
  54. pkgNamesInUse: make(map[string]bool),
  55. importMap: make(map[string]string),
  56. fileToGoPackageName: make(map[*descriptor.FileDescriptorProto]string),
  57. output: bytes.NewBuffer(nil),
  58. }
  59. return t
  60. }
  61. func (t *liverpc) Generate(in *plugin.CodeGeneratorRequest) *plugin.CodeGeneratorResponse {
  62. params, err := parseCommandLineParams(in.GetParameter())
  63. if err != nil {
  64. gen.Fail("could not parse parameters passed to --liverpc_out", err.Error())
  65. }
  66. t.importPrefix = params.importPrefix
  67. t.importMap = params.importMap
  68. t.genFiles = gen.FilesToGenerate(in)
  69. // Collect information on types.
  70. t.reg = typemap.New(in.ProtoFile)
  71. t.registerPackageName("context")
  72. t.registerPackageName("ioutil")
  73. t.registerPackageName("proto")
  74. t.registerPackageName("liverpc")
  75. // Time to figure out package names of objects defined in protobuf. First,
  76. // we'll figure out the name for the package we're generating.
  77. genPkgName, err := deduceGenPkgName(t.genFiles)
  78. if err != nil {
  79. gen.Fail(err.Error())
  80. }
  81. t.genPkgName = genPkgName
  82. // Next, we need to pick names for all the files that are dependencies.
  83. for _, f := range in.ProtoFile {
  84. if fileDescSliceContains(t.genFiles, f) {
  85. // This is a file we are generating. It gets the shared package name.
  86. t.fileToGoPackageName[f] = t.genPkgName
  87. } else {
  88. // This is a dependency. Use its package name.
  89. name := f.GetPackage()
  90. if name == "" {
  91. name = stringutils.BaseName(f.GetName())
  92. }
  93. name = stringutils.CleanIdentifier(name)
  94. alias := t.registerPackageName(name)
  95. t.fileToGoPackageName[f] = alias
  96. }
  97. }
  98. // Showtime! Generate the response.
  99. resp := new(plugin.CodeGeneratorResponse)
  100. var servicesNames []string
  101. for _, f := range t.genFiles {
  102. respFile := t.generate(f)
  103. for _, s := range f.Service {
  104. servicesNames = append(servicesNames, *s.Name)
  105. }
  106. if respFile != nil {
  107. resp.File = append(resp.File, respFile)
  108. }
  109. }
  110. // generate a temp file of service names
  111. // because a protobuf plugin can only generate for a single package
  112. // therefore we generate these temp files for other script to combine
  113. // a single client for all packages
  114. var filename = "client." + genPkgName + ".txt"
  115. var respFile = &plugin.CodeGeneratorResponse_File{}
  116. respFile.Name = &filename
  117. var content = strings.Join(servicesNames, "\n")
  118. content += "\n"
  119. respFile.Content = &content
  120. resp.File = append(resp.File, respFile)
  121. return resp
  122. }
  123. func (t *liverpc) registerPackageName(name string) (alias string) {
  124. alias = name
  125. i := 1
  126. for t.pkgNamesInUse[alias] {
  127. alias = name + strconv.Itoa(i)
  128. i++
  129. }
  130. t.pkgNamesInUse[alias] = true
  131. t.pkgs[name] = alias
  132. return alias
  133. }
  134. func (t *liverpc) generate(file *descriptor.FileDescriptorProto) *plugin.CodeGeneratorResponse_File {
  135. resp := new(plugin.CodeGeneratorResponse_File)
  136. if len(file.Service) == 0 {
  137. return nil
  138. }
  139. t.generateFileHeader(file)
  140. t.generateImports(file)
  141. if t.filesHandled == 0 {
  142. t.generateUtilImports()
  143. }
  144. // For each service, generate client stubs and server
  145. for i, service := range file.Service {
  146. t.generateService(file, service, i)
  147. }
  148. // Util functions only generated once per package
  149. if t.filesHandled == 0 {
  150. t.generateUtils()
  151. }
  152. t.generateFileDescriptor(file)
  153. resp.Name = proto.String(goFileName(file))
  154. resp.Content = proto.String(t.formattedOutput())
  155. t.output.Reset()
  156. t.filesHandled++
  157. return resp
  158. }
  159. func (t *liverpc) generateFileHeader(file *descriptor.FileDescriptorProto) {
  160. t.P("// Code generated by protoc-gen-liverpc ", gen.Version, ", DO NOT EDIT.")
  161. t.P("// source: ", file.GetName())
  162. t.P()
  163. if t.filesHandled == 0 {
  164. t.P("/*")
  165. t.P("Package ", t.genPkgName, " is a generated liverpc stub package.")
  166. t.P("This code was generated with go-common/app/tool/liverpc/protoc-gen-liverpc ", gen.Version, ".")
  167. t.P()
  168. comment, err := t.reg.FileComments(file)
  169. if err == nil && comment.Leading != "" {
  170. for _, line := range strings.Split(comment.Leading, "\n") {
  171. line = strings.TrimPrefix(line, " ")
  172. // ensure we don't escape from the block comment
  173. line = strings.Replace(line, "*/", "* /", -1)
  174. t.P(line)
  175. }
  176. t.P()
  177. }
  178. t.P("It is generated from these files:")
  179. for _, f := range t.genFiles {
  180. t.P("\t", f.GetName())
  181. }
  182. t.P("*/")
  183. }
  184. t.P(`package `, t.genPkgName)
  185. t.P()
  186. }
  187. func (t *liverpc) generateImports(file *descriptor.FileDescriptorProto) {
  188. if len(file.Service) == 0 {
  189. return
  190. }
  191. t.P(`import `, t.pkgs["context"], ` "context"`)
  192. t.P()
  193. t.P(`import `, t.pkgs["proto"], ` "github.com/golang/protobuf/proto"`)
  194. t.P(`import "go-common/library/net/rpc/liverpc"`)
  195. t.P()
  196. // It's legal to import a message and use it as an input or output for a
  197. // method. Make sure to import the package of any such message. First, dedupe
  198. // them.
  199. deps := make(map[string]string) // Map of package name to quoted import path.
  200. ourImportPath := path.Dir(goFileName(file))
  201. for _, s := range file.Service {
  202. for _, m := range s.Method {
  203. defs := []*typemap.MessageDefinition{
  204. t.reg.MethodInputDefinition(m),
  205. t.reg.MethodOutputDefinition(m),
  206. }
  207. for _, def := range defs {
  208. // By default, import path is the dirname of the Go filename.
  209. importPath := path.Dir(goFileName(def.File))
  210. if importPath == ourImportPath {
  211. continue
  212. }
  213. if substitution, ok := t.importMap[def.File.GetName()]; ok {
  214. importPath = substitution
  215. }
  216. importPath = t.importPrefix + importPath
  217. pkg := t.goPackageName(def.File)
  218. deps[pkg] = strconv.Quote(importPath)
  219. }
  220. }
  221. }
  222. for pkg, importPath := range deps {
  223. t.P(`import `, pkg, ` `, importPath)
  224. }
  225. if len(deps) > 0 {
  226. t.P()
  227. }
  228. t.P(`var _ proto.Message // generate to suppress unused imports`)
  229. }
  230. func (t *liverpc) generateUtilImports() {
  231. t.P("// Imports only used by utility functions:")
  232. //t.P(`import `, t.pkgs["io"], ` "io"`)
  233. //t.P(`import `, t.pkgs["strconv"], ` "strconv"`)
  234. //t.P(`import `, t.pkgs["json"], ` "encoding/json"`)
  235. //t.P(`import `, t.pkgs["url"], ` "net/url"`)
  236. }
  237. // Generate utility functions used in LiveRpc code.
  238. // These should be generated just once per package.
  239. func (t *liverpc) generateUtils() {
  240. t.sectionComment(`Utils`)
  241. t.P(`func doRPCRequest(ctx `, t.pkgs["context"], `.Context, client *liverpc.Client, version int, method string, in, out `, t.pkgs["proto"], `.Message, opts []liverpc.CallOption) (err error) {`)
  242. t.P(` err = client.Call(ctx, version, method, in, out, opts...)`)
  243. t.P(` return`)
  244. t.P(`}`)
  245. t.P()
  246. }
  247. // P forwards to g.gen.P, which prints output.
  248. func (t *liverpc) P(args ...string) {
  249. for _, v := range args {
  250. t.output.WriteString(v)
  251. }
  252. t.output.WriteByte('\n')
  253. }
  254. // Big header comments to makes it easier to visually parse a generated file.
  255. func (t *liverpc) sectionComment(sectionTitle string) {
  256. t.P()
  257. t.P(`// `, strings.Repeat("=", len(sectionTitle)))
  258. t.P(`// `, sectionTitle)
  259. t.P(`// `, strings.Repeat("=", len(sectionTitle)))
  260. t.P()
  261. }
  262. func (t *liverpc) generateService(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto, index int) {
  263. servName := serviceName(service)
  264. t.sectionComment(servName + ` Interface`)
  265. t.generateLiveRPCInterface(file, service)
  266. t.sectionComment(servName + ` Live Rpc Client`)
  267. t.generateClient(file, service)
  268. }
  269. func (t *liverpc) generateLiveRPCInterface(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) {
  270. comments, err := t.reg.ServiceComments(file, service)
  271. if err == nil {
  272. t.printComments(comments)
  273. }
  274. t.P(`type `, clientName(service), ` interface {`)
  275. for _, method := range service.Method {
  276. comments, err = t.reg.MethodComments(file, service, method)
  277. if err == nil {
  278. t.printComments(comments)
  279. }
  280. t.P(t.generateSignature(method))
  281. t.P()
  282. }
  283. t.P(`}`)
  284. }
  285. func (t *liverpc) generateSignature(method *descriptor.MethodDescriptorProto) string {
  286. methName := methodName(method)
  287. inputBodyType := t.goTypeName(method.GetInputType())
  288. outputType := t.goTypeName(method.GetOutputType())
  289. return fmt.Sprintf(` %s(ctx %s.Context, req *%s, opts ...liverpc.CallOption) (resp *%s, err error)`, methName, t.pkgs["context"], inputBodyType, outputType)
  290. }
  291. // valid names: 'JSON', 'Protobuf'
  292. func (t *liverpc) generateClient(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) {
  293. clientName := clientName(service)
  294. structName := unexported(clientName)
  295. newClientFunc := "New" + clientName
  296. t.P(`type `, structName, ` struct {`)
  297. t.P(` client *liverpc.Client`)
  298. t.P(`}`)
  299. t.P()
  300. t.P(`// `, newClientFunc, ` creates a client that implements the `, clientName, ` interface.`)
  301. t.P(`func `, newClientFunc, `(client *liverpc.Client) `, clientName, ` {`)
  302. t.P(` return &`, structName, `{`)
  303. t.P(` client: client,`)
  304. t.P(` }`)
  305. t.P(`}`)
  306. t.P()
  307. for _, method := range service.Method {
  308. methName := methodName(method)
  309. pkgName := pkgName(file)
  310. inputType := t.goTypeName(method.GetInputType())
  311. outputType := t.goTypeName(method.GetOutputType())
  312. parts := strings.Split(pkgName, ".")
  313. if len(parts) < 2 {
  314. panic("package name must contain at least to parts, eg: service.v1, get " + pkgName + "!")
  315. }
  316. vStr := parts[len(parts)-1]
  317. if len(vStr) < 2 {
  318. panic("package name must contain a valid version, eg: service.v1")
  319. }
  320. _, err := strconv.Atoi(vStr[1:])
  321. if err != nil {
  322. panic("package name must contain a valid version, eg: service.v1, get " + vStr)
  323. }
  324. rpcMethod := method.GetName()
  325. rpcCtrl := service.GetName()
  326. rpcCmd := rpcCtrl + "." + rpcMethod
  327. t.P(`func (c *`, structName, `) `, methName, `(ctx `, t.pkgs["context"], `.Context, in *`, inputType, `, opts ...liverpc.CallOption) (*`, outputType, `, error) {`)
  328. t.P(` out := new(`, outputType, `)`)
  329. t.P(` err := doRPCRequest(ctx,c.client, `, vStr[1:], `, "`, rpcCmd, `", in, out, opts)`)
  330. t.P(` if err != nil {`)
  331. t.P(` return nil, err`)
  332. t.P(` }`)
  333. t.P(` return out, nil`)
  334. t.P(`}`)
  335. t.P()
  336. }
  337. }
  338. func (t *liverpc) generateFileDescriptor(file *descriptor.FileDescriptorProto) {
  339. // Copied straight of of protoc-gen-go, which trims out comments.
  340. pb := proto.Clone(file).(*descriptor.FileDescriptorProto)
  341. pb.SourceCodeInfo = nil
  342. b, err := proto.Marshal(pb)
  343. if err != nil {
  344. gen.Fail(err.Error())
  345. }
  346. var buf bytes.Buffer
  347. w, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression)
  348. w.Write(b)
  349. w.Close()
  350. buf.Bytes()
  351. }
  352. func (t *liverpc) printComments(comments typemap.DefinitionComments) bool {
  353. text := strings.TrimSuffix(comments.Leading, "\n")
  354. if len(strings.TrimSpace(text)) == 0 {
  355. return false
  356. }
  357. split := strings.Split(text, "\n")
  358. for _, line := range split {
  359. t.P("// ", strings.TrimPrefix(line, " "))
  360. }
  361. return len(split) > 0
  362. }
  363. // Given a protobuf name for a Message, return the Go name we will use for that
  364. // type, including its package prefix.
  365. func (t *liverpc) goTypeName(protoName string) string {
  366. def := t.reg.MessageDefinition(protoName)
  367. if def == nil {
  368. gen.Fail("could not find message for", protoName)
  369. }
  370. var prefix string
  371. if pkg := t.goPackageName(def.File); pkg != t.genPkgName {
  372. prefix = pkg + "."
  373. }
  374. var name string
  375. for _, parent := range def.Lineage() {
  376. name += parent.Descriptor.GetName() + "_"
  377. }
  378. name += def.Descriptor.GetName()
  379. return prefix + name
  380. }
  381. func (t *liverpc) goPackageName(file *descriptor.FileDescriptorProto) string {
  382. return t.fileToGoPackageName[file]
  383. }
  384. func (t *liverpc) formattedOutput() string {
  385. // Reformat generated code.
  386. fset := token.NewFileSet()
  387. raw := t.output.Bytes()
  388. ast, err := parser.ParseFile(fset, "", raw, parser.ParseComments)
  389. if err != nil {
  390. // Print out the bad code with line numbers.
  391. // This should never happen in practice, but it can while changing generated code,
  392. // so consider this a debugging aid.
  393. var src bytes.Buffer
  394. s := bufio.NewScanner(bytes.NewReader(raw))
  395. for line := 1; s.Scan(); line++ {
  396. fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
  397. }
  398. gen.Fail("bad Go source code was generated:", err.Error(), "\n"+src.String())
  399. }
  400. out := bytes.NewBuffer(nil)
  401. err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(out, fset, ast)
  402. if err != nil {
  403. gen.Fail("generated Go source code could not be reformatted:", err.Error())
  404. }
  405. return out.String()
  406. }
  407. func unexported(s string) string { return strings.ToLower(s[:1]) + s[1:] }
  408. func pkgName(file *descriptor.FileDescriptorProto) string {
  409. return file.GetPackage()
  410. }
  411. func serviceName(service *descriptor.ServiceDescriptorProto) string {
  412. return stringutils.CamelCase(service.GetName())
  413. }
  414. func clientName(service *descriptor.ServiceDescriptorProto) string {
  415. return serviceName(service) + "RPCClient"
  416. }
  417. func methodName(method *descriptor.MethodDescriptorProto) string {
  418. return stringutils.CamelCase(method.GetName())
  419. }
  420. func fileDescSliceContains(slice []*descriptor.FileDescriptorProto, f *descriptor.FileDescriptorProto) bool {
  421. for _, sf := range slice {
  422. if f == sf {
  423. return true
  424. }
  425. }
  426. return false
  427. }
  428. // deduceGenPkgName figures out the go package name to use for generated code.
  429. // Will try to use the explicit go_package setting in a file (if set, must be
  430. // consistent in all files). If no files have go_package set, then use the
  431. // protobuf package name (must be consistent in all files)
  432. func deduceGenPkgName(genFiles []*descriptor.FileDescriptorProto) (string, error) {
  433. var genPkgName string
  434. for _, f := range genFiles {
  435. name, explicit := goPackageName(f)
  436. if explicit {
  437. name = stringutils.CleanIdentifier(name)
  438. if genPkgName != "" && genPkgName != name {
  439. // Make sure they're all set consistently.
  440. return "", errors.Errorf("files have conflicting go_package settings, must be the same: %q and %q", genPkgName, name)
  441. }
  442. genPkgName = name
  443. }
  444. }
  445. if genPkgName != "" {
  446. return genPkgName, nil
  447. }
  448. // If there is no explicit setting, then check the implicit package name
  449. // (derived from the protobuf package name) of the files and make sure it's
  450. // consistent.
  451. for _, f := range genFiles {
  452. name, _ := goPackageName(f)
  453. name = stringutils.CleanIdentifier(name)
  454. if genPkgName != "" && genPkgName != name {
  455. return "", errors.Errorf("files have conflicting package names, must be the same or overridden with go_package: %q and %q", genPkgName, name)
  456. }
  457. genPkgName = name
  458. }
  459. // All the files have the same name, so we're good.
  460. return genPkgName, nil
  461. }