auth.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package client
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "encoding/binary"
  6. "github.com/juju/errors"
  7. . "github.com/siddontang/go-mysql/mysql"
  8. "github.com/siddontang/go-mysql/packet"
  9. )
  10. func (c *Conn) readInitialHandshake() error {
  11. data, err := c.ReadPacket()
  12. if err != nil {
  13. return errors.Trace(err)
  14. }
  15. if data[0] == ERR_HEADER {
  16. return errors.New("read initial handshake error")
  17. }
  18. if data[0] < MinProtocolVersion {
  19. return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
  20. }
  21. //skip mysql version
  22. //mysql version end with 0x00
  23. pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
  24. //connection id length is 4
  25. c.connectionID = uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
  26. pos += 4
  27. c.salt = []byte{}
  28. c.salt = append(c.salt, data[pos:pos+8]...)
  29. //skip filter
  30. pos += 8 + 1
  31. //capability lower 2 bytes
  32. c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
  33. pos += 2
  34. if len(data) > pos {
  35. //skip server charset
  36. //c.charset = data[pos]
  37. pos += 1
  38. c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
  39. pos += 2
  40. c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
  41. pos += 2
  42. //skip auth data len or [00]
  43. //skip reserved (all [00])
  44. pos += 10 + 1
  45. // The documentation is ambiguous about the length.
  46. // The official Python library uses the fixed length 12
  47. // mysql-proxy also use 12
  48. // which is not documented but seems to work.
  49. c.salt = append(c.salt, data[pos:pos+12]...)
  50. }
  51. return nil
  52. }
  53. func (c *Conn) writeAuthHandshake() error {
  54. // Adjust client capability flags based on server support
  55. capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
  56. CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_LONG_FLAG
  57. // To enable TLS / SSL
  58. if c.TLSConfig != nil {
  59. capability |= CLIENT_PLUGIN_AUTH
  60. capability |= CLIENT_SSL
  61. }
  62. capability &= c.capability
  63. //packet length
  64. //capbility 4
  65. //max-packet size 4
  66. //charset 1
  67. //reserved all[0] 23
  68. length := 4 + 4 + 1 + 23
  69. //username
  70. length += len(c.user) + 1
  71. //we only support secure connection
  72. auth := CalcPassword(c.salt, []byte(c.password))
  73. length += 1 + len(auth)
  74. if len(c.db) > 0 {
  75. capability |= CLIENT_CONNECT_WITH_DB
  76. length += len(c.db) + 1
  77. }
  78. // mysql_native_password + null-terminated
  79. length += 21 + 1
  80. c.capability = capability
  81. data := make([]byte, length+4)
  82. //capability [32 bit]
  83. data[4] = byte(capability)
  84. data[5] = byte(capability >> 8)
  85. data[6] = byte(capability >> 16)
  86. data[7] = byte(capability >> 24)
  87. //MaxPacketSize [32 bit] (none)
  88. //data[8] = 0x00
  89. //data[9] = 0x00
  90. //data[10] = 0x00
  91. //data[11] = 0x00
  92. //Charset [1 byte]
  93. //use default collation id 33 here, is utf-8
  94. data[12] = byte(DEFAULT_COLLATION_ID)
  95. // SSL Connection Request Packet
  96. // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
  97. if c.TLSConfig != nil {
  98. // Send TLS / SSL request packet
  99. if err := c.WritePacket(data[:(4+4+1+23)+4]); err != nil {
  100. return err
  101. }
  102. // Switch to TLS
  103. tlsConn := tls.Client(c.Conn.Conn, c.TLSConfig)
  104. if err := tlsConn.Handshake(); err != nil {
  105. return err
  106. }
  107. currentSequence := c.Sequence
  108. c.Conn = packet.NewConn(tlsConn)
  109. c.Sequence = currentSequence
  110. }
  111. //Filler [23 bytes] (all 0x00)
  112. pos := 13 + 23
  113. //User [null terminated string]
  114. if len(c.user) > 0 {
  115. pos += copy(data[pos:], c.user)
  116. }
  117. data[pos] = 0x00
  118. pos++
  119. // auth [length encoded integer]
  120. data[pos] = byte(len(auth))
  121. pos += 1 + copy(data[pos+1:], auth)
  122. // db [null terminated string]
  123. if len(c.db) > 0 {
  124. pos += copy(data[pos:], c.db)
  125. data[pos] = 0x00
  126. pos++
  127. }
  128. // Assume native client during response
  129. pos += copy(data[pos:], "mysql_native_password")
  130. data[pos] = 0x00
  131. return c.WritePacket(data)
  132. }