Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Daniel Montoya <dsmontoyam at gmail.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
Diego Dupin <diego.dupin at gmail.com>
Dirkjan Bussink <d.bussink at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Expand Down
28 changes: 14 additions & 14 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func TestAuthFastCachingSHA256PasswordCached(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -176,7 +176,7 @@ func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -232,7 +232,7 @@ func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -284,7 +284,7 @@ func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -357,7 +357,7 @@ func TestAuthFastCleartextPassword(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -400,7 +400,7 @@ func TestAuthFastCleartextPasswordEmpty(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -459,7 +459,7 @@ func TestAuthFastNativePassword(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -502,7 +502,7 @@ func TestAuthFastNativePasswordEmpty(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -544,7 +544,7 @@ func TestAuthFastSHA256PasswordEmpty(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -592,7 +592,7 @@ func TestAuthFastSHA256PasswordRSA(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -641,7 +641,7 @@ func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -678,7 +678,7 @@ func TestAuthFastSHA256PasswordSecure(t *testing.T) {
// unset TLS config to prevent the actual establishment of a TLS wrapper
mc.cfg.TLS = nil

err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1343,7 +1343,7 @@ func TestEd25519Auth(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = mc.writeHandshakeResponsePacket(authResp, plugin)
err = mc.writeHandshakeResponsePacket(authResp, 0, 0, plugin)
if err != nil {
t.Fatal(err)
}
Expand Down
50 changes: 50 additions & 0 deletions benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,53 @@ func BenchmarkReceiveMassiveRows(b *testing.B) {
}
})
}

// BenchmarkReceiveMetadata measures performance of receiving more metadata than real data
func BenchmarkReceiveMetadata(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()

// Create a table with 1000 integer fields
createTableQuery := "CREATE TABLE large_integer_table ("
for i := 0; i < 1000; i++ {
createTableQuery += fmt.Sprintf("col_%d INT", i)
if i < 999 {
createTableQuery += ", "
}
}
createTableQuery += ")"

// Initialize database
db := initDB(b, false,
"DROP TABLE IF EXISTS large_integer_table",
createTableQuery,
"INSERT INTO large_integer_table VALUES ("+
strings.Repeat("0,", 999)+"0)", // Insert a row of zeros
)
defer db.Close()

// Prepare a SELECT query to retrieve metadata
stmt := tb.checkStmt(db.Prepare("SELECT * FROM large_integer_table LIMIT 1"))
defer stmt.Close()

b.StartTimer()

// Benchmark metadata retrieval
for i := 0; i < b.N; i++ {
rows := tb.checkRows(stmt.Query())

// Create a slice to scan all columns
values := make([]interface{}, 1000)
valuePtrs := make([]interface{}, 1000)
for j := range values {
valuePtrs[j] = &values[j]
}
rows.Next()
// Scan the row
err := rows.Scan(valuePtrs...)
tb.check(err)

rows.Close()
}
}
57 changes: 33 additions & 24 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,22 @@ import (
)

type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
compIO *compIO
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
flags clientFlag
status statusFlag
sequence uint8
compressSequence uint8
parseTime bool
compress bool
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
result mysqlResult // managed by clearResult() and handleOkPacket().
compIO *compIO
cfg *Config
connector *connector
maxAllowedPacket int
maxWriteSize int
clientCapabilities capabilityFlag
clientExtCapabilities extendedCapabilityFlag
status statusFlag
sequence uint8
compressSequence uint8
parseTime bool
compress bool

// for context support (Go 1.8+)
watching bool
Expand Down Expand Up @@ -223,13 +224,21 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
if err = mc.skipColumns(stmt.paramCount); err != nil {
return nil, err
}
}

if columnCount > 0 {
err = mc.readUntilEOF()
if mc.clientExtCapabilities&clientCacheMetadata != 0 {
if stmt.columns, err = mc.readColumns(int(columnCount)); err != nil {
return nil, err
}
} else {
if err = mc.skipColumns(int(columnCount)); err != nil {
return nil, err
}
}
}
}

Expand Down Expand Up @@ -370,19 +379,19 @@ func (mc *mysqlConn) exec(query string) error {
}

// Read Result
resLen, err := handleOk.readResultSetHeaderPacket()
resLen, _, err := handleOk.readResultSetHeaderPacket()
if err != nil {
return err
}

if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
if err := mc.skipColumns(resLen); err != nil {
return err
}

// rows
if err := mc.readUntilEOF(); err != nil {
if err := mc.skipResultSetRows(); err != nil {
return err
}
}
Expand Down Expand Up @@ -419,7 +428,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)

// Read Result
var resLen int
resLen, err = handleOk.readResultSetHeaderPacket()
resLen, _, err = handleOk.readResultSetHeaderPacket()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -453,22 +462,22 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
}

// Read Result
resLen, err := handleOk.readResultSetHeaderPacket()
resLen, _, err := handleOk.readResultSetHeaderPacket()
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}

if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
if err := mc.skipColumns(resLen); err != nil {
return nil, err
}
}

dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
return dest[0].([]byte), mc.skipResultSetRows()
}
}
return nil, err
Expand Down
6 changes: 3 additions & 3 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
mc.buf = newBuffer()

// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
authData, serverCapabilities, serverExtendedCapabilities, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
Expand All @@ -153,7 +153,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
if err = mc.writeHandshakeResponsePacket(authResp, serverCapabilities, serverExtendedCapabilities, plugin); err != nil {
mc.cleanup()
return nil, err
}
Expand All @@ -167,7 +167,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return nil, err
}

if mc.cfg.compress && mc.flags&clientCompress == clientCompress {
if mc.cfg.compress && mc.clientCapabilities&clientCompress > 0 {
mc.compress = true
mc.compIO = newCompIO(mc)
}
Expand Down
18 changes: 16 additions & 2 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ const (
)

// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
type clientFlag uint32
type capabilityFlag uint32

const (
clientLongPassword clientFlag = 1 << iota
clientMySQL capabilityFlag = 1 << iota
clientFoundRows
clientLongFlag
clientConnectWithDB
Expand All @@ -73,6 +73,20 @@ const (
clientDeprecateEOF
)

// https://mariadb.com/kb/en/connection/#capabilities
type extendedCapabilityFlag uint32

const (
progressIndicator extendedCapabilityFlag = 1 << iota
clientComMulti
clientStmtBulkOperations
clientExtendedMetadata
clientCacheMetadata
clientUnitBulkResult
)

// https://mariadb.com/kb/en/connection/#capabilities

const (
comQuit byte = iota + 1
comInitDB
Expand Down
Loading