conn.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. // Package zk is a native Go client library for the ZooKeeper orchestration service.
  2. package zk
  3. /*
  4. TODO:
  5. * make sure a ping response comes back in a reasonable time
  6. Possible watcher events:
  7. * Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err}
  8. */
  9. import (
  10. "crypto/rand"
  11. "encoding/binary"
  12. "errors"
  13. "fmt"
  14. "io"
  15. "net"
  16. "strconv"
  17. "strings"
  18. "sync"
  19. "sync/atomic"
  20. "time"
  21. )
  22. // ErrNoServer indicates that an operation cannot be completed
  23. // because attempts to connect to all servers in the list failed.
  24. var ErrNoServer = errors.New("zk: could not connect to a server")
  25. // ErrInvalidPath indicates that an operation was being attempted on
  26. // an invalid path. (e.g. empty path)
  27. var ErrInvalidPath = errors.New("zk: invalid path")
  28. // DefaultLogger uses the stdlib log package for logging.
  29. var DefaultLogger Logger = defaultLogger{}
  30. const (
  31. bufferSize = 1536 * 1024
  32. eventChanSize = 6
  33. sendChanSize = 16
  34. protectedPrefix = "_c_"
  35. )
  36. type watchType int
  37. const (
  38. watchTypeData = iota
  39. watchTypeExist
  40. watchTypeChild
  41. )
  42. type watchPathType struct {
  43. path string
  44. wType watchType
  45. }
  46. type Dialer func(network, address string, timeout time.Duration) (net.Conn, error)
  47. // Logger is an interface that can be implemented to provide custom log output.
  48. type Logger interface {
  49. Printf(string, ...interface{})
  50. }
  51. type authCreds struct {
  52. scheme string
  53. auth []byte
  54. }
  55. type Conn struct {
  56. lastZxid int64
  57. sessionID int64
  58. state State // must be 32-bit aligned
  59. xid uint32
  60. sessionTimeoutMs int32 // session timeout in milliseconds
  61. passwd []byte
  62. dialer Dialer
  63. hostProvider HostProvider
  64. serverMu sync.Mutex // protects server
  65. server string // remember the address/port of the current server
  66. conn net.Conn
  67. eventChan chan Event
  68. eventCallback EventCallback // may be nil
  69. shouldQuit chan struct{}
  70. pingInterval time.Duration
  71. recvTimeout time.Duration
  72. connectTimeout time.Duration
  73. creds []authCreds
  74. credsMu sync.Mutex // protects server
  75. sendChan chan *request
  76. requests map[int32]*request // Xid -> pending request
  77. requestsLock sync.Mutex
  78. watchers map[watchPathType][]chan Event
  79. watchersLock sync.Mutex
  80. closeChan chan struct{} // channel to tell send loop stop
  81. // Debug (used by unit tests)
  82. reconnectDelay time.Duration
  83. logger Logger
  84. buf []byte
  85. }
  86. // connOption represents a connection option.
  87. type connOption func(c *Conn)
  88. type request struct {
  89. xid int32
  90. opcode int32
  91. pkt interface{}
  92. recvStruct interface{}
  93. recvChan chan response
  94. // Because sending and receiving happen in separate go routines, there's
  95. // a possible race condition when creating watches from outside the read
  96. // loop. We must ensure that a watcher gets added to the list synchronously
  97. // with the response from the server on any request that creates a watch.
  98. // In order to not hard code the watch logic for each opcode in the recv
  99. // loop the caller can use recvFunc to insert some synchronously code
  100. // after a response.
  101. recvFunc func(*request, *responseHeader, error)
  102. }
  103. type response struct {
  104. zxid int64
  105. err error
  106. }
  107. type Event struct {
  108. Type EventType
  109. State State
  110. Path string // For non-session events, the path of the watched node.
  111. Err error
  112. Server string // For connection events
  113. }
  114. // HostProvider is used to represent a set of hosts a ZooKeeper client should connect to.
  115. // It is an analog of the Java equivalent:
  116. // http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup
  117. type HostProvider interface {
  118. // Init is called first, with the servers specified in the connection string.
  119. Init(servers []string) error
  120. // Len returns the number of servers.
  121. Len() int
  122. // Next returns the next server to connect to. retryStart will be true if we've looped through
  123. // all known servers without Connected() being called.
  124. Next() (server string, retryStart bool)
  125. // Notify the HostProvider of a successful connection.
  126. Connected()
  127. }
  128. // ConnectWithDialer establishes a new connection to a pool of zookeeper servers
  129. // using a custom Dialer. See Connect for further information about session timeout.
  130. // This method is deprecated and provided for compatibility: use the WithDialer option instead.
  131. func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) {
  132. return Connect(servers, sessionTimeout, WithDialer(dialer))
  133. }
  134. // Connect establishes a new connection to a pool of zookeeper
  135. // servers. The provided session timeout sets the amount of time for which
  136. // a session is considered valid after losing connection to a server. Within
  137. // the session timeout it's possible to reestablish a connection to a different
  138. // server and keep the same session. This is means any ephemeral nodes and
  139. // watches are maintained.
  140. func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) {
  141. if len(servers) == 0 {
  142. return nil, nil, errors.New("zk: server list must not be empty")
  143. }
  144. srvs := make([]string, len(servers))
  145. for i, addr := range servers {
  146. if strings.Contains(addr, ":") {
  147. srvs[i] = addr
  148. } else {
  149. srvs[i] = addr + ":" + strconv.Itoa(DefaultPort)
  150. }
  151. }
  152. // Randomize the order of the servers to avoid creating hotspots
  153. stringShuffle(srvs)
  154. ec := make(chan Event, eventChanSize)
  155. conn := &Conn{
  156. dialer: net.DialTimeout,
  157. hostProvider: &DNSHostProvider{},
  158. conn: nil,
  159. state: StateDisconnected,
  160. eventChan: ec,
  161. shouldQuit: make(chan struct{}),
  162. connectTimeout: 1 * time.Second,
  163. sendChan: make(chan *request, sendChanSize),
  164. requests: make(map[int32]*request),
  165. watchers: make(map[watchPathType][]chan Event),
  166. passwd: emptyPassword,
  167. logger: DefaultLogger,
  168. buf: make([]byte, bufferSize),
  169. // Debug
  170. reconnectDelay: 0,
  171. }
  172. // Set provided options.
  173. for _, option := range options {
  174. option(conn)
  175. }
  176. if err := conn.hostProvider.Init(srvs); err != nil {
  177. return nil, nil, err
  178. }
  179. conn.setTimeouts(int32(sessionTimeout / time.Millisecond))
  180. go func() {
  181. conn.loop()
  182. conn.flushRequests(ErrClosing)
  183. conn.invalidateWatches(ErrClosing)
  184. close(conn.eventChan)
  185. }()
  186. return conn, ec, nil
  187. }
  188. // WithDialer returns a connection option specifying a non-default Dialer.
  189. func WithDialer(dialer Dialer) connOption {
  190. return func(c *Conn) {
  191. c.dialer = dialer
  192. }
  193. }
  194. // WithHostProvider returns a connection option specifying a non-default HostProvider.
  195. func WithHostProvider(hostProvider HostProvider) connOption {
  196. return func(c *Conn) {
  197. c.hostProvider = hostProvider
  198. }
  199. }
  200. // EventCallback is a function that is called when an Event occurs.
  201. type EventCallback func(Event)
  202. // WithEventCallback returns a connection option that specifies an event
  203. // callback.
  204. // The callback must not block - doing so would delay the ZK go routines.
  205. func WithEventCallback(cb EventCallback) connOption {
  206. return func(c *Conn) {
  207. c.eventCallback = cb
  208. }
  209. }
  210. func (c *Conn) Close() {
  211. close(c.shouldQuit)
  212. select {
  213. case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil):
  214. case <-time.After(time.Second):
  215. }
  216. }
  217. // State returns the current state of the connection.
  218. func (c *Conn) State() State {
  219. return State(atomic.LoadInt32((*int32)(&c.state)))
  220. }
  221. // SessionID returns the current session id of the connection.
  222. func (c *Conn) SessionID() int64 {
  223. return atomic.LoadInt64(&c.sessionID)
  224. }
  225. // SetLogger sets the logger to be used for printing errors.
  226. // Logger is an interface provided by this package.
  227. func (c *Conn) SetLogger(l Logger) {
  228. c.logger = l
  229. }
  230. func (c *Conn) setTimeouts(sessionTimeoutMs int32) {
  231. c.sessionTimeoutMs = sessionTimeoutMs
  232. sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond
  233. c.recvTimeout = sessionTimeout * 2 / 3
  234. c.pingInterval = c.recvTimeout / 2
  235. }
  236. func (c *Conn) setState(state State) {
  237. atomic.StoreInt32((*int32)(&c.state), int32(state))
  238. c.sendEvent(Event{Type: EventSession, State: state, Server: c.Server()})
  239. }
  240. func (c *Conn) sendEvent(evt Event) {
  241. if c.eventCallback != nil {
  242. c.eventCallback(evt)
  243. }
  244. select {
  245. case c.eventChan <- evt:
  246. default:
  247. // panic("zk: event channel full - it must be monitored and never allowed to be full")
  248. }
  249. }
  250. func (c *Conn) connect() error {
  251. var retryStart bool
  252. for {
  253. c.serverMu.Lock()
  254. c.server, retryStart = c.hostProvider.Next()
  255. c.serverMu.Unlock()
  256. c.setState(StateConnecting)
  257. if retryStart {
  258. c.flushUnsentRequests(ErrNoServer)
  259. select {
  260. case <-time.After(time.Second):
  261. // pass
  262. case <-c.shouldQuit:
  263. c.setState(StateDisconnected)
  264. c.flushUnsentRequests(ErrClosing)
  265. return ErrClosing
  266. }
  267. }
  268. zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout)
  269. if err == nil {
  270. c.conn = zkConn
  271. c.setState(StateConnected)
  272. c.logger.Printf("Connected to %s", c.Server())
  273. return nil
  274. }
  275. c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err)
  276. }
  277. }
  278. func (c *Conn) resendZkAuth(reauthReadyChan chan struct{}) {
  279. c.credsMu.Lock()
  280. defer c.credsMu.Unlock()
  281. defer close(reauthReadyChan)
  282. c.logger.Printf("Re-submitting `%d` credentials after reconnect",
  283. len(c.creds))
  284. for _, cred := range c.creds {
  285. resChan, err := c.sendRequest(
  286. opSetAuth,
  287. &setAuthRequest{Type: 0,
  288. Scheme: cred.scheme,
  289. Auth: cred.auth,
  290. },
  291. &setAuthResponse{},
  292. nil)
  293. if err != nil {
  294. c.logger.Printf("Call to sendRequest failed during credential resubmit: %s", err)
  295. // FIXME(prozlach): lets ignore errors for now
  296. continue
  297. }
  298. res := <-resChan
  299. if res.err != nil {
  300. c.logger.Printf("Credential re-submit failed: %s", res.err)
  301. // FIXME(prozlach): lets ignore errors for now
  302. continue
  303. }
  304. }
  305. }
  306. func (c *Conn) sendRequest(
  307. opcode int32,
  308. req interface{},
  309. res interface{},
  310. recvFunc func(*request, *responseHeader, error),
  311. ) (
  312. <-chan response,
  313. error,
  314. ) {
  315. rq := &request{
  316. xid: c.nextXid(),
  317. opcode: opcode,
  318. pkt: req,
  319. recvStruct: res,
  320. recvChan: make(chan response, 1),
  321. recvFunc: recvFunc,
  322. }
  323. if err := c.sendData(rq); err != nil {
  324. return nil, err
  325. }
  326. return rq.recvChan, nil
  327. }
  328. func (c *Conn) loop() {
  329. for {
  330. if err := c.connect(); err != nil {
  331. // c.Close() was called
  332. return
  333. }
  334. err := c.authenticate()
  335. switch {
  336. case err == ErrSessionExpired:
  337. c.logger.Printf("Authentication failed: %s", err)
  338. c.invalidateWatches(err)
  339. case err != nil && c.conn != nil:
  340. c.logger.Printf("Authentication failed: %s", err)
  341. c.conn.Close()
  342. case err == nil:
  343. c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs)
  344. c.hostProvider.Connected() // mark success
  345. c.closeChan = make(chan struct{}) // channel to tell send loop stop
  346. reauthChan := make(chan struct{}) // channel to tell send loop that authdata has been resubmitted
  347. var wg sync.WaitGroup
  348. wg.Add(1)
  349. go func() {
  350. <-reauthChan
  351. err := c.sendLoop()
  352. c.logger.Printf("Send loop terminated: err=%v", err)
  353. c.conn.Close() // causes recv loop to EOF/exit
  354. wg.Done()
  355. }()
  356. wg.Add(1)
  357. go func() {
  358. err := c.recvLoop(c.conn)
  359. c.logger.Printf("Recv loop terminated: err=%v", err)
  360. if err == nil {
  361. panic("zk: recvLoop should never return nil error")
  362. }
  363. close(c.closeChan) // tell send loop to exit
  364. wg.Done()
  365. }()
  366. c.resendZkAuth(reauthChan)
  367. c.sendSetWatches()
  368. wg.Wait()
  369. }
  370. c.setState(StateDisconnected)
  371. select {
  372. case <-c.shouldQuit:
  373. c.flushRequests(ErrClosing)
  374. return
  375. default:
  376. }
  377. if err != ErrSessionExpired {
  378. err = ErrConnectionClosed
  379. }
  380. c.flushRequests(err)
  381. if c.reconnectDelay > 0 {
  382. select {
  383. case <-c.shouldQuit:
  384. return
  385. case <-time.After(c.reconnectDelay):
  386. }
  387. }
  388. }
  389. }
  390. func (c *Conn) flushUnsentRequests(err error) {
  391. for {
  392. select {
  393. default:
  394. return
  395. case req := <-c.sendChan:
  396. req.recvChan <- response{-1, err}
  397. }
  398. }
  399. }
  400. // Send error to all pending requests and clear request map
  401. func (c *Conn) flushRequests(err error) {
  402. c.requestsLock.Lock()
  403. for _, req := range c.requests {
  404. req.recvChan <- response{-1, err}
  405. }
  406. c.requests = make(map[int32]*request)
  407. c.requestsLock.Unlock()
  408. }
  409. // Send error to all watchers and clear watchers map
  410. func (c *Conn) invalidateWatches(err error) {
  411. c.watchersLock.Lock()
  412. defer c.watchersLock.Unlock()
  413. if len(c.watchers) >= 0 {
  414. for pathType, watchers := range c.watchers {
  415. ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err}
  416. for _, ch := range watchers {
  417. ch <- ev
  418. close(ch)
  419. }
  420. }
  421. c.watchers = make(map[watchPathType][]chan Event)
  422. }
  423. }
  424. func (c *Conn) sendSetWatches() {
  425. c.watchersLock.Lock()
  426. defer c.watchersLock.Unlock()
  427. if len(c.watchers) == 0 {
  428. return
  429. }
  430. req := &setWatchesRequest{
  431. RelativeZxid: c.lastZxid,
  432. DataWatches: make([]string, 0),
  433. ExistWatches: make([]string, 0),
  434. ChildWatches: make([]string, 0),
  435. }
  436. n := 0
  437. for pathType, watchers := range c.watchers {
  438. if len(watchers) == 0 {
  439. continue
  440. }
  441. switch pathType.wType {
  442. case watchTypeData:
  443. req.DataWatches = append(req.DataWatches, pathType.path)
  444. case watchTypeExist:
  445. req.ExistWatches = append(req.ExistWatches, pathType.path)
  446. case watchTypeChild:
  447. req.ChildWatches = append(req.ChildWatches, pathType.path)
  448. }
  449. n++
  450. }
  451. if n == 0 {
  452. return
  453. }
  454. go func() {
  455. res := &setWatchesResponse{}
  456. _, err := c.request(opSetWatches, req, res, nil)
  457. if err != nil {
  458. c.logger.Printf("Failed to set previous watches: %s", err.Error())
  459. }
  460. }()
  461. }
  462. func (c *Conn) authenticate() error {
  463. buf := make([]byte, 256)
  464. // Encode and send a connect request.
  465. n, err := encodePacket(buf[4:], &connectRequest{
  466. ProtocolVersion: protocolVersion,
  467. LastZxidSeen: c.lastZxid,
  468. TimeOut: c.sessionTimeoutMs,
  469. SessionID: c.SessionID(),
  470. Passwd: c.passwd,
  471. })
  472. if err != nil {
  473. return err
  474. }
  475. binary.BigEndian.PutUint32(buf[:4], uint32(n))
  476. c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10))
  477. _, err = c.conn.Write(buf[:n+4])
  478. c.conn.SetWriteDeadline(time.Time{})
  479. if err != nil {
  480. return err
  481. }
  482. // Receive and decode a connect response.
  483. c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10))
  484. _, err = io.ReadFull(c.conn, buf[:4])
  485. c.conn.SetReadDeadline(time.Time{})
  486. if err != nil {
  487. return err
  488. }
  489. blen := int(binary.BigEndian.Uint32(buf[:4]))
  490. if cap(buf) < blen {
  491. buf = make([]byte, blen)
  492. }
  493. _, err = io.ReadFull(c.conn, buf[:blen])
  494. if err != nil {
  495. return err
  496. }
  497. r := connectResponse{}
  498. _, err = decodePacket(buf[:blen], &r)
  499. if err != nil {
  500. return err
  501. }
  502. if r.SessionID == 0 {
  503. atomic.StoreInt64(&c.sessionID, int64(0))
  504. c.passwd = emptyPassword
  505. c.lastZxid = 0
  506. c.setState(StateExpired)
  507. return ErrSessionExpired
  508. }
  509. atomic.StoreInt64(&c.sessionID, r.SessionID)
  510. c.setTimeouts(r.TimeOut)
  511. c.passwd = r.Passwd
  512. c.setState(StateHasSession)
  513. return nil
  514. }
  515. func (c *Conn) sendData(req *request) error {
  516. header := &requestHeader{req.xid, req.opcode}
  517. n, err := encodePacket(c.buf[4:], header)
  518. if err != nil {
  519. req.recvChan <- response{-1, err}
  520. return nil
  521. }
  522. n2, err := encodePacket(c.buf[4+n:], req.pkt)
  523. if err != nil {
  524. req.recvChan <- response{-1, err}
  525. return nil
  526. }
  527. n += n2
  528. binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
  529. c.requestsLock.Lock()
  530. select {
  531. case <-c.closeChan:
  532. req.recvChan <- response{-1, ErrConnectionClosed}
  533. c.requestsLock.Unlock()
  534. return ErrConnectionClosed
  535. default:
  536. }
  537. c.requests[req.xid] = req
  538. c.requestsLock.Unlock()
  539. c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
  540. _, err = c.conn.Write(c.buf[:n+4])
  541. c.conn.SetWriteDeadline(time.Time{})
  542. if err != nil {
  543. req.recvChan <- response{-1, err}
  544. c.conn.Close()
  545. return err
  546. }
  547. return nil
  548. }
  549. func (c *Conn) sendLoop() error {
  550. pingTicker := time.NewTicker(c.pingInterval)
  551. defer pingTicker.Stop()
  552. for {
  553. select {
  554. case req := <-c.sendChan:
  555. if err := c.sendData(req); err != nil {
  556. return err
  557. }
  558. case <-pingTicker.C:
  559. n, err := encodePacket(c.buf[4:], &requestHeader{Xid: -2, Opcode: opPing})
  560. if err != nil {
  561. panic("zk: opPing should never fail to serialize")
  562. }
  563. binary.BigEndian.PutUint32(c.buf[:4], uint32(n))
  564. c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout))
  565. _, err = c.conn.Write(c.buf[:n+4])
  566. c.conn.SetWriteDeadline(time.Time{})
  567. if err != nil {
  568. c.conn.Close()
  569. return err
  570. }
  571. case <-c.closeChan:
  572. return nil
  573. }
  574. }
  575. }
  576. func (c *Conn) recvLoop(conn net.Conn) error {
  577. buf := make([]byte, bufferSize)
  578. for {
  579. // package length
  580. conn.SetReadDeadline(time.Now().Add(c.recvTimeout))
  581. _, err := io.ReadFull(conn, buf[:4])
  582. if err != nil {
  583. return err
  584. }
  585. blen := int(binary.BigEndian.Uint32(buf[:4]))
  586. if cap(buf) < blen {
  587. buf = make([]byte, blen)
  588. }
  589. _, err = io.ReadFull(conn, buf[:blen])
  590. conn.SetReadDeadline(time.Time{})
  591. if err != nil {
  592. return err
  593. }
  594. res := responseHeader{}
  595. _, err = decodePacket(buf[:16], &res)
  596. if err != nil {
  597. return err
  598. }
  599. if res.Xid == -1 {
  600. res := &watcherEvent{}
  601. _, err := decodePacket(buf[16:blen], res)
  602. if err != nil {
  603. return err
  604. }
  605. ev := Event{
  606. Type: res.Type,
  607. State: res.State,
  608. Path: res.Path,
  609. Err: nil,
  610. }
  611. c.sendEvent(ev)
  612. wTypes := make([]watchType, 0, 2)
  613. switch res.Type {
  614. case EventNodeCreated:
  615. wTypes = append(wTypes, watchTypeExist)
  616. case EventNodeDeleted, EventNodeDataChanged:
  617. wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild)
  618. case EventNodeChildrenChanged:
  619. wTypes = append(wTypes, watchTypeChild)
  620. }
  621. c.watchersLock.Lock()
  622. for _, t := range wTypes {
  623. wpt := watchPathType{res.Path, t}
  624. if watchers := c.watchers[wpt]; watchers != nil && len(watchers) > 0 {
  625. for _, ch := range watchers {
  626. ch <- ev
  627. close(ch)
  628. }
  629. delete(c.watchers, wpt)
  630. }
  631. }
  632. c.watchersLock.Unlock()
  633. } else if res.Xid == -2 {
  634. // Ping response. Ignore.
  635. } else if res.Xid < 0 {
  636. c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid)
  637. } else {
  638. if res.Zxid > 0 {
  639. c.lastZxid = res.Zxid
  640. }
  641. c.requestsLock.Lock()
  642. req, ok := c.requests[res.Xid]
  643. if ok {
  644. delete(c.requests, res.Xid)
  645. }
  646. c.requestsLock.Unlock()
  647. if !ok {
  648. c.logger.Printf("Response for unknown request with xid %d", res.Xid)
  649. } else {
  650. if res.Err != 0 {
  651. err = res.Err.toError()
  652. } else {
  653. _, err = decodePacket(buf[16:blen], req.recvStruct)
  654. }
  655. if req.recvFunc != nil {
  656. req.recvFunc(req, &res, err)
  657. }
  658. req.recvChan <- response{res.Zxid, err}
  659. if req.opcode == opClose {
  660. return io.EOF
  661. }
  662. }
  663. }
  664. }
  665. }
  666. func (c *Conn) nextXid() int32 {
  667. return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff)
  668. }
  669. func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event {
  670. c.watchersLock.Lock()
  671. defer c.watchersLock.Unlock()
  672. ch := make(chan Event, 1)
  673. wpt := watchPathType{path, watchType}
  674. c.watchers[wpt] = append(c.watchers[wpt], ch)
  675. return ch
  676. }
  677. func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response {
  678. rq := &request{
  679. xid: c.nextXid(),
  680. opcode: opcode,
  681. pkt: req,
  682. recvStruct: res,
  683. recvChan: make(chan response, 1),
  684. recvFunc: recvFunc,
  685. }
  686. c.sendChan <- rq
  687. return rq.recvChan
  688. }
  689. func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) {
  690. r := <-c.queueRequest(opcode, req, res, recvFunc)
  691. return r.zxid, r.err
  692. }
  693. func (c *Conn) AddAuth(scheme string, auth []byte) error {
  694. _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil)
  695. if err != nil {
  696. return err
  697. }
  698. // Remember authdata so that it can be re-submitted on reconnect
  699. //
  700. // FIXME(prozlach): For now we treat "userfoo:passbar" and "userfoo:passbar2"
  701. // as two different entries, which will be re-submitted on reconnet. Some
  702. // research is needed on how ZK treats these cases and
  703. // then maybe switch to something like "map[username] = password" to allow
  704. // only single password for given user with users being unique.
  705. obj := authCreds{
  706. scheme: scheme,
  707. auth: auth,
  708. }
  709. c.credsMu.Lock()
  710. c.creds = append(c.creds, obj)
  711. c.credsMu.Unlock()
  712. return nil
  713. }
  714. func (c *Conn) Children(path string) ([]string, *Stat, error) {
  715. res := &getChildren2Response{}
  716. _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil)
  717. return res.Children, &res.Stat, err
  718. }
  719. func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) {
  720. var ech <-chan Event
  721. res := &getChildren2Response{}
  722. _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
  723. if err == nil {
  724. ech = c.addWatcher(path, watchTypeChild)
  725. }
  726. })
  727. if err != nil {
  728. return nil, nil, nil, err
  729. }
  730. return res.Children, &res.Stat, ech, err
  731. }
  732. func (c *Conn) Get(path string) ([]byte, *Stat, error) {
  733. res := &getDataResponse{}
  734. _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil)
  735. return res.Data, &res.Stat, err
  736. }
  737. // GetW returns the contents of a znode and sets a watch
  738. func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) {
  739. var ech <-chan Event
  740. res := &getDataResponse{}
  741. _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
  742. if err == nil {
  743. ech = c.addWatcher(path, watchTypeData)
  744. }
  745. })
  746. if err != nil {
  747. return nil, nil, nil, err
  748. }
  749. return res.Data, &res.Stat, ech, err
  750. }
  751. func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) {
  752. if path == "" {
  753. return nil, ErrInvalidPath
  754. }
  755. res := &setDataResponse{}
  756. _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil)
  757. return &res.Stat, err
  758. }
  759. func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) {
  760. res := &createResponse{}
  761. _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil)
  762. return res.Path, err
  763. }
  764. // CreateProtectedEphemeralSequential fixes a race condition if the server crashes
  765. // after it creates the node. On reconnect the session may still be valid so the
  766. // ephemeral node still exists. Therefore, on reconnect we need to check if a node
  767. // with a GUID generated on create exists.
  768. func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) {
  769. var guid [16]byte
  770. _, err := io.ReadFull(rand.Reader, guid[:16])
  771. if err != nil {
  772. return "", err
  773. }
  774. guidStr := fmt.Sprintf("%x", guid)
  775. parts := strings.Split(path, "/")
  776. parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1])
  777. rootPath := strings.Join(parts[:len(parts)-1], "/")
  778. protectedPath := strings.Join(parts, "/")
  779. var newPath string
  780. for i := 0; i < 3; i++ {
  781. newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl)
  782. switch err {
  783. case ErrSessionExpired:
  784. // No need to search for the node since it can't exist. Just try again.
  785. case ErrConnectionClosed:
  786. children, _, err := c.Children(rootPath)
  787. if err != nil {
  788. return "", err
  789. }
  790. for _, p := range children {
  791. parts := strings.Split(p, "/")
  792. if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) {
  793. if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr {
  794. return rootPath + "/" + p, nil
  795. }
  796. }
  797. }
  798. case nil:
  799. return newPath, nil
  800. default:
  801. return "", err
  802. }
  803. }
  804. return "", err
  805. }
  806. func (c *Conn) Delete(path string, version int32) error {
  807. _, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil)
  808. return err
  809. }
  810. func (c *Conn) Exists(path string) (bool, *Stat, error) {
  811. res := &existsResponse{}
  812. _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil)
  813. exists := true
  814. if err == ErrNoNode {
  815. exists = false
  816. err = nil
  817. }
  818. return exists, &res.Stat, err
  819. }
  820. func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) {
  821. var ech <-chan Event
  822. res := &existsResponse{}
  823. _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) {
  824. if err == nil {
  825. ech = c.addWatcher(path, watchTypeData)
  826. } else if err == ErrNoNode {
  827. ech = c.addWatcher(path, watchTypeExist)
  828. }
  829. })
  830. exists := true
  831. if err == ErrNoNode {
  832. exists = false
  833. err = nil
  834. }
  835. if err != nil {
  836. return false, nil, nil, err
  837. }
  838. return exists, &res.Stat, ech, err
  839. }
  840. func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) {
  841. res := &getAclResponse{}
  842. _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil)
  843. return res.Acl, &res.Stat, err
  844. }
  845. func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) {
  846. res := &setAclResponse{}
  847. _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil)
  848. return &res.Stat, err
  849. }
  850. func (c *Conn) Sync(path string) (string, error) {
  851. res := &syncResponse{}
  852. _, err := c.request(opSync, &syncRequest{Path: path}, res, nil)
  853. return res.Path, err
  854. }
  855. type MultiResponse struct {
  856. Stat *Stat
  857. String string
  858. Error error
  859. }
  860. // Multi executes multiple ZooKeeper operations or none of them. The provided
  861. // ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or
  862. // *CheckVersionRequest.
  863. func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) {
  864. req := &multiRequest{
  865. Ops: make([]multiRequestOp, 0, len(ops)),
  866. DoneHeader: multiHeader{Type: -1, Done: true, Err: -1},
  867. }
  868. for _, op := range ops {
  869. var opCode int32
  870. switch op.(type) {
  871. case *CreateRequest:
  872. opCode = opCreate
  873. case *SetDataRequest:
  874. opCode = opSetData
  875. case *DeleteRequest:
  876. opCode = opDelete
  877. case *CheckVersionRequest:
  878. opCode = opCheck
  879. default:
  880. return nil, fmt.Errorf("unknown operation type %T", op)
  881. }
  882. req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op})
  883. }
  884. res := &multiResponse{}
  885. _, err := c.request(opMulti, req, res, nil)
  886. mr := make([]MultiResponse, len(res.Ops))
  887. for i, op := range res.Ops {
  888. mr[i] = MultiResponse{Stat: op.Stat, String: op.String, Error: op.Err.toError()}
  889. }
  890. return mr, err
  891. }
  892. // Server returns the current or last-connected server name.
  893. func (c *Conn) Server() string {
  894. c.serverMu.Lock()
  895. defer c.serverMu.Unlock()
  896. return c.server
  897. }