conn.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. package memcache
  2. import (
  3. "bufio"
  4. "bytes"
  5. "compress/gzip"
  6. "context"
  7. "encoding/gob"
  8. "encoding/json"
  9. "fmt"
  10. "io"
  11. "net"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "time"
  16. "github.com/gogo/protobuf/proto"
  17. pkgerr "github.com/pkg/errors"
  18. )
  19. var (
  20. crlf = []byte("\r\n")
  21. spaceStr = string(" ")
  22. replyOK = []byte("OK\r\n")
  23. replyStored = []byte("STORED\r\n")
  24. replyNotStored = []byte("NOT_STORED\r\n")
  25. replyExists = []byte("EXISTS\r\n")
  26. replyNotFound = []byte("NOT_FOUND\r\n")
  27. replyDeleted = []byte("DELETED\r\n")
  28. replyEnd = []byte("END\r\n")
  29. replyTouched = []byte("TOUCHED\r\n")
  30. replyValueStr = "VALUE"
  31. replyClientErrorPrefix = []byte("CLIENT_ERROR ")
  32. replyServerErrorPrefix = []byte("SERVER_ERROR ")
  33. )
  34. const (
  35. _encodeBuf = 4096 // 4kb
  36. // 1024*1024 - 1, set error???
  37. _largeValue = 1000 * 1000 // 1MB
  38. )
  39. type reader struct {
  40. io.Reader
  41. }
  42. func (r *reader) Reset(rd io.Reader) {
  43. r.Reader = rd
  44. }
  45. // conn is the low-level implementation of Conn
  46. type conn struct {
  47. // Shared
  48. mu sync.Mutex
  49. err error
  50. conn net.Conn
  51. // Read & Write
  52. readTimeout time.Duration
  53. writeTimeout time.Duration
  54. rw *bufio.ReadWriter
  55. // Item Reader
  56. ir bytes.Reader
  57. // Compress
  58. gr gzip.Reader
  59. gw *gzip.Writer
  60. cb bytes.Buffer
  61. // Encoding
  62. edb bytes.Buffer
  63. // json
  64. jr reader
  65. jd *json.Decoder
  66. je *json.Encoder
  67. // protobuffer
  68. ped *proto.Buffer
  69. }
  70. // DialOption specifies an option for dialing a Memcache server.
  71. type DialOption struct {
  72. f func(*dialOptions)
  73. }
  74. type dialOptions struct {
  75. readTimeout time.Duration
  76. writeTimeout time.Duration
  77. dial func(network, addr string) (net.Conn, error)
  78. }
  79. // DialReadTimeout specifies the timeout for reading a single command reply.
  80. func DialReadTimeout(d time.Duration) DialOption {
  81. return DialOption{func(do *dialOptions) {
  82. do.readTimeout = d
  83. }}
  84. }
  85. // DialWriteTimeout specifies the timeout for writing a single command.
  86. func DialWriteTimeout(d time.Duration) DialOption {
  87. return DialOption{func(do *dialOptions) {
  88. do.writeTimeout = d
  89. }}
  90. }
  91. // DialConnectTimeout specifies the timeout for connecting to the Memcache server.
  92. func DialConnectTimeout(d time.Duration) DialOption {
  93. return DialOption{func(do *dialOptions) {
  94. dialer := net.Dialer{Timeout: d}
  95. do.dial = dialer.Dial
  96. }}
  97. }
  98. // DialNetDial specifies a custom dial function for creating TCP
  99. // connections. If this option is left out, then net.Dial is
  100. // used. DialNetDial overrides DialConnectTimeout.
  101. func DialNetDial(dial func(network, addr string) (net.Conn, error)) DialOption {
  102. return DialOption{func(do *dialOptions) {
  103. do.dial = dial
  104. }}
  105. }
  106. // Dial connects to the Memcache server at the given network and
  107. // address using the specified options.
  108. func Dial(network, address string, options ...DialOption) (Conn, error) {
  109. do := dialOptions{
  110. dial: net.Dial,
  111. }
  112. for _, option := range options {
  113. option.f(&do)
  114. }
  115. netConn, err := do.dial(network, address)
  116. if err != nil {
  117. return nil, pkgerr.WithStack(err)
  118. }
  119. return NewConn(netConn, do.readTimeout, do.writeTimeout), nil
  120. }
  121. // NewConn returns a new memcache connection for the given net connection.
  122. func NewConn(netConn net.Conn, readTimeout, writeTimeout time.Duration) Conn {
  123. if writeTimeout <= 0 || readTimeout <= 0 {
  124. panic("must config memcache timeout")
  125. }
  126. c := &conn{
  127. conn: netConn,
  128. rw: bufio.NewReadWriter(bufio.NewReader(netConn),
  129. bufio.NewWriter(netConn)),
  130. readTimeout: readTimeout,
  131. writeTimeout: writeTimeout,
  132. }
  133. c.jd = json.NewDecoder(&c.jr)
  134. c.je = json.NewEncoder(&c.edb)
  135. c.gw = gzip.NewWriter(&c.cb)
  136. c.edb.Grow(_encodeBuf)
  137. // NOTE reuse bytes.Buffer internal buf
  138. // DON'T concurrency call Scan
  139. c.ped = proto.NewBuffer(c.edb.Bytes())
  140. return c
  141. }
  142. func (c *conn) Close() error {
  143. c.mu.Lock()
  144. err := c.err
  145. if c.err == nil {
  146. c.err = pkgerr.New("memcache: closed")
  147. err = c.conn.Close()
  148. }
  149. c.mu.Unlock()
  150. return err
  151. }
  152. func (c *conn) fatal(err error) error {
  153. c.mu.Lock()
  154. if c.err == nil {
  155. c.err = pkgerr.WithStack(err)
  156. // Close connection to force errors on subsequent calls and to unblock
  157. // other reader or writer.
  158. c.conn.Close()
  159. }
  160. c.mu.Unlock()
  161. return c.err
  162. }
  163. func (c *conn) Err() error {
  164. c.mu.Lock()
  165. err := c.err
  166. c.mu.Unlock()
  167. return err
  168. }
  169. func (c *conn) Add(item *Item) error {
  170. return c.populate("add", item)
  171. }
  172. func (c *conn) Set(item *Item) error {
  173. return c.populate("set", item)
  174. }
  175. func (c *conn) Replace(item *Item) error {
  176. return c.populate("replace", item)
  177. }
  178. func (c *conn) CompareAndSwap(item *Item) error {
  179. return c.populate("cas", item)
  180. }
  181. func (c *conn) populate(cmd string, item *Item) (err error) {
  182. if !legalKey(item.Key) {
  183. return pkgerr.WithStack(ErrMalformedKey)
  184. }
  185. var res []byte
  186. if res, err = c.encode(item); err != nil {
  187. return
  188. }
  189. l := len(res)
  190. count := l/(_largeValue) + 1
  191. if count == 1 {
  192. item.Value = res
  193. return c.populateOne(cmd, item)
  194. }
  195. nItem := &Item{
  196. Key: item.Key,
  197. Value: []byte(strconv.Itoa(l)),
  198. Expiration: item.Expiration,
  199. Flags: item.Flags | flagLargeValue,
  200. }
  201. err = c.populateOne(cmd, nItem)
  202. if err != nil {
  203. return
  204. }
  205. k := item.Key
  206. nItem.Flags = item.Flags
  207. for i := 1; i <= count; i++ {
  208. if i == count {
  209. nItem.Value = res[_largeValue*(count-1):]
  210. } else {
  211. nItem.Value = res[_largeValue*(i-1) : _largeValue*i]
  212. }
  213. nItem.Key = fmt.Sprintf("%s%d", k, i)
  214. if err = c.populateOne(cmd, nItem); err != nil {
  215. return
  216. }
  217. }
  218. return
  219. }
  220. func (c *conn) populateOne(cmd string, item *Item) (err error) {
  221. if c.writeTimeout != 0 {
  222. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  223. }
  224. // <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
  225. if cmd == "cas" {
  226. _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d %d\r\n",
  227. cmd, item.Key, item.Flags, item.Expiration, len(item.Value), item.cas)
  228. } else {
  229. _, err = fmt.Fprintf(c.rw, "%s %s %d %d %d\r\n",
  230. cmd, item.Key, item.Flags, item.Expiration, len(item.Value))
  231. }
  232. if err != nil {
  233. return c.fatal(err)
  234. }
  235. c.rw.Write(item.Value)
  236. c.rw.Write(crlf)
  237. if err = c.rw.Flush(); err != nil {
  238. return c.fatal(err)
  239. }
  240. if c.readTimeout != 0 {
  241. c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
  242. }
  243. line, err := c.rw.ReadSlice('\n')
  244. if err != nil {
  245. return c.fatal(err)
  246. }
  247. switch {
  248. case bytes.Equal(line, replyStored):
  249. return nil
  250. case bytes.Equal(line, replyNotStored):
  251. return ErrNotStored
  252. case bytes.Equal(line, replyExists):
  253. return ErrCASConflict
  254. case bytes.Equal(line, replyNotFound):
  255. return ErrNotFound
  256. }
  257. return pkgerr.WithStack(protocolError(string(line)))
  258. }
  259. func (c *conn) Get(key string) (r *Item, err error) {
  260. if !legalKey(key) {
  261. return nil, pkgerr.WithStack(ErrMalformedKey)
  262. }
  263. if c.writeTimeout != 0 {
  264. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  265. }
  266. if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", key); err != nil {
  267. return nil, c.fatal(err)
  268. }
  269. if err = c.rw.Flush(); err != nil {
  270. return nil, c.fatal(err)
  271. }
  272. if err = c.parseGetReply(func(it *Item) {
  273. r = it
  274. }); err != nil {
  275. return
  276. }
  277. if r == nil {
  278. err = ErrNotFound
  279. return
  280. }
  281. if r.Flags&flagLargeValue != flagLargeValue {
  282. return
  283. }
  284. if r, err = c.getLargeValue(r); err != nil {
  285. return
  286. }
  287. return
  288. }
  289. func (c *conn) GetMulti(keys []string) (res map[string]*Item, err error) {
  290. for _, key := range keys {
  291. if !legalKey(key) {
  292. return nil, pkgerr.WithStack(ErrMalformedKey)
  293. }
  294. }
  295. if c.writeTimeout != 0 {
  296. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  297. }
  298. if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
  299. return nil, c.fatal(err)
  300. }
  301. if err = c.rw.Flush(); err != nil {
  302. return nil, c.fatal(err)
  303. }
  304. res = make(map[string]*Item, len(keys))
  305. if err = c.parseGetReply(func(it *Item) {
  306. res[it.Key] = it
  307. }); err != nil {
  308. return
  309. }
  310. for k, v := range res {
  311. if v.Flags&flagLargeValue != flagLargeValue {
  312. continue
  313. }
  314. r, err := c.getLargeValue(v)
  315. if err != nil {
  316. return res, err
  317. }
  318. res[k] = r
  319. }
  320. return
  321. }
  322. func (c *conn) getMulti(keys []string) (res map[string]*Item, err error) {
  323. for _, key := range keys {
  324. if !legalKey(key) {
  325. return nil, pkgerr.WithStack(ErrMalformedKey)
  326. }
  327. }
  328. if c.writeTimeout != 0 {
  329. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  330. }
  331. if _, err = fmt.Fprintf(c.rw, "gets %s\r\n", strings.Join(keys, " ")); err != nil {
  332. return nil, c.fatal(err)
  333. }
  334. if err = c.rw.Flush(); err != nil {
  335. return nil, c.fatal(err)
  336. }
  337. res = make(map[string]*Item, len(keys))
  338. err = c.parseGetReply(func(it *Item) {
  339. res[it.Key] = it
  340. })
  341. return
  342. }
  343. func (c *conn) getLargeValue(it *Item) (r *Item, err error) {
  344. l, err := strconv.Atoi(string(it.Value))
  345. if err != nil {
  346. return
  347. }
  348. count := l/_largeValue + 1
  349. keys := make([]string, 0, count)
  350. for i := 1; i <= count; i++ {
  351. keys = append(keys, fmt.Sprintf("%s%d", it.Key, i))
  352. }
  353. items, err := c.getMulti(keys)
  354. if err != nil {
  355. return
  356. }
  357. if len(items) < count {
  358. err = ErrNotFound
  359. return
  360. }
  361. v := make([]byte, 0, l)
  362. for _, k := range keys {
  363. if items[k] == nil || items[k].Value == nil {
  364. err = ErrNotFound
  365. return
  366. }
  367. v = append(v, items[k].Value...)
  368. }
  369. it.Value = v
  370. it.Flags = it.Flags ^ flagLargeValue
  371. r = it
  372. return
  373. }
  374. func (c *conn) parseGetReply(f func(*Item)) error {
  375. if c.readTimeout != 0 {
  376. c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
  377. }
  378. for {
  379. line, err := c.rw.ReadSlice('\n')
  380. if err != nil {
  381. return c.fatal(err)
  382. }
  383. if bytes.Equal(line, replyEnd) {
  384. return nil
  385. }
  386. if bytes.HasPrefix(line, replyServerErrorPrefix) {
  387. errMsg := line[len(replyServerErrorPrefix):]
  388. return c.fatal(protocolError(errMsg))
  389. }
  390. it := new(Item)
  391. size, err := scanGetReply(line, it)
  392. if err != nil {
  393. return c.fatal(err)
  394. }
  395. it.Value = make([]byte, size+2)
  396. if _, err = io.ReadFull(c.rw, it.Value); err != nil {
  397. return c.fatal(err)
  398. }
  399. if !bytes.HasSuffix(it.Value, crlf) {
  400. return c.fatal(protocolError("corrupt get reply, no except CRLF"))
  401. }
  402. it.Value = it.Value[:size]
  403. f(it)
  404. }
  405. }
  406. func scanGetReply(line []byte, item *Item) (size int, err error) {
  407. if !bytes.HasSuffix(line, crlf) {
  408. return 0, protocolError("corrupt get reply, no except CRLF")
  409. }
  410. // VALUE <key> <flags> <bytes> [<cas unique>]
  411. chunks := strings.Split(string(line[:len(line)-2]), spaceStr)
  412. if len(chunks) < 4 {
  413. return 0, protocolError("corrupt get reply")
  414. }
  415. if chunks[0] != replyValueStr {
  416. return 0, protocolError("corrupt get reply, no except VALUE")
  417. }
  418. item.Key = chunks[1]
  419. flags64, err := strconv.ParseUint(chunks[2], 10, 32)
  420. if err != nil {
  421. return 0, err
  422. }
  423. item.Flags = uint32(flags64)
  424. if size, err = strconv.Atoi(chunks[3]); err != nil {
  425. return
  426. }
  427. if len(chunks) > 4 {
  428. item.cas, err = strconv.ParseUint(chunks[4], 10, 64)
  429. }
  430. return
  431. }
  432. func (c *conn) Touch(key string, expire int32) (err error) {
  433. if !legalKey(key) {
  434. return pkgerr.WithStack(ErrMalformedKey)
  435. }
  436. line, err := c.writeReadLine("touch %s %d\r\n", key, expire)
  437. if err != nil {
  438. return err
  439. }
  440. switch {
  441. case bytes.Equal(line, replyTouched):
  442. return nil
  443. case bytes.Equal(line, replyNotFound):
  444. return ErrNotFound
  445. default:
  446. return pkgerr.WithStack(protocolError(string(line)))
  447. }
  448. }
  449. func (c *conn) Increment(key string, delta uint64) (uint64, error) {
  450. return c.incrDecr("incr", key, delta)
  451. }
  452. func (c *conn) Decrement(key string, delta uint64) (newValue uint64, err error) {
  453. return c.incrDecr("decr", key, delta)
  454. }
  455. func (c *conn) incrDecr(cmd, key string, delta uint64) (uint64, error) {
  456. if !legalKey(key) {
  457. return 0, pkgerr.WithStack(ErrMalformedKey)
  458. }
  459. line, err := c.writeReadLine("%s %s %d\r\n", cmd, key, delta)
  460. if err != nil {
  461. return 0, err
  462. }
  463. switch {
  464. case bytes.Equal(line, replyNotFound):
  465. return 0, ErrNotFound
  466. case bytes.HasPrefix(line, replyClientErrorPrefix):
  467. errMsg := line[len(replyClientErrorPrefix):]
  468. return 0, pkgerr.WithStack(protocolError(errMsg))
  469. }
  470. val, err := strconv.ParseUint(string(line[:len(line)-2]), 10, 64)
  471. if err != nil {
  472. return 0, err
  473. }
  474. return val, nil
  475. }
  476. func (c *conn) Delete(key string) (err error) {
  477. if !legalKey(key) {
  478. return pkgerr.WithStack(ErrMalformedKey)
  479. }
  480. line, err := c.writeReadLine("delete %s\r\n", key)
  481. if err != nil {
  482. return err
  483. }
  484. switch {
  485. case bytes.Equal(line, replyOK):
  486. return nil
  487. case bytes.Equal(line, replyDeleted):
  488. return nil
  489. case bytes.Equal(line, replyNotStored):
  490. return ErrNotStored
  491. case bytes.Equal(line, replyExists):
  492. return ErrCASConflict
  493. case bytes.Equal(line, replyNotFound):
  494. return ErrNotFound
  495. }
  496. return pkgerr.WithStack(protocolError(string(line)))
  497. }
  498. func (c *conn) writeReadLine(format string, args ...interface{}) ([]byte, error) {
  499. if c.writeTimeout != 0 {
  500. c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
  501. }
  502. _, err := fmt.Fprintf(c.rw, format, args...)
  503. if err != nil {
  504. return nil, c.fatal(pkgerr.WithStack(err))
  505. }
  506. if err = c.rw.Flush(); err != nil {
  507. return nil, c.fatal(pkgerr.WithStack(err))
  508. }
  509. if c.readTimeout != 0 {
  510. c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
  511. }
  512. line, err := c.rw.ReadSlice('\n')
  513. if err != nil {
  514. return line, c.fatal(pkgerr.WithStack(err))
  515. }
  516. return line, nil
  517. }
  518. func (c *conn) Scan(item *Item, v interface{}) (err error) {
  519. c.ir.Reset(item.Value)
  520. if item.Flags&FlagGzip == FlagGzip {
  521. if err = c.gr.Reset(&c.ir); err != nil {
  522. return
  523. }
  524. if err = c.decode(&c.gr, item, v); err != nil {
  525. err = pkgerr.WithStack(err)
  526. return
  527. }
  528. err = c.gr.Close()
  529. } else {
  530. err = c.decode(&c.ir, item, v)
  531. }
  532. err = pkgerr.WithStack(err)
  533. return
  534. }
  535. func (c *conn) WithContext(ctx context.Context) Conn {
  536. // FIXME: implement WithContext
  537. return c
  538. }
  539. func (c *conn) encode(item *Item) (data []byte, err error) {
  540. if (item.Flags | _flagEncoding) == _flagEncoding {
  541. if item.Value == nil {
  542. return nil, ErrItem
  543. }
  544. } else if item.Object == nil {
  545. return nil, ErrItem
  546. }
  547. // encoding
  548. switch {
  549. case item.Flags&FlagGOB == FlagGOB:
  550. c.edb.Reset()
  551. if err = gob.NewEncoder(&c.edb).Encode(item.Object); err != nil {
  552. return
  553. }
  554. data = c.edb.Bytes()
  555. case item.Flags&FlagProtobuf == FlagProtobuf:
  556. c.edb.Reset()
  557. c.ped.SetBuf(c.edb.Bytes())
  558. pb, ok := item.Object.(proto.Message)
  559. if !ok {
  560. err = ErrItemObject
  561. return
  562. }
  563. if err = c.ped.Marshal(pb); err != nil {
  564. return
  565. }
  566. data = c.ped.Bytes()
  567. case item.Flags&FlagJSON == FlagJSON:
  568. c.edb.Reset()
  569. if err = c.je.Encode(item.Object); err != nil {
  570. return
  571. }
  572. data = c.edb.Bytes()
  573. default:
  574. data = item.Value
  575. }
  576. // compress
  577. if item.Flags&FlagGzip == FlagGzip {
  578. c.cb.Reset()
  579. c.gw.Reset(&c.cb)
  580. if _, err = c.gw.Write(data); err != nil {
  581. return
  582. }
  583. if err = c.gw.Close(); err != nil {
  584. return
  585. }
  586. data = c.cb.Bytes()
  587. }
  588. if len(data) > 8000000 {
  589. err = ErrValueSize
  590. }
  591. return
  592. }
  593. func (c *conn) decode(rd io.Reader, item *Item, v interface{}) (err error) {
  594. var data []byte
  595. switch {
  596. case item.Flags&FlagGOB == FlagGOB:
  597. err = gob.NewDecoder(rd).Decode(v)
  598. case item.Flags&FlagJSON == FlagJSON:
  599. c.jr.Reset(rd)
  600. err = c.jd.Decode(v)
  601. default:
  602. data = item.Value
  603. if item.Flags&FlagGzip == FlagGzip {
  604. c.edb.Reset()
  605. if _, err = io.Copy(&c.edb, rd); err != nil {
  606. return
  607. }
  608. data = c.edb.Bytes()
  609. }
  610. if item.Flags&FlagProtobuf == FlagProtobuf {
  611. m, ok := v.(proto.Message)
  612. if !ok {
  613. err = ErrItemObject
  614. return
  615. }
  616. c.ped.SetBuf(data)
  617. err = c.ped.Unmarshal(m)
  618. } else {
  619. switch v.(type) {
  620. case *[]byte:
  621. d := v.(*[]byte)
  622. *d = data
  623. case *string:
  624. d := v.(*string)
  625. *d = string(data)
  626. case interface{}:
  627. err = json.Unmarshal(data, v)
  628. }
  629. }
  630. }
  631. return
  632. }
  633. func legalKey(key string) bool {
  634. if len(key) > 250 || len(key) == 0 {
  635. return false
  636. }
  637. for i := 0; i < len(key); i++ {
  638. if key[i] <= ' ' || key[i] == 0x7f {
  639. return false
  640. }
  641. }
  642. return true
  643. }