123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- package websocket
- import (
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "go-common/app/service/main/broadcast/libs/bufio"
- )
- const (
-
- finBit = 1 << 7
- rsv1Bit = 1 << 6
- rsv2Bit = 1 << 5
- rsv3Bit = 1 << 4
- opBit = 0x0f
-
- maskBit = 1 << 7
- lenBit = 0x7f
- continuationFrame = 0
- continuationFrameMaxRead = 100
- )
- const (
-
-
- TextMessage = 1
-
- BinaryMessage = 2
-
-
-
- CloseMessage = 8
-
-
- PingMessage = 9
-
-
- PongMessage = 10
- )
- var (
-
- ErrMessageClose = errors.New("close control message")
-
- ErrMessageMaxRead = errors.New("continuation frame max read")
- )
- type Conn struct {
- rwc io.ReadWriteCloser
- r *bufio.Reader
- w *bufio.Writer
- }
- func newConn(rwc io.ReadWriteCloser, r *bufio.Reader, w *bufio.Writer) *Conn {
- return &Conn{rwc: rwc, r: r, w: w}
- }
- func (c *Conn) WriteMessage(msgType int, msg []byte) (err error) {
- if err = c.WriteHeader(msgType, len(msg)); err != nil {
- return
- }
- err = c.WriteBody(msg)
- return
- }
- func (c *Conn) WriteHeader(msgType int, length int) (err error) {
- var h []byte
- if h, err = c.w.Peek(2); err != nil {
- return
- }
-
- h[0] = 0
- h[0] |= finBit | byte(msgType)
-
- h[1] = 0
- switch {
- case length <= 125:
-
- h[1] |= byte(length)
- case length < 65536:
-
- h[1] |= 126
- if h, err = c.w.Peek(2); err != nil {
- return
- }
- binary.BigEndian.PutUint16(h, uint16(length))
- default:
-
- h[1] |= 127
- if h, err = c.w.Peek(8); err != nil {
- return
- }
- binary.BigEndian.PutUint64(h, uint64(length))
- }
- return
- }
- func (c *Conn) WriteBody(b []byte) (err error) {
- if len(b) > 0 {
- _, err = c.w.Write(b)
- }
- return
- }
- func (c *Conn) Peek(n int) ([]byte, error) {
- return c.w.Peek(n)
- }
- func (c *Conn) Flush() error {
- return c.w.Flush()
- }
- func (c *Conn) ReadMessage() (op int, payload []byte, err error) {
- var (
- fin bool
- finOp, n int
- partPayload []byte
- )
- for {
-
- if fin, op, partPayload, err = c.readFrame(); err != nil {
- return
- }
- switch op {
- case BinaryMessage, TextMessage, continuationFrame:
- if fin && len(payload) == 0 {
- return op, partPayload, nil
- }
-
- payload = append(payload, partPayload...)
- if op != continuationFrame {
- finOp = op
- }
-
- if fin {
- op = finOp
- return
- }
- case PingMessage:
-
- if err = c.WriteMessage(PongMessage, partPayload); err != nil {
- return
- }
- case PongMessage:
-
- case CloseMessage:
-
- err = ErrMessageClose
- return
- default:
- err = fmt.Errorf("unknown control message, fin=%t, op=%d", fin, op)
- return
- }
- if n > continuationFrameMaxRead {
- err = ErrMessageMaxRead
- return
- }
- n++
- }
- }
- func (c *Conn) readFrame() (fin bool, op int, payload []byte, err error) {
- var (
- b byte
- p []byte
- mask bool
- maskKey []byte
- payloadLen int64
- )
-
- b, err = c.r.ReadByte()
- if err != nil {
- return
- }
-
- fin = (b & finBit) != 0
-
- if rsv := b & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
- return false, 0, nil, fmt.Errorf("unexpected reserved bits rsv1=%d, rsv2=%d, rsv3=%d", b&rsv1Bit, b&rsv2Bit, b&rsv3Bit)
- }
-
- op = int(b & opBit)
-
- b, err = c.r.ReadByte()
- if err != nil {
- return
- }
-
- mask = (b & maskBit) != 0
-
- switch b & lenBit {
- case 126:
-
- if p, err = c.r.Pop(2); err != nil {
- return
- }
- payloadLen = int64(binary.BigEndian.Uint16(p))
- case 127:
-
- if p, err = c.r.Pop(8); err != nil {
- return
- }
- payloadLen = int64(binary.BigEndian.Uint64(p))
- default:
-
- payloadLen = int64(b & lenBit)
- }
-
- if mask {
- maskKey, err = c.r.Pop(4)
- if err != nil {
- return
- }
- }
-
- if payloadLen > 0 {
- if payload, err = c.r.Pop(int(payloadLen)); err != nil {
- return
- }
- if mask {
- maskBytes(maskKey, 0, payload)
- }
- }
- return
- }
- func (c *Conn) Close() error {
- return c.rwc.Close()
- }
- func maskBytes(key []byte, pos int, b []byte) int {
- for i := range b {
- b[i] ^= key[pos&3]
- pos++
- }
- return pos & 3
- }
|