conn.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package websocket
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "go-common/app/service/main/broadcast/libs/bufio"
  8. )
  9. const (
  10. // Frame header byte 0 bits from Section 5.2 of RFC 6455
  11. finBit = 1 << 7
  12. rsv1Bit = 1 << 6
  13. rsv2Bit = 1 << 5
  14. rsv3Bit = 1 << 4
  15. opBit = 0x0f
  16. // Frame header byte 1 bits from Section 5.2 of RFC 6455
  17. maskBit = 1 << 7
  18. lenBit = 0x7f
  19. continuationFrame = 0
  20. continuationFrameMaxRead = 100
  21. )
  22. // The message types are defined in RFC 6455, section 11.8.
  23. const (
  24. // TextMessage denotes a text data message. The text message payload is
  25. // interpreted as UTF-8 encoded text data.
  26. TextMessage = 1
  27. // BinaryMessage denotes a binary data message.
  28. BinaryMessage = 2
  29. // CloseMessage denotes a close control message. The optional message
  30. // payload contains a numeric code and text. Use the FormatCloseMessage
  31. // function to format a close message payload.
  32. CloseMessage = 8
  33. // PingMessage denotes a ping control message. The optional message payload
  34. // is UTF-8 encoded text.
  35. PingMessage = 9
  36. // PongMessage denotes a ping control message. The optional message payload
  37. // is UTF-8 encoded text.
  38. PongMessage = 10
  39. )
  40. var (
  41. // ErrMessageClose close control message
  42. ErrMessageClose = errors.New("close control message")
  43. // ErrMessageMaxRead continuation frrame max read
  44. ErrMessageMaxRead = errors.New("continuation frame max read")
  45. )
  46. // Conn represents a WebSocket connection.
  47. type Conn struct {
  48. rwc io.ReadWriteCloser
  49. r *bufio.Reader
  50. w *bufio.Writer
  51. }
  52. // new connection
  53. func newConn(rwc io.ReadWriteCloser, r *bufio.Reader, w *bufio.Writer) *Conn {
  54. return &Conn{rwc: rwc, r: r, w: w}
  55. }
  56. // WriteMessage write a message by type.
  57. func (c *Conn) WriteMessage(msgType int, msg []byte) (err error) {
  58. if err = c.WriteHeader(msgType, len(msg)); err != nil {
  59. return
  60. }
  61. err = c.WriteBody(msg)
  62. return
  63. }
  64. // WriteHeader write header frame.
  65. func (c *Conn) WriteHeader(msgType int, length int) (err error) {
  66. var h []byte
  67. if h, err = c.w.Peek(2); err != nil {
  68. return
  69. }
  70. // 1.First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
  71. h[0] = 0
  72. h[0] |= finBit | byte(msgType)
  73. // 2.Second byte. Mask/Payload len(7bits)
  74. h[1] = 0
  75. switch {
  76. case length <= 125:
  77. // 7 bits
  78. h[1] |= byte(length)
  79. case length < 65536:
  80. // 16 bits
  81. h[1] |= 126
  82. if h, err = c.w.Peek(2); err != nil {
  83. return
  84. }
  85. binary.BigEndian.PutUint16(h, uint16(length))
  86. default:
  87. // 64 bits
  88. h[1] |= 127
  89. if h, err = c.w.Peek(8); err != nil {
  90. return
  91. }
  92. binary.BigEndian.PutUint64(h, uint64(length))
  93. }
  94. return
  95. }
  96. // WriteBody write a message body.
  97. func (c *Conn) WriteBody(b []byte) (err error) {
  98. if len(b) > 0 {
  99. _, err = c.w.Write(b)
  100. }
  101. return
  102. }
  103. // Peek write peek.
  104. func (c *Conn) Peek(n int) ([]byte, error) {
  105. return c.w.Peek(n)
  106. }
  107. // Flush flush writer buffer
  108. func (c *Conn) Flush() error {
  109. return c.w.Flush()
  110. }
  111. // ReadMessage read a message.
  112. func (c *Conn) ReadMessage() (op int, payload []byte, err error) {
  113. var (
  114. fin bool
  115. finOp, n int
  116. partPayload []byte
  117. )
  118. for {
  119. // read frame
  120. if fin, op, partPayload, err = c.readFrame(); err != nil {
  121. return
  122. }
  123. switch op {
  124. case BinaryMessage, TextMessage, continuationFrame:
  125. if fin && len(payload) == 0 {
  126. return op, partPayload, nil
  127. }
  128. // continuation frame
  129. payload = append(payload, partPayload...)
  130. if op != continuationFrame {
  131. finOp = op
  132. }
  133. // final frame
  134. if fin {
  135. op = finOp
  136. return
  137. }
  138. case PingMessage:
  139. // handler ping
  140. if err = c.WriteMessage(PongMessage, partPayload); err != nil {
  141. return
  142. }
  143. case PongMessage:
  144. // handler pong
  145. case CloseMessage:
  146. // handler close
  147. err = ErrMessageClose
  148. return
  149. default:
  150. err = fmt.Errorf("unknown control message, fin=%t, op=%d", fin, op)
  151. return
  152. }
  153. if n > continuationFrameMaxRead {
  154. err = ErrMessageMaxRead
  155. return
  156. }
  157. n++
  158. }
  159. }
  160. func (c *Conn) readFrame() (fin bool, op int, payload []byte, err error) {
  161. var (
  162. b byte
  163. p []byte
  164. mask bool
  165. maskKey []byte
  166. payloadLen int64
  167. )
  168. // 1.First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
  169. b, err = c.r.ReadByte()
  170. if err != nil {
  171. return
  172. }
  173. // final frame
  174. fin = (b & finBit) != 0
  175. // rsv MUST be 0
  176. if rsv := b & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
  177. return false, 0, nil, fmt.Errorf("unexpected reserved bits rsv1=%d, rsv2=%d, rsv3=%d", b&rsv1Bit, b&rsv2Bit, b&rsv3Bit)
  178. }
  179. // op code
  180. op = int(b & opBit)
  181. // 2.Second byte. Mask/Payload len(7bits)
  182. b, err = c.r.ReadByte()
  183. if err != nil {
  184. return
  185. }
  186. // is mask payload
  187. mask = (b & maskBit) != 0
  188. // payload length
  189. switch b & lenBit {
  190. case 126:
  191. // 16 bits
  192. if p, err = c.r.Pop(2); err != nil {
  193. return
  194. }
  195. payloadLen = int64(binary.BigEndian.Uint16(p))
  196. case 127:
  197. // 64 bits
  198. if p, err = c.r.Pop(8); err != nil {
  199. return
  200. }
  201. payloadLen = int64(binary.BigEndian.Uint64(p))
  202. default:
  203. // 7 bits
  204. payloadLen = int64(b & lenBit)
  205. }
  206. // read mask key
  207. if mask {
  208. maskKey, err = c.r.Pop(4)
  209. if err != nil {
  210. return
  211. }
  212. }
  213. // read payload
  214. if payloadLen > 0 {
  215. if payload, err = c.r.Pop(int(payloadLen)); err != nil {
  216. return
  217. }
  218. if mask {
  219. maskBytes(maskKey, 0, payload)
  220. }
  221. }
  222. return
  223. }
  224. // Close close the connection.
  225. func (c *Conn) Close() error {
  226. return c.rwc.Close()
  227. }
  228. func maskBytes(key []byte, pos int, b []byte) int {
  229. for i := range b {
  230. b[i] ^= key[pos&3]
  231. pos++
  232. }
  233. return pos & 3
  234. }