server_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613
  1. package rpc
  2. import (
  3. "errors"
  4. "fmt"
  5. "log"
  6. "net"
  7. "runtime"
  8. "strings"
  9. "sync"
  10. "sync/atomic"
  11. "testing"
  12. "time"
  13. "go-common/library/net/rpc/context"
  14. )
  15. var (
  16. newSvr, newInterceptorServer *Server
  17. serverAddr, newServerAddr, newServerInterceptorAddr string
  18. once, newOnce, newInterceptorOnce sync.Once
  19. testInterceptor *TestInterceptor
  20. statCount uint64
  21. )
  22. const (
  23. _testToken = "test_token"
  24. )
  25. type TestInterceptor struct {
  26. Token, RateMethod string
  27. }
  28. type Args struct {
  29. A, B int
  30. }
  31. type Reply struct {
  32. C int
  33. }
  34. type Arith int
  35. // Some of Arith's methods have value args, some have pointer args. That's deliberate.
  36. func (t *TestInterceptor) Rate(c context.Context) error {
  37. log.Printf("Interceptor rate method: %s, current: %s", t.RateMethod, c.ServiceMethod())
  38. if t.RateMethod == c.ServiceMethod() {
  39. return fmt.Errorf("Interceptor rate method: %s, time: %s", c.ServiceMethod(), c.Now())
  40. }
  41. return nil
  42. }
  43. func (t *TestInterceptor) Auth(c context.Context, addr net.Addr, token string) error {
  44. if t.Token != token {
  45. return fmt.Errorf("Interceptor auth token: %s, ip: %s seq: %d failed", token, addr, c.Seq())
  46. }
  47. log.Printf("Interceptor auth token: %s, ip: %s, seq: %d ok", token, addr, c.Seq())
  48. return nil
  49. }
  50. func (t *TestInterceptor) Stat(c context.Context, args interface{}, err error) {
  51. atomic.AddUint64(&statCount, 1)
  52. }
  53. func (t *Arith) Auth(c context.Context, args Auth, reply *Reply) error {
  54. return nil
  55. }
  56. func (t *Arith) Add(c context.Context, args Args, reply *Reply) error {
  57. reply.C = args.A + args.B
  58. return nil
  59. }
  60. func (t *Arith) Mul(c context.Context, args *Args, reply *Reply) error {
  61. reply.C = args.A * args.B
  62. return nil
  63. }
  64. func (t *Arith) Div(c context.Context, args Args, reply *Reply) error {
  65. if args.B == 0 {
  66. return errors.New("divide by zero")
  67. }
  68. reply.C = args.A / args.B
  69. return nil
  70. }
  71. func (t *Arith) String(c context.Context, args *Args, reply *string) error {
  72. *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
  73. return nil
  74. }
  75. func (t *Arith) Scan(c context.Context, args string, reply *Reply) (err error) {
  76. _, err = fmt.Sscan(args, &reply.C)
  77. return
  78. }
  79. func (t *Arith) Error(c context.Context, args *Args, reply *Reply) error {
  80. panic("ERROR")
  81. }
  82. type hidden int
  83. func (t *hidden) Exported(c context.Context, args Args, reply *Reply) error {
  84. reply.C = args.A + args.B
  85. return nil
  86. }
  87. type Embed struct {
  88. hidden
  89. }
  90. // NOTE listen and start the server
  91. func listenTCP() (net.Listener, string) {
  92. l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
  93. if e != nil {
  94. log.Fatalf("net.Listen tcp :0: %v", e)
  95. }
  96. return l, l.Addr().String()
  97. }
  98. func startServer() {
  99. Register(new(Arith))
  100. Register(new(Embed))
  101. RegisterName("net.rpc.Arith", new(Arith))
  102. var l net.Listener
  103. l, serverAddr = listenTCP()
  104. log.Println("Test RPC server listening on", serverAddr)
  105. go Accept(l)
  106. }
  107. func startNewServer() {
  108. newSvr = newServer()
  109. newSvr.Register(new(Arith))
  110. newSvr.Register(new(Embed))
  111. newSvr.RegisterName("net.rpc.Arith", new(Arith))
  112. newSvr.RegisterName("newServer.Arith", new(Arith))
  113. var l net.Listener
  114. l, newServerAddr = listenTCP()
  115. log.Println("NewServer test RPC server listening on", newServerAddr)
  116. go newSvr.Accept(l)
  117. }
  118. func startNewInterceptorServer() {
  119. testInterceptor = &TestInterceptor{Token: _testToken}
  120. newInterceptorServer = newServer()
  121. newInterceptorServer.Register(new(Arith))
  122. newInterceptorServer.Register(new(Embed))
  123. newInterceptorServer.RegisterName("net.rpc.Arith", new(Arith))
  124. newInterceptorServer.RegisterName("newServer.Arith", new(Arith))
  125. newInterceptorServer.Interceptor = testInterceptor
  126. var l net.Listener
  127. l, newServerInterceptorAddr = listenTCP()
  128. log.Println("NewInterceptorServer test RPC server listening on", newServerAddr)
  129. go newInterceptorServer.Accept(l)
  130. }
  131. // NOTE test rpc call with check expected
  132. func TestServerRPC(t *testing.T) {
  133. once.Do(startServer)
  134. newOnce.Do(startNewServer)
  135. newInterceptorOnce.Do(startNewInterceptorServer)
  136. testRPC(t, serverAddr)
  137. testRPC(t, newServerAddr)
  138. testNewServerRPC(t, newServerAddr)
  139. testNewServerAuthRPC(t, newServerInterceptorAddr)
  140. testNewInterceptorServerRPC(t, newServerInterceptorAddr)
  141. testNewInterceptorServerRateRPC(t, newServerInterceptorAddr)
  142. }
  143. func testRPC(t *testing.T, addr string) {
  144. client, err := dial("tcp", addr, time.Second)
  145. if err != nil {
  146. t.Fatal("dialing", err)
  147. }
  148. defer client.Close()
  149. // Synchronous calls
  150. args := &Args{7, 8}
  151. reply := new(Reply)
  152. err = client.Call("Arith.Add", args, reply)
  153. if err != nil {
  154. t.Errorf("Add: expected no error but got string %q", err.Error())
  155. }
  156. if reply.C != args.A+args.B {
  157. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  158. }
  159. // Methods exported from unexported embedded structs
  160. args = &Args{7, 0}
  161. reply = new(Reply)
  162. err = client.Call("Embed.Exported", args, reply)
  163. if err != nil {
  164. t.Errorf("Add: expected no error but got string %q", err.Error())
  165. }
  166. if reply.C != args.A+args.B {
  167. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  168. }
  169. // Nonexistent method
  170. args = &Args{7, 0}
  171. reply = new(Reply)
  172. err = client.Call("Arith.BadOperation", args, reply)
  173. // expect an error
  174. if err == nil {
  175. t.Error("BadOperation: expected error")
  176. } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
  177. t.Errorf("BadOperation: expected can't find method error; got %q", err)
  178. }
  179. // Unknown service
  180. args = &Args{7, 8}
  181. reply = new(Reply)
  182. err = client.Call("Arith.Unknown", args, reply)
  183. if err == nil {
  184. t.Error("expected error calling unknown service")
  185. } else if !strings.Contains(err.Error(), "method") {
  186. t.Error("expected error about method; got", err)
  187. }
  188. // Out of order.
  189. args = &Args{7, 8}
  190. mulReply := new(Reply)
  191. mulCall := client.Go("Arith.Mul", args, mulReply, nil)
  192. addReply := new(Reply)
  193. addCall := client.Go("Arith.Add", args, addReply, nil)
  194. addCall = <-addCall.Done
  195. if addCall.Error != nil {
  196. t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
  197. }
  198. if addReply.C != args.A+args.B {
  199. t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
  200. }
  201. mulCall = <-mulCall.Done
  202. if mulCall.Error != nil {
  203. t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
  204. }
  205. if mulReply.C != args.A*args.B {
  206. t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
  207. }
  208. // Error test
  209. args = &Args{7, 0}
  210. reply = new(Reply)
  211. err = client.Call("Arith.Div", args, reply)
  212. // expect an error: zero divide
  213. if err == nil {
  214. t.Error("Div: expected error")
  215. } else if err.Error() != "divide by zero" {
  216. t.Error("Div: expected divide by zero error; got", err)
  217. }
  218. // Bad type.
  219. reply = new(Reply)
  220. err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
  221. if err == nil {
  222. t.Error("expected error calling Arith.Add with wrong arg type")
  223. } else if !strings.Contains(err.Error(), "type") {
  224. t.Error("expected error about type; got", err)
  225. }
  226. // Non-struct argument
  227. const Val = 12345
  228. str := fmt.Sprint(Val)
  229. reply = new(Reply)
  230. err = client.Call("Arith.Scan", &str, reply)
  231. if err != nil {
  232. t.Errorf("Scan: expected no error but got string %q", err.Error())
  233. } else if reply.C != Val {
  234. t.Errorf("Scan: expected %d got %d", Val, reply.C)
  235. }
  236. // Non-struct reply
  237. args = &Args{27, 35}
  238. str = ""
  239. err = client.Call("Arith.String", args, &str)
  240. if err != nil {
  241. t.Errorf("String: expected no error but got string %q", err.Error())
  242. }
  243. expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
  244. if str != expect {
  245. t.Errorf("String: expected %s got %s", expect, str)
  246. }
  247. args = &Args{7, 8}
  248. reply = new(Reply)
  249. err = client.Call("Arith.Mul", args, reply)
  250. if err != nil {
  251. t.Errorf("Mul: expected no error but got string %q", err.Error())
  252. }
  253. if reply.C != args.A*args.B {
  254. t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
  255. }
  256. // ServiceName contain "." character
  257. args = &Args{7, 8}
  258. reply = new(Reply)
  259. err = client.Call("net.rpc.Arith.Add", args, reply)
  260. if err != nil {
  261. t.Errorf("Add: expected no error but got string %q", err.Error())
  262. }
  263. if reply.C != args.A+args.B {
  264. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  265. }
  266. }
  267. func testNewServerRPC(t *testing.T, addr string) {
  268. client, err := dial("tcp", addr, time.Second)
  269. if err != nil {
  270. t.Fatal("dialing", err)
  271. }
  272. defer client.Close()
  273. // Synchronous calls
  274. args := &Args{7, 8}
  275. reply := new(Reply)
  276. err = client.Call("newServer.Arith.Add", args, reply)
  277. if err != nil {
  278. t.Errorf("Add: expected no error but got string %q", err.Error())
  279. }
  280. if reply.C != args.A+args.B {
  281. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  282. }
  283. }
  284. func testNewInterceptorServerRPC(t *testing.T, addr string) {
  285. client, err := dial("tcp", addr, time.Second)
  286. if err != nil {
  287. t.Fatal("authing", err)
  288. }
  289. defer client.Close()
  290. // Synchronous calls
  291. args := &Args{7, 8}
  292. reply := new(Reply)
  293. err = client.Call("newServer.Arith.Add", args, reply)
  294. if err != nil {
  295. t.Errorf("Add: expected no error but got string %q", err.Error())
  296. }
  297. if reply.C != args.A+args.B {
  298. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  299. }
  300. }
  301. func testNewInterceptorServerRateRPC(t *testing.T, addr string) {
  302. client, err := dial("tcp", addr, time.Second)
  303. if err != nil {
  304. t.Fatal("authing", err)
  305. }
  306. defer client.Close()
  307. // Synchronous calls
  308. args := &Args{7, 8}
  309. reply := new(Reply)
  310. err = client.Call("Arith.Add", args, reply)
  311. if err != nil {
  312. t.Errorf("Add: expected no error but got string %q", err.Error())
  313. }
  314. if reply.C != args.A+args.B {
  315. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  316. }
  317. // check rate the method
  318. testInterceptor.RateMethod = "Arith.Add"
  319. args = &Args{7, 8}
  320. reply = new(Reply)
  321. err = client.Call("Arith.Add", args, reply)
  322. if err == nil {
  323. t.Errorf("Add: expected error this rate method")
  324. }
  325. }
  326. func testNewServerAuthRPC(t *testing.T, addr string) {
  327. _, err := dial("tcp", addr, time.Second)
  328. if err != nil {
  329. t.Errorf("Auth: expected error %q", err.Error())
  330. }
  331. }
  332. // NOTE test no point structs for registration
  333. type ReplyNotPointer int
  334. type ArgNotPublic int
  335. type ReplyNotPublic int
  336. type NeedsPtrType int
  337. type local struct{}
  338. func (t *ReplyNotPointer) ReplyNotPointer(c context.Context, args *Args, reply Reply) error {
  339. return nil
  340. }
  341. func (t *ArgNotPublic) ArgNotPublic(c context.Context, args *local, reply *Reply) error {
  342. return nil
  343. }
  344. func (t *ReplyNotPublic) ReplyNotPublic(c context.Context, args *Args, reply *local) error {
  345. return nil
  346. }
  347. func (t *NeedsPtrType) NeedsPtrType(c context.Context, args *Args, reply *Reply) error {
  348. return nil
  349. }
  350. // Check that registration handles lots of bad methods and a type with no suitable methods.
  351. func TestRegistrationError(t *testing.T) {
  352. err := Register(new(ReplyNotPointer))
  353. if err == nil {
  354. t.Error("expected error registering ReplyNotPointer")
  355. }
  356. err = Register(new(ArgNotPublic))
  357. if err == nil {
  358. t.Error("expected error registering ArgNotPublic")
  359. }
  360. err = Register(new(ReplyNotPublic))
  361. if err == nil {
  362. t.Error("expected error registering ReplyNotPublic")
  363. }
  364. err = Register(NeedsPtrType(0))
  365. if err == nil {
  366. t.Error("expected error registering NeedsPtrType")
  367. } else if !strings.Contains(err.Error(), "pointer") {
  368. t.Error("expected hint when registering NeedsPtrType")
  369. }
  370. }
  371. // NOTE test multiple call methods
  372. func dialDirect() (*client, error) {
  373. return dial("tcp", serverAddr, time.Second)
  374. }
  375. func countMallocs(dial func() (*client, error), t *testing.T) float64 {
  376. once.Do(startServer)
  377. client, err := dial()
  378. if err != nil {
  379. t.Fatal("error dialing", err)
  380. }
  381. defer client.Close()
  382. args := &Args{7, 8}
  383. reply := new(Reply)
  384. return testing.AllocsPerRun(100, func() {
  385. err := client.Call("Arith.Add", args, reply)
  386. if err != nil {
  387. t.Errorf("Add: expected no error but got string %q", err.Error())
  388. }
  389. if reply.C != args.A+args.B {
  390. t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
  391. }
  392. })
  393. }
  394. func TestCountMallocs(t *testing.T) {
  395. if testing.Short() {
  396. t.Skip("skipping malloc count in short mode")
  397. }
  398. if runtime.GOMAXPROCS(0) > 1 {
  399. t.Skip("skipping; GOMAXPROCS>1")
  400. }
  401. fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
  402. }
  403. func TestTCPClose(t *testing.T) {
  404. once.Do(startServer)
  405. client, err := dialDirect()
  406. if err != nil {
  407. t.Fatalf("dialing: %v", err)
  408. }
  409. defer client.Close()
  410. args := Args{17, 8}
  411. var reply Reply
  412. err = client.Call("Arith.Mul", args, &reply)
  413. if err != nil {
  414. t.Fatal("arith error:", err)
  415. }
  416. t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
  417. if reply.C != args.A*args.B {
  418. t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
  419. }
  420. }
  421. func TestErrorAfterClientClose(t *testing.T) {
  422. once.Do(startServer)
  423. client, err := dialDirect()
  424. if err != nil {
  425. t.Fatalf("dialing: %v", err)
  426. }
  427. err = client.Close()
  428. if err != nil {
  429. t.Fatal("close error:", err)
  430. }
  431. err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
  432. if err != ErrShutdown {
  433. t.Errorf("Forever: expected ErrShutdown got %v", err)
  434. }
  435. }
  436. // Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
  437. func TestAcceptExitAfterListenerClose(t *testing.T) {
  438. newSvr = newServer()
  439. newSvr.Register(new(Arith))
  440. newSvr.RegisterName("net.rpc.Arith", new(Arith))
  441. newSvr.RegisterName("newServer.Arith", new(Arith))
  442. var l net.Listener
  443. l, newServerAddr = listenTCP()
  444. l.Close()
  445. newSvr.Accept(l)
  446. }
  447. func TestParseDSN(t *testing.T) {
  448. c := parseDSN("tcp://127.0.0.1:8099")
  449. if c.Proto != "tcp" {
  450. t.Error("parse dsn proto not equal tcp")
  451. }
  452. if c.Addr != "127.0.0.1:8099" {
  453. t.Error("parse dsn addr not equal")
  454. }
  455. }
  456. func benchmarkEndToEnd(dial func() (*client, error), b *testing.B) {
  457. once.Do(startServer)
  458. client, err := dial()
  459. if err != nil {
  460. b.Fatal("error dialing:", err)
  461. }
  462. defer client.Close()
  463. // Synchronous calls
  464. args := &Args{7, 8}
  465. b.ResetTimer()
  466. b.RunParallel(func(pb *testing.PB) {
  467. reply := new(Reply)
  468. for pb.Next() {
  469. err := client.Call("Arith.Add", args, reply)
  470. if err != nil {
  471. b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
  472. }
  473. if reply.C != args.A+args.B {
  474. b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
  475. }
  476. }
  477. })
  478. }
  479. func benchmarkEndToEndAsync(dial func() (*client, error), b *testing.B) {
  480. if b.N == 0 {
  481. return
  482. }
  483. const MaxConcurrentCalls = 100
  484. once.Do(startServer)
  485. client, err := dial()
  486. if err != nil {
  487. b.Fatal("error dialing:", err)
  488. }
  489. defer client.Close()
  490. // Asynchronous calls
  491. args := &Args{7, 8}
  492. procs := 4 * runtime.GOMAXPROCS(-1)
  493. send := int32(b.N)
  494. recv := int32(b.N)
  495. var wg sync.WaitGroup
  496. wg.Add(procs)
  497. gate := make(chan bool, MaxConcurrentCalls)
  498. res := make(chan *Call, MaxConcurrentCalls)
  499. b.ResetTimer()
  500. for p := 0; p < procs; p++ {
  501. go func() {
  502. for atomic.AddInt32(&send, -1) >= 0 {
  503. gate <- true
  504. reply := new(Reply)
  505. client.Go("Arith.Add", args, reply, res)
  506. }
  507. }()
  508. go func() {
  509. for call := range res {
  510. A := call.Args.(*Args).A
  511. B := call.Args.(*Args).B
  512. C := call.Reply.(*Reply).C
  513. if A+B != C {
  514. return
  515. }
  516. <-gate
  517. if atomic.AddInt32(&recv, -1) == 0 {
  518. close(res)
  519. }
  520. }
  521. wg.Done()
  522. }()
  523. }
  524. wg.Wait()
  525. }
  526. func BenchmarkEndToEnd(b *testing.B) {
  527. benchmarkEndToEnd(dialDirect, b)
  528. }
  529. func BenchmarkEndToEndAsync(b *testing.B) {
  530. benchmarkEndToEndAsync(dialDirect, b)
  531. }