client_conn.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. package liverpc
  2. import (
  3. "context"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net"
  9. "strconv"
  10. "strings"
  11. "time"
  12. "go-common/library/conf/env"
  13. "go-common/library/log"
  14. "go-common/library/net/metadata"
  15. "go-common/library/net/trace"
  16. "github.com/gogo/protobuf/proto"
  17. "github.com/json-iterator/go"
  18. "github.com/pkg/errors"
  19. )
  20. // ClientConn connect represent a real client connection to a rpc server
  21. type ClientConn struct {
  22. addr string
  23. network string
  24. rwc io.ReadWriteCloser
  25. Timeout time.Duration
  26. DialTimeout time.Duration
  27. callInfo *callInfo
  28. }
  29. type fullReqMsg struct {
  30. Header *Header `json:"header"`
  31. HTTP *HTTP `json:"http"`
  32. Body interface{} `json:"body"`
  33. }
  34. // Dial dial a rpc server
  35. func Dial(ctx context.Context, network, addr string, timeout time.Duration, connTimeout time.Duration) (*ClientConn, error) {
  36. c := &ClientConn{
  37. addr: addr,
  38. network: network,
  39. Timeout: timeout,
  40. DialTimeout: connTimeout,
  41. }
  42. conn, err := net.DialTimeout(c.network, c.addr, c.DialTimeout)
  43. if err != nil {
  44. return nil, err
  45. }
  46. c.rwc = conn
  47. return c, err
  48. }
  49. // Close close the caller connection.
  50. func (c *ClientConn) Close() error {
  51. if c.rwc != nil {
  52. return c.rwc.Close()
  53. }
  54. return nil
  55. }
  56. func (c *ClientConn) writeRequest(ctx context.Context, req *protoReq) (err error) {
  57. var (
  58. headerBuf = make([]byte, _headerLen)
  59. header = req.Header
  60. body = req.Body
  61. )
  62. binary.BigEndian.PutUint32(headerBuf[0:4], header.magic)
  63. binary.BigEndian.PutUint32(headerBuf[4:8], header.timestamp)
  64. binary.BigEndian.PutUint32(headerBuf[8:12], header.checkSum)
  65. binary.BigEndian.PutUint32(headerBuf[12:16], header.version)
  66. binary.BigEndian.PutUint32(headerBuf[16:20], header.reserved)
  67. binary.BigEndian.PutUint32(headerBuf[20:24], header.seq)
  68. binary.BigEndian.PutUint32(headerBuf[24:28], uint32(len(body)))
  69. copy(headerBuf[28:60], header.cmd)
  70. if _, err = c.rwc.Write(headerBuf); err != nil {
  71. err = errors.Wrap(err, "write req header error")
  72. return
  73. }
  74. if log.V(2) {
  75. log.Info("liverpc body: %s", string(body))
  76. }
  77. if _, err = c.rwc.Write(body); err != nil {
  78. err = errors.Wrap(err, "write req body error")
  79. return
  80. }
  81. return
  82. }
  83. func (c *ClientConn) readResponse(ctx context.Context, resp *protoResp) (err error) {
  84. var (
  85. headerBuf = make([]byte, _headerLen)
  86. length int
  87. )
  88. if _, err = c.rwc.Read(headerBuf); err != nil {
  89. err = errors.Wrap(err, "read resp header error")
  90. return
  91. }
  92. resp.Header.magic = binary.BigEndian.Uint32(headerBuf[0:4])
  93. resp.Header.timestamp = binary.BigEndian.Uint32(headerBuf[4:8])
  94. resp.Header.checkSum = binary.BigEndian.Uint32(headerBuf[8:12])
  95. resp.Header.version = binary.BigEndian.Uint32(headerBuf[12:16])
  96. resp.Header.reserved = binary.BigEndian.Uint32(headerBuf[16:20])
  97. resp.Header.seq = binary.BigEndian.Uint32(headerBuf[20:24])
  98. resp.Header.length = binary.BigEndian.Uint32(headerBuf[24:28])
  99. resp.Header.cmd = headerBuf[28:60]
  100. resp.Body = make([]byte, resp.Header.length)
  101. if length, err = io.ReadFull(c.rwc, resp.Body); err != nil {
  102. err = errors.Wrap(err, "read resp body error")
  103. return
  104. }
  105. if uint32(length) != resp.Header.length {
  106. err = errors.New("bad resp body data")
  107. return
  108. }
  109. return
  110. }
  111. func (c *ClientConn) composeReqPackHeader(reqPack *protoReq, version int, serviceMethod string) {
  112. reqPack.Header.magic = _magic
  113. reqPack.Header.checkSum = 0
  114. reqPack.Header.seq = 1
  115. reqPack.Header.timestamp = uint32(time.Now().Unix())
  116. reqPack.Header.reserved = 0
  117. reqPack.Header.version = uint32(version)
  118. // command: {message_type}controller.method
  119. reqPack.Header.cmd = make([]byte, 32)
  120. reqPack.Header.cmd[0] = _cmdReqType
  121. // serviceMethod: Room.room_init
  122. copy(reqPack.Header.cmd[1:], []byte(serviceMethod))
  123. }
  124. func (c *ClientConn) setupDeadline(ctx context.Context) error {
  125. var t time.Duration
  126. if c.callInfo.Timeout != 0 {
  127. t = c.callInfo.Timeout
  128. } else {
  129. t, _ = ctx.Value(KeyTimeout).(time.Duration)
  130. }
  131. if t == 0 {
  132. t = c.Timeout
  133. }
  134. conn := c.rwc.(net.Conn)
  135. if conn != nil {
  136. err := conn.SetDeadline(time.Now().Add(t))
  137. if err != nil {
  138. conn.Close()
  139. return err
  140. }
  141. }
  142. return nil
  143. }
  144. // CallRaw call the service method, waits for it to complete, and returns reply its error status.
  145. // this is can be use without protobuf
  146. // client: {service}
  147. // serviceMethod: {version}|{controller.method}
  148. // httpURL: /room/v1/Room/room_init
  149. // httpURL: /{service}/{version}/{controller}/{method}
  150. func (c *ClientConn) CallRaw(ctx context.Context, version int, serviceMethod string, in *Args) (out *Reply, err error) {
  151. var (
  152. reqPack protoReq
  153. respPack protoResp
  154. code = "0"
  155. now = time.Now()
  156. uid int64
  157. )
  158. defer func() {
  159. stats.Timing(serviceMethod, int64(time.Since(now)/time.Millisecond))
  160. if code != "" {
  161. stats.Incr(serviceMethod, code)
  162. }
  163. logging(ctx, version, serviceMethod, c.addr, err, time.Since(now), uid)
  164. }()
  165. if err = c.setupDeadline(ctx); err != nil {
  166. return
  167. }
  168. // it is ok for request http field to be nil
  169. if in.Header == nil {
  170. if c.callInfo.Header != nil {
  171. in.Header = c.callInfo.Header
  172. } else if hdr, _ := ctx.Value(KeyHeader).(*Header); hdr != nil {
  173. in.Header = hdr
  174. } else {
  175. in.Header = createHeader(ctx)
  176. }
  177. }
  178. uid = in.Header.Uid
  179. if in.HTTP == nil {
  180. if c.callInfo.HTTP != nil {
  181. in.HTTP = c.callInfo.HTTP
  182. }
  183. }
  184. if in.Body == nil {
  185. in.Body = map[string]interface{}{}
  186. }
  187. c.composeReqPackHeader(&reqPack, version, serviceMethod)
  188. var reqBytes []byte
  189. if reqBytes, err = json.Marshal(in); err != nil {
  190. err = errors.Wrap(err, "CallRaw json marshal error")
  191. code = "marshalErr"
  192. return
  193. }
  194. reqPack.Body = reqBytes
  195. ch := make(chan error, 1)
  196. go func() {
  197. ch <- c.sendAndRecv(ctx, &reqPack, &respPack)
  198. }()
  199. select {
  200. case <-ctx.Done():
  201. err = errors.WithStack(ctx.Err())
  202. code = "canceled"
  203. return
  204. case err = <-ch:
  205. if err != nil {
  206. code = "ioErr"
  207. return
  208. }
  209. }
  210. out = &Reply{}
  211. if err = json.Unmarshal(respPack.Body, out); err != nil {
  212. err = errors.Wrap(err, "proto unmarshal error: "+string(respPack.Body))
  213. code = "unmarshalErr"
  214. return
  215. }
  216. return
  217. }
  218. func logging(ctx context.Context, version int, serviceMethod string, addr string, err error, ts time.Duration, uid int64) {
  219. var (
  220. path string
  221. errMsg string
  222. )
  223. logFunc := log.Infov
  224. if err != nil {
  225. if errors.Cause(err) == context.Canceled {
  226. logFunc = log.Warnv
  227. } else {
  228. logFunc = log.Errorv
  229. }
  230. errMsg = fmt.Sprintf("%+v", err)
  231. }
  232. path = "/v" + strconv.Itoa(version) + "/" + strings.Replace(serviceMethod, ".", "/", 1)
  233. logFunc(ctx,
  234. log.KV("path", path),
  235. log.KV("error", errMsg),
  236. log.KV("addr", addr),
  237. log.KV("ts", float64(ts.Seconds())),
  238. log.KV("uid", uid),
  239. log.KV("log", "LIVERPC"),
  240. )
  241. }
  242. func (c *ClientConn) sendAndRecv(ctx context.Context, reqPack *protoReq, respPack *protoResp) (err error) {
  243. if err = c.writeRequest(ctx, reqPack); err != nil {
  244. return
  245. }
  246. if err = c.readResponse(ctx, respPack); err != nil {
  247. return
  248. }
  249. return
  250. }
  251. // Call call the service method, waits for it to complete, and returns its error status.
  252. // this is used with protobuf generated msg
  253. // client: {service}
  254. // serviceMethod: {version}|{controller.method}
  255. // httpURL: /room/v1/Room/room_init
  256. // httpURL: /{service}/{version}/{controller}/{method}
  257. func (c *ClientConn) Call(ctx context.Context, version int, serviceMethod string, in, out proto.Message) (err error) {
  258. var (
  259. reqPack protoReq
  260. respPack protoResp
  261. code = "0"
  262. now = time.Now()
  263. uid int64
  264. )
  265. defer func() {
  266. stats.Timing(serviceMethod, int64(time.Since(now)/time.Millisecond))
  267. if code != "" {
  268. stats.Incr(serviceMethod, code)
  269. }
  270. logging(ctx, version, serviceMethod, c.addr, err, time.Since(now), uid)
  271. }()
  272. if err = c.setupDeadline(ctx); err != nil {
  273. return
  274. }
  275. fullMsg := &fullReqMsg{}
  276. if c.callInfo.Header != nil {
  277. fullMsg.Header = c.callInfo.Header
  278. } else if hdr, _ := ctx.Value(KeyHeader).(*Header); hdr != nil {
  279. fullMsg.Header = hdr
  280. } else {
  281. fullMsg.Header = createHeader(ctx)
  282. }
  283. uid = fullMsg.Header.Uid
  284. if c.callInfo.HTTP != nil {
  285. fullMsg.HTTP = c.callInfo.HTTP
  286. }
  287. fullMsg.Body = in
  288. // it is ok for request http field to be nil
  289. c.composeReqPackHeader(&reqPack, version, serviceMethod)
  290. var reqBody []byte
  291. if reqBody, err = json.Marshal(fullMsg); err != nil {
  292. err = errors.Wrap(err, "Call json marshal error")
  293. code = "marshalErr"
  294. return
  295. }
  296. reqPack.Body = reqBody
  297. ch := make(chan error, 1)
  298. go func() {
  299. ch <- c.sendAndRecv(ctx, &reqPack, &respPack)
  300. }()
  301. select {
  302. case <-ctx.Done():
  303. err = errors.WithStack(ctx.Err())
  304. code = "canceled"
  305. return
  306. case err = <-ch:
  307. if err != nil {
  308. code = "ioErr"
  309. return
  310. }
  311. }
  312. if err = jsoniter.Unmarshal(respPack.Body, out); err != nil {
  313. err = errors.Wrap(err, "proto unmarshal error: "+string(respPack.Body))
  314. code = "unmarshalErr"
  315. return
  316. }
  317. return
  318. }
  319. func createHeader(ctx context.Context) *Header {
  320. header := &Header{}
  321. header.UserIp = metadata.String(ctx, metadata.RemoteIP)
  322. header.Caller = strings.Replace(env.AppID, ".", "-", -1)
  323. if header.Caller == "" {
  324. header.Caller = "unknown"
  325. }
  326. tracer, ok := metadata.Value(ctx, metadata.Trace).(trace.Trace)
  327. if ok {
  328. trace.Inject(tracer, nil, header)
  329. }
  330. mid, _ := metadata.Value(ctx, "mid").(int64)
  331. header.Uid = mid
  332. if color := metadata.String(ctx, metadata.Color); color != "" {
  333. header.SourceGroup = color
  334. } else {
  335. header.SourceGroup = env.Color
  336. }
  337. //header.Platform = ctx.Request.FormValue("platform")
  338. return header
  339. }