conn.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. package packet
  2. import (
  3. "bufio"
  4. "bytes"
  5. "io"
  6. "net"
  7. "github.com/juju/errors"
  8. . "github.com/siddontang/go-mysql/mysql"
  9. )
  10. /*
  11. Conn is the base class to handle MySQL protocol.
  12. */
  13. type Conn struct {
  14. net.Conn
  15. br *bufio.Reader
  16. Sequence uint8
  17. }
  18. func NewConn(conn net.Conn) *Conn {
  19. c := new(Conn)
  20. c.br = bufio.NewReaderSize(conn, 4096)
  21. c.Conn = conn
  22. return c
  23. }
  24. func (c *Conn) ReadPacket() ([]byte, error) {
  25. var buf bytes.Buffer
  26. if err := c.ReadPacketTo(&buf); err != nil {
  27. return nil, errors.Trace(err)
  28. } else {
  29. return buf.Bytes(), nil
  30. }
  31. // header := []byte{0, 0, 0, 0}
  32. // if _, err := io.ReadFull(c.br, header); err != nil {
  33. // return nil, ErrBadConn
  34. // }
  35. // length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
  36. // if length < 1 {
  37. // return nil, fmt.Errorf("invalid payload length %d", length)
  38. // }
  39. // sequence := uint8(header[3])
  40. // if sequence != c.Sequence {
  41. // return nil, fmt.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
  42. // }
  43. // c.Sequence++
  44. // data := make([]byte, length)
  45. // if _, err := io.ReadFull(c.br, data); err != nil {
  46. // return nil, ErrBadConn
  47. // } else {
  48. // if length < MaxPayloadLen {
  49. // return data, nil
  50. // }
  51. // var buf []byte
  52. // buf, err = c.ReadPacket()
  53. // if err != nil {
  54. // return nil, ErrBadConn
  55. // } else {
  56. // return append(data, buf...), nil
  57. // }
  58. // }
  59. }
  60. func (c *Conn) ReadPacketTo(w io.Writer) error {
  61. header := []byte{0, 0, 0, 0}
  62. if _, err := io.ReadFull(c.br, header); err != nil {
  63. return ErrBadConn
  64. }
  65. length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
  66. if length < 1 {
  67. return errors.Errorf("invalid payload length %d", length)
  68. }
  69. sequence := uint8(header[3])
  70. if sequence != c.Sequence {
  71. return errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
  72. }
  73. c.Sequence++
  74. if n, err := io.CopyN(w, c.br, int64(length)); err != nil {
  75. return ErrBadConn
  76. } else if n != int64(length) {
  77. return ErrBadConn
  78. } else {
  79. if length < MaxPayloadLen {
  80. return nil
  81. }
  82. if err := c.ReadPacketTo(w); err != nil {
  83. return err
  84. }
  85. }
  86. return nil
  87. }
  88. // data already has 4 bytes header
  89. // will modify data inplace
  90. func (c *Conn) WritePacket(data []byte) error {
  91. length := len(data) - 4
  92. for length >= MaxPayloadLen {
  93. data[0] = 0xff
  94. data[1] = 0xff
  95. data[2] = 0xff
  96. data[3] = c.Sequence
  97. if n, err := c.Write(data[:4+MaxPayloadLen]); err != nil {
  98. return ErrBadConn
  99. } else if n != (4 + MaxPayloadLen) {
  100. return ErrBadConn
  101. } else {
  102. c.Sequence++
  103. length -= MaxPayloadLen
  104. data = data[MaxPayloadLen:]
  105. }
  106. }
  107. data[0] = byte(length)
  108. data[1] = byte(length >> 8)
  109. data[2] = byte(length >> 16)
  110. data[3] = c.Sequence
  111. if n, err := c.Write(data); err != nil {
  112. return ErrBadConn
  113. } else if n != len(data) {
  114. return ErrBadConn
  115. } else {
  116. c.Sequence++
  117. return nil
  118. }
  119. }
  120. func (c *Conn) ResetSequence() {
  121. c.Sequence = 0
  122. }
  123. func (c *Conn) Close() error {
  124. c.Sequence = 0
  125. if c.Conn != nil {
  126. return c.Conn.Close()
  127. }
  128. return nil
  129. }