Register database engines directly from db lib (#19279)

This commit is contained in:
STeve (Xin) Huang 2022-12-18 00:59:47 -05:00 committed by GitHub
parent 66b65dd2d7
commit 6c858c09ee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 60 additions and 66 deletions

View file

@ -82,6 +82,9 @@ import (
func TestMain(m *testing.M) {
utils.InitLoggerForTests()
native.PrecomputeTestKeys(m)
registerTestSnowflakeEngine()
registerTestElasticsearchEngine()
registerTestSQLServerEngine()
os.Exit(m.Run())
}
@ -599,7 +602,7 @@ func TestGCPRequireSSL(t *testing.T) {
require.NoError(t, err)
}
func init() {
func registerTestSQLServerEngine() {
// Override SQL Server engine that is used normally with the test one
// that mocks connection dial and Kerberos auth.
common.RegisterEngine(newTestSQLServerEngine, defaults.ProtocolSQLServer)

View file

@ -29,7 +29,6 @@ import (
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/defaults"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/srv/db/cassandra/protocol"
"github.com/gravitational/teleport/lib/srv/db/common"
@ -37,12 +36,8 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(newEngine, defaults.ProtocolCassandra)
}
// newEngine create new Cassandra engine.
func newEngine(ec common.EngineConfig) common.Engine {
// NewEngine create new Cassandra engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
}

View file

@ -64,6 +64,18 @@ func GetEngine(name string, conf EngineConfig) (Engine, error) {
return engineFn(conf), nil
}
// CheckEngines checks if provided engine names are registered.
func CheckEngines(names ...string) error {
enginesMu.RLock()
defer enginesMu.RUnlock()
for _, name := range names {
if engines[name] == nil {
return trace.NotFound("database engine %q is not registered", name)
}
}
return nil
}
// EngineConfig is the common configuration every database engine uses.
type EngineConfig struct {
// Auth handles database access authentication.

View file

@ -33,6 +33,9 @@ import (
func TestRegisterEngine(t *testing.T) {
// Cleanup "test" engine in case this test is run in a loop.
RegisterEngine(nil, "test")
t.Cleanup(func() {
RegisterEngine(nil, "test")
})
ec := EngineConfig{
Context: context.Background(),
@ -48,6 +51,7 @@ func TestRegisterEngine(t *testing.T) {
engine, err := GetEngine("test", ec)
require.Nil(t, engine)
require.IsType(t, trace.NotFound(""), err)
require.IsType(t, trace.NotFound(""), CheckEngines("test"))
// Register a "test" engine.
RegisterEngine(func(ec EngineConfig) Engine {

View file

@ -36,19 +36,14 @@ import (
"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/common/role"
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(newEngine, defaults.ProtocolElasticsearch)
}
// newEngine create new elasticsearch engine.
func newEngine(ec common.EngineConfig) common.Engine {
// NewEngine create new elasticsearch engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{EngineConfig: ec}
}

View file

@ -36,7 +36,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db/elasticsearch"
)
func init() {
func registerTestElasticsearchEngine() {
// Override Elasticsearch engine that is used normally with the test one
// with custom HTTP client.
common.RegisterEngine(newTestElasticsearchEngine, defaults.ProtocolElasticsearch)

View file

@ -31,11 +31,8 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(newEngine, defaults.ProtocolMongoDB)
}
func newEngine(ec common.EngineConfig) common.Engine {
// NewEngine create new MongoDB engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
}

View file

@ -41,11 +41,8 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(newEngine, defaults.ProtocolMySQL)
}
func newEngine(ec common.EngineConfig) common.Engine {
// NewEngine create new MySQL engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
}

View file

@ -28,20 +28,14 @@ import (
"github.com/sirupsen/logrus"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/srv/db/cloud"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/common/role"
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(newEngine,
defaults.ProtocolPostgres,
defaults.ProtocolCockroachDB)
}
func newEngine(ec common.EngineConfig) common.Engine {
// NewEngine create new Postgres engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
}

View file

@ -37,12 +37,8 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(newEngine, defaults.ProtocolRedis)
}
// newEngine create new Redis engine.
func newEngine(ec common.EngineConfig) common.Engine {
// NewEngine create new Redis engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
}

View file

@ -16,7 +16,6 @@ limitations under the License.
package db
//nolint:goimports // goimports disagree with gci on blank imports
import (
"context"
"crypto/tls"
@ -42,25 +41,31 @@ import (
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv"
// Import to register Cassandra engine.
_ "github.com/gravitational/teleport/lib/srv/db/cassandra"
"github.com/gravitational/teleport/lib/srv/db/cassandra"
"github.com/gravitational/teleport/lib/srv/db/cloud"
"github.com/gravitational/teleport/lib/srv/db/cloud/users"
"github.com/gravitational/teleport/lib/srv/db/common"
// Import to register Elasticsearch engine.
_ "github.com/gravitational/teleport/lib/srv/db/elasticsearch"
// Import to register MongoDB engine.
_ "github.com/gravitational/teleport/lib/srv/db/mongodb"
"github.com/gravitational/teleport/lib/srv/db/elasticsearch"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
"github.com/gravitational/teleport/lib/srv/db/mysql"
// Import to register Postgres engine.
_ "github.com/gravitational/teleport/lib/srv/db/postgres"
// Import to register Redis engine.
_ "github.com/gravitational/teleport/lib/srv/db/redis"
// Import to register Snowflake engine.
_ "github.com/gravitational/teleport/lib/srv/db/snowflake"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/srv/db/redis"
"github.com/gravitational/teleport/lib/srv/db/snowflake"
"github.com/gravitational/teleport/lib/srv/db/sqlserver"
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(cassandra.NewEngine, defaults.ProtocolCassandra)
common.RegisterEngine(elasticsearch.NewEngine, defaults.ProtocolElasticsearch)
common.RegisterEngine(mongodb.NewEngine, defaults.ProtocolMongoDB)
common.RegisterEngine(mysql.NewEngine, defaults.ProtocolMySQL)
common.RegisterEngine(postgres.NewEngine, defaults.ProtocolPostgres, defaults.ProtocolCockroachDB)
common.RegisterEngine(redis.NewEngine, defaults.ProtocolRedis)
common.RegisterEngine(snowflake.NewEngine, defaults.ProtocolSnowflake)
common.RegisterEngine(sqlserver.NewEngine, defaults.ProtocolSQLServer)
}
// Config is the configuration for a database proxy server.
type Config struct {
// Clock used to control time.
@ -292,6 +297,10 @@ func (m *monitoredDatabases) get() types.ResourcesWithLabelsMap {
// New returns a new database server.
func New(ctx context.Context, config Config) (*Server, error) {
if err := common.CheckEngines(defaults.DatabaseProtocols...); err != nil {
return nil, trace.Wrap(err)
}
err := config.CheckAndSetDefaults(ctx)
if err != nil {
return nil, trace.Wrap(err)

View file

@ -44,12 +44,8 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(newEngine, defaults.ProtocolSnowflake)
}
// newEngine create new Snowflake engine.
func newEngine(ec common.EngineConfig) common.Engine {
// NewEngine create new Snowflake engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
HTTPClient: getDefaultHTTPClient(),

View file

@ -46,7 +46,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db/snowflake"
)
func init() {
func registerTestSnowflakeEngine() {
// Override Snowflake engine that is used normally with the test one
// with custom HTTP client.
common.RegisterEngine(newTestSnowflakeEngine, defaults.ProtocolSnowflake)

View file

@ -24,7 +24,6 @@ import (
"github.com/gravitational/trace"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/defaults"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db/common"
@ -32,11 +31,8 @@ import (
"github.com/gravitational/teleport/lib/utils"
)
func init() {
common.RegisterEngine(newEngine, defaults.ProtocolSQLServer)
}
func newEngine(ec common.EngineConfig) common.Engine {
// NewEngine create new SQL Server engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
Connector: &connector{