diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index b55446d54d5..a16caaace8d 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -82,6 +82,9 @@ import ( func TestMain(m *testing.M) { utils.InitLoggerForTests() native.PrecomputeTestKeys(m) + registerTestSnowflakeEngine() + registerTestElasticsearchEngine() + registerTestSQLServerEngine() os.Exit(m.Run()) } @@ -599,7 +602,7 @@ func TestGCPRequireSSL(t *testing.T) { require.NoError(t, err) } -func init() { +func registerTestSQLServerEngine() { // Override SQL Server engine that is used normally with the test one // that mocks connection dial and Kerberos auth. common.RegisterEngine(newTestSQLServerEngine, defaults.ProtocolSQLServer) diff --git a/lib/srv/db/cassandra/engine.go b/lib/srv/db/cassandra/engine.go index 5a5ba18eec6..81b1381e54e 100644 --- a/lib/srv/db/cassandra/engine.go +++ b/lib/srv/db/cassandra/engine.go @@ -29,7 +29,6 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/defaults" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/srv/db/cassandra/protocol" "github.com/gravitational/teleport/lib/srv/db/common" @@ -37,12 +36,8 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -func init() { - common.RegisterEngine(newEngine, defaults.ProtocolCassandra) -} - -// newEngine create new Cassandra engine. -func newEngine(ec common.EngineConfig) common.Engine { +// NewEngine create new Cassandra engine. +func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, } diff --git a/lib/srv/db/common/engines.go b/lib/srv/db/common/engines.go index 30d63d4a323..f363c00fc1b 100644 --- a/lib/srv/db/common/engines.go +++ b/lib/srv/db/common/engines.go @@ -64,6 +64,18 @@ func GetEngine(name string, conf EngineConfig) (Engine, error) { return engineFn(conf), nil } +// CheckEngines checks if provided engine names are registered. +func CheckEngines(names ...string) error { + enginesMu.RLock() + defer enginesMu.RUnlock() + for _, name := range names { + if engines[name] == nil { + return trace.NotFound("database engine %q is not registered", name) + } + } + return nil +} + // EngineConfig is the common configuration every database engine uses. type EngineConfig struct { // Auth handles database access authentication. diff --git a/lib/srv/db/common/engines_test.go b/lib/srv/db/common/engines_test.go index 091866226cc..033cead41ef 100644 --- a/lib/srv/db/common/engines_test.go +++ b/lib/srv/db/common/engines_test.go @@ -33,6 +33,9 @@ import ( func TestRegisterEngine(t *testing.T) { // Cleanup "test" engine in case this test is run in a loop. RegisterEngine(nil, "test") + t.Cleanup(func() { + RegisterEngine(nil, "test") + }) ec := EngineConfig{ Context: context.Background(), @@ -48,6 +51,7 @@ func TestRegisterEngine(t *testing.T) { engine, err := GetEngine("test", ec) require.Nil(t, engine) require.IsType(t, trace.NotFound(""), err) + require.IsType(t, trace.NotFound(""), CheckEngines("test")) // Register a "test" engine. RegisterEngine(func(ec EngineConfig) Engine { diff --git a/lib/srv/db/elasticsearch/engine.go b/lib/srv/db/elasticsearch/engine.go index 380d6579242..f8ea271d99b 100644 --- a/lib/srv/db/elasticsearch/engine.go +++ b/lib/srv/db/elasticsearch/engine.go @@ -36,19 +36,14 @@ import ( "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/utils" ) -func init() { - common.RegisterEngine(newEngine, defaults.ProtocolElasticsearch) -} - -// newEngine create new elasticsearch engine. -func newEngine(ec common.EngineConfig) common.Engine { +// NewEngine create new elasticsearch engine. +func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{EngineConfig: ec} } diff --git a/lib/srv/db/elasticsearch_test.go b/lib/srv/db/elasticsearch_test.go index 76266e6ade8..f8330b6eb39 100644 --- a/lib/srv/db/elasticsearch_test.go +++ b/lib/srv/db/elasticsearch_test.go @@ -36,7 +36,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/elasticsearch" ) -func init() { +func registerTestElasticsearchEngine() { // Override Elasticsearch engine that is used normally with the test one // with custom HTTP client. common.RegisterEngine(newTestElasticsearchEngine, defaults.ProtocolElasticsearch) diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index 929fc9c7eda..12603eb797a 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -31,11 +31,8 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -func init() { - common.RegisterEngine(newEngine, defaults.ProtocolMongoDB) -} - -func newEngine(ec common.EngineConfig) common.Engine { +// NewEngine create new MongoDB engine. +func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, } diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index b02b87864e4..b7520f2474e 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -41,11 +41,8 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -func init() { - common.RegisterEngine(newEngine, defaults.ProtocolMySQL) -} - -func newEngine(ec common.EngineConfig) common.Engine { +// NewEngine create new MySQL engine. +func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, } diff --git a/lib/srv/db/postgres/engine.go b/lib/srv/db/postgres/engine.go index 87b215deb82..f60d618062a 100644 --- a/lib/srv/db/postgres/engine.go +++ b/lib/srv/db/postgres/engine.go @@ -28,20 +28,14 @@ import ( "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/db/cloud" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/srv/db/common/role" "github.com/gravitational/teleport/lib/utils" ) -func init() { - common.RegisterEngine(newEngine, - defaults.ProtocolPostgres, - defaults.ProtocolCockroachDB) -} - -func newEngine(ec common.EngineConfig) common.Engine { +// NewEngine create new Postgres engine. +func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, } diff --git a/lib/srv/db/redis/engine.go b/lib/srv/db/redis/engine.go index 513fad32b89..1fdae35fb9f 100644 --- a/lib/srv/db/redis/engine.go +++ b/lib/srv/db/redis/engine.go @@ -37,12 +37,8 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -func init() { - common.RegisterEngine(newEngine, defaults.ProtocolRedis) -} - -// newEngine create new Redis engine. -func newEngine(ec common.EngineConfig) common.Engine { +// NewEngine create new Redis engine. +func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 38f000c50bc..6d0f3365aaa 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -16,7 +16,6 @@ limitations under the License. package db -//nolint:goimports // goimports disagree with gci on blank imports import ( "context" "crypto/tls" @@ -42,25 +41,31 @@ import ( "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv" - // Import to register Cassandra engine. - _ "github.com/gravitational/teleport/lib/srv/db/cassandra" + "github.com/gravitational/teleport/lib/srv/db/cassandra" "github.com/gravitational/teleport/lib/srv/db/cloud" "github.com/gravitational/teleport/lib/srv/db/cloud/users" "github.com/gravitational/teleport/lib/srv/db/common" - // Import to register Elasticsearch engine. - _ "github.com/gravitational/teleport/lib/srv/db/elasticsearch" - // Import to register MongoDB engine. - _ "github.com/gravitational/teleport/lib/srv/db/mongodb" + "github.com/gravitational/teleport/lib/srv/db/elasticsearch" + "github.com/gravitational/teleport/lib/srv/db/mongodb" "github.com/gravitational/teleport/lib/srv/db/mysql" - // Import to register Postgres engine. - _ "github.com/gravitational/teleport/lib/srv/db/postgres" - // Import to register Redis engine. - _ "github.com/gravitational/teleport/lib/srv/db/redis" - // Import to register Snowflake engine. - _ "github.com/gravitational/teleport/lib/srv/db/snowflake" + "github.com/gravitational/teleport/lib/srv/db/postgres" + "github.com/gravitational/teleport/lib/srv/db/redis" + "github.com/gravitational/teleport/lib/srv/db/snowflake" + "github.com/gravitational/teleport/lib/srv/db/sqlserver" "github.com/gravitational/teleport/lib/utils" ) +func init() { + common.RegisterEngine(cassandra.NewEngine, defaults.ProtocolCassandra) + common.RegisterEngine(elasticsearch.NewEngine, defaults.ProtocolElasticsearch) + common.RegisterEngine(mongodb.NewEngine, defaults.ProtocolMongoDB) + common.RegisterEngine(mysql.NewEngine, defaults.ProtocolMySQL) + common.RegisterEngine(postgres.NewEngine, defaults.ProtocolPostgres, defaults.ProtocolCockroachDB) + common.RegisterEngine(redis.NewEngine, defaults.ProtocolRedis) + common.RegisterEngine(snowflake.NewEngine, defaults.ProtocolSnowflake) + common.RegisterEngine(sqlserver.NewEngine, defaults.ProtocolSQLServer) +} + // Config is the configuration for a database proxy server. type Config struct { // Clock used to control time. @@ -292,6 +297,10 @@ func (m *monitoredDatabases) get() types.ResourcesWithLabelsMap { // New returns a new database server. func New(ctx context.Context, config Config) (*Server, error) { + if err := common.CheckEngines(defaults.DatabaseProtocols...); err != nil { + return nil, trace.Wrap(err) + } + err := config.CheckAndSetDefaults(ctx) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/db/snowflake/engine.go b/lib/srv/db/snowflake/engine.go index 2c42f39b568..e18caccfbbb 100644 --- a/lib/srv/db/snowflake/engine.go +++ b/lib/srv/db/snowflake/engine.go @@ -44,12 +44,8 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -func init() { - common.RegisterEngine(newEngine, defaults.ProtocolSnowflake) -} - -// newEngine create new Snowflake engine. -func newEngine(ec common.EngineConfig) common.Engine { +// NewEngine create new Snowflake engine. +func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, HTTPClient: getDefaultHTTPClient(), diff --git a/lib/srv/db/snowflake_test.go b/lib/srv/db/snowflake_test.go index 981f596439d..94d87976b1e 100644 --- a/lib/srv/db/snowflake_test.go +++ b/lib/srv/db/snowflake_test.go @@ -46,7 +46,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/snowflake" ) -func init() { +func registerTestSnowflakeEngine() { // Override Snowflake engine that is used normally with the test one // with custom HTTP client. common.RegisterEngine(newTestSnowflakeEngine, defaults.ProtocolSnowflake) diff --git a/lib/srv/db/sqlserver/engine.go b/lib/srv/db/sqlserver/engine.go index 873094e4887..655e480b669 100644 --- a/lib/srv/db/sqlserver/engine.go +++ b/lib/srv/db/sqlserver/engine.go @@ -24,7 +24,6 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/defaults" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/db/common" @@ -32,11 +31,8 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -func init() { - common.RegisterEngine(newEngine, defaults.ProtocolSQLServer) -} - -func newEngine(ec common.EngineConfig) common.Engine { +// NewEngine create new SQL Server engine. +func NewEngine(ec common.EngineConfig) common.Engine { return &Engine{ EngineConfig: ec, Connector: &connector{