123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 |
- package liverpc
- import (
- "context"
- "encoding/binary"
- "encoding/json"
- "fmt"
- "io"
- "net"
- "strconv"
- "strings"
- "time"
- "go-common/library/conf/env"
- "go-common/library/log"
- "go-common/library/net/metadata"
- "go-common/library/net/trace"
- "github.com/gogo/protobuf/proto"
- "github.com/json-iterator/go"
- "github.com/pkg/errors"
- )
- // ClientConn connect represent a real client connection to a rpc server
- type ClientConn struct {
- addr string
- network string
- rwc io.ReadWriteCloser
- Timeout time.Duration
- DialTimeout time.Duration
- callInfo *callInfo
- }
- type fullReqMsg struct {
- Header *Header `json:"header"`
- HTTP *HTTP `json:"http"`
- Body interface{} `json:"body"`
- }
- // Dial dial a rpc server
- func Dial(ctx context.Context, network, addr string, timeout time.Duration, connTimeout time.Duration) (*ClientConn, error) {
- c := &ClientConn{
- addr: addr,
- network: network,
- Timeout: timeout,
- DialTimeout: connTimeout,
- }
- conn, err := net.DialTimeout(c.network, c.addr, c.DialTimeout)
- if err != nil {
- return nil, err
- }
- c.rwc = conn
- return c, err
- }
- // Close close the caller connection.
- func (c *ClientConn) Close() error {
- if c.rwc != nil {
- return c.rwc.Close()
- }
- return nil
- }
- func (c *ClientConn) writeRequest(ctx context.Context, req *protoReq) (err error) {
- var (
- headerBuf = make([]byte, _headerLen)
- header = req.Header
- body = req.Body
- )
- binary.BigEndian.PutUint32(headerBuf[0:4], header.magic)
- binary.BigEndian.PutUint32(headerBuf[4:8], header.timestamp)
- binary.BigEndian.PutUint32(headerBuf[8:12], header.checkSum)
- binary.BigEndian.PutUint32(headerBuf[12:16], header.version)
- binary.BigEndian.PutUint32(headerBuf[16:20], header.reserved)
- binary.BigEndian.PutUint32(headerBuf[20:24], header.seq)
- binary.BigEndian.PutUint32(headerBuf[24:28], uint32(len(body)))
- copy(headerBuf[28:60], header.cmd)
- if _, err = c.rwc.Write(headerBuf); err != nil {
- err = errors.Wrap(err, "write req header error")
- return
- }
- if log.V(2) {
- log.Info("liverpc body: %s", string(body))
- }
- if _, err = c.rwc.Write(body); err != nil {
- err = errors.Wrap(err, "write req body error")
- return
- }
- return
- }
- func (c *ClientConn) readResponse(ctx context.Context, resp *protoResp) (err error) {
- var (
- headerBuf = make([]byte, _headerLen)
- length int
- )
- if _, err = c.rwc.Read(headerBuf); err != nil {
- err = errors.Wrap(err, "read resp header error")
- return
- }
- resp.Header.magic = binary.BigEndian.Uint32(headerBuf[0:4])
- resp.Header.timestamp = binary.BigEndian.Uint32(headerBuf[4:8])
- resp.Header.checkSum = binary.BigEndian.Uint32(headerBuf[8:12])
- resp.Header.version = binary.BigEndian.Uint32(headerBuf[12:16])
- resp.Header.reserved = binary.BigEndian.Uint32(headerBuf[16:20])
- resp.Header.seq = binary.BigEndian.Uint32(headerBuf[20:24])
- resp.Header.length = binary.BigEndian.Uint32(headerBuf[24:28])
- resp.Header.cmd = headerBuf[28:60]
- resp.Body = make([]byte, resp.Header.length)
- if length, err = io.ReadFull(c.rwc, resp.Body); err != nil {
- err = errors.Wrap(err, "read resp body error")
- return
- }
- if uint32(length) != resp.Header.length {
- err = errors.New("bad resp body data")
- return
- }
- return
- }
- func (c *ClientConn) composeReqPackHeader(reqPack *protoReq, version int, serviceMethod string) {
- reqPack.Header.magic = _magic
- reqPack.Header.checkSum = 0
- reqPack.Header.seq = 1
- reqPack.Header.timestamp = uint32(time.Now().Unix())
- reqPack.Header.reserved = 0
- reqPack.Header.version = uint32(version)
- // command: {message_type}controller.method
- reqPack.Header.cmd = make([]byte, 32)
- reqPack.Header.cmd[0] = _cmdReqType
- // serviceMethod: Room.room_init
- copy(reqPack.Header.cmd[1:], []byte(serviceMethod))
- }
- func (c *ClientConn) setupDeadline(ctx context.Context) error {
- var t time.Duration
- if c.callInfo.Timeout != 0 {
- t = c.callInfo.Timeout
- } else {
- t, _ = ctx.Value(KeyTimeout).(time.Duration)
- }
- if t == 0 {
- t = c.Timeout
- }
- conn := c.rwc.(net.Conn)
- if conn != nil {
- err := conn.SetDeadline(time.Now().Add(t))
- if err != nil {
- conn.Close()
- return err
- }
- }
- return nil
- }
- // CallRaw call the service method, waits for it to complete, and returns reply its error status.
- // this is can be use without protobuf
- // client: {service}
- // serviceMethod: {version}|{controller.method}
- // httpURL: /room/v1/Room/room_init
- // httpURL: /{service}/{version}/{controller}/{method}
- func (c *ClientConn) CallRaw(ctx context.Context, version int, serviceMethod string, in *Args) (out *Reply, err error) {
- var (
- reqPack protoReq
- respPack protoResp
- code = "0"
- now = time.Now()
- uid int64
- )
- defer func() {
- stats.Timing(serviceMethod, int64(time.Since(now)/time.Millisecond))
- if code != "" {
- stats.Incr(serviceMethod, code)
- }
- logging(ctx, version, serviceMethod, c.addr, err, time.Since(now), uid)
- }()
- if err = c.setupDeadline(ctx); err != nil {
- return
- }
- // it is ok for request http field to be nil
- if in.Header == nil {
- if c.callInfo.Header != nil {
- in.Header = c.callInfo.Header
- } else if hdr, _ := ctx.Value(KeyHeader).(*Header); hdr != nil {
- in.Header = hdr
- } else {
- in.Header = createHeader(ctx)
- }
- }
- uid = in.Header.Uid
- if in.HTTP == nil {
- if c.callInfo.HTTP != nil {
- in.HTTP = c.callInfo.HTTP
- }
- }
- if in.Body == nil {
- in.Body = map[string]interface{}{}
- }
- c.composeReqPackHeader(&reqPack, version, serviceMethod)
- var reqBytes []byte
- if reqBytes, err = json.Marshal(in); err != nil {
- err = errors.Wrap(err, "CallRaw json marshal error")
- code = "marshalErr"
- return
- }
- reqPack.Body = reqBytes
- ch := make(chan error, 1)
- go func() {
- ch <- c.sendAndRecv(ctx, &reqPack, &respPack)
- }()
- select {
- case <-ctx.Done():
- err = errors.WithStack(ctx.Err())
- code = "canceled"
- return
- case err = <-ch:
- if err != nil {
- code = "ioErr"
- return
- }
- }
- out = &Reply{}
- if err = json.Unmarshal(respPack.Body, out); err != nil {
- err = errors.Wrap(err, "proto unmarshal error: "+string(respPack.Body))
- code = "unmarshalErr"
- return
- }
- return
- }
- func logging(ctx context.Context, version int, serviceMethod string, addr string, err error, ts time.Duration, uid int64) {
- var (
- path string
- errMsg string
- )
- logFunc := log.Infov
- if err != nil {
- if errors.Cause(err) == context.Canceled {
- logFunc = log.Warnv
- } else {
- logFunc = log.Errorv
- }
- errMsg = fmt.Sprintf("%+v", err)
- }
- path = "/v" + strconv.Itoa(version) + "/" + strings.Replace(serviceMethod, ".", "/", 1)
- logFunc(ctx,
- log.KV("path", path),
- log.KV("error", errMsg),
- log.KV("addr", addr),
- log.KV("ts", float64(ts.Seconds())),
- log.KV("uid", uid),
- log.KV("log", "LIVERPC"),
- )
- }
- func (c *ClientConn) sendAndRecv(ctx context.Context, reqPack *protoReq, respPack *protoResp) (err error) {
- if err = c.writeRequest(ctx, reqPack); err != nil {
- return
- }
- if err = c.readResponse(ctx, respPack); err != nil {
- return
- }
- return
- }
- // Call call the service method, waits for it to complete, and returns its error status.
- // this is used with protobuf generated msg
- // client: {service}
- // serviceMethod: {version}|{controller.method}
- // httpURL: /room/v1/Room/room_init
- // httpURL: /{service}/{version}/{controller}/{method}
- func (c *ClientConn) Call(ctx context.Context, version int, serviceMethod string, in, out proto.Message) (err error) {
- var (
- reqPack protoReq
- respPack protoResp
- code = "0"
- now = time.Now()
- uid int64
- )
- defer func() {
- stats.Timing(serviceMethod, int64(time.Since(now)/time.Millisecond))
- if code != "" {
- stats.Incr(serviceMethod, code)
- }
- logging(ctx, version, serviceMethod, c.addr, err, time.Since(now), uid)
- }()
- if err = c.setupDeadline(ctx); err != nil {
- return
- }
- fullMsg := &fullReqMsg{}
- if c.callInfo.Header != nil {
- fullMsg.Header = c.callInfo.Header
- } else if hdr, _ := ctx.Value(KeyHeader).(*Header); hdr != nil {
- fullMsg.Header = hdr
- } else {
- fullMsg.Header = createHeader(ctx)
- }
- uid = fullMsg.Header.Uid
- if c.callInfo.HTTP != nil {
- fullMsg.HTTP = c.callInfo.HTTP
- }
- fullMsg.Body = in
- // it is ok for request http field to be nil
- c.composeReqPackHeader(&reqPack, version, serviceMethod)
- var reqBody []byte
- if reqBody, err = json.Marshal(fullMsg); err != nil {
- err = errors.Wrap(err, "Call json marshal error")
- code = "marshalErr"
- return
- }
- reqPack.Body = reqBody
- ch := make(chan error, 1)
- go func() {
- ch <- c.sendAndRecv(ctx, &reqPack, &respPack)
- }()
- select {
- case <-ctx.Done():
- err = errors.WithStack(ctx.Err())
- code = "canceled"
- return
- case err = <-ch:
- if err != nil {
- code = "ioErr"
- return
- }
- }
- if err = jsoniter.Unmarshal(respPack.Body, out); err != nil {
- err = errors.Wrap(err, "proto unmarshal error: "+string(respPack.Body))
- code = "unmarshalErr"
- return
- }
- return
- }
- func createHeader(ctx context.Context) *Header {
- header := &Header{}
- header.UserIp = metadata.String(ctx, metadata.RemoteIP)
- header.Caller = strings.Replace(env.AppID, ".", "-", -1)
- if header.Caller == "" {
- header.Caller = "unknown"
- }
- tracer, ok := metadata.Value(ctx, metadata.Trace).(trace.Trace)
- if ok {
- trace.Inject(tracer, nil, header)
- }
- mid, _ := metadata.Value(ctx, "mid").(int64)
- header.Uid = mid
- if color := metadata.String(ctx, metadata.Color); color != "" {
- header.SourceGroup = color
- } else {
- header.SourceGroup = env.Color
- }
- //header.Platform = ctx.Request.FormValue("platform")
- return header
- }
|