conn.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package tcp
  2. import (
  3. "bufio"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "strconv"
  9. "strings"
  10. "time"
  11. )
  12. type conn struct {
  13. // conn
  14. conn net.Conn
  15. // Read
  16. readTimeout time.Duration
  17. br *bufio.Reader
  18. // Write
  19. writeTimeout time.Duration
  20. bw *bufio.Writer
  21. // Scratch space for formatting argument length.
  22. // '*' or '$', length, "\r\n"
  23. lenScratch [32]byte
  24. // Scratch space for formatting integers and floats.
  25. numScratch [40]byte
  26. }
  27. // newConn returns a new connection for the given net connection.
  28. func newConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) *conn {
  29. return &conn{
  30. conn: netConn,
  31. readTimeout: readTimeout,
  32. writeTimeout: writeTimeout,
  33. br: bufio.NewReaderSize(netConn, _readBufSize),
  34. bw: bufio.NewWriterSize(netConn, _writeBufSize),
  35. }
  36. }
  37. // Read read data from connection
  38. func (c *conn) Read() (cmd string, args [][]byte, err error) {
  39. var (
  40. ln, cn int
  41. bs []byte
  42. )
  43. if c.readTimeout > 0 {
  44. c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
  45. }
  46. // start read
  47. if bs, err = c.readLine(); err != nil {
  48. return
  49. }
  50. if len(bs) < 2 {
  51. err = fmt.Errorf("read error data(%s) from connection", bs)
  52. return
  53. }
  54. // maybe a cmd that without any params is received,such as: QUIT
  55. if strings.ToLower(string(bs)) == _quit {
  56. cmd = _quit
  57. return
  58. }
  59. // get param number
  60. if ln, err = parseLen(bs[1:]); err != nil {
  61. return
  62. }
  63. args = make([][]byte, 0, ln-1)
  64. for i := 0; i < ln; i++ {
  65. if cn, err = c.readLen(_protoBulk); err != nil {
  66. return
  67. }
  68. if bs, err = c.readData(cn); err != nil {
  69. return
  70. }
  71. if i == 0 {
  72. cmd = strings.ToLower(string(bs))
  73. continue
  74. }
  75. args = append(args, bs)
  76. }
  77. return
  78. }
  79. // WriteError write error to connection and close connection
  80. func (c *conn) WriteError(err error) {
  81. if c.writeTimeout > 0 {
  82. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  83. }
  84. if err = c.Write(proto{prefix: _protoErr, message: err.Error()}); err != nil {
  85. c.Close()
  86. return
  87. }
  88. c.Flush()
  89. c.Close()
  90. }
  91. // Write write data to connection
  92. func (c *conn) Write(p proto) (err error) {
  93. if c.writeTimeout > 0 {
  94. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  95. }
  96. // start write
  97. switch p.prefix {
  98. case _protoStr:
  99. err = c.writeStatus(p.message)
  100. case _protoErr:
  101. err = c.writeError(p.message)
  102. case _protoInt:
  103. err = c.writeInt64(int64(p.integer))
  104. case _protoBulk:
  105. // c.writeString(p.message)
  106. err = c.writeBytes([]byte(p.message))
  107. case _protoArray:
  108. err = c.writeLen(p.prefix, p.integer)
  109. }
  110. return
  111. }
  112. // Flush flush connection
  113. func (c *conn) Flush() error {
  114. if c.writeTimeout > 0 {
  115. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  116. }
  117. return c.bw.Flush()
  118. }
  119. // Close close connection
  120. func (c *conn) Close() error {
  121. return c.conn.Close()
  122. }
  123. // parseLen parses bulk string and array lengths.
  124. func parseLen(p []byte) (int, error) {
  125. if len(p) == 0 {
  126. return -1, errors.New("malformed length")
  127. }
  128. if p[0] == '-' && len(p) == 2 && p[1] == '1' {
  129. // handle $-1 and $-1 null replies.
  130. return -1, nil
  131. }
  132. var n int
  133. for _, b := range p {
  134. n *= 10
  135. if b < '0' || b > '9' {
  136. return -1, errors.New("illegal bytes in length")
  137. }
  138. n += int(b - '0')
  139. }
  140. return n, nil
  141. }
  142. func (c *conn) readLine() ([]byte, error) {
  143. p, err := c.br.ReadBytes('\n')
  144. if err == bufio.ErrBufferFull {
  145. return nil, errors.New("long response line")
  146. }
  147. if err != nil {
  148. return nil, err
  149. }
  150. i := len(p) - 2
  151. if i < 0 || p[i] != '\r' {
  152. return nil, errors.New("bad response line terminator")
  153. }
  154. return p[:i], nil
  155. }
  156. func (c *conn) readLen(prefix byte) (int, error) {
  157. ls, err := c.readLine()
  158. if err != nil {
  159. return 0, err
  160. }
  161. if len(ls) < 2 {
  162. return 0, errors.New("illegal bytes in length")
  163. }
  164. if ls[0] != prefix {
  165. return 0, errors.New("illegal bytes in length")
  166. }
  167. return parseLen(ls[1:])
  168. }
  169. func (c *conn) readData(n int) ([]byte, error) {
  170. if n > _maxValueSize {
  171. return nil, errors.New("exceeding max value limit")
  172. }
  173. buf := make([]byte, n+2)
  174. r, err := io.ReadFull(c.br, buf)
  175. if err != nil {
  176. return nil, err
  177. }
  178. if n != r-2 {
  179. return nil, errors.New("invalid bytes in len")
  180. }
  181. return buf[:n], err
  182. }
  183. func (c *conn) writeLen(prefix byte, n int) error {
  184. c.lenScratch[len(c.lenScratch)-1] = '\n'
  185. c.lenScratch[len(c.lenScratch)-2] = '\r'
  186. i := len(c.lenScratch) - 3
  187. for {
  188. c.lenScratch[i] = byte('0' + n%10)
  189. i--
  190. n = n / 10
  191. if n == 0 {
  192. break
  193. }
  194. }
  195. c.lenScratch[i] = prefix
  196. _, err := c.bw.Write(c.lenScratch[i:])
  197. return err
  198. }
  199. func (c *conn) writeStatus(s string) (err error) {
  200. c.bw.WriteByte(_protoStr)
  201. c.bw.WriteString(s)
  202. _, err = c.bw.WriteString("\r\n")
  203. return
  204. }
  205. func (c *conn) writeError(s string) (err error) {
  206. c.bw.WriteByte(_protoErr)
  207. c.bw.WriteString(s)
  208. _, err = c.bw.WriteString("\r\n")
  209. return
  210. }
  211. func (c *conn) writeInt64(n int64) (err error) {
  212. c.bw.WriteByte(_protoInt)
  213. c.bw.Write(strconv.AppendInt(c.numScratch[:0], n, 10))
  214. _, err = c.bw.WriteString("\r\n")
  215. return
  216. }
  217. func (c *conn) writeString(s string) (err error) {
  218. c.writeLen(_protoBulk, len(s))
  219. c.bw.WriteString(s)
  220. _, err = c.bw.WriteString("\r\n")
  221. return
  222. }
  223. func (c *conn) writeBytes(s []byte) (err error) {
  224. if len(s) == 0 {
  225. c.bw.WriteByte('$')
  226. c.bw.Write(_nullBulk)
  227. } else {
  228. c.writeLen(_protoBulk, len(s))
  229. c.bw.Write(s)
  230. }
  231. _, err = c.bw.WriteString("\r\n")
  232. return
  233. }
  234. func (c *conn) writeStrings(ss []string) (err error) {
  235. c.writeLen(_protoArray, len(ss))
  236. for _, s := range ss {
  237. if err = c.writeString(s); err != nil {
  238. return
  239. }
  240. }
  241. return
  242. }