dialect.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package core
  2. import (
  3. "fmt"
  4. "strings"
  5. "time"
  6. )
  7. type DbType string
  8. type Uri struct {
  9. DbType DbType
  10. Proto string
  11. Host string
  12. Port string
  13. DbName string
  14. User string
  15. Passwd string
  16. Charset string
  17. Laddr string
  18. Raddr string
  19. Timeout time.Duration
  20. Schema string
  21. }
  22. // a dialect is a driver's wrapper
  23. type Dialect interface {
  24. SetLogger(logger ILogger)
  25. Init(*DB, *Uri, string, string) error
  26. URI() *Uri
  27. DB() *DB
  28. DBType() DbType
  29. SqlType(*Column) string
  30. FormatBytes(b []byte) string
  31. DriverName() string
  32. DataSourceName() string
  33. QuoteStr() string
  34. IsReserved(string) bool
  35. Quote(string) string
  36. AndStr() string
  37. OrStr() string
  38. EqStr() string
  39. RollBackStr() string
  40. AutoIncrStr() string
  41. SupportInsertMany() bool
  42. SupportEngine() bool
  43. SupportCharset() bool
  44. SupportDropIfExists() bool
  45. IndexOnTable() bool
  46. ShowCreateNull() bool
  47. IndexCheckSql(tableName, idxName string) (string, []interface{})
  48. TableCheckSql(tableName string) (string, []interface{})
  49. IsColumnExist(tableName string, colName string) (bool, error)
  50. CreateTableSql(table *Table, tableName, storeEngine, charset string) string
  51. DropTableSql(tableName string) string
  52. CreateIndexSql(tableName string, index *Index) string
  53. DropIndexSql(tableName string, index *Index) string
  54. ModifyColumnSql(tableName string, col *Column) string
  55. ForUpdateSql(query string) string
  56. //CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error
  57. //MustDropTable(tableName string) error
  58. GetColumns(tableName string) ([]string, map[string]*Column, error)
  59. GetTables() ([]*Table, error)
  60. GetIndexes(tableName string) (map[string]*Index, error)
  61. Filters() []Filter
  62. }
  63. func OpenDialect(dialect Dialect) (*DB, error) {
  64. return Open(dialect.DriverName(), dialect.DataSourceName())
  65. }
  66. type Base struct {
  67. db *DB
  68. dialect Dialect
  69. driverName string
  70. dataSourceName string
  71. logger ILogger
  72. *Uri
  73. }
  74. func (b *Base) DB() *DB {
  75. return b.db
  76. }
  77. func (b *Base) SetLogger(logger ILogger) {
  78. b.logger = logger
  79. }
  80. func (b *Base) Init(db *DB, dialect Dialect, uri *Uri, drivername, dataSourceName string) error {
  81. b.db, b.dialect, b.Uri = db, dialect, uri
  82. b.driverName, b.dataSourceName = drivername, dataSourceName
  83. return nil
  84. }
  85. func (b *Base) URI() *Uri {
  86. return b.Uri
  87. }
  88. func (b *Base) DBType() DbType {
  89. return b.Uri.DbType
  90. }
  91. func (b *Base) FormatBytes(bs []byte) string {
  92. return fmt.Sprintf("0x%x", bs)
  93. }
  94. func (b *Base) DriverName() string {
  95. return b.driverName
  96. }
  97. func (b *Base) ShowCreateNull() bool {
  98. return true
  99. }
  100. func (b *Base) DataSourceName() string {
  101. return b.dataSourceName
  102. }
  103. func (b *Base) AndStr() string {
  104. return "AND"
  105. }
  106. func (b *Base) OrStr() string {
  107. return "OR"
  108. }
  109. func (b *Base) EqStr() string {
  110. return "="
  111. }
  112. func (db *Base) RollBackStr() string {
  113. return "ROLL BACK"
  114. }
  115. func (db *Base) SupportDropIfExists() bool {
  116. return true
  117. }
  118. func (db *Base) DropTableSql(tableName string) string {
  119. return fmt.Sprintf("DROP TABLE IF EXISTS `%s`", tableName)
  120. }
  121. func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) {
  122. db.LogSQL(query, args)
  123. rows, err := db.DB().Query(query, args...)
  124. if err != nil {
  125. return false, err
  126. }
  127. defer rows.Close()
  128. if rows.Next() {
  129. return true, nil
  130. }
  131. return false, nil
  132. }
  133. func (db *Base) IsColumnExist(tableName, colName string) (bool, error) {
  134. query := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?"
  135. query = strings.Replace(query, "`", db.dialect.QuoteStr(), -1)
  136. return db.HasRecords(query, db.DbName, tableName, colName)
  137. }
  138. /*
  139. func (db *Base) CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error {
  140. sql, args := db.dialect.TableCheckSql(tableName)
  141. rows, err := db.DB().Query(sql, args...)
  142. if db.Logger != nil {
  143. db.Logger.Info("[sql]", sql, args)
  144. }
  145. if err != nil {
  146. return err
  147. }
  148. defer rows.Close()
  149. if rows.Next() {
  150. return nil
  151. }
  152. sql = db.dialect.CreateTableSql(table, tableName, storeEngine, charset)
  153. _, err = db.DB().Exec(sql)
  154. if db.Logger != nil {
  155. db.Logger.Info("[sql]", sql)
  156. }
  157. return err
  158. }*/
  159. func (db *Base) CreateIndexSql(tableName string, index *Index) string {
  160. quote := db.dialect.Quote
  161. var unique string
  162. var idxName string
  163. if index.Type == UniqueType {
  164. unique = " UNIQUE"
  165. }
  166. idxName = index.XName(tableName)
  167. return fmt.Sprintf("CREATE%s INDEX %v ON %v (%v)", unique,
  168. quote(idxName), quote(tableName),
  169. quote(strings.Join(index.Cols, quote(","))))
  170. }
  171. func (db *Base) DropIndexSql(tableName string, index *Index) string {
  172. quote := db.dialect.Quote
  173. var name string
  174. if index.IsRegular {
  175. name = index.XName(tableName)
  176. } else {
  177. name = index.Name
  178. }
  179. return fmt.Sprintf("DROP INDEX %v ON %s", quote(name), quote(tableName))
  180. }
  181. func (db *Base) ModifyColumnSql(tableName string, col *Column) string {
  182. return fmt.Sprintf("alter table %s MODIFY COLUMN %s", tableName, col.StringNoPk(db.dialect))
  183. }
  184. func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset string) string {
  185. var sql string
  186. sql = "CREATE TABLE IF NOT EXISTS "
  187. if tableName == "" {
  188. tableName = table.Name
  189. }
  190. sql += b.dialect.Quote(tableName)
  191. sql += " ("
  192. if len(table.ColumnsSeq()) > 0 {
  193. pkList := table.PrimaryKeys
  194. for _, colName := range table.ColumnsSeq() {
  195. col := table.GetColumn(colName)
  196. if col.IsPrimaryKey && len(pkList) == 1 {
  197. sql += col.String(b.dialect)
  198. } else {
  199. sql += col.StringNoPk(b.dialect)
  200. }
  201. sql = strings.TrimSpace(sql)
  202. sql += ", "
  203. }
  204. if len(pkList) > 1 {
  205. sql += "PRIMARY KEY ( "
  206. sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(",")))
  207. sql += " ), "
  208. }
  209. sql = sql[:len(sql)-2]
  210. }
  211. sql += ")"
  212. if b.dialect.SupportEngine() && storeEngine != "" {
  213. sql += " ENGINE=" + storeEngine
  214. }
  215. if b.dialect.SupportCharset() {
  216. if len(charset) == 0 {
  217. charset = b.dialect.URI().Charset
  218. }
  219. if len(charset) > 0 {
  220. sql += " DEFAULT CHARSET " + charset
  221. }
  222. }
  223. return sql
  224. }
  225. func (b *Base) ForUpdateSql(query string) string {
  226. return query + " FOR UPDATE"
  227. }
  228. func (b *Base) LogSQL(sql string, args []interface{}) {
  229. if b.logger != nil && b.logger.IsShowSQL() {
  230. if len(args) > 0 {
  231. b.logger.Infof("[SQL] %v %v", sql, args)
  232. } else {
  233. b.logger.Infof("[SQL] %v", sql)
  234. }
  235. }
  236. }
  237. var (
  238. dialects = map[string]func() Dialect{}
  239. )
  240. // RegisterDialect register database dialect
  241. func RegisterDialect(dbName DbType, dialectFunc func() Dialect) {
  242. if dialectFunc == nil {
  243. panic("core: Register dialect is nil")
  244. }
  245. dialects[strings.ToLower(string(dbName))] = dialectFunc // !nashtsai! allow override dialect
  246. }
  247. // QueryDialect query if registed database dialect
  248. func QueryDialect(dbName DbType) Dialect {
  249. if d, ok := dialects[strings.ToLower(string(dbName))]; ok {
  250. return d()
  251. }
  252. return nil
  253. }