123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- package packet
- import (
- "bufio"
- "bytes"
- "io"
- "net"
- "github.com/juju/errors"
- . "github.com/siddontang/go-mysql/mysql"
- )
- /*
- Conn is the base class to handle MySQL protocol.
- */
- type Conn struct {
- net.Conn
- br *bufio.Reader
- Sequence uint8
- }
- func NewConn(conn net.Conn) *Conn {
- c := new(Conn)
- c.br = bufio.NewReaderSize(conn, 4096)
- c.Conn = conn
- return c
- }
- func (c *Conn) ReadPacket() ([]byte, error) {
- var buf bytes.Buffer
- if err := c.ReadPacketTo(&buf); err != nil {
- return nil, errors.Trace(err)
- } else {
- return buf.Bytes(), nil
- }
- // header := []byte{0, 0, 0, 0}
- // if _, err := io.ReadFull(c.br, header); err != nil {
- // return nil, ErrBadConn
- // }
- // length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
- // if length < 1 {
- // return nil, fmt.Errorf("invalid payload length %d", length)
- // }
- // sequence := uint8(header[3])
- // if sequence != c.Sequence {
- // return nil, fmt.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
- // }
- // c.Sequence++
- // data := make([]byte, length)
- // if _, err := io.ReadFull(c.br, data); err != nil {
- // return nil, ErrBadConn
- // } else {
- // if length < MaxPayloadLen {
- // return data, nil
- // }
- // var buf []byte
- // buf, err = c.ReadPacket()
- // if err != nil {
- // return nil, ErrBadConn
- // } else {
- // return append(data, buf...), nil
- // }
- // }
- }
- func (c *Conn) ReadPacketTo(w io.Writer) error {
- header := []byte{0, 0, 0, 0}
- if _, err := io.ReadFull(c.br, header); err != nil {
- return ErrBadConn
- }
- length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
- if length < 1 {
- return errors.Errorf("invalid payload length %d", length)
- }
- sequence := uint8(header[3])
- if sequence != c.Sequence {
- return errors.Errorf("invalid sequence %d != %d", sequence, c.Sequence)
- }
- c.Sequence++
- if n, err := io.CopyN(w, c.br, int64(length)); err != nil {
- return ErrBadConn
- } else if n != int64(length) {
- return ErrBadConn
- } else {
- if length < MaxPayloadLen {
- return nil
- }
- if err := c.ReadPacketTo(w); err != nil {
- return err
- }
- }
- return nil
- }
- // data already has 4 bytes header
- // will modify data inplace
- func (c *Conn) WritePacket(data []byte) error {
- length := len(data) - 4
- for length >= MaxPayloadLen {
- data[0] = 0xff
- data[1] = 0xff
- data[2] = 0xff
- data[3] = c.Sequence
- if n, err := c.Write(data[:4+MaxPayloadLen]); err != nil {
- return ErrBadConn
- } else if n != (4 + MaxPayloadLen) {
- return ErrBadConn
- } else {
- c.Sequence++
- length -= MaxPayloadLen
- data = data[MaxPayloadLen:]
- }
- }
- data[0] = byte(length)
- data[1] = byte(length >> 8)
- data[2] = byte(length >> 16)
- data[3] = c.Sequence
- if n, err := c.Write(data); err != nil {
- return ErrBadConn
- } else if n != len(data) {
- return ErrBadConn
- } else {
- c.Sequence++
- return nil
- }
- }
- func (c *Conn) ResetSequence() {
- c.Sequence = 0
- }
- func (c *Conn) Close() error {
- c.Sequence = 0
- if c.Conn != nil {
- return c.Conn.Close()
- }
- return nil
- }
|