client.go 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. package liverpc
  2. import (
  3. "context"
  4. "fmt"
  5. "net/url"
  6. "sync/atomic"
  7. "time"
  8. "go-common/library/conf/env"
  9. "go-common/library/log"
  10. "go-common/library/naming"
  11. "go-common/library/naming/discovery"
  12. "go-common/library/net/metadata"
  13. "go-common/library/net/trace"
  14. "go-common/library/stat"
  15. xtime "go-common/library/time"
  16. "github.com/golang/protobuf/proto"
  17. "github.com/pkg/errors"
  18. )
  19. // Key is ContextKey
  20. type Key int
  21. const (
  22. _ Key = iota
  23. // KeyHeader use this in context to pass rpc header field
  24. // Depreated 请使用HeaderOption来传递Header
  25. KeyHeader
  26. // KeyTimeout deprecated
  27. // Depreated 请使用HTTPOption来传递HTTP
  28. KeyTimeout
  29. )
  30. const (
  31. _scheme = "liverpc"
  32. _dialRetries = 3
  33. )
  34. // Get Implement tracer carrier interface
  35. func (m *Header) Get(key string) string {
  36. if key == trace.KeyTraceID {
  37. return m.TraceId
  38. }
  39. return ""
  40. }
  41. // Set Implement tracer carrier interface
  42. func (m *Header) Set(key string, val string) {
  43. if key == trace.KeyTraceID {
  44. m.TraceId = val
  45. }
  46. }
  47. var (
  48. // ErrNoClient no RPC client.
  49. errNoClient = errors.New("no rpc client")
  50. errGroupInvalid = errors.New("invalid group")
  51. stats = stat.RPCClient
  52. )
  53. // GroupAddrs a map struct storing addrs vary groups
  54. type GroupAddrs map[string][]string
  55. // ClientConfig client config.
  56. type ClientConfig struct {
  57. AppID string
  58. Group string
  59. Timeout xtime.Duration
  60. ConnTimeout xtime.Duration
  61. Addr string // if addr is provided, it will use add, else, use discovery
  62. }
  63. // Client is a RPC client.
  64. type Client struct {
  65. conf *ClientConfig
  66. dis naming.Resolver
  67. addrs atomic.Value // GroupAddrs
  68. addrsIdx int64
  69. }
  70. // NewClient new a RPC client with discovery.
  71. func NewClient(c *ClientConfig) *Client {
  72. if c.Timeout <= 0 {
  73. c.Timeout = xtime.Duration(time.Second)
  74. }
  75. if c.ConnTimeout <= 0 {
  76. c.ConnTimeout = xtime.Duration(time.Second)
  77. }
  78. cli := &Client{
  79. conf: c,
  80. }
  81. if c.Addr != "" {
  82. groupAddrs := make(GroupAddrs)
  83. groupAddrs[""] = []string{c.Addr}
  84. cli.addrs.Store(groupAddrs)
  85. return cli
  86. }
  87. cli.dis = discovery.Build(c.AppID)
  88. // discovery watch & fetch nodes
  89. event := cli.dis.Watch()
  90. select {
  91. case _, ok := <-event:
  92. if !ok {
  93. panic("刚启动就从discovery拉到了关闭的event")
  94. }
  95. cli.disFetch()
  96. fmt.Printf("开始创建:%s 的liverpc client,等待从discovery拉取节点:%s\n", c.AppID, time.Now().Format("2006-01-02 15:04:05"))
  97. case <-time.After(10 * time.Second):
  98. fmt.Printf("失败创建:%s 的liverpc client,竟然从discovery拉取节点超时了:%s\n", c.AppID, time.Now().Format("2006-01-02 15:04:05"))
  99. }
  100. go cli.disproc(event)
  101. return cli
  102. }
  103. func (c *Client) disproc(event <-chan struct{}) {
  104. for {
  105. _, ok := <-event
  106. if !ok {
  107. return
  108. }
  109. c.disFetch()
  110. }
  111. }
  112. func (c *Client) disFetch() {
  113. ins, ok := c.dis.Fetch(context.Background())
  114. if !ok {
  115. return
  116. }
  117. insZone, ok := ins[env.Zone]
  118. if !ok {
  119. return
  120. }
  121. addrs := make(GroupAddrs)
  122. for _, svr := range insZone {
  123. group, ok := svr.Metadata["color"]
  124. if !ok {
  125. group = ""
  126. }
  127. for _, addr := range svr.Addrs {
  128. u, err := url.Parse(addr)
  129. if err == nil && u.Scheme == _scheme {
  130. addrs[group] = append(addrs[group], u.Host)
  131. }
  132. }
  133. }
  134. if len(addrs) > 0 {
  135. c.addrs.Store(addrs)
  136. }
  137. }
  138. // pickConn pick conn by addrs
  139. func (c *Client) pickConn(ctx context.Context, addrs []string, dialTimeout time.Duration) (*ClientConn, error) {
  140. var (
  141. lastErr error
  142. )
  143. if len(addrs) == 0 {
  144. lastErr = errors.New("addrs empty")
  145. } else {
  146. for i := 0; i < _dialRetries; i++ {
  147. idx := atomic.AddInt64(&c.addrsIdx, 1)
  148. addr := addrs[int(idx)%len(addrs)]
  149. if dialTimeout == 0 {
  150. dialTimeout = time.Duration(c.conf.ConnTimeout)
  151. }
  152. cc, err := Dial(ctx, "tcp", addr, time.Duration(c.conf.Timeout), dialTimeout)
  153. if err != nil {
  154. lastErr = errors.Wrapf(err, "Dial %s error", addr)
  155. continue
  156. }
  157. return cc, nil
  158. }
  159. }
  160. if lastErr != nil {
  161. return nil, errors.WithMessage(errNoClient, lastErr.Error())
  162. }
  163. return nil, errors.WithStack(errNoClient)
  164. }
  165. // fetchAddrs fetch addrs by different strategies
  166. // source_group first, come from request header if exists, currently only CallRaw supports source_group
  167. // then env group, come from os.env
  168. // since no invalid group found, return error
  169. func (c *Client) fetchAddrs(ctx context.Context, request interface{}) (addrs []string, err error) {
  170. var (
  171. args *Args
  172. groupAddrs GroupAddrs
  173. ok bool
  174. sourceGroup string
  175. groups []string
  176. )
  177. defer func() {
  178. if err != nil {
  179. err = errors.WithMessage(errGroupInvalid, err.Error())
  180. }
  181. }()
  182. // try parse request header and fetch source group
  183. if args, ok = request.(*Args); ok && args.Header != nil {
  184. sourceGroup = args.Header.SourceGroup
  185. if sourceGroup != "" {
  186. groups = append(groups, sourceGroup)
  187. }
  188. }
  189. metaColor := metadata.String(ctx, metadata.Color)
  190. if metaColor != "" && metaColor != sourceGroup {
  191. groups = append(groups, metaColor)
  192. }
  193. if env.Color != "" && env.Color != metaColor {
  194. groups = append(groups, env.Color)
  195. }
  196. groups = append(groups, "")
  197. if groupAddrs, ok = c.addrs.Load().(GroupAddrs); !ok {
  198. err = errors.New("addrs load error")
  199. return
  200. }
  201. if len(groupAddrs) == 0 {
  202. err = errors.New("group addrs empty")
  203. return
  204. }
  205. for _, group := range groups {
  206. if addrs, ok = groupAddrs[group]; ok {
  207. break
  208. }
  209. }
  210. if len(addrs) == 0 {
  211. err = errors.Errorf("addrs empty source(%s), metadata(%s), env(%s), default empty, allAddrs(%+v)",
  212. sourceGroup, metaColor, env.Color, groupAddrs)
  213. return
  214. }
  215. return
  216. }
  217. // Call call the service method, waits for it to complete, and returns its error status.
  218. // client: {service}
  219. // serviceMethod: {version}|{controller.method}
  220. // httpURL: /room/v1/Room/room_init
  221. // httpURL: /{service}/{version}/{controller}/{method}
  222. func (c *Client) Call(ctx context.Context, version int, serviceMethod string, in proto.Message, out proto.Message, opts ...CallOption) (err error) {
  223. var (
  224. cc *ClientConn
  225. addrs []string
  226. )
  227. isPickErr := true
  228. defer func() {
  229. if cc != nil {
  230. cc.Close()
  231. }
  232. if err != nil && isPickErr {
  233. log.Error("liverpc Call pick connection error, version %d, method: %s, error: %+v", version, serviceMethod, err)
  234. }
  235. }() // for now it is non-persistent connection
  236. var cInfo = &callInfo{}
  237. for _, o := range opts {
  238. o.before(cInfo)
  239. }
  240. addrs, err = c.fetchAddrs(ctx, in)
  241. if err != nil {
  242. return
  243. }
  244. cc, err = c.pickConn(ctx, addrs, cInfo.DialTimeout)
  245. if err != nil {
  246. return
  247. }
  248. isPickErr = false
  249. cc.callInfo = cInfo
  250. err = cc.Call(ctx, version, serviceMethod, in, out)
  251. if err != nil {
  252. return
  253. }
  254. for _, o := range opts {
  255. o.after(cc.callInfo)
  256. }
  257. return
  258. }
  259. // CallRaw call the service method, waits for it to complete, and returns reply its error status.
  260. // this is can be use without protobuf
  261. // client: {service}
  262. // serviceMethod: {version}|{controller.method}
  263. // httpURL: /room/v1/Room/room_init
  264. // httpURL: /{service}/{version}/{controller}/{method}
  265. func (c *Client) CallRaw(ctx context.Context, version int, serviceMethod string, in *Args, opts ...CallOption) (out *Reply, err error) {
  266. var (
  267. cc *ClientConn
  268. addrs []string
  269. )
  270. isPickErr := true
  271. defer func() {
  272. if cc != nil {
  273. cc.Close()
  274. }
  275. if err != nil && isPickErr {
  276. log.Error("liverpc CallRaw pick connection error, version %d, method: %s, error: %+v", version, serviceMethod, err)
  277. }
  278. }() // for now it is non-persistent connection
  279. var cInfo = &callInfo{}
  280. for _, o := range opts {
  281. o.before(cInfo)
  282. }
  283. addrs, err = c.fetchAddrs(ctx, in)
  284. if err != nil {
  285. return
  286. }
  287. cc, err = c.pickConn(ctx, addrs, cInfo.DialTimeout)
  288. if err != nil {
  289. return
  290. }
  291. isPickErr = false
  292. cc.callInfo = cInfo
  293. out, err = cc.CallRaw(ctx, version, serviceMethod, in)
  294. if err != nil {
  295. return
  296. }
  297. for _, o := range opts {
  298. o.after(cc.callInfo)
  299. }
  300. return
  301. }
  302. //Close handle client exit
  303. func (c *Client) Close() {
  304. if c.dis != nil {
  305. c.dis.Close()
  306. }
  307. }