connection.go 13 KB


  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "context"
  11. "database/sql"
  12. "database/sql/driver"
  13. "io"
  14. "net"
  15. "strconv"
  16. "strings"
  17. "time"
  18. )
  19. type mysqlConn struct {
  20. buf buffer
  21. netConn net.Conn
  22. affectedRows uint64
  23. insertId uint64
  24. cfg *Config
  25. maxAllowedPacket int
  26. maxWriteSize int
  27. writeTimeout time.Duration
  28. flags clientFlag
  29. status statusFlag
  30. sequence uint8
  31. parseTime bool
  32. // for context support (Go 1.8+)
  33. watching bool
  34. watcher chan<- context.Context
  35. closech chan struct{}
  36. finished chan<- struct{}
  37. canceled atomicError // set non-nil if conn is canceled
  38. closed atomicBool // set when conn is closed, before closech is closed
  39. }
  40. // Handles parameters set in DSN after the connection is established
  41. func (mc *mysqlConn) handleParams() (err error) {
  42. for param, val := range mc.cfg.Params {
  43. switch param {
  44. // Charset
  45. case "charset":
  46. charsets := strings.Split(val, ",")
  47. for i := range charsets {
  48. // ignore errors here - a charset may not exist
  49. err = mc.exec("SET NAMES " + charsets[i])
  50. if err == nil {
  51. break
  52. }
  53. }
  54. if err != nil {
  55. return
  56. }
  57. // System Vars
  58. default:
  59. err = mc.exec("SET " + param + "=" + val + "")
  60. if err != nil {
  61. return
  62. }
  63. }
  64. }
  65. return
  66. }
  67. func (mc *mysqlConn) markBadConn(err error) error {
  68. if mc == nil {
  69. return err
  70. }
  71. if err != errBadConnNoWrite {
  72. return err
  73. }
  74. return driver.ErrBadConn
  75. }
  76. func (mc *mysqlConn) Begin() (driver.Tx, error) {
  77. return mc.begin(false)
  78. }
  79. func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
  80. if mc.closed.IsSet() {
  81. errLog.Print(ErrInvalidConn)
  82. return nil, driver.ErrBadConn
  83. }
  84. var q string
  85. if readOnly {
  86. q = "START TRANSACTION READ ONLY"
  87. } else {
  88. q = "START TRANSACTION"
  89. }
  90. err := mc.exec(q)
  91. if err == nil {
  92. return &mysqlTx{mc}, err
  93. }
  94. return nil, mc.markBadConn(err)
  95. }
  96. func (mc *mysqlConn) Close() (err error) {
  97. // Makes Close idempotent
  98. if !mc.closed.IsSet() {
  99. err = mc.writeCommandPacket(comQuit)
  100. }
  101. mc.cleanup()
  102. return
  103. }
  104. // Closes the network connection and unsets internal variables. Do not call this
  105. // function after successfully authentication, call Close instead. This function
  106. // is called before auth or on auth failure because MySQL will have already
  107. // closed the network connection.
  108. func (mc *mysqlConn) cleanup() {
  109. if !mc.closed.TrySet(true) {
  110. return
  111. }
  112. // Makes cleanup idempotent
  113. close(mc.closech)
  114. if mc.netConn == nil {
  115. return
  116. }
  117. if err := mc.netConn.Close(); err != nil {
  118. errLog.Print(err)
  119. }
  120. }
  121. func (mc *mysqlConn) error() error {
  122. if mc.closed.IsSet() {
  123. if err := mc.canceled.Value(); err != nil {
  124. return err
  125. }
  126. return ErrInvalidConn
  127. }
  128. return nil
  129. }
  130. func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
  131. if mc.closed.IsSet() {
  132. errLog.Print(ErrInvalidConn)
  133. return nil, driver.ErrBadConn
  134. }
  135. // Send command
  136. err := mc.writeCommandPacketStr(comStmtPrepare, query)
  137. if err != nil {
  138. return nil, mc.markBadConn(err)
  139. }
  140. stmt := &mysqlStmt{
  141. mc: mc,
  142. }
  143. // Read Result
  144. columnCount, err := stmt.readPrepareResultPacket()
  145. if err == nil {
  146. if stmt.paramCount > 0 {
  147. if err = mc.readUntilEOF(); err != nil {
  148. return nil, err
  149. }
  150. }
  151. if columnCount > 0 {
  152. err = mc.readUntilEOF()
  153. }
  154. }
  155. return stmt, err
  156. }
  157. func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
  158. // Number of ? should be same to len(args)
  159. if strings.Count(query, "?") != len(args) {
  160. return "", driver.ErrSkip
  161. }
  162. buf, err := mc.buf.takeCompleteBuffer()
  163. if err != nil {
  164. // can not take the buffer. Something must be wrong with the connection
  165. errLog.Print(err)
  166. return "", ErrInvalidConn
  167. }
  168. buf = buf[:0]
  169. argPos := 0
  170. for i := 0; i < len(query); i++ {
  171. q := strings.IndexByte(query[i:], '?')
  172. if q == -1 {
  173. buf = append(buf, query[i:]...)
  174. break
  175. }
  176. buf = append(buf, query[i:i+q]...)
  177. i += q
  178. arg := args[argPos]
  179. argPos++
  180. if arg == nil {
  181. buf = append(buf, "NULL"...)
  182. continue
  183. }
  184. switch v := arg.(type) {
  185. case int64:
  186. buf = strconv.AppendInt(buf, v, 10)
  187. case float64:
  188. buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
  189. case bool:
  190. if v {
  191. buf = append(buf, '1')
  192. } else {
  193. buf = append(buf, '0')
  194. }
  195. case time.Time:
  196. if v.IsZero() {
  197. buf = append(buf, "'0000-00-00'"...)
  198. } else {
  199. v := v.In(mc.cfg.Loc)
  200. v = v.Add(time.Nanosecond * 500) // To round under microsecond
  201. year := v.Year()
  202. year100 := year / 100
  203. year1 := year % 100
  204. month := v.Month()
  205. day := v.Day()
  206. hour := v.Hour()
  207. minute := v.Minute()
  208. second := v.Second()
  209. micro := v.Nanosecond() / 1000
  210. buf = append(buf, []byte{
  211. '\'',
  212. digits10[year100], digits01[year100],
  213. digits10[year1], digits01[year1],
  214. '-',
  215. digits10[month], digits01[month],
  216. '-',
  217. digits10[day], digits01[day],
  218. ' ',
  219. digits10[hour], digits01[hour],
  220. ':',
  221. digits10[minute], digits01[minute],
  222. ':',
  223. digits10[second], digits01[second],
  224. }...)
  225. if micro != 0 {
  226. micro10000 := micro / 10000
  227. micro100 := micro / 100 % 100
  228. micro1 := micro % 100
  229. buf = append(buf, []byte{
  230. '.',
  231. digits10[micro10000], digits01[micro10000],
  232. digits10[micro100], digits01[micro100],
  233. digits10[micro1], digits01[micro1],
  234. }...)
  235. }
  236. buf = append(buf, '\'')
  237. }
  238. case []byte:
  239. if v == nil {
  240. buf = append(buf, "NULL"...)
  241. } else {
  242. buf = append(buf, "_binary'"...)
  243. if mc.status&statusNoBackslashEscapes == 0 {
  244. buf = escapeBytesBackslash(buf, v)
  245. } else {
  246. buf = escapeBytesQuotes(buf, v)
  247. }
  248. buf = append(buf, '\'')
  249. }
  250. case string:
  251. buf = append(buf, '\'')
  252. if mc.status&statusNoBackslashEscapes == 0 {
  253. buf = escapeStringBackslash(buf, v)
  254. } else {
  255. buf = escapeStringQuotes(buf, v)
  256. }
  257. buf = append(buf, '\'')
  258. default:
  259. return "", driver.ErrSkip
  260. }
  261. if len(buf)+4 > mc.maxAllowedPacket {
  262. return "", driver.ErrSkip
  263. }
  264. }
  265. if argPos != len(args) {
  266. return "", driver.ErrSkip
  267. }
  268. return string(buf), nil
  269. }
  270. func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  271. if mc.closed.IsSet() {
  272. errLog.Print(ErrInvalidConn)
  273. return nil, driver.ErrBadConn
  274. }
  275. if len(args) != 0 {
  276. if !mc.cfg.InterpolateParams {
  277. return nil, driver.ErrSkip
  278. }
  279. // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
  280. prepared, err := mc.interpolateParams(query, args)
  281. if err != nil {
  282. return nil, err
  283. }
  284. query = prepared
  285. }
  286. mc.affectedRows = 0
  287. mc.insertId = 0
  288. err := mc.exec(query)
  289. if err == nil {
  290. return &mysqlResult{
  291. affectedRows: int64(mc.affectedRows),
  292. insertId: int64(mc.insertId),
  293. }, err
  294. }
  295. return nil, mc.markBadConn(err)
  296. }
  297. // Internal function to execute commands
  298. func (mc *mysqlConn) exec(query string) error {
  299. // Send command
  300. if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
  301. return mc.markBadConn(err)
  302. }
  303. // Read Result
  304. resLen, err := mc.readResultSetHeaderPacket()
  305. if err != nil {
  306. return err
  307. }
  308. if resLen > 0 {
  309. // columns
  310. if err := mc.readUntilEOF(); err != nil {
  311. return err
  312. }
  313. // rows
  314. if err := mc.readUntilEOF(); err != nil {
  315. return err
  316. }
  317. }
  318. return mc.discardResults()
  319. }
  320. func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
  321. return mc.query(query, args)
  322. }
  323. func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
  324. if mc.closed.IsSet() {
  325. errLog.Print(ErrInvalidConn)
  326. return nil, driver.ErrBadConn
  327. }
  328. if len(args) != 0 {
  329. if !mc.cfg.InterpolateParams {
  330. return nil, driver.ErrSkip
  331. }
  332. // try client-side prepare to reduce roundtrip
  333. prepared, err := mc.interpolateParams(query, args)
  334. if err != nil {
  335. return nil, err
  336. }
  337. query = prepared
  338. }
  339. // Send command
  340. err := mc.writeCommandPacketStr(comQuery, query)
  341. if err == nil {
  342. // Read Result
  343. var resLen int
  344. resLen, err = mc.readResultSetHeaderPacket()
  345. if err == nil {
  346. rows := new(textRows)
  347. rows.mc = mc
  348. if resLen == 0 {
  349. rows.rs.done = true
  350. switch err := rows.NextResultSet(); err {
  351. case nil, io.EOF:
  352. return rows, nil
  353. default:
  354. return nil, err
  355. }
  356. }
  357. // Columns
  358. rows.rs.columns, err = mc.readColumns(resLen)
  359. return rows, err
  360. }
  361. }
  362. return nil, mc.markBadConn(err)
  363. }
  364. // Gets the value of the given MySQL System Variable
  365. // The returned byte slice is only valid until the next read
  366. func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
  367. // Send command
  368. if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
  369. return nil, err
  370. }
  371. // Read Result
  372. resLen, err := mc.readResultSetHeaderPacket()
  373. if err == nil {
  374. rows := new(textRows)
  375. rows.mc = mc
  376. rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
  377. if resLen > 0 {
  378. // Columns
  379. if err := mc.readUntilEOF(); err != nil {
  380. return nil, err
  381. }
  382. }
  383. dest := make([]driver.Value, resLen)
  384. if err = rows.readRow(dest); err == nil {
  385. return dest[0].([]byte), mc.readUntilEOF()
  386. }
  387. }
  388. return nil, err
  389. }
  390. // finish is called when the query has canceled.
  391. func (mc *mysqlConn) cancel(err error) {
  392. mc.canceled.Set(err)
  393. mc.cleanup()
  394. }
  395. // finish is called when the query has succeeded.
  396. func (mc *mysqlConn) finish() {
  397. if !mc.watching || mc.finished == nil {
  398. return
  399. }
  400. select {
  401. case mc.finished <- struct{}{}:
  402. mc.watching = false
  403. case <-mc.closech:
  404. }
  405. }
  406. // Ping implements driver.Pinger interface
  407. func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
  408. if mc.closed.IsSet() {
  409. errLog.Print(ErrInvalidConn)
  410. return driver.ErrBadConn
  411. }
  412. if err = mc.watchCancel(ctx); err != nil {
  413. return
  414. }
  415. defer mc.finish()
  416. if err = mc.writeCommandPacket(comPing); err != nil {
  417. return mc.markBadConn(err)
  418. }
  419. return mc.readResultOK()
  420. }
  421. // BeginTx implements driver.ConnBeginTx interface
  422. func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
  423. if err := mc.watchCancel(ctx); err != nil {
  424. return nil, err
  425. }
  426. defer mc.finish()
  427. if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
  428. level, err := mapIsolationLevel(opts.Isolation)
  429. if err != nil {
  430. return nil, err
  431. }
  432. err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
  433. if err != nil {
  434. return nil, err
  435. }
  436. }
  437. return mc.begin(opts.ReadOnly)
  438. }
  439. func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
  440. dargs, err := namedValueToValue(args)
  441. if err != nil {
  442. return nil, err
  443. }
  444. if err := mc.watchCancel(ctx); err != nil {
  445. return nil, err
  446. }
  447. rows, err := mc.query(query, dargs)
  448. if err != nil {
  449. mc.finish()
  450. return nil, err
  451. }
  452. rows.finish = mc.finish
  453. return rows, err
  454. }
  455. func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
  456. dargs, err := namedValueToValue(args)
  457. if err != nil {
  458. return nil, err
  459. }
  460. if err := mc.watchCancel(ctx); err != nil {
  461. return nil, err
  462. }
  463. defer mc.finish()
  464. return mc.Exec(query, dargs)
  465. }
  466. func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
  467. if err := mc.watchCancel(ctx); err != nil {
  468. return nil, err
  469. }
  470. stmt, err := mc.Prepare(query)
  471. mc.finish()
  472. if err != nil {
  473. return nil, err
  474. }
  475. select {
  476. default:
  477. case <-ctx.Done():
  478. stmt.Close()
  479. return nil, ctx.Err()
  480. }
  481. return stmt, nil
  482. }
  483. func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
  484. dargs, err := namedValueToValue(args)
  485. if err != nil {
  486. return nil, err
  487. }
  488. if err := stmt.mc.watchCancel(ctx); err != nil {
  489. return nil, err
  490. }
  491. rows, err := stmt.query(dargs)
  492. if err != nil {
  493. stmt.mc.finish()
  494. return nil, err
  495. }
  496. rows.finish = stmt.mc.finish
  497. return rows, err
  498. }
  499. func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
  500. dargs, err := namedValueToValue(args)
  501. if err != nil {
  502. return nil, err
  503. }
  504. if err := stmt.mc.watchCancel(ctx); err != nil {
  505. return nil, err
  506. }
  507. defer stmt.mc.finish()
  508. return stmt.Exec(dargs)
  509. }
  510. func (mc *mysqlConn) watchCancel(ctx context.Context) error {
  511. if mc.watching {
  512. // Reach here if canceled,
  513. // so the connection is already invalid
  514. mc.cleanup()
  515. return nil
  516. }
  517. // When ctx is already cancelled, don't watch it.
  518. if err := ctx.Err(); err != nil {
  519. return err
  520. }
  521. // When ctx is not cancellable, don't watch it.
  522. if ctx.Done() == nil {
  523. return nil
  524. }
  525. // When watcher is not alive, can't watch it.
  526. if mc.watcher == nil {
  527. return nil
  528. }
  529. mc.watching = true
  530. mc.watcher <- ctx
  531. return nil
  532. }
  533. func (mc *mysqlConn) startWatcher() {
  534. watcher := make(chan context.Context, 1)
  535. mc.watcher = watcher
  536. finished := make(chan struct{})
  537. mc.finished = finished
  538. go func() {
  539. for {
  540. var ctx context.Context
  541. select {
  542. case ctx = <-watcher:
  543. case <-mc.closech:
  544. return
  545. }
  546. select {
  547. case <-ctx.Done():
  548. mc.cancel(ctx.Err())
  549. case <-finished:
  550. case <-mc.closech:
  551. return
  552. }
  553. }
  554. }()
  555. }
  556. func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
  557. nv.Value, err = converter{}.ConvertValue(nv.Value)
  558. return
  559. }
  560. // ResetSession implements driver.SessionResetter.
  561. // (From Go 1.10)
  562. func (mc *mysqlConn) ResetSession(ctx context.Context) error {
  563. if mc.closed.IsSet() {
  564. return driver.ErrBadConn
  565. }
  566. return nil
  567. }