grpclb_test.go 26 KB


  1. /*
  2. *
  3. * Copyright 2016 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package grpclb
  19. import (
  20. "context"
  21. "errors"
  22. "fmt"
  23. "io"
  24. "net"
  25. "strconv"
  26. "strings"
  27. "sync"
  28. "sync/atomic"
  29. "testing"
  30. "time"
  31. durationpb "github.com/golang/protobuf/ptypes/duration"
  32. "google.golang.org/grpc"
  33. "google.golang.org/grpc/balancer"
  34. lbgrpc "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
  35. lbpb "google.golang.org/grpc/balancer/grpclb/grpc_lb_v1"
  36. "google.golang.org/grpc/codes"
  37. "google.golang.org/grpc/credentials"
  38. _ "google.golang.org/grpc/grpclog/glogger"
  39. "google.golang.org/grpc/internal/leakcheck"
  40. "google.golang.org/grpc/metadata"
  41. "google.golang.org/grpc/peer"
  42. "google.golang.org/grpc/resolver"
  43. "google.golang.org/grpc/resolver/manual"
  44. "google.golang.org/grpc/status"
  45. testpb "google.golang.org/grpc/test/grpc_testing"
  46. )
  47. var (
  48. lbServerName = "bar.com"
  49. beServerName = "foo.com"
  50. lbToken = "iamatoken"
  51. // Resolver replaces localhost with fakeName in Next().
  52. // Dialer replaces fakeName with localhost when dialing.
  53. // This will test that custom dialer is passed from Dial to grpclb.
  54. fakeName = "fake.Name"
  55. )
  56. type serverNameCheckCreds struct {
  57. mu sync.Mutex
  58. sn string
  59. expected string
  60. }
  61. func (c *serverNameCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  62. if _, err := io.WriteString(rawConn, c.sn); err != nil {
  63. fmt.Printf("Failed to write the server name %s to the client %v", c.sn, err)
  64. return nil, nil, err
  65. }
  66. return rawConn, nil, nil
  67. }
  68. func (c *serverNameCheckCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  69. c.mu.Lock()
  70. defer c.mu.Unlock()
  71. b := make([]byte, len(c.expected))
  72. errCh := make(chan error, 1)
  73. go func() {
  74. _, err := rawConn.Read(b)
  75. errCh <- err
  76. }()
  77. select {
  78. case err := <-errCh:
  79. if err != nil {
  80. fmt.Printf("Failed to read the server name from the server %v", err)
  81. return nil, nil, err
  82. }
  83. case <-ctx.Done():
  84. return nil, nil, ctx.Err()
  85. }
  86. if c.expected != string(b) {
  87. fmt.Printf("Read the server name %s want %s", string(b), c.expected)
  88. return nil, nil, errors.New("received unexpected server name")
  89. }
  90. return rawConn, nil, nil
  91. }
  92. func (c *serverNameCheckCreds) Info() credentials.ProtocolInfo {
  93. c.mu.Lock()
  94. defer c.mu.Unlock()
  95. return credentials.ProtocolInfo{}
  96. }
  97. func (c *serverNameCheckCreds) Clone() credentials.TransportCredentials {
  98. c.mu.Lock()
  99. defer c.mu.Unlock()
  100. return &serverNameCheckCreds{
  101. expected: c.expected,
  102. }
  103. }
  104. func (c *serverNameCheckCreds) OverrideServerName(s string) error {
  105. c.mu.Lock()
  106. defer c.mu.Unlock()
  107. c.expected = s
  108. return nil
  109. }
  110. // fakeNameDialer replaces fakeName with localhost when dialing.
  111. // This will test that custom dialer is passed from Dial to grpclb.
  112. func fakeNameDialer(addr string, timeout time.Duration) (net.Conn, error) {
  113. addr = strings.Replace(addr, fakeName, "localhost", 1)
  114. return net.DialTimeout("tcp", addr, timeout)
  115. }
  116. // merge merges the new client stats into current stats.
  117. //
  118. // It's a test-only method. rpcStats is defined in grpclb_picker.
  119. func (s *rpcStats) merge(cs *lbpb.ClientStats) {
  120. atomic.AddInt64(&s.numCallsStarted, cs.NumCallsStarted)
  121. atomic.AddInt64(&s.numCallsFinished, cs.NumCallsFinished)
  122. atomic.AddInt64(&s.numCallsFinishedWithClientFailedToSend, cs.NumCallsFinishedWithClientFailedToSend)
  123. atomic.AddInt64(&s.numCallsFinishedKnownReceived, cs.NumCallsFinishedKnownReceived)
  124. s.mu.Lock()
  125. for _, perToken := range cs.CallsFinishedWithDrop {
  126. s.numCallsDropped[perToken.LoadBalanceToken] += perToken.NumCalls
  127. }
  128. s.mu.Unlock()
  129. }
  130. func mapsEqual(a, b map[string]int64) bool {
  131. if len(a) != len(b) {
  132. return false
  133. }
  134. for k, v1 := range a {
  135. if v2, ok := b[k]; !ok || v1 != v2 {
  136. return false
  137. }
  138. }
  139. return true
  140. }
  141. func atomicEqual(a, b *int64) bool {
  142. return atomic.LoadInt64(a) == atomic.LoadInt64(b)
  143. }
  144. // equal compares two rpcStats.
  145. //
  146. // It's a test-only method. rpcStats is defined in grpclb_picker.
  147. func (s *rpcStats) equal(o *rpcStats) bool {
  148. if !atomicEqual(&s.numCallsStarted, &o.numCallsStarted) {
  149. return false
  150. }
  151. if !atomicEqual(&s.numCallsFinished, &o.numCallsFinished) {
  152. return false
  153. }
  154. if !atomicEqual(&s.numCallsFinishedWithClientFailedToSend, &o.numCallsFinishedWithClientFailedToSend) {
  155. return false
  156. }
  157. if !atomicEqual(&s.numCallsFinishedKnownReceived, &o.numCallsFinishedKnownReceived) {
  158. return false
  159. }
  160. s.mu.Lock()
  161. defer s.mu.Unlock()
  162. o.mu.Lock()
  163. defer o.mu.Unlock()
  164. if !mapsEqual(s.numCallsDropped, o.numCallsDropped) {
  165. return false
  166. }
  167. return true
  168. }
  169. type remoteBalancer struct {
  170. sls chan *lbpb.ServerList
  171. statsDura time.Duration
  172. done chan struct{}
  173. stats *rpcStats
  174. }
  175. func newRemoteBalancer(intervals []time.Duration) *remoteBalancer {
  176. return &remoteBalancer{
  177. sls: make(chan *lbpb.ServerList, 1),
  178. done: make(chan struct{}),
  179. stats: newRPCStats(),
  180. }
  181. }
  182. func (b *remoteBalancer) stop() {
  183. close(b.sls)
  184. close(b.done)
  185. }
  186. func (b *remoteBalancer) BalanceLoad(stream lbgrpc.LoadBalancer_BalanceLoadServer) error {
  187. req, err := stream.Recv()
  188. if err != nil {
  189. return err
  190. }
  191. initReq := req.GetInitialRequest()
  192. if initReq.Name != beServerName {
  193. return status.Errorf(codes.InvalidArgument, "invalid service name: %v", initReq.Name)
  194. }
  195. resp := &lbpb.LoadBalanceResponse{
  196. LoadBalanceResponseType: &lbpb.LoadBalanceResponse_InitialResponse{
  197. InitialResponse: &lbpb.InitialLoadBalanceResponse{
  198. ClientStatsReportInterval: &durationpb.Duration{
  199. Seconds: int64(b.statsDura.Seconds()),
  200. Nanos: int32(b.statsDura.Nanoseconds() - int64(b.statsDura.Seconds())*1e9),
  201. },
  202. },
  203. },
  204. }
  205. if err := stream.Send(resp); err != nil {
  206. return err
  207. }
  208. go func() {
  209. for {
  210. var (
  211. req *lbpb.LoadBalanceRequest
  212. err error
  213. )
  214. if req, err = stream.Recv(); err != nil {
  215. return
  216. }
  217. b.stats.merge(req.GetClientStats())
  218. }
  219. }()
  220. for v := range b.sls {
  221. resp = &lbpb.LoadBalanceResponse{
  222. LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{
  223. ServerList: v,
  224. },
  225. }
  226. if err := stream.Send(resp); err != nil {
  227. return err
  228. }
  229. }
  230. <-b.done
  231. return nil
  232. }
  233. type testServer struct {
  234. testpb.TestServiceServer
  235. addr string
  236. fallback bool
  237. }
  238. const testmdkey = "testmd"
  239. func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
  240. md, ok := metadata.FromIncomingContext(ctx)
  241. if !ok {
  242. return nil, status.Error(codes.Internal, "failed to receive metadata")
  243. }
  244. if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) {
  245. return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md)
  246. }
  247. grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
  248. return &testpb.Empty{}, nil
  249. }
  250. func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
  251. return nil
  252. }
  253. func startBackends(sn string, fallback bool, lis ...net.Listener) (servers []*grpc.Server) {
  254. for _, l := range lis {
  255. creds := &serverNameCheckCreds{
  256. sn: sn,
  257. }
  258. s := grpc.NewServer(grpc.Creds(creds))
  259. testpb.RegisterTestServiceServer(s, &testServer{addr: l.Addr().String(), fallback: fallback})
  260. servers = append(servers, s)
  261. go func(s *grpc.Server, l net.Listener) {
  262. s.Serve(l)
  263. }(s, l)
  264. }
  265. return
  266. }
  267. func stopBackends(servers []*grpc.Server) {
  268. for _, s := range servers {
  269. s.Stop()
  270. }
  271. }
  272. type testServers struct {
  273. lbAddr string
  274. ls *remoteBalancer
  275. lb *grpc.Server
  276. beIPs []net.IP
  277. bePorts []int
  278. }
  279. func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
  280. var (
  281. beListeners []net.Listener
  282. ls *remoteBalancer
  283. lb *grpc.Server
  284. beIPs []net.IP
  285. bePorts []int
  286. )
  287. for i := 0; i < numberOfBackends; i++ {
  288. // Start a backend.
  289. beLis, e := net.Listen("tcp", "localhost:0")
  290. if e != nil {
  291. err = fmt.Errorf("Failed to listen %v", err)
  292. return
  293. }
  294. beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
  295. bePorts = append(bePorts, beLis.Addr().(*net.TCPAddr).Port)
  296. beListeners = append(beListeners, beLis)
  297. }
  298. backends := startBackends(beServerName, false, beListeners...)
  299. // Start a load balancer.
  300. lbLis, err := net.Listen("tcp", "localhost:0")
  301. if err != nil {
  302. err = fmt.Errorf("Failed to create the listener for the load balancer %v", err)
  303. return
  304. }
  305. lbCreds := &serverNameCheckCreds{
  306. sn: lbServerName,
  307. }
  308. lb = grpc.NewServer(grpc.Creds(lbCreds))
  309. ls = newRemoteBalancer(nil)
  310. lbgrpc.RegisterLoadBalancerServer(lb, ls)
  311. go func() {
  312. lb.Serve(lbLis)
  313. }()
  314. tss = &testServers{
  315. lbAddr: fakeName + ":" + strconv.Itoa(lbLis.Addr().(*net.TCPAddr).Port),
  316. ls: ls,
  317. lb: lb,
  318. beIPs: beIPs,
  319. bePorts: bePorts,
  320. }
  321. cleanup = func() {
  322. defer stopBackends(backends)
  323. defer func() {
  324. ls.stop()
  325. lb.Stop()
  326. }()
  327. }
  328. return
  329. }
  330. func TestGRPCLB(t *testing.T) {
  331. defer leakcheck.Check(t)
  332. r, cleanup := manual.GenerateAndRegisterManualResolver()
  333. defer cleanup()
  334. tss, cleanup, err := newLoadBalancer(1)
  335. if err != nil {
  336. t.Fatalf("failed to create new load balancer: %v", err)
  337. }
  338. defer cleanup()
  339. be := &lbpb.Server{
  340. IpAddress: tss.beIPs[0],
  341. Port: int32(tss.bePorts[0]),
  342. LoadBalanceToken: lbToken,
  343. }
  344. var bes []*lbpb.Server
  345. bes = append(bes, be)
  346. sl := &lbpb.ServerList{
  347. Servers: bes,
  348. }
  349. tss.ls.sls <- sl
  350. creds := serverNameCheckCreds{
  351. expected: beServerName,
  352. }
  353. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  354. defer cancel()
  355. cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
  356. grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  357. if err != nil {
  358. t.Fatalf("Failed to dial to the backend %v", err)
  359. }
  360. defer cc.Close()
  361. testC := testpb.NewTestServiceClient(cc)
  362. r.NewAddress([]resolver.Address{{
  363. Addr: tss.lbAddr,
  364. Type: resolver.GRPCLB,
  365. ServerName: lbServerName,
  366. }})
  367. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
  368. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  369. }
  370. }
  371. // The remote balancer sends response with duplicates to grpclb client.
  372. func TestGRPCLBWeighted(t *testing.T) {
  373. defer leakcheck.Check(t)
  374. r, cleanup := manual.GenerateAndRegisterManualResolver()
  375. defer cleanup()
  376. tss, cleanup, err := newLoadBalancer(2)
  377. if err != nil {
  378. t.Fatalf("failed to create new load balancer: %v", err)
  379. }
  380. defer cleanup()
  381. beServers := []*lbpb.Server{{
  382. IpAddress: tss.beIPs[0],
  383. Port: int32(tss.bePorts[0]),
  384. LoadBalanceToken: lbToken,
  385. }, {
  386. IpAddress: tss.beIPs[1],
  387. Port: int32(tss.bePorts[1]),
  388. LoadBalanceToken: lbToken,
  389. }}
  390. portsToIndex := make(map[int]int)
  391. for i := range beServers {
  392. portsToIndex[tss.bePorts[i]] = i
  393. }
  394. creds := serverNameCheckCreds{
  395. expected: beServerName,
  396. }
  397. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  398. defer cancel()
  399. cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
  400. grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  401. if err != nil {
  402. t.Fatalf("Failed to dial to the backend %v", err)
  403. }
  404. defer cc.Close()
  405. testC := testpb.NewTestServiceClient(cc)
  406. r.NewAddress([]resolver.Address{{
  407. Addr: tss.lbAddr,
  408. Type: resolver.GRPCLB,
  409. ServerName: lbServerName,
  410. }})
  411. sequences := []string{"00101", "00011"}
  412. for _, seq := range sequences {
  413. var (
  414. bes []*lbpb.Server
  415. p peer.Peer
  416. result string
  417. )
  418. for _, s := range seq {
  419. bes = append(bes, beServers[s-'0'])
  420. }
  421. tss.ls.sls <- &lbpb.ServerList{Servers: bes}
  422. for i := 0; i < 1000; i++ {
  423. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
  424. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  425. }
  426. result += strconv.Itoa(portsToIndex[p.Addr.(*net.TCPAddr).Port])
  427. }
  428. // The generated result will be in format of "0010100101".
  429. if !strings.Contains(result, strings.Repeat(seq, 2)) {
  430. t.Errorf("got result sequence %q, want patten %q", result, seq)
  431. }
  432. }
  433. }
  434. func TestDropRequest(t *testing.T) {
  435. defer leakcheck.Check(t)
  436. r, cleanup := manual.GenerateAndRegisterManualResolver()
  437. defer cleanup()
  438. tss, cleanup, err := newLoadBalancer(1)
  439. if err != nil {
  440. t.Fatalf("failed to create new load balancer: %v", err)
  441. }
  442. defer cleanup()
  443. tss.ls.sls <- &lbpb.ServerList{
  444. Servers: []*lbpb.Server{{
  445. IpAddress: tss.beIPs[0],
  446. Port: int32(tss.bePorts[0]),
  447. LoadBalanceToken: lbToken,
  448. Drop: false,
  449. }, {
  450. Drop: true,
  451. }},
  452. }
  453. creds := serverNameCheckCreds{
  454. expected: beServerName,
  455. }
  456. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  457. defer cancel()
  458. cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
  459. grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  460. if err != nil {
  461. t.Fatalf("Failed to dial to the backend %v", err)
  462. }
  463. defer cc.Close()
  464. testC := testpb.NewTestServiceClient(cc)
  465. r.NewAddress([]resolver.Address{{
  466. Addr: tss.lbAddr,
  467. Type: resolver.GRPCLB,
  468. ServerName: lbServerName,
  469. }})
  470. // Wait for the 1st, non-fail-fast RPC to succeed. This ensures both server
  471. // connections are made, because the first one has DropForLoadBalancing set
  472. // to true.
  473. var i int
  474. for i = 0; i < 1000; i++ {
  475. if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil {
  476. break
  477. }
  478. time.Sleep(time.Millisecond)
  479. }
  480. if i >= 1000 {
  481. t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", testC, err)
  482. }
  483. select {
  484. case <-ctx.Done():
  485. t.Fatal("timed out", ctx.Err())
  486. default:
  487. }
  488. for _, failfast := range []bool{true, false} {
  489. for i := 0; i < 3; i++ {
  490. // Even RPCs should fail, because the 2st backend has
  491. // DropForLoadBalancing set to true.
  492. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); status.Code(err) != codes.Unavailable {
  493. t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.Unavailable)
  494. }
  495. // Odd RPCs should succeed since they choose the non-drop-request
  496. // backend according to the round robin policy.
  497. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(failfast)); err != nil {
  498. t.Errorf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  499. }
  500. }
  501. }
  502. }
  503. // When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
  504. func TestBalancerDisconnects(t *testing.T) {
  505. defer leakcheck.Check(t)
  506. r, cleanup := manual.GenerateAndRegisterManualResolver()
  507. defer cleanup()
  508. var (
  509. tests []*testServers
  510. lbs []*grpc.Server
  511. )
  512. for i := 0; i < 2; i++ {
  513. tss, cleanup, err := newLoadBalancer(1)
  514. if err != nil {
  515. t.Fatalf("failed to create new load balancer: %v", err)
  516. }
  517. defer cleanup()
  518. be := &lbpb.Server{
  519. IpAddress: tss.beIPs[0],
  520. Port: int32(tss.bePorts[0]),
  521. LoadBalanceToken: lbToken,
  522. }
  523. var bes []*lbpb.Server
  524. bes = append(bes, be)
  525. sl := &lbpb.ServerList{
  526. Servers: bes,
  527. }
  528. tss.ls.sls <- sl
  529. tests = append(tests, tss)
  530. lbs = append(lbs, tss.lb)
  531. }
  532. creds := serverNameCheckCreds{
  533. expected: beServerName,
  534. }
  535. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  536. defer cancel()
  537. cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
  538. grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  539. if err != nil {
  540. t.Fatalf("Failed to dial to the backend %v", err)
  541. }
  542. defer cc.Close()
  543. testC := testpb.NewTestServiceClient(cc)
  544. r.NewAddress([]resolver.Address{{
  545. Addr: tests[0].lbAddr,
  546. Type: resolver.GRPCLB,
  547. ServerName: lbServerName,
  548. }, {
  549. Addr: tests[1].lbAddr,
  550. Type: resolver.GRPCLB,
  551. ServerName: lbServerName,
  552. }})
  553. var p peer.Peer
  554. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
  555. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  556. }
  557. if p.Addr.(*net.TCPAddr).Port != tests[0].bePorts[0] {
  558. t.Fatalf("got peer: %v, want peer port: %v", p.Addr, tests[0].bePorts[0])
  559. }
  560. lbs[0].Stop()
  561. // Stop balancer[0], balancer[1] should be used by grpclb.
  562. // Check peer address to see if that happened.
  563. for i := 0; i < 1000; i++ {
  564. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
  565. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  566. }
  567. if p.Addr.(*net.TCPAddr).Port == tests[1].bePorts[0] {
  568. return
  569. }
  570. time.Sleep(time.Millisecond)
  571. }
  572. t.Fatalf("No RPC sent to second backend after 1 second")
  573. }
  574. type customGRPCLBBuilder struct {
  575. balancer.Builder
  576. name string
  577. }
  578. func (b *customGRPCLBBuilder) Name() string {
  579. return b.name
  580. }
  581. const grpclbCustomFallbackName = "grpclb_with_custom_fallback_timeout"
  582. func init() {
  583. balancer.Register(&customGRPCLBBuilder{
  584. Builder: newLBBuilderWithFallbackTimeout(100 * time.Millisecond),
  585. name: grpclbCustomFallbackName,
  586. })
  587. }
  588. func TestFallback(t *testing.T) {
  589. defer leakcheck.Check(t)
  590. r, cleanup := manual.GenerateAndRegisterManualResolver()
  591. defer cleanup()
  592. tss, cleanup, err := newLoadBalancer(1)
  593. if err != nil {
  594. t.Fatalf("failed to create new load balancer: %v", err)
  595. }
  596. defer cleanup()
  597. // Start a standalone backend.
  598. beLis, err := net.Listen("tcp", "localhost:0")
  599. if err != nil {
  600. t.Fatalf("Failed to listen %v", err)
  601. }
  602. defer beLis.Close()
  603. standaloneBEs := startBackends(beServerName, true, beLis)
  604. defer stopBackends(standaloneBEs)
  605. be := &lbpb.Server{
  606. IpAddress: tss.beIPs[0],
  607. Port: int32(tss.bePorts[0]),
  608. LoadBalanceToken: lbToken,
  609. }
  610. var bes []*lbpb.Server
  611. bes = append(bes, be)
  612. sl := &lbpb.ServerList{
  613. Servers: bes,
  614. }
  615. tss.ls.sls <- sl
  616. creds := serverNameCheckCreds{
  617. expected: beServerName,
  618. }
  619. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  620. defer cancel()
  621. cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
  622. grpc.WithBalancerName(grpclbCustomFallbackName),
  623. grpc.WithTransportCredentials(&creds), grpc.WithDialer(fakeNameDialer))
  624. if err != nil {
  625. t.Fatalf("Failed to dial to the backend %v", err)
  626. }
  627. defer cc.Close()
  628. testC := testpb.NewTestServiceClient(cc)
  629. r.NewAddress([]resolver.Address{{
  630. Addr: "",
  631. Type: resolver.GRPCLB,
  632. ServerName: lbServerName,
  633. }, {
  634. Addr: beLis.Addr().String(),
  635. Type: resolver.Backend,
  636. ServerName: beServerName,
  637. }})
  638. var p peer.Peer
  639. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
  640. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  641. }
  642. if p.Addr.String() != beLis.Addr().String() {
  643. t.Fatalf("got peer: %v, want peer: %v", p.Addr, beLis.Addr())
  644. }
  645. r.NewAddress([]resolver.Address{{
  646. Addr: tss.lbAddr,
  647. Type: resolver.GRPCLB,
  648. ServerName: lbServerName,
  649. }, {
  650. Addr: beLis.Addr().String(),
  651. Type: resolver.Backend,
  652. ServerName: beServerName,
  653. }})
  654. for i := 0; i < 1000; i++ {
  655. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
  656. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  657. }
  658. if p.Addr.(*net.TCPAddr).Port == tss.bePorts[0] {
  659. return
  660. }
  661. time.Sleep(time.Millisecond)
  662. }
  663. t.Fatalf("No RPC sent to backend behind remote balancer after 1 second")
  664. }
  665. type failPreRPCCred struct{}
  666. func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
  667. if strings.Contains(uri[0], failtosendURI) {
  668. return nil, fmt.Errorf("rpc should fail to send")
  669. }
  670. return nil, nil
  671. }
  672. func (failPreRPCCred) RequireTransportSecurity() bool {
  673. return false
  674. }
  675. func checkStats(stats, expected *rpcStats) error {
  676. if !stats.equal(expected) {
  677. return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
  678. }
  679. return nil
  680. }
  681. func runAndGetStats(t *testing.T, drop bool, runRPCs func(*grpc.ClientConn)) *rpcStats {
  682. defer leakcheck.Check(t)
  683. r, cleanup := manual.GenerateAndRegisterManualResolver()
  684. defer cleanup()
  685. tss, cleanup, err := newLoadBalancer(1)
  686. if err != nil {
  687. t.Fatalf("failed to create new load balancer: %v", err)
  688. }
  689. defer cleanup()
  690. tss.ls.sls <- &lbpb.ServerList{
  691. Servers: []*lbpb.Server{{
  692. IpAddress: tss.beIPs[0],
  693. Port: int32(tss.bePorts[0]),
  694. LoadBalanceToken: lbToken,
  695. Drop: drop,
  696. }},
  697. }
  698. tss.ls.statsDura = 100 * time.Millisecond
  699. creds := serverNameCheckCreds{expected: beServerName}
  700. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
  701. defer cancel()
  702. cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName,
  703. grpc.WithTransportCredentials(&creds),
  704. grpc.WithPerRPCCredentials(failPreRPCCred{}),
  705. grpc.WithDialer(fakeNameDialer))
  706. if err != nil {
  707. t.Fatalf("Failed to dial to the backend %v", err)
  708. }
  709. defer cc.Close()
  710. r.NewAddress([]resolver.Address{{
  711. Addr: tss.lbAddr,
  712. Type: resolver.GRPCLB,
  713. ServerName: lbServerName,
  714. }})
  715. runRPCs(cc)
  716. time.Sleep(1 * time.Second)
  717. stats := tss.ls.stats
  718. return stats
  719. }
  720. const (
  721. countRPC = 40
  722. failtosendURI = "failtosend"
  723. dropErrDesc = "request dropped by grpclb"
  724. )
  725. func TestGRPCLBStatsUnarySuccess(t *testing.T) {
  726. defer leakcheck.Check(t)
  727. stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
  728. testC := testpb.NewTestServiceClient(cc)
  729. // The first non-failfast RPC succeeds, all connections are up.
  730. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
  731. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  732. }
  733. for i := 0; i < countRPC-1; i++ {
  734. testC.EmptyCall(context.Background(), &testpb.Empty{})
  735. }
  736. })
  737. if err := checkStats(stats, &rpcStats{
  738. numCallsStarted: int64(countRPC),
  739. numCallsFinished: int64(countRPC),
  740. numCallsFinishedKnownReceived: int64(countRPC),
  741. }); err != nil {
  742. t.Fatal(err)
  743. }
  744. }
  745. func TestGRPCLBStatsUnaryDrop(t *testing.T) {
  746. defer leakcheck.Check(t)
  747. c := 0
  748. stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
  749. testC := testpb.NewTestServiceClient(cc)
  750. for {
  751. c++
  752. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
  753. if strings.Contains(err.Error(), dropErrDesc) {
  754. break
  755. }
  756. }
  757. }
  758. for i := 0; i < countRPC; i++ {
  759. testC.EmptyCall(context.Background(), &testpb.Empty{})
  760. }
  761. })
  762. if err := checkStats(stats, &rpcStats{
  763. numCallsStarted: int64(countRPC + c),
  764. numCallsFinished: int64(countRPC + c),
  765. numCallsFinishedWithClientFailedToSend: int64(c - 1),
  766. numCallsDropped: map[string]int64{lbToken: int64(countRPC + 1)},
  767. }); err != nil {
  768. t.Fatal(err)
  769. }
  770. }
  771. func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
  772. defer leakcheck.Check(t)
  773. stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
  774. testC := testpb.NewTestServiceClient(cc)
  775. // The first non-failfast RPC succeeds, all connections are up.
  776. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
  777. t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
  778. }
  779. for i := 0; i < countRPC-1; i++ {
  780. cc.Invoke(context.Background(), failtosendURI, &testpb.Empty{}, nil)
  781. }
  782. })
  783. if err := checkStats(stats, &rpcStats{
  784. numCallsStarted: int64(countRPC),
  785. numCallsFinished: int64(countRPC),
  786. numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
  787. numCallsFinishedKnownReceived: 1,
  788. }); err != nil {
  789. t.Fatal(err)
  790. }
  791. }
  792. func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
  793. defer leakcheck.Check(t)
  794. stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
  795. testC := testpb.NewTestServiceClient(cc)
  796. // The first non-failfast RPC succeeds, all connections are up.
  797. stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
  798. if err != nil {
  799. t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
  800. }
  801. for {
  802. if _, err = stream.Recv(); err == io.EOF {
  803. break
  804. }
  805. }
  806. for i := 0; i < countRPC-1; i++ {
  807. stream, err = testC.FullDuplexCall(context.Background())
  808. if err == nil {
  809. // Wait for stream to end if err is nil.
  810. for {
  811. if _, err = stream.Recv(); err == io.EOF {
  812. break
  813. }
  814. }
  815. }
  816. }
  817. })
  818. if err := checkStats(stats, &rpcStats{
  819. numCallsStarted: int64(countRPC),
  820. numCallsFinished: int64(countRPC),
  821. numCallsFinishedKnownReceived: int64(countRPC),
  822. }); err != nil {
  823. t.Fatal(err)
  824. }
  825. }
  826. func TestGRPCLBStatsStreamingDrop(t *testing.T) {
  827. defer leakcheck.Check(t)
  828. c := 0
  829. stats := runAndGetStats(t, true, func(cc *grpc.ClientConn) {
  830. testC := testpb.NewTestServiceClient(cc)
  831. for {
  832. c++
  833. if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
  834. if strings.Contains(err.Error(), dropErrDesc) {
  835. break
  836. }
  837. }
  838. }
  839. for i := 0; i < countRPC; i++ {
  840. testC.FullDuplexCall(context.Background())
  841. }
  842. })
  843. if err := checkStats(stats, &rpcStats{
  844. numCallsStarted: int64(countRPC + c),
  845. numCallsFinished: int64(countRPC + c),
  846. numCallsFinishedWithClientFailedToSend: int64(c - 1),
  847. numCallsDropped: map[string]int64{lbToken: int64(countRPC + 1)},
  848. }); err != nil {
  849. t.Fatal(err)
  850. }
  851. }
  852. func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
  853. defer leakcheck.Check(t)
  854. stats := runAndGetStats(t, false, func(cc *grpc.ClientConn) {
  855. testC := testpb.NewTestServiceClient(cc)
  856. // The first non-failfast RPC succeeds, all connections are up.
  857. stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
  858. if err != nil {
  859. t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
  860. }
  861. for {
  862. if _, err = stream.Recv(); err == io.EOF {
  863. break
  864. }
  865. }
  866. for i := 0; i < countRPC-1; i++ {
  867. cc.NewStream(context.Background(), &grpc.StreamDesc{}, failtosendURI)
  868. }
  869. })
  870. if err := checkStats(stats, &rpcStats{
  871. numCallsStarted: int64(countRPC),
  872. numCallsFinished: int64(countRPC),
  873. numCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
  874. numCallsFinishedKnownReceived: 1,
  875. }); err != nil {
  876. t.Fatal(err)
  877. }
  878. }