schema.go 8.2 KB


  1. // Copyright 2012, Google Inc. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package schema
  5. import (
  6. "database/sql"
  7. "fmt"
  8. "strings"
  9. "github.com/juju/errors"
  10. "github.com/siddontang/go-mysql/mysql"
  11. )
  12. var ErrTableNotExist = errors.New("table is not exist")
  13. var ErrMissingTableMeta = errors.New("missing table meta")
  14. var HAHealthCheckSchema = "mysql.ha_health_check"
  15. const (
  16. TYPE_NUMBER = iota + 1 // tinyint, smallint, mediumint, int, bigint, year
  17. TYPE_FLOAT // float, double
  18. TYPE_ENUM // enum
  19. TYPE_SET // set
  20. TYPE_STRING // other
  21. TYPE_DATETIME // datetime
  22. TYPE_TIMESTAMP // timestamp
  23. TYPE_DATE // date
  24. TYPE_TIME // time
  25. TYPE_BIT // bit
  26. TYPE_JSON // json
  27. )
  28. type TableColumn struct {
  29. Name string
  30. Type int
  31. Collation string
  32. RawType string
  33. IsAuto bool
  34. IsUnsigned bool
  35. EnumValues []string
  36. SetValues []string
  37. }
  38. type Index struct {
  39. Name string
  40. Columns []string
  41. Cardinality []uint64
  42. }
  43. type Table struct {
  44. Schema string
  45. Name string
  46. Columns []TableColumn
  47. Indexes []*Index
  48. PKColumns []int
  49. }
  50. func (ta *Table) String() string {
  51. return fmt.Sprintf("%s.%s", ta.Schema, ta.Name)
  52. }
  53. func (ta *Table) AddColumn(name string, columnType string, collation string, extra string) {
  54. index := len(ta.Columns)
  55. ta.Columns = append(ta.Columns, TableColumn{Name: name, Collation: collation})
  56. ta.Columns[index].RawType = columnType
  57. if strings.HasPrefix(columnType, "float") ||
  58. strings.HasPrefix(columnType, "double") ||
  59. strings.HasPrefix(columnType, "decimal") {
  60. ta.Columns[index].Type = TYPE_FLOAT
  61. } else if strings.HasPrefix(columnType, "enum") {
  62. ta.Columns[index].Type = TYPE_ENUM
  63. ta.Columns[index].EnumValues = strings.Split(strings.Replace(
  64. strings.TrimSuffix(
  65. strings.TrimPrefix(
  66. columnType, "enum("),
  67. ")"),
  68. "'", "", -1),
  69. ",")
  70. } else if strings.HasPrefix(columnType, "set") {
  71. ta.Columns[index].Type = TYPE_SET
  72. ta.Columns[index].SetValues = strings.Split(strings.Replace(
  73. strings.TrimSuffix(
  74. strings.TrimPrefix(
  75. columnType, "set("),
  76. ")"),
  77. "'", "", -1),
  78. ",")
  79. } else if strings.HasPrefix(columnType, "datetime") {
  80. ta.Columns[index].Type = TYPE_DATETIME
  81. } else if strings.HasPrefix(columnType, "timestamp") {
  82. ta.Columns[index].Type = TYPE_TIMESTAMP
  83. } else if strings.HasPrefix(columnType, "time") {
  84. ta.Columns[index].Type = TYPE_TIME
  85. } else if "date" == columnType {
  86. ta.Columns[index].Type = TYPE_DATE
  87. } else if strings.HasPrefix(columnType, "bit") {
  88. ta.Columns[index].Type = TYPE_BIT
  89. } else if strings.HasPrefix(columnType, "json") {
  90. ta.Columns[index].Type = TYPE_JSON
  91. } else if strings.Contains(columnType, "int") || strings.HasPrefix(columnType, "year") {
  92. ta.Columns[index].Type = TYPE_NUMBER
  93. } else {
  94. ta.Columns[index].Type = TYPE_STRING
  95. }
  96. if strings.Contains(columnType, "unsigned") || strings.Contains(columnType, "zerofill") {
  97. ta.Columns[index].IsUnsigned = true
  98. }
  99. if extra == "auto_increment" {
  100. ta.Columns[index].IsAuto = true
  101. }
  102. }
  103. func (ta *Table) FindColumn(name string) int {
  104. for i, col := range ta.Columns {
  105. if col.Name == name {
  106. return i
  107. }
  108. }
  109. return -1
  110. }
  111. func (ta *Table) GetPKColumn(index int) *TableColumn {
  112. return &ta.Columns[ta.PKColumns[index]]
  113. }
  114. func (ta *Table) AddIndex(name string) (index *Index) {
  115. index = NewIndex(name)
  116. ta.Indexes = append(ta.Indexes, index)
  117. return index
  118. }
  119. func NewIndex(name string) *Index {
  120. return &Index{name, make([]string, 0, 8), make([]uint64, 0, 8)}
  121. }
  122. func (idx *Index) AddColumn(name string, cardinality uint64) {
  123. idx.Columns = append(idx.Columns, name)
  124. if cardinality == 0 {
  125. cardinality = uint64(len(idx.Cardinality) + 1)
  126. }
  127. idx.Cardinality = append(idx.Cardinality, cardinality)
  128. }
  129. func (idx *Index) FindColumn(name string) int {
  130. for i, colName := range idx.Columns {
  131. if name == colName {
  132. return i
  133. }
  134. }
  135. return -1
  136. }
  137. func IsTableExist(conn mysql.Executer, schema string, name string) (bool, error) {
  138. query := fmt.Sprintf("SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '%s' and TABLE_NAME = '%s' LIMIT 1", schema, name)
  139. r, err := conn.Execute(query)
  140. if err != nil {
  141. return false, errors.Trace(err)
  142. }
  143. return r.RowNumber() == 1, nil
  144. }
  145. func NewTableFromSqlDB(conn *sql.DB, schema string, name string) (*Table, error) {
  146. ta := &Table{
  147. Schema: schema,
  148. Name: name,
  149. Columns: make([]TableColumn, 0, 16),
  150. Indexes: make([]*Index, 0, 8),
  151. }
  152. if err := ta.fetchColumnsViaSqlDB(conn); err != nil {
  153. return nil, errors.Trace(err)
  154. }
  155. if err := ta.fetchIndexesViaSqlDB(conn); err != nil {
  156. return nil, errors.Trace(err)
  157. }
  158. return ta, nil
  159. }
  160. func NewTable(conn mysql.Executer, schema string, name string) (*Table, error) {
  161. ta := &Table{
  162. Schema: schema,
  163. Name: name,
  164. Columns: make([]TableColumn, 0, 16),
  165. Indexes: make([]*Index, 0, 8),
  166. }
  167. if err := ta.fetchColumns(conn); err != nil {
  168. return nil, errors.Trace(err)
  169. }
  170. if err := ta.fetchIndexes(conn); err != nil {
  171. return nil, errors.Trace(err)
  172. }
  173. return ta, nil
  174. }
  175. func (ta *Table) fetchColumns(conn mysql.Executer) error {
  176. r, err := conn.Execute(fmt.Sprintf("show full columns from `%s`.`%s`", ta.Schema, ta.Name))
  177. if err != nil {
  178. return errors.Trace(err)
  179. }
  180. for i := 0; i < r.RowNumber(); i++ {
  181. name, _ := r.GetString(i, 0)
  182. colType, _ := r.GetString(i, 1)
  183. collation, _ := r.GetString(i, 2)
  184. extra, _ := r.GetString(i, 6)
  185. ta.AddColumn(name, colType, collation, extra)
  186. }
  187. return nil
  188. }
  189. func (ta *Table) fetchColumnsViaSqlDB(conn *sql.DB) error {
  190. r, err := conn.Query(fmt.Sprintf("show full columns from `%s`.`%s`", ta.Schema, ta.Name))
  191. if err != nil {
  192. return errors.Trace(err)
  193. }
  194. defer r.Close()
  195. var unusedVal interface{}
  196. unused := &unusedVal
  197. for r.Next() {
  198. var name, colType, extra string
  199. var collation sql.NullString
  200. err := r.Scan(&name, &colType, &collation, &unused, &unused, &unused, &extra, &unused, &unused)
  201. if err != nil {
  202. return errors.Trace(err)
  203. }
  204. ta.AddColumn(name, colType, collation.String, extra)
  205. }
  206. return r.Err()
  207. }
  208. func (ta *Table) fetchIndexes(conn mysql.Executer) error {
  209. r, err := conn.Execute(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name))
  210. if err != nil {
  211. return errors.Trace(err)
  212. }
  213. var currentIndex *Index
  214. currentName := ""
  215. for i := 0; i < r.RowNumber(); i++ {
  216. indexName, _ := r.GetString(i, 2)
  217. if currentName != indexName {
  218. currentIndex = ta.AddIndex(indexName)
  219. currentName = indexName
  220. }
  221. cardinality, _ := r.GetUint(i, 6)
  222. colName, _ := r.GetString(i, 4)
  223. currentIndex.AddColumn(colName, cardinality)
  224. }
  225. return ta.fetchPrimaryKeyColumns()
  226. }
  227. func (ta *Table) fetchIndexesViaSqlDB(conn *sql.DB) error {
  228. r, err := conn.Query(fmt.Sprintf("show index from `%s`.`%s`", ta.Schema, ta.Name))
  229. if err != nil {
  230. return errors.Trace(err)
  231. }
  232. defer r.Close()
  233. var currentIndex *Index
  234. currentName := ""
  235. var unusedVal interface{}
  236. unused := &unusedVal
  237. for r.Next() {
  238. var indexName, colName string
  239. var cardinality interface{}
  240. err := r.Scan(
  241. &unused,
  242. &unused,
  243. &indexName,
  244. &unused,
  245. &colName,
  246. &unused,
  247. &cardinality,
  248. &unused,
  249. &unused,
  250. &unused,
  251. &unused,
  252. &unused,
  253. &unused,
  254. )
  255. if err != nil {
  256. return errors.Trace(err)
  257. }
  258. if currentName != indexName {
  259. currentIndex = ta.AddIndex(indexName)
  260. currentName = indexName
  261. }
  262. c := toUint64(cardinality)
  263. currentIndex.AddColumn(colName, c)
  264. }
  265. return ta.fetchPrimaryKeyColumns()
  266. }
  267. func toUint64(i interface{}) uint64 {
  268. switch i := i.(type) {
  269. case int:
  270. return uint64(i)
  271. case int8:
  272. return uint64(i)
  273. case int16:
  274. return uint64(i)
  275. case int32:
  276. return uint64(i)
  277. case int64:
  278. return uint64(i)
  279. case uint:
  280. return uint64(i)
  281. case uint8:
  282. return uint64(i)
  283. case uint16:
  284. return uint64(i)
  285. case uint32:
  286. return uint64(i)
  287. case uint64:
  288. return uint64(i)
  289. }
  290. return 0
  291. }
  292. func (ta *Table) fetchPrimaryKeyColumns() error {
  293. if len(ta.Indexes) == 0 {
  294. return nil
  295. }
  296. pkIndex := ta.Indexes[0]
  297. if pkIndex.Name != "PRIMARY" {
  298. return nil
  299. }
  300. ta.PKColumns = make([]int, len(pkIndex.Columns))
  301. for i, pkCol := range pkIndex.Columns {
  302. ta.PKColumns[i] = ta.FindColumn(pkCol)
  303. }
  304. return nil
  305. }