teleport/integration/ec2_test.go
2022-07-28 13:34:27 +00:00

476 lines
16 KiB
Go

/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package integration
import (
"context"
"fmt"
"io"
"os"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/breaker"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/lite"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/labels/ec2"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
)
func newSilentLogger() utils.Logger {
logger := utils.NewLoggerForTests()
logger.SetLevel(logrus.PanicLevel)
logger.SetOutput(io.Discard)
return logger
}
func newNodeConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinMethod types.JoinMethod) *service.Config {
config := service.MakeDefaultConfig()
config.Token = tokenName
config.JoinMethod = joinMethod
config.SSH.Enabled = true
config.SSH.Addr.Addr = helpers.NewListener(t, service.ListenerNodeSSH, &config.FileDescriptors)
config.Auth.Enabled = false
config.Proxy.Enabled = false
config.DataDir = t.TempDir()
config.AuthServers = append(config.AuthServers, authAddr)
config.Log = newSilentLogger()
config.CircuitBreakerConfig = breaker.NoopBreakerConfig()
return config
}
func newProxyConfig(t *testing.T, authAddr utils.NetAddr, tokenName string, joinMethod types.JoinMethod) *service.Config {
config := service.MakeDefaultConfig()
config.Version = defaults.TeleportConfigVersionV2
config.Token = tokenName
config.JoinMethod = joinMethod
config.SSH.Enabled = false
config.Auth.Enabled = false
proxyAddr := helpers.NewListener(t, service.ListenerProxyWeb, &config.FileDescriptors)
config.Proxy.Enabled = true
config.Proxy.DisableWebInterface = true
config.Proxy.WebAddr.Addr = proxyAddr
config.Proxy.EnableProxyProtocol = true
config.DataDir = t.TempDir()
config.AuthServers = append(config.AuthServers, authAddr)
config.Log = newSilentLogger()
config.CircuitBreakerConfig = breaker.NoopBreakerConfig()
return config
}
func newAuthConfig(t *testing.T, clock clockwork.Clock) *service.Config {
var err error
storageConfig := backend.Config{
Type: lite.GetName(),
Params: backend.Params{
"path": t.TempDir(),
"poll_stream_period": 50 * time.Millisecond,
},
}
config := service.MakeDefaultConfig()
config.DataDir = t.TempDir()
config.Auth.ListenAddr.Addr = helpers.NewListener(t, service.ListenerAuth, &config.FileDescriptors)
config.Auth.ClusterName, err = services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{
ClusterName: "testcluster",
})
require.NoError(t, err)
config.AuthServers = append(config.AuthServers, config.Auth.ListenAddr)
config.Auth.StorageConfig = storageConfig
config.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex)
config.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{
StaticTokens: []types.ProvisionTokenV1{},
})
require.NoError(t, err)
config.Proxy.Enabled = false
config.SSH.Enabled = false
config.Clock = clock
config.Log = newSilentLogger()
config.CircuitBreakerConfig = breaker.NoopBreakerConfig()
return config
}
func getIID(t *testing.T) imds.InstanceIdentityDocument {
cfg, err := config.LoadDefaultConfig(context.TODO())
require.NoError(t, err)
imdsClient := imds.NewFromConfig(cfg)
output, err := imdsClient.GetInstanceIdentityDocument(context.TODO(), nil)
require.NoError(t, err)
return output.InstanceIdentityDocument
}
func getCallerIdentity(t *testing.T) *sts.GetCallerIdentityOutput {
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
require.NoError(t, err)
stsService := sts.New(sess)
output, err := stsService.GetCallerIdentity(nil /*input*/)
require.NoError(t, err)
return output
}
// TestEC2NodeJoin is an integration test which asserts that the EC2 method for
// Simplified Node Joining works when run on a real EC2 instance with access to
// the IMDS and the ec2:DesribeInstances API. This is a very basic test, unit
// testing with mocked AWS endpoints is in lib/auth/join_ec2_test.go
func TestEC2NodeJoin(t *testing.T) {
if os.Getenv("TELEPORT_TEST_EC2") == "" {
t.Skipf("Skipping TestEC2NodeJoin because TELEPORT_TEST_EC2 is not set")
}
// fetch the IID to create a token which will match this instance
iid := getIID(t)
tokenName := "test_token"
token, err := types.NewProvisionTokenFromSpec(
tokenName,
time.Now().Add(time.Hour),
types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Allow: []*types.TokenRule{
{
AWSAccount: iid.AccountID,
AWSRegions: []string{iid.Region},
},
},
})
require.NoError(t, err)
// mock the current time so that the IID will pass the TTL check
clock := clockwork.NewFakeClockAt(iid.PendingTime.Add(time.Second))
// create and start the auth server
authConfig := newAuthConfig(t, clock)
authSvc, err := service.NewTeleport(authConfig)
require.NoError(t, err)
require.NoError(t, authSvc.Start())
t.Cleanup(func() { require.NoError(t, authSvc.Close()) })
authServer := authSvc.GetAuthServer()
authServer.SetClock(clock)
err = authServer.UpsertToken(context.Background(), token)
require.NoError(t, err)
// sanity check there are no nodes to start with
nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace)
require.NoError(t, err)
require.Empty(t, nodes)
// create and start the node
nodeConfig := newNodeConfig(t, authConfig.Auth.ListenAddr, tokenName, types.JoinMethodEC2)
nodeSvc, err := service.NewTeleport(nodeConfig)
require.NoError(t, err)
require.NoError(t, nodeSvc.Start())
t.Cleanup(func() { require.NoError(t, nodeSvc.Close()) })
_, err = nodeSvc.WaitForEventTimeout(10*time.Second, service.TeleportReadyEvent)
require.NoError(t, err, "timeout waiting for node readiness")
// the node should eventually join the cluster and heartbeat
require.Eventually(t, func() bool {
nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace)
require.NoError(t, err)
return len(nodes) > 0
}, time.Minute, time.Second, "waiting for node to join cluster")
}
// TestIAMNodeJoin is an integration test which asserts that the IAM method for
// Simplified Node Joining works when run on a real EC2 instance with an
// attached IAM role. This is a very basic test, unit testing with mocked AWS
// endpoints is in lib/auth/join_iam_test.go
func TestIAMNodeJoin(t *testing.T) {
if os.Getenv("TELEPORT_TEST_EC2") == "" {
t.Skipf("Skipping TestIAMNodeJoin because TELEPORT_TEST_EC2 is not set")
}
// create and start the auth server
authConfig := newAuthConfig(t, nil /*clock*/)
authSvc, err := service.NewTeleport(authConfig)
require.NoError(t, err)
require.NoError(t, authSvc.Start())
t.Cleanup(func() { require.NoError(t, authSvc.Close()) })
authServer := authSvc.GetAuthServer()
// fetch the caller identity to find the AWS account and create the token
id := getCallerIdentity(t)
tokenName := "test_token"
token, err := types.NewProvisionTokenFromSpec(
tokenName,
time.Now().Add(time.Hour),
types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode, types.RoleProxy},
Allow: []*types.TokenRule{
{
AWSAccount: *id.Account,
},
},
JoinMethod: types.JoinMethodIAM,
})
require.NoError(t, err)
err = authServer.UpsertToken(context.Background(), token)
require.NoError(t, err)
// sanity check there are no proxies to start with
proxies, err := authServer.GetProxies()
require.NoError(t, err)
require.Empty(t, proxies)
// create and start the proxy, will use the IAM method to join by connecting
// directly to the auth server
proxyConfig := newProxyConfig(t, authConfig.Auth.ListenAddr, tokenName, types.JoinMethodIAM)
proxySvc, err := service.NewTeleport(proxyConfig)
require.NoError(t, err)
require.NoError(t, proxySvc.Start())
t.Cleanup(func() { require.NoError(t, proxySvc.Close()) })
// the proxy should eventually join the cluster and heartbeat
require.Eventually(t, func() bool {
proxies, err := authServer.GetProxies()
require.NoError(t, err)
return len(proxies) > 0
}, time.Minute, time.Second, "waiting for proxy to join cluster")
// InsecureDevMode needed for node to trust proxy
wasInsecureDevMode := lib.IsInsecureDevMode()
t.Cleanup(func() { lib.SetInsecureDevMode(wasInsecureDevMode) })
lib.SetInsecureDevMode(true)
// sanity check there are no nodes to start with
nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace)
require.NoError(t, err)
require.Empty(t, nodes)
// create and start a node, with use the IAM method to join in IoT mode by
// connecting to the proxy
nodeConfig := newNodeConfig(t, proxyConfig.Proxy.WebAddr, tokenName, types.JoinMethodIAM)
nodeSvc, err := service.NewTeleport(nodeConfig)
require.NoError(t, err)
require.NoError(t, nodeSvc.Start())
t.Cleanup(func() { require.NoError(t, nodeSvc.Close()) })
// the node should eventually join the cluster and heartbeat
require.Eventually(t, func() bool {
nodes, err := authServer.GetNodes(context.Background(), apidefaults.Namespace)
require.NoError(t, err)
return len(nodes) > 0
}, time.Minute, time.Second, "waiting for node to join cluster")
}
type mockIMDSClient struct {
tags map[string]string
}
func (m *mockIMDSClient) IsAvailable(ctx context.Context) bool {
return true
}
func (m *mockIMDSClient) GetTagKeys(ctx context.Context) ([]string, error) {
keys := make([]string, 0, len(m.tags))
for k := range m.tags {
keys = append(keys, k)
}
return keys, nil
}
func (m *mockIMDSClient) GetTagValue(ctx context.Context, key string) (string, error) {
if value, ok := m.tags[key]; ok {
return value, nil
}
return "", trace.NotFound("Tag %q not found", key)
}
// TestEC2Labels is an integration test which asserts that Teleport correctly picks up
// EC2 tags when running on an EC2 instance.
func TestEC2Labels(t *testing.T) {
storageConfig := backend.Config{
Type: lite.GetName(),
Params: backend.Params{
"path": t.TempDir(),
"poll_stream_period": 50 * time.Millisecond,
},
}
tconf := service.MakeDefaultConfig()
tconf.Log = newSilentLogger()
tconf.DataDir = t.TempDir()
tconf.Auth.Enabled = true
tconf.Proxy.Enabled = true
tconf.Proxy.SSHAddr.Addr = helpers.NewListener(t, service.ListenerProxySSH, &tconf.FileDescriptors)
tconf.Proxy.WebAddr.Addr = helpers.NewListener(t, service.ListenerProxyWeb, &tconf.FileDescriptors)
tconf.Proxy.ReverseTunnelListenAddr.Addr = helpers.NewListener(t, service.ListenerProxyTunnel, &tconf.FileDescriptors)
tconf.Proxy.DisableWebInterface = true
tconf.Auth.StorageConfig = storageConfig
tconf.Auth.ListenAddr.Addr = helpers.NewListener(t, service.ListenerAuth, &tconf.FileDescriptors)
tconf.AuthServers = append(tconf.AuthServers, tconf.Auth.ListenAddr)
tconf.SSH.Enabled = true
tconf.SSH.Addr.Addr = helpers.NewListener(t, service.ListenerNodeSSH, &tconf.FileDescriptors)
appConf := service.App{
Name: "test-app",
URI: "app.example.com",
}
tconf.Apps.Enabled = true
tconf.Apps.Apps = []service.App{appConf}
dbConfig := service.Database{
Name: "test-db",
Protocol: "postgres",
URI: "postgres://somewhere.example.com",
}
tconf.Databases.Enabled = true
tconf.Databases.Databases = []service.Database{dbConfig}
helpers.EnableKubernetesService(t, tconf)
imClient := &mockIMDSClient{
tags: map[string]string{
"Name": "my-instance",
},
}
proc, err := service.NewTeleport(tconf, service.WithIMDSClient(imClient))
require.NoError(t, err)
require.NoError(t, proc.Start())
t.Cleanup(func() { require.NoError(t, proc.Close()) })
ctx := context.Background()
authServer := proc.GetAuthServer()
var nodes []types.Server
var apps []types.AppServer
var databases []types.DatabaseServer
var kubes []types.Server
// Wait for everything to come online.
require.Eventually(t, func() bool {
var err error
nodes, err = authServer.GetNodes(ctx, tconf.SSH.Namespace)
require.NoError(t, err)
apps, err = authServer.GetApplicationServers(ctx, tconf.SSH.Namespace)
require.NoError(t, err)
databases, err = authServer.GetDatabaseServers(ctx, tconf.SSH.Namespace)
require.NoError(t, err)
kubes, err = authServer.GetKubeServices(ctx)
require.NoError(t, err)
return len(nodes) == 1 && len(apps) == 1 && len(databases) == 1 && len(kubes) == 1
}, 10*time.Second, time.Second)
tagName := fmt.Sprintf("%s/Name", ec2.AWSNamespace)
// Check that EC2 labels were applied.
require.Eventually(t, func() bool {
node, err := authServer.GetNode(ctx, tconf.SSH.Namespace, nodes[0].GetName())
require.NoError(t, err)
_, nodeHasLabel := node.GetAllLabels()[tagName]
apps, err := authServer.GetApplicationServers(ctx, tconf.SSH.Namespace)
require.NoError(t, err)
require.Len(t, apps, 1)
app := apps[0].GetApp()
_, appHasLabel := app.GetAllLabels()[tagName]
databases, err := authServer.GetDatabaseServers(ctx, tconf.SSH.Namespace)
require.NoError(t, err)
require.Len(t, databases, 1)
database := databases[0].GetDatabase()
_, dbHasLabel := database.GetAllLabels()[tagName]
kubeClusters := helpers.GetKubeClusters(t, authServer)
require.Len(t, kubeClusters, 1)
kube := kubeClusters[0]
_, kubeHasLabel := kube.StaticLabels[tagName]
return nodeHasLabel && appHasLabel && dbHasLabel && kubeHasLabel
}, 10*time.Second, time.Second)
}
// TestEC2Hostname is an integration test which asserts that Teleport sets its
// hostname if the EC2 tag `TeleportHostname` is available.
func TestEC2Hostname(t *testing.T) {
teleportHostname := "fakehost.example.com"
storageConfig := backend.Config{
Type: lite.GetName(),
Params: backend.Params{
"path": t.TempDir(),
"poll_stream_period": 50 * time.Millisecond,
},
}
tconf := service.MakeDefaultConfig()
tconf.Log = newSilentLogger()
tconf.DataDir = t.TempDir()
tconf.Auth.Enabled = true
tconf.Proxy.Enabled = true
tconf.Proxy.DisableWebInterface = true
tconf.Proxy.SSHAddr.Addr = helpers.NewListener(t, service.ListenerProxySSH, &tconf.FileDescriptors)
tconf.Proxy.WebAddr.Addr = helpers.NewListener(t, service.ListenerProxyWeb, &tconf.FileDescriptors)
tconf.Auth.StorageConfig = storageConfig
tconf.Auth.ListenAddr.Addr = helpers.NewListener(t, service.ListenerAuth, &tconf.FileDescriptors)
tconf.AuthServers = append(tconf.AuthServers, tconf.Auth.ListenAddr)
tconf.SSH.Enabled = true
tconf.SSH.Addr.Addr = helpers.NewListener(t, service.ListenerNodeSSH, &tconf.FileDescriptors)
imClient := &mockIMDSClient{
tags: map[string]string{
types.EC2HostnameTag: teleportHostname,
},
}
proc, err := service.NewTeleport(tconf, service.WithIMDSClient(imClient))
require.NoError(t, err)
require.NoError(t, proc.Start())
t.Cleanup(func() { require.NoError(t, proc.Close()) })
ctx := context.Background()
authServer := proc.GetAuthServer()
var node types.Server
require.Eventually(t, func() bool {
nodes, err := authServer.GetNodes(ctx, tconf.SSH.Namespace)
require.NoError(t, err)
if len(nodes) == 1 {
node = nodes[0]
return true
}
return false
}, 10*time.Second, time.Second)
require.Equal(t, teleportHostname, node.GetHostname())
}