server_websocket.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. package server
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "io"
  6. "net"
  7. "strings"
  8. "time"
  9. iModel "go-common/app/interface/main/broadcast/model"
  10. "go-common/app/service/main/broadcast/libs/bytes"
  11. itime "go-common/app/service/main/broadcast/libs/time"
  12. "go-common/app/service/main/broadcast/libs/websocket"
  13. "go-common/app/service/main/broadcast/model"
  14. "go-common/library/log"
  15. "go-common/library/net/metadata"
  16. )
  17. // InitWebsocket listen all tcp.bind and start accept connections.
  18. func InitWebsocket(server *Server, addrs []string, accept int) (err error) {
  19. var (
  20. bind string
  21. listener *net.TCPListener
  22. addr *net.TCPAddr
  23. )
  24. for _, bind = range addrs {
  25. if addr, err = net.ResolveTCPAddr("tcp", bind); err != nil {
  26. log.Error("net.ResolveTCPAddr(\"tcp\", \"%s\") error(%v)", bind, err)
  27. return
  28. }
  29. if listener, err = net.ListenTCP("tcp", addr); err != nil {
  30. log.Error("net.ListenTCP(\"tcp\", \"%s\") error(%v)", bind, err)
  31. return
  32. }
  33. log.Info("start ws listen: \"%s\"", bind)
  34. // split N core accept
  35. for i := 0; i < accept; i++ {
  36. go acceptWebsocket(server, listener)
  37. }
  38. }
  39. return
  40. }
  41. // InitWebsocketWithTLS init websocket with tls.
  42. func InitWebsocketWithTLS(server *Server, addrs []string, certFile, privateFile string, accept int) (err error) {
  43. var (
  44. bind string
  45. listener net.Listener
  46. cert tls.Certificate
  47. certs []tls.Certificate
  48. )
  49. certFiles := strings.Split(certFile, ",")
  50. privateFiles := strings.Split(privateFile, ",")
  51. for i := range certFiles {
  52. cert, err = tls.LoadX509KeyPair(certFiles[i], privateFiles[i])
  53. if err != nil {
  54. log.Error("Error loading certificate. error(%v)", err)
  55. return
  56. }
  57. certs = append(certs, cert)
  58. }
  59. tlsCfg := &tls.Config{Certificates: certs}
  60. tlsCfg.BuildNameToCertificate()
  61. for _, bind = range addrs {
  62. if listener, err = tls.Listen("tcp", bind, tlsCfg); err != nil {
  63. log.Error("net.ListenTCP(\"tcp\", \"%s\") error(%v)", bind, err)
  64. return
  65. }
  66. log.Info("start wss listen: \"%s\"", bind)
  67. // split N core accept
  68. for i := 0; i < accept; i++ {
  69. go acceptWebsocketWithTLS(server, listener)
  70. }
  71. }
  72. return
  73. }
  74. // Accept accepts connections on the listener and serves requests
  75. // for each incoming connection. Accept blocks; the caller typically
  76. // invokes it in a go statement.
  77. func acceptWebsocket(server *Server, lis *net.TCPListener) {
  78. var (
  79. conn *net.TCPConn
  80. err error
  81. r int
  82. )
  83. for {
  84. if conn, err = lis.AcceptTCP(); err != nil {
  85. // if listener close then return
  86. log.Error("listener.Accept(\"%s\") error(%v)", lis.Addr().String(), err)
  87. return
  88. }
  89. if err = conn.SetKeepAlive(server.c.TCP.Keepalive); err != nil {
  90. log.Error("conn.SetKeepAlive() error(%v)", err)
  91. return
  92. }
  93. if err = conn.SetReadBuffer(server.c.TCP.Rcvbuf); err != nil {
  94. log.Error("conn.SetReadBuffer() error(%v)", err)
  95. return
  96. }
  97. if err = conn.SetWriteBuffer(server.c.TCP.Sndbuf); err != nil {
  98. log.Error("conn.SetWriteBuffer() error(%v)", err)
  99. return
  100. }
  101. go serveWebsocket(server, conn, r)
  102. if r++; r == _maxInt {
  103. r = 0
  104. }
  105. }
  106. }
  107. // Accept accepts connections on the listener and serves requests
  108. // for each incoming connection. Accept blocks; the caller typically
  109. // invokes it in a go statement.
  110. func acceptWebsocketWithTLS(server *Server, lis net.Listener) {
  111. var (
  112. conn net.Conn
  113. err error
  114. r int
  115. )
  116. for {
  117. if conn, err = lis.Accept(); err != nil {
  118. // if listener close then return
  119. log.Error("listener.Accept(\"%s\") error(%v)", lis.Addr().String(), err)
  120. return
  121. }
  122. go serveWebsocket(server, conn, r)
  123. if r++; r == _maxInt {
  124. r = 0
  125. }
  126. }
  127. }
  128. func serveWebsocket(s *Server, conn net.Conn, r int) {
  129. var (
  130. // timer
  131. tr = s.round.Timer(r)
  132. rp = s.round.Reader(r)
  133. wp = s.round.Writer(r)
  134. )
  135. if s.c.Broadcast.Debug {
  136. // ip addr
  137. lAddr := conn.LocalAddr().String()
  138. rAddr := conn.RemoteAddr().String()
  139. log.Info("start tcp serve \"%s\" with \"%s\"", lAddr, rAddr)
  140. }
  141. s.ServeWebsocket(conn, rp, wp, tr)
  142. }
  143. // ServeWebsocket .
  144. func (s *Server) ServeWebsocket(conn net.Conn, rp, wp *bytes.Pool, tr *itime.Timer) {
  145. var (
  146. err error
  147. accepts []int32
  148. rid string
  149. white bool
  150. p *model.Proto
  151. b *Bucket
  152. trd *itime.TimerData
  153. lastHB = time.Now()
  154. rb = rp.Get()
  155. ch = NewChannel(s.c.ProtoSection.CliProto, s.c.ProtoSection.SvrProto)
  156. rr = &ch.Reader
  157. wr = &ch.Writer
  158. ws *websocket.Conn // websocket
  159. req *websocket.Request
  160. )
  161. // reader
  162. ch.Reader.ResetBuffer(conn, rb.Bytes())
  163. // handshake
  164. step := 0
  165. trd = tr.Add(time.Duration(s.c.ProtoSection.HandshakeTimeout), func() {
  166. conn.SetDeadline(time.Now().Add(time.Millisecond * 100))
  167. conn.Close()
  168. log.Error("key: %s remoteIP: %s step: %d ws handshake timeout", ch.Key, conn.RemoteAddr().String(), step)
  169. })
  170. // websocket
  171. ch.IP, _, _ = net.SplitHostPort(conn.RemoteAddr().String())
  172. step = 1
  173. if req, err = websocket.ReadRequest(rr); err != nil || req.RequestURI != "/sub" {
  174. conn.Close()
  175. tr.Del(trd)
  176. rp.Put(rb)
  177. if err != io.EOF {
  178. log.Error("http.ReadRequest(rr) error(%v)", err)
  179. }
  180. return
  181. }
  182. // writer
  183. wb := wp.Get()
  184. ch.Writer.ResetBuffer(conn, wb.Bytes())
  185. step = 2
  186. if ws, err = websocket.Upgrade(conn, rr, wr, req); err != nil {
  187. conn.Close()
  188. tr.Del(trd)
  189. rp.Put(rb)
  190. wp.Put(wb)
  191. if err != io.EOF {
  192. log.Error("websocket.NewServerConn error(%v)", err)
  193. }
  194. return
  195. }
  196. // must not setadv, only used in auth
  197. step = 3
  198. md := metadata.MD{
  199. metadata.RemoteIP: ch.IP,
  200. }
  201. ctx := metadata.NewContext(context.Background(), md)
  202. ctx, cancel := context.WithCancel(ctx)
  203. defer cancel()
  204. if p, err = ch.CliProto.Set(); err == nil {
  205. if ch.Mid, ch.Key, rid, ch.Platform, accepts, err = s.authWebsocket(ctx, ws, p, req.Header.Get("Cookie")); err == nil {
  206. ch.Watch(accepts...)
  207. b = s.Bucket(ch.Key)
  208. err = b.Put(rid, ch)
  209. if s.c.Broadcast.Debug {
  210. log.Info("websocket connnected key:%s mid:%d proto:%+v", ch.Key, ch.Mid, p)
  211. }
  212. }
  213. }
  214. step = 4
  215. if err != nil {
  216. ws.Close()
  217. rp.Put(rb)
  218. wp.Put(wb)
  219. tr.Del(trd)
  220. if err != io.EOF && err != websocket.ErrMessageClose {
  221. log.Error("key: %s remoteIP: %s step: %d ws handshake failed error(%v)", ch.Key, conn.RemoteAddr().String(), step, err)
  222. }
  223. return
  224. }
  225. trd.Key = ch.Key
  226. tr.Set(trd, _clientHeartbeat)
  227. white = whitelist.Contains(ch.Mid)
  228. if white {
  229. whitelist.Printf("key: %s[%s] auth\n", ch.Key, rid)
  230. }
  231. // hanshake ok start dispatch goroutine
  232. step = 5
  233. reportCh(actionConnect, ch)
  234. go s.dispatchWebsocket(ws, wp, wb, ch)
  235. serverHeartbeat := s.RandServerHearbeat()
  236. for {
  237. if p, err = ch.CliProto.Set(); err != nil {
  238. break
  239. }
  240. if white {
  241. whitelist.Printf("key: %s start read proto\n", ch.Key)
  242. }
  243. if err = p.ReadWebsocket(ws); err != nil {
  244. break
  245. }
  246. if white {
  247. whitelist.Printf("key: %s read proto:%v\n", ch.Key, p)
  248. }
  249. if p.Operation == model.OpHeartbeat {
  250. tr.Set(trd, _clientHeartbeat)
  251. p.Body = nil
  252. p.Operation = model.OpHeartbeatReply
  253. // last server heartbeat
  254. if now := time.Now(); now.Sub(lastHB) > serverHeartbeat {
  255. if err = s.Heartbeat(ctx, ch.Mid, ch.Key); err == nil {
  256. lastHB = now
  257. } else {
  258. err = nil
  259. }
  260. }
  261. if s.c.Broadcast.Debug {
  262. log.Info("websocket heartbeat receive key:%s, mid:%d", ch.Key, ch.Mid)
  263. }
  264. step++
  265. } else {
  266. if err = s.Operate(p, ch, b); err != nil {
  267. break
  268. }
  269. }
  270. if white {
  271. whitelist.Printf("key: %s process proto:%v\n", ch.Key, p)
  272. }
  273. ch.CliProto.SetAdv()
  274. ch.Signal()
  275. if white {
  276. whitelist.Printf("key: %s signal\n", ch.Key)
  277. }
  278. }
  279. if white {
  280. whitelist.Printf("key: %s server tcp error(%v)\n", ch.Key, err)
  281. }
  282. if err != nil && err != io.EOF && err != websocket.ErrMessageClose && !strings.Contains(err.Error(), "closed") {
  283. log.Error("key: %s server ws failed error(%v)", ch.Key, err)
  284. }
  285. b.Del(ch)
  286. tr.Del(trd)
  287. ws.Close()
  288. ch.Close()
  289. rp.Put(rb)
  290. if err = s.Disconnect(ctx, ch.Mid, ch.Key); err != nil {
  291. log.Error("key: %s operator do disconnect error(%v)", ch.Key, err)
  292. }
  293. if white {
  294. whitelist.Printf("key: %s disconnect error(%v)\n", ch.Key, err)
  295. }
  296. reportCh(actionDisconnect, ch)
  297. if s.c.Broadcast.Debug {
  298. log.Info("websocket disconnected key: %s mid:%d", ch.Key, ch.Mid)
  299. }
  300. }
  301. // dispatch accepts connections on the listener and serves requests
  302. // for each incoming connection. dispatch blocks; the caller typically
  303. // invokes it in a go statement.
  304. func (s *Server) dispatchWebsocket(ws *websocket.Conn, wp *bytes.Pool, wb *bytes.Buffer, ch *Channel) {
  305. var (
  306. err error
  307. finish bool
  308. online int32
  309. white = whitelist.Contains(ch.Mid)
  310. )
  311. if s.c.Broadcast.Debug {
  312. log.Info("key: %s start dispatch tcp goroutine", ch.Key)
  313. }
  314. for {
  315. if white {
  316. whitelist.Printf("key: %s wait proto ready\n", ch.Key)
  317. }
  318. var p = ch.Ready()
  319. if white {
  320. whitelist.Printf("key: %s proto ready\n", ch.Key)
  321. }
  322. if s.c.Broadcast.Debug {
  323. log.Info("key:%s dispatch msg:%s", ch.Key, p.Body)
  324. }
  325. switch p {
  326. case model.ProtoFinish:
  327. if white {
  328. whitelist.Printf("key: %s receive proto finish\n", ch.Key)
  329. }
  330. if s.c.Broadcast.Debug {
  331. log.Info("key: %s wakeup exit dispatch goroutine", ch.Key)
  332. }
  333. finish = true
  334. goto failed
  335. case model.ProtoReady:
  336. // fetch message from svrbox(client send)
  337. for {
  338. if p, err = ch.CliProto.Get(); err != nil {
  339. err = nil // must be empty error
  340. break
  341. }
  342. if white {
  343. whitelist.Printf("key: %s start write client proto%v\n", ch.Key, p)
  344. }
  345. if p.Operation == model.OpHeartbeatReply {
  346. if ch.Room != nil {
  347. online = ch.Room.OnlineNum()
  348. b := map[string]interface{}{"room": map[string]interface{}{"online": online, "room_id": ch.Room.ID}}
  349. p.Body = iModel.Message(b, nil)
  350. }
  351. if err = p.WriteWebsocketHeart(ws); err != nil {
  352. goto failed
  353. }
  354. } else {
  355. if err = p.WriteWebsocket(ws); err != nil {
  356. goto failed
  357. }
  358. }
  359. if white {
  360. whitelist.Printf("key: %s write client proto%v\n", ch.Key, p)
  361. }
  362. p.Body = nil // avoid memory leak
  363. ch.CliProto.GetAdv()
  364. }
  365. default:
  366. if white {
  367. whitelist.Printf("key: %s start write server proto%v\n", ch.Key, p)
  368. }
  369. if err = p.WriteWebsocket(ws); err != nil {
  370. goto failed
  371. }
  372. if white {
  373. whitelist.Printf("key: %s write server proto%v\n", ch.Key, p)
  374. }
  375. if s.c.Broadcast.Debug {
  376. log.Info("websocket sent a message key:%s mid:%d proto:%+v", ch.Key, ch.Mid, p)
  377. }
  378. }
  379. if white {
  380. whitelist.Printf("key: %s start flush \n", ch.Key)
  381. }
  382. // only hungry flush response
  383. if err = ws.Flush(); err != nil {
  384. break
  385. }
  386. if white {
  387. whitelist.Printf("key: %s flush\n", ch.Key)
  388. }
  389. }
  390. failed:
  391. if white {
  392. whitelist.Printf("key: %s dispatch tcp error(%v)\n", ch.Key, err)
  393. }
  394. if err != nil && err != io.EOF && err != websocket.ErrMessageClose {
  395. log.Error("key: %s dispatch ws error(%v)", ch.Key, err)
  396. }
  397. ws.Close()
  398. wp.Put(wb)
  399. // must ensure all channel message discard, for reader won't blocking Signal
  400. for !finish {
  401. finish = (ch.Ready() == model.ProtoFinish)
  402. }
  403. if s.c.Broadcast.Debug {
  404. log.Info("key: %s dispatch goroutine exit", ch.Key)
  405. }
  406. }
  407. // auth for goim handshake with client, use rsa & aes.
  408. func (s *Server) authWebsocket(ctx context.Context, ws *websocket.Conn, p *model.Proto, cookie string) (mid int64, key string, rid string, platform string, accepts []int32, err error) {
  409. for {
  410. if err = p.ReadWebsocket(ws); err != nil {
  411. return
  412. }
  413. if p.Operation == model.OpAuth {
  414. break
  415. } else {
  416. log.Error("ws request operation(%d) not auth", p.Operation)
  417. }
  418. }
  419. if mid, key, rid, platform, accepts, err = s.Connect(ctx, p, cookie); err != nil {
  420. return
  421. }
  422. p.Body = []byte(`{"code":0,"message":"ok"}`)
  423. p.Operation = model.OpAuthReply
  424. if err = p.WriteWebsocket(ws); err != nil {
  425. return
  426. }
  427. err = ws.Flush()
  428. return
  429. }