MySQL prepared statement support (#10283)

This commit is contained in:
STeve (Xin) Huang 2022-02-16 14:46:54 -05:00 committed by GitHub
parent daa2bcb6ad
commit 55fbd56217
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 5044 additions and 439 deletions

File diff suppressed because it is too large Load diff

View file

@ -1844,6 +1844,13 @@ message OneOf {
events.DesktopRecording DesktopRecording = 69;
events.DesktopClipboardSend DesktopClipboardSend = 70;
events.DesktopClipboardReceive DesktopClipboardReceive = 71;
events.MySQLStatementPrepare MySQLStatementPrepare = 72;
events.MySQLStatementExecute MySQLStatementExecute = 73;
events.MySQLStatementSendLongData MySQLStatementSendLongData = 74;
events.MySQLStatementClose MySQLStatementClose = 75;
events.MySQLStatementReset MySQLStatementReset = 76;
events.MySQLStatementFetch MySQLStatementFetch = 77;
events.MySQLStatementBulkExecute MySQLStatementBulkExecute = 78;
}
}
@ -1967,3 +1974,146 @@ message RouteToDatabase {
// Database is an optional database name to embed.
string Database = 4 [ (gogoproto.jsontag) = "database,omitempty" ];
}
// MySQLStatementPrepare is emitted when a MySQL client creates a prepared
// statement using the prepared statement protocol.
message MySQLStatementPrepare {
// Metadata is a common event metadata.
Metadata Metadata = 1
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// User is a common user event metadata.
UserMetadata User = 2
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// SessionMetadata is a common event session metadata.
SessionMetadata Session = 3
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// Database contains database related metadata.
DatabaseMetadata Database = 4
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// Query is the prepared statement query.
string Query = 5 [ (gogoproto.jsontag) = "query" ];
}
// MySQLStatementExecute is emitted when a MySQL client executes a prepared
// statement using the prepared statement protocol.
message MySQLStatementExecute {
// Metadata is a common event metadata.
Metadata Metadata = 1
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// User is a common user event metadata.
UserMetadata User = 2
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// SessionMetadata is a common event session metadata.
SessionMetadata Session = 3
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// Database contains database related metadata.
DatabaseMetadata Database = 4
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// StatementID is the identifier of the prepared statement.
uint32 StatementID = 5 [ (gogoproto.jsontag) = "statement_id" ];
// Parameters are the parameters used to execute the prepared statement.
repeated string Parameters = 6 [ (gogoproto.jsontag) = "parameters" ];
}
// MySQLStatementSendLongData is emitted when a MySQL client sends long bytes
// stream using the prepared statement protocol.
message MySQLStatementSendLongData {
// Metadata is a common event metadata.
Metadata Metadata = 1
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// User is a common user event metadata.
UserMetadata User = 2
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// SessionMetadata is a common event session metadata.
SessionMetadata Session = 3
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// Database contains database related metadata.
DatabaseMetadata Database = 4
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// StatementID is the identifier of the prepared statement.
uint32 StatementID = 5 [ (gogoproto.jsontag) = "statement_id" ];
// ParameterID is the identifier of the parameter.
uint32 ParameterID = 6 [ (gogoproto.jsontag) = "parameter_id" ];
// DataSize is the size of the data.
uint32 DataSize = 7 [ (gogoproto.jsontag) = "data_size" ];
}
// MySQLStatementClose is emitted when a MySQL client deallocates a prepared
// statement using the prepared statement protocol.
message MySQLStatementClose {
// Metadata is a common event metadata.
Metadata Metadata = 1
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// User is a common user event metadata.
UserMetadata User = 2
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// SessionMetadata is a common event session metadata.
SessionMetadata Session = 3
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// Database contains database related metadata.
DatabaseMetadata Database = 4
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// StatementID is the identifier of the prepared statement.
uint32 StatementID = 5 [ (gogoproto.jsontag) = "statement_id" ];
}
// MySQLStatementReset is emitted when a MySQL client resets the data of a
// prepared statement using the prepared statement protocol.
message MySQLStatementReset {
// Metadata is a common event metadata.
Metadata Metadata = 1
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// User is a common user event metadata.
UserMetadata User = 2
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// SessionMetadata is a common event session metadata.
SessionMetadata Session = 3
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// Database contains database related metadata.
DatabaseMetadata Database = 4
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// StatementID is the identifier of the prepared statement.
uint32 StatementID = 5 [ (gogoproto.jsontag) = "statement_id" ];
}
// MySQLStatementFetch is emitted when a MySQL client fetches rows from a
// prepared statement using the prepared statement protocol.
message MySQLStatementFetch {
// Metadata is a common event metadata.
Metadata Metadata = 1
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// User is a common user event metadata.
UserMetadata User = 2
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// SessionMetadata is a common event session metadata.
SessionMetadata Session = 3
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// Database contains database related metadata.
DatabaseMetadata Database = 4
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// StatementID is the identifier of the prepared statement.
uint32 StatementID = 5 [ (gogoproto.jsontag) = "statement_id" ];
// RowsCount is the number of rows to fetch.
uint32 RowsCount = 6 [ (gogoproto.jsontag) = "rows_count" ];
}
// MySQLStatementBulkExecute is emitted when a MySQL client executes a bulk
// insert of a prepared statement using the prepared statement protocol.
message MySQLStatementBulkExecute {
// Metadata is a common event metadata.
Metadata Metadata = 1
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// User is a common user event metadata.
UserMetadata User = 2
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// SessionMetadata is a common event session metadata.
SessionMetadata Session = 3
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// Database contains database related metadata.
DatabaseMetadata Database = 4
[ (gogoproto.nullable) = false, (gogoproto.embed) = true, (gogoproto.jsontag) = "" ];
// StatementID is the identifier of the prepared statement.
uint32 StatementID = 5 [ (gogoproto.jsontag) = "statement_id" ];
// Parameters are the parameters used to execute the prepared statement.
repeated string Parameters = 6 [ (gogoproto.jsontag) = "parameters" ];
}

View file

@ -321,6 +321,34 @@ func ToOneOf(in AuditEvent) (*OneOf, error) {
out.Event = &OneOf_DesktopClipboardSend{
DesktopClipboardSend: e,
}
case *MySQLStatementPrepare:
out.Event = &OneOf_MySQLStatementPrepare{
MySQLStatementPrepare: e,
}
case *MySQLStatementExecute:
out.Event = &OneOf_MySQLStatementExecute{
MySQLStatementExecute: e,
}
case *MySQLStatementSendLongData:
out.Event = &OneOf_MySQLStatementSendLongData{
MySQLStatementSendLongData: e,
}
case *MySQLStatementClose:
out.Event = &OneOf_MySQLStatementClose{
MySQLStatementClose: e,
}
case *MySQLStatementReset:
out.Event = &OneOf_MySQLStatementReset{
MySQLStatementReset: e,
}
case *MySQLStatementFetch:
out.Event = &OneOf_MySQLStatementFetch{
MySQLStatementFetch: e,
}
case *MySQLStatementBulkExecute:
out.Event = &OneOf_MySQLStatementBulkExecute{
MySQLStatementBulkExecute: e,
}
default:
return nil, trace.BadParameter("event type %T is not supported", in)
}

View file

@ -406,20 +406,45 @@ const (
// DatabaseSessionPostgresParseEvent is emitted when a Postgres client
// creates a prepared statement using extended query protocol.
DatabaseSessionPostgresParseEvent = "db.session.postgres.parse"
DatabaseSessionPostgresParseEvent = "db.session.postgres.statements.parse"
// DatabaseSessionPostgresBindEvent is emitted when a Postgres client
// readies a prepared statement for execution and binds it to parameters.
DatabaseSessionPostgresBindEvent = "db.session.postgres.bind"
DatabaseSessionPostgresBindEvent = "db.session.postgres.statements.bind"
// DatabaseSessionPostgresExecuteEvent is emitted when a Postgres client
// executes a previously bound prepared statement.
DatabaseSessionPostgresExecuteEvent = "db.session.postgres.execute"
DatabaseSessionPostgresExecuteEvent = "db.session.postgres.statements.execute"
// DatabaseSessionPostgresCloseEvent is emitted when a Postgres client
// closes an existing prepared statement.
DatabaseSessionPostgresCloseEvent = "db.session.postgres.close"
DatabaseSessionPostgresCloseEvent = "db.session.postgres.statements.close"
// DatabaseSessionPostgresFunctionEvent is emitted when a Postgres client
// calls an internal function.
DatabaseSessionPostgresFunctionEvent = "db.session.postgres.function"
// DatabaseSessionMySQLStatementPrepareEvent is emitted when a MySQL client
// creates a prepared statement using the prepared statement protocol.
DatabaseSessionMySQLStatementPrepareEvent = "db.session.mysql.statements.prepare"
// DatabaseSessionMySQLStatementExecuteEvent is emitted when a MySQL client
// executes a prepared statement using the prepared statement protocol.
DatabaseSessionMySQLStatementExecuteEvent = "db.session.mysql.statements.execute"
// DatabaseSessionMySQLStatementSendLongDataEvent is emitted when a MySQL
// client sends long bytes stream using the prepared statement protocol.
DatabaseSessionMySQLStatementSendLongDataEvent = "db.session.mysql.statements.send_long_data"
// DatabaseSessionMySQLStatementCloseEvent is emitted when a MySQL client
// deallocates a prepared statement using the prepared statement protocol.
DatabaseSessionMySQLStatementCloseEvent = "db.session.mysql.statements.close"
// DatabaseSessionMySQLStatementResetEvent is emitted when a MySQL client
// resets the data of a prepared statement using the prepared statement
// protocol.
DatabaseSessionMySQLStatementResetEvent = "db.session.mysql.statements.reset"
// DatabaseSessionMySQLStatementFetchEvent is emitted when a MySQL client
// fetches rows from a prepared statement using the prepared statement
// protocol.
DatabaseSessionMySQLStatementFetchEvent = "db.session.mysql.statements.fetch"
// DatabaseSessionMySQLStatementBulkExecuteEvent is emitted when a MySQL
// client executes a bulk insert of a prepared statement using the prepared
// statement protocol.
DatabaseSessionMySQLStatementBulkExecuteEvent = "db.session.mysql.statements.bulk_execute"
// SessionRejectedReasonMaxConnections indicates that a session.rejected event
// corresponds to enforcement of the max_connections control.
SessionRejectedReasonMaxConnections = "max_connections limit reached"

View file

@ -359,17 +359,32 @@ const (
// DatabaseSessionQueryFailedCode is the database query failure event code.
DatabaseSessionQueryFailedCode = "TDB02W"
// PostgresParseCode is the db.session.postgres.parse event code.
// PostgresParseCode is the db.session.postgres.statements.parse event code.
PostgresParseCode = "TPG00I"
// PostgresBindCode is the db.session.postgres.bind event code.
// PostgresBindCode is the db.session.postgres.statements.bind event code.
PostgresBindCode = "TPG01I"
// PostgresExecuteCode is the db.session.postgres.execute event code.
// PostgresExecuteCode is the db.session.postgres.statements.execute event code.
PostgresExecuteCode = "TPG02I"
// PostgresCloseCode is the db.session.postgres.close event code.
// PostgresCloseCode is the db.session.postgres.statements.close event code.
PostgresCloseCode = "TPG03I"
// PostgresFunctionCallCode is the db.session.postgres.function event code.
PostgresFunctionCallCode = "TPG04I"
// MySQLStatementPrepareCode is the db.session.mysql.statements.prepare event code.
MySQLStatementPrepareCode = "TMY00I"
// MySQLStatementExecuteCode is the db.session.mysql.statements.execute event code.
MySQLStatementExecuteCode = "TMY01I"
// MySQLStatementSendLongDataCode is the db.session.mysql.statements.send_long_data event code.
MySQLStatementSendLongDataCode = "TMY02I"
// MySQLStatementCloseCode is the db.session.mysql.statements.close event code.
MySQLStatementCloseCode = "TMY03I"
// MySQLStatementResetCode is the db.session.mysql.statements.reset event code.
MySQLStatementResetCode = "TMY04I"
// MySQLStatementFetchCode is the db.session.mysql.statements.fetch event code.
MySQLStatementFetchCode = "TMY05I"
// MySQLStatementBulkExecuteCode is the db.session.mysql.statements.bulk_execute event code.
MySQLStatementBulkExecuteCode = "TMY06I"
// DatabaseCreateCode is the db.create event code.
DatabaseCreateCode = "TDB03I"
// DatabaseUpdateCode is the db.update event code.

View file

@ -165,6 +165,20 @@ func FromEventFields(fields EventFields) (apievents.AuditEvent, error) {
e = &events.PostgresClose{}
case DatabaseSessionPostgresFunctionEvent:
e = &events.PostgresFunctionCall{}
case DatabaseSessionMySQLStatementPrepareEvent:
e = &events.MySQLStatementPrepare{}
case DatabaseSessionMySQLStatementExecuteEvent:
e = &events.MySQLStatementExecute{}
case DatabaseSessionMySQLStatementSendLongDataEvent:
e = &events.MySQLStatementSendLongData{}
case DatabaseSessionMySQLStatementCloseEvent:
e = &events.MySQLStatementClose{}
case DatabaseSessionMySQLStatementResetEvent:
e = &events.MySQLStatementReset{}
case DatabaseSessionMySQLStatementFetchEvent:
e = &events.MySQLStatementFetch{}
case DatabaseSessionMySQLStatementBulkExecuteEvent:
e = &events.MySQLStatementBulkExecute{}
case KubeRequestEvent:
e = &events.KubeRequest{}
case MFADeviceAddEvent:

View file

@ -552,9 +552,211 @@ func TestJSON(t *testing.T) {
EndTime: time.Date(2020, 04, 23, 18, 26, 35, 350*int(time.Millisecond), time.UTC),
},
},
{
name: "MySQL statement prepare",
json: `{"cluster_name":"test-cluster","code":"TMY00I","db_name":"test","db_protocol":"mysql","db_service":"test-mysql","db_uri":"localhost:3306","db_user":"alice","ei":22,"event":"db.session.mysql.statements.prepare","query":"select 1","sid":"test-session","time":"2022-02-22T22:22:22.222Z","uid":"test-id","user":"alice@example.com"}`,
event: apievents.MySQLStatementPrepare{
Metadata: apievents.Metadata{
Index: 22,
ID: "test-id",
Type: DatabaseSessionMySQLStatementPrepareEvent,
Code: MySQLStatementPrepareCode,
Time: time.Date(2022, 02, 22, 22, 22, 22, 222*int(time.Millisecond), time.UTC),
ClusterName: "test-cluster",
},
UserMetadata: apievents.UserMetadata{
User: "alice@example.com",
},
SessionMetadata: apievents.SessionMetadata{
SessionID: "test-session",
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseService: "test-mysql",
DatabaseProtocol: "mysql",
DatabaseURI: "localhost:3306",
DatabaseName: "test",
DatabaseUser: "alice",
},
Query: "select 1",
},
},
{
name: "MySQL statement execute",
json: `{"cluster_name":"test-cluster","code":"TMY01I","db_name":"test","db_protocol":"mysql","db_service":"test-mysql","db_uri":"localhost:3306","db_user":"alice","ei":22,"event":"db.session.mysql.statements.execute","parameters":null,"sid":"test-session","statement_id":222,"time":"2022-02-22T22:22:22.222Z","uid":"test-id","user":"alice@example.com"}`,
event: apievents.MySQLStatementExecute{
Metadata: apievents.Metadata{
Index: 22,
ID: "test-id",
Type: DatabaseSessionMySQLStatementExecuteEvent,
Code: MySQLStatementExecuteCode,
Time: time.Date(2022, 02, 22, 22, 22, 22, 222*int(time.Millisecond), time.UTC),
ClusterName: "test-cluster",
},
UserMetadata: apievents.UserMetadata{
User: "alice@example.com",
},
SessionMetadata: apievents.SessionMetadata{
SessionID: "test-session",
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseService: "test-mysql",
DatabaseProtocol: "mysql",
DatabaseURI: "localhost:3306",
DatabaseName: "test",
DatabaseUser: "alice",
},
StatementID: 222,
},
},
{
name: "MySQL statement send long data",
json: `{"cluster_name":"test-cluster","code":"TMY02I","db_name":"test","db_protocol":"mysql","db_service":"test-mysql","data_size":55,"db_uri":"localhost:3306","db_user":"alice","ei":22,"event":"db.session.mysql.statements.send_long_data","parameter_id":5,"sid":"test-session","statement_id":222,"time":"2022-02-22T22:22:22.222Z","uid":"test-id","user":"alice@example.com"}`,
event: apievents.MySQLStatementSendLongData{
Metadata: apievents.Metadata{
Index: 22,
ID: "test-id",
Type: DatabaseSessionMySQLStatementSendLongDataEvent,
Code: MySQLStatementSendLongDataCode,
Time: time.Date(2022, 02, 22, 22, 22, 22, 222*int(time.Millisecond), time.UTC),
ClusterName: "test-cluster",
},
UserMetadata: apievents.UserMetadata{
User: "alice@example.com",
},
SessionMetadata: apievents.SessionMetadata{
SessionID: "test-session",
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseService: "test-mysql",
DatabaseProtocol: "mysql",
DatabaseURI: "localhost:3306",
DatabaseName: "test",
DatabaseUser: "alice",
},
ParameterID: 5,
StatementID: 222,
DataSize: 55,
},
},
{
name: "MySQL statement close",
json: `{"cluster_name":"test-cluster","code":"TMY03I","db_name":"test","db_protocol":"mysql","db_service":"test-mysql","db_uri":"localhost:3306","db_user":"alice","ei":22,"event":"db.session.mysql.statements.close","sid":"test-session","statement_id":222,"time":"2022-02-22T22:22:22.222Z","uid":"test-id","user":"alice@example.com"}`,
event: apievents.MySQLStatementClose{
Metadata: apievents.Metadata{
Index: 22,
ID: "test-id",
Type: DatabaseSessionMySQLStatementCloseEvent,
Code: MySQLStatementCloseCode,
Time: time.Date(2022, 02, 22, 22, 22, 22, 222*int(time.Millisecond), time.UTC),
ClusterName: "test-cluster",
},
UserMetadata: apievents.UserMetadata{
User: "alice@example.com",
},
SessionMetadata: apievents.SessionMetadata{
SessionID: "test-session",
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseService: "test-mysql",
DatabaseProtocol: "mysql",
DatabaseURI: "localhost:3306",
DatabaseName: "test",
DatabaseUser: "alice",
},
StatementID: 222,
},
},
{
name: "MySQL statement reset",
json: `{"cluster_name":"test-cluster","code":"TMY04I","db_name":"test","db_protocol":"mysql","db_service":"test-mysql","db_uri":"localhost:3306","db_user":"alice","ei":22,"event":"db.session.mysql.statements.reset","sid":"test-session","statement_id":222,"time":"2022-02-22T22:22:22.222Z","uid":"test-id","user":"alice@example.com"}`,
event: apievents.MySQLStatementReset{
Metadata: apievents.Metadata{
Index: 22,
ID: "test-id",
Type: DatabaseSessionMySQLStatementResetEvent,
Code: MySQLStatementResetCode,
Time: time.Date(2022, 02, 22, 22, 22, 22, 222*int(time.Millisecond), time.UTC),
ClusterName: "test-cluster",
},
UserMetadata: apievents.UserMetadata{
User: "alice@example.com",
},
SessionMetadata: apievents.SessionMetadata{
SessionID: "test-session",
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseService: "test-mysql",
DatabaseProtocol: "mysql",
DatabaseURI: "localhost:3306",
DatabaseName: "test",
DatabaseUser: "alice",
},
StatementID: 222,
},
},
{
name: "MySQL statement fetch",
json: `{"cluster_name":"test-cluster","code":"TMY05I","db_name":"test","db_protocol":"mysql","db_service":"test-mysql","db_uri":"localhost:3306","db_user":"alice","ei":22,"event":"db.session.mysql.statements.fetch","rows_count": 5,"sid":"test-session","statement_id":222,"time":"2022-02-22T22:22:22.222Z","uid":"test-id","user":"alice@example.com"}`,
event: apievents.MySQLStatementFetch{
Metadata: apievents.Metadata{
Index: 22,
ID: "test-id",
Type: DatabaseSessionMySQLStatementFetchEvent,
Code: MySQLStatementFetchCode,
Time: time.Date(2022, 02, 22, 22, 22, 22, 222*int(time.Millisecond), time.UTC),
ClusterName: "test-cluster",
},
UserMetadata: apievents.UserMetadata{
User: "alice@example.com",
},
SessionMetadata: apievents.SessionMetadata{
SessionID: "test-session",
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseService: "test-mysql",
DatabaseProtocol: "mysql",
DatabaseURI: "localhost:3306",
DatabaseName: "test",
DatabaseUser: "alice",
},
StatementID: 222,
RowsCount: 5,
},
},
{
name: "MySQL statement bulk execute",
json: `{"cluster_name":"test-cluster","code":"TMY06I","db_name":"test","db_protocol":"mysql","db_service":"test-mysql","db_uri":"localhost:3306","db_user":"alice","ei":22,"event":"db.session.mysql.statements.bulk_execute","parameters":null,"sid":"test-session","statement_id":222,"time":"2022-02-22T22:22:22.222Z","uid":"test-id","user":"alice@example.com"}`,
event: apievents.MySQLStatementBulkExecute{
Metadata: apievents.Metadata{
Index: 22,
ID: "test-id",
Type: DatabaseSessionMySQLStatementBulkExecuteEvent,
Code: MySQLStatementBulkExecuteCode,
Time: time.Date(2022, 02, 22, 22, 22, 22, 222*int(time.Millisecond), time.UTC),
ClusterName: "test-cluster",
},
UserMetadata: apievents.UserMetadata{
User: "alice@example.com",
},
SessionMetadata: apievents.SessionMetadata{
SessionID: "test-session",
},
DatabaseMetadata: apievents.DatabaseMetadata{
DatabaseService: "test-mysql",
DatabaseProtocol: "mysql",
DatabaseURI: "localhost:3306",
DatabaseName: "test",
DatabaseUser: "alice",
},
StatementID: 222,
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
outJSON, err := utils.FastMarshal(tc.event)
require.NoError(t, err)
require.JSONEq(t, tc.json, string(outJSON))

View file

@ -298,6 +298,28 @@ func (e *Engine) receiveFromClient(clientConn, serverConn net.Conn, clientErrCh
return
case *protocol.Quit:
return
case *protocol.StatementPreparePacket:
e.Audit.EmitEvent(e.Context, makeStatementPrepareEvent(sessionCtx, pkt))
case *protocol.StatementExecutePacket:
// TODO(greedy52) Number of parameters is required to parse
// paremeters out of the packet. Parameter definitions are required
// to properly format the parameters for including in the audit
// log. Both number of parameters and parameter definitions can be
// obtained from the response of COM_STMT_PREPARE.
e.Audit.EmitEvent(e.Context, makeStatementExecuteEvent(sessionCtx, pkt))
case *protocol.StatementSendLongDataPacket:
e.Audit.EmitEvent(e.Context, makeStatementSendLongDataEvent(sessionCtx, pkt))
case *protocol.StatementClosePacket:
e.Audit.EmitEvent(e.Context, makeStatementCloseEvent(sessionCtx, pkt))
case *protocol.StatementResetPacket:
e.Audit.EmitEvent(e.Context, makeStatementResetEvent(sessionCtx, pkt))
case *protocol.StatementFetchPacket:
e.Audit.EmitEvent(e.Context, makeStatementFetchEvent(sessionCtx, pkt))
case *protocol.StatementBulkExecutePacket:
// TODO(greedy52) Number of parameters and parameter definitions
// are required. See above comments for StatementExecutePacket.
e.Audit.EmitEvent(e.Context, makeStatementBulkExecuteEvent(sessionCtx, pkt))
}
_, err = protocol.WritePacket(packet.Bytes(), serverConn)
if err != nil {

127
lib/srv/db/mysql/events.go Normal file
View file

@ -0,0 +1,127 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mysql
import (
"github.com/gravitational/teleport/api/types/events"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mysql/protocol"
)
// makeStatementPrepareEvent creates an audit event for MySQL statement prepare
// command.
func makeStatementPrepareEvent(session *common.Session, packet *protocol.StatementPreparePacket) events.AuditEvent {
return &events.MySQLStatementPrepare{
Metadata: common.MakeEventMetadata(session,
libevents.DatabaseSessionMySQLStatementPrepareEvent,
libevents.MySQLStatementPrepareCode),
UserMetadata: common.MakeUserMetadata(session),
SessionMetadata: common.MakeSessionMetadata(session),
DatabaseMetadata: common.MakeDatabaseMetadata(session),
Query: packet.Query(),
}
}
// makeStatementExecuteEvent creates an audit event for MySQL statement execute
// command.
func makeStatementExecuteEvent(session *common.Session, packet *protocol.StatementExecutePacket) events.AuditEvent {
// TODO(greedy52) get parameters from packet and format them for audit.
return &events.MySQLStatementExecute{
Metadata: common.MakeEventMetadata(session,
libevents.DatabaseSessionMySQLStatementExecuteEvent,
libevents.MySQLStatementExecuteCode),
UserMetadata: common.MakeUserMetadata(session),
SessionMetadata: common.MakeSessionMetadata(session),
DatabaseMetadata: common.MakeDatabaseMetadata(session),
StatementID: packet.StatementID(),
}
}
// makeStatementSendLongDataEvent creates an audit event for MySQL statement
// send long data command.
func makeStatementSendLongDataEvent(session *common.Session, packet *protocol.StatementSendLongDataPacket) events.AuditEvent {
return &events.MySQLStatementSendLongData{
Metadata: common.MakeEventMetadata(session,
libevents.DatabaseSessionMySQLStatementSendLongDataEvent,
libevents.MySQLStatementSendLongDataCode),
UserMetadata: common.MakeUserMetadata(session),
SessionMetadata: common.MakeSessionMetadata(session),
DatabaseMetadata: common.MakeDatabaseMetadata(session),
StatementID: packet.StatementID(),
ParameterID: uint32(packet.ParameterID()),
DataSize: uint32(len(packet.Data())),
}
}
// makeStatementCloseEvent creates an audit event for MySQL statement close
// command.
func makeStatementCloseEvent(session *common.Session, packet *protocol.StatementClosePacket) events.AuditEvent {
return &events.MySQLStatementClose{
Metadata: common.MakeEventMetadata(session,
libevents.DatabaseSessionMySQLStatementCloseEvent,
libevents.MySQLStatementCloseCode),
UserMetadata: common.MakeUserMetadata(session),
SessionMetadata: common.MakeSessionMetadata(session),
DatabaseMetadata: common.MakeDatabaseMetadata(session),
StatementID: packet.StatementID(),
}
}
// makeStatementResetEvent creates an audit event for MySQL statement close
// command.
func makeStatementResetEvent(session *common.Session, packet *protocol.StatementResetPacket) events.AuditEvent {
return &events.MySQLStatementReset{
Metadata: common.MakeEventMetadata(session,
libevents.DatabaseSessionMySQLStatementResetEvent,
libevents.MySQLStatementResetCode),
UserMetadata: common.MakeUserMetadata(session),
SessionMetadata: common.MakeSessionMetadata(session),
DatabaseMetadata: common.MakeDatabaseMetadata(session),
StatementID: packet.StatementID(),
}
}
// makeStatementFetchEvent creates an audit event for MySQL statement fetch
// command.
func makeStatementFetchEvent(session *common.Session, packet *protocol.StatementFetchPacket) events.AuditEvent {
return &events.MySQLStatementFetch{
Metadata: common.MakeEventMetadata(session,
libevents.DatabaseSessionMySQLStatementFetchEvent,
libevents.MySQLStatementFetchCode),
UserMetadata: common.MakeUserMetadata(session),
SessionMetadata: common.MakeSessionMetadata(session),
DatabaseMetadata: common.MakeDatabaseMetadata(session),
StatementID: packet.StatementID(),
RowsCount: packet.RowsCount(),
}
}
// makeStatementBulkExecuteEvent creates an audit event for MySQL statement
// bulk execute command.
func makeStatementBulkExecuteEvent(session *common.Session, packet *protocol.StatementBulkExecutePacket) events.AuditEvent {
// TODO(greedy52) get parameters from packet and format them for audit.
return &events.MySQLStatementBulkExecute{
Metadata: common.MakeEventMetadata(session,
libevents.DatabaseSessionMySQLStatementBulkExecuteEvent,
libevents.MySQLStatementBulkExecuteCode),
UserMetadata: common.MakeUserMetadata(session),
SessionMetadata: common.MakeSessionMetadata(session),
DatabaseMetadata: common.MakeDatabaseMetadata(session),
StatementID: packet.StatementID(),
}
}

View file

@ -19,7 +19,6 @@ package protocol
import (
"bytes"
"io"
"net"
"github.com/gravitational/trace"
"github.com/siddontang/go-mysql/mysql"
@ -103,7 +102,7 @@ func (p *ChangeUser) User() string {
// ParsePacket reads a protocol packet from the connection and returns it
// in a parsed form. See ReadPacket below for the packet structure.
func ParsePacket(conn net.Conn) (Packet, error) {
func ParsePacket(conn io.Reader) (Packet, error) {
packetBytes, packetType, err := ReadPacket(conn)
if err != nil {
return nil, trace.Wrap(err)
@ -122,9 +121,9 @@ func ParsePacket(conn net.Conn) (Packet, error) {
// fields. In protocol version 4.1 it includes '#' marker:
//
// https://dev.mysql.com/doc/internals/en/packet-ERR_Packet.html
minLen := 7 // 4-byte header + 3-byte payload before message
if bytes.Contains(packetBytes, []byte("#")) {
minLen = 13 // 4-byte header + 9-byte payload before message
minLen := packetHeaderSize + packetTypeSize + 2 // 4-byte header + 1-byte type + 2-byte error code
if len(packetBytes) > minLen && packetBytes[minLen] == '#' {
minLen += 6 // 1-byte marker '#' + 5-byte state
}
// Be a bit paranoid and make sure the packet is not truncated.
if len(packetBytes) < minLen {
@ -134,26 +133,75 @@ func ParsePacket(conn net.Conn) (Packet, error) {
case mysql.COM_QUERY:
// Be a bit paranoid and make sure the packet is not truncated.
if len(packetBytes) < 5 {
if len(packetBytes) < packetHeaderAndTypeSize {
return nil, trace.BadParameter("failed to parse COM_QUERY packet: %v", packetBytes)
}
// 4-byte packet header + 1-byte payload header, then query text.
return &Query{packet: packet, query: string(packetBytes[5:])}, nil
return &Query{packet: packet, query: string(packetBytes[packetHeaderAndTypeSize:])}, nil
case mysql.COM_QUIT:
return &Quit{packet: packet}, nil
case mysql.COM_CHANGE_USER:
if len(packetBytes) < 5 {
return nil, trace.BadParameter("failed to parse COM_CHANGE_USER packet: %s", packetBytes)
if len(packetBytes) < packetHeaderAndTypeSize {
return nil, trace.BadParameter("failed to parse COM_CHANGE_USER packet: %v", packetBytes)
}
// User is the first null-terminated string in the payload:
// https://dev.mysql.com/doc/internals/en/com-change-user.html#packet-COM_CHANGE_USER
idx := bytes.IndexByte(packetBytes[5:], 0x00)
idx := bytes.IndexByte(packetBytes[packetHeaderAndTypeSize:], 0x00)
if idx < 0 {
return nil, trace.BadParameter("failed to parse COM_CHANGE_USER packet: %s", packetBytes)
return nil, trace.BadParameter("failed to parse COM_CHANGE_USER packet: %v", packetBytes)
}
return &ChangeUser{packet: packet, user: string(packetBytes[5 : 5+idx])}, nil
return &ChangeUser{packet: packet, user: string(packetBytes[packetHeaderAndTypeSize : packetHeaderAndTypeSize+idx])}, nil
case mysql.COM_STMT_PREPARE:
packet, ok := parseStatementPreparePacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_PREPARE packet: %v", packetBytes)
}
return packet, nil
case mysql.COM_STMT_SEND_LONG_DATA:
packet, ok := parseStatementSendLongDataPacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_SEND_LONG_DATA packet: %v", packetBytes)
}
return packet, nil
case mysql.COM_STMT_EXECUTE:
packet, ok := parseStatementExecutePacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_EXECUTE packet: %v", packetBytes)
}
return packet, nil
case mysql.COM_STMT_CLOSE:
packet, ok := parseStatementClosePacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_CLOSE packet: %v", packetBytes)
}
return packet, nil
case mysql.COM_STMT_RESET:
packet, ok := parseStatementResetPacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_RESET packet: %v", packetBytes)
}
return packet, nil
case mysql.COM_STMT_FETCH:
packet, ok := parseStatementFetchPacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_FETCH packet: %v", packetBytes)
}
return packet, nil
case packetTypeStatementBulkExecute:
packet, ok := parseStatementBulkExecutePacket(packet)
if !ok {
return nil, trace.BadParameter("failed to parse COM_STMT_BULK_EXECUTE packet: %v", packetBytes)
}
return packet, nil
}
return &Generic{packet: packet}, nil
@ -176,7 +224,7 @@ func ParsePacket(conn net.Conn) (Packet, error) {
// number
//
// https://dev.mysql.com/doc/internals/en/mysql-packet.html
func ReadPacket(conn net.Conn) (pkt []byte, pktType byte, err error) {
func ReadPacket(conn io.Reader) (pkt []byte, pktType byte, err error) {
// Read 4-byte packet header.
var header [4]byte
if _, err := io.ReadFull(conn, header[:]); err != nil {
@ -206,10 +254,30 @@ func ReadPacket(conn net.Conn) (pkt []byte, pktType byte, err error) {
}
// WritePacket writes the provided protocol packet to the connection.
func WritePacket(pkt []byte, conn net.Conn) (int, error) {
func WritePacket(pkt []byte, conn io.Writer) (int, error) {
n, err := conn.Write(pkt)
if err != nil {
return 0, trace.ConvertSystemError(err)
}
return n, nil
}
const (
// packetHeaderSize is the size of the packet header.
packetHeaderSize = 4
// packetTypeSize is the size of the command type.
packetTypeSize = 1
// packetHeaderAndTypeSize is the combined size of the packet header and
// type.
packetHeaderAndTypeSize = packetHeaderSize + packetTypeSize
)
const (
// packetTypeStatementBulkExecute is a MariaDB specific packet type for
// COM_STMT_BULK_EXECUTE packets.
//
// https://mariadb.com/kb/en/com_stmt_bulk_execute/
packetTypeStatementBulkExecute = 0xfa
)

View file

@ -0,0 +1,322 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package protocol
import (
"bytes"
"io"
"net"
"testing"
"testing/iotest"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)
var (
sampleOKPacket = &OK{
packet: packet{
bytes: []byte{
0x03, 0x00, 0x00, 0x00, // header
0x00, // type
0x00, 0x00,
},
},
}
sampleQueryPacket = &Query{
packet: packet{
bytes: []byte{
0x09, 0x00, 0x00, 0x00, // header
0x03, // type
0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x31, // query
},
},
query: "select 1",
}
sampleErrorPacket = &Error{
packet: packet{
bytes: []byte{
0x09, 0x00, 0x00, 0x00, // header
0xff, // type
0x51, 0x04, // error code
0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, // message
},
},
message: "denied",
}
sampleErrorWithSQLStatePacket = &Error{
packet: packet{
bytes: []byte{
0x0f, 0x00, 0x00, 0x00, // header
0xff, // type
0x51, 0x04, // error code
0x23, // marker #
0x48, 0x59, 0x30, 0x30, 0x30, // state - HY000
0x64, 0x65, 0x6e, 0x69, 0x65, 0x64, // message
},
},
message: "denied",
}
sampleQuitPacket = &Quit{
packet: packet{
bytes: []byte{
0x01, 0x00, 0x00, 0x00, // header
0x01, //type
},
},
}
sampleChangeUserPacket = &ChangeUser{
packet: packet{
bytes: []byte{
0x05, 0x00, 0x00, 0x04, // header
0x11, // type
0x62, 0x6f, 0x62, 0x00, // null terminated "bob"
},
},
user: "bob",
}
sampleStatementPreparePacket = &StatementPreparePacket{
packet: packet{
bytes: []byte{
0x09, 0x00, 0x00, 0x00, // header
0x16, // type
0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x31, // query
},
},
query: "select 1",
}
sampleStatementSendLongDataPacket = &StatementSendLongDataPacket{
statementIDPacket: statementIDPacket{
packet: packet{
bytes: []byte{
0x0a, 0x00, 0x00, 0x00, // header
0x18, // type
0x05, 0x00, 0x00, 0x00, // statement ID
0x02, 0x00, // parameter ID
0x62, 0x6f, 0x62, //data
},
},
statementID: 5,
},
parameterID: 2,
data: []byte("bob"),
}
sampleStatementExecutePacket = &StatementExecutePacket{
statementIDPacket: statementIDPacket{
packet: packet{
bytes: []byte{
0x1e, 0x00, 0x00, 0x00, // header
0x17, // type
0x02, 0x00, 0x00, 0x00, // statement ID
0x00, // cursor flag
0x01, 0x00, 0x00, 0x00, // iteration count
0x00, // nullbit map
0x01, // new-params-bound flag
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-type
0xfe, 0x00, // param 1 type - MYSQL_TYPE_STRING
0x08, 0x00, // param 2 type - MYSQL_TYPE_LONGLONG
0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // param 1 value - "hello"
0xc8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // param 2 value - 200
},
},
statementID: 2,
},
cursorFlag: 0x00,
iterations: 1,
nullBitmapAndParameters: []byte{
0x00, // null bitmap
0x01, // new-params-bound flag
0xfe, 0x00, // param 1 type - MYSQL_TYPE_STRING
0x08, 0x00, // param 2 type - MYSQL_TYPE_LONGLONG
0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f, // param 1 value - "hello"
0xc8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // param 2 value - 200
},
}
sampleStatementClosePacket = &StatementClosePacket{
statementIDPacket: statementIDPacket{
packet: packet{
bytes: []byte{
0x05, 0x00, 0x00, 0x00, // header
0x19, // type
0x01, 0x00, 0x00, 0x00, // statement ID
},
},
statementID: 1,
},
}
sampleStatementResetPacket = &StatementResetPacket{
statementIDPacket: statementIDPacket{
packet: packet{
bytes: []byte{
0x05, 0x00, 0x00, 0x00, // header
0x1a, // type
0x01, 0x00, 0x00, 0x00, // statement ID
},
},
statementID: 1,
},
}
sampleStatementFetchPacket = &StatementFetchPacket{
statementIDPacket: statementIDPacket{
packet: packet{
bytes: []byte{
0x09, 0x00, 0x00, 0x00, // header
0x1c, // type
0x01, 0x00, 0x00, 0x00, // statement ID
0x0a, 0x00, 0x00, 0x00, // num rows
},
},
statementID: 1,
},
rowsCount: 10,
}
sampleStatementBulkExecutePacket = &StatementBulkExecutePacket{
statementIDPacket: statementIDPacket{
packet: packet{
bytes: []byte{
0x15, 0x00, 0x00, 0x00, // header
0xfa, // type
0x01, 0x00, 0x00, 0x00, // statement ID
0x80, 0x00, // bulkFlag
0xfe, 0x00, // param 1 type - MYSQL_TYPE_STRING
0x08, 0x00, // param 2 type - MYSQL_TYPE_LONGLONG
0x01, // param 1 - null
0x00, 0xc8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // param 2 value - 200
},
},
statementID: 1,
},
bulkFlag: 128,
parameters: []byte{
0xfe, 0x00, // param 1 type - MYSQL_TYPE_STRING
0x08, 0x00, // param 2 type - MYSQL_TYPE_LONGLONG
0x01, // param 1 - null
0x00, 0xc8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // param 2 value - 200
},
}
)
func TestParsePacket(t *testing.T) {
tests := []struct {
name string
input io.Reader
expectedPacket Packet
expectErrorIs func(error) bool
}{
{
name: "network error",
input: iotest.ErrReader(&net.OpError{}),
expectErrorIs: trace.IsConnectionProblem,
},
{
name: "OK_HEADER",
input: bytes.NewBuffer(sampleOKPacket.Bytes()),
expectedPacket: sampleOKPacket,
},
{
name: "ERR_HEADER",
input: bytes.NewBuffer(sampleErrorPacket.Bytes()),
expectedPacket: sampleErrorPacket,
},
{
name: "ERR_HEADER protocol 4.1",
input: bytes.NewBuffer(sampleErrorWithSQLStatePacket.Bytes()),
expectedPacket: sampleErrorWithSQLStatePacket,
},
{
name: "COM_QUERY",
input: bytes.NewBuffer(sampleQueryPacket.Bytes()),
expectedPacket: sampleQueryPacket,
},
{
name: "COM_QUIT",
input: bytes.NewBuffer(sampleQuitPacket.Bytes()),
expectedPacket: sampleQuitPacket,
},
{
name: "COM_CHANGE_USER",
input: bytes.NewBuffer(sampleChangeUserPacket.Bytes()),
expectedPacket: sampleChangeUserPacket,
},
{
name: "COM_STMT_PREPARE",
input: bytes.NewBuffer(sampleStatementPreparePacket.Bytes()),
expectedPacket: sampleStatementPreparePacket,
},
{
name: "COM_STMT_SEND_LONG_DATA",
input: bytes.NewBuffer(sampleStatementSendLongDataPacket.Bytes()),
expectedPacket: sampleStatementSendLongDataPacket,
},
{
name: "COM_STMT_EXECUTE",
input: bytes.NewBuffer(sampleStatementExecutePacket.Bytes()),
expectedPacket: sampleStatementExecutePacket,
},
{
name: "COM_STMT_CLOSE",
input: bytes.NewBuffer(sampleStatementClosePacket.Bytes()),
expectedPacket: sampleStatementClosePacket,
},
{
name: "COM_STMT_RESET",
input: bytes.NewBuffer(sampleStatementResetPacket.Bytes()),
expectedPacket: sampleStatementResetPacket,
},
{
name: "COM_STMT_FETCH",
input: bytes.NewBuffer(sampleStatementFetchPacket.Bytes()),
expectedPacket: sampleStatementFetchPacket,
},
{
name: "COM_STMT_BULK_EXECUTE",
input: bytes.NewBuffer(sampleStatementBulkExecutePacket.Bytes()),
expectedPacket: sampleStatementBulkExecutePacket,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
actualPacket, err := ParsePacket(test.input)
if test.expectErrorIs != nil {
require.Error(t, err)
require.True(t, test.expectErrorIs(err))
} else {
require.NoError(t, err)
require.Equal(t, test.expectedPacket, actualPacket)
}
})
}
}

View file

@ -0,0 +1,57 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package protocol
import "encoding/binary"
// skipHeaderAndType skips packet header and command type, and returns rest of
// the bytes.
func skipHeaderAndType(input []byte) (unread []byte, ok bool) {
return skipBytes(input, packetHeaderAndTypeSize)
}
// skipBytes skips n bytes from input and returns rest of the bytes.
func skipBytes(input []byte, n int) (unread []byte, ok bool) {
if len(input) < n {
return nil, false
}
return input[n:], true
}
// readByte reads one byte from input and returns rest of the bytes.
func readByte(input []byte) (unread []byte, read byte, ok bool) {
if len(input) < 1 {
return nil, 0x00, false
}
return input[1:], input[0], true
}
// readUint32 reads an uint32 from input and returns rest of the bytes.
func readUint32(input []byte) (unread []byte, read uint32, ok bool) {
if len(input) < 4 {
return nil, 0, false
}
return input[4:], binary.LittleEndian.Uint32(input[:4]), true
}
// readUint16 reads an uint16 from input and returns rest of the bytes.
func readUint16(input []byte) (unread []byte, read uint16, ok bool) {
if len(input) < 2 {
return nil, 0, false
}
return input[2:], binary.LittleEndian.Uint16(input[:2]), true
}

View file

@ -0,0 +1,329 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package protocol
import "github.com/siddontang/go-mysql/mysql"
// StatementPreparePacket represents the COM_STMT_PREPARE command.
//
// https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html
// https://mariadb.com/kb/en/com_stmt_prepare/
//
// COM_STMT_PREPARE creates a prepared statement from passed query string.
// Parameter placeholders are marked with "?" in the query. A COM_STMT_PREPARE
// response is expected from the server after sending this command.
type StatementPreparePacket struct {
packet
// query is the query to prepare.
query string
}
// Query returns the query text.
func (p *StatementPreparePacket) Query() string {
return p.query
}
// statementIDPacket represents a common packet format where statement ID is
// after the packet type.
//
// The statement ID is returned by the server in the COM_STMT_PREPARE response.
// All prepared statement packets except COM_STMT_PREPARE starts with the
// statement ID after the packet type to identify the prepared statement to
// use.
//
// The statement ID is an unsigned integer counter, usually starting at 1 for
// each client connection.
type statementIDPacket struct {
packet
// statementID is the ID of the associated statement.
statementID uint32
}
// StatementID returns the statement ID.
func (p *statementIDPacket) StatementID() uint32 {
return p.statementID
}
// StatementSendLongDataPacket represents the COM_STMT_SEND_LONG_DATA command.
//
// https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
// https://mariadb.com/kb/en/com_stmt_send_long_data/
//
// COM_STMT_SEND_LONG_DATA is used to send byte stream data to the server, and
// the server appends this data to the specified parameter upon receiving it.
// It is usually used for big blobs.
type StatementSendLongDataPacket struct {
statementIDPacket
// parameterID is the identifier of the parameter or column.
parameterID uint16
// data is the byte data sent in the packet.
data []byte
}
// ParameterID returns the parameter ID.
func (p *StatementSendLongDataPacket) ParameterID() uint16 {
return p.parameterID
}
// Data returns the data in bytes.
func (p *StatementSendLongDataPacket) Data() []byte {
return p.data
}
// StatementExecutePacket represents the COM_STMT_EXECUTE command.
//
// https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
// https://mariadb.com/kb/en/com_stmt_execute/
//
// COM_STMT_EXECUTE asks the server to execute a prepared statement, with the
// types and values for the placeholders.
//
// Statement ID "-1" (0xffffffff) can be used to indicate the last statement
// prepared on current connection, for MariaDB server version 10.2 and above.
type StatementExecutePacket struct {
statementIDPacket
// cursorFlag specifies type of the cursor.
cursorFlag byte
// iterations is the iteration count specified in the command. The MySQL
// doc states that it is always 1.
iterations uint32
// nullBitmapAndParameters are raw packet bytes that represent a null
// bitmap and parameters with types and values. They are not decoded in the
// initial parsing because number of parameters is unknown.
nullBitmapAndParameters []byte
}
// Parameters returns a slice of parameters.
func (p *StatementExecutePacket) Parameters(definitions []mysql.Field) (parameters []interface{}, ok bool) {
// TODO(greedy52) implement parsing of null bitmap, parameter types, and
// paramerter binary values.
return nil, true
}
// StatementClosePacket represents the COM_STMT_CLOSE command.
//
// https://dev.mysql.com/doc/internals/en/com-stmt-close.html
// https://mariadb.com/kb/en/3-binary-protocol-prepared-statements-com_stmt_close/
//
// COM_STMT_CLOSE deallocates a prepared statement.
type StatementClosePacket struct {
statementIDPacket
}
// StatementResetPacket represents the COM_STMT_RESET command.
//
// https://dev.mysql.com/doc/internals/en/com-stmt-reset.html
// https://mariadb.com/kb/en/com_stmt_reset/
//
// COM_STMT_RESET resets the data of a prepared statement which was accumulated
// with COM_STMT_SEND_LONG_DATA.
type StatementResetPacket struct {
statementIDPacket
}
// StatementFetchPacket represents the COM_STMT_FETCH command.
//
// https://dev.mysql.com/doc/internals/en/com-stmt-fetch.html
// https://mariadb.com/kb/en/com_stmt_fetch/
//
// COM_STMT_FETCH fetch rows from a existing resultset after a
// COM_STMT_EXECUTE.
type StatementFetchPacket struct {
statementIDPacket
// rowsCount number of rows to fetch.
rowsCount uint32
}
// RowsCount returns number of rows to fetch.
func (s *StatementFetchPacket) RowsCount() uint32 {
return s.rowsCount
}
// StatementBulkExecutePacket represents the COM_STMT_BULK_EXECUTE command.
//
// https://mariadb.com/kb/en/com_stmt_bulk_execute/
//
// COM_STMT_BULK_EXECUTE executes a bulk insert of a previously prepared
// statement.
type StatementBulkExecutePacket struct {
statementIDPacket
// bulkFlag is a flag specifies either 64 (return generated auto-increment
// IDs) or 128 (send types to server).
bulkFlag uint16
// parameters are raw packet bytes that contain parameter type and values.
// They are not decoded in the initial parsing because number of parameters
// is unknown.
parameters []byte
}
// Parameters returns a slice of parameters.
func (p *StatementBulkExecutePacket) Parameters(definitions []mysql.Field) (parameters []interface{}, ok bool) {
// TODO(greedy52) implement parsing of parameters from
// COM_STMT_BULK_EXECUTE packet.
return nil, true
}
// parseStatementPreparePacket parses packet bytes and returns a Packet if
// successful.
func parseStatementPreparePacket(rawPacket packet) (Packet, bool) {
unread, ok := skipHeaderAndType(rawPacket.bytes)
if !ok {
return nil, false
}
return &StatementPreparePacket{
packet: rawPacket,
query: string(unread),
}, true
}
// parseStatementIDPacket parses packet bytes and returns a statementIDPacket
// if successful.
func parseStatementIDPacket(rawPacket packet) (statementIDPacket, []byte, bool) {
unread, ok := skipHeaderAndType(rawPacket.bytes)
if !ok {
return statementIDPacket{}, nil, false
}
unread, statementID, ok := readUint32(unread)
if !ok {
return statementIDPacket{}, nil, false
}
return statementIDPacket{
packet: rawPacket,
statementID: statementID,
}, unread, true
}
// parseStatementSendLongDataPacket parses packet bytes and returns a Packet if
// successful.
func parseStatementSendLongDataPacket(rawPacket packet) (Packet, bool) {
parent, unread, ok := parseStatementIDPacket(rawPacket)
if !ok {
return nil, false
}
unread, parameterID, ok := readUint16(unread)
if !ok {
return nil, false
}
return &StatementSendLongDataPacket{
statementIDPacket: parent,
parameterID: parameterID,
data: unread,
}, true
}
// parseStatementExecutePacket parses packet bytes and returns a Packet if
// successful.
func parseStatementExecutePacket(rawPacket packet) (Packet, bool) {
parent, unread, ok := parseStatementIDPacket(rawPacket)
if !ok {
return nil, false
}
unread, cursorFlag, ok := readByte(unread)
if !ok {
return nil, false
}
unread, iterations, ok := readUint32(unread)
if !ok {
return nil, false
}
return &StatementExecutePacket{
statementIDPacket: parent,
cursorFlag: cursorFlag,
iterations: iterations,
nullBitmapAndParameters: unread,
}, true
}
// parseStatementClosePacket parses packet bytes and returns a Packet if
// successful.
func parseStatementClosePacket(rawPacket packet) (Packet, bool) {
parent, _, ok := parseStatementIDPacket(rawPacket)
if !ok {
return nil, false
}
return &StatementClosePacket{
statementIDPacket: parent,
}, true
}
// parseStatementResetPacket parses packet bytes and returns a Packet if
// successful.
func parseStatementResetPacket(rawPacket packet) (Packet, bool) {
parent, _, ok := parseStatementIDPacket(rawPacket)
if !ok {
return nil, false
}
return &StatementResetPacket{
statementIDPacket: parent,
}, true
}
// parseStatementFetchPacket parses packet bytes and returns a Packet if
// successful.
func parseStatementFetchPacket(rawPacket packet) (Packet, bool) {
parent, unread, ok := parseStatementIDPacket(rawPacket)
if !ok {
return nil, false
}
_, rowsCount, ok := readUint32(unread)
if !ok {
return nil, false
}
return &StatementFetchPacket{
statementIDPacket: parent,
rowsCount: rowsCount,
}, true
}
// parseStatementBulkExecutePacket parses packet bytes and returns a Packet if
// successful.
func parseStatementBulkExecutePacket(rawPacket packet) (Packet, bool) {
parent, unread, ok := parseStatementIDPacket(rawPacket)
if !ok {
return nil, false
}
unread, bulkFlag, ok := readUint16(unread)
if !ok {
return nil, false
}
return &StatementBulkExecutePacket{
statementIDPacket: parent,
bulkFlag: bulkFlag,
parameters: unread,
}, true
}