// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may not // use this file except in compliance with the License. A copy of the License is // located at // // http://www.apache.org/licenses/LICENSE-2.0 // // or in the "license" file accompanying this file. This file is distributed on // an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either // express or implied. See the License for the specific language governing // permissions and limitations under the License. package main import ( "bufio" "bytes" "compress/gzip" "fmt" "go/parser" "go/printer" "go/token" "path" "strconv" "strings" "go-common/app/tool/liverpc/protoc-gen-liverpc/gen" "go-common/app/tool/liverpc/protoc-gen-liverpc/gen/stringutils" "go-common/app/tool/liverpc/protoc-gen-liverpc/gen/typemap" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/protoc-gen-go/descriptor" plugin "github.com/golang/protobuf/protoc-gen-go/plugin" "github.com/pkg/errors" ) type liverpc struct { filesHandled int reg *typemap.Registry // Map to record whether we've built each package pkgs map[string]string pkgNamesInUse map[string]bool importPrefix string // String to prefix to imported package file names. importMap map[string]string // Mapping from .proto file name to import path. // Package naming: genPkgName string // Name of the package that we're generating fileToGoPackageName map[*descriptor.FileDescriptorProto]string // List of files that were inputs to the generator. We need to hold this in // the struct so we can write a header for the file that lists its inputs. genFiles []*descriptor.FileDescriptorProto // Output buffer that holds the bytes we want to write out for a single file. // Gets reset after working on a file. output *bytes.Buffer } func liveRPCGenerator() *liverpc { t := &liverpc{ pkgs: make(map[string]string), pkgNamesInUse: make(map[string]bool), importMap: make(map[string]string), fileToGoPackageName: make(map[*descriptor.FileDescriptorProto]string), output: bytes.NewBuffer(nil), } return t } func (t *liverpc) Generate(in *plugin.CodeGeneratorRequest) *plugin.CodeGeneratorResponse { params, err := parseCommandLineParams(in.GetParameter()) if err != nil { gen.Fail("could not parse parameters passed to --liverpc_out", err.Error()) } t.importPrefix = params.importPrefix t.importMap = params.importMap t.genFiles = gen.FilesToGenerate(in) // Collect information on types. t.reg = typemap.New(in.ProtoFile) t.registerPackageName("context") t.registerPackageName("ioutil") t.registerPackageName("proto") t.registerPackageName("liverpc") // Time to figure out package names of objects defined in protobuf. First, // we'll figure out the name for the package we're generating. genPkgName, err := deduceGenPkgName(t.genFiles) if err != nil { gen.Fail(err.Error()) } t.genPkgName = genPkgName // Next, we need to pick names for all the files that are dependencies. for _, f := range in.ProtoFile { if fileDescSliceContains(t.genFiles, f) { // This is a file we are generating. It gets the shared package name. t.fileToGoPackageName[f] = t.genPkgName } else { // This is a dependency. Use its package name. name := f.GetPackage() if name == "" { name = stringutils.BaseName(f.GetName()) } name = stringutils.CleanIdentifier(name) alias := t.registerPackageName(name) t.fileToGoPackageName[f] = alias } } // Showtime! Generate the response. resp := new(plugin.CodeGeneratorResponse) var servicesNames []string for _, f := range t.genFiles { respFile := t.generate(f) for _, s := range f.Service { servicesNames = append(servicesNames, *s.Name) } if respFile != nil { resp.File = append(resp.File, respFile) } } // generate a temp file of service names // because a protobuf plugin can only generate for a single package // therefore we generate these temp files for other script to combine // a single client for all packages var filename = "client." + genPkgName + ".txt" var respFile = &plugin.CodeGeneratorResponse_File{} respFile.Name = &filename var content = strings.Join(servicesNames, "\n") content += "\n" respFile.Content = &content resp.File = append(resp.File, respFile) return resp } func (t *liverpc) registerPackageName(name string) (alias string) { alias = name i := 1 for t.pkgNamesInUse[alias] { alias = name + strconv.Itoa(i) i++ } t.pkgNamesInUse[alias] = true t.pkgs[name] = alias return alias } func (t *liverpc) generate(file *descriptor.FileDescriptorProto) *plugin.CodeGeneratorResponse_File { resp := new(plugin.CodeGeneratorResponse_File) if len(file.Service) == 0 { return nil } t.generateFileHeader(file) t.generateImports(file) if t.filesHandled == 0 { t.generateUtilImports() } // For each service, generate client stubs and server for i, service := range file.Service { t.generateService(file, service, i) } // Util functions only generated once per package if t.filesHandled == 0 { t.generateUtils() } t.generateFileDescriptor(file) resp.Name = proto.String(goFileName(file)) resp.Content = proto.String(t.formattedOutput()) t.output.Reset() t.filesHandled++ return resp } func (t *liverpc) generateFileHeader(file *descriptor.FileDescriptorProto) { t.P("// Code generated by protoc-gen-liverpc ", gen.Version, ", DO NOT EDIT.") t.P("// source: ", file.GetName()) t.P() if t.filesHandled == 0 { t.P("/*") t.P("Package ", t.genPkgName, " is a generated liverpc stub package.") t.P("This code was generated with go-common/app/tool/liverpc/protoc-gen-liverpc ", gen.Version, ".") t.P() comment, err := t.reg.FileComments(file) if err == nil && comment.Leading != "" { for _, line := range strings.Split(comment.Leading, "\n") { line = strings.TrimPrefix(line, " ") // ensure we don't escape from the block comment line = strings.Replace(line, "*/", "* /", -1) t.P(line) } t.P() } t.P("It is generated from these files:") for _, f := range t.genFiles { t.P("\t", f.GetName()) } t.P("*/") } t.P(`package `, t.genPkgName) t.P() } func (t *liverpc) generateImports(file *descriptor.FileDescriptorProto) { if len(file.Service) == 0 { return } t.P(`import `, t.pkgs["context"], ` "context"`) t.P() t.P(`import `, t.pkgs["proto"], ` "github.com/golang/protobuf/proto"`) t.P(`import "go-common/library/net/rpc/liverpc"`) t.P() // It's legal to import a message and use it as an input or output for a // method. Make sure to import the package of any such message. First, dedupe // them. deps := make(map[string]string) // Map of package name to quoted import path. ourImportPath := path.Dir(goFileName(file)) for _, s := range file.Service { for _, m := range s.Method { defs := []*typemap.MessageDefinition{ t.reg.MethodInputDefinition(m), t.reg.MethodOutputDefinition(m), } for _, def := range defs { // By default, import path is the dirname of the Go filename. importPath := path.Dir(goFileName(def.File)) if importPath == ourImportPath { continue } if substitution, ok := t.importMap[def.File.GetName()]; ok { importPath = substitution } importPath = t.importPrefix + importPath pkg := t.goPackageName(def.File) deps[pkg] = strconv.Quote(importPath) } } } for pkg, importPath := range deps { t.P(`import `, pkg, ` `, importPath) } if len(deps) > 0 { t.P() } t.P(`var _ proto.Message // generate to suppress unused imports`) } func (t *liverpc) generateUtilImports() { t.P("// Imports only used by utility functions:") //t.P(`import `, t.pkgs["io"], ` "io"`) //t.P(`import `, t.pkgs["strconv"], ` "strconv"`) //t.P(`import `, t.pkgs["json"], ` "encoding/json"`) //t.P(`import `, t.pkgs["url"], ` "net/url"`) } // Generate utility functions used in LiveRpc code. // These should be generated just once per package. func (t *liverpc) generateUtils() { t.sectionComment(`Utils`) t.P(`func doRPCRequest(ctx `, t.pkgs["context"], `.Context, client *liverpc.Client, version int, method string, in, out `, t.pkgs["proto"], `.Message, opts []liverpc.CallOption) (err error) {`) t.P(` err = client.Call(ctx, version, method, in, out, opts...)`) t.P(` return`) t.P(`}`) t.P() } // P forwards to g.gen.P, which prints output. func (t *liverpc) P(args ...string) { for _, v := range args { t.output.WriteString(v) } t.output.WriteByte('\n') } // Big header comments to makes it easier to visually parse a generated file. func (t *liverpc) sectionComment(sectionTitle string) { t.P() t.P(`// `, strings.Repeat("=", len(sectionTitle))) t.P(`// `, sectionTitle) t.P(`// `, strings.Repeat("=", len(sectionTitle))) t.P() } func (t *liverpc) generateService(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto, index int) { servName := serviceName(service) t.sectionComment(servName + ` Interface`) t.generateLiveRPCInterface(file, service) t.sectionComment(servName + ` Live Rpc Client`) t.generateClient(file, service) } func (t *liverpc) generateLiveRPCInterface(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) { comments, err := t.reg.ServiceComments(file, service) if err == nil { t.printComments(comments) } t.P(`type `, clientName(service), ` interface {`) for _, method := range service.Method { comments, err = t.reg.MethodComments(file, service, method) if err == nil { t.printComments(comments) } t.P(t.generateSignature(method)) t.P() } t.P(`}`) } func (t *liverpc) generateSignature(method *descriptor.MethodDescriptorProto) string { methName := methodName(method) inputBodyType := t.goTypeName(method.GetInputType()) outputType := t.goTypeName(method.GetOutputType()) return fmt.Sprintf(` %s(ctx %s.Context, req *%s, opts ...liverpc.CallOption) (resp *%s, err error)`, methName, t.pkgs["context"], inputBodyType, outputType) } // valid names: 'JSON', 'Protobuf' func (t *liverpc) generateClient(file *descriptor.FileDescriptorProto, service *descriptor.ServiceDescriptorProto) { clientName := clientName(service) structName := unexported(clientName) newClientFunc := "New" + clientName t.P(`type `, structName, ` struct {`) t.P(` client *liverpc.Client`) t.P(`}`) t.P() t.P(`// `, newClientFunc, ` creates a client that implements the `, clientName, ` interface.`) t.P(`func `, newClientFunc, `(client *liverpc.Client) `, clientName, ` {`) t.P(` return &`, structName, `{`) t.P(` client: client,`) t.P(` }`) t.P(`}`) t.P() for _, method := range service.Method { methName := methodName(method) pkgName := pkgName(file) inputType := t.goTypeName(method.GetInputType()) outputType := t.goTypeName(method.GetOutputType()) parts := strings.Split(pkgName, ".") if len(parts) < 2 { panic("package name must contain at least to parts, eg: service.v1, get " + pkgName + "!") } vStr := parts[len(parts)-1] if len(vStr) < 2 { panic("package name must contain a valid version, eg: service.v1") } _, err := strconv.Atoi(vStr[1:]) if err != nil { panic("package name must contain a valid version, eg: service.v1, get " + vStr) } rpcMethod := method.GetName() rpcCtrl := service.GetName() rpcCmd := rpcCtrl + "." + rpcMethod t.P(`func (c *`, structName, `) `, methName, `(ctx `, t.pkgs["context"], `.Context, in *`, inputType, `, opts ...liverpc.CallOption) (*`, outputType, `, error) {`) t.P(` out := new(`, outputType, `)`) t.P(` err := doRPCRequest(ctx,c.client, `, vStr[1:], `, "`, rpcCmd, `", in, out, opts)`) t.P(` if err != nil {`) t.P(` return nil, err`) t.P(` }`) t.P(` return out, nil`) t.P(`}`) t.P() } } func (t *liverpc) generateFileDescriptor(file *descriptor.FileDescriptorProto) { // Copied straight of of protoc-gen-go, which trims out comments. pb := proto.Clone(file).(*descriptor.FileDescriptorProto) pb.SourceCodeInfo = nil b, err := proto.Marshal(pb) if err != nil { gen.Fail(err.Error()) } var buf bytes.Buffer w, _ := gzip.NewWriterLevel(&buf, gzip.BestCompression) w.Write(b) w.Close() buf.Bytes() } func (t *liverpc) printComments(comments typemap.DefinitionComments) bool { text := strings.TrimSuffix(comments.Leading, "\n") if len(strings.TrimSpace(text)) == 0 { return false } split := strings.Split(text, "\n") for _, line := range split { t.P("// ", strings.TrimPrefix(line, " ")) } return len(split) > 0 } // Given a protobuf name for a Message, return the Go name we will use for that // type, including its package prefix. func (t *liverpc) goTypeName(protoName string) string { def := t.reg.MessageDefinition(protoName) if def == nil { gen.Fail("could not find message for", protoName) } var prefix string if pkg := t.goPackageName(def.File); pkg != t.genPkgName { prefix = pkg + "." } var name string for _, parent := range def.Lineage() { name += parent.Descriptor.GetName() + "_" } name += def.Descriptor.GetName() return prefix + name } func (t *liverpc) goPackageName(file *descriptor.FileDescriptorProto) string { return t.fileToGoPackageName[file] } func (t *liverpc) formattedOutput() string { // Reformat generated code. fset := token.NewFileSet() raw := t.output.Bytes() ast, err := parser.ParseFile(fset, "", raw, parser.ParseComments) if err != nil { // Print out the bad code with line numbers. // This should never happen in practice, but it can while changing generated code, // so consider this a debugging aid. var src bytes.Buffer s := bufio.NewScanner(bytes.NewReader(raw)) for line := 1; s.Scan(); line++ { fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes()) } gen.Fail("bad Go source code was generated:", err.Error(), "\n"+src.String()) } out := bytes.NewBuffer(nil) err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(out, fset, ast) if err != nil { gen.Fail("generated Go source code could not be reformatted:", err.Error()) } return out.String() } func unexported(s string) string { return strings.ToLower(s[:1]) + s[1:] } func pkgName(file *descriptor.FileDescriptorProto) string { return file.GetPackage() } func serviceName(service *descriptor.ServiceDescriptorProto) string { return stringutils.CamelCase(service.GetName()) } func clientName(service *descriptor.ServiceDescriptorProto) string { return serviceName(service) + "RPCClient" } func methodName(method *descriptor.MethodDescriptorProto) string { return stringutils.CamelCase(method.GetName()) } func fileDescSliceContains(slice []*descriptor.FileDescriptorProto, f *descriptor.FileDescriptorProto) bool { for _, sf := range slice { if f == sf { return true } } return false } // deduceGenPkgName figures out the go package name to use for generated code. // Will try to use the explicit go_package setting in a file (if set, must be // consistent in all files). If no files have go_package set, then use the // protobuf package name (must be consistent in all files) func deduceGenPkgName(genFiles []*descriptor.FileDescriptorProto) (string, error) { var genPkgName string for _, f := range genFiles { name, explicit := goPackageName(f) if explicit { name = stringutils.CleanIdentifier(name) if genPkgName != "" && genPkgName != name { // Make sure they're all set consistently. return "", errors.Errorf("files have conflicting go_package settings, must be the same: %q and %q", genPkgName, name) } genPkgName = name } } if genPkgName != "" { return genPkgName, nil } // If there is no explicit setting, then check the implicit package name // (derived from the protobuf package name) of the files and make sure it's // consistent. for _, f := range genFiles { name, _ := goPackageName(f) name = stringutils.CleanIdentifier(name) if genPkgName != "" && genPkgName != name { return "", errors.Errorf("files have conflicting package names, must be the same or overridden with go_package: %q and %q", genPkgName, name) } genPkgName = name } // All the files have the same name, so we're good. return genPkgName, nil }