serverreflection.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. /*
  2. *
  3. * Copyright 2016 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. //go:generate protoc --go_out=plugins=grpc:. grpc_reflection_v1alpha/reflection.proto
  19. /*
  20. Package reflection implements server reflection service.
  21. The service implemented is defined in:
  22. https://github.com/grpc/grpc/blob/master/src/proto/grpc/reflection/v1alpha/reflection.proto.
  23. To register server reflection on a gRPC server:
  24. import "google.golang.org/grpc/reflection"
  25. s := grpc.NewServer()
  26. pb.RegisterYourOwnServer(s, &server{})
  27. // Register reflection service on gRPC server.
  28. reflection.Register(s)
  29. s.Serve(lis)
  30. */
  31. package reflection // import "google.golang.org/grpc/reflection"
  32. import (
  33. "bytes"
  34. "compress/gzip"
  35. "fmt"
  36. "io"
  37. "io/ioutil"
  38. "reflect"
  39. "sort"
  40. "sync"
  41. "github.com/golang/protobuf/proto"
  42. dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
  43. "google.golang.org/grpc"
  44. "google.golang.org/grpc/codes"
  45. rpb "google.golang.org/grpc/reflection/grpc_reflection_v1alpha"
  46. "google.golang.org/grpc/status"
  47. )
  48. type serverReflectionServer struct {
  49. s *grpc.Server
  50. initSymbols sync.Once
  51. serviceNames []string
  52. symbols map[string]*dpb.FileDescriptorProto // map of fully-qualified names to files
  53. }
  54. // Register registers the server reflection service on the given gRPC server.
  55. func Register(s *grpc.Server) {
  56. rpb.RegisterServerReflectionServer(s, &serverReflectionServer{
  57. s: s,
  58. })
  59. }
  60. // protoMessage is used for type assertion on proto messages.
  61. // Generated proto message implements function Descriptor(), but Descriptor()
  62. // is not part of interface proto.Message. This interface is needed to
  63. // call Descriptor().
  64. type protoMessage interface {
  65. Descriptor() ([]byte, []int)
  66. }
  67. func (s *serverReflectionServer) getSymbols() (svcNames []string, symbolIndex map[string]*dpb.FileDescriptorProto) {
  68. s.initSymbols.Do(func() {
  69. serviceInfo := s.s.GetServiceInfo()
  70. s.symbols = map[string]*dpb.FileDescriptorProto{}
  71. s.serviceNames = make([]string, 0, len(serviceInfo))
  72. processed := map[string]struct{}{}
  73. for svc, info := range serviceInfo {
  74. s.serviceNames = append(s.serviceNames, svc)
  75. fdenc, ok := parseMetadata(info.Metadata)
  76. if !ok {
  77. continue
  78. }
  79. fd, err := decodeFileDesc(fdenc)
  80. if err != nil {
  81. continue
  82. }
  83. s.processFile(fd, processed)
  84. }
  85. sort.Strings(s.serviceNames)
  86. })
  87. return s.serviceNames, s.symbols
  88. }
  89. func (s *serverReflectionServer) processFile(fd *dpb.FileDescriptorProto, processed map[string]struct{}) {
  90. filename := fd.GetName()
  91. if _, ok := processed[filename]; ok {
  92. return
  93. }
  94. processed[filename] = struct{}{}
  95. prefix := fd.GetPackage()
  96. for _, msg := range fd.MessageType {
  97. s.processMessage(fd, prefix, msg)
  98. }
  99. for _, en := range fd.EnumType {
  100. s.processEnum(fd, prefix, en)
  101. }
  102. for _, ext := range fd.Extension {
  103. s.processField(fd, prefix, ext)
  104. }
  105. for _, svc := range fd.Service {
  106. svcName := fqn(prefix, svc.GetName())
  107. s.symbols[svcName] = fd
  108. for _, meth := range svc.Method {
  109. name := fqn(svcName, meth.GetName())
  110. s.symbols[name] = fd
  111. }
  112. }
  113. for _, dep := range fd.Dependency {
  114. fdenc := proto.FileDescriptor(dep)
  115. fdDep, err := decodeFileDesc(fdenc)
  116. if err != nil {
  117. continue
  118. }
  119. s.processFile(fdDep, processed)
  120. }
  121. }
  122. func (s *serverReflectionServer) processMessage(fd *dpb.FileDescriptorProto, prefix string, msg *dpb.DescriptorProto) {
  123. msgName := fqn(prefix, msg.GetName())
  124. s.symbols[msgName] = fd
  125. for _, nested := range msg.NestedType {
  126. s.processMessage(fd, msgName, nested)
  127. }
  128. for _, en := range msg.EnumType {
  129. s.processEnum(fd, msgName, en)
  130. }
  131. for _, ext := range msg.Extension {
  132. s.processField(fd, msgName, ext)
  133. }
  134. for _, fld := range msg.Field {
  135. s.processField(fd, msgName, fld)
  136. }
  137. for _, oneof := range msg.OneofDecl {
  138. oneofName := fqn(msgName, oneof.GetName())
  139. s.symbols[oneofName] = fd
  140. }
  141. }
  142. func (s *serverReflectionServer) processEnum(fd *dpb.FileDescriptorProto, prefix string, en *dpb.EnumDescriptorProto) {
  143. enName := fqn(prefix, en.GetName())
  144. s.symbols[enName] = fd
  145. for _, val := range en.Value {
  146. valName := fqn(enName, val.GetName())
  147. s.symbols[valName] = fd
  148. }
  149. }
  150. func (s *serverReflectionServer) processField(fd *dpb.FileDescriptorProto, prefix string, fld *dpb.FieldDescriptorProto) {
  151. fldName := fqn(prefix, fld.GetName())
  152. s.symbols[fldName] = fd
  153. }
  154. func fqn(prefix, name string) string {
  155. if prefix == "" {
  156. return name
  157. }
  158. return prefix + "." + name
  159. }
  160. // fileDescForType gets the file descriptor for the given type.
  161. // The given type should be a proto message.
  162. func (s *serverReflectionServer) fileDescForType(st reflect.Type) (*dpb.FileDescriptorProto, error) {
  163. m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(protoMessage)
  164. if !ok {
  165. return nil, fmt.Errorf("failed to create message from type: %v", st)
  166. }
  167. enc, _ := m.Descriptor()
  168. return decodeFileDesc(enc)
  169. }
  170. // decodeFileDesc does decompression and unmarshalling on the given
  171. // file descriptor byte slice.
  172. func decodeFileDesc(enc []byte) (*dpb.FileDescriptorProto, error) {
  173. raw, err := decompress(enc)
  174. if err != nil {
  175. return nil, fmt.Errorf("failed to decompress enc: %v", err)
  176. }
  177. fd := new(dpb.FileDescriptorProto)
  178. if err := proto.Unmarshal(raw, fd); err != nil {
  179. return nil, fmt.Errorf("bad descriptor: %v", err)
  180. }
  181. return fd, nil
  182. }
  183. // decompress does gzip decompression.
  184. func decompress(b []byte) ([]byte, error) {
  185. r, err := gzip.NewReader(bytes.NewReader(b))
  186. if err != nil {
  187. return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
  188. }
  189. out, err := ioutil.ReadAll(r)
  190. if err != nil {
  191. return nil, fmt.Errorf("bad gzipped descriptor: %v", err)
  192. }
  193. return out, nil
  194. }
  195. func typeForName(name string) (reflect.Type, error) {
  196. pt := proto.MessageType(name)
  197. if pt == nil {
  198. return nil, fmt.Errorf("unknown type: %q", name)
  199. }
  200. st := pt.Elem()
  201. return st, nil
  202. }
  203. func fileDescContainingExtension(st reflect.Type, ext int32) (*dpb.FileDescriptorProto, error) {
  204. m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
  205. if !ok {
  206. return nil, fmt.Errorf("failed to create message from type: %v", st)
  207. }
  208. var extDesc *proto.ExtensionDesc
  209. for id, desc := range proto.RegisteredExtensions(m) {
  210. if id == ext {
  211. extDesc = desc
  212. break
  213. }
  214. }
  215. if extDesc == nil {
  216. return nil, fmt.Errorf("failed to find registered extension for extension number %v", ext)
  217. }
  218. return decodeFileDesc(proto.FileDescriptor(extDesc.Filename))
  219. }
  220. func (s *serverReflectionServer) allExtensionNumbersForType(st reflect.Type) ([]int32, error) {
  221. m, ok := reflect.Zero(reflect.PtrTo(st)).Interface().(proto.Message)
  222. if !ok {
  223. return nil, fmt.Errorf("failed to create message from type: %v", st)
  224. }
  225. exts := proto.RegisteredExtensions(m)
  226. out := make([]int32, 0, len(exts))
  227. for id := range exts {
  228. out = append(out, id)
  229. }
  230. return out, nil
  231. }
  232. // fileDescEncodingByFilename finds the file descriptor for given filename,
  233. // does marshalling on it and returns the marshalled result.
  234. func (s *serverReflectionServer) fileDescEncodingByFilename(name string) ([]byte, error) {
  235. enc := proto.FileDescriptor(name)
  236. if enc == nil {
  237. return nil, fmt.Errorf("unknown file: %v", name)
  238. }
  239. fd, err := decodeFileDesc(enc)
  240. if err != nil {
  241. return nil, err
  242. }
  243. return proto.Marshal(fd)
  244. }
  245. // parseMetadata finds the file descriptor bytes specified meta.
  246. // For SupportPackageIsVersion4, m is the name of the proto file, we
  247. // call proto.FileDescriptor to get the byte slice.
  248. // For SupportPackageIsVersion3, m is a byte slice itself.
  249. func parseMetadata(meta interface{}) ([]byte, bool) {
  250. // Check if meta is the file name.
  251. if fileNameForMeta, ok := meta.(string); ok {
  252. return proto.FileDescriptor(fileNameForMeta), true
  253. }
  254. // Check if meta is the byte slice.
  255. if enc, ok := meta.([]byte); ok {
  256. return enc, true
  257. }
  258. return nil, false
  259. }
  260. // fileDescEncodingContainingSymbol finds the file descriptor containing the given symbol,
  261. // does marshalling on it and returns the marshalled result.
  262. // The given symbol can be a type, a service or a method.
  263. func (s *serverReflectionServer) fileDescEncodingContainingSymbol(name string) ([]byte, error) {
  264. _, symbols := s.getSymbols()
  265. fd := symbols[name]
  266. if fd == nil {
  267. // Check if it's a type name that was not present in the
  268. // transitive dependencies of the registered services.
  269. if st, err := typeForName(name); err == nil {
  270. fd, err = s.fileDescForType(st)
  271. if err != nil {
  272. return nil, err
  273. }
  274. }
  275. }
  276. if fd == nil {
  277. return nil, fmt.Errorf("unknown symbol: %v", name)
  278. }
  279. return proto.Marshal(fd)
  280. }
  281. // fileDescEncodingContainingExtension finds the file descriptor containing given extension,
  282. // does marshalling on it and returns the marshalled result.
  283. func (s *serverReflectionServer) fileDescEncodingContainingExtension(typeName string, extNum int32) ([]byte, error) {
  284. st, err := typeForName(typeName)
  285. if err != nil {
  286. return nil, err
  287. }
  288. fd, err := fileDescContainingExtension(st, extNum)
  289. if err != nil {
  290. return nil, err
  291. }
  292. return proto.Marshal(fd)
  293. }
  294. // allExtensionNumbersForTypeName returns all extension numbers for the given type.
  295. func (s *serverReflectionServer) allExtensionNumbersForTypeName(name string) ([]int32, error) {
  296. st, err := typeForName(name)
  297. if err != nil {
  298. return nil, err
  299. }
  300. extNums, err := s.allExtensionNumbersForType(st)
  301. if err != nil {
  302. return nil, err
  303. }
  304. return extNums, nil
  305. }
  306. // ServerReflectionInfo is the reflection service handler.
  307. func (s *serverReflectionServer) ServerReflectionInfo(stream rpb.ServerReflection_ServerReflectionInfoServer) error {
  308. for {
  309. in, err := stream.Recv()
  310. if err == io.EOF {
  311. return nil
  312. }
  313. if err != nil {
  314. return err
  315. }
  316. out := &rpb.ServerReflectionResponse{
  317. ValidHost: in.Host,
  318. OriginalRequest: in,
  319. }
  320. switch req := in.MessageRequest.(type) {
  321. case *rpb.ServerReflectionRequest_FileByFilename:
  322. b, err := s.fileDescEncodingByFilename(req.FileByFilename)
  323. if err != nil {
  324. out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
  325. ErrorResponse: &rpb.ErrorResponse{
  326. ErrorCode: int32(codes.NotFound),
  327. ErrorMessage: err.Error(),
  328. },
  329. }
  330. } else {
  331. out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
  332. FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
  333. }
  334. }
  335. case *rpb.ServerReflectionRequest_FileContainingSymbol:
  336. b, err := s.fileDescEncodingContainingSymbol(req.FileContainingSymbol)
  337. if err != nil {
  338. out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
  339. ErrorResponse: &rpb.ErrorResponse{
  340. ErrorCode: int32(codes.NotFound),
  341. ErrorMessage: err.Error(),
  342. },
  343. }
  344. } else {
  345. out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
  346. FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
  347. }
  348. }
  349. case *rpb.ServerReflectionRequest_FileContainingExtension:
  350. typeName := req.FileContainingExtension.ContainingType
  351. extNum := req.FileContainingExtension.ExtensionNumber
  352. b, err := s.fileDescEncodingContainingExtension(typeName, extNum)
  353. if err != nil {
  354. out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
  355. ErrorResponse: &rpb.ErrorResponse{
  356. ErrorCode: int32(codes.NotFound),
  357. ErrorMessage: err.Error(),
  358. },
  359. }
  360. } else {
  361. out.MessageResponse = &rpb.ServerReflectionResponse_FileDescriptorResponse{
  362. FileDescriptorResponse: &rpb.FileDescriptorResponse{FileDescriptorProto: [][]byte{b}},
  363. }
  364. }
  365. case *rpb.ServerReflectionRequest_AllExtensionNumbersOfType:
  366. extNums, err := s.allExtensionNumbersForTypeName(req.AllExtensionNumbersOfType)
  367. if err != nil {
  368. out.MessageResponse = &rpb.ServerReflectionResponse_ErrorResponse{
  369. ErrorResponse: &rpb.ErrorResponse{
  370. ErrorCode: int32(codes.NotFound),
  371. ErrorMessage: err.Error(),
  372. },
  373. }
  374. } else {
  375. out.MessageResponse = &rpb.ServerReflectionResponse_AllExtensionNumbersResponse{
  376. AllExtensionNumbersResponse: &rpb.ExtensionNumberResponse{
  377. BaseTypeName: req.AllExtensionNumbersOfType,
  378. ExtensionNumber: extNums,
  379. },
  380. }
  381. }
  382. case *rpb.ServerReflectionRequest_ListServices:
  383. svcNames, _ := s.getSymbols()
  384. serviceResponses := make([]*rpb.ServiceResponse, len(svcNames))
  385. for i, n := range svcNames {
  386. serviceResponses[i] = &rpb.ServiceResponse{
  387. Name: n,
  388. }
  389. }
  390. out.MessageResponse = &rpb.ServerReflectionResponse_ListServicesResponse{
  391. ListServicesResponse: &rpb.ListServiceResponse{
  392. Service: serviceResponses,
  393. },
  394. }
  395. default:
  396. return status.Errorf(codes.InvalidArgument, "invalid MessageRequest: %v", in.MessageRequest)
  397. }
  398. if err := stream.Send(out); err != nil {
  399. return err
  400. }
  401. }
  402. }