diff --git a/connection.go b/connection.go index 462e7d13..2b19c927 100644 --- a/connection.go +++ b/connection.go @@ -111,14 +111,13 @@ func (mc *mysqlConn) handleParams() (err error) { return } +// markBadConn replaces errBadConnNoWrite with driver.ErrBadConn. +// This function is used to return driver.ErrBadConn only when safe to retry. func (mc *mysqlConn) markBadConn(err error) error { - if mc == nil { - return err - } - if err != errBadConnNoWrite { - return err + if err == errBadConnNoWrite { + return driver.ErrBadConn } - return driver.ErrBadConn + return err } func (mc *mysqlConn) Begin() (driver.Tx, error) { @@ -127,7 +126,6 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.Load() { - mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } var q string @@ -189,7 +187,6 @@ func (mc *mysqlConn) error() error { func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.closed.Load() { - mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -324,7 +321,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.Load() { - mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -384,7 +380,6 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) handleOk := mc.clearResult() if mc.closed.Load() { - mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { @@ -408,7 +403,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) var resLen int resLen, err = handleOk.readResultSetHeaderPacket() if err != nil { - return nil, mc.markBadConn(err) + return nil, err } rows := new(textRows) @@ -482,7 +477,6 @@ func (mc *mysqlConn) finish() { // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.Load() { - mc.log(ErrInvalidConn) return driver.ErrBadConn } @@ -704,3 +698,6 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error { func (mc *mysqlConn) IsValid() bool { return !mc.closed.Load() } + +var _ driver.SessionResetter = &mysqlConn{} +var _ driver.Validator = &mysqlConn{} diff --git a/errors.go b/errors.go index 238e480f..584617b1 100644 --- a/errors.go +++ b/errors.go @@ -32,7 +32,7 @@ var ( // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn - // to trigger a resend. + // to trigger a resend. Use mc.markBadConn(err) to do this. // See https://github.com/go-sql-driver/mysql/pull/302 errBadConnNoWrite = errors.New("bad connection") ) diff --git a/statement.go b/statement.go index 0436f224..35b02bbe 100644 --- a/statement.go +++ b/statement.go @@ -51,7 +51,6 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { if stmt.mc.closed.Load() { - stmt.mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command @@ -95,7 +94,6 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if stmt.mc.closed.Load() { - stmt.mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command