123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530 |
- // Copyright 2018 Twitch Interactive, Inc. All Rights Reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License"). You may not
- // use this file except in compliance with the License. A copy of the License is
- // located at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // or in the "license" file accompanying this file. This file is distributed on
- // an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
- // express or implied. See the License for the specific language governing
- // permissions and limitations under the License.
- package main
- import (
- "bufio"
- "bytes"
- "compress/gzip"
- "fmt"
- "go/parser"
- "go/printer"
- "go/token"
- "path"
- "strconv"
- "strings"
- "go-common/app/tool/liverpc/protoc-gen-liverpc/gen"
- "go-common/app/tool/liverpc/protoc-gen-liverpc/gen/stringutils"
- "go-common/app/tool/liverpc/protoc-gen-liverpc/gen/typemap"
- "github.com/golang/protobuf/proto"
- "github.com/golang/protobuf/protoc-gen-go/descriptor"
- plugin "github.com/golang/protobuf/protoc-gen-go/plugin"
- "github.com/pkg/errors"
- )
- type liverpc struct {
- filesHandled int
- reg *typemap.Registry
- // Map to record whether we've built each package
- pkgs map[string]string
- pkgNamesInUse map[string]bool
- importPrefix string // String to prefix to imported package file names.
- importMap map[string]string // Mapping from .proto file name to import path.
- // Package naming:
- genPkgName string // Name of the package that we're generating
- fileToGoPackageName map[*descriptor.FileDescriptorProto]string
- // List of files that were inputs to the generator. We need to hold this in
- // the struct so we can write a header for the file that lists its inputs.
- genFiles []*descriptor.FileDescriptorProto
- // Output buffer that holds the bytes we want to write out for a single file.
- // Gets reset after working on a file.
- output *bytes.Buffer
- }
- func liveRPCGenerator() *liverpc {
- t := &liverpc{
- pkgs: make(map[string]string),
- pkgNamesInUse: make(map[string]bool),
- importMap: make(map[string]string),
- fileToGoPackageName: make(map[*descriptor.FileDescriptorProto]string),
- output: bytes.NewBuffer(nil),
- }
- return t
- }
- func (t *liverpc) Generate(in *plugin.CodeGeneratorRequest) *plugin.CodeGeneratorResponse {
- params, err := parseCommandLineParams(in.GetParameter())
- if err != nil {
- gen.Fail("could not parse parameters passed to --liverpc_out", err.Error())
- }
- t.importPrefix = params.importPrefix
- t.importMap = params.importMap
- t.genFiles = gen.FilesToGenerate(in)
- // Collect information on types.
- t.reg = typemap.New(in.ProtoFile)
- t.registerPackageName("context")
- t.registerPackageName("ioutil")
- t.registerPackageName("proto")
- t.registerPackageName("liverpc")
- // Time to figure out package names of objects defined in protobuf. First,
- // we'll figure out the name for the package we're generating.
- genPkgName, err := deduceGenPkgName(t.genFiles)
- if err != nil {
- gen.Fail(err.Error())
- }
- t.genPkgName = genPkgName
- // Next, we need to pick names for all the files that are dependencies.
- for _, f := range in.ProtoFile {
- if fileDescSliceContains(t.genFiles, f) {
- // This is a file we are generating. It gets the shared package name.
- t.fileToGoPackageName[f] = t.genPkgName
- } else {
- // This is a dependency. Use its package name.
- name := f.GetPackage()
- if name == "" {
- name = stringutils.BaseName(f.GetName())
- }
- name = stringutils.CleanIdentifier(name)
- alias := t.registerPackageName(name)
- t.fileToGoPackageName[f] = alias
- }
- }
- // Showtime! Generate the response.
- resp := new(plugin.CodeGeneratorResponse)
- var servicesNames []string
- for _, f := range t.genFiles {
- respFile := t.generate(f)
- for _, s := range f.Service {
- servicesNames = append(servicesNames, *s.Name)
- }
- if respFile != nil {
- resp.File = append(resp.File, respFile)
- }
- }
- // generate a temp file of service names
- // because a protobuf plugin can only generate for a single package
- // therefore we generate these temp files for other script to combine
- // a single client for all packages
- var filename = "client." + genPkgName + ".txt"
- var respFile = &plugin.CodeGeneratorResponse_File{}
- respFile.Name = &filename
- var content = strings.Join(servicesNames, "\n")
- content += "\n"
- respFile.Content = &content
- resp.File = append(resp.File, respFile)
- return resp
- }
- func (t *liverpc) registerPackageName(name string) (alias string) {
- alias = name
- i := 1
- for t.pkgNamesInUse[alias] {
- alias = name + strconv.Itoa(i)
- i++
- }
- t.pkgNamesInUse[alias] = true
- t.pkgs[name] = alias
- return alias
- }
- func (t *liverpc) generate(file *descriptor.FileDescriptorProto) *plugin.CodeGeneratorResponse_File {
- resp := new(plugin.CodeGeneratorResponse_File)
- if len(file.Service) == 0 {
- return nil
- }
- t.generateFileHeader(file)
- t.generateImports(file)
- if t.filesHandled == 0 {
- t.generateUtilImports()
- }
- // For each service, generate client stubs and server
- for i, service := range file.Service {
- t.generateService(file, service, i)
- }
- // Util functions only generated once per package
- if t.filesHandled == 0 {
- t.generateUtils()
- }
- t.generateFileDescriptor(file)
- resp.Name = proto.String(goFileName(file))
- resp.Content = proto.String(t.formattedOutput())
- t.output.Reset()
- t.filesHandled++
- return resp
- }
- func (t *liverpc) generateFileHeader(file *descriptor.FileDescriptorProto) {
- t.P("// Code generated by protoc-gen-liverpc ", gen.Version, ", DO NOT EDIT.")
- t.P("// source: ", file.GetName())
- t.P()
- if t.filesHandled == 0 {
- t.P("/*")
- t.P("Package ", t.genPkgName, " is a generated liverpc stub package.")
- t.P("This code was generated with go-common/app/tool/liverpc/protoc-gen-liverpc ", gen.Version, ".")
- t.P()
- comment, err := t.reg.FileComments(file)
- if err == nil && comment.Leading != "" {
- for _, line := range strings.Split(comment.Leading, "\n") {
- line = strings.TrimPrefix(line, " ")
- // ensure we don't escape from the block comment
- line = strings.Replace(line, "*/", "* /", -1)
- t.P(line)
- }
- t.P()
- }
- t.P("It is generated from these files:")
- for _, f := range t.genFiles {
- t.P("\t", f.GetName())
- }
- t.P("*/")
- }
- t.P(`package `, t.genPkgName)
- t.P()
- }
- func (t *liverpc) generateImports(file *descriptor.FileDescriptorProto) {
- if len(file.Service) == 0 {
- return
- }
- t.P(`import `, t.pkgs["context"], ` "context"`)
- t.P()
- t.P(`import `, t.pkgs["proto"], ` "github.com/golang/protobuf/proto"`)
- t.P(`import "go-common/library/net/rpc/liverpc"`)
- t.P()
- // It's legal to import a message and use it as an input or output for a
- // method. Make sure to import the package of any such message. First, dedupe
- // them.
- deps := make(map[string]string) // Map of package name to quoted import path.
- ourImportPath := path.Dir(goFileName(file))
- for _, s := range file.Service {
- for _, m := range s.Method {
- defs := []*typemap.MessageDefinition{
- t.reg.MethodInputDefinition(m),
- t.reg.MethodOutputDefinition(m),
- }
- for _, def := range defs {
- // By default, import path is the dirname of the Go filename.
- importPath := path.Dir(goFileName(def.File))
- if importPath == ourImportPath {
- continue
- }
- if substitution, ok := t.importMap[def.File.GetName()]; ok {
- importPath = substitution
- }
- importPath = t.importPrefix + importPath
- pkg := t.goPackageName(def.File)
- deps[pkg] = strconv.Quote(importPath)
- }
- }
- }
- for pkg, importPath := range deps {
- t.P(`import `, pkg, ` `, importPath)
- }
- if len(deps) > 0 {
- t.P()
- }
- t.P(`var _ proto.Message // generate to suppress unused imports`)
- }
- func (t *liverpc) generateUtilImports() {
- t.P("// Imports only used by utility functions:")
- //t.P(`import `, t.pkgs["io"], ` "io"`)
- //t.P(`import `, t.pkgs["strconv"], ` "strconv"`)
- //t.P(`import `, t.pkgs["json"], ` "encoding/json"`)
- //t.P(`import `, t.pkgs["url"], ` "net/url"`)
- }
- // Generate utility functions used in LiveRpc code.
- // These should be generated just once per package.
- func (t *liverpc) generateUtils() {
- t.sectionComment(`Utils`)
- t.P(`func doRPCRequest(ctx `, t.pkgs["context"], `.Context, client *liverpc.Client, version int, method string, in, out `, t.pkgs["proto"], `.Message, opts []liverpc.CallOption) (err error) {`)
- t.P(` err = client.Call(ctx, version, method, in, out, opts...)`)
- t.P(` return`)
- t.P(`}`)
- t.P()
- }
- // P forwards to g.gen.P, which prints output.
- func (t *liverpc) P(args ...string) {
- for _, v := range args {
- t.output.WriteString(v)
- }
- t.output.WriteByte('\n')
- }
- // Big header comments to makes it easier to visually parse a generated file.
- func (t *liverpc) sectionComment(sectionTitle string) {
- t.P()
- t.P(`// `, strings.Repeat("=", len(sectionTitle)))
- t.P(`// `, sectionTitle)
- t.P(`// `, strings.Repeat("=", len(sectionTitle)))
- t.P()
- }
- func (t *liverpc) generateService(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto, index int) {
- servName := serviceName(service)
- t.sectionComment(servName + ` Interface`)
- t.generateLiveRPCInterface(file, service)
- t.sectionComment(servName + ` Live Rpc Client`)
- t.generateClient(file, service)
- }
- func (t *liverpc) generateLiveRPCInterface(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) {
- comments, err := t.reg.ServiceComments(file, service)
- if err == nil {
- t.printComments(comments)
- }
- t.P(`type `, clientName(service), ` interface {`)
- for _, method := range service.Method {
- comments, err = t.reg.MethodComments(file, service, method)
- if err == nil {
- t.printComments(comments)
- }
- t.P(t.generateSignature(method))
- t.P()
- }
- t.P(`}`)
- }
- func (t *liverpc) generateSignature(method *descriptor.MethodDescriptorProto) string {
- methName := methodName(method)
- inputBodyType := t.goTypeName(method.GetInputType())
- outputType := t.goTypeName(method.GetOutputType())
- return fmt.Sprintf(` %s(ctx %s.Context, req *%s, opts ...liverpc.CallOption) (resp *%s, err error)`, methName, t.pkgs["context"], inputBodyType, outputType)
- }
- // valid names: 'JSON', 'Protobuf'
- func (t *liverpc) generateClient(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) {
- clientName := clientName(service)
- structName := unexported(clientName)
- newClientFunc := "New" + clientName
- t.P(`type `, structName, ` struct {`)
- t.P(` client *liverpc.Client`)
- t.P(`}`)
- t.P()
- t.P(`// `, newClientFunc, ` creates a client that implements the `, clientName, ` interface.`)
- t.P(`func `, newClientFunc, `(client *liverpc.Client) `, clientName, ` {`)
- t.P(` return &`, structName, `{`)
- t.P(` client: client,`)
- t.P(` }`)
- t.P(`}`)
- t.P()
- for _, method := range service.Method {
- methName := methodName(method)
- pkgName := pkgName(file)
- inputType := t.goTypeName(method.GetInputType())
- outputType := t.goTypeName(method.GetOutputType())
- parts := strings.Split(pkgName, ".")
- if len(parts) < 2 {
- panic("package name must contain at least to parts, eg: service.v1, get " + pkgName + "!")
- }
- vStr := parts[len(parts)-1]
- if len(vStr) < 2 {
- panic("package name must contain a valid version, eg: service.v1")
- }
- _, err := strconv.Atoi(vStr[1:])
- if err != nil {
- panic("package name must contain a valid version, eg: service.v1, get " + vStr)
- }
- rpcMethod := method.GetName()
- rpcCtrl := service.GetName()
- rpcCmd := rpcCtrl + "." + rpcMethod
- t.P(`func (c *`, structName, `) `, methName, `(ctx `, t.pkgs["context"], `.Context, in *`, inputType, `, opts ...liverpc.CallOption) (*`, outputType, `, error) {`)
- t.P(` out := new(`, outputType, `)`)
- t.P(` err := doRPCRequest(ctx,c.client, `, vStr[1:], `, "`, rpcCmd, `", in, out, opts)`)
- t.P(` if err != nil {`)
- t.P(` return nil, err`)
- t.P(` }`)
- t.P(` return out, nil`)
- t.P(`}`)
- t.P()
- }
- }
- func (t *liverpc) generateFileDescriptor(file *descriptor.FileDescriptorProto) {
- // Copied straight of of protoc-gen-go, which trims out comments.
- pb := proto.Clone(file).(*descriptor.FileDescriptorProto)
- pb.SourceCodeInfo = nil
- b, err := proto.Marshal(pb)
- if err != nil {
- gen.Fail(err.Error())
- }
- var buf bytes.Buffer
- w, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression)
- w.Write(b)
- w.Close()
- buf.Bytes()
- }
- func (t *liverpc) printComments(comments typemap.DefinitionComments) bool {
- text := strings.TrimSuffix(comments.Leading, "\n")
- if len(strings.TrimSpace(text)) == 0 {
- return false
- }
- split := strings.Split(text, "\n")
- for _, line := range split {
- t.P("// ", strings.TrimPrefix(line, " "))
- }
- return len(split) > 0
- }
- // Given a protobuf name for a Message, return the Go name we will use for that
- // type, including its package prefix.
- func (t *liverpc) goTypeName(protoName string) string {
- def := t.reg.MessageDefinition(protoName)
- if def == nil {
- gen.Fail("could not find message for", protoName)
- }
- var prefix string
- if pkg := t.goPackageName(def.File); pkg != t.genPkgName {
- prefix = pkg + "."
- }
- var name string
- for _, parent := range def.Lineage() {
- name += parent.Descriptor.GetName() + "_"
- }
- name += def.Descriptor.GetName()
- return prefix + name
- }
- func (t *liverpc) goPackageName(file *descriptor.FileDescriptorProto) string {
- return t.fileToGoPackageName[file]
- }
- func (t *liverpc) formattedOutput() string {
- // Reformat generated code.
- fset := token.NewFileSet()
- raw := t.output.Bytes()
- ast, err := parser.ParseFile(fset, "", raw, parser.ParseComments)
- if err != nil {
- // Print out the bad code with line numbers.
- // This should never happen in practice, but it can while changing generated code,
- // so consider this a debugging aid.
- var src bytes.Buffer
- s := bufio.NewScanner(bytes.NewReader(raw))
- for line := 1; s.Scan(); line++ {
- fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
- }
- gen.Fail("bad Go source code was generated:", err.Error(), "\n"+src.String())
- }
- out := bytes.NewBuffer(nil)
- err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(out, fset, ast)
- if err != nil {
- gen.Fail("generated Go source code could not be reformatted:", err.Error())
- }
- return out.String()
- }
- func unexported(s string) string { return strings.ToLower(s[:1]) + s[1:] }
- func pkgName(file *descriptor.FileDescriptorProto) string {
- return file.GetPackage()
- }
- func serviceName(service *descriptor.ServiceDescriptorProto) string {
- return stringutils.CamelCase(service.GetName())
- }
- func clientName(service *descriptor.ServiceDescriptorProto) string {
- return serviceName(service) + "RPCClient"
- }
- func methodName(method *descriptor.MethodDescriptorProto) string {
- return stringutils.CamelCase(method.GetName())
- }
- func fileDescSliceContains(slice []*descriptor.FileDescriptorProto, f *descriptor.FileDescriptorProto) bool {
- for _, sf := range slice {
- if f == sf {
- return true
- }
- }
- return false
- }
- // deduceGenPkgName figures out the go package name to use for generated code.
- // Will try to use the explicit go_package setting in a file (if set, must be
- // consistent in all files). If no files have go_package set, then use the
- // protobuf package name (must be consistent in all files)
- func deduceGenPkgName(genFiles []*descriptor.FileDescriptorProto) (string, error) {
- var genPkgName string
- for _, f := range genFiles {
- name, explicit := goPackageName(f)
- if explicit {
- name = stringutils.CleanIdentifier(name)
- if genPkgName != "" && genPkgName != name {
- // Make sure they're all set consistently.
- return "", errors.Errorf("files have conflicting go_package settings, must be the same: %q and %q", genPkgName, name)
- }
- genPkgName = name
- }
- }
- if genPkgName != "" {
- return genPkgName, nil
- }
- // If there is no explicit setting, then check the implicit package name
- // (derived from the protobuf package name) of the files and make sure it's
- // consistent.
- for _, f := range genFiles {
- name, _ := goPackageName(f)
- name = stringutils.CleanIdentifier(name)
- if genPkgName != "" && genPkgName != name {
- return "", errors.Errorf("files have conflicting package names, must be the same or overridden with go_package: %q and %q", genPkgName, name)
- }
- genPkgName = name
- }
- // All the files have the same name, so we're good.
- return genPkgName, nil
- }
|