server_websocket_v1.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. package server
  2. import (
  3. "crypto/tls"
  4. "io"
  5. "net"
  6. "time"
  7. "go-common/app/service/main/broadcast/libs/bytes"
  8. itime "go-common/app/service/main/broadcast/libs/time"
  9. "go-common/app/service/main/broadcast/libs/websocket"
  10. "go-common/app/service/main/broadcast/model"
  11. "go-common/library/log"
  12. )
  13. // InitWebsocketV1 listen all tcp.bind and start accept connections.
  14. func InitWebsocketV1(s *Server, addrs []string, accept int) (err error) {
  15. var (
  16. bind string
  17. listener *net.TCPListener
  18. addr *net.TCPAddr
  19. )
  20. for _, bind = range addrs {
  21. if addr, err = net.ResolveTCPAddr("tcp", bind); err != nil {
  22. log.Error("net.ResolveTCPAddr(\"tcp\", \"%s\") error(%v)", bind, err)
  23. return
  24. }
  25. if listener, err = net.ListenTCP("tcp", addr); err != nil {
  26. log.Error("net.ListenTCP(\"tcp\", \"%s\") error(%v)", bind, err)
  27. return
  28. }
  29. log.Info("start ws listen: \"%s\"", bind)
  30. // split N core accept
  31. for i := 0; i < accept; i++ {
  32. go acceptWebsocketV1(s, listener)
  33. }
  34. }
  35. return
  36. }
  37. // InitWebsocketWithTLSV1 .
  38. func InitWebsocketWithTLSV1(s *Server, addrs []string, certFile, privateFile string, accept int) (err error) {
  39. var (
  40. bind string
  41. listener net.Listener
  42. cert tls.Certificate
  43. )
  44. cert, err = tls.LoadX509KeyPair(certFile, privateFile)
  45. if err != nil {
  46. log.Error("Error loading certificate. error(%v)", err)
  47. return
  48. }
  49. tlsCfg := &tls.Config{Certificates: []tls.Certificate{cert}}
  50. for _, bind = range addrs {
  51. if listener, err = tls.Listen("tcp", bind, tlsCfg); err != nil {
  52. log.Error("net.ListenTCP(\"tcp\", \"%s\") error(%v)", bind, err)
  53. return
  54. }
  55. log.Info("start wss listen: \"%s\"", bind)
  56. // split N core accept
  57. for i := 0; i < accept; i++ {
  58. go acceptWebsocketWithTLSV1(s, listener)
  59. }
  60. }
  61. return
  62. }
  63. // Accept accepts connections on the listener and serves requests
  64. // for each incoming connection. Accept blocks; the caller typically
  65. // invokes it in a go statement.
  66. func acceptWebsocketV1(s *Server, lis *net.TCPListener) {
  67. var (
  68. conn *net.TCPConn
  69. err error
  70. r int
  71. )
  72. for {
  73. if conn, err = lis.AcceptTCP(); err != nil {
  74. // if listener close then return
  75. log.Error("listener.Accept(\"%s\") error(%v)", lis.Addr().String(), err)
  76. time.Sleep(time.Second)
  77. continue
  78. }
  79. if err = conn.SetKeepAlive(s.c.TCP.Keepalive); err != nil {
  80. log.Error("conn.SetKeepAlive() error(%v)", err)
  81. return
  82. }
  83. if err = conn.SetReadBuffer(s.c.TCP.Rcvbuf); err != nil {
  84. log.Error("conn.SetReadBuffer() error(%v)", err)
  85. return
  86. }
  87. if err = conn.SetWriteBuffer(s.c.TCP.Sndbuf); err != nil {
  88. log.Error("conn.SetWriteBuffer() error(%v)", err)
  89. return
  90. }
  91. go serveWebsocketV1(s, conn, r)
  92. if r++; r == _maxInt {
  93. r = 0
  94. }
  95. }
  96. }
  97. // Accept accepts connections on the listener and serves requests
  98. // for each incoming connection. Accept blocks; the caller typically
  99. // invokes it in a go statement.
  100. func acceptWebsocketWithTLSV1(server *Server, lis net.Listener) {
  101. var (
  102. conn net.Conn
  103. err error
  104. r int
  105. )
  106. for {
  107. if conn, err = lis.Accept(); err != nil {
  108. // if listener close then return
  109. log.Error("listener.Accept(\"%s\") error(%v)", lis.Addr().String(), err)
  110. return
  111. }
  112. go serveWebsocketV1(server, conn, r)
  113. if r++; r == _maxInt {
  114. r = 0
  115. }
  116. }
  117. }
  118. func serveWebsocketV1(server *Server, conn net.Conn, r int) {
  119. var (
  120. // timer
  121. tr = server.round.Timer(r)
  122. rp = server.round.Reader(r)
  123. wp = server.round.Writer(r)
  124. )
  125. server.serveWebsocketV1(conn, rp, wp, tr)
  126. }
  127. // TODO linger close?
  128. func (s *Server) serveWebsocketV1(conn net.Conn, rp, wp *bytes.Pool, tr *itime.Timer) {
  129. var (
  130. err error
  131. roomID string
  132. hb time.Duration // heartbeat
  133. p *model.Proto
  134. b *Bucket
  135. trd *itime.TimerData
  136. rb = rp.Get()
  137. ch = NewChannel(s.c.ProtoSection.CliProto, s.c.ProtoSection.SvrProto)
  138. rr = &ch.Reader
  139. wr = &ch.Writer
  140. ws *websocket.Conn // websocket
  141. req *websocket.Request
  142. rpt *Report
  143. )
  144. // reader
  145. ch.Reader.ResetBuffer(conn, rb.Bytes())
  146. // handshake
  147. trd = tr.Add(time.Duration(s.c.ProtoSection.HandshakeTimeout), func() {
  148. conn.SetDeadline(time.Now().Add(time.Millisecond))
  149. conn.Close()
  150. })
  151. // websocket
  152. if req, err = websocket.ReadRequest(rr); err != nil || req.RequestURI != "/sub" {
  153. conn.Close()
  154. tr.Del(trd)
  155. rp.Put(rb)
  156. if err != io.EOF {
  157. log.Error("http.ReadRequest(rr) error(%v)", err)
  158. }
  159. return
  160. }
  161. // writer
  162. wb := wp.Get()
  163. ch.Writer.ResetBuffer(conn, wb.Bytes())
  164. if ws, err = websocket.Upgrade(conn, rr, wr, req); err != nil {
  165. conn.Close()
  166. tr.Del(trd)
  167. rp.Put(rb)
  168. wp.Put(wb)
  169. if err != io.EOF {
  170. log.Error("websocket.NewServerConn error(%v)", err)
  171. }
  172. return
  173. }
  174. ch.V1 = true
  175. ch.IP, _, _ = net.SplitHostPort(conn.RemoteAddr().String())
  176. // must not setadv, only used in auth
  177. if p, err = ch.CliProto.Set(); err == nil {
  178. if ch.Key, roomID, ch.Mid, hb, rpt, err = s.authWebsocketV1(ws, p, ch.IP); err == nil {
  179. b = s.Bucket(ch.Key)
  180. err = b.Put(roomID, ch)
  181. }
  182. }
  183. if err != nil {
  184. if err != io.EOF && err != websocket.ErrMessageClose {
  185. log.Error("key: %s ip: %s handshake failed error(%v)", ch.Key, conn.RemoteAddr().String(), err)
  186. }
  187. ws.Close()
  188. rp.Put(rb)
  189. wp.Put(wb)
  190. tr.Del(trd)
  191. return
  192. }
  193. trd.Key = ch.Key
  194. tr.Set(trd, hb)
  195. var online int32
  196. if ch.Room != nil {
  197. online = ch.Room.OnlineNum()
  198. }
  199. report(actionConnect, rpt, online)
  200. // hanshake ok start dispatch goroutine
  201. go s.dispatchWebsocketV1(ch.Key, ws, wp, wb, ch)
  202. for {
  203. if p, err = ch.CliProto.Set(); err != nil {
  204. break
  205. }
  206. if err = p.ReadWebsocketV1(ws); err != nil {
  207. break
  208. }
  209. if p.Operation == model.OpHeartbeat {
  210. tr.Set(trd, hb)
  211. p.Operation = model.OpHeartbeatReply
  212. } else {
  213. if err = s.Operate(p, ch, b); err != nil {
  214. break
  215. }
  216. }
  217. ch.CliProto.SetAdv()
  218. ch.Signal()
  219. }
  220. if err != nil && err != io.EOF && err != websocket.ErrMessageClose {
  221. log.Error("key: %s server tcp failed error(%v)", ch.Key, err)
  222. }
  223. b.Del(ch)
  224. tr.Del(trd)
  225. ws.Close()
  226. ch.Close()
  227. rp.Put(rb)
  228. //if err = s.Disconnect(context.Background(), ch.Mid, roomID); err != nil {
  229. // log.Error("key: %s operator do disconnect error(%v)", ch.Key, err)
  230. //}
  231. if ch.Room != nil {
  232. online = ch.Room.OnlineNum()
  233. }
  234. report(actionDisconnect, rpt, online)
  235. }
  236. // dispatch accepts connections on the listener and serves requests
  237. // for each incoming connection. dispatch blocks; the caller typically
  238. // invokes it in a go statement.
  239. func (s *Server) dispatchWebsocketV1(key string, ws *websocket.Conn, wp *bytes.Pool, wb *bytes.Buffer, ch *Channel) {
  240. var (
  241. err error
  242. finish bool
  243. online int32
  244. )
  245. for {
  246. var p = ch.Ready()
  247. switch p {
  248. case model.ProtoFinish:
  249. finish = true
  250. goto failed
  251. case model.ProtoReady:
  252. // fetch message from svrbox(client send)
  253. for {
  254. if p, err = ch.CliProto.Get(); err != nil {
  255. err = nil // must be empty error
  256. break
  257. }
  258. if p.Operation == model.OpHeartbeatReply {
  259. if ch.Room != nil {
  260. online = ch.Room.OnlineNum()
  261. }
  262. if err = p.WriteWebsocketHeartV1(ws, online); err != nil {
  263. goto failed
  264. }
  265. } else {
  266. if err = p.WriteWebsocketV1(ws); err != nil {
  267. goto failed
  268. }
  269. }
  270. p.Body = nil // avoid memory leak
  271. ch.CliProto.GetAdv()
  272. }
  273. default:
  274. // server send
  275. if err = p.WriteWebsocketV1(ws); err != nil {
  276. goto failed
  277. }
  278. }
  279. // only hungry flush response
  280. if err = ws.Flush(); err != nil {
  281. break
  282. }
  283. }
  284. failed:
  285. if err != nil && err != io.EOF && err != websocket.ErrMessageClose {
  286. log.Error("key: %s dispatch tcp error(%v)", key, err)
  287. }
  288. ws.Close()
  289. wp.Put(wb)
  290. // must ensure all channel message discard, for reader won't blocking Signal
  291. for !finish {
  292. finish = (ch.Ready() == model.ProtoFinish)
  293. }
  294. }
  295. // auth for goim handshake with client, use rsa & aes.
  296. func (s *Server) authWebsocketV1(ws *websocket.Conn, p *model.Proto, ip string) (key, roomID string, userID int64, heartbeat time.Duration, rpt *Report, err error) {
  297. if err = p.ReadWebsocketV1(ws); err != nil {
  298. return
  299. }
  300. if p.Operation != model.OpAuth {
  301. err = ErrOperation
  302. return
  303. }
  304. if userID, roomID, key, rpt, err = s.NoAuth(int16(p.Ver), p.Body, ip); err != nil {
  305. return
  306. }
  307. heartbeat = _clientHeartbeat
  308. p.Body = nil
  309. p.Operation = model.OpAuthReply
  310. if err = p.WriteWebsocketV1(ws); err != nil {
  311. return
  312. }
  313. err = ws.Flush()
  314. return
  315. }