/* Copyright 2015-2020 Gravitational, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package web import ( "archive/tar" "bufio" "bytes" "compress/gzip" "context" "crypto" "crypto/tls" "crypto/x509" "encoding/base32" "encoding/base64" "encoding/hex" "encoding/json" "encoding/pem" "errors" "fmt" "io" "net" "net/http" "net/http/cookiejar" "net/http/httptest" "net/url" "os" "os/user" "path/filepath" "sort" "strings" "testing" "time" "github.com/gogo/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/julienschmidt/httprouter" lemma_secret "github.com/mailgun/lemma/secret" "github.com/mailgun/timetools" "github.com/pquerna/otp/totp" "github.com/sashabaranov/go-openai" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" commonv1 "go.opentelemetry.io/proto/otlp/common/v1" resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1" tracepb "go.opentelemetry.io/proto/otlp/trace/v1" "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/testing/protocmp" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/serializer" authztypes "k8s.io/client-go/kubernetes/typed/authorization/v1" "k8s.io/client-go/tools/clientcmd" clientcmdapi "k8s.io/client-go/tools/clientcmd/api" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/breaker" authproto "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" kubeproto "github.com/gravitational/teleport/api/gen/proto/go/teleport/kube/v1" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth" tlsutils "github.com/gravitational/teleport/lib/auth/keygen" "github.com/gravitational/teleport/lib/auth/mocku2f" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/auth/testauthority" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/bpf" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/conntest" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/httplib/csrf" kubeproxy "github.com/gravitational/teleport/lib/kube/proxy" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/proxy" restricted "github.com/gravitational/teleport/lib/restrictedsession" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/secret" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/srv/desktop" "github.com/gravitational/teleport/lib/srv/desktop/tdp" "github.com/gravitational/teleport/lib/srv/regular" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" websession "github.com/gravitational/teleport/lib/web/session" "github.com/gravitational/teleport/lib/web/ui" ) const hostID = "00000000-0000-0000-0000-000000000000" type WebSuite struct { ctx context.Context cancel context.CancelFunc node *regular.Server proxy *regular.Server proxyTunnel reversetunnelclient.Server srvID string user string webServer *httptest.Server webHandler *APIHandler mockU2F *mocku2f.Key server *auth.TestServer proxyClient *auth.Client clock clockwork.FakeClock } // TestMain will re-execute Teleport to run a command if "exec" is passed to // it as an argument. Otherwise, it will run tests as normal. func TestMain(m *testing.M) { utils.InitLoggerForTests() // If the test is re-executing itself, execute the command that comes over // the pipe. if srv.IsReexec() { srv.RunAndExit(os.Args[1]) return } native.PrecomputeTestKeys(m) // Otherwise run tests as normal. code := m.Run() os.Exit(code) } func newWebSuite(t *testing.T) *WebSuite { return newWebSuiteWithConfig(t, webSuiteConfig{}) } type webSuiteConfig struct { // AuthPreferenceSpec is custom initial AuthPreference spec for the test. authPreferenceSpec *types.AuthPreferenceSpecV2 disableDiskBasedRecording bool uiConfig webclient.UIConfig // Custom "HealthCheckAppServer" function. Can be used to avoid dialing app // services. HealthCheckAppServer healthCheckAppServerFunc // OpenAIConfig is a custom OpenAI config for the test. OpenAIConfig *openai.ClientConfig // ClusterFeatures allows overriding default auth server features ClusterFeatures *authproto.Features } func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { mockU2F, err := mocku2f.Create() require.NoError(t, err) require.NotNil(t, mockU2F) u, err := user.Current() require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) s := &WebSuite{ mockU2F: mockU2F, clock: clockwork.NewFakeClock(), user: u.Username, ctx: ctx, cancel: cancel, } networkingConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ KeepAliveInterval: types.Duration(10 * time.Second), }) require.NoError(t, err) authCfg := auth.TestServerConfig{ Auth: auth.TestAuthServerConfig{ ClusterName: "localhost", Dir: t.TempDir(), Clock: s.clock, ClusterNetworkingConfig: networkingConfig, AuthPreferenceSpec: cfg.authPreferenceSpec, }, } if cfg.disableDiskBasedRecording { authCfg.Auth.AuditLog = events.NewDiscardAuditLog() } s.server, err = auth.NewTestServer(authCfg) require.NoError(t, err) if cfg.disableDiskBasedRecording { // use a sync recording mode because the disk-based uploader // that runs in the background introduces races with test cleanup recConfig := types.DefaultSessionRecordingConfig() recConfig.SetMode(types.RecordAtNodeSync) err := s.server.AuthServer.AuthServer.SetSessionRecordingConfig(context.Background(), recConfig) require.NoError(t, err) } // Register the auth server, since test auth server doesn't start its own // heartbeat. err = s.server.Auth().UpsertAuthServer(ctx, &types.ServerV2{ Kind: types.KindAuthServer, Version: types.V2, Metadata: types.Metadata{ Namespace: apidefaults.Namespace, Name: "auth", }, Spec: types.ServerSpecV2{ Addr: s.server.TLS.Listener.Addr().String(), Hostname: "localhost", Version: teleport.Version, }, }) require.NoError(t, err) priv, pub, err := testauthority.New().GenerateKeyPair() require.NoError(t, err) tlsPub, err := auth.PrivateKeyToPublicKeyTLS(priv) require.NoError(t, err) nodeID := "node" // start node certs, err := s.server.Auth().GenerateHostCerts(s.ctx, &authproto.HostCertsRequest{ HostID: hostID, NodeName: nodeID, Role: types.RoleNode, PublicSSHKey: pub, PublicTLSKey: tlsPub, }) require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) require.NoError(t, err) nodeClient, err := s.server.NewClient(auth.TestIdentity{ I: authz.BuiltinRole{ Role: types.RoleNode, Username: nodeID, }, }) require.NoError(t, err) nodeLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentNode, Client: nodeClient, }, }) require.NoError(t, err) nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ Semaphores: nodeClient, AccessPoint: nodeClient, LockEnforcer: nodeLockWatcher, Emitter: nodeClient, Component: teleport.ComponentNode, ServerID: nodeID, }) require.NoError(t, err) // create SSH service: nodeDataDir := t.TempDir() node, err := regular.New( ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, nodeID, []ssh.Signer{signer}, nodeClient, nodeDataDir, "", utils.NetAddr{}, nodeClient, regular.SetUUID(nodeID), regular.SetNamespace(apidefaults.Namespace), regular.SetShell("/bin/sh"), regular.SetEmitter(nodeClient), regular.SetPAMConfig(&servicecfg.PAMConfig{Enabled: false}), regular.SetBPF(&bpf.NOP{}), regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(nodeLockWatcher), regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) s.node = node s.srvID = node.ID() require.NoError(t, s.node.Start()) // create reverse tunnel service: proxyID := "proxy" s.proxyClient, err = s.server.NewClient(auth.TestIdentity{ I: authz.BuiltinRole{ Role: types.RoleProxy, Username: proxyID, }, }) require.NoError(t, err) revTunListener, err := net.Listen("tcp", fmt.Sprintf("%v:0", s.server.ClusterName())) require.NoError(t, err) proxyLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, Client: s.proxyClient, }, }) require.NoError(t, err) proxyNodeWatcher, err := services.NewNodeWatcher(s.ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, Client: s.proxyClient, }, }) require.NoError(t, err) caWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, Client: s.proxyClient, }, Types: []types.CertAuthType{types.HostCA, types.UserCA}, }) require.NoError(t, err) defer caWatcher.Close() revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, ClientTLS: s.proxyClient.TLSConfig(), ClusterName: s.server.ClusterName(), HostSigners: []ssh.Signer{signer}, LocalAuthClient: s.proxyClient, LocalAccessPoint: s.proxyClient, Emitter: s.proxyClient, NewCachingAccessPoint: noCache, DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), LocalAuthAddresses: []string{s.server.TLS.Listener.Addr().String()}, Clock: s.clock, }) require.NoError(t, err) s.proxyTunnel = revTunServer router, err := proxy.NewRouter(proxy.RouterConfig{ ClusterName: s.server.ClusterName(), Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), RemoteClusterGetter: s.proxyClient, SiteGetter: revTunServer, TracerProvider: tracing.NoopProvider(), }) require.NoError(t, err) proxySessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ Semaphores: s.proxyClient, AccessPoint: s.proxyClient, LockEnforcer: proxyLockWatcher, Emitter: s.proxyClient, Component: teleport.ComponentProxy, ServerID: proxyID, }) require.NoError(t, err) // proxy server: s.proxy, err = regular.New( ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, s.server.ClusterName(), []ssh.Signer{signer}, s.proxyClient, t.TempDir(), "", utils.NetAddr{}, s.proxyClient, regular.SetUUID(proxyID), regular.SetProxyMode("", revTunServer, s.proxyClient, router), regular.SetEmitter(s.proxyClient), regular.SetNamespace(apidefaults.Namespace), regular.SetBPF(&bpf.NOP{}), regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(proxyLockWatcher), regular.SetNodeWatcher(proxyNodeWatcher), regular.SetSessionController(proxySessionController), ) require.NoError(t, err) // Expired sessions are purged immediately var sessionLingeringThreshold time.Duration fs, err := newDebugFileSystem() require.NoError(t, err) features := *modules.GetModules().Features().ToProto() // safe to dereference because ToProto creates a struct and return a pointer to it if cfg.ClusterFeatures != nil { features = *cfg.ClusterFeatures } handlerConfig := Config{ ClusterFeatures: features, Proxy: revTunServer, AuthServers: utils.FromAddr(s.server.TLS.Addr()), DomainName: s.server.ClusterName(), ProxyClient: s.proxyClient, CipherSuites: utils.DefaultCipherSuites(), AccessPoint: s.proxyClient, Context: s.ctx, HostUUID: proxyID, Emitter: s.proxyClient, StaticFS: fs, CachedSessionLingeringThreshold: &sessionLingeringThreshold, ProxySettings: &mockProxySettings{}, SessionControl: SessionControllerFunc(func(ctx context.Context, sctx *SessionContext, login, localAddr, remoteAddr string) (context.Context, error) { controller := srv.WebSessionController(proxySessionController) ctx, err := controller(ctx, sctx, login, localAddr, remoteAddr) return ctx, trace.Wrap(err) }), Router: router, HealthCheckAppServer: cfg.HealthCheckAppServer, UI: cfg.uiConfig, OpenAIConfig: cfg.OpenAIConfig, } if handlerConfig.HealthCheckAppServer == nil { handlerConfig.HealthCheckAppServer = func(context.Context, string, string) error { return nil } } handler, err := NewHandler(handlerConfig, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(s.clock)) require.NoError(t, err) s.webServer = httptest.NewUnstartedServer(handler) s.webHandler = handler s.webServer.StartTLS() err = s.proxy.Start() require.NoError(t, err) // Wait for proxy to fully register before starting the test. for start := time.Now(); ; { proxies, err := s.proxyClient.GetProxies() require.NoError(t, err) if len(proxies) != 0 { break } if time.Since(start) > 5*time.Second { t.Fatal("proxy didn't register within 5s after startup") } } proxyAddr := utils.MustParseAddr(s.proxy.Addr()) addr := utils.MustParseAddr(s.webServer.Listener.Addr().String()) handler.handler.cfg.ProxyWebAddr = *addr handler.handler.cfg.ProxySSHAddr = *proxyAddr _, sshPort, err := net.SplitHostPort(proxyAddr.String()) require.NoError(t, err) handler.handler.sshPort = sshPort t.Cleanup(func() { // In particular close the lock watchers by canceling the context. s.cancel() s.webServer.Close() var errors []error if err := s.proxyTunnel.Close(); err != nil { errors = append(errors, err) } if err := s.node.Close(); err != nil { errors = append(errors, err) } s.webServer.Close() if err := s.proxy.Close(); err != nil { errors = append(errors, err) } if err := s.server.Shutdown(context.Background()); err != nil { errors = append(errors, err) } require.Empty(t, errors) }) return s } func (s *WebSuite) addNode(t *testing.T, uuid string, hostname string, address string) *regular.Server { priv, pub, err := testauthority.New().GenerateKeyPair() require.NoError(t, err) tlsPub, err := auth.PrivateKeyToPublicKeyTLS(priv) require.NoError(t, err) // start node certs, err := s.server.Auth().GenerateHostCerts(s.ctx, &authproto.HostCertsRequest{ HostID: uuid, NodeName: hostname, Role: types.RoleNode, PublicSSHKey: pub, PublicTLSKey: tlsPub, }) require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) require.NoError(t, err) nodeClient, err := s.server.NewClient(auth.TestIdentity{ I: authz.BuiltinRole{ Role: types.RoleNode, Username: uuid, }, }) require.NoError(t, err) nodeLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentNode, Client: nodeClient, }, }) require.NoError(t, err) nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ Semaphores: nodeClient, AccessPoint: nodeClient, LockEnforcer: nodeLockWatcher, Emitter: nodeClient, Component: teleport.ComponentNode, ServerID: uuid, }) require.NoError(t, err) // create SSH service: node, err := regular.New( context.Background(), utils.NetAddr{AddrNetwork: "tcp", Addr: address}, hostname, []ssh.Signer{signer}, nodeClient, t.TempDir(), "", utils.NetAddr{}, nodeClient, regular.SetUUID(uuid), regular.SetNamespace(apidefaults.Namespace), regular.SetShell("/bin/sh"), regular.SetEmitter(nodeClient), regular.SetPAMConfig(&servicecfg.PAMConfig{Enabled: false}), regular.SetBPF(&bpf.NOP{}), regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(s.clock), regular.SetLockWatcher(nodeLockWatcher), regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) require.NoError(t, node.Start()) t.Cleanup(func() { require.NoError(t, node.Close()) node.Wait() }) return node } func noCache(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) { return clt, nil } func (r *authPack) renewSession(ctx context.Context, t *testing.T) *roundtrip.Response { resp, err := r.clt.PostJSON(ctx, r.clt.Endpoint("webapi", "sessions", "web", "renew"), nil) require.NoError(t, err) return resp } func (r *authPack) validateAPI(ctx context.Context, t *testing.T) { _, err := r.clt.Get(ctx, r.clt.Endpoint("webapi", "sites"), url.Values{}) require.NoError(t, err) } type authPack struct { otpSecret string user string login string password string session *CreateSessionResponse clt *TestWebClient cookies []*http.Cookie device *auth.TestDevice } // authPack returns new authenticated package consisting of created valid // user, otp token, created web session and authenticated client. func (s *WebSuite) authPack(t *testing.T, user string, roles ...string) *authPack { login := s.user pass := "abc123" rawSecret := "def456" otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) require.NoError(t, err) s.createUser(t, user, login, pass, otpSecret, roles...) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, s.clock.Now()) require.NoError(t, err) clt := s.client(t) req := CreateSessionReq{ User: user, Pass: pass, SecondFactorToken: validToken, } csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" re, err := s.login(clt, csrfToken, csrfToken, req) require.NoError(t, err) var rawSess *CreateSessionResponse require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) sess, err := rawSess.response() require.NoError(t, err) jar, err := cookiejar.New(nil) require.NoError(t, err) clt = s.client(t, roundtrip.BearerAuth(sess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) return &authPack{ otpSecret: otpSecret, user: user, login: login, session: sess, clt: clt, cookies: re.Cookies(), } } func (s *WebSuite) authPackWithMFA(t *testing.T, name string, roles ...types.Role) *authPack { const password = "testing" user, err := types.NewUser(name) require.NoError(t, err) userRole := services.RoleForUser(user) userRole.SetLogins(types.Allow, []string{s.user}) err = s.server.Auth().UpsertRole(s.ctx, userRole) require.NoError(t, err) for _, role := range roles { err = s.server.Auth().UpsertRole(s.ctx, role) require.NoError(t, err) user.AddRole(role.GetName()) } user.AddRole(userRole.GetName()) err = s.server.Auth().CreateUser(s.ctx, user) require.NoError(t, err) clt := s.client(t) // create register challenge token, err := s.server.Auth().CreateResetPasswordToken(s.ctx, auth.CreateUserTokenRequest{ Name: name, }) require.NoError(t, err) res, err := s.server.Auth().CreateRegisterChallenge(s.ctx, &authproto.CreateRegisterChallengeRequest{ TokenID: token.GetName(), DeviceType: authproto.DeviceType_DEVICE_TYPE_WEBAUTHN, DeviceUsage: authproto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS, }) require.NoError(t, err) cc := wanlib.CredentialCreationFromProto(res.GetWebauthn()) // use passwordless as auth method device, err := mocku2f.Create() require.NoError(t, err) device.SetPasswordless() ccr, err := device.SignCredentialCreation("https://localhost", cc) require.NoError(t, err) _, err = s.server.Auth().ChangeUserAuthentication(s.ctx, &authproto.ChangeUserAuthenticationRequest{ TokenID: token.GetName(), NewPassword: []byte(password), NewMFARegisterResponse: &authproto.MFARegisterResponse{ Response: &authproto.MFARegisterResponse_Webauthn{ Webauthn: wanlib.CredentialCreationResponseToProto(ccr), }, }, }) require.NoError(t, err) beginReq := &client.MFAChallengeRequest{ User: name, Pass: password, } re, err := s.loginMFA(clt, beginReq, device) require.NoError(t, err) var rawSess *CreateSessionResponse require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) sess, err := rawSess.response() require.NoError(t, err) jar, err := cookiejar.New(nil) require.NoError(t, err) clt = s.client(t, roundtrip.BearerAuth(sess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) return &authPack{ user: name, login: s.user, session: sess, clt: clt, cookies: re.Cookies(), device: &auth.TestDevice{Key: device}, } } func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string, roles ...string) { teleUser, err := types.NewUser(user) require.NoError(t, err) role := services.RoleForUser(teleUser) role.SetLogins(types.Allow, []string{login}) options := role.GetOptions() options.ForwardAgent = types.NewBool(true) role.SetOptions(options) err = s.server.Auth().UpsertRole(s.ctx, role) require.NoError(t, err) teleUser.AddRole(role.GetName()) for _, r := range roles { teleUser.AddRole(r) } teleUser.SetCreatedBy(types.CreatedBy{ User: types.UserRef{Name: "some-auth-user"}, }) err = s.server.Auth().CreateUser(s.ctx, teleUser) require.NoError(t, err) err = s.server.Auth().UpsertPassword(user, []byte(pass)) require.NoError(t, err) if otpSecret != "" { dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) require.NoError(t, err) err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) require.NoError(t, err) } } func verifySecurityResponseHeaders(t *testing.T, h http.Header) { t.Helper() cases := []struct { header string expectedValue string }{ { header: "X-Content-Type-Options", expectedValue: "nosniff", }, { header: "Referrer-Policy", expectedValue: "strict-origin", }, { header: "X-Frame-Options", expectedValue: "SAMEORIGIN", }, { header: "Strict-Transport-Security", expectedValue: "max-age=31536000; includeSubDomains", }, } for _, tc := range cases { require.Contains(t, h, tc.header) require.Equal(t, tc.expectedValue, h.Get(tc.header)) } } func TestValidRedirectURL(t *testing.T) { t.Parallel() for _, tt := range []struct { desc, url string valid bool }{ {"valid absolute https url", "https://example.com?a=1", true}, {"valid absolute http url", "http://example.com?a=1", true}, {"valid relative url", "/path/to/something", true}, {"garbage", "fjoiewjwpods302j09", false}, {"empty string", "", false}, {"block bad protocol", "javascript:alert('xss')", false}, } { t.Run(tt.desc, func(t *testing.T) { require.Equal(t, tt.valid, isValidRedirectURL(tt.url)) }) } } func TestMetaRedirect(t *testing.T) { t.Parallel() h := &Handler{} redirectHandler := h.WithMetaRedirect(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) string { return "https://example.com" }) req := httptest.NewRequest(http.MethodPost, "/some/route", nil) resp := httptest.NewRecorder() redirectHandler(resp, req, nil) targetElement := `` require.Equal(t, http.StatusOK, resp.Code) body := resp.Body.String() require.Contains(t, body, targetElement) } func Test_clientMetaFromReq(t *testing.T) { ua := "foobar" r := httptest.NewRequest( http.MethodGet, "https://example.com/webapi/foo", nil, ) r.Header.Set("User-Agent", ua) got := clientMetaFromReq(r) require.Equal(t, &auth.ForwardedClientMetadata{ UserAgent: ua, RemoteAddr: "192.0.2.1:1234", }, got) } func TestWebSessionsCRUD(t *testing.T) { t.Parallel() s := newWebSuite(t) pack := s.authPack(t, "foo") // make sure we can use client to make authenticated requests re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) require.NoError(t, err) var clusters []ui.Cluster require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) // now delete session _, err = pack.clt.Delete( context.Background(), pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // subsequent requests trying to use this session will fail _, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) } func TestCSRF(t *testing.T) { t.Parallel() s := newWebSuite(t) type input struct { reqToken string cookieToken string } // create a valid user user := "csrfuser" pass := "abc123" otpSecret := base32.StdEncoding.EncodeToString([]byte("def456")) s.createUser(t, user, user, pass, otpSecret) // create a valid login form request validToken, err := totp.GenerateCode(otpSecret, time.Now()) require.NoError(t, err) loginForm := CreateSessionReq{ User: user, Pass: pass, SecondFactorToken: validToken, } encodedToken1 := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" encodedToken2 := "bf355921bbf3ef3672a03e410d4194077dfa5fe863c652521763b3e7f81e7b11" invalid := []input{ {reqToken: encodedToken2, cookieToken: encodedToken1}, {reqToken: "", cookieToken: encodedToken1}, {reqToken: "", cookieToken: ""}, {reqToken: encodedToken1, cookieToken: ""}, } clt := s.client(t) // valid _, err = s.login(clt, encodedToken1, encodedToken1, loginForm) require.NoError(t, err) // invalid for i := range invalid { _, err := s.login(clt, invalid[i].cookieToken, invalid[i].reqToken, loginForm) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) } } func TestPasswordChange(t *testing.T) { t.Parallel() s := newWebSuite(t) pack := s.authPack(t, "foo") // invalidate the token s.clock.Advance(1 * time.Minute) validToken, err := totp.GenerateCode(pack.otpSecret, s.clock.Now()) require.NoError(t, err) req := changePasswordReq{ OldPassword: []byte("abc123"), NewPassword: []byte("abc1234"), SecondFactorToken: validToken, } _, err = pack.clt.PutJSON(context.Background(), pack.clt.Endpoint("webapi", "users", "password"), req) require.NoError(t, err) } // TestValidateBearerToken tests that the bearer token's user name // matches the user name on the cookie. func TestValidateBearerToken(t *testing.T) { t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] pack1 := proxy.authPack(t, "user1", nil /* roles */) pack2 := proxy.authPack(t, "user2", nil /* roles */) // Swap pack1's session token with pack2's sessionToken jar, err := cookiejar.New(nil) require.NoError(t, err) pack1.clt = proxy.newClient(t, roundtrip.BearerAuth(pack2.session.Token), roundtrip.CookieJar(jar)) jar.SetCookies(&proxy.webURL, pack1.cookies) // Auth protected endpoint. req := changePasswordReq{} _, err = pack1.clt.PutJSON(context.Background(), pack1.clt.Endpoint("webapi", "users", "password"), req) require.True(t, trace.IsAccessDenied(err)) require.True(t, strings.Contains(err.Error(), "bad bearer token")) } func TestWebSessionsBadInput(t *testing.T) { t.Parallel() s := newWebSuite(t) user := "bob" pass := "abc123" rawSecret := "def456" otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) err := s.server.Auth().UpsertPassword(user, []byte(pass)) require.NoError(t, err) dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) require.NoError(t, err) err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) require.NoError(t, err) // create valid token validToken, err := totp.GenerateCode(otpSecret, time.Now()) require.NoError(t, err) clt := s.client(t) reqs := []CreateSessionReq{ // empty request {}, // missing user { Pass: pass, SecondFactorToken: validToken, }, // missing pass { User: user, SecondFactorToken: validToken, }, // bad pass { User: user, Pass: "bla bla", SecondFactorToken: validToken, }, // bad otp token { User: user, Pass: pass, SecondFactorToken: "bad token", }, // missing otp token { User: user, Pass: pass, }, } for i, req := range reqs { t.Run(fmt.Sprintf("tc %v", i), func(t *testing.T) { _, err := clt.PostJSON(s.ctx, clt.Endpoint("webapi", "sessions", "web"), req) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) }) } } type clusterNodesGetResponse struct { Items []ui.Server `json:"items"` StartKey string `json:"startKey"` TotalCount int `json:"totalCount"` } func TestClusterNodesGet(t *testing.T) { t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "test-user@example.com", nil /* roles */) // Get the node already added by `newWebPack` servers, err := env.server.Auth().GetNodes(context.Background(), apidefaults.Namespace) require.NoError(t, err) require.Len(t, servers, 1) server1 := servers[0] // Add another node. server2, err := types.NewServerWithLabels("server2", types.KindNode, types.ServerSpecV2{}, map[string]string{"test-field": "test-value"}) require.NoError(t, err) _, err = env.server.Auth().UpsertNode(context.Background(), server2) require.NoError(t, err) // Get nodes from endpoint. clusterName := env.server.ClusterName() endpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "nodes") query := url.Values{"sort": []string{"name"}} // Get nodes. re, err := pack.clt.Get(context.Background(), endpoint, query) require.NoError(t, err) // Test response. res := clusterNodesGetResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &res)) require.Len(t, res.Items, 2) require.Equal(t, 2, res.TotalCount) require.ElementsMatch(t, res.Items, []ui.Server{ { ClusterName: clusterName, Name: server1.GetName(), Hostname: server1.GetHostname(), Tunnel: server1.GetUseTunnel(), Addr: server1.GetAddr(), Labels: []ui.Label{}, SSHLogins: []string{pack.login}, }, { ClusterName: clusterName, Name: "server2", Labels: []ui.Label{{Name: "test-field", Value: "test-value"}}, Tunnel: false, SSHLogins: []string{pack.login}, }, }) // Get nodes using shortcut. re, err = pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", currentSiteShortcut, "nodes"), query) require.NoError(t, err) res2 := clusterNodesGetResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &res2)) require.Len(t, res.Items, 2) require.Equal(t, res, res2) } type clusterAlertsGetResponse struct { Alerts []types.ClusterAlert `json:"alerts"` } func TestClusterAlertsGet(t *testing.T) { t.Parallel() env := newWebPack(t, 1) // generate alert alert, err := types.NewClusterAlert( "test-alert", "test alert message", types.WithAlertSeverity(0), types.WithAlertLabel(types.AlertOnLogin, "yes"), // AlertPermitAll is necessary because the alert is only shown to // admin clients by default. types.WithAlertLabel(types.AlertPermitAll, "yes"), ) require.NoError(t, err) err = env.server.Auth().UpsertClusterAlert(context.Background(), alert) require.NoError(t, err) // get alerts. clusterName := env.server.ClusterName() pack := env.proxies[0].authPack(t, "test-user@example.com", nil) endpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "alerts") re, err := pack.clt.Get(context.Background(), endpoint, nil) require.NoError(t, err) alerts := clusterAlertsGetResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &alerts)) require.Len(t, alerts.Alerts, 1) } func TestSiteNodeConnectInvalidSessionID(t *testing.T) { t.Parallel() s := newWebSuite(t) _, _, err := s.makeTerminal(t, s.authPack(t, "foo"), withSessionID("/../../../foo")) require.Error(t, err) } func TestResolveServerHostPort(t *testing.T) { t.Parallel() sampleNode := types.ServerV2{} sampleNode.SetName("eca53e45-86a9-11e7-a893-0242ac0a0101") sampleNode.Spec.Hostname = "nodehostname" // valid cases validCases := []struct { server string nodes []types.Server expectedHost string expectedPort int }{ { server: "localhost", expectedHost: "localhost", expectedPort: 0, }, { server: "localhost:8080", expectedHost: "localhost", expectedPort: 8080, }, { server: "eca53e45-86a9-11e7-a893-0242ac0a0101", nodes: []types.Server{&sampleNode}, expectedHost: "nodehostname", expectedPort: 0, }, } // invalid cases invalidCases := []struct { server string expectedErr string }{ { server: ":22", expectedErr: "empty hostname", }, { server: ":", expectedErr: "empty hostname", }, { server: "", expectedErr: "empty server name", }, { server: "host:", expectedErr: "invalid port", }, { server: "host:port", expectedErr: "invalid port", }, } for _, testCase := range validCases { host, port, err := resolveServerHostPort(testCase.server, testCase.nodes) require.NoError(t, err, testCase.server) require.Equal(t, testCase.expectedHost, host, testCase.server) require.Equal(t, testCase.expectedPort, port, testCase.server) } for _, testCase := range invalidCases { _, _, err := resolveServerHostPort(testCase.server, nil) require.Error(t, err, testCase.server) require.Regexp(t, ".*"+testCase.expectedErr+".*", err.Error(), testCase.server) } } func isFileTransferRequest(e *Envelope) bool { if e.GetType() != defaults.WebsocketAudit { return false } var ef events.EventFields if err := json.Unmarshal([]byte(e.GetPayload()), &ef); err != nil { return false } return ef.GetType() == string(srv.FileTransferUpdate) } func isFileTransferDecision(e *Envelope) bool { if e.GetType() != defaults.WebsocketAudit { return false } var ef events.EventFields if err := json.Unmarshal([]byte(e.GetPayload()), &ef); err != nil { return false } return ef.GetType() == string(srv.FileTransferApproved) } func getRequestId(e *Envelope) (string, error) { var ef events.EventFields if err := json.Unmarshal([]byte(e.GetPayload()), &ef); err != nil { return "", err } return ef.GetString("requestID"), nil } func TestFileTransferEvents(t *testing.T) { t.Parallel() s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true}) errs := make(chan error, 2) readLoop := func(ctx context.Context, ws *websocket.Conn, ch chan<- *Envelope) { for { select { case <-ctx.Done(): return default: } typ, b, err := ws.ReadMessage() if err != nil { errs <- err return } if typ != websocket.BinaryMessage { errs <- trace.BadParameter("expected binary message, got %v", typ) return } var envelope Envelope if err := proto.Unmarshal(b, &envelope); err != nil { errs <- trace.Wrap(err) return } ch <- &envelope } } // Create a new user "foo", open a terminal to a new session pack := s.authPack(t, "foo") ws, _, err := s.makeTerminal(t, pack) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) wsMessages := make(chan *Envelope) go readLoop(ctx, ws, wsMessages) // Create file transfer event data, err := json.Marshal(events.EventFields{ "download": true, "location": "~/myfile.txt", }) require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketFileTransferRequest, Payload: string(data), } envelopeBytes, err := proto.Marshal(envelope) require.NoError(t, err) err = ws.WriteMessage(websocket.BinaryMessage, envelopeBytes) require.NoError(t, err) done := time.After(5 * time.Second) for { select { case <-done: require.FailNow(t, "expected to receive a file transfer event") case err := <-errs: require.NoError(t, err) case e := <-wsMessages: if isFileTransferRequest(e) { requestId, err := getRequestId(e) require.NoError(t, err) data, err := json.Marshal(events.EventFields{ "requestId": requestId, "approved": true, }) require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketFileTransferDecision, Payload: string(data), } envelopeBytes, err := proto.Marshal(envelope) require.NoError(t, err) err = ws.WriteMessage(websocket.BinaryMessage, envelopeBytes) require.NoError(t, err) } if isFileTransferDecision(e) { return } } } } func TestNewTerminalHandler(t *testing.T) { ctx := context.Background() invalidCases := []struct { expectedErr string cfg TerminalHandlerConfig }{ { expectedErr: "sid: invalid session id", cfg: TerminalHandlerConfig{ SessionData: session.Session{ ID: session.ID("not a uuid"), }, }, }, { expectedErr: "login: missing login", cfg: TerminalHandlerConfig{ SessionData: session.Session{ ID: session.NewID(), Login: "", }, }, }, { expectedErr: "server: missing server", cfg: TerminalHandlerConfig{ SessionData: session.Session{ ID: session.NewID(), Login: "root", ServerID: "", }, }, }, { expectedErr: "term: bad dimensions(-1x0)", cfg: TerminalHandlerConfig{ SessionData: session.Session{ ID: session.NewID(), Login: "root", ServerID: uuid.New().String(), }, Term: session.TerminalParams{ W: -1, H: 0, }, }, }, { expectedErr: "term: bad dimensions(1x4097)", cfg: TerminalHandlerConfig{ SessionData: session.Session{ ID: session.NewID(), Login: "root", ServerID: uuid.New().String(), }, Term: session.TerminalParams{ W: 1, H: 4097, }, }, }, } for _, testCase := range invalidCases { _, err := NewTerminal(ctx, testCase.cfg) require.Equal(t, err.Error(), testCase.expectedErr) } validNode := types.ServerV2{} validNode.SetName("eca53e45-86a9-11e7-a893-0242ac0a0101") validNode.Spec.Hostname = "nodehostname" // Valid Case validCfg := TerminalHandlerConfig{ Term: session.TerminalParams{ W: 100, H: 100, }, SessionCtx: &SessionContext{}, AuthProvider: authProviderMock{ server: validNode, }, LocalAuthProvider: authProviderMock{}, SessionData: session.Session{ ID: session.NewID(), Login: "root", ServerID: uuid.New().String(), }, KeepAliveInterval: time.Duration(100), ProxyHostPort: "1234", InteractiveCommand: make([]string, 1), DisplayLogin: "tree", Router: &proxy.Router{}, } term, err := NewTerminal(ctx, validCfg) require.NoError(t, err) // passed through require.Equal(t, validCfg.SessionCtx, term.ctx) require.Equal(t, validCfg.AuthProvider, term.authProvider) require.Equal(t, validCfg.SessionData, term.sessionData) require.Equal(t, validCfg.KeepAliveInterval, term.keepAliveInterval) require.Equal(t, validCfg.ProxyHostPort, term.proxyHostPort) require.Equal(t, validCfg.InteractiveCommand, term.interactiveCommand) require.Equal(t, validCfg.Term, term.term) require.Equal(t, validCfg.DisplayLogin, term.displayLogin) // newly added require.NotNil(t, term.log) } func TestUIConfig(t *testing.T) { uiConfig := webclient.UIConfig{ ScrollbackLines: 555, } t.Parallel() ctx, cancel := context.WithCancel(context.Background()) s := newWebSuiteWithConfig(t, webSuiteConfig{uiConfig: uiConfig}) clt := s.client(t) endpoint := clt.Endpoint("web", "config.js") re, err := clt.Get(ctx, endpoint, nil) require.NoError(t, err) require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) t.Cleanup(cancel) // Response is type application/javascript, we need to strip off the variable name // and the semicolon at the end, then we are left with json like object. var cfg webclient.WebConfig str := strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) require.NoError(t, err) require.Equal(t, uiConfig, cfg.UI) } func TestResizeTerminal(t *testing.T) { t.Parallel() s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true}) sid := session.NewID() errs := make(chan error, 2) readLoop := func(ctx context.Context, ws *websocket.Conn, ch chan<- *Envelope) { for { select { case <-ctx.Done(): return default: } typ, b, err := ws.ReadMessage() if err != nil { errs <- err return } if typ != websocket.BinaryMessage { errs <- trace.BadParameter("expected binary message, got %v", typ) return } var envelope Envelope if err := proto.Unmarshal(b, &envelope); err != nil { errs <- trace.Wrap(err) return } ch <- &envelope } } // Create a new user "foo", open a terminal to a new session pack1 := s.authPack(t, "foo") ws1, sess, err := s.makeTerminal(t, pack1) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws1.Close()) }) // Create a new user "bar", open a terminal to the session created above pack2 := s.authPack(t, "bar") ws2, sess2, err := s.makeTerminal(t, pack2, withSessionID(sess.ID), withParticipantMode(types.SessionPeerMode)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws2.Close()) }) require.Equal(t, sess.ID, sess2.ID) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) ws1Messages := make(chan *Envelope) ws2Messages := make(chan *Envelope) go readLoop(ctx, ws1, ws1Messages) go readLoop(ctx, ws2, ws2Messages) // consume events from the first terminal // we exect to see at least one raw event with PTY data (indicating terminal ready) // and 2 resize events from the second user joining the session (one for the default // size, and one for the manual resize request) done := time.After(10 * time.Second) t1ResizeEvents, t1RawEvents := 0, 0 t1ready: for { select { case <-done: require.FailNowf(t, "", "expected to receive 2 resize events (got %d) and at least 1 raw event (got %d)", t1ResizeEvents, t1RawEvents) case err := <-errs: require.NoError(t, err) case e := <-ws1Messages: if isResizeEventEnvelope(e) { t1ResizeEvents++ } if e.GetType() == defaults.WebsocketRaw { t1RawEvents++ } if t1ResizeEvents == 2 && t1RawEvents > 0 { break t1ready } } } // we should not expect to see a resize event on terminal 2, // since they are not broadcasted back to the originator select { case e := <-ws2Messages: if isResizeEventEnvelope(e) { require.FailNow(t, "terminal 2 should not have received a resize event") } case err := <-errs: require.NoError(t, err) case <-time.After(1 * time.Second): } // Resize the second terminal. This should be reflected only on the first terminal // because resize events are sent to participants but not the originator.. params, err := session.NewTerminalParamsFromInt(300, 120) require.NoError(t, err) data, err := json.Marshal(events.EventFields{ events.EventType: events.ResizeEvent, events.EventNamespace: apidefaults.Namespace, events.SessionEventID: sid.String(), events.TerminalSize: params.Serialize(), }) require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketResize, Payload: string(data), } envelopeBytes, err := proto.Marshal(envelope) require.NoError(t, err) err = ws2.WriteMessage(websocket.BinaryMessage, envelopeBytes) require.NoError(t, err) // the first terminal should see the resize event done = time.After(5 * time.Second) for { select { case <-done: require.FailNow(t, "expected to receive a final resize event") case err := <-errs: require.NoError(t, err) case e := <-ws1Messages: if isResizeEventEnvelope(e) { return } } } } func isResizeEventEnvelope(e *Envelope) bool { if e.GetType() != defaults.WebsocketAudit { return false } var ef events.EventFields if err := json.Unmarshal([]byte(e.GetPayload()), &ef); err != nil { return false } return ef.GetType() == events.ResizeEvent } // TestTerminalPing tests that the server sends continuous ping control messages. func TestTerminalPing(t *testing.T) { t.Parallel() s := newWebSuite(t) ws, _, err := s.makeTerminal(t, s.authPack(t, "foo"), withKeepaliveInterval(time.Second)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) closed := false done := make(chan struct{}) ws.SetPingHandler(func(message string) error { if closed == false { close(done) closed = true } err := ws.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second)) if err == websocket.ErrCloseSent { return nil } else if e, ok := err.(net.Error); ok && e.Timeout() { return nil } return err }) // We need to continuously read incoming messages in order to process ping messages. // We only care about receiving a ping here so dropping them is fine. go func() { for { _, _, err := ws.ReadMessage() if err != nil { return } } }() select { case <-done: case <-time.After(6 * time.Second): t.Fatal("timeout waiting for ping") } } func TestTerminal(t *testing.T) { t.Parallel() cases := []struct { name string recordingConfig types.SessionRecordingConfigV2 }{ { name: "node recording mode", recordingConfig: types.SessionRecordingConfigV2{ Spec: types.SessionRecordingConfigSpecV2{ Mode: types.RecordAtNodeSync, }, }, }, { name: "proxy recording mode", recordingConfig: types.SessionRecordingConfigV2{ Spec: types.SessionRecordingConfigSpecV2{ Mode: types.RecordAtProxySync, }, }, }, } for _, tt := range cases { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() s := newWebSuite(t) require.NoError(t, s.server.Auth().SetSessionRecordingConfig(context.Background(), &tt.recordingConfig)) ws, _, err := s.makeTerminal(t, s.authPack(t, "foo")) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) validateTerminalStream(t, ws) }) } } func TestTerminalRouting(t *testing.T) { t.Parallel() s := newWebSuite(t) // add nodes with various conflicting values llama := s.addNode(t, uuid.NewString(), "llama", "127.0.0.1:0") s.addNode(t, uuid.NewString(), "llamas", "127.0.0.1:0") alpaca1 := s.addNode(t, uuid.NewString(), "alpaca", "127.0.0.1:0") s.addNode(t, uuid.NewString(), "alpaca", "127.0.0.1:0") closeNoError := func(t *testing.T, err error) { require.NoError(t, err) } closeOkNetworkError := func(t *testing.T, err error) { if err == nil { return } require.True(t, utils.IsOKNetworkError(err), "websocket closure should have return an error indicating that the server already terminated the connection") } cases := []struct { name string target string output string wsCloseAssertion func(t *testing.T, err error) }{ { name: "exact match by uuid", target: llama.ID(), output: "teleport", wsCloseAssertion: closeNoError, }, { name: "exact match by hostname", target: "llama", output: "teleport", wsCloseAssertion: closeNoError, }, { name: "exact match by ip", target: llama.Addr(), output: "teleport", wsCloseAssertion: closeNoError, }, { name: "ambiguous host", target: "alpaca", output: "error: ambiguous host could match multiple nodes", // failed resolution results in the server closing the socket first, so expect an ok close error wsCloseAssertion: closeOkNetworkError, }, { name: "connect by uuid successful when multiple hostnames match", target: alpaca1.ID(), output: "teleport", wsCloseAssertion: closeNoError, }, { name: "ambiguous ip", target: "127.0.0.1", output: "error: ambiguous host could match multiple nodes", // failed resolution results in the server closing the socket first, so expect an ok close error wsCloseAssertion: closeOkNetworkError, }, } for i, tt := range cases { i, tt := i, tt t.Run(tt.name, func(t *testing.T) { t.Parallel() ws, _, err := s.makeTerminal(t, s.authPack(t, fmt.Sprintf("foo-%d", i)), withServer(tt.target)) require.NoError(t, err) t.Cleanup(func() { tt.wsCloseAssertion(t, ws.Close()) }) stream := NewTerminalStream(s.ctx, ws, utils.NewLoggerForTests()) // here we intentionally run a command where the output we're looking // for is not present in the command itself _, err = io.WriteString(stream, "echo txlxport | sed 's/x/e/g'\r\n") require.NoError(t, err) require.NoError(t, waitForOutput(stream, tt.output)) }) } } func TestTerminalNameResolution(t *testing.T) { t.Parallel() s := newWebSuite(t) pack := s.authPack(t, "foo") llama := s.addNode(t, uuid.NewString(), "llama", "127.0.0.1:0") ctx, cancel := context.WithTimeout(context.Background(), 7*time.Second) t.Cleanup(cancel) // Wait for the node to be registered as the registration is asynchronous. require.Eventuallyf(t, func() bool { nodes, err := s.proxyClient.GetNodes(ctx, "default") require.NoError(t, err) return len(nodes) == 2 // one created by default and llama }, 5*time.Second, 200*time.Millisecond, "failed to register node") tests := []struct { name string target string serverID string serverHostname string port int }{ { name: "registered node by name", target: "llama", serverID: llama.ID(), serverHostname: "llama", }, { name: "registered node by address", target: llama.Addr(), serverID: llama.ID(), serverHostname: llama.Addr(), }, { name: "direct dial", target: "root@example.com", serverID: "root@example.com", serverHostname: "root@example.com", }, { name: "direct dial with port", target: "root@example.com:1234", serverID: "root@example.com", serverHostname: "root@example.com", port: 1234, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() ws, resp, err := s.makeTerminal(t, pack, withServer(tt.target)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) require.Equal(t, tt.serverID, resp.ServerID) require.Equal(t, tt.serverHostname, resp.ServerHostname) require.Equal(t, tt.port, resp.ServerHostPort) }) } } func TestTerminalRequireSessionMFA(t *testing.T) { ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "llama", nil /* roles */) clt, err := env.server.NewClient(auth.TestUser("llama")) require.NoError(t, err) cases := []struct { name string getAuthPreference func() types.AuthPreference registerDevice func() *auth.TestDevice getChallengeResponseBytes func(chals *client.MFAAuthenticateChallenge, dev *auth.TestDevice) []byte }{ { name: "with webauthn", getAuthPreference: func() types.AuthPreference { ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorWebauthn, Webauthn: &types.Webauthn{ RPID: "localhost", }, RequireMFAType: types.RequireMFAType_SESSION, }) require.NoError(t, err) return ap }, registerDevice: func() *auth.TestDevice { webauthnDev, err := auth.RegisterTestDevice(ctx, clt, "webauthn", authproto.DeviceType_DEVICE_TYPE_WEBAUTHN, nil /* authenticator */) require.NoError(t, err) return webauthnDev }, getChallengeResponseBytes: func(chals *client.MFAAuthenticateChallenge, dev *auth.TestDevice) []byte { res, err := dev.SolveAuthn(&authproto.MFAAuthenticateChallenge{ WebauthnChallenge: wanlib.CredentialAssertionToProto(chals.WebauthnChallenge), }) require.NoError(t, err) webauthnResBytes, err := json.Marshal(wanlib.CredentialAssertionResponseFromProto(res.GetWebauthn())) require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketWebauthnChallenge, Payload: string(webauthnResBytes), } protoBytes, err := proto.Marshal(envelope) require.NoError(t, err) return protoBytes }, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err = env.server.Auth().SetAuthPreference(ctx, tc.getAuthPreference()) require.NoError(t, err) dev := tc.registerDevice() // Open a terminal to a new session. ws, _ := proxy.makeTerminal(t, pack, "") // Wait for websocket authn challenge event. ty, raw, err := ws.ReadMessage() require.Nil(t, err) require.Equal(t, websocket.BinaryMessage, ty) var env Envelope require.Nil(t, proto.Unmarshal(raw, &env)) chals := &client.MFAAuthenticateChallenge{} require.Nil(t, json.Unmarshal([]byte(env.Payload), &chals)) // Send response over ws. stream := NewTerminalStream(ctx, ws, utils.NewLoggerForTests()) err = stream.ws.WriteMessage(websocket.BinaryMessage, tc.getChallengeResponseBytes(chals, dev)) require.Nil(t, err) // Test we can write. _, err = io.WriteString(stream, "echo txlxport | sed 's/x/e/g'\r\n") require.Nil(t, err) require.Nil(t, waitForOutput(stream, "teleport")) }) } } type windowsDesktopServiceMock struct { listener net.Listener } func mustStartWindowsDesktopMock(t *testing.T, authClient *auth.Server) *windowsDesktopServiceMock { l, err := net.Listen("tcp", "localhost:0") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, l.Close()) }) authID := auth.IdentityID{ Role: types.RoleWindowsDesktop, HostUUID: "windows_server", NodeName: "windows_server", } n, err := authClient.GetClusterName() require.NoError(t, err) dns := []string{"localhost", "127.0.0.1", desktop.WildcardServiceDNS} identity, err := auth.LocalRegister(authID, authClient, nil, dns, "", nil) require.NoError(t, err) tlsConfig, err := identity.TLSConfig(nil) require.NoError(t, err) tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert require.NoError(t, err) ca, err := authClient.GetCertAuthority(context.Background(), types.CertAuthID{Type: types.UserCA, DomainName: n.GetClusterName()}, false) require.NoError(t, err) for _, kp := range services.GetTLSCerts(ca) { require.True(t, tlsConfig.ClientCAs.AppendCertsFromPEM(kp)) } wd := &windowsDesktopServiceMock{ listener: l, } go func() { conn, err := l.Accept() if err != nil { return } tlsConn := tls.Server(conn, tlsConfig) if err := tlsConn.Handshake(); err != nil { t.Errorf("Unexpected error %v", err) return } wd.handleConn(t, tlsConn) }() return wd } func (w *windowsDesktopServiceMock) handleConn(t *testing.T, conn *tls.Conn) { tdpConn := tdp.NewConn(conn) // Ensure that incoming connection is MFAVerified. require.NotEmpty(t, conn.ConnectionState().PeerCertificates) cert := conn.ConnectionState().PeerCertificates[0] identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) require.NoError(t, err) require.NotEmpty(t, identity.MFAVerified) msg, err := tdpConn.ReadMessage() require.NoError(t, err) require.IsType(t, tdp.ClientUsername{}, msg) msg, err = tdpConn.ReadMessage() require.NoError(t, err) require.IsType(t, tdp.ClientScreenSpec{}, msg) err = tdpConn.WriteMessage(tdp.Notification{Message: "test", Severity: tdp.SeverityWarning}) require.NoError(t, err) } func TestDesktopAccessMFARequiresMfa(t *testing.T) { tests := []struct { name string authPref types.AuthPreferenceSpecV2 mfaHandler func(t *testing.T, ws *websocket.Conn, dev *auth.TestDevice) registerDevice func(t *testing.T, ctx context.Context, clt *auth.Client) *auth.TestDevice }{ { name: "webauthn", authPref: types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorWebauthn, Webauthn: &types.Webauthn{ RPID: "localhost", }, RequireMFAType: types.RequireMFAType_SESSION, }, mfaHandler: handleDesktopMFAWebauthnChallenge, registerDevice: func(t *testing.T, ctx context.Context, clt *auth.Client) *auth.TestDevice { webauthnDev, err := auth.RegisterTestDevice(ctx, clt, "webauthn", authproto.DeviceType_DEVICE_TYPE_WEBAUTHN, nil /* authenticator */) require.NoError(t, err) return webauthnDev }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "llama", nil /* roles */) clt, err := env.server.NewClient(auth.TestUser("llama")) require.NoError(t, err) wdID := uuid.New().String() wdMock := mustStartWindowsDesktopMock(t, env.server.Auth()) wd, err := types.NewWindowsDesktopV3("desktop1", nil, types.WindowsDesktopSpecV3{ Addr: wdMock.listener.Addr().String(), Domain: "CORP", HostID: wdID, }) require.NoError(t, err) err = env.server.Auth().UpsertWindowsDesktop(context.Background(), wd) require.NoError(t, err) wds, err := types.NewWindowsDesktopServiceV3(types.Metadata{Name: wdID}, types.WindowsDesktopServiceSpecV3{ Addr: wdMock.listener.Addr().String(), TeleportVersion: teleport.Version, }) require.NoError(t, err) _, err = env.server.Auth().UpsertWindowsDesktopService(context.Background(), wds) require.NoError(t, err) ap, err := types.NewAuthPreference(tc.authPref) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) dev := tc.registerDevice(t, ctx, clt) ws := proxy.makeDesktopSession(t, pack, session.NewID(), env.server.TLS.Listener.Addr()) tc.mfaHandler(t, ws, dev) tdpClient := tdp.NewConn(&WebsocketIO{Conn: ws}) msg, err := tdpClient.ReadMessage() require.NoError(t, err) require.IsType(t, tdp.Notification{}, msg) }) } } func handleDesktopMFAWebauthnChallenge(t *testing.T, ws *websocket.Conn, dev *auth.TestDevice) { br := bufio.NewReader(&WebsocketIO{Conn: ws}) mt, err := br.ReadByte() require.NoError(t, err) require.Equal(t, tdp.TypeMFA, tdp.MessageType(mt)) mfaChallange, err := tdp.DecodeMFAChallenge(br) require.NoError(t, err) res, err := dev.SolveAuthn(&authproto.MFAAuthenticateChallenge{ WebauthnChallenge: wanlib.CredentialAssertionToProto(mfaChallange.WebauthnChallenge), }) require.NoError(t, err) err = tdp.NewConn(&WebsocketIO{Conn: ws}).WriteMessage(tdp.MFA{ Type: defaults.WebsocketWebauthnChallenge[0], MFAAuthenticateResponse: &authproto.MFAAuthenticateResponse{ Response: &authproto.MFAAuthenticateResponse_Webauthn{ Webauthn: res.GetWebauthn(), }, }, }) require.NoError(t, err) } func TestWebAgentForward(t *testing.T) { t.Parallel() s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true}) ws, _, err := s.makeTerminal(t, s.authPack(t, "foo")) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) stream := NewTerminalStream(s.ctx, ws, utils.NewLoggerForTests()) _, err = io.WriteString(stream, "echo $SSH_AUTH_SOCK\r\n") require.NoError(t, err) err = waitForOutput(stream, "/") require.NoError(t, err) } func TestActiveSessions(t *testing.T) { // Use enterprise license (required for moderated sessions). modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise}) s := newWebSuite(t) pack := s.authPack(t, "foo") start := time.Now() kinds := []types.SessionKind{ types.SSHSessionKind, types.KubernetesSessionKind, types.WindowsDesktopSessionKind, types.DatabaseSessionKind, types.AppSessionKind, } ids := make(map[string]struct{}) for _, kind := range kinds { tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ SessionID: string(session.NewID()), ClusterName: s.server.ClusterName(), Kind: string(kind), State: types.SessionState_SessionStateRunning, Created: start, Expires: start.Add(1 * time.Hour), Hostname: s.node.GetInfo().GetHostname(), DesktopName: s.node.GetInfo().GetHostname(), AppName: s.node.GetInfo().GetHostname(), DatabaseName: s.node.GetInfo().GetHostname(), Address: s.srvID, Login: pack.login, Participants: []types.Participant{ {ID: "id", User: "user-1", LastActive: start}, }, HostPolicies: []*types.SessionTrackerPolicySet{ { Name: "foo", Version: "5", RequireSessionJoin: []*types.SessionRequirePolicy{ { Name: "foo", }, }, }, }, }) require.NoError(t, err) ids[tracker.GetSessionID()] = struct{}{} _, err = s.server.Auth().CreateSessionTracker(context.Background(), tracker) require.NoError(t, err) } // create an inactive session, which should not show up inactive, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ SessionID: string(session.NewID()), ClusterName: s.server.ClusterName(), Kind: string(types.SSHSessionKind), State: types.SessionState_SessionStateTerminated, Created: time.Now(), Expires: time.Now().Add(1 * time.Hour), Hostname: s.node.GetInfo().GetHostname(), Address: s.srvID, Login: pack.login, Participants: nil, }) require.NoError(t, err) _, err = s.server.Auth().CreateSessionTracker(context.Background(), inactive) require.NoError(t, err) re, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) require.NoError(t, err) var sessResp siteSessionsGetResponse require.NoError(t, json.Unmarshal(re.Bytes(), &sessResp)) require.Len(t, sessResp.Sessions, len(kinds)) for _, session := range sessResp.Sessions { require.Contains(t, ids, string(session.ID)) require.Equal(t, s.node.GetNamespace(), session.Namespace) require.NotNil(t, session.Parties) require.Greater(t, session.TerminalParams.H, 0) require.Greater(t, session.TerminalParams.W, 0) require.Equal(t, pack.login, session.Login) require.False(t, session.Created.IsZero()) require.False(t, session.LastActive.IsZero()) require.Equal(t, s.srvID, session.ServerID) require.Equal(t, s.node.GetInfo().GetHostname(), session.ServerHostname) require.Equal(t, s.srvID, session.ServerAddr) require.Equal(t, s.server.ClusterName(), session.ClusterName) require.ElementsMatch(t, []types.SessionParticipantMode{"peer"}, session.ParticipantModes) } } func TestCloseConnectionsOnLogout(t *testing.T) { t.Parallel() s := newWebSuite(t) pack := s.authPack(t, "foo") ws, _, err := s.makeTerminal(t, pack) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) stream := NewTerminalStream(s.ctx, ws, utils.NewLoggerForTests()) // to make sure we have a session _, err = io.WriteString(stream, "expr 137 + 39\r\n") require.NoError(t, err) // make sure the server has replied out := make([]byte, 100) _, err = stream.Read(out) require.NoError(t, err) _, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // wait until timeout or detect that the connection has been closed. after := time.After(5 * time.Second) errC := make(chan error) go func() { for { _, err := stream.Read(out) if err != nil { errC <- err return } } }() select { case <-after: t.Fatalf("timeout") case err := <-errC: require.ErrorIs(t, err, io.EOF) } } func TestPlayback(t *testing.T) { t.Parallel() s := newWebSuite(t) pack := s.authPack(t, "foo") ws, _, err := s.makeTerminal(t, pack) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) } type httpErrorMessage struct { Message string `json:"message"` } type httpErrorResponse struct { Error httpErrorMessage `json:"error"` } func TestLogin_PrivateKeyEnabledError(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{ MockAttestHardwareKey: func(_ context.Context, _ interface{}, policy keys.PrivateKeyPolicy, _ *keys.AttestationStatement, _ crypto.PublicKey, _ time.Duration) (keys.PrivateKeyPolicy, error) { return "", keys.NewPrivateKeyPolicyError(policy) }, }) s := newWebSuite(t) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOff, RequireMFAType: types.RequireMFAType_HARDWARE_KEY_TOUCH, }) require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) require.NoError(t, err) // create user s.createUser(t, "user1", "root", "password", "") loginReq, err := json.Marshal(CreateSessionReq{ User: "user1", Pass: "password", }) require.NoError(t, err) clt := s.client(t) req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions", "web"), bytes.NewBuffer(loginReq)) require.NoError(t, err) ua := "test-ua" req.Header.Set("User-Agent", ua) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) req.Header.Set("Content-Type", "application/json") req.Header.Set(csrf.HeaderName, csrfToken) re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) require.NoError(t, err) var resErr httpErrorResponse require.NoError(t, json.Unmarshal(re.Bytes(), &resErr)) require.Contains(t, resErr.Error.Message, keys.PrivateKeyPolicyHardwareKeyTouch) } func TestLogin(t *testing.T) { t.Parallel() s := newWebSuite(t) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOff, }) require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, ap) require.NoError(t, err) // create user s.createUser(t, "user1", "root", "password", "") loginReq, err := json.Marshal(CreateSessionReq{ User: "user1", Pass: "password", }) require.NoError(t, err) clt := s.client(t) ua := "test-ua" req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions", "web"), bytes.NewBuffer(loginReq)) require.NoError(t, err) req.Header.Set("User-Agent", ua) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" addCSRFCookieToReq(req, csrfToken) req.Header.Set("Content-Type", "application/json") req.Header.Set(csrf.HeaderName, csrfToken) re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) require.NoError(t, err) ctx := context.Background() events, _, err := s.server.AuthServer.AuditLog.SearchEvents(ctx, events.SearchEventsRequest{ From: s.clock.Now().Add(-time.Hour), To: s.clock.Now().Add(time.Hour), EventTypes: []string{events.UserLoginEvent}, Limit: 1, Order: types.EventOrderDescending, }) require.NoError(t, err) event := events[0].(*apievents.UserLogin) require.Equal(t, true, event.Success) require.Equal(t, ua, event.UserAgent) require.True(t, strings.HasPrefix(event.RemoteAddr, "127.0.0.1:")) var rawSess *CreateSessionResponse require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) cookies := re.Cookies() require.Len(t, cookies, 1) require.NotEmpty(t, rawSess.SessionExpires) // now make sure we are logged in by calling authenticated method // we need to supply both session cookie and bearer token for // request to succeed jar, err := cookiejar.New(nil) require.NoError(t, err) clt = s.client(t, roundtrip.BearerAuth(rawSess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) re, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) require.NoError(t, err) var clusters []ui.Cluster require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) // in absence of session cookie or bearer auth the same request fill fail // no session cookie: clt = s.client(t, roundtrip.BearerAuth(rawSess.Token)) _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) // no bearer token: clt = s.client(t, roundtrip.CookieJar(jar)) _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) } // TestEmptyMotD ensures that responses returned by both /webapi/ping and // /webapi/motd work when no MotD is set func TestEmptyMotD(t *testing.T) { t.Parallel() s := newWebSuite(t) wc := s.client(t) // Given an auth server configured *not* to expose a Message Of The // Day... // When I issue a ping request... re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) // Expect that the MotD flag in the ping response is *not* set var pingResponse *webclient.PingResponse require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) require.False(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) require.NoError(t, err) // Expect that an empty response returned var motdResponse *webclient.MotD require.NoError(t, json.Unmarshal(re.Bytes(), &motdResponse)) require.Empty(t, motdResponse.Text) } // TestMotD ensures that a response is returned by both /webapi/ping and /webapi/motd // and that that the response bodies contain their MOTD components func TestMotD(t *testing.T) { t.Parallel() const motd = "Hello. I'm a Teleport cluster!" s := newWebSuite(t) wc := s.client(t) // Given an auth server configured to expose a Message Of The Day... prefs := types.DefaultAuthPreference() prefs.SetMessageOfTheDay(motd) require.NoError(t, s.server.AuthServer.AuthServer.SetAuthPreference(s.ctx, prefs)) // When I issue a ping request... re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) // Expect that the MotD flag in the ping response is set to indicate // a MotD var pingResponse *webclient.PingResponse require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) require.True(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) require.NoError(t, err) // Expect that the text returned is the configured value var motdResponse *webclient.MotD require.NoError(t, json.Unmarshal(re.Bytes(), &motdResponse)) require.Equal(t, motd, motdResponse.Text) } // TestPingAutomaticUpgrades ensures /webapi/ping returns whether AutomaticUpgrades are enabled. func TestPingAutomaticUpgrades(t *testing.T) { t.Run("Automatic Upgrades are enabled", func(t *testing.T) { // Enable Automatic Upgrades modules.SetTestModules(t, &modules.TestModules{TestFeatures: modules.Features{ AutomaticUpgrades: true, }}) // Set up s := newWebSuite(t) wc := s.client(t) var pingResponse *webclient.PingResponse // Get Ping response re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) require.True(t, pingResponse.AutomaticUpgrades, "expected automatic upgrades to be enabled") }) t.Run("Automatic Upgrades are disabled", func(t *testing.T) { // Disable Automatic Upgrades modules.SetTestModules(t, &modules.TestModules{TestFeatures: modules.Features{ AutomaticUpgrades: false, }}) // Set up s := newWebSuite(t) wc := s.client(t) var pingResponse *webclient.PingResponse // Get Ping response re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) require.NoError(t, json.Unmarshal(re.Bytes(), &pingResponse)) require.False(t, pingResponse.AutomaticUpgrades, "expected automatic upgrades to be disabled") }) } // TestInstallerRepoChannel ensures the returned installer script has the proper repo channel func TestInstallerRepoChannel(t *testing.T) { t.Run("cloud with automatic upgrades", func(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{ TestFeatures: modules.Features{ Cloud: true, AutomaticUpgrades: true, }, }) s := newWebSuiteWithConfig(t, webSuiteConfig{ authPreferenceSpec: &types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOn, Webauthn: &types.Webauthn{RPID: "localhost"}, }, }) wc := s.client(t) t.Run("documented variables are injected", func(t *testing.T) { // Variables documented here: https://goteleport.com/docs/server-access/guides/ec2-discovery/#step-67-optional-customize-the-default-installer-script err := s.server.Auth().SetInstaller(s.ctx, types.MustNewInstallerV1("custom", `#!/usr/bin/env bash echo {{ .PublicProxyAddr }} echo Teleport-{{ .MajorVersion }} echo Repository Channel: {{ .RepoChannel }} echo AutomaticUpgrades: {{ .AutomaticUpgrades }} `)) require.NoError(t, err) re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "custom"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) // Variables must be injected require.Contains(t, responseString, "echo Teleport-v") require.NotContains(t, responseString, "echo Repository Channel: stable/v") require.Contains(t, responseString, "echo Repository Channel: stable/cloud") require.Contains(t, responseString, "echo AutomaticUpgrades: true") }) t.Run("default-installer", func(t *testing.T) { re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "default-installer"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) // The repo's channel to use is stable/cloud require.Contains(t, responseString, "stable/cloud") require.NotContains(t, responseString, "stable/v") require.Contains(t, responseString, ""+ " PACKAGE_LIST=\"teleport-ent jq\"\n"+ " # shellcheck disable=SC2050\n"+ " if [ \"true\" = \"true\" ]; then\n"+ " PACKAGE_LIST=\"${PACKAGE_LIST} teleport-ent-updater\"\n"+ " fi", ) }) t.Run("default-agentless-installer", func(t *testing.T) { re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "default-agentless-installer"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) // The repo's channel to use is stable/cloud require.Contains(t, responseString, "stable/cloud") require.NotContains(t, responseString, "stable/v") require.Contains(t, responseString, ""+ " PACKAGE_LIST=\"jq teleport-ent\"\n"+ " # shellcheck disable=SC2050\n"+ " if [[ \"true\" == \"true\" ]]; then\n"+ " PACKAGE_LIST=\"${PACKAGE_LIST} teleport-ent-updater\"\n"+ " fi\n", ) }) }) t.Run("cloud without automatic upgrades", func(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{ TestFeatures: modules.Features{ Cloud: true, AutomaticUpgrades: false, }, }) s := newWebSuiteWithConfig(t, webSuiteConfig{ authPreferenceSpec: &types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOn, Webauthn: &types.Webauthn{RPID: "localhost"}, }, }) wc := s.client(t) t.Run("documented variables are injected", func(t *testing.T) { // Variables documented here: https://goteleport.com/docs/server-access/guides/ec2-discovery/#step-67-optional-customize-the-default-installer-script err := s.server.Auth().SetInstaller(s.ctx, types.MustNewInstallerV1("custom", `#!/usr/bin/env bash echo {{ .PublicProxyAddr }} echo Teleport-{{ .MajorVersion }} echo Repository Channel: {{ .RepoChannel }} echo AutomaticUpgrades: {{ .AutomaticUpgrades }} `)) require.NoError(t, err) re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "custom"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) // Variables must be injected require.Contains(t, responseString, "echo Teleport-v") require.Contains(t, responseString, "echo Repository Channel: stable/v") require.NotContains(t, responseString, "echo Repository Channel: stable/cloud") require.Contains(t, responseString, "echo AutomaticUpgrades: false") }) t.Run("default-installer", func(t *testing.T) { re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "default-installer"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) require.NotContains(t, responseString, "stable/cloud") }) t.Run("default-agentless-installer", func(t *testing.T) { re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "default-agentless-installer"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) require.NotContains(t, responseString, "stable/cloud") }) }) t.Run("oss or enterprise with automatic upgrades", func(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{ TestBuildType: modules.BuildOSS, TestFeatures: modules.Features{ Cloud: false, AutomaticUpgrades: true, }, }) s := newWebSuiteWithConfig(t, webSuiteConfig{ authPreferenceSpec: &types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOn, Webauthn: &types.Webauthn{RPID: "localhost"}, }, }) wc := s.client(t) t.Run("documented variables are injected", func(t *testing.T) { // Variables documented here: https://goteleport.com/docs/server-access/guides/ec2-discovery/#step-67-optional-customize-the-default-installer-script err := s.server.Auth().SetInstaller(s.ctx, types.MustNewInstallerV1("custom", `#!/usr/bin/env bash echo {{ .PublicProxyAddr }} echo Teleport-{{ .MajorVersion }} echo Repository Channel: {{ .RepoChannel }} echo AutomaticUpgrades: {{ .AutomaticUpgrades }} `)) require.NoError(t, err) re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "custom"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) // Variables must be injected require.Contains(t, responseString, "echo Teleport-v") require.Contains(t, responseString, "echo Repository Channel: stable/v") require.NotContains(t, responseString, "echo Repository Channel: stable/cloud") require.Contains(t, responseString, "echo AutomaticUpgrades: false") }) t.Run("default-installer", func(t *testing.T) { re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "default-installer"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) // The repo's channel to use is stable/cloud require.NotContains(t, responseString, "stable/cloud") require.Contains(t, responseString, "stable/v") require.Contains(t, responseString, ""+ " PACKAGE_LIST=\"teleport jq\"\n"+ " # shellcheck disable=SC2050\n"+ " if [ \"false\" = \"true\" ]; then\n"+ " PACKAGE_LIST=\"${PACKAGE_LIST} teleport-updater\"\n"+ " fi", ) }) t.Run("default-agentless-installer", func(t *testing.T) { re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "scripts", "installer", "default-agentless-installer"), url.Values{}) require.NoError(t, err) responseString := string(re.Bytes()) // The repo's channel to use is stable/cloud require.NotContains(t, responseString, "stable/cloud") require.Contains(t, responseString, "stable/v") require.Contains(t, responseString, ""+ " PACKAGE_LIST=\"jq teleport\"\n"+ " # shellcheck disable=SC2050\n"+ " if [[ \"false\" == \"true\" ]]; then\n"+ " PACKAGE_LIST=\"${PACKAGE_LIST} teleport-updater\"\n"+ " fi\n", ) }) }) } func TestMultipleConnectors(t *testing.T) { t.Parallel() s := newWebSuite(t) wc := s.client(t) // create two oidc connectors, one named "foo" and another named "bar" oidcConnectorSpec := types.OIDCConnectorSpecV3{ RedirectURLs: []string{"https://localhost:3080/v1/webapi/oidc/callback"}, ClientID: "000000000000-aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.example.com", ClientSecret: "AAAAAAAAAAAAAAAAAAAAAAAA", IssuerURL: "https://oidc.example.com", Display: "Login with Example", Scope: []string{"group"}, ClaimsToRoles: []types.ClaimMapping{ { Claim: "group", Value: "admin", Roles: []string{"admin"}, }, }, } o, err := types.NewOIDCConnector("foo", oidcConnectorSpec) require.NoError(t, err) err = s.server.Auth().UpsertOIDCConnector(s.ctx, o) require.NoError(t, err) o2, err := types.NewOIDCConnector("bar", oidcConnectorSpec) require.NoError(t, err) err = s.server.Auth().UpsertOIDCConnector(s.ctx, o2) require.NoError(t, err) // set the auth preferences to oidc with no connector name authPreference, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: "oidc", }) require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) require.NoError(t, err) // hit the ping endpoint to get the auth type and connector name re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) var out *webclient.PingResponse require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector name we got back was the first connector // in the backend, in this case it's "bar" oidcConnectors, err := s.server.Auth().GetOIDCConnectors(s.ctx, false) require.NoError(t, err) require.Equal(t, oidcConnectors[0].GetName(), out.Auth.OIDC.Name) // update the auth preferences and this time specify the connector name authPreference, err = types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: "oidc", ConnectorName: "foo", }) require.NoError(t, err) err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) require.NoError(t, err) // hit the ping endpoing to get the auth type and connector name re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector we get back is "foo" require.Equal(t, "foo", out.Auth.OIDC.Name) } // TestConstructSSHResponse checks if the secret package uses AES-GCM to // encrypt and decrypt data that passes through the ConstructSSHResponse // function. func TestConstructSSHResponse(t *testing.T) { key, err := secret.NewKey() require.NoError(t, err) u, err := url.Parse("http://www.example.com/callback") require.NoError(t, err) query := u.Query() query.Set("secret_key", key.String()) u.RawQuery = query.Encode() rawresp, err := ConstructSSHResponse(AuthParams{ Username: "foo", Cert: []byte{0x00}, TLSCert: []byte{0x01}, ClientRedirectURL: u.String(), }) require.NoError(t, err) require.Empty(t, rawresp.Query().Get("secret")) require.Empty(t, rawresp.Query().Get("secret_key")) require.NotEmpty(t, rawresp.Query().Get("response")) plaintext, err := key.Open([]byte(rawresp.Query().Get("response"))) require.NoError(t, err) var resp *auth.SSHLoginResponse err = json.Unmarshal(plaintext, &resp) require.NoError(t, err) require.Equal(t, "foo", resp.Username) require.EqualValues(t, []byte{0x00}, resp.Cert) require.EqualValues(t, []byte{0x01}, resp.TLSCert) } // TestConstructSSHResponseLegacy checks if the secret package uses NaCl to // encrypt and decrypt data that passes through the ConstructSSHResponse // function. func TestConstructSSHResponseLegacy(t *testing.T) { key, err := lemma_secret.NewKey() require.NoError(t, err) lemma, err := lemma_secret.New(&lemma_secret.Config{KeyBytes: key}) require.NoError(t, err) u, err := url.Parse("http://www.example.com/callback") require.NoError(t, err) query := u.Query() query.Set("secret", lemma_secret.KeyToEncodedString(key)) u.RawQuery = query.Encode() rawresp, err := ConstructSSHResponse(AuthParams{ Username: "foo", Cert: []byte{0x00}, TLSCert: []byte{0x01}, ClientRedirectURL: u.String(), }) require.NoError(t, err) require.Empty(t, rawresp.Query().Get("secret")) require.Empty(t, rawresp.Query().Get("secret_key")) require.NotEmpty(t, rawresp.Query().Get("response")) var sealedData *lemma_secret.SealedBytes err = json.Unmarshal([]byte(rawresp.Query().Get("response")), &sealedData) require.NoError(t, err) plaintext, err := lemma.Open(sealedData) require.NoError(t, err) var resp *auth.SSHLoginResponse err = json.Unmarshal(plaintext, &resp) require.NoError(t, err) require.Equal(t, "foo", resp.Username) require.EqualValues(t, []byte{0x00}, resp.Cert) require.EqualValues(t, []byte{0x01}, resp.TLSCert) } type byTimeAndIndex []apievents.AuditEvent func (f byTimeAndIndex) Len() int { return len(f) } func (f byTimeAndIndex) Less(i, j int) bool { itime := f[i].GetTime() jtime := f[j].GetTime() if itime.Equal(jtime) && events.GetSessionID(f[i]) == events.GetSessionID(f[j]) { return f[i].GetIndex() < f[j].GetIndex() } return itime.Before(jtime) } func (f byTimeAndIndex) Swap(i, j int) { f[i], f[j] = f[j], f[i] } // TestSearchClusterEvents makes sure web API allows querying events by type. func TestSearchClusterEvents(t *testing.T) { t.Parallel() s := newWebSuite(t) clock := s.clock sessionEvents := eventstest.GenerateTestSession(eventstest.SessionParams{ PrintEvents: 3, Clock: clock, ServerID: s.proxy.ID(), }) for _, e := range sessionEvents { require.NoError(t, s.proxyClient.EmitAuditEvent(s.ctx, e)) } sort.Sort(sort.Reverse(byTimeAndIndex(sessionEvents))) sessionStart := sessionEvents[0] sessionPrint := sessionEvents[1] sessionEnd := sessionEvents[4] fromTime := []string{clock.Now().AddDate(0, -1, 0).UTC().Format(time.RFC3339)} toTime := []string{clock.Now().AddDate(0, 1, 0).UTC().Format(time.RFC3339)} testCases := []struct { // Comment is the test case description. Comment string // Query is the search query sent to the API. Query url.Values // Result is the expected returned list of events. Result []apievents.AuditEvent // TestStartKey is a flag to test start key value. TestStartKey bool // StartKeyValue is the value of start key to expect. StartKeyValue string }{ { Comment: "Empty query", Query: url.Values{ "from": fromTime, "to": toTime, }, Result: sessionEvents, }, { Comment: "Query by session start event", Query: url.Values{ "include": []string{sessionStart.GetType()}, "from": fromTime, "to": toTime, }, Result: sessionEvents[:1], }, { Comment: "Query session start and session end events", Query: url.Values{ "include": []string{sessionEnd.GetType() + "," + sessionStart.GetType()}, "from": fromTime, "to": toTime, }, Result: []apievents.AuditEvent{sessionStart, sessionEnd}, }, { Comment: "Query events with filter by type and limit", Query: url.Values{ "include": []string{sessionPrint.GetType() + "," + sessionEnd.GetType()}, "limit": []string{"1"}, "from": fromTime, "to": toTime, }, Result: []apievents.AuditEvent{sessionPrint}, }, { Comment: "Query session start and session end events with limit and test returned start key", Query: url.Values{ "include": []string{sessionEnd.GetType() + "," + sessionStart.GetType()}, "limit": []string{"1"}, "from": fromTime, "to": toTime, }, Result: []apievents.AuditEvent{sessionStart}, TestStartKey: true, StartKeyValue: sessionStart.GetID(), }, { Comment: "Query session start and session end events with limit and given start key", Query: url.Values{ "include": []string{sessionEnd.GetType() + "," + sessionStart.GetType()}, "startKey": []string{sessionStart.GetID()}, "from": fromTime, "to": toTime, }, Result: []apievents.AuditEvent{sessionEnd}, TestStartKey: true, StartKeyValue: "", }, } pack := s.authPack(t, "foo") for _, tc := range testCases { tc := tc t.Run(tc.Comment, func(t *testing.T) { t.Parallel() response, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "events", "search"), tc.Query) require.NoError(t, err) var result eventsListGetResponse require.NoError(t, json.Unmarshal(response.Bytes(), &result)) // filter out irrelvant auth events filteredEvents := []events.EventFields{} for _, e := range result.Events { t := e.GetType() if t == events.SessionStartEvent || t == events.SessionPrintEvent || t == events.SessionEndEvent { filteredEvents = append(filteredEvents, e) } } require.Len(t, filteredEvents, len(tc.Result)) for i, resultEvent := range filteredEvents { require.Equal(t, tc.Result[i].GetType(), resultEvent.GetType()) require.Equal(t, tc.Result[i].GetID(), resultEvent.GetID()) } // Session prints do not have IDs, only sessionStart and sessionEnd. // When retrieving events for sessionStart and sessionEnd, sessionStart is returned first. if tc.TestStartKey { require.Equal(t, tc.StartKeyValue, result.StartKey) } }) } } func TestGetClusterDetails(t *testing.T) { t.Parallel() s := newWebSuite(t) site, err := s.proxyTunnel.GetSite(s.server.ClusterName()) require.NoError(t, err) require.NotNil(t, site) cluster, err := ui.GetClusterDetails(s.ctx, site) require.NoError(t, err) require.Equal(t, s.server.ClusterName(), cluster.Name) require.Equal(t, teleport.Version, cluster.ProxyVersion) require.Equal(t, fmt.Sprintf("%v:%v", s.server.ClusterName(), defaults.HTTPListenPort), cluster.PublicURL) require.Equal(t, teleport.RemoteClusterStatusOnline, cluster.Status) require.NotNil(t, cluster.LastConnected) require.Equal(t, teleport.Version, cluster.AuthVersion) nodes, err := s.proxyClient.GetNodes(s.ctx, apidefaults.Namespace) require.NoError(t, err) require.Len(t, nodes, cluster.NodeCount) } func TestTokenGeneration(t *testing.T) { const username = "test-user@example.com" // Users should be able to create Tokens even if they can't update them roleTokenCRD, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Rules: []types.Rule{ types.NewRule(types.KindToken, []string{types.VerbCreate, types.VerbRead}), }, }, }) require.NoError(t, err) env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, username, []types.Role{roleTokenCRD}) endpoint := pack.clt.Endpoint("webapi", "token") tt := []struct { name string roles types.SystemRoles shouldErr bool joinMethod types.JoinMethod suggestedAgentMatcherLabels types.Labels allow []*types.TokenRule }{ { name: "single node role", roles: types.SystemRoles{types.RoleNode}, shouldErr: false, }, { name: "single app role", roles: types.SystemRoles{types.RoleApp}, shouldErr: false, }, { name: "single db role", roles: types.SystemRoles{types.RoleDatabase}, shouldErr: false, }, { name: "multiple roles", roles: types.SystemRoles{types.RoleNode, types.RoleApp, types.RoleDatabase}, shouldErr: false, }, { name: "return error if no role is requested", roles: types.SystemRoles{}, shouldErr: true, }, { name: "cannot request token with IAM join method without allow field", roles: types.SystemRoles{types.RoleNode}, joinMethod: types.JoinMethodIAM, shouldErr: true, }, { name: "can request token with IAM join method", roles: types.SystemRoles{types.RoleNode}, joinMethod: types.JoinMethodIAM, allow: []*types.TokenRule{{AWSAccount: "1234"}}, shouldErr: false, }, { name: "adds the agent match labels", roles: types.SystemRoles{types.RoleDatabase}, suggestedAgentMatcherLabels: types.Labels{ "*": apiutils.Strings{"*"}, }, shouldErr: false, }, } for _, tc := range tt { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() re, err := pack.clt.PostJSON(context.Background(), endpoint, types.ProvisionTokenSpecV2{ Roles: tc.roles, JoinMethod: tc.joinMethod, Allow: tc.allow, SuggestedAgentMatcherLabels: tc.suggestedAgentMatcherLabels, }) if tc.shouldErr { require.Error(t, err) return } require.NoError(t, err) var responseToken nodeJoinToken err = json.Unmarshal(re.Bytes(), &responseToken) require.NoError(t, err) require.NotEmpty(t, responseToken.SuggestedLabels) require.Condition(t, func() (success bool) { for _, uiLabel := range responseToken.SuggestedLabels { if uiLabel.Name == types.InternalResourceIDLabel && uiLabel.Value != "" { return true } } return false }) // generated token roles should match the requested ones generatedToken, err := proxy.auth.Auth().GetToken(context.Background(), responseToken.ID) require.NoError(t, err) require.Equal(t, tc.roles, generatedToken.GetRoles()) expectedJoinMethod := tc.joinMethod if tc.joinMethod == "" { expectedJoinMethod = types.JoinMethodToken } // if no joinMethod is provided, expect token method require.Equal(t, expectedJoinMethod, generatedToken.GetJoinMethod()) require.Equal(t, tc.suggestedAgentMatcherLabels, generatedToken.GetSuggestedAgentMatcherLabels()) }) } } func TestInstallDatabaseScriptGeneration(t *testing.T) { const username = "test-user@example.com" // Users should be able to create Tokens even if they can't update them roleTokenCRD, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Rules: []types.Rule{ types.NewRule(types.KindToken, []string{types.VerbCreate, types.VerbRead}), }, }, }) require.NoError(t, err) env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, username, []types.Role{roleTokenCRD}) // Create a new token with the desired SuggestedAgentMatcherLabels endpointGenerateToken := pack.clt.Endpoint("webapi", "token") re, err := pack.clt.PostJSON( context.Background(), endpointGenerateToken, types.ProvisionTokenSpecV2{ Roles: types.SystemRoles{types.RoleDatabase}, SuggestedAgentMatcherLabels: types.Labels{ "stage": apiutils.Strings{"prod"}, }, }) require.NoError(t, err) var responseToken nodeJoinToken require.NoError(t, json.Unmarshal(re.Bytes(), &responseToken)) // Generating the script with the token should return the SuggestedAgentMatcherLabels provided in the first request endpointInstallDatabase := pack.clt.Endpoint("scripts", responseToken.ID, "install-database.sh") t.Log(responseToken, endpointInstallDatabase) req, err := http.NewRequest(http.MethodGet, endpointInstallDatabase, nil) require.NoError(t, err) anonHTTPClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }, } resp, err := anonHTTPClient.Do(req) require.NoError(t, err) scriptBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) require.NoError(t, resp.Body.Close()) script := string(scriptBytes) // It contains the agenbtMatchLabels require.Contains(t, script, "stage: prod") } func TestSignMTLS(t *testing.T) { env := newWebPack(t, 1) clusterName := env.server.ClusterName() proxy := env.proxies[0] pack := proxy.authPack(t, "test-user@example.com", nil) endpoint := pack.clt.Endpoint("webapi", "token") re, err := pack.clt.PostJSON(context.Background(), endpoint, types.ProvisionTokenSpecV2{ Roles: types.SystemRoles{types.RoleDatabase}, }) require.NoError(t, err) var responseToken nodeJoinToken err = json.Unmarshal(re.Bytes(), &responseToken) require.NoError(t, err) // download mTLS files from /webapi/sites/:site/sign/db endpointSign := pack.clt.Endpoint("webapi", "sites", clusterName, "sign", "db") bs, err := json.Marshal(struct { Hostname string `json:"hostname"` TTL string `json:"ttl"` }{ Hostname: "mypg.example.com", TTL: "2h", }) require.NoError(t, err) req, err := http.NewRequest(http.MethodPost, endpointSign, bytes.NewReader(bs)) require.NoError(t, err) req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", "Bearer "+responseToken.ID) anonHTTPClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }, } resp, err := anonHTTPClient.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) gzipReader, err := gzip.NewReader(resp.Body) require.NoError(t, err) tarReader := tar.NewReader(gzipReader) tarContentFileNames := []string{} for { header, err := tarReader.Next() if errors.Is(err, io.EOF) { break } require.NoError(t, err) require.Equal(t, byte(tar.TypeReg), header.Typeflag) require.Equal(t, int64(0o600), header.Mode) tarContentFileNames = append(tarContentFileNames, header.Name) } expectedFileNames := []string{"server.cas", "server.key", "server.crt"} require.ElementsMatch(t, tarContentFileNames, expectedFileNames) // the token is no longer valid, so trying again should return an error req, err = http.NewRequest(http.MethodPost, endpointSign, bytes.NewReader(bs)) require.NoError(t, err) req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", "Bearer "+responseToken.ID) respSecondCall, err := anonHTTPClient.Do(req) require.NoError(t, err) defer respSecondCall.Body.Close() require.Equal(t, http.StatusForbidden, respSecondCall.StatusCode) } func TestSignMTLS_failsAccessDenied(t *testing.T) { env := newWebPack(t, 1) clusterName := env.server.ClusterName() username := "test-user@example.com" roleUserUpdate, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Rules: []types.Rule{ types.NewRule(types.KindUser, []string{types.VerbUpdate}), types.NewRule(types.KindToken, []string{types.VerbCreate}), }, }, }) require.NoError(t, err) proxy := env.proxies[0] pack := proxy.authPack(t, username, []types.Role{roleUserUpdate}) endpoint := pack.clt.Endpoint("webapi", "token") re, err := pack.clt.PostJSON(context.Background(), endpoint, types.ProvisionTokenSpecV2{ Roles: types.SystemRoles{types.RoleProxy}, }) require.NoError(t, err) var responseToken nodeJoinToken err = json.Unmarshal(re.Bytes(), &responseToken) require.NoError(t, err) // download mTLS files from /webapi/sites/:site/sign/db endpointSign := pack.clt.Endpoint("webapi", "sites", clusterName, "sign", "db") bs, err := json.Marshal(struct { Hostname string `json:"hostname"` TTL string `json:"ttl"` Format string `json:"format"` }{ Hostname: "mypg.example.com", TTL: "2h", Format: "db", }) require.NoError(t, err) req, err := http.NewRequest(http.MethodPost, endpointSign, bytes.NewReader(bs)) require.NoError(t, err) req.Header.Add("Content-Type", "application/json") req.Header.Add("Authorization", "Bearer "+responseToken.ID) anonHTTPClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }, } resp, err := anonHTTPClient.Do(req) require.NoError(t, err) defer resp.Body.Close() // It fails because we passed a Provision Token with the wrong Role: Proxy require.Equal(t, http.StatusForbidden, resp.StatusCode) // using a user token also returns Forbidden endpointResetToken := pack.clt.Endpoint("webapi", "users", "password", "token") _, err = pack.clt.PostJSON(context.Background(), endpointResetToken, auth.CreateUserTokenRequest{ Name: username, TTL: time.Minute, Type: auth.UserTokenTypeResetPassword, }) require.NoError(t, err) req, err = http.NewRequest(http.MethodPost, endpointSign, bytes.NewReader(bs)) require.NoError(t, err) resp, err = anonHTTPClient.Do(req) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusForbidden, resp.StatusCode) } // TestCheckAccessToRegisteredResource_AccessDenied tests that access denied error // is ignored. func TestCheckAccessToRegisteredResource_AccessDenied(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "foo", nil /* roles */) // newWebPack already registers 1 node. n, err := env.server.Auth().GetNodes(ctx, env.node.GetNamespace()) require.NoError(t, err) require.Len(t, n, 1) // Checking for access returns true. endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "resources", "check") re, err := pack.clt.Get(ctx, endpoint, url.Values{}) require.NoError(t, err) resp := checkAccessToRegisteredResourceResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.True(t, resp.HasResource) // Deny this resource. fooRole, err := env.server.Auth().GetRole(ctx, "user:foo") require.NoError(t, err) fooRole.SetRules(types.Deny, []types.Rule{types.NewRule(types.KindNode, services.RW())}) require.NoError(t, env.server.Auth().UpsertRole(ctx, fooRole)) // Direct querying should return a access denied error. endpoint = pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "nodes") _, err = pack.clt.Get(ctx, endpoint, url.Values{}) require.True(t, trace.IsAccessDenied(err)) // Checking for access returns false, not an error. endpoint = pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "resources", "check") re, err = pack.clt.Get(ctx, endpoint, url.Values{}) require.NoError(t, err) resp = checkAccessToRegisteredResourceResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.False(t, resp.HasResource) } func TestCheckAccessToRegisteredResource(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "foo", nil /* roles */) // Delete the node that was created by the `newWebPack` to start afresh. require.NoError(t, env.server.Auth().DeleteNode(ctx, env.node.GetNamespace(), env.node.ID())) n, err := env.server.Auth().GetNodes(ctx, env.node.GetNamespace()) require.NoError(t, err) require.Len(t, n, 0) // Double check we start of with no resources. endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "resources", "check") re, err := pack.clt.Get(ctx, endpoint, url.Values{}) require.NoError(t, err) resp := checkAccessToRegisteredResourceResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.False(t, resp.HasResource) // Test all cases return true. tests := []struct { name string resourceKind string insertResource func() deleteResource func() }{ { name: "has registered windows desktop", insertResource: func() { wd, err := types.NewWindowsDesktopV3("test-desktop", nil, types.WindowsDesktopSpecV3{ Addr: "addr", HostID: "hostid", }) require.NoError(t, err) require.NoError(t, env.server.Auth().UpsertWindowsDesktop(ctx, wd)) }, deleteResource: func() { require.NoError(t, env.server.Auth().DeleteWindowsDesktop(ctx, "hostid", "test-desktop")) wds, err := env.server.Auth().GetWindowsDesktops(ctx, types.WindowsDesktopFilter{}) require.NoError(t, err) require.Len(t, wds, 0) }, }, { name: "has registered node", insertResource: func() { resource, err := types.NewServer("test-node", types.KindNode, types.ServerSpecV2{}) require.NoError(t, err) _, err = env.server.Auth().UpsertNode(ctx, resource) require.NoError(t, err) }, deleteResource: func() { require.NoError(t, env.server.Auth().DeleteNode(ctx, apidefaults.Namespace, "test-node")) nodes, err := env.server.Auth().GetNodes(ctx, apidefaults.Namespace) require.NoError(t, err) require.Len(t, nodes, 0) }, }, { name: "has registered app server", insertResource: func() { resource := &types.AppServerV3{ Metadata: types.Metadata{Name: "test-app"}, Kind: types.KindAppServer, Version: types.V2, Spec: types.AppServerSpecV3{ HostID: "hostid", App: &types.AppV3{ Metadata: types.Metadata{ Name: "app-name", }, Spec: types.AppSpecV3{ URI: "https://console.aws.amazon.com", }, }, }, } _, err := env.server.Auth().UpsertApplicationServer(ctx, resource) require.NoError(t, err) }, deleteResource: func() { require.NoError(t, env.server.Auth().DeleteApplicationServer(ctx, apidefaults.Namespace, "hostid", "test-app")) apps, err := env.server.Auth().GetApplicationServers(ctx, apidefaults.Namespace) require.NoError(t, err) require.Len(t, apps, 0) }, }, { name: "has registered db server", insertResource: func() { db, err := types.NewDatabaseServerV3(types.Metadata{ Name: "test-db", }, types.DatabaseServerSpecV3{ Protocol: "test-protocol", URI: "test-uri", Hostname: "test-hostname", HostID: "test-hostID", }) require.NoError(t, err) _, err = env.server.Auth().UpsertDatabaseServer(ctx, db) require.NoError(t, err) }, deleteResource: func() { require.NoError(t, env.server.Auth().DeleteDatabaseServer(ctx, apidefaults.Namespace, "test-hostID", "test-db")) dbs, err := env.server.Auth().GetDatabaseServers(ctx, apidefaults.Namespace) require.NoError(t, err) require.Len(t, dbs, 0) }, }, { name: "has registered kube server", insertResource: func() { kubeCluster, err := types.NewKubernetesClusterV3(types.Metadata{Name: "test-kube-name"}, types.KubernetesClusterSpecV3{}) require.NoError(t, err) kubeServer, err := types.NewKubernetesServerV3FromCluster(kubeCluster, "test-kube", "test-kube") require.NoError(t, err) _, err = env.server.Auth().UpsertKubernetesServer(ctx, kubeServer) require.NoError(t, err) }, deleteResource: func() { require.NoError(t, env.server.Auth().DeleteKubernetesServer(ctx, "test-kube", "test-kube-name")) kubes, err := env.server.Auth().GetKubernetesServers(ctx) require.NoError(t, err) require.Len(t, kubes, 0) }, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { tc.insertResource() re, err := pack.clt.Get(ctx, endpoint, url.Values{}) require.NoError(t, err) resp := checkAccessToRegisteredResourceResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.True(t, resp.HasResource) tc.deleteResource() }) } } func TestAuthExport(t *testing.T) { env := newWebPack(t, 1) clusterName := env.server.ClusterName() proxy := env.proxies[0] pack := proxy.authPack(t, "test-user@example.com", nil) validateTLSCertificateDERFunc := func(t *testing.T, b []byte) { cert, err := x509.ParseCertificate(b) require.NoError(t, err) require.NotNil(t, cert, "ParseCertificate failed") require.Equal(t, "localhost", cert.Subject.CommonName, "unexpected certificate subject CN") } validateTLSCertificatePEMFunc := func(t *testing.T, b []byte) { pemBlock, _ := pem.Decode(b) require.NotNil(t, pemBlock, "pem.Decode failed") validateTLSCertificateDERFunc(t, pemBlock.Bytes) } for _, tt := range []struct { name string authType string expectedStatus int assertBody func(t *testing.T, bs []byte) }{ { name: "all", authType: "", expectedStatus: http.StatusOK, assertBody: func(t *testing.T, b []byte) { require.Contains(t, string(b), "@cert-authority localhost,*.localhost ssh-rsa ") require.Contains(t, string(b), "cert-authority ssh-rsa") }, }, { name: "host", authType: "host", expectedStatus: http.StatusOK, assertBody: func(t *testing.T, b []byte) { require.Contains(t, string(b), "@cert-authority localhost,*.localhost ssh-rsa ") }, }, { name: "user", authType: "user", expectedStatus: http.StatusOK, assertBody: func(t *testing.T, b []byte) { require.Contains(t, string(b), "cert-authority ssh-rsa") }, }, { name: "windows", authType: "windows", expectedStatus: http.StatusOK, assertBody: validateTLSCertificateDERFunc, }, { name: "db", authType: "db", expectedStatus: http.StatusOK, assertBody: validateTLSCertificatePEMFunc, }, { name: "tls", authType: "tls", expectedStatus: http.StatusOK, assertBody: validateTLSCertificatePEMFunc, }, { name: "invalid", authType: "invalid", expectedStatus: http.StatusBadRequest, assertBody: func(t *testing.T, b []byte) { require.Contains(t, string(b), `"invalid" authority type is not supported`) }, }, } { t.Run(tt.name, func(t *testing.T) { // export host certificate t.Run("deprecated endpoint", func(t *testing.T) { endpointExport := pack.clt.Endpoint("webapi", "sites", clusterName, "auth", "export") authExportTestByEndpoint(t, endpointExport, tt.authType, tt.expectedStatus, tt.assertBody) }) t.Run("new endpoint", func(t *testing.T) { endpointExport := pack.clt.Endpoint("webapi", "auth", "export") authExportTestByEndpoint(t, endpointExport, tt.authType, tt.expectedStatus, tt.assertBody) }) }) } } func authExportTestByEndpoint(t *testing.T, endpointExport, authType string, expectedStatus int, assertBody func(t *testing.T, bs []byte)) { ctx := context.Background() if authType != "" { endpointExport = fmt.Sprintf("%s?type=%s", endpointExport, authType) } reqCtx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, endpointExport, nil) require.NoError(t, err) anonHTTPClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, }, } resp, err := anonHTTPClient.Do(req) require.NoError(t, err) defer resp.Body.Close() bs, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, expectedStatus, resp.StatusCode, "invalid status code with body %s", string(bs)) require.NotEmpty(t, bs, "unexpected empty body from http response") if assertBody != nil { assertBody(t, bs) } } func TestClusterDatabasesGet(t *testing.T) { t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "test-user@example.com", nil /* roles */) query := url.Values{"sort": []string{"name"}} endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "databases") re, err := pack.clt.Get(context.Background(), endpoint, query) require.NoError(t, err) type testResponse struct { Items []ui.Database `json:"items"` TotalCount int `json:"totalCount"` } // No db registered. resp := testResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.Len(t, resp.Items, 0) // Register databases. db, err := types.NewDatabaseServerV3(types.Metadata{ Name: "dbServer1", }, types.DatabaseServerSpecV3{ Hostname: "test-hostname", HostID: "test-hostID", Database: &types.DatabaseV3{ Metadata: types.Metadata{ Name: "db1", Description: "test-description", Labels: map[string]string{"test-field": "test-value"}, }, Spec: types.DatabaseSpecV3{ Protocol: "test-protocol", URI: "test-uri:1234", }, }, }) require.NoError(t, err) db2, err := types.NewDatabaseServerV3(types.Metadata{ Name: "dbServer2", }, types.DatabaseServerSpecV3{ Hostname: "test-hostname", HostID: "test-hostID", Database: &types.DatabaseV3{ Metadata: types.Metadata{ Name: "db2", }, Spec: types.DatabaseSpecV3{ Protocol: "test-protocol", URI: "test-uri:1234", }, }, }) require.NoError(t, err) _, err = env.server.Auth().UpsertDatabaseServer(context.Background(), db) require.NoError(t, err) _, err = env.server.Auth().UpsertDatabaseServer(context.Background(), db2) require.NoError(t, err) // Test without defined database names or users in role. re, err = pack.clt.Get(context.Background(), endpoint, query) require.NoError(t, err) resp = testResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.Len(t, resp.Items, 2) require.Equal(t, 2, resp.TotalCount) require.ElementsMatch(t, resp.Items, []ui.Database{{ Name: "db1", Desc: "test-description", Protocol: "test-protocol", Type: types.DatabaseTypeSelfHosted, Labels: []ui.Label{{Name: "test-field", Value: "test-value"}}, Hostname: "test-uri", URI: "test-uri:1234", }, { Name: "db2", Type: types.DatabaseTypeSelfHosted, Labels: []ui.Label{}, Protocol: "test-protocol", Hostname: "test-uri", URI: "test-uri:1234", }}) // Test with a role that defines database names and users. extraRole := &types.RoleV6{ Metadata: types.Metadata{Name: "extra-role"}, Spec: types.RoleSpecV6{ Allow: types.RoleConditions{ DatabaseNames: []string{"name1"}, DatabaseUsers: []string{"user1"}, DatabaseLabels: types.Labels{ "*": []string{"*"}, }, }, }, } pack = proxy.authPack(t, "test-user2@example.com", services.NewRoleSet(extraRole)) endpoint = pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "databases") re, err = pack.clt.Get(context.Background(), endpoint, query) require.NoError(t, err) resp = testResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.Len(t, resp.Items, 2) require.Equal(t, 2, resp.TotalCount) require.ElementsMatch(t, resp.Items, []ui.Database{{ Name: "db1", Desc: "test-description", Protocol: "test-protocol", Type: types.DatabaseTypeSelfHosted, Labels: []ui.Label{{Name: "test-field", Value: "test-value"}}, Hostname: "test-uri", DatabaseUsers: []string{"user1"}, DatabaseNames: []string{"name1"}, URI: "test-uri:1234", }, { Name: "db2", Type: types.DatabaseTypeSelfHosted, Labels: []ui.Label{}, Protocol: "test-protocol", Hostname: "test-uri", DatabaseUsers: []string{"user1"}, DatabaseNames: []string{"name1"}, URI: "test-uri:1234", }}) } func TestClusterDatabaseGet(t *testing.T) { env := newWebPack(t, 1) ctx := context.Background() proxy := env.proxies[0] dbNames := []string{"db1", "db2"} dbUsers := []string{"user1", "user2"} for _, tt := range []struct { name string preRegisterDB bool databaseName string userRoles func(*testing.T) []types.Role expectedDBUsers []string expectedDBNames []string requireError require.ErrorAssertionFunc }{ { name: "valid", preRegisterDB: true, databaseName: "valid", requireError: require.NoError, }, { name: "notfound", preRegisterDB: true, databaseName: "otherdb", requireError: func(tt require.TestingT, err error, i ...interface{}) { require.True(tt, trace.IsNotFound(err), "expected a not found error, got %v", err) }, }, { name: "notauthorized", preRegisterDB: true, databaseName: "notauthorized", userRoles: func(tt *testing.T) []types.Role { role, err := types.NewRole( "myrole", types.RoleSpecV6{ Allow: types.RoleConditions{ DatabaseLabels: types.Labels{ "env": apiutils.Strings{"staging"}, }, }, }, ) require.NoError(tt, err) return []types.Role{role} }, requireError: func(tt require.TestingT, err error, i ...interface{}) { require.True(tt, trace.IsNotFound(err), "expected a not found error, got %v", err) }, }, { name: "nodb", preRegisterDB: false, databaseName: "nodb", userRoles: func(tt *testing.T) []types.Role { roleWithDBName, err := types.NewRole( "myroleWithDBName", types.RoleSpecV6{ Allow: types.RoleConditions{ DatabaseLabels: types.Labels{ "env": apiutils.Strings{"prod"}, }, DatabaseNames: dbNames, }, }, ) require.NoError(tt, err) return []types.Role{roleWithDBName} }, expectedDBNames: dbNames, expectedDBUsers: dbUsers, requireError: func(tt require.TestingT, err error, i ...interface{}) { require.True(tt, trace.IsNotFound(err), "expected a not found error, got %v", err) }, }, { name: "authorizedDBNamesUsers", preRegisterDB: true, databaseName: "authorizedDBNamesUsers", userRoles: func(tt *testing.T) []types.Role { roleWithDBName, err := types.NewRole( "myroleWithDBName", types.RoleSpecV6{ Allow: types.RoleConditions{ DatabaseLabels: types.Labels{ "env": apiutils.Strings{"prod"}, }, DatabaseNames: dbNames, }, }, ) require.NoError(tt, err) roleWithDBUser, err := types.NewRole( "myroleWithDBUser", types.RoleSpecV6{ Allow: types.RoleConditions{ DatabaseLabels: types.Labels{ "env": apiutils.Strings{"prod"}, }, DatabaseUsers: dbUsers, }, }, ) require.NoError(tt, err) return []types.Role{roleWithDBUser, roleWithDBName} }, expectedDBNames: dbNames, expectedDBUsers: dbUsers, requireError: require.NoError, }, } { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() // Create default pre-registerDB if tt.preRegisterDB { db, err := types.NewDatabaseV3(types.Metadata{ Name: tt.name, Labels: map[string]string{ "env": "prod", }, }, types.DatabaseSpecV3{ Protocol: "test-protocol", URI: "test-uri", }) require.NoError(t, err) dbServer, err := types.NewDatabaseServerV3(types.Metadata{ Name: tt.name, }, types.DatabaseServerSpecV3{ Hostname: tt.name, Protocol: "test-protocol", URI: "test-uri", HostID: uuid.NewString(), Database: db, }) require.NoError(t, err) _, err = env.server.Auth().UpsertDatabaseServer(context.Background(), dbServer) require.NoError(t, err) } var roles []types.Role if tt.userRoles != nil { roles = tt.userRoles(t) } pack := proxy.authPack(t, tt.name+"_user@example.com", roles) endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "databases", tt.databaseName) re, err := pack.clt.Get(ctx, endpoint, nil) tt.requireError(t, err) if err != nil { return } resp := ui.Database{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.Equal(t, tt.databaseName, resp.Name, "database name") require.Equal(t, types.DatabaseTypeSelfHosted, resp.Type, "database type") require.EqualValues(t, []ui.Label{{Name: "env", Value: "prod"}}, resp.Labels) require.ElementsMatch(t, tt.expectedDBUsers, resp.DatabaseUsers) require.ElementsMatch(t, tt.expectedDBNames, resp.DatabaseNames) }) } } func TestClusterKubesGet(t *testing.T) { env := newWebPack(t, 1) proxy := env.proxies[0] extraRole := &types.RoleV6{ Metadata: types.Metadata{Name: "extra-role"}, Spec: types.RoleSpecV6{ Allow: types.RoleConditions{ KubeUsers: []string{"user1"}, KubeGroups: []string{"group1"}, KubernetesLabels: types.Labels{ "*": []string{"*"}, }, }, }, } cluster1, err := types.NewKubernetesClusterV3( types.Metadata{ Name: "test-kube1", Labels: map[string]string{"test-field": "test-value"}, }, types.KubernetesClusterSpecV3{}, ) require.NoError(t, err) // duplicate same server for i := 0; i < 3; i++ { server, err := types.NewKubernetesServerV3FromCluster( cluster1, fmt.Sprintf("hostname-%d", i), fmt.Sprintf("uid-%d", i), ) require.NoError(t, err) // Register a kube service. _, err = env.server.Auth().UpsertKubernetesServer(context.Background(), server) require.NoError(t, err) } cluster2, err := types.NewKubernetesClusterV3( types.Metadata{ Name: "test-kube2", }, types.KubernetesClusterSpecV3{}, ) require.NoError(t, err) server2, err := types.NewKubernetesServerV3FromCluster( cluster2, "test-kube2-hostname", "test-kube2-hostid", ) require.NoError(t, err) _, err = env.server.Auth().UpsertKubernetesServer(context.Background(), server2) require.NoError(t, err) type testResponse struct { Items []ui.KubeCluster `json:"items"` TotalCount int `json:"totalCount"` } tt := []struct { name string user string extraRoles services.RoleSet expectedResponse []ui.KubeCluster }{ { name: "user with no extra roles", user: "test-user@example.com", expectedResponse: []ui.KubeCluster{ { Name: "test-kube1", Labels: []ui.Label{{Name: "test-field", Value: "test-value"}}, KubeUsers: nil, KubeGroups: nil, }, { Name: "test-kube2", Labels: []ui.Label{}, KubeUsers: nil, KubeGroups: nil, }, }, }, { name: "user with extra roles", user: "test-user2@example.com", extraRoles: services.NewRoleSet(extraRole), expectedResponse: []ui.KubeCluster{ { Name: "test-kube1", Labels: []ui.Label{{Name: "test-field", Value: "test-value"}}, KubeUsers: []string{"user1"}, KubeGroups: []string{"group1"}, }, { Name: "test-kube2", Labels: []ui.Label{}, KubeUsers: []string{"user1"}, KubeGroups: []string{"group1"}, }, }, }, } for _, tc := range tt { pack := proxy.authPack(t, tc.user, tc.extraRoles) endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "kubernetes") re, err := pack.clt.Get(context.Background(), endpoint, url.Values{}) require.NoError(t, err) resp := testResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.Len(t, resp.Items, 2) require.Equal(t, 2, resp.TotalCount) require.ElementsMatch(t, tc.expectedResponse, resp.Items) } } func TestClusterKubePodsGet(t *testing.T) { t.Parallel() kubeClusterName := "kube_cluster" roleWithFullAccess := func(username string) []types.Role { ret, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Namespaces: []string{apidefaults.Namespace}, KubernetesLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, Rules: []types.Rule{ types.NewRule(types.KindConnectionDiagnostic, services.RW()), }, KubeGroups: []string{"groups"}, KubernetesResources: []types.KubernetesResource{ { Kind: types.KindKubePod, Namespace: types.Wildcard, Name: types.Wildcard, }, }, }, }) require.NoError(t, err) return []types.Role{ret} } require.NotNil(t, roleWithFullAccess) env := newWebPack(t, 1) type testResponse struct { Items []ui.KubeResource `json:"items"` TotalCount int `json:"totalCount"` } tt := []struct { name string user string expectedResponse []ui.KubeResource }{ { name: "get pods from gRPC server", user: "test-user@example.com", expectedResponse: []ui.KubeResource{ { Kind: types.KindKubePod, Name: "test-pod", Namespace: "default", Labels: []ui.Label{{Name: "app", Value: "test"}}, KubeCluster: kubeClusterName, }, { Kind: types.KindKubePod, Name: "test-pod2", Namespace: "default", Labels: []ui.Label{{Name: "app", Value: "test2"}}, KubeCluster: kubeClusterName, }, }, }, } proxy := env.proxies[0] listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) // Init fake grpc Kube service. initGRPCServer(t, env, listener) addr := utils.MustParseAddr(listener.Addr().String()) proxy.handler.handler.cfg.ProxyWebAddr = *addr for _, tc := range tt { tc := tc t.Run(tc.name, func(t *testing.T) { pack := proxy.authPack(t, tc.user, roleWithFullAccess(tc.user)) endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "pods") params := url.Values{} params.Add("kubeCluster", kubeClusterName) re, err := pack.clt.Get(context.Background(), endpoint, params) require.NoError(t, err) resp := testResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.Len(t, resp.Items, 2) require.Equal(t, 2, resp.TotalCount) require.ElementsMatch(t, tc.expectedResponse, resp.Items) }) } } func TestClusterAppsGet(t *testing.T) { env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "test-user@example.com", nil /* roles */) type testResponse struct { Items []ui.App `json:"items"` TotalCount int `json:"totalCount"` } // add a user group ug, err := types.NewUserGroup(types.Metadata{ Name: "ug1", Description: "ug1-description", }, types.UserGroupSpecV1{Applications: []string{"app1"}}) require.NoError(t, err) err = env.server.Auth().CreateUserGroup(context.Background(), ug) require.NoError(t, err) resource := &types.AppServerV3{ Metadata: types.Metadata{Name: "test-app"}, Kind: types.KindAppServer, Version: types.V2, Spec: types.AppServerSpecV3{ HostID: "hostid", App: &types.AppV3{ Metadata: types.Metadata{ Name: "app1", Description: "description", Labels: map[string]string{"test-field": "test-value"}, }, Spec: types.AppSpecV3{ URI: "https://console.aws.amazon.com", // sets field awsConsole to true PublicAddr: "publicaddrs", UserGroups: []string{"ug1"}, }, }, }, } resource2, err := types.NewAppServerV3(types.Metadata{Name: "server2"}, types.AppServerSpecV3{ HostID: "hostid", App: &types.AppV3{ Metadata: types.Metadata{Name: "app2"}, Spec: types.AppSpecV3{URI: "uri", PublicAddr: "publicaddrs"}, }, }) require.NoError(t, err) // Register apps. _, err = env.server.Auth().UpsertApplicationServer(context.Background(), resource) require.NoError(t, err) _, err = env.server.Auth().UpsertApplicationServer(context.Background(), resource2) require.NoError(t, err) // Make the call. endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "apps") re, err := pack.clt.Get(context.Background(), endpoint, url.Values{"sort": []string{"name"}}) require.NoError(t, err) // Test correct response. resp := testResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.Len(t, resp.Items, 2) require.Equal(t, 2, resp.TotalCount) require.ElementsMatch(t, resp.Items, []ui.App{{ Name: "app1", Description: resource.Spec.App.GetDescription(), URI: resource.Spec.App.GetURI(), PublicAddr: resource.Spec.App.GetPublicAddr(), Labels: []ui.Label{{Name: "test-field", Value: "test-value"}}, FQDN: resource.Spec.App.GetPublicAddr(), ClusterID: env.server.ClusterName(), AWSConsole: true, UserGroups: []ui.UserGroupAndDescription{{Name: "ug1", Description: "ug1-description"}}, }, { Name: "app2", URI: "uri", Labels: []ui.Label{}, ClusterID: env.server.ClusterName(), FQDN: "publicaddrs", PublicAddr: "publicaddrs", AWSConsole: false, }}) } // TestApplicationAccessDisabled makes sure application access can be disabled // via modules. func TestApplicationAccessDisabled(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{ TestFeatures: modules.Features{ App: false, }, }) env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "foo@example.com", nil /* roles */) // Register an application. app, err := types.NewAppV3(types.Metadata{ Name: "panel", }, types.AppSpecV3{ URI: "localhost", PublicAddr: "panel.example.com", }) require.NoError(t, err) server, err := types.NewAppServerV3FromApp(app, "host", uuid.New().String()) require.NoError(t, err) _, err = env.server.Auth().UpsertApplicationServer(context.Background(), server) require.NoError(t, err) endpoint := pack.clt.Endpoint("webapi", "sessions", "app") _, err = pack.clt.PostJSON(context.Background(), endpoint, &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }) require.Error(t, err) require.Contains(t, err.Error(), "this Teleport cluster is not licensed for application access") } // TestApplicationWebSessionsDeletedAfterLogout makes sure user's application // sessions are deleted after user logout. func TestApplicationWebSessionsDeletedAfterLogout(t *testing.T) { env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "foo@example.com", nil /* roles */) // Register multiple applications. applications := []struct { name string publicAddr string }{ {name: "panel", publicAddr: "panel.example.com"}, {name: "admin", publicAddr: "admin.example.com"}, {name: "metrics", publicAddr: "metrics.example.com"}, } // Register and create a session for each application. for _, application := range applications { // Register an application. app, err := types.NewAppV3(types.Metadata{ Name: application.name, }, types.AppSpecV3{ URI: "localhost", PublicAddr: application.publicAddr, }) require.NoError(t, err) server, err := types.NewAppServerV3FromApp(app, "host", uuid.New().String()) require.NoError(t, err) _, err = env.server.Auth().UpsertApplicationServer(context.Background(), server) require.NoError(t, err) // Create application session endpoint := pack.clt.Endpoint("webapi", "sessions", "app") _, err = pack.clt.PostJSON(context.Background(), endpoint, &CreateAppSessionRequest{ FQDNHint: application.publicAddr, PublicAddr: application.publicAddr, ClusterName: "localhost", }) require.NoError(t, err) } // List sessions, should have one for each application. sessions, err := proxy.client.GetAppSessions(context.Background()) require.NoError(t, err) require.Len(t, sessions, len(applications)) // Logout from Telport. _, err = pack.clt.Delete(context.Background(), pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // Check sessions after logout, should be empty. sessions, err = proxy.client.GetAppSessions(context.Background()) require.NoError(t, err) require.Len(t, sessions, 0) } func TestGetWebConfig(t *testing.T) { ctx := context.Background() env := newWebPack(t, 1) // Set auth preference with passwordless. const MOTD = "Welcome to cluster, your activity will be recorded." ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOptional, ConnectorName: constants.PasswordlessConnector, Webauthn: &types.Webauthn{ RPID: "localhost", }, MessageOfTheDay: MOTD, }) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) // Add a test connector. github, err := types.NewGithubConnector("test-github", types.GithubConnectorSpecV3{ TeamsToLogins: []types.TeamMapping{ { Organization: "octocats", Team: "dummy", Logins: []string{"dummy"}, }, }, }) require.NoError(t, err) err = env.server.Auth().UpsertGithubConnector(ctx, github) require.NoError(t, err) expectedCfg := webclient.WebConfig{ Auth: webclient.WebConfigAuthSettings{ SecondFactor: constants.SecondFactorOptional, Providers: []webclient.WebConfigAuthProvider{{ Name: "test-github", Type: constants.Github, WebAPIURL: webclient.WebConfigAuthProviderGitHubURL, }}, LocalAuthEnabled: true, AllowPasswordless: true, AuthType: constants.Local, PreferredLocalMFA: constants.SecondFactorWebauthn, LocalConnectorName: constants.PasswordlessConnector, PrivateKeyPolicy: keys.PrivateKeyPolicyNone, MOTD: MOTD, }, CanJoinSessions: true, ProxyClusterName: env.server.ClusterName(), IsCloud: false, AssistEnabled: false, AutomaticUpgrades: false, } // Make a request. clt := env.proxies[0].newClient(t) endpoint := clt.Endpoint("web", "config.js") re, err := clt.Get(ctx, endpoint, nil) require.NoError(t, err) require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) // Response is type application/javascript, we need to strip off the variable name // and the semicolon at the end, then we are left with json like object. var cfg webclient.WebConfig str := strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) require.NoError(t, err) require.Equal(t, expectedCfg, cfg) // update features and assert that it is properly updated on the config object modules.SetTestModules(t, &modules.TestModules{ TestFeatures: modules.Features{ Cloud: true, IsUsageBasedBilling: true, AutomaticUpgrades: true, }, }) mockProxySetting := &mockProxySettings{ mockedGetProxySettings: func(ctx context.Context) (*webclient.ProxySettings, error) { return &webclient.ProxySettings{AssistEnabled: true}, nil }, } env.proxies[0].handler.handler.cfg.ProxySettings = mockProxySetting expectedCfg.IsCloud = true expectedCfg.IsUsageBasedBilling = true expectedCfg.AutomaticUpgrades = true expectedCfg.AssistEnabled = true // request and verify enabled features are enabled. re, err = clt.Get(ctx, endpoint, nil) require.NoError(t, err) require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) str = strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) require.NoError(t, err) require.Equal(t, expectedCfg, cfg) // use mock client to assert that if ping returns an error, we'll default to // cluster config mockClient := mockedPingTestProxy{ mockedPing: func(ctx context.Context) (authproto.PingResponse, error) { return authproto.PingResponse{}, errors.New("err") }, } env.proxies[0].client = mockClient expectedCfg.AutomaticUpgrades = false // update modules but NOT the expected config modules.SetTestModules(t, &modules.TestModules{ TestFeatures: modules.Features{ Cloud: false, IsUsageBasedBilling: false, }, }) // request and verify again re, err = clt.Get(ctx, endpoint, nil) require.NoError(t, err) require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) str = strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) require.NoError(t, err) require.Equal(t, expectedCfg, cfg) } func TestCreatePrivilegeToken(t *testing.T) { t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] // Create a user with second factor totp. pack := proxy.authPack(t, "foo@example.com", nil /* roles */) // Get a totp code. totpCode, err := totp.GenerateCode(pack.otpSecret, env.clock.Now().Add(30*time.Second)) require.NoError(t, err) endpoint := pack.clt.Endpoint("webapi", "users", "privilege", "token") re, err := pack.clt.PostJSON(context.Background(), endpoint, &privilegeTokenRequest{ SecondFactorToken: totpCode, }) require.NoError(t, err) var privilegeToken string err = json.Unmarshal(re.Bytes(), &privilegeToken) require.NoError(t, err) require.NotEmpty(t, privilegeToken) } func TestAddMFADevice(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "foo@example.com", nil /* roles */) // Enable second factor. ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOptional, Webauthn: &types.Webauthn{ RPID: "localhost", }, }) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) // Get a totp code to re-auth. totpCode, err := totp.GenerateCode(pack.otpSecret, env.clock.Now().Add(30*time.Second)) require.NoError(t, err) // Obtain a privilege token. endpoint := pack.clt.Endpoint("webapi", "users", "privilege", "token") re, err := pack.clt.PostJSON(ctx, endpoint, &privilegeTokenRequest{ SecondFactorToken: totpCode, }) require.NoError(t, err) var privilegeToken string require.NoError(t, json.Unmarshal(re.Bytes(), &privilegeToken)) tests := []struct { name string deviceName string getTOTPCode func() string getWebauthnResp func() *wanlib.CredentialCreationResponse }{ { name: "new TOTP device", deviceName: "new-totp", getTOTPCode: func() string { // Create totp secrets. res, err := env.server.Auth().CreateRegisterChallenge(ctx, &authproto.CreateRegisterChallengeRequest{ TokenID: privilegeToken, DeviceType: authproto.DeviceType_DEVICE_TYPE_TOTP, }) require.NoError(t, err) _, regRes, err := auth.NewTestDeviceFromChallenge(res, auth.WithTestDeviceClock(env.clock)) require.NoError(t, err) return regRes.GetTOTP().Code }, }, { name: "new Webauthn device", deviceName: "new-webauthn", getWebauthnResp: func() *wanlib.CredentialCreationResponse { // Get webauthn register challenge. res, err := env.server.Auth().CreateRegisterChallenge(ctx, &authproto.CreateRegisterChallengeRequest{ TokenID: privilegeToken, DeviceType: authproto.DeviceType_DEVICE_TYPE_WEBAUTHN, }) require.NoError(t, err) _, regRes, err := auth.NewTestDeviceFromChallenge(res) require.NoError(t, err) return wanlib.CredentialCreationResponseFromProto(regRes.GetWebauthn()) }, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() var totpCode string var webauthnRegResp *wanlib.CredentialCreationResponse if tc.getWebauthnResp != nil { webauthnRegResp = tc.getWebauthnResp() } else { totpCode = tc.getTOTPCode() } // Add device. endpoint := pack.clt.Endpoint("webapi", "mfa", "devices") _, err := pack.clt.PostJSON(ctx, endpoint, addMFADeviceRequest{ PrivilegeTokenID: privilegeToken, DeviceName: tc.deviceName, SecondFactorToken: totpCode, WebauthnRegisterResponse: webauthnRegResp, }) require.NoError(t, err) }) } } func TestDeleteMFA(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "foo@example.com", nil /* roles */) // setting up client manually because we need sanitizer off jar, err := cookiejar.New(nil) require.NoError(t, err) opts := []roundtrip.ClientParam{roundtrip.BearerAuth(pack.session.Token), roundtrip.CookieJar(jar), roundtrip.HTTPClient(client.NewInsecureWebClient())} rclt, err := roundtrip.NewClient(proxy.webURL.String(), teleport.WebAPIVersion, opts...) require.NoError(t, err) clt := client.WebClient{Client: rclt} jar.SetCookies(&proxy.webURL, pack.cookies) totpCode, err := totp.GenerateCode(pack.otpSecret, env.clock.Now().Add(30*time.Second)) require.NoError(t, err) // Obtain a privilege token. endpoint := pack.clt.Endpoint("webapi", "users", "privilege", "token") re, err := pack.clt.PostJSON(ctx, endpoint, &privilegeTokenRequest{ SecondFactorToken: totpCode, }) require.NoError(t, err) var privilegeToken string require.NoError(t, json.Unmarshal(re.Bytes(), &privilegeToken)) names := []string{"x", "??", "%123/", "///", "my/device", "?/%&*1"} for _, devName := range names { devName := devName t.Run(devName, func(t *testing.T) { t.Parallel() otpSecret := base32.StdEncoding.EncodeToString([]byte(devName)) dev, err := services.NewTOTPDevice(devName, otpSecret, env.clock.Now()) require.NoError(t, err) err = env.server.Auth().UpsertMFADevice(ctx, pack.user, dev) require.NoError(t, err) enc := url.PathEscape(devName) _, err = clt.Delete(ctx, pack.clt.Endpoint("webapi", "mfa", "token", privilegeToken, "devices", enc)) require.NoError(t, err) }) } } func TestGetMFADevicesWithAuth(t *testing.T) { t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "foo@example.com", nil /* roles */) endpoint := pack.clt.Endpoint("webapi", "mfa", "devices") re, err := pack.clt.Get(context.Background(), endpoint, url.Values{}) require.NoError(t, err) var devices []ui.MFADevice err = json.Unmarshal(re.Bytes(), &devices) require.NoError(t, err) require.Len(t, devices, 1) } func TestGetAndDeleteMFADevices_WithRecoveryApprovedToken(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] // Create a user with a TOTP device. username := "llama" proxy.createUser(ctx, t, username, "root", "password", "some-otp-secret", nil /* roles */) // Enable second factor. ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOptional, Webauthn: &types.Webauthn{ RPID: env.server.ClusterName(), }, }) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) // Acquire an approved token. approvedToken, err := types.NewUserToken("some-token-id") require.NoError(t, err) approvedToken.SetUser(username) approvedToken.SetSubKind(auth.UserTokenTypeRecoveryApproved) approvedToken.SetExpiry(env.clock.Now().Add(5 * time.Minute)) _, err = env.server.Auth().CreateUserToken(ctx, approvedToken) require.NoError(t, err) // Call the getter endpoint. clt := proxy.newClient(t) getDevicesEndpoint := clt.Endpoint("webapi", "mfa", "token", approvedToken.GetName(), "devices") res, err := clt.Get(ctx, getDevicesEndpoint, url.Values{}) require.NoError(t, err) var devices []ui.MFADevice err = json.Unmarshal(res.Bytes(), &devices) require.NoError(t, err) require.Len(t, devices, 1) // Call the delete endpoint. _, err = clt.Delete(ctx, clt.Endpoint("webapi", "mfa", "token", approvedToken.GetName(), "devices", devices[0].Name)) require.NoError(t, err) // Check device has been deleted. res, err = clt.Get(ctx, getDevicesEndpoint, url.Values{}) require.NoError(t, err) err = json.Unmarshal(res.Bytes(), &devices) require.NoError(t, err) require.Len(t, devices, 0) } func TestCreateAuthenticateChallenge(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] // Create a user with a TOTP device, with second factor preference to OTP only. authPack := proxy.authPack(t, "llama@example.com", nil /* roles */) // Authenticated client for private endpoints. authnClt := authPack.clt // Unauthenticated client for public endpoints. publicClt := proxy.newClient(t) // Acquire a start token, for the request the requires it. startToken, err := types.NewUserToken("some-token-id") require.NoError(t, err) startToken.SetUser(authPack.user) startToken.SetSubKind(auth.UserTokenTypeRecoveryStart) startToken.SetExpiry(env.clock.Now().Add(5 * time.Minute)) _, err = env.server.Auth().CreateUserToken(ctx, startToken) require.NoError(t, err) tests := []struct { name string clt *TestWebClient ep []string reqBody client.MFAChallengeRequest }{ { name: "/webapi/mfa/authenticatechallenge/password", clt: authnClt, ep: []string{"webapi", "mfa", "authenticatechallenge", "password"}, reqBody: client.MFAChallengeRequest{ Pass: authPack.password, }, }, { name: "/webapi/mfa/login/begin", clt: publicClt, ep: []string{"webapi", "mfa", "login", "begin"}, reqBody: client.MFAChallengeRequest{ User: authPack.user, Pass: authPack.password, }, }, { name: "/webapi/mfa/authenticatechallenge", clt: authnClt, ep: []string{"webapi", "mfa", "authenticatechallenge"}, }, { name: "/webapi/mfa/token/:token/authenticatechallenge", clt: publicClt, ep: []string{"webapi", "mfa", "token", startToken.GetName(), "authenticatechallenge"}, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() endpoint := tc.clt.Endpoint(tc.ep...) res, err := tc.clt.PostJSON(ctx, endpoint, tc.reqBody) require.NoError(t, err) var chal client.MFAAuthenticateChallenge err = json.Unmarshal(res.Bytes(), &chal) require.NoError(t, err) require.True(t, chal.TOTPChallenge) require.Empty(t, chal.WebauthnChallenge) }) } } func TestCreateRegisterChallenge(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] clt := proxy.newClient(t) // Enable second factor. ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOn, Webauthn: &types.Webauthn{ RPID: env.server.ClusterName(), }, }) require.NoError(t, err) require.NoError(t, env.server.Auth().SetAuthPreference(ctx, ap)) // Acquire an accepted token. token, err := types.NewUserToken("some-token-id") require.NoError(t, err) token.SetUser("llama") token.SetSubKind(auth.UserTokenTypePrivilege) token.SetExpiry(env.clock.Now().Add(5 * time.Minute)) _, err = env.server.Auth().CreateUserToken(ctx, token) require.NoError(t, err) tests := []struct { name string req *createRegisterChallengeRequest assertChallenge func(t *testing.T, c *client.MFARegisterChallenge) }{ { name: "totp", req: &createRegisterChallengeRequest{ DeviceType: "totp", }, }, { name: "webauthn", req: &createRegisterChallengeRequest{ DeviceType: "webauthn", }, }, { name: "passwordless", req: &createRegisterChallengeRequest{ DeviceType: "webauthn", DeviceUsage: "passwordless", }, assertChallenge: func(t *testing.T, c *client.MFARegisterChallenge) { // rrk=true is a good proxy for passwordless. require.NotNil(t, c.Webauthn.Response.AuthenticatorSelection.RequireResidentKey, "rrk cannot be nil") require.True(t, *c.Webauthn.Response.AuthenticatorSelection.RequireResidentKey, "rrk cannot be false") }, }, } for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() endpoint := clt.Endpoint("webapi", "mfa", "token", token.GetName(), "registerchallenge") res, err := clt.PostJSON(ctx, endpoint, tc.req) require.NoError(t, err) var chal client.MFARegisterChallenge require.NoError(t, json.Unmarshal(res.Bytes(), &chal)) switch tc.req.DeviceType { case "totp": require.NotNil(t, chal.TOTP.QRCode, "TOTP QR code cannot be nil") case "webauthn": require.NotNil(t, chal.Webauthn, "WebAuthn challenge cannot be nil") } if tc.assertChallenge != nil { tc.assertChallenge(t, &chal) } }) } } // TestCreateAppSession verifies that an existing session to the Web UI can // be exchanged for an application specific session. func TestCreateAppSession(t *testing.T) { t.Parallel() s := newWebSuite(t) pack := s.authPack(t, "foo@example.com") // Register an application called "panel". app, err := types.NewAppV3(types.Metadata{ Name: "panel", }, types.AppSpecV3{ URI: "http://127.0.0.1:8080", PublicAddr: "panel.example.com", }) require.NoError(t, err) server, err := types.NewAppServerV3FromApp(app, "host", uuid.New().String()) require.NoError(t, err) _, err = s.server.Auth().UpsertApplicationServer(s.ctx, server) require.NoError(t, err) // Extract the session ID and bearer token for the current session. rawCookie := *pack.cookies[0] cookieBytes, err := hex.DecodeString(rawCookie.Value) require.NoError(t, err) var sessionCookie websession.Cookie err = json.Unmarshal(cookieBytes, &sessionCookie) require.NoError(t, err) tests := []struct { name string inCreateRequest *CreateAppSessionRequest outError require.ErrorAssertionFunc outFQDN string outUsername string }{ { name: "Valid request: all fields", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }, outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { name: "Valid request: without FQDN", inCreateRequest: &CreateAppSessionRequest{ PublicAddr: "panel.example.com", ClusterName: "localhost", }, outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { name: "Valid request: only FQDN", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", }, outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { name: "Invalid request: only public address", inCreateRequest: &CreateAppSessionRequest{ PublicAddr: "panel.example.com", }, outError: require.Error, }, { name: "Invalid request: only cluster name", inCreateRequest: &CreateAppSessionRequest{ ClusterName: "localhost", }, outError: require.Error, }, { name: "Invalid application", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "invalid.example.com", ClusterName: "localhost", }, outError: require.Error, }, { name: "Invalid cluster name", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com", PublicAddr: "panel.example.com", ClusterName: "example.com", }, outError: require.Error, }, { name: "Malicious request: all fields", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com@malicious.com", PublicAddr: "panel.example.com", ClusterName: "localhost", }, outError: require.NoError, outFQDN: "panel.example.com", outUsername: "foo@example.com", }, { name: "Malicious request: only FQDN", inCreateRequest: &CreateAppSessionRequest{ FQDNHint: "panel.example.com@malicious.com", }, outError: require.Error, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() // Make a request to create an application session for "panel". endpoint := pack.clt.Endpoint("webapi", "sessions", "app") resp, err := pack.clt.PostJSON(s.ctx, endpoint, tt.inCreateRequest) tt.outError(t, err) if err != nil { return } // Unmarshal the response. var response *CreateAppSessionResponse require.NoError(t, json.Unmarshal(resp.Bytes(), &response)) require.Equal(t, tt.outFQDN, response.FQDN) // Verify that the application session was created. sess, err := s.server.Auth().GetAppSession(s.ctx, types.GetAppSessionRequest{ SessionID: response.CookieValue, }) require.NoError(t, err) require.Equal(t, tt.outUsername, sess.GetUser()) require.NotEmpty(t, response.CookieValue) require.Equal(t, response.CookieValue, sess.GetName()) require.NotEmpty(t, response.SubjectCookieValue, "every session should create a secret token") require.Equal(t, response.SubjectCookieValue, sess.GetBearerToken()) }) } } func TestCreateAppSessionHealthCheckAppServer(t *testing.T) { t.Parallel() validApp, err := types.NewAppV3(types.Metadata{ Name: "valid", }, types.AppSpecV3{ URI: "http://127.0.0.1:8080", PublicAddr: "valid.example.com", }) require.NoError(t, err) invalidApp, err := types.NewAppV3(types.Metadata{ Name: "invalid", }, types.AppSpecV3{ URI: "http://127.0.0.1:8080", PublicAddr: "invalid.example.com", }) require.NoError(t, err) s := newWebSuiteWithConfig(t, webSuiteConfig{ HealthCheckAppServer: func(_ context.Context, publicAddr string, _ string) error { // Can only serve "validApp". if publicAddr == validApp.GetPublicAddr() { return nil } return trace.ConnectionProblem(nil, "offline AppServer") }, }) for _, app := range []*types.AppV3{validApp, invalidApp} { server, err := types.NewAppServerV3FromApp(app, "host", uuid.New().String()) require.NoError(t, err) _, err = s.server.Auth().UpsertApplicationServer(s.ctx, server) require.NoError(t, err) } pack := s.authPack(t, "foo@example.com") rawCookie := *pack.cookies[0] cookieBytes, err := hex.DecodeString(rawCookie.Value) require.NoError(t, err) var sessionCookie websession.Cookie err = json.Unmarshal(cookieBytes, &sessionCookie) require.NoError(t, err) for _, tc := range []struct { desc string publicAddr string expectErr require.ErrorAssertionFunc }{ { desc: "request to application that can be served", publicAddr: validApp.GetPublicAddr(), expectErr: require.NoError, }, { desc: "request to application that cannot be served", publicAddr: invalidApp.GetPublicAddr(), expectErr: require.Error, }, } { t.Run(tc.desc, func(t *testing.T) { endpoint := pack.clt.Endpoint("webapi", "sessions", "app") _, err := pack.clt.PostJSON(s.ctx, endpoint, &CreateAppSessionRequest{ FQDNHint: tc.publicAddr, }) tc.expectErr(t, err) }) } } func TestNewSessionResponseWithRenewSession(t *testing.T) { t.Parallel() env := newWebPack(t, 1) // Set a web idle timeout. duration := time.Duration(5) * time.Minute cfg := types.DefaultClusterNetworkingConfig() cfg.SetWebIdleTimeout(duration) require.NoError(t, env.server.Auth().SetClusterNetworkingConfig(context.Background(), cfg)) proxy := env.proxies[0] pack := proxy.authPack(t, "foo", nil /* roles */) var ns *CreateSessionResponse resp := pack.renewSession(context.Background(), t) require.NoError(t, json.Unmarshal(resp.Bytes(), &ns)) require.Equal(t, int(duration.Milliseconds()), ns.SessionInactiveTimeoutMS) require.Equal(t, roundtrip.AuthBearer, ns.TokenType) require.NotEmpty(t, ns.SessionExpires) require.NotEmpty(t, ns.Token) require.NotEmpty(t, ns.TokenExpiresIn) } // TestWebSessionsRenewDoesNotBreakExistingTerminalSession validates that the // session renewed via one proxy does not force the terminals created by another // proxy to disconnect // // See https://github.com/gravitational/teleport/issues/5265 func TestWebSessionsRenewDoesNotBreakExistingTerminalSession(t *testing.T) { env := newWebPack(t, 2) proxy1, proxy2 := env.proxies[0], env.proxies[1] // Connect to both proxies pack1 := proxy1.authPack(t, "foo", nil /* roles */) pack2 := proxy2.authPackFromPack(t, pack1) ws, _ := proxy2.makeTerminal(t, pack2, "") // Advance the time before renewing the session. // This will allow the new session to have a more plausible // expiration const delta = 30 * time.Second env.clock.Advance(auth.BearerTokenTTL - delta) // Renew the session using the 1st proxy resp := pack1.renewSession(context.Background(), t) // Expire the old session and make sure it has been removed. // The bearer token is also removed after this point, so we have to // use the new session data for future connects env.clock.Advance(delta + 1*time.Second) pack2 = proxy2.authPackFromResponse(t, resp) // Verify that access via the 2nd proxy also works for the same session pack2.validateAPI(context.Background(), t) // Check whether the terminal session is still active validateTerminalStream(t, ws) } // TestWebSessionsRenewAllowsOldBearerTokenToLinger validates that the // bearer token bound to the previous session is still active after the // session renewal, if the renewal happens with a time margin. // // See https://github.com/gravitational/teleport/issues/5265 func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) { // Login to implicitly create a new web session env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "foo", nil /* roles */) delta := 30 * time.Second // Advance the time before renewing the session. // This will allow the new session to have a more plausible // expiration env.clock.Advance(auth.BearerTokenTTL - delta) // make sure we can use client to make authenticated requests // before we issue this request, we will recover session id and bearer token // prevSessionCookie := *pack.cookies[0] prevBearerToken := pack.session.Token resp := pack.renewSession(context.Background(), t) newPack := proxy.authPackFromResponse(t, resp) // new session is functioning newPack.validateAPI(context.Background(), t) sessionCookie := *newPack.cookies[0] bearerToken := newPack.session.Token require.NotEmpty(t, bearerToken) require.NotEmpty(t, cmp.Diff(bearerToken, prevBearerToken)) prevSessionID := decodeSessionCookie(t, prevSessionCookie.Value) activeSessionID := decodeSessionCookie(t, sessionCookie.Value) require.NotEmpty(t, cmp.Diff(prevSessionID, activeSessionID)) // old session is still valid jar, err := cookiejar.New(nil) require.NoError(t, err) oldClt := proxy.newClient(t, roundtrip.BearerAuth(prevBearerToken), roundtrip.CookieJar(jar)) jar.SetCookies(&proxy.webURL, []*http.Cookie{&prevSessionCookie}) _, err = oldClt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) require.NoError(t, err) // now expire the old session and make sure it has been removed env.clock.Advance(delta) _, err = proxy.client.GetWebSession(context.Background(), types.GetWebSessionRequest{ User: "foo", SessionID: prevSessionID, }) require.Regexp(t, "^key.*not found$", err.Error()) // now delete session _, err = newPack.clt.Delete( context.Background(), pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // subsequent requests to use this session will fail _, err = newPack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites"), url.Values{}) require.True(t, trace.IsAccessDenied(err)) } // TestChangeUserAuthentication_recoveryCodesReturnedForCloud tests for following: // - Recovery codes are not returned for usernames that are not emails // - Recovery codes are returned for usernames that are valid emails func TestChangeUserAuthentication_recoveryCodesReturnedForCloud(t *testing.T) { env := newWebPack(t, 1) ctx := context.Background() // Enable second factor. ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) // Enable cloud feature. modules.SetTestModules(t, &modules.TestModules{ TestFeatures: modules.Features{ RecoveryCodes: true, }, }) // Creaet a username that is not a valid email format for recovery. teleUser, err := types.NewUser("invalid-name-for-recovery") require.NoError(t, err) require.NoError(t, env.server.Auth().CreateUser(ctx, teleUser)) // Create a reset password token and secrets. resetToken, err := env.server.Auth().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ Name: "invalid-name-for-recovery", }) require.NoError(t, err) res, err := env.server.Auth().CreateRegisterChallenge(ctx, &authproto.CreateRegisterChallengeRequest{ TokenID: resetToken.GetName(), DeviceType: authproto.DeviceType_DEVICE_TYPE_TOTP, }) require.NoError(t, err) totpCode, err := totp.GenerateCode(res.GetTOTP().GetSecret(), env.clock.Now()) require.NoError(t, err) // Test invalid username does not receive codes. clt := env.proxies[0].client re, err := clt.ChangeUserAuthentication(ctx, &authproto.ChangeUserAuthenticationRequest{ TokenID: resetToken.GetName(), NewPassword: []byte("abc123"), NewMFARegisterResponse: &authproto.MFARegisterResponse{Response: &authproto.MFARegisterResponse_TOTP{ TOTP: &authproto.TOTPRegisterResponse{Code: totpCode}, }}, }) require.NoError(t, err) require.Nil(t, re.Recovery) require.False(t, re.PrivateKeyPolicyEnabled) // Create a user that is valid for recovery. teleUser, err = types.NewUser("valid-username@example.com") require.NoError(t, err) require.NoError(t, env.server.Auth().CreateUser(ctx, teleUser)) // Create a reset password token and secrets. resetToken, err = env.server.Auth().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ Name: "valid-username@example.com", }) require.NoError(t, err) res, err = env.server.Auth().CreateRegisterChallenge(ctx, &authproto.CreateRegisterChallengeRequest{ TokenID: resetToken.GetName(), DeviceType: authproto.DeviceType_DEVICE_TYPE_TOTP, }) require.NoError(t, err) totpCode, err = totp.GenerateCode(res.GetTOTP().GetSecret(), env.clock.Now()) require.NoError(t, err) // Test valid username (email) returns codes. re, err = clt.ChangeUserAuthentication(ctx, &authproto.ChangeUserAuthenticationRequest{ TokenID: resetToken.GetName(), NewPassword: []byte("abc123"), NewMFARegisterResponse: &authproto.MFARegisterResponse{Response: &authproto.MFARegisterResponse_TOTP{ TOTP: &authproto.TOTPRegisterResponse{Code: totpCode}, }}, }) require.NoError(t, err) require.Len(t, re.Recovery.Codes, 3) require.NotEmpty(t, re.Recovery.Created) require.False(t, re.PrivateKeyPolicyEnabled) } // TestChangeUserAuthentication_WithPrivacyPolicyEnabledError tests // that when there is a privacy policy enabled error, we still get // a non error response with recovery codes and a privacy policy // flag set to true. func TestChangeUserAuthentication_WithPrivacyPolicyEnabledError(t *testing.T) { env := newWebPack(t, 1) ctx := context.Background() // Enable second factor required by cloud and a privacy policy. ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, RequireMFAType: types.RequireMFAType_HARDWARE_KEY_TOUCH, }) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) // Enable cloud feature. modules.SetTestModules(t, &modules.TestModules{ TestFeatures: modules.Features{ RecoveryCodes: true, }, MockAttestHardwareKey: func(_ context.Context, _ interface{}, policy keys.PrivateKeyPolicy, _ *keys.AttestationStatement, _ crypto.PublicKey, _ time.Duration) (keys.PrivateKeyPolicy, error) { return "", keys.NewPrivateKeyPolicyError(policy) }, }) // Create a user that is valid for recovery. teleUser, err := types.NewUser("valid-username@example.com") require.NoError(t, err) require.NoError(t, env.server.Auth().CreateUser(ctx, teleUser)) // Create a reset password token and secrets. resetToken, err := env.server.Auth().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ Name: "valid-username@example.com", }) require.NoError(t, err) res, err := env.server.Auth().CreateRegisterChallenge(ctx, &authproto.CreateRegisterChallengeRequest{ TokenID: resetToken.GetName(), DeviceType: authproto.DeviceType_DEVICE_TYPE_TOTP, }) require.NoError(t, err) totpCode, err := totp.GenerateCode(res.GetTOTP().GetSecret(), env.clock.Now()) require.NoError(t, err) // Craft http request data. clt := env.proxies[0].newClient(t) req := changeUserAuthenticationRequest{ SecondFactorToken: totpCode, Password: []byte("abc123"), TokenID: resetToken.GetName(), } httpReqData, err := json.Marshal(req) require.NoError(t, err) // CSRF protected endpoint. csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" httpReq, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(httpReqData)) require.NoError(t, err) addCSRFCookieToReq(httpReq, csrfToken) httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set(csrf.HeaderName, csrfToken) httpRes, err := httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { return clt.HTTPClient().Do(httpReq) })) require.NoError(t, err) var apiRes ui.ChangedUserAuthn require.NoError(t, json.Unmarshal(httpRes.Bytes(), &apiRes)) require.Len(t, apiRes.Recovery.Codes, 3) require.NotEmpty(t, apiRes.Recovery.Created) require.True(t, apiRes.PrivateKeyPolicyEnabled) } func TestChangeUserAuthentication_settingDefaultClusterAuthPreference(t *testing.T) { tt := []struct { name string cloud bool numberOfUsers int password []byte authPreferenceType string initialConnectorName string resultConnectorName string }{{ name: "first cloud sign-in changes connector to `passwordless`", cloud: true, numberOfUsers: 1, authPreferenceType: constants.Local, initialConnectorName: "", resultConnectorName: constants.PasswordlessConnector, }, { name: "first non-cloud sign-in doesn't change the connector", cloud: false, numberOfUsers: 1, authPreferenceType: constants.Local, initialConnectorName: "", resultConnectorName: "", }, { name: "second cloud sign-in doesn't change the connector", cloud: true, numberOfUsers: 2, authPreferenceType: constants.Local, initialConnectorName: "", resultConnectorName: "", }, { name: "first cloud sign-in does not change custom connector", cloud: true, numberOfUsers: 1, authPreferenceType: constants.OIDC, initialConnectorName: "custom", resultConnectorName: "custom", }, { name: "first cloud sign-in with password does not change connector", cloud: true, numberOfUsers: 1, password: []byte("abc123"), authPreferenceType: constants.Local, initialConnectorName: "", resultConnectorName: "", }} for _, tc := range tt { modules.SetTestModules(t, &modules.TestModules{ TestFeatures: modules.Features{ Cloud: tc.cloud, }, }) const RPID = "localhost" s := newWebSuiteWithConfig(t, webSuiteConfig{ authPreferenceSpec: &types.AuthPreferenceSpecV2{ Type: tc.authPreferenceType, ConnectorName: tc.initialConnectorName, SecondFactor: constants.SecondFactorOn, Webauthn: &types.Webauthn{ RPID: RPID, }, }, }) // user and role users := make([]types.User, tc.numberOfUsers) for i := 0; i < tc.numberOfUsers; i++ { user, err := types.NewUser(fmt.Sprintf("test_user_%v", i)) require.NoError(t, err) user.SetCreatedBy(types.CreatedBy{ User: types.UserRef{Name: "other_user"}, }) role := services.RoleForUser(user) err = s.server.Auth().UpsertRole(s.ctx, role) require.NoError(t, err) user.AddRole(role.GetName()) err = s.server.Auth().CreateUser(s.ctx, user) require.NoError(t, err) users[i] = user } initialUser := users[0] clt := s.client(t) // create register challenge token, err := s.server.Auth().CreateResetPasswordToken(s.ctx, auth.CreateUserTokenRequest{ Name: initialUser.GetName(), }) require.NoError(t, err) res, err := s.server.Auth().CreateRegisterChallenge(s.ctx, &authproto.CreateRegisterChallengeRequest{ TokenID: token.GetName(), DeviceType: authproto.DeviceType_DEVICE_TYPE_WEBAUTHN, DeviceUsage: authproto.DeviceUsage_DEVICE_USAGE_PASSWORDLESS, }) require.NoError(t, err) cc := wanlib.CredentialCreationFromProto(res.GetWebauthn()) // use passwordless as auth method device, err := mocku2f.Create() require.NoError(t, err) device.SetPasswordless() ccr, err := device.SignCredentialCreation("https://"+RPID, cc) require.NoError(t, err) // send sign-in response to server body, err := json.Marshal(changeUserAuthenticationRequest{ WebauthnCreationResponse: ccr, TokenID: token.GetName(), DeviceName: "passwordless-device", Password: tc.password, }) require.NoError(t, err) req, err := http.NewRequest("PUT", clt.Endpoint("webapi", "users", "password", "token"), bytes.NewBuffer(body)) require.NoError(t, err) csrfToken, err := csrf.GenerateToken() require.NoError(t, err) addCSRFCookieToReq(req, csrfToken) req.Header.Set(csrf.HeaderName, csrfToken) req.Header.Set("Content-Type", "application/json") re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) require.NoError(t, err) require.Equal(t, re.Code(), http.StatusOK) // check if auth preference connectorName is set authPreference, err := s.server.Auth().GetAuthPreference(s.ctx) require.NoError(t, err) require.Equal(t, authPreference.GetConnectorName(), tc.resultConnectorName, "Found unexpected auth connector name") } } func TestParseSSORequestParams(t *testing.T) { t.Parallel() token := "someMeaninglessTokenString" tests := []struct { name, url string wantErr bool expected *SSORequestParams }{ { name: "preserve redirect's query params (escaped)", url: "https://localhost/login?connector_id=oidc&redirect_url=https:%2F%2Flocalhost:8080%2Fweb%2Fcluster%2Fim-a-cluster-name%2Fnodes%3Fsearch=tunnel&sort=hostname:asc", expected: &SSORequestParams{ ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", ConnectorID: "oidc", CSRFToken: token, }, }, { name: "preserve redirect's query params (unescaped)", url: "https://localhost/login?connector_id=github&redirect_url=https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", expected: &SSORequestParams{ ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", ConnectorID: "github", CSRFToken: token, }, }, { name: "preserve various encoded chars", url: "https://localhost/login?connector_id=saml&redirect_url=https:%2F%2Flocalhost:8080%2Fweb%2Fcluster%2Fim-a-cluster-name%2Fapps%3Fquery=search(%2522watermelon%2522%252C%2520%2522this%2522)%2520%2526%2526%2520labels%255B%2522unique-id%2522%255D%2520%253D%253D%2520%2522hi%2522&sort=name:asc", expected: &SSORequestParams{ ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/apps?query=search(%22watermelon%22%2C%20%22this%22)%20%26%26%20labels%5B%22unique-id%22%5D%20%3D%3D%20%22hi%22&sort=name:asc", ConnectorID: "saml", CSRFToken: token, }, }, { name: "invalid redirect_url query param", url: "https://localhost/login?redirect=https://localhost/nodes&connector_id=oidc", wantErr: true, }, { name: "invalid connector_id query param", url: "https://localhost/login?redirect_url=https://localhost/nodes&connector=oidc", wantErr: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req, err := http.NewRequest("", tc.url, nil) require.NoError(t, err) addCSRFCookieToReq(req, token) params, err := ParseSSORequestParams(req) switch { case tc.wantErr: require.Error(t, err) default: require.NoError(t, err) require.Equal(t, tc.expected, params) } }) } } func TestClusterDesktopsGet(t *testing.T) { t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "test-user@example.com", nil /* roles */) type testResponse struct { Items []ui.Desktop `json:"items"` TotalCount int `json:"totalCount"` } // Add a few desktops. resource, err := types.NewWindowsDesktopV3("desktop1", map[string]string{"test-field": "test-value"}, types.WindowsDesktopSpecV3{ Addr: "addr:3389", // test stripping off rdp port HostID: "host", }) require.NoError(t, err) resource2, err := types.NewWindowsDesktopV3("desktop2", map[string]string{"test-field": "test-value2"}, types.WindowsDesktopSpecV3{ Addr: "addr", HostID: "host", }) require.NoError(t, err) err = env.server.Auth().UpsertWindowsDesktop(context.Background(), resource) require.NoError(t, err) err = env.server.Auth().UpsertWindowsDesktop(context.Background(), resource2) require.NoError(t, err) // Make the call. query := url.Values{"sort": []string{"name"}} endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "desktops") re, err := pack.clt.Get(context.Background(), endpoint, query) require.NoError(t, err) // Test correct response. resp := testResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.Len(t, resp.Items, 2) require.Equal(t, 2, resp.TotalCount) require.ElementsMatch(t, resp.Items, []ui.Desktop{{ OS: constants.WindowsOS, Name: "desktop1", Addr: "addr", Labels: []ui.Label{{Name: "test-field", Value: "test-value"}}, HostID: "host", }, { OS: constants.WindowsOS, Name: "desktop2", Addr: "addr", Labels: []ui.Label{{Name: "test-field", Value: "test-value2"}}, HostID: "host", }}) } func TestDesktopActive(t *testing.T) { desktopName := "rickey-rock" env := newWebPack(t, 1) ctx := context.Background() role, err := types.NewRole("admin", types.RoleSpecV6{ Allow: types.RoleConditions{ WindowsDesktopLabels: types.Labels{"environment": []string{"dev"}}, }, }) require.NoError(t, err) pack := env.proxies[0].authPack(t, "foo", []types.Role{role}) check := func(match string) { resp, err := pack.clt.Get(ctx, pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "desktops", desktopName, "active"), url.Values{}) require.NoError(t, err) require.Contains(t, string(resp.Bytes()), match) } check("\"active\":false") desktop, err := types.NewWindowsDesktopV3(desktopName, map[string]string{"environment": "dev"}, types.WindowsDesktopSpecV3{ Domain: "ad", Addr: "foo", HostID: "bar", }) require.NoError(t, err) err = env.server.Auth().CreateWindowsDesktop(ctx, desktop) require.NoError(t, err) tracker, err := types.NewSessionTracker(types.SessionTrackerSpecV1{ SessionID: "foo", Kind: string(types.WindowsDesktopSessionKind), State: types.SessionState_SessionStateRunning, DesktopName: desktopName, }) require.NoError(t, err) _, err = env.server.Auth().CreateSessionTracker(ctx, tracker) require.NoError(t, err) check("\"active\":true") } func TestGetUserOrResetToken(t *testing.T) { env := newWebPack(t, 1) ctx := context.Background() username := "someuser" // Create a username. teleUser, err := types.NewUser(username) require.NoError(t, err) teleUser.SetLogins([]string{"login1"}) require.NoError(t, env.server.Auth().CreateUser(ctx, teleUser)) // Create a reset password token and secrets. resetToken, err := env.server.Auth().CreateResetPasswordToken(ctx, auth.CreateUserTokenRequest{ Name: username, Type: auth.UserTokenTypeResetPasswordInvite, }) require.NoError(t, err) pack := env.proxies[0].authPack(t, "foo", nil /* roles */) // the default roles of foo don't have users read but we need it on our tests fooRole, err := env.server.Auth().GetRole(ctx, "user:foo") require.NoError(t, err) fooAllowRules := fooRole.GetRules(types.Allow) fooAllowRules = append(fooAllowRules, types.NewRule(types.KindUser, services.RO())) fooRole.SetRules(types.Allow, fooAllowRules) require.NoError(t, env.server.Auth().UpsertRole(ctx, fooRole)) resp, err := pack.clt.Get(ctx, pack.clt.Endpoint("webapi", "users", username), url.Values{}) require.NoError(t, err) require.Contains(t, string(resp.Bytes()), "login1") resp, err = pack.clt.Get(ctx, pack.clt.Endpoint("webapi", "users", "password", "token", resetToken.GetName()), url.Values{}) require.NoError(t, err) require.Equal(t, resp.Code(), http.StatusOK) _, err = pack.clt.Get(ctx, pack.clt.Endpoint("webapi", "users", "password", "notToken", resetToken.GetName()), url.Values{}) require.True(t, trace.IsNotFound(err)) } func TestListConnectionsDiagnostic(t *testing.T) { t.Parallel() ctx := context.Background() username := "someuser" diagName := "diag1" roleROConnectionDiagnostics, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Rules: []types.Rule{ types.NewRule(types.KindConnectionDiagnostic, []string{types.VerbRead}), }, }, }) require.NoError(t, err) env := newWebPack(t, 1) clusterName := env.server.ClusterName() pack := env.proxies[0].authPack(t, username, []types.Role{roleROConnectionDiagnostics}) connectionsEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "diagnostics", "connections", diagName) // No connection diagnostics so far, should return not found _, err = pack.clt.Get(ctx, connectionsEndpoint, url.Values{}) require.True(t, trace.IsNotFound(err)) connectionDiagnostic, err := types.NewConnectionDiagnosticV1(diagName, map[string]string{}, types.ConnectionDiagnosticSpecV1{ Success: true, Message: "success for cd0", }) require.NoError(t, err) require.NoError(t, env.server.Auth().CreateConnectionDiagnostic(ctx, connectionDiagnostic)) resp, err := pack.clt.Get(ctx, connectionsEndpoint, url.Values{}) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.Code()) var receivedConnectionDiagnostic ui.ConnectionDiagnostic require.NoError(t, json.Unmarshal(resp.Bytes(), &receivedConnectionDiagnostic)) require.True(t, receivedConnectionDiagnostic.Success) require.Equal(t, receivedConnectionDiagnostic.ID, diagName) require.Equal(t, receivedConnectionDiagnostic.Message, "success for cd0") diag, err := env.server.Auth().GetConnectionDiagnostic(ctx, diagName) require.NoError(t, err) // Adding traces diag.AppendTrace(&types.ConnectionDiagnosticTrace{ Type: types.ConnectionDiagnosticTrace_RBAC_NODE, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "some details", }) diag.SetMessage("after update") require.NoError(t, env.server.Auth().UpdateConnectionDiagnostic(ctx, diag)) resp, err = pack.clt.Get(ctx, connectionsEndpoint, url.Values{}) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.Code()) require.NoError(t, json.Unmarshal(resp.Bytes(), &receivedConnectionDiagnostic)) require.True(t, receivedConnectionDiagnostic.Success) require.Equal(t, receivedConnectionDiagnostic.ID, diagName) require.Equal(t, receivedConnectionDiagnostic.Message, "after update") require.Len(t, receivedConnectionDiagnostic.Traces, 1) require.NotNil(t, receivedConnectionDiagnostic.Traces[0]) require.Equal(t, receivedConnectionDiagnostic.Traces[0].Details, "some details") } func TestDiagnoseSSHConnection(t *testing.T) { ctx := context.Background() osUser, err := user.Current() require.NoError(t, err) osUsername := osUser.Username require.NotEmpty(t, osUsername) roleWithFullAccess := func(username string, login string) []types.Role { ret, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Namespaces: []string{apidefaults.Namespace}, NodeLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, Rules: []types.Rule{ types.NewRule(types.KindConnectionDiagnostic, services.RW()), }, Logins: []string{login}, }, }) require.NoError(t, err) return []types.Role{ret} } require.NotNil(t, roleWithFullAccess) rolesWithoutAccessToNode := func(username string, login string) []types.Role { ret, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Namespaces: []string{apidefaults.Namespace}, NodeLabels: types.Labels{"forbidden": []string{"yes"}}, Rules: []types.Rule{ types.NewRule(types.KindConnectionDiagnostic, services.RW()), }, Logins: []string{login}, }, }) require.NoError(t, err) return []types.Role{ret} } require.NotNil(t, rolesWithoutAccessToNode) roleWithPrincipal := func(username string, principal string) []types.Role { ret, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Namespaces: []string{apidefaults.Namespace}, NodeLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, Rules: []types.Rule{ types.NewRule(types.KindConnectionDiagnostic, services.RW()), }, Logins: []string{principal}, }, }) require.NoError(t, err) return []types.Role{ret} } require.NotNil(t, roleWithPrincipal) env := newWebPack(t, 1) nodeName := env.node.GetInfo().GetHostname() // Wait for node to show up require.Eventually(t, func() bool { _, err := env.server.Auth().GetNode(ctx, apidefaults.Namespace, nodeName) if trace.IsNotFound(err) { return false } assert.NoError(t, err, "GetNode returned an unexpected error") return true }, 5*time.Second, 250*time.Millisecond) for _, tt := range []struct { name string teleportUser string roles []types.Role resourceName string nodeUser string stopNode bool expectedSuccess bool expectedMessage string expectedTraces []types.ConnectionDiagnosticTrace }{ { name: "success", roles: roleWithFullAccess("success", osUsername), teleportUser: "success", resourceName: nodeName, nodeUser: osUsername, expectedSuccess: true, expectedMessage: "success", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_RBAC_NODE, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "You have access to the Node.", }, { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Node is alive and reachable.", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "The requested principal is allowed.", }, { Type: types.ConnectionDiagnosticTrace_NODE_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: fmt.Sprintf("%q user exists in target node", osUsername), }, }, }, { name: "node not found", roles: roleWithFullAccess("nodenotfound", osUsername), teleportUser: "nodenotfound", resourceName: "notanode", nodeUser: osUsername, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `Failed to connect to the Node. Ensure teleport service is running using "systemctl status teleport".`, Error: "Teleport proxy failed to connect to", }, }, }, { name: "node not reachable", teleportUser: "nodenotreachable", roles: roleWithFullAccess("nodenotreachable", osUsername), resourceName: nodeName, nodeUser: osUsername, stopNode: true, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `Failed to connect to the Node. Ensure teleport service is running using "systemctl status teleport".`, Error: "Teleport proxy failed to connect to", }, }, }, { name: "no access to node", teleportUser: "userwithoutaccess", roles: rolesWithoutAccessToNode("userwithoutaccess", osUsername), resourceName: nodeName, nodeUser: osUsername, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_RBAC_NODE, Status: types.ConnectionDiagnosticTrace_FAILED, Details: "You are not authorized to access this node. Ensure your role grants access by adding it to the 'node_labels' property.", Error: fmt.Sprintf("user userwithoutaccess@localhost is not authorized to login as %s@localhost: access to node denied", osUsername), }, }, }, { name: "selected principal is not part of the allowed principals", teleportUser: "deniedprincipal", roles: roleWithFullAccess("deniedprincipal", "otherprincipal"), resourceName: nodeName, nodeUser: osUsername, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `Principal "` + osUsername + `" is not allowed by this certificate. Ensure your roles grants access by adding it to the 'login' property.`, Error: `ssh: principal "` + osUsername + `" not in the set of valid principals for given certificate: ["otherprincipal" "-teleport-internal-join"]`, }, }, }, { name: "principal doesnt exist in target host", teleportUser: "principaldoesnotexist", roles: roleWithPrincipal("principaldoesnotexist", "nonvalidlinuxuser"), resourceName: nodeName, nodeUser: "nonvalidlinuxuser", expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_NODE_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `Invalid user. Please ensure the principal "nonvalidlinuxuser" is a valid Linux login in the target node. Output from Node: Failed to launch: user:`, Error: "Process exited with status 255", }, }, }, } { t.Run(tt.name, func(t *testing.T) { localEnv := env if tt.stopNode { localEnv = newWebPack(t, 1) require.NoError(t, localEnv.node.Close()) } clusterName := localEnv.server.ClusterName() pack := localEnv.proxies[0].authPack(t, tt.teleportUser, tt.roles) createConnectionEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "diagnostics", "connections") resp, err := pack.clt.PostJSON(ctx, createConnectionEndpoint, conntest.TestConnectionRequest{ ResourceKind: types.KindNode, ResourceName: tt.resourceName, SSHPrincipal: tt.nodeUser, }) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.Code()) var connectionDiagnostic ui.ConnectionDiagnostic require.NoError(t, json.Unmarshal(resp.Bytes(), &connectionDiagnostic)) gotFailedTraces := 0 expectedFailedTraces := 0 t.Log(tt.name) t.Log(connectionDiagnostic.Message, connectionDiagnostic.Success) for i, trace := range connectionDiagnostic.Traces { if trace.Status == types.ConnectionDiagnosticTrace_FAILED.String() { gotFailedTraces++ } t.Logf("%d status='%s' type='%s' details='%s' error='%s'\n", i, trace.Status, trace.TraceType, trace.Details, trace.Error) } require.Equal(t, tt.expectedSuccess, connectionDiagnostic.Success) require.Equal(t, tt.expectedMessage, connectionDiagnostic.Message) for _, expectedTrace := range tt.expectedTraces { if expectedTrace.Status == types.ConnectionDiagnosticTrace_FAILED { expectedFailedTraces++ } foundTrace := false for _, returnedTrace := range connectionDiagnostic.Traces { if expectedTrace.Type.String() != returnedTrace.TraceType { continue } foundTrace = true require.Equal(t, returnedTrace.Status, expectedTrace.Status.String()) require.Contains(t, returnedTrace.Details, expectedTrace.Details) require.Contains(t, returnedTrace.Error, expectedTrace.Error) } require.True(t, foundTrace, "expected trace %v was not found", expectedTrace) } require.Equal(t, expectedFailedTraces, gotFailedTraces) }) } // Test success with per-session MFA. ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, RequireMFAType: types.RequireMFAType_SESSION, }) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) // Get a totp code to re-auth. pack := env.proxies[0].authPack(t, "llama", roleWithFullAccess("success", osUsername)) totpCode, err := totp.GenerateCode(pack.otpSecret, env.clock.Now().Add(30*time.Second)) require.NoError(t, err) clusterName := env.server.ClusterName() createConnectionEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "diagnostics", "connections") resp, err := pack.clt.PostJSON(ctx, createConnectionEndpoint, conntest.TestConnectionRequest{ ResourceKind: types.KindNode, ResourceName: nodeName, SSHPrincipal: osUsername, MFAResponse: client.MFAChallengeResponse{TOTPCode: totpCode}, }) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.Code()) var connectionDiagnostic ui.ConnectionDiagnostic require.NoError(t, json.Unmarshal(resp.Bytes(), &connectionDiagnostic)) require.True(t, connectionDiagnostic.Success) } func TestDiagnoseKubeConnection(t *testing.T) { var ( validKubeUsers = []string{} multiKubeUsers = []string{"user1", "user2"} validKubeGroups = []string{"validKubeGroup"} invalidKubeGroups = []string{"invalidKubeGroups"} kubeClusterName = "kube_cluster" disconnectedKubeClustername = "dis_kube_cluster" ctx = context.Background() ) roleWithFullAccess := func(username string, kubeUsers, kubeGroups []string) []types.Role { ret, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Namespaces: []string{apidefaults.Namespace}, KubernetesLabels: types.Labels{types.Wildcard: []string{types.Wildcard}}, Rules: []types.Rule{ types.NewRule(types.KindConnectionDiagnostic, services.RW()), }, KubeGroups: kubeGroups, KubeUsers: kubeUsers, KubernetesResources: []types.KubernetesResource{ { Kind: types.KindKubePod, Namespace: types.Wildcard, Name: types.Wildcard, }, }, }, }) require.NoError(t, err) return []types.Role{ret} } require.NotNil(t, roleWithFullAccess) rolesWithoutAccessToKubeCluster := func(username string, kubeUsers, kubeGroups []string) []types.Role { ret, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Namespaces: []string{apidefaults.Namespace}, KubernetesLabels: types.Labels{"forbidden": []string{"yes"}}, Rules: []types.Rule{ types.NewRule(types.KindConnectionDiagnostic, services.RW()), }, KubeGroups: kubeGroups, KubeUsers: kubeUsers, KubernetesResources: []types.KubernetesResource{ { Kind: types.KindKubePod, Namespace: types.Wildcard, Name: types.Wildcard, }, }, }, }) require.NoError(t, err) return []types.Role{ret} } require.NotNil(t, rolesWithoutAccessToKubeCluster) env := newWebPack(t, 1) rt := http.NewServeMux() rt.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if slices.Contains(r.Header.Values("Impersonate-Group"), invalidKubeGroups[0]) { marshalRBACError(t, w) return } marshalValidPodList(t, w) }) testKube := httptest.NewTLSServer(rt) t.Cleanup(func() { testKube.Close() }) startKube( ctx, t, startKubeOptions{ serviceType: kubeproxy.KubeService, authServer: env.server.TLS, clusters: []kubeClusterConfig{ { name: kubeClusterName, apiEndpoint: testKube.URL, }, }, }, ) for _, tt := range []struct { name string teleportUser string roleFunc func(string, []string, []string) []types.Role kubeUsers []string kubeGroups []string resourceName string selectedKubeUser string selectedKubeGroups []string expectedSuccess bool disconnectedKube bool expectedMessage string expectedTraces []types.ConnectionDiagnosticTrace }{ { name: "kube cluster not found", roleFunc: roleWithFullAccess, kubeGroups: validKubeGroups, kubeUsers: validKubeUsers, teleportUser: "notfound", resourceName: "notregistered", expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `Failed to connect to Kubernetes cluster. Ensure the cluster is registered and online.`, Error: "kubernetes cluster \"notregistered\" is not registered or is offline", }, }, }, { name: "kube cluster disconnected", roleFunc: roleWithFullAccess, kubeGroups: validKubeGroups, kubeUsers: validKubeUsers, teleportUser: "disconnected", resourceName: disconnectedKubeClustername, disconnectedKube: true, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `Failed to connect to Kubernetes cluster. Ensure the cluster is registered and online.`, Error: fmt.Sprintf("kubernetes cluster %q is not registered or is offline", disconnectedKubeClustername), }, }, }, { name: "no access to kube cluster", teleportUser: "userwithoutaccess", roleFunc: rolesWithoutAccessToKubeCluster, kubeGroups: validKubeGroups, kubeUsers: validKubeUsers, resourceName: kubeClusterName, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Kubernetes Cluster is registered in Teleport.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "User-associated roles define valid Kubernetes principals.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_KUBE, Status: types.ConnectionDiagnosticTrace_FAILED, Details: "You are not authorized to access this Kubernetes Cluster. Ensure your role grants access by adding it to the 'kubernetes_labels' property.", Error: "[00] access denied", }, }, }, { name: "no kube principals", teleportUser: "userwithoutprincipals", roleFunc: roleWithFullAccess, kubeGroups: nil, kubeUsers: nil, resourceName: kubeClusterName, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Kubernetes Cluster is registered in Teleport.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_FAILED, Details: "User-associated roles do not configure \"kubernetes_groups\" or \"kubernetes_users\". Make sure that at least one is configured for the user.", Error: "Your user's Teleport role does not allow Kubernetes access." + " Please ask cluster administrator to ensure your role has appropriate kubernetes_groups and kubernetes_users set.", }, }, }, { name: "teleport access but Kube RBAC fails", teleportUser: "userbadrbac", roleFunc: roleWithFullAccess, kubeGroups: invalidKubeGroups, kubeUsers: validKubeUsers, resourceName: kubeClusterName, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Kubernetes Cluster is registered in Teleport.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "User-associated roles define valid Kubernetes principals.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_KUBE_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_FAILED, Details: "You are not allowed to list pods in the \"default\" namespace. Make sure your \"kubernetes_groups\" or \"kubernetes_users\" exist in the cluster and grant you access to list pods.", Error: "pods is forbidden: User \"USER\" cannot list resource \"pods\" in API group \"\" in the namespace \"default\"", }, }, }, { name: "user with multiple defined kube_users", roleFunc: roleWithFullAccess, kubeGroups: validKubeGroups, kubeUsers: multiKubeUsers, teleportUser: "multiuser", resourceName: kubeClusterName, expectedSuccess: false, expectedMessage: "failed", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Kubernetes Cluster is registered in Teleport.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `User-associated roles define multiple "kubernetes_users". Make sure that only one value is defined or that you select the target user.`, Error: "please select a user to impersonate, refusing to select a user due to several kubernetes_users set up for this user", }, }, }, { name: "user chose to impersonate invalid kube_users", roleFunc: roleWithFullAccess, kubeGroups: validKubeGroups, kubeUsers: multiKubeUsers, teleportUser: "userwithWrongImpUser", resourceName: kubeClusterName, expectedSuccess: false, expectedMessage: "failed", selectedKubeUser: "missingUser", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Kubernetes Cluster is registered in Teleport.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `User-associated roles do now allow the desired "kubernetes_user" impersonation. Please define a "kubernetes_user" that your roles allow to impersonate.`, Error: `impersonation request has been denied, user header "missingUser" is not allowed in roles`, }, }, }, { name: "user chose to impersonate invalid kube_group", roleFunc: roleWithFullAccess, kubeGroups: validKubeGroups, kubeUsers: multiKubeUsers, teleportUser: "userwithWrongImpGroup", resourceName: kubeClusterName, expectedSuccess: false, expectedMessage: "failed", selectedKubeUser: "user1", selectedKubeGroups: []string{"missingGroup"}, expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Kubernetes Cluster is registered in Teleport.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_FAILED, Details: `User-associated roles do now allow the desired "kubernetes_group" impersonation. Please define a "kubernetes_group" that your roles allow to impersonate.`, Error: `impersonation request has been denied, group header "missingGroup" value is not allowed in roles`, }, }, }, { name: "user with multiple defined kube_users", roleFunc: roleWithFullAccess, kubeGroups: validKubeGroups, kubeUsers: validKubeUsers, teleportUser: "successwithmultiusers", resourceName: kubeClusterName, expectedSuccess: true, expectedMessage: "success", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Kubernetes Cluster is registered in Teleport.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "User-associated roles define valid Kubernetes principals.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_KUBE, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "You are authorized to access this Kubernetes Cluster.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_KUBE_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Access to the Kubernetes Cluster granted.", Error: "", }, }, }, { name: "success", roleFunc: roleWithFullAccess, kubeGroups: validKubeGroups, kubeUsers: validKubeUsers, teleportUser: "success", resourceName: kubeClusterName, expectedSuccess: true, expectedMessage: "success", expectedTraces: []types.ConnectionDiagnosticTrace{ { Type: types.ConnectionDiagnosticTrace_CONNECTIVITY, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Kubernetes Cluster is registered in Teleport.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "User-associated roles define valid Kubernetes principals.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_RBAC_KUBE, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "You are authorized to access this Kubernetes Cluster.", Error: "", }, { Type: types.ConnectionDiagnosticTrace_KUBE_PRINCIPAL, Status: types.ConnectionDiagnosticTrace_SUCCESS, Details: "Access to the Kubernetes Cluster granted.", Error: "", }, }, }, } { t.Run(tt.name, func(t *testing.T) { localEnv := env if tt.disconnectedKube { kubeServer, cleanup, _ := startKubeWithoutCleanup(ctx, t, startKubeOptions{ serviceType: kubeproxy.KubeService, authServer: env.server.TLS, clusters: []kubeClusterConfig{ { name: tt.resourceName, apiEndpoint: testKube.URL, }, }, }) err := kubeServer.Close() require.NoError(t, err) require.NoError(t, cleanup()) } clusterName := localEnv.server.ClusterName() roles := tt.roleFunc(tt.teleportUser, tt.kubeUsers, tt.kubeGroups) pack := localEnv.proxies[0].authPack(t, tt.teleportUser, roles) createConnectionEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "diagnostics", "connections") resp, err := pack.clt.PostJSON(ctx, createConnectionEndpoint, conntest.TestConnectionRequest{ ResourceKind: types.KindKubernetesCluster, ResourceName: tt.resourceName, // Default is 30 seconds but since tests run locally, we can reduce this value to also improve test responsiveness DialTimeout: time.Second, KubernetesImpersonation: conntest.KubernetesImpersonation{ KubernetesUser: tt.selectedKubeUser, KubernetesGroups: tt.selectedKubeGroups, }, }) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.Code()) var connectionDiagnostic ui.ConnectionDiagnostic require.NoError(t, json.Unmarshal(resp.Bytes(), &connectionDiagnostic)) gotFailedTraces := 0 expectedFailedTraces := 0 t.Log(tt.name) t.Log(connectionDiagnostic.Message, connectionDiagnostic.Success) for i, trace := range connectionDiagnostic.Traces { if trace.Status == types.ConnectionDiagnosticTrace_FAILED.String() { gotFailedTraces++ } t.Logf("%d status='%s' type='%s' details='%s' error='%s'\n", i, trace.Status, trace.TraceType, trace.Details, trace.Error) } require.Equal(t, tt.expectedSuccess, connectionDiagnostic.Success) require.Equal(t, tt.expectedMessage, connectionDiagnostic.Message) for _, expectedTrace := range tt.expectedTraces { if expectedTrace.Status == types.ConnectionDiagnosticTrace_FAILED { expectedFailedTraces++ } foundTrace := false for _, returnedTrace := range connectionDiagnostic.Traces { if expectedTrace.Type.String() != returnedTrace.TraceType { continue } foundTrace = true require.Equal(t, returnedTrace.Status, expectedTrace.Status.String()) require.Equal(t, returnedTrace.Details, expectedTrace.Details) require.Contains(t, returnedTrace.Error, expectedTrace.Error) } require.True(t, foundTrace, expectedTrace) } require.Equal(t, expectedFailedTraces, gotFailedTraces) }) } // Test success with per-session MFA. ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, RequireMFAType: types.RequireMFAType_SESSION, }) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) // Get a totp code to re-auth. pack := env.proxies[0].authPack(t, "llama", roleWithFullAccess("llama", validKubeUsers, validKubeGroups)) totpCode, err := totp.GenerateCode(pack.otpSecret, env.clock.Now().Add(30*time.Second)) require.NoError(t, err) clusterName := env.server.ClusterName() createConnectionEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "diagnostics", "connections") resp, err := pack.clt.PostJSON(ctx, createConnectionEndpoint, conntest.TestConnectionRequest{ ResourceKind: types.KindKubernetesCluster, ResourceName: kubeClusterName, // Default is 30 seconds but since tests run locally, we can reduce this value to also improve test responsiveness DialTimeout: time.Second, KubernetesImpersonation: conntest.KubernetesImpersonation{}, MFAResponse: client.MFAChallengeResponse{TOTPCode: totpCode}, }) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.Code()) var connectionDiagnostic ui.ConnectionDiagnostic require.NoError(t, json.Unmarshal(resp.Bytes(), &connectionDiagnostic)) require.True(t, connectionDiagnostic.Success) } func TestCreateDatabase(t *testing.T) { t.Parallel() ctx := context.Background() username := "someuser" roleCreateDatabase, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ DatabaseNames: []string{"name1"}, DatabaseUsers: []string{"user1"}, Rules: []types.Rule{ types.NewRule(types.KindDatabase, []string{types.VerbCreate}), }, DatabaseLabels: types.Labels{ types.Wildcard: {types.Wildcard}, }, }, }) require.NoError(t, err) env := newWebPack(t, 1) clusterName := env.server.ClusterName() pack := env.proxies[0].authPack(t, username, []types.Role{roleCreateDatabase}) createDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases") // Create an initial database to table test a duplicate creation _, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createDatabaseRequest{ Name: "duplicatedb", Protocol: "mysql", URI: "someuri:3306", }) require.NoError(t, err) for _, tt := range []struct { name string req createDatabaseRequest expectedStatus int errAssert require.ErrorAssertionFunc }{ { name: "valid", req: createDatabaseRequest{ Name: "mydatabase", Protocol: "mysql", URI: "someuri:3306", Labels: []ui.Label{ { Name: "teleport.dev/origin", Value: "dynamic", }, }, }, expectedStatus: http.StatusOK, errAssert: require.NoError, }, { name: "valid with labels", req: createDatabaseRequest{ Name: "dbwithlabels", Protocol: "mysql", URI: "someuri:3306", Labels: []ui.Label{ { Name: "env", Value: "prod", }, { Name: "teleport.dev/origin", Value: "dynamic", }, }, }, expectedStatus: http.StatusOK, errAssert: require.NoError, }, { name: "empty name", req: createDatabaseRequest{ Name: "", Protocol: "mysql", URI: "someuri:3306", }, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "missing database name") }, }, { name: "empty protocol", req: createDatabaseRequest{ Name: "emptyprotocol", Protocol: "", URI: "someuri:3306", }, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "missing protocol") }, }, { name: "empty uri", req: createDatabaseRequest{ Name: "emptyuri", Protocol: "mysql", URI: "", }, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "missing uri") }, }, { name: "missing port", req: createDatabaseRequest{ Name: "missingport", Protocol: "mysql", URI: "someuri", }, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "missing port in address") }, }, { name: "duplicatedb", req: createDatabaseRequest{ Name: "duplicatedb", Protocol: "mysql", URI: "someuri:3306", }, expectedStatus: http.StatusConflict, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.True(t, trace.IsAlreadyExists(err), "expected already exists error, got %v", err) require.Contains(t, err.Error(), `failed to create database ("duplicatedb" already exists), please use another name`) }, }, } { // Create database resp, err := pack.clt.PostJSON(ctx, createDatabaseEndpoint, tt.req) tt.errAssert(t, err) require.Equal(t, resp.Code(), tt.expectedStatus, "invalid status code received") if err != nil { continue } // Ensure database exists database, err := env.proxies[0].client.GetDatabase(ctx, tt.req.Name) require.NoError(t, err) require.Equal(t, database.GetName(), tt.req.Name) require.Equal(t, database.GetProtocol(), tt.req.Protocol) require.Equal(t, database.GetURI(), tt.req.URI) // At least the provided labels exist in the database resource databaseLabels := database.GetAllLabels() for _, label := range tt.req.Labels { require.Contains(t, databaseLabels, label.Name, "label not found") require.Equal(t, label.Value, databaseLabels[label.Name], "label exists but has unexpected value") } // Check response value: if tt.expectedStatus == http.StatusOK { result := ui.Database{} require.NoError(t, json.Unmarshal(resp.Bytes(), &result)) require.Equal(t, result, ui.Database{ Name: tt.req.Name, Protocol: tt.req.Protocol, Type: types.DatabaseTypeSelfHosted, Labels: tt.req.Labels, Hostname: "someuri", DatabaseUsers: []string{"user1"}, DatabaseNames: []string{"name1"}, URI: "someuri:3306", }) } } } func TestUpdateDatabase_Errors(t *testing.T) { t.Parallel() ctx := context.Background() databaseName := "somedb" username := "someuser" roleCreateUpdateDatabase, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Rules: []types.Rule{ types.NewRule(types.KindDatabase, []string{types.VerbCreate, types.VerbUpdate, types.VerbRead}), }, DatabaseLabels: types.Labels{ types.Wildcard: {types.Wildcard}, }, }, }) require.NoError(t, err) env := newWebPack(t, 1) clusterName := env.server.ClusterName() pack := env.proxies[0].authPack(t, username, []types.Role{roleCreateUpdateDatabase}) // Create database createDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases") _, err = pack.clt.PostJSON(ctx, createDatabaseEndpoint, createDatabaseRequest{ Name: databaseName, Protocol: "mysql", URI: "someuri:3306", }) require.NoError(t, err) for _, tt := range []struct { name string req updateDatabaseRequest expectedStatus int errAssert require.ErrorAssertionFunc }{ { name: "empty ca_cert", req: updateDatabaseRequest{ CACert: strPtr(""), }, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "missing CA certificate data") }, }, { name: "invalid certificate", req: updateDatabaseRequest{ CACert: strPtr("Not a certificate"), }, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "could not parse provided CA as X.509 PEM certificate") }, }, { name: "invalid awsRDS missing resourceID field", req: updateDatabaseRequest{ AWSRDS: &awsRDS{ AccountID: "123123123123", }, }, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "missing aws rds field resource id") }, }, { name: "invalid awsRDS missing accountID field", req: updateDatabaseRequest{ AWSRDS: &awsRDS{ ResourceID: "123123123123", }, }, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "missing aws rds field account id") }, }, { name: "no fields defined", req: updateDatabaseRequest{}, expectedStatus: http.StatusBadRequest, errAssert: func(tt require.TestingT, err error, i ...interface{}) { require.ErrorContains(tt, err, "missing fields to update the database") }, }, } { t.Run(tt.name, func(t *testing.T) { // Update database's CA Cert updateDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases", databaseName) resp, err := pack.clt.PutJSON(ctx, updateDatabaseEndpoint, tt.req) tt.errAssert(t, err) require.Equal(t, resp.Code(), tt.expectedStatus, "invalid status code received") }) } } func TestUpdateDatabase_NonErrors(t *testing.T) { t.Parallel() ctx := context.Background() databaseName := "somedb" username := "someuser" roleCreateUpdateDatabase, err := types.NewRole(services.RoleNameForUser(username), types.RoleSpecV6{ Allow: types.RoleConditions{ Rules: []types.Rule{ types.NewRule(types.KindDatabase, []string{types.VerbCreate, types.VerbUpdate, types.VerbRead}), }, DatabaseLabels: types.Labels{ types.Wildcard: {types.Wildcard}, }, }, }) require.NoError(t, err) env := newWebPack(t, 1) clusterName := env.server.ClusterName() pack := env.proxies[0].authPack(t, username, []types.Role{roleCreateUpdateDatabase}) // Create a database. dbProtocol := "mysql" database, err := getNewDatabaseResource(createDatabaseRequest{ Name: databaseName, Protocol: dbProtocol, URI: "someuri:3306", }) require.NoError(t, err) require.NoError(t, env.server.Auth().CreateDatabase(ctx, database)) requiredOriginLabel := ui.Label{Name: types.OriginLabel, Value: types.OriginDynamic} // Each test case builds on top of each other. for _, tt := range []struct { name string req updateDatabaseRequest expectedFields ui.Database expectedAWSRDS awsRDS }{ { name: "update caCert", req: updateDatabaseRequest{ CACert: &fakeValidTLSCert, }, expectedFields: ui.Database{ Name: databaseName, Protocol: dbProtocol, Type: "self-hosted", Hostname: "someuri", Labels: []ui.Label{requiredOriginLabel}, URI: "someuri:3306", }, }, { name: "update URI", req: updateDatabaseRequest{ URI: "something-else:3306", }, expectedFields: ui.Database{ Name: databaseName, Protocol: dbProtocol, Type: "self-hosted", Hostname: "something-else", Labels: []ui.Label{requiredOriginLabel}, URI: "something-else:3306", }, }, { name: "update aws rds fields", req: updateDatabaseRequest{ URI: "llama.cgi8.us-west-2.rds.amazonaws.com:3306", AWSRDS: &awsRDS{ AccountID: "123123123123", ResourceID: "db-1234", }, }, expectedAWSRDS: awsRDS{ AccountID: "123123123123", ResourceID: "db-1234", }, expectedFields: ui.Database{ Name: databaseName, Protocol: dbProtocol, Type: "rds", Hostname: "llama.cgi8.us-west-2.rds.amazonaws.com", Labels: []ui.Label{requiredOriginLabel}, URI: "llama.cgi8.us-west-2.rds.amazonaws.com:3306", AWS: &ui.AWS{ AWS: types.AWS{ Region: "us-west-2", AccountID: "123123123123", RDS: types.RDS{ ResourceID: "db-1234", InstanceID: "llama", }, }, }, }, }, { name: "update labels", req: updateDatabaseRequest{ Labels: []ui.Label{{Name: "env", Value: "prod"}}, }, expectedAWSRDS: awsRDS{ AccountID: "123123123123", ResourceID: "db-1234", }, expectedFields: ui.Database{ Name: databaseName, Protocol: dbProtocol, Type: "rds", Hostname: "llama.cgi8.us-west-2.rds.amazonaws.com", Labels: []ui.Label{{Name: "env", Value: "prod"}, requiredOriginLabel}, URI: "llama.cgi8.us-west-2.rds.amazonaws.com:3306", AWS: &ui.AWS{ AWS: types.AWS{ Region: "us-west-2", AccountID: "123123123123", RDS: types.RDS{ ResourceID: "db-1234", InstanceID: "llama", }, }, }, }, }, { name: "update multiple fields", req: updateDatabaseRequest{ URI: "alpaca.cgi8.us-east-1.rds.amazonaws.com:3306", AWSRDS: &awsRDS{ AccountID: "000000000000", ResourceID: "db-0000", }, }, expectedAWSRDS: awsRDS{ AccountID: "000000000000", ResourceID: "db-0000", }, expectedFields: ui.Database{ Name: databaseName, Protocol: dbProtocol, Type: "rds", Hostname: "alpaca.cgi8.us-east-1.rds.amazonaws.com", Labels: []ui.Label{{Name: "env", Value: "prod"}, requiredOriginLabel}, URI: "alpaca.cgi8.us-east-1.rds.amazonaws.com:3306", AWS: &ui.AWS{ AWS: types.AWS{ Region: "us-east-1", AccountID: "000000000000", RDS: types.RDS{ ResourceID: "db-0000", InstanceID: "alpaca", }, }, }, }, }, } { t.Run(tt.name, func(t *testing.T) { updateDatabaseEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "databases", databaseName) resp, err := pack.clt.PutJSON(ctx, updateDatabaseEndpoint, tt.req) require.NoError(t, err) var dbResp ui.Database require.NoError(t, json.Unmarshal(resp.Bytes(), &dbResp)) require.Equal(t, tt.expectedFields, dbResp) // Ensure database was updated database, err := env.proxies[0].client.GetDatabase(ctx, databaseName) require.NoError(t, err) require.Equal(t, database.GetCA(), fakeValidTLSCert) // should not have changed require.Equal(t, database.GetType(), tt.expectedFields.Type) require.Equal(t, database.GetProtocol(), tt.expectedFields.Protocol) require.Equal(t, database.GetURI(), fmt.Sprintf("%s:3306", tt.expectedFields.Hostname)) require.Equal(t, database.GetAWS().AccountID, tt.expectedAWSRDS.AccountID) require.Equal(t, database.GetAWS().RDS.ResourceID, tt.expectedAWSRDS.ResourceID) }) } } type authProviderMock struct { server types.ServerV2 } func (mock authProviderMock) GetNodes(ctx context.Context, n string) ([]types.Server, error) { return []types.Server{&mock.server}, nil } func (mock authProviderMock) GetSessionEvents(n string, s session.ID, c int) ([]events.EventFields, error) { return []events.EventFields{}, nil } func (mock authProviderMock) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) { return nil, trace.NotFound("foo") } func (mock authProviderMock) IsMFARequired(ctx context.Context, req *authproto.IsMFARequiredRequest) (*authproto.IsMFARequiredResponse, error) { return nil, nil } func (mock authProviderMock) GenerateUserSingleUseCerts(ctx context.Context) (authproto.AuthService_GenerateUserSingleUseCertsClient, error) { return nil, nil } func (mock authProviderMock) GenerateOpenSSHCert(ctx context.Context, req *authproto.OpenSSHCertRequest) (*authproto.OpenSSHCert, error) { return nil, nil } func (mock authProviderMock) MaintainSessionPresence(ctx context.Context) (authproto.AuthService_MaintainSessionPresenceClient, error) { return nil, nil } func (mock authProviderMock) GetUser(_ string, _ bool) (types.User, error) { return nil, nil } func (mock authProviderMock) GetRole(_ context.Context, _ string) (types.Role, error) { return nil, nil } type terminalOpt func(t *TerminalRequest) func withSessionID(sid session.ID) terminalOpt { return func(t *TerminalRequest) { t.SessionID = sid } } func withServer(target string) terminalOpt { return func(t *TerminalRequest) { t.Server = target } } func withKeepaliveInterval(d time.Duration) terminalOpt { return func(t *TerminalRequest) { t.KeepAliveInterval = d } } func withParticipantMode(m types.SessionParticipantMode) terminalOpt { return func(t *TerminalRequest) { t.ParticipantMode = m } } func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOpt) (*websocket.Conn, *session.Session, error) { req := TerminalRequest{ Server: s.srvID, Login: pack.login, Term: session.TerminalParams{ W: 100, H: 100, }, } for _, opt := range opts { opt(&req) } u := url.URL{ Host: s.url().Host, Scheme: client.WSS, Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", currentSiteShortcut), } data, err := json.Marshal(req) if err != nil { return nil, nil, err } q := u.Query() q.Set("params", string(data)) q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} dialer.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } header := http.Header{} header.Add("Origin", "http://localhost") for _, cookie := range pack.cookies { header.Add("Cookie", cookie.String()) } ws, resp, err := dialer.Dial(u.String(), header) if err != nil { return nil, nil, trace.Wrap(err) } ty, raw, err := ws.ReadMessage() if err != nil { return nil, nil, trace.Wrap(err) } require.Equal(t, websocket.BinaryMessage, ty) var env Envelope err = proto.Unmarshal(raw, &env) if err != nil { return nil, nil, trace.Wrap(err) } var sessResp siteSessionGenerateResponse err = json.Unmarshal([]byte(env.Payload), &sessResp) if err != nil { return nil, nil, trace.Wrap(err) } err = resp.Body.Close() if err != nil { return nil, nil, trace.Wrap(err) } return ws, &sessResp.Session, nil } func waitForOutput(r io.Reader, substr string) error { timeoutCh := time.After(10 * time.Second) var prev string out := make([]byte, int64(len(substr)*2)) for { select { case <-timeoutCh: return trace.BadParameter("timeout waiting on terminal for output: %v", substr) default: } n, err := r.Read(out) outStr := removeSpace(string(out[:n])) // Check for [substr] before checking the error, // as it's valid for n > 0 even when there is an error. // The [substr] is checked against the current and previous // output to account for scenarios where the [substr] is split // across two reads. While we try to prevent this by reading // twice the length of [substr] there are no guarantees the // whole thing will arrive in a single read. if n > 0 && strings.Contains(prev+outStr, substr) { return nil } if err != nil { return trace.Wrap(err) } prev = outStr } } func (s *WebSuite) client(t *testing.T, opts ...roundtrip.ClientParam) *TestWebClient { opts = append(opts, roundtrip.HTTPClient(client.NewInsecureWebClient())) wc, err := client.NewWebClient(s.url().String(), opts...) if err != nil { panic(err) } return &TestWebClient{wc, t} } type TestWebClient struct { *client.WebClient t *testing.T } // It is understood that implementing RoundTrip here will NOT result in calls from `Get`, or `PostJSON` from // client.WebClient getting this verification. Those functions would additionally need to be specified here. // Despite that, currently our use of RoundTrip directly is providing us enough broad coverage to verify these headers. func (c *TestWebClient) RoundTrip(fn roundtrip.RoundTripFn) (*roundtrip.Response, error) { c.t.Helper() resp, err := c.WebClient.RoundTrip(fn) verifySecurityResponseHeaders(c.t, resp.Headers()) return resp, err } func (s *WebSuite) login(clt *TestWebClient, cookieToken string, reqToken string, reqData interface{}) (*roundtrip.Response, error) { return httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { data, err := json.Marshal(reqData) if err != nil { return nil, err } req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions", "web"), bytes.NewBuffer(data)) if err != nil { return nil, err } addCSRFCookieToReq(req, cookieToken) req.Header.Set("Content-Type", "application/json") req.Header.Set(csrf.HeaderName, reqToken) return clt.HTTPClient().Do(req) })) } func (s *WebSuite) loginMFA(clt *TestWebClient, reqData *client.MFAChallengeRequest, device *mocku2f.Key) (*roundtrip.Response, error) { resp, err := httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { data, err := json.Marshal(reqData) if err != nil { return nil, trace.Wrap(err) } req, err := http.NewRequest("POST", clt.Endpoint("webapi", "mfa", "login", "begin"), bytes.NewBuffer(data)) if err != nil { return nil, trace.Wrap(err) } req.Header.Set("Content-Type", "application/json") resp, err := clt.HTTPClient().Do(req) return resp, trace.Wrap(err) })) if err != nil { return nil, trace.Wrap(err) } var challenge client.MFAAuthenticateChallenge err = json.Unmarshal(resp.Bytes(), &challenge) if err != nil { return nil, trace.Wrap(err) } car, err := device.SignAssertion("https://localhost", challenge.WebauthnChallenge) if err != nil { return nil, trace.Wrap(err) } return httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { respData := &client.AuthenticateWebUserRequest{ User: reqData.User, WebauthnAssertionResponse: car, } data, err := json.Marshal(respData) if err != nil { return nil, trace.Wrap(err) } req, err := http.NewRequest("POST", clt.Endpoint("webapi", "mfa", "login", "finishsession"), bytes.NewBuffer(data)) if err != nil { return nil, trace.Wrap(err) } req.Header.Set("Content-Type", "application/json") resp, err := clt.HTTPClient().Do(req) return resp, trace.Wrap(err) })) } func (s *WebSuite) url() *url.URL { u, err := url.Parse("https://" + s.webServer.Listener.Addr().String()) if err != nil { panic(err) } return u } func addCSRFCookieToReq(req *http.Request, token string) { cookie := &http.Cookie{ Name: csrf.CookieName, Value: token, } req.AddCookie(cookie) } func removeSpace(in string) string { for _, c := range []string{"\n", "\r", "\t"} { in = strings.Replace(in, c, " ", -1) } return strings.TrimSpace(in) } func decodeSessionCookie(t *testing.T, value string) (sessionID string) { sessionBytes, err := hex.DecodeString(value) require.NoError(t, err) var cookie struct { User string `json:"user"` SessionID string `json:"sid"` } require.NoError(t, json.Unmarshal(sessionBytes, &cookie)) return cookie.SessionID } func (r CreateSessionResponse) response() (*CreateSessionResponse, error) { return &CreateSessionResponse{TokenType: r.TokenType, Token: r.Token, TokenExpiresIn: r.TokenExpiresIn, SessionInactiveTimeoutMS: r.SessionInactiveTimeoutMS}, nil } func newWebPack(t *testing.T, numProxies int, opts ...proxyOption) *webPack { ctx := context.Background() clock := clockwork.NewFakeClockAt(time.Now()) server, err := auth.NewTestServer(auth.TestServerConfig{ Auth: auth.TestAuthServerConfig{ ClusterName: "localhost", Dir: t.TempDir(), Clock: clock, AuditLog: events.NewDiscardAuditLog(), }, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, server.Shutdown(ctx)) }) // use a sync recording mode because the disk-based uploader // that runs in the background introduces races with test cleanup recConfig := types.DefaultSessionRecordingConfig() recConfig.SetMode(types.RecordAtNodeSync) err = server.AuthServer.AuthServer.SetSessionRecordingConfig(context.Background(), recConfig) require.NoError(t, err) // Register the auth server, since test auth server doesn't start its own // heartbeat. err = server.Auth().UpsertAuthServer(ctx, &types.ServerV2{ Kind: types.KindAuthServer, Version: types.V2, Metadata: types.Metadata{ Namespace: apidefaults.Namespace, Name: "auth", }, Spec: types.ServerSpecV2{ Addr: server.TLS.Listener.Addr().String(), Hostname: "localhost", Version: teleport.Version, }, }) require.NoError(t, err) priv, pub, err := testauthority.New().GenerateKeyPair() require.NoError(t, err) tlsPub, err := auth.PrivateKeyToPublicKeyTLS(priv) require.NoError(t, err) const nodeID = "node" // start auth server certs, err := server.Auth().GenerateHostCerts(ctx, &authproto.HostCertsRequest{ HostID: hostID, NodeName: nodeID, Role: types.RoleNode, PublicSSHKey: pub, PublicTLSKey: tlsPub, }) require.NoError(t, err) signer, err := sshutils.NewSigner(priv, certs.SSH) require.NoError(t, err) hostSigners := []ssh.Signer{signer} nodeClient, err := server.TLS.NewClient(auth.TestIdentity{ I: authz.BuiltinRole{ Role: types.RoleNode, Username: nodeID, }, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, nodeClient.Close()) }) nodeLockWatcher, err := services.NewLockWatcher(ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentNode, Client: nodeClient, }, }) require.NoError(t, err) t.Cleanup(nodeLockWatcher.Close) nodeSessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ Semaphores: nodeClient, AccessPoint: nodeClient, LockEnforcer: nodeLockWatcher, Emitter: nodeClient, Component: teleport.ComponentNode, ServerID: nodeID, }) require.NoError(t, err) // create SSH service: nodeDataDir := t.TempDir() node, err := regular.New( ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, nodeID, hostSigners, nodeClient, nodeDataDir, "", utils.NetAddr{}, nodeClient, regular.SetUUID(nodeID), regular.SetNamespace(apidefaults.Namespace), regular.SetShell("/bin/sh"), regular.SetEmitter(nodeClient), regular.SetPAMConfig(&servicecfg.PAMConfig{Enabled: false}), regular.SetBPF(&bpf.NOP{}), regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(clock), regular.SetLockWatcher(nodeLockWatcher), regular.SetSessionController(nodeSessionController), ) require.NoError(t, err) require.NoError(t, node.Start()) t.Cleanup(func() { require.NoError(t, node.Close()) }) var proxies []*testProxy for p := 0; p < numProxies; p++ { proxyID := fmt.Sprintf("proxy%v", p) proxies = append(proxies, createProxy(ctx, t, proxyID, node, server.TLS, hostSigners, clock, opts...)) } // Wait for proxies to fully register before starting the test. for start := time.Now(); ; { proxies, err := proxies[0].client.GetProxies() require.NoError(t, err) if len(proxies) == numProxies { break } if time.Since(start) > 5*time.Second { t.Fatalf("Proxies didn't register within 5s after startup; registered: %d, want: %d", len(proxies), numProxies) } } return &webPack{ proxies: proxies, server: server, node: node, clock: clock, } } type proxyConfig struct { minimalHandler bool } type proxyOption func(cfg *proxyConfig) func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regular.Server, authServer *auth.TestTLSServer, hostSigners []ssh.Signer, clock clockwork.FakeClock, opts ...proxyOption, ) *testProxy { cfg := proxyConfig{} for _, opt := range opts { opt(&cfg) } // create reverse tunnel service: client, err := authServer.NewClient(auth.TestIdentity{ I: authz.BuiltinRole{ Role: types.RoleProxy, Username: proxyID, }, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, client.Close()) }) revTunListener, err := net.Listen("tcp", fmt.Sprintf("%v:0", authServer.ClusterName())) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunListener.Close()) }) proxyLockWatcher, err := services.NewLockWatcher(ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, Client: client, }, }) require.NoError(t, err) t.Cleanup(proxyLockWatcher.Close) proxyCAWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, Client: client, }, Types: []types.CertAuthType{types.HostCA, types.UserCA}, }) require.NoError(t, err) t.Cleanup(proxyLockWatcher.Close) proxyNodeWatcher, err := services.NewNodeWatcher(ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, Client: client, }, }) require.NoError(t, err) t.Cleanup(proxyNodeWatcher.Close) revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, ClientTLS: client.TLSConfig(), ClusterName: authServer.ClusterName(), HostSigners: hostSigners, LocalAuthClient: client, LocalAccessPoint: client, Emitter: client, NewCachingAccessPoint: noCache, DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, CertAuthorityWatcher: proxyCAWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), LocalAuthAddresses: []string{authServer.Listener.Addr().String()}, }) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, revTunServer.Close()) }) router, err := proxy.NewRouter(proxy.RouterConfig{ ClusterName: authServer.ClusterName(), Log: utils.NewLoggerForTests().WithField(trace.Component, "test"), RemoteClusterGetter: client, SiteGetter: revTunServer, TracerProvider: tracing.NoopProvider(), }) require.NoError(t, err) sessionController, err := srv.NewSessionController(srv.SessionControllerConfig{ Semaphores: client, AccessPoint: client, LockEnforcer: proxyLockWatcher, Emitter: client, Component: teleport.ComponentProxy, ServerID: proxyID, }) require.NoError(t, err) proxyServer, err := regular.New( ctx, utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, authServer.ClusterName(), hostSigners, client, t.TempDir(), "", utils.NetAddr{AddrNetwork: "tcp", Addr: "proxy-1.example.com:443"}, client, regular.SetUUID(proxyID), regular.SetProxyMode("", revTunServer, client, router), regular.SetEmitter(client), regular.SetNamespace(apidefaults.Namespace), regular.SetBPF(&bpf.NOP{}), regular.SetRestrictedSessionManager(&restricted.NOP{}), regular.SetClock(clock), regular.SetLockWatcher(proxyLockWatcher), regular.SetNodeWatcher(proxyNodeWatcher), regular.SetSessionController(sessionController), regular.SetPublicAddrs([]utils.NetAddr{{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}}), ) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyServer.Close()) }) fs, err := newDebugFileSystem() require.NoError(t, err) handler, err := NewHandler(Config{ Proxy: revTunServer, AuthServers: utils.FromAddr(authServer.Addr()), DomainName: authServer.ClusterName(), ProxyClient: client, ProxyPublicAddrs: utils.MustParseAddrList("proxy-1.example.com", "proxy-2.example.com"), CipherSuites: utils.DefaultCipherSuites(), AccessPoint: client, Context: ctx, HostUUID: proxyID, Emitter: client, StaticFS: fs, ProxySettings: &mockProxySettings{}, SessionControl: SessionControllerFunc(func(ctx context.Context, sctx *SessionContext, login, localAddr, remoteAddr string) (context.Context, error) { controller := srv.WebSessionController(sessionController) ctx, err := controller(ctx, sctx, login, localAddr, remoteAddr) return ctx, trace.Wrap(err) }), Router: router, HealthCheckAppServer: func(context.Context, string, string) error { return nil }, MinimalReverseTunnelRoutesOnly: cfg.minimalHandler, }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(clock)) require.NoError(t, err) webServer := httptest.NewTLSServer(handler) t.Cleanup(webServer.Close) require.NoError(t, proxyServer.Start()) proxyAddr := utils.MustParseAddr(proxyServer.Addr()) addr := utils.MustParseAddr(webServer.Listener.Addr().String()) handler.handler.cfg.ProxyWebAddr = *addr handler.handler.cfg.ProxySSHAddr = *proxyAddr _, sshPort, err := net.SplitHostPort(proxyAddr.String()) require.NoError(t, err) handler.handler.sshPort = sshPort kubeProxyAddr := startKube( ctx, t, startKubeOptions{ serviceType: kubeproxy.ProxyService, authServer: authServer, revTunnel: revTunServer, }, ) handler.handler.cfg.ProxyKubeAddr = utils.FromAddr(kubeProxyAddr) url, err := url.Parse("https://" + webServer.Listener.Addr().String()) require.NoError(t, err) handler.handler.cfg.PublicProxyAddr = url.String() return &testProxy{ clock: clock, auth: authServer, client: client, revTun: revTunServer, node: node, proxy: proxyServer, web: webServer, handler: handler, webURL: *url, } } // webPack represents the state of a single web test. // It replicates most of the WebSuite and serves to gradually // transition the test suite to use the testing package // directly. type webPack struct { proxies []*testProxy server *auth.TestServer node *regular.Server clock clockwork.FakeClock } type testProxy struct { clock clockwork.FakeClock client auth.ClientI auth *auth.TestTLSServer revTun reversetunnelclient.Server node *regular.Server proxy *regular.Server handler *APIHandler web *httptest.Server webURL url.URL } // authPack returns new authenticated package consisting of created valid // user, otp token, created web session and authenticated client. func (r *testProxy) authPack(t *testing.T, teleportUser string, roles []types.Role) *authPack { ctx := context.Background() const ( pass = "abc123" rawSecret = "def456" ) u, err := user.Current() require.NoError(t, err) loginUser := u.Username otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, }) require.NoError(t, err) err = r.auth.Auth().SetAuthPreference(ctx, ap) require.NoError(t, err) r.createUser(context.Background(), t, teleportUser, loginUser, pass, otpSecret, roles) // create a valid otp token validToken, err := totp.GenerateCode(otpSecret, r.clock.Now()) require.NoError(t, err) clt := r.newClient(t) req := CreateSessionReq{ User: teleportUser, Pass: pass, SecondFactorToken: validToken, } csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" resp := login(t, clt, csrfToken, csrfToken, req) var rawSession *CreateSessionResponse require.NoError(t, json.Unmarshal(resp.Bytes(), &rawSession)) session, err := rawSession.response() require.NoError(t, err) jar, err := cookiejar.New(nil) require.NoError(t, err) clt = r.newClient(t, roundtrip.BearerAuth(session.Token), roundtrip.CookieJar(jar)) jar.SetCookies(&r.webURL, resp.Cookies()) return &authPack{ otpSecret: otpSecret, user: teleportUser, login: loginUser, session: session, clt: clt, cookies: resp.Cookies(), password: pass, } } func (r *testProxy) authPackFromPack(t *testing.T, pack *authPack) *authPack { jar, err := cookiejar.New(nil) require.NoError(t, err) clt := r.newClient(t, roundtrip.BearerAuth(pack.session.Token), roundtrip.CookieJar(jar)) jar.SetCookies(&r.webURL, pack.cookies) result := *pack result.clt = clt return &result } func (r *testProxy) authPackFromResponse(t *testing.T, httpResp *roundtrip.Response) *authPack { var resp *CreateSessionResponse require.NoError(t, json.Unmarshal(httpResp.Bytes(), &resp)) jar, err := cookiejar.New(nil) require.NoError(t, err) clt := r.newClient(t, roundtrip.BearerAuth(resp.Token), roundtrip.CookieJar(jar)) jar.SetCookies(&r.webURL, httpResp.Cookies()) session, err := resp.response() require.NoError(t, err) if session.TokenExpiresIn < 0 { t.Errorf("Expected expiry time to be in the future but got %v", session.TokenExpiresIn) } return &authPack{ session: session, clt: clt, cookies: httpResp.Cookies(), } } func defaultRoleForNewUser(teleUser types.User, login string) types.Role { role := services.RoleForUser(teleUser) role.SetLogins(types.Allow, []string{login}) role.SetWindowsDesktopLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}}) options := role.GetOptions() options.ForwardAgent = types.NewBool(true) role.SetOptions(options) return role } func (r *testProxy) createUser(ctx context.Context, t *testing.T, user, login, pass, otpSecret string, roles []types.Role) { teleUser, err := types.NewUser(user) require.NoError(t, err) if len(roles) == 0 { roles = []types.Role{defaultRoleForNewUser(teleUser, login)} } for _, role := range roles { err = r.auth.Auth().UpsertRole(ctx, role) require.NoError(t, err) teleUser.AddRole(role.GetName()) } teleUser.SetCreatedBy(types.CreatedBy{ User: types.UserRef{Name: "some-auth-user"}, }) err = r.auth.Auth().CreateUser(ctx, teleUser) require.NoError(t, err) err = r.auth.Auth().UpsertPassword(user, []byte(pass)) require.NoError(t, err) if otpSecret != "" { dev, err := services.NewTOTPDevice("otp", otpSecret, r.clock.Now()) require.NoError(t, err) err = r.auth.Auth().UpsertMFADevice(ctx, user, dev) require.NoError(t, err) } } func (r *testProxy) newClient(t *testing.T, opts ...roundtrip.ClientParam) *TestWebClient { opts = append(opts, roundtrip.HTTPClient(client.NewInsecureWebClient())) clt, err := client.NewWebClient(r.webURL.String(), opts...) require.NoError(t, err) return &TestWebClient{clt, t} } func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) (*websocket.Conn, session.Session) { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", currentSiteShortcut), } requestData := TerminalRequest{ Server: r.node.ID(), Login: pack.login, Term: session.TerminalParams{ W: 100, H: 100, }, } if sessionID != "" { requestData.SessionID = sessionID } data, err := json.Marshal(requestData) require.NoError(t, err) q := u.Query() q.Set("params", string(data)) q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} dialer.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } header := http.Header{} header.Add("Origin", "http://localhost") for _, cookie := range pack.cookies { header.Add("Cookie", cookie.String()) } ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) require.NoError(t, resp.Body.Close()) }) ty, raw, err := ws.ReadMessage() require.NoError(t, err) require.Equal(t, websocket.BinaryMessage, ty) var env Envelope require.NoError(t, proto.Unmarshal(raw, &env)) var sessResp siteSessionGenerateResponse require.NoError(t, json.Unmarshal([]byte(env.Payload), &sessResp)) return ws, sessResp.Session } func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect", currentSiteShortcut, "desktop1"), } q := u.Query() q.Set("username", "marek") q.Set("width", "100") q.Set("height", "100") q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} dialer.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, } header := http.Header{} for _, cookie := range pack.cookies { header.Add("Cookie", cookie.String()) } ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) require.NoError(t, resp.Body.Close()) }) return ws } func login(t *testing.T, clt *TestWebClient, cookieToken, reqToken string, reqData interface{}) *roundtrip.Response { resp, err := httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { data, err := json.Marshal(reqData) if err != nil { return nil, err } req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions", "web"), bytes.NewBuffer(data)) if err != nil { return nil, err } addCSRFCookieToReq(req, cookieToken) req.Header.Set("Content-Type", "application/json") req.Header.Set(csrf.HeaderName, reqToken) return clt.HTTPClient().Do(req) })) require.NoError(t, err) return resp } func validateTerminalStream(t *testing.T, ws *websocket.Conn) { t.Helper() stream := NewTerminalStream(context.Background(), ws, utils.NewLoggerForTests()) // here we intentionally run a command where the output we're looking // for is not present in the command itself _, err := io.WriteString(stream, "echo txlxport | sed 's/x/e/g'\r\n") require.NoError(t, err) require.NoError(t, waitForOutput(stream, "teleport")) } type mockProxySettings struct { mockedGetProxySettings func(ctx context.Context) (*webclient.ProxySettings, error) } func (mock *mockProxySettings) GetProxySettings(ctx context.Context) (*webclient.ProxySettings, error) { if mock.mockedGetProxySettings != nil { return mock.mockedGetProxySettings(ctx) } return &webclient.ProxySettings{}, nil } // GetOpenAIAPIKey returns a dummy OpenAI API key. func (mock *mockProxySettings) GetOpenAIAPIKey() string { return "test-key" } // TestUserContextWithAccessRequest checks that the userContext includes the ID of the // access request after it has been consumed and the web session has been renewed. func TestUserContextWithAccessRequest(t *testing.T) { t.Parallel() env := newWebPack(t, 1) proxy := env.proxies[0] ctx := context.Background() // Set user and role names. username := "user" baseRoleName := "role" requestableRolename := "requestable-role" // Create user's base role with the ability to request the requestable role. baseRole, err := types.NewRole(baseRoleName, types.RoleSpecV6{ Allow: types.RoleConditions{ Request: &types.AccessRequestConditions{ Roles: []string{requestableRolename}, }, }, }) require.NoError(t, err) // Create user with the base role. pack := proxy.authPack(t, username, []types.Role{baseRole}) // Create the requestable role. requestableRole, err := types.NewRole(requestableRolename, types.RoleSpecV6{}) require.NoError(t, err) err = env.server.Auth().UpsertRole(ctx, requestableRole) require.NoError(t, err) identity := tlsca.Identity{ Expires: env.clock.Now().Add(1 * time.Hour), } // Create and approve an access request for the requestable role. accessReq, err := services.NewAccessRequest(username, requestableRolename) require.NoError(t, err) accessReq.SetState(types.RequestState_APPROVED) err = env.server.Auth().CreateAccessRequest(ctx, accessReq, identity) require.NoError(t, err) // Get the ID of the created and approved access request. accessRequestID := accessReq.GetMetadata().Name // Make a request to renew the session with the ID of the access request. _, err = pack.clt.PostJSON(ctx, pack.clt.Endpoint("webapi", "sessions", "web", "renew"), renewSessionRequest{ AccessRequestID: accessRequestID, }) require.NoError(t, err) // Make a request to fetch the userContext. endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "context") response, err := pack.clt.Get(context.Background(), endpoint, url.Values{}) require.NoError(t, err) // Process the JSON response of the request. var userContext ui.UserContext err = json.Unmarshal(response.Bytes(), &userContext) require.NoError(t, err) // Verify that the userContext returned contains the correct Access Request ID. require.Equal(t, accessRequestID, userContext.ConsumedAccessRequestID) } // TestIsMFARequired_AcceptedRequests mostly tests that requests // are formatted correctly. func TestIsMFARequired_AcceptedRequests(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] pack := proxy.authPack(t, "llama", nil /* roles */) cfg, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOTP, RequireMFAType: types.RequireMFAType_SESSION, }) require.NoError(t, err) err = env.server.Auth().SetAuthPreference(ctx, cfg) require.NoError(t, err) for _, test := range []struct { name string errMsg string getRequest func() isMFARequiredRequest }{ { name: "valid db req", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{ Database: &isMFARequiredDatabase{ ServiceName: "name", Protocol: "protocol", }, } }, }, { name: "invalid db req", errMsg: "missing service_name", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{Database: &isMFARequiredDatabase{}} }, }, { name: "valid node req", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{ Node: &isMFARequiredNode{ NodeName: "name", Login: "login", }, } }, }, { name: "invalid node req", errMsg: "missing login", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{Node: &isMFARequiredNode{}} }, }, { name: "valid kube req", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{ Kube: &isMFARequiredKube{ ClusterName: "name", }, } }, }, { name: "invalid kube req", errMsg: "missing cluster_name", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{Kube: &isMFARequiredKube{}} }, }, { name: "valid windows desktop req", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{ WindowsDesktop: &isMFARequiredWindowsDesktop{ DesktopName: "name", Login: "login", }, } }, }, { name: "invalid windows desktop req", errMsg: "missing desktop_name", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{WindowsDesktop: &isMFARequiredWindowsDesktop{}} }, }, { name: "invalid empty req", errMsg: "missing target", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{} }, }, { name: "invalid multi field", errMsg: "only one target is allowed", getRequest: func() isMFARequiredRequest { return isMFARequiredRequest{ Kube: &isMFARequiredKube{ ClusterName: "name", }, Node: &isMFARequiredNode{ NodeName: "name", Login: "login", }, } }, }, } { test := test t.Run(test.name, func(t *testing.T) { endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "mfa", "required") re, err := pack.clt.PostJSON(ctx, endpoint, test.getRequest()) if test.errMsg != "" { require.True(t, trace.IsBadParameter(err), "isMFARequired returned err = %v (%T), wanted trace.BadParameter", err, err) require.ErrorContains(t, err, test.errMsg) return } require.NoError(t, err) resp := isMfaRequiredResponse{} require.NoError(t, json.Unmarshal(re.Bytes(), &resp)) require.True(t, resp.Required, "isMFARequired returned response with unexpected value for Required field") }) } } func TestWithLimiterHandlerFunc(t *testing.T) { const burst = 20 limiter, err := limiter.NewRateLimiter(limiter.Config{ Rates: []limiter.Rate{ { Period: time.Minute, Average: 10, Burst: burst, }, }, Clock: &timetools.FreezedTime{ CurrentTime: time.Date(2016, 6, 5, 4, 3, 2, 1, time.UTC), }, }) require.NoError(t, err) h := &Handler{limiter: limiter} hf := h.WithLimiterHandlerFunc(func(http.ResponseWriter, *http.Request, httprouter.Params) (interface{}, error) { return nil, nil }) // Verify that a valid burst is allowed. r := &http.Request{} for i := 0; i < burst; i++ { r.RemoteAddr = fmt.Sprintf("127.0.0.1:%v", i) _, err = hf(nil, r, nil) require.NoError(t, err, "WithLimiterHandlerFunc failed unexpectedly") } // Verify that exceeding the limit causes errors. r.RemoteAddr = fmt.Sprintf("127.0.0.1:%v", burst) _, err = hf(nil, r, nil) require.True(t, trace.IsLimitExceeded(err), "WithLimiterHandlerFunc returned err = %T, want trace.LimitExceededError", err) } // kubeClusterConfig defines the cluster to be created type kubeClusterConfig struct { name string apiEndpoint string } func newKubeConfigFile(ctx context.Context, t *testing.T, clusters ...kubeClusterConfig) string { tmpDir := t.TempDir() kubeConf := clientcmdapi.NewConfig() for _, cluster := range clusters { kubeConf.Clusters[cluster.name] = &clientcmdapi.Cluster{ Server: cluster.apiEndpoint, InsecureSkipTLSVerify: true, } kubeConf.AuthInfos[cluster.name] = &clientcmdapi.AuthInfo{} kubeConf.Contexts[cluster.name] = &clientcmdapi.Context{ Cluster: cluster.name, AuthInfo: cluster.name, } } kubeConfigLocation := filepath.Join(tmpDir, "kubeconfig") err := clientcmd.WriteToFile(*kubeConf, kubeConfigLocation) require.NoError(t, err) return kubeConfigLocation } type startKubeOptions struct { clusters []kubeClusterConfig authServer *auth.TestTLSServer revTunnel reversetunnelclient.Server serviceType kubeproxy.KubeServiceType } func startKube(ctx context.Context, t *testing.T, cfg startKubeOptions) net.Addr { server, cleanup, addr := startKubeWithoutCleanup(ctx, t, cfg) t.Cleanup(func() { err := server.Close() require.NoError(t, err) require.NoError(t, cleanup()) }) return addr } type cleanupFunc func() error func startKubeWithoutCleanup(ctx context.Context, t *testing.T, cfg startKubeOptions) (*kubeproxy.TLSServer, cleanupFunc, net.Addr) { role := types.RoleProxy if cfg.serviceType == kubeproxy.KubeService { role = types.RoleKube } var kubeConfigLocation string if len(cfg.clusters) > 0 { kubeConfigLocation = newKubeConfigFile(ctx, t, cfg.clusters...) } keyGen := tlsutils.New(ctx) hostID := uuid.New().String() // heartbeatsWaitChannel waits for clusters heartbeats to start. heartbeatsWaitChannel := make(chan struct{}, len(cfg.clusters)) client, err := cfg.authServer.NewClient(auth.TestServerID(role, hostID)) require.NoError(t, err) // Auth client, lock watcher and authorizer for Kube proxy. proxyAuthClient, err := cfg.authServer.NewClient(auth.TestBuiltin(types.RoleProxy)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyAuthClient.Close()) }) proxyLockWatcher, err := services.NewLockWatcher(ctx, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, Client: proxyAuthClient, }, }) require.NoError(t, err) proxyAuthorizer, err := authz.NewAuthorizer(authz.AuthorizerOpts{ ClusterName: cfg.authServer.ClusterName(), AccessPoint: proxyAuthClient, LockWatcher: proxyLockWatcher, }) require.NoError(t, err) // TLS config for kube proxy and Kube service. authID := auth.IdentityID{ Role: role, HostUUID: hostID, NodeName: "kube_server", } dns := []string{"localhost", "127.0.0.1", "kube." + constants.APIDomain, "*" + constants.APIDomain} identity, err := auth.LocalRegister(authID, cfg.authServer.Auth(), nil, dns, "", nil) require.NoError(t, err) tlsConfig, err := identity.TLSConfig(nil) require.NoError(t, err) component := teleport.Component(teleport.ComponentProxy, teleport.ComponentProxyKube) if cfg.serviceType == kubeproxy.KubeService { component = teleport.ComponentKube } proxySigner := &mockPROXYSigner{} if cfg.serviceType == kubeproxy.KubeService { proxySigner = nil } clock := clockwork.NewRealClock() watcher, err := services.NewKubeServerWatcher(ctx, services.KubeServerWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: component, Log: log, Client: client, Clock: clock, }, }) require.NoError(t, err) kubeServer, err := kubeproxy.NewTLSServer(kubeproxy.TLSServerConfig{ ForwarderConfig: kubeproxy.ForwarderConfig{ Namespace: apidefaults.Namespace, Keygen: keyGen, ClusterName: cfg.authServer.ClusterName(), Authz: proxyAuthorizer, AuthClient: client, Emitter: client, DataDir: t.TempDir(), CachingAuthClient: client, HostID: hostID, Context: ctx, KubeconfigPath: kubeConfigLocation, KubeServiceType: cfg.serviceType, Component: component, LockWatcher: proxyLockWatcher, ReverseTunnelSrv: cfg.revTunnel, PROXYSigner: proxySigner, // skip Impersonation validation CheckImpersonationPermissions: func(ctx context.Context, clusterName string, sarClient authztypes.SelfSubjectAccessReviewInterface) error { return nil }, ConnTLSConfig: tlsConfig, Clock: clockwork.NewRealClock(), ClusterFeatures: func() authproto.Features { return authproto.Features{ Kubernetes: true, } }, }, TLS: tlsConfig, AccessPoint: client, DynamicLabels: nil, LimiterConfig: limiter.Config{ MaxConnections: 1000, MaxNumberOfUsers: 1000, }, // each time heartbeat is called we insert data into the channel. // this is used to make sure that heartbeat started and the clusters // are registered in the auth server OnHeartbeat: func(err error) { select { case heartbeatsWaitChannel <- struct{}{}: default: } }, GetRotation: func(role types.SystemRole) (*types.Rotation, error) { return &types.Rotation{}, nil }, ResourceMatchers: nil, OnReconcile: func(kc types.KubeClusters) {}, KubernetesServersWatcher: watcher, }) require.NoError(t, err) listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) errChan := make(chan error, 1) go func() { defer close(errChan) err := kubeServer.Serve(listener) // ignore server closed error returned when .Close is called. if errors.Is(err, http.ErrServerClosed) { return } errChan <- err }() // wait for the watcher to init or it may race with test cleanup. require.NoError(t, watcher.WaitInitialization()) // Waits for len(clusters) heartbeats to start heartbeatsToExpect := len(cfg.clusters) for i := 0; i < heartbeatsToExpect; i++ { <-heartbeatsWaitChannel } return kubeServer, func() error { return <-errChan }, listener.Addr() } func marshalRBACError(t *testing.T, w http.ResponseWriter) { status := &metav1.Status{ Message: "pods is forbidden: User \"USER\" cannot list resource \"pods\" in API group \"\" in the namespace \"default\"", Code: http.StatusForbidden, Reason: metav1.StatusReasonForbidden, Status: metav1.StatusFailure, } data, err := runtime.Encode(statusCodecs.LegacyCodec(), status) require.NoError(t, err) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusInternalServerError) _, err = w.Write(data) require.NoError(t, err) } func marshalValidPodList(t *testing.T, w http.ResponseWriter) { result := &corev1.PodList{ TypeMeta: metav1.TypeMeta{ Kind: "PodList", APIVersion: "v1", }, ListMeta: metav1.ListMeta{ SelfLink: "", ResourceVersion: "1231415", Continue: "", RemainingItemCount: nil, }, Items: []corev1.Pod{}, } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) err := json.NewEncoder(w).Encode(result) require.NoError(t, err) } // statusScheme is private scheme for the decoding here until someone fixes the TODO in NewConnection var statusScheme = runtime.NewScheme() // ParameterCodec knows about query parameters used with the meta v1 API spec. var statusCodecs = serializer.NewCodecFactory(statusScheme) func init() { statusScheme.AddUnversionedTypes(metav1.SchemeGroupVersion, &metav1.Status{}, ) } // TestForwardingTraces checks that the userContext includes the ID of the // access request after it has been consumed and the web session has been renewed. func TestForwardingTraces(t *testing.T) { t.Parallel() env := newWebPack(t, 1) p := env.proxies[0] newRequest := func(t *testing.T) *http.Request { req, err := http.NewRequest(http.MethodGet, "", nil) require.NoError(t, err) return req } // Span captured from the UI which was marshaled by opentelemetry-js. const rawSpan = `{"resourceSpans":[{"resource":{"attributes":[{"key":"service.name","value":{"stringValue":"web-ui"}},{"key":"telemetry.sdk.language","value":{"stringValue":"webjs"}},{"key":"telemetry.sdk.name","value":{"stringValue":"opentelemetry"}},{"key":"telemetry.sdk.version","value":{"stringValue":"1.7.0"}},{"key":"service.version","value":{"stringValue":"0.1.0"}}],"droppedAttributesCount":0},"scopeSpans":[{"scope":{"name":"@opentelemetry/instrumentation-fetch","version":"0.33.0"},"spans":[{"traceId":"255c8d876e7dbf3707ee8451ad518652","spanId":"d9edec516e598d8c","name":"HTTP GET","kind":3,"startTimeUnixNano":1668606426497000000,"endTimeUnixNano":1668502943215499800,"attributes":[{"key":"component","value":{"stringValue":"fetch"}},{"key":"http.method","value":{"stringValue":"GET"}},{"key":"http.url","value":{"stringValue":"https://proxy.example.com/v1/webapi/user/status"}},{"key":"http.status_code","value":{"intValue":0}},{"key":"http.status_text","value":{"stringValue":"Failed to fetch"}},{"key":"http.host","value":{"stringValue":"proxy.example.com"}},{"key":"http.scheme","value":{"stringValue":"https"}},{"key":"http.user_agent","value":{"stringValue":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36 "}},{"key":"http.response_content_length","value":{"intValue":0}}],"droppedAttributesCount":0,"events":[{"attributes":[],"name":"fetchStart","timeUnixNano":1668502943210900000,"droppedAttributesCount":0},{"attributes":[],"name":"domainLookupStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"domainLookupEnd","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"connectStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"secureConnectionStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"connectEnd","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"requestStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"responseStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"responseEnd","timeUnixNano":1668502943215100000,"droppedAttributesCount":0}],"droppedEventsCount":0,"status":{"code":0},"links":[],"droppedLinksCount":0}]}]}]}` // dummy span with arbitrary data, needed to be able to protojson.Marshal in tests span := &tracepb.TracesData{ ResourceSpans: []*tracepb.ResourceSpans{ { Resource: &resourcev1.Resource{ Attributes: []*commonv1.KeyValue{ { Key: "test", Value: &commonv1.AnyValue{ Value: &commonv1.AnyValue_IntValue{ IntValue: 0, }, }, }, }, }, ScopeSpans: []*tracepb.ScopeSpans{ { Spans: []*tracepb.Span{ { TraceId: []byte{1, 2, 3, 4}, SpanId: []byte{5, 6, 7, 8}, TraceState: "", ParentSpanId: []byte{9, 10, 11, 12}, Name: "test", Kind: tracepb.Span_SPAN_KIND_CLIENT, StartTimeUnixNano: uint64(time.Now().Add(-1 * time.Minute).Unix()), EndTimeUnixNano: uint64(time.Now().Unix()), Attributes: []*commonv1.KeyValue{ { Key: "test", Value: &commonv1.AnyValue{ Value: &commonv1.AnyValue_IntValue{ IntValue: 11, }, }, }, }, Status: &tracepb.Status{ Message: "success!", Code: tracepb.Status_STATUS_CODE_OK, }, }, }, }, }, }, }, } cases := []struct { name string req func(t *testing.T) *http.Request assertion func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) }{ { name: "no data", req: func(t *testing.T) *http.Request { r := newRequest(t) r.Body = io.NopCloser(&bytes.Buffer{}) return r }, assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { require.NoError(t, err) require.Equal(t, http.StatusBadRequest, code) require.Empty(t, spans) }, }, { name: "invalid data", req: func(t *testing.T) *http.Request { r := newRequest(t) r.Body = io.NopCloser(strings.NewReader(`{"test": "abc"}`)) return r }, assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { require.NoError(t, err) require.Equal(t, http.StatusBadRequest, code) require.Empty(t, spans) }, }, { name: "no traces", req: func(t *testing.T) *http.Request { r := newRequest(t) raw, err := protojson.Marshal(&tracepb.ResourceSpans{}) require.NoError(t, err) r.Body = io.NopCloser(bytes.NewBuffer(raw)) return r }, assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { require.NoError(t, err) require.Equal(t, http.StatusBadRequest, code) require.Empty(t, spans) }, }, { name: "traces with base64 encoded ids", req: func(t *testing.T) *http.Request { r := newRequest(t) // Since the id fields of the span are all []byte, // protojson will marshal them into base64 raw, err := protojson.Marshal(span) require.NoError(t, err) r.Body = io.NopCloser(bytes.NewBuffer(raw)) return r }, assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { require.NoError(t, err) require.Equal(t, http.StatusOK, code) require.Len(t, spans, 1) require.Empty(t, cmp.Diff(span.ResourceSpans[0], spans[0], protocmp.Transform())) }, }, { name: "traces with hex encoded ids", req: func(t *testing.T) *http.Request { r := newRequest(t) // The id fields are hex encoded instead of base64 encoded // by opentelemetry-js for the rawSpan r.Body = io.NopCloser(strings.NewReader(rawSpan)) return r }, assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { require.NoError(t, err) require.Equal(t, http.StatusOK, code) require.Len(t, spans, 1) var data tracepb.TracesData require.NoError(t, protojson.Unmarshal([]byte(rawSpan), &data)) // compare the spans, but ignore the ids since we know that the rawSpan // has hex encoded ids and protojson.Unmarshal will give us an invalid value require.Empty(t, cmp.Diff(data.ResourceSpans[0], spans[0], protocmp.Transform(), protocmp.IgnoreFields(&tracepb.Span{}, "span_id", "trace_id"))) // compare the ids separately sid1 := spans[0].ScopeSpans[0].Spans[0].SpanId tid1 := spans[0].ScopeSpans[0].Spans[0].TraceId sid2 := data.ResourceSpans[0].ScopeSpans[0].Spans[0].SpanId tid2 := data.ResourceSpans[0].ScopeSpans[0].Spans[0].TraceId require.Equal(t, hex.EncodeToString(sid1), base64.StdEncoding.EncodeToString(sid2)) require.Equal(t, hex.EncodeToString(tid1), base64.StdEncoding.EncodeToString(tid2)) }, }, } // NOTE: resetting the tracing client prevents // the test cases from running in parallel for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { clt := &mockTraceClient{ uploadReceived: make(chan struct{}), } p.handler.handler.cfg.TraceClient = clt recorder := httptest.NewRecorder() // use the handler directly because there is no easy way to pipe in our tracing // data using the pack client in a format that would match the ui. _, err := p.handler.handler.traces(recorder, tt.req(t), nil, nil) // if traces weren't uploaded perform the assertion // without waiting for traces to be forwarded if err != nil || recorder.Code != http.StatusOK { tt.assertion(t, clt.spans, err, recorder.Code) return } // traces are forwarded in a goroutine, wait for them // to be received by the trace client before doing the // assertion select { case <-clt.uploadReceived: case <-time.After(10 * time.Second): t.Fatal("Timed out waiting for traces to be uploaded") } tt.assertion(t, clt.spans, err, recorder.Code) }) } } type mockPROXYSigner struct{} func (m *mockPROXYSigner) SignPROXYHeader(source, destination net.Addr) ([]byte, error) { return nil, nil } type mockTraceClient struct { uploadError error uploadReceived chan struct{} spans []*tracepb.ResourceSpans } func (m *mockTraceClient) Start(ctx context.Context) error { return nil } func (m *mockTraceClient) Stop(ctx context.Context) error { return nil } func (m *mockTraceClient) UploadTraces(ctx context.Context, protoSpans []*tracepb.ResourceSpans) error { m.spans = append(m.spans, protoSpans...) m.uploadReceived <- struct{}{} return m.uploadError } func TestLogout(t *testing.T) { ctx := context.Background() t.Parallel() env := newWebPack(t, 2) // create a logged in user for proxy 1 pack := env.proxies[0].authPack(t, "llama", nil /* roles */) // ensure the client is authenticated re, err := pack.clt.Get(ctx, pack.clt.Endpoint("webapi", "sites"), url.Values{}) require.NoError(t, err) var clusters []ui.Cluster require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) require.Len(t, clusters, 1) // create a client for proxy 2 with the token and cookies from proxy 1 jar, err := cookiejar.New(nil) require.NoError(t, err) jar.SetCookies(&env.proxies[1].webURL, pack.cookies) clt2 := env.proxies[1].newClient(t, roundtrip.BearerAuth(pack.session.Token), roundtrip.CookieJar(jar)) // ensure the second client is authenticated re, err = clt2.Get(ctx, clt2.Endpoint("webapi", "sites"), url.Values{}) require.NoError(t, err) require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) require.Len(t, clusters, 1) // logout from proxy 1 _, err = pack.clt.Delete(ctx, pack.clt.Endpoint("webapi", "sessions", "web")) require.NoError(t, err) // ensure proxy 1 invalidated the session _, err = pack.clt.Get(ctx, pack.clt.Endpoint("webapi", "sites"), url.Values{}) require.Error(t, err) require.ErrorIs(t, err, trace.AccessDenied("missing session cookie")) // should still be authenticated to proxy 2 until the expiration loop kicks in re, err = clt2.Get(ctx, clt2.Endpoint("webapi", "sites"), url.Values{}) require.NoError(t, err) require.NoError(t, json.Unmarshal(re.Bytes(), &clusters)) require.Len(t, clusters, 1) // advance the clock to fire the expiration ticker env.clock.Advance(time.Second) // wait for the expiration loop to purge the session require.Eventually(t, func() bool { return env.proxies[1].handler.handler.auth.ActiveSessions() == 0 }, 5*time.Second, 100*time.Millisecond) // ensure proxy 2 invalidated the session _, err = clt2.Get(ctx, clt2.Endpoint("webapi", "sites"), url.Values{}) require.True(t, trace.IsAccessDenied(err)) require.ErrorIs(t, err, trace.AccessDenied("need auth")) } func TestGetIsDashboard(t *testing.T) { tt := []struct { name string features authproto.Features expected bool }{ { name: "not cloud nor recovery codes is not dashboard", features: authproto.Features{ Cloud: false, RecoveryCodes: false, }, expected: false, }, { name: "not cloud, with recovery codes is dashboard", features: authproto.Features{ Cloud: false, RecoveryCodes: true, }, expected: true, }, { name: "cloud, with recovery codes is not dashboard", features: authproto.Features{ Cloud: true, RecoveryCodes: true, }, expected: false, }, { name: "cloud, without recovery codes is not dashboard", features: authproto.Features{ Cloud: true, RecoveryCodes: false, }, expected: false, }, } for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { result := isDashboard(tc.features) require.Equal(t, tc.expected, result) }) } } // initGRPCServer creates a grpc server serving on the provided listener. func initGRPCServer(t *testing.T, env *webPack, listener net.Listener) { clusterName := env.server.ClusterName() // Auth client, lock watcher and authorizer for Kube proxy. proxyAuthClient, err := env.server.TLS.NewClient(auth.TestBuiltin(types.RoleProxy)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, proxyAuthClient.Close()) }) serverIdentity, err := auth.NewServerIdentity(env.server.Auth(), uuid.NewString(), types.RoleProxy) require.NoError(t, err) tlsConfig, err := serverIdentity.TLSConfig(nil) require.NoError(t, err) limiter, err := limiter.NewLimiter(limiter.Config{MaxConnections: 100}) require.NoError(t, err) // authMiddleware authenticates request assuming TLS client authentication // adds authentication information to the context // and passes it to the API server authMiddleware := &auth.Middleware{ ClusterName: clusterName, Limiter: limiter, AcceptedUsage: []string{teleport.UsageKubeOnly}, } tlsConf := copyAndConfigureTLS(tlsConfig, logrus.New(), proxyAuthClient, clusterName) creds, err := auth.NewTransportCredentials(auth.TransportCredentialsConfig{ TransportCredentials: credentials.NewTLS(tlsConf), UserGetter: authMiddleware, }) require.NoError(t, err) grpcServer := grpc.NewServer( grpc.ChainUnaryInterceptor( authMiddleware.UnaryInterceptor(), ), grpc.ChainStreamInterceptor( authMiddleware.StreamInterceptor(), ), grpc.Creds(creds), ) kubeproto.RegisterKubeServiceServer(grpcServer, &fakeKubeService{}) errC := make(chan error, 1) t.Cleanup(func() { grpcServer.GracefulStop() require.NoError(t, <-errC) }) go func() { err := grpcServer.Serve(listener) errC <- trace.Wrap(err) }() } // copyAndConfigureTLS can be used to copy and modify an existing *tls.Config // for Teleport application proxy servers. func copyAndConfigureTLS(config *tls.Config, log logrus.FieldLogger, accessPoint auth.AccessCache, clusterName string) *tls.Config { tlsConfig := config.Clone() // Require clients to present a certificate tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert // Configure function that will be used to fetch the CA that signed the // client's certificate to verify the chain presented. If the client does not // pass in the cluster name, this functions pulls back all CA to try and // match the certificate presented against any CA. tlsConfig.GetConfigForClient = auth.WithClusterCAs(tlsConfig.Clone(), accessPoint, clusterName, log) return tlsConfig } type fakeKubeService struct { kubeproto.UnimplementedKubeServiceServer } func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kubeproto.ListKubernetesResourcesRequest) (*kubeproto.ListKubernetesResourcesResponse, error) { return &kubeproto.ListKubernetesResourcesResponse{ Resources: []*types.KubernetesResourceV1{ { Kind: types.KindKubePod, Metadata: types.Metadata{ Name: "test-pod", Labels: map[string]string{ "app": "test", }, }, Spec: types.KubernetesResourceSpecV1{ Namespace: "default", }, }, { Kind: types.KindKubePod, Metadata: types.Metadata{ Name: "test-pod2", Labels: map[string]string{ "app": "test2", }, }, Spec: types.KubernetesResourceSpecV1{ Namespace: "default", }, }, }, TotalCount: 2, }, nil } // TestSimultaneousAuthenticateRequest ensures that multiple authenticated // requests do not race to create a SessionContext. This would happen when // Proxies were deployed behind a round-robin load balancer. Only the Proxy // that handled the login will have initially created a SessionContext for // the particular user+session. All subsequent requests to the other Proxies // in the load balancer pool attempt to create a SessionContext in // [Handler.AuthenticateRequest] if one didn't already exist. If the web UI // makes enough requests fast enough it can result in the Proxy trying to // create multiple SessionContext for a user+session. Since only one SessionContext // is stored in the sessionCache all previous SessionContext and their underlying // auth client get closed, which results in an ugly and unfriendly // `grpc: the client connection is closing` error banner on the web UI. func TestSimultaneousAuthenticateRequest(t *testing.T) { ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] // Authenticate to get a session token and cookies. pack := proxy.authPack(t, "test-user@example.com", nil) // Reset the sessions so that all future requests will race to create // a new SessionContext for the user + session pair to simulate multiple // proxies behind a load balancer. proxy.handler.handler.auth.sessions = map[string]*SessionContext{} // Create a request with the auth header and cookies for the session. endpoint := pack.clt.Endpoint("webapi", "sites") req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer "+pack.session.Token) for _, cookie := range pack.cookies { req.AddCookie(cookie) } // Spawn several requests in parallel and attempt to use the auth client. type res struct { domain string err error } const requests = 10 respC := make(chan res, requests) for i := 0; i < requests; i++ { go func() { sctx, err := proxy.handler.handler.AuthenticateRequest(httptest.NewRecorder(), req.Clone(ctx), false) if err != nil { respC <- res{err: err} return } clt, err := sctx.GetClient() if err != nil { respC <- res{err: err} return } domain, err := clt.GetDomainName(ctx) respC <- res{domain: domain, err: err} }() } // Assert that all requests were successful and each one was able to // get the domain name without its auth client being closed. for i := 0; i < requests; i++ { select { case res := <-respC: require.NoError(t, res.err) require.Equal(t, "localhost", res.domain) case <-time.After(5 * time.Second): t.Fatal("timed out waiting for responses") } } } // mockedPingTestProxy is a test proxy with a mocked Ping method type mockedPingTestProxy struct { auth.ClientI mockedPing func(ctx context.Context) (authproto.PingResponse, error) } func (m mockedPingTestProxy) Ping(ctx context.Context) (authproto.PingResponse, error) { return m.mockedPing(ctx) } // TestModeratedSession validates that peers are able to start Moderated // Sessions and remain in the waiting room until the required number of // moderators are present. Only when the moderator is present the peer // is allowed to access the host and start entering input and receiving // output until the moderator terminates the session. func TestModeratedSession(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise}) ctx := context.Background() s := newWebSuiteWithConfig(t, webSuiteConfig{disableDiskBasedRecording: true}) peerRole, err := types.NewRole("moderated", types.RoleSpecV6{ Allow: types.RoleConditions{ RequireSessionJoin: []*types.SessionRequirePolicy{ { Name: "moderated", Filter: "contains(user.roles, \"moderator\")", Kinds: []string{string(types.SSHSessionKind)}, Count: 1, Modes: []string{string(types.SessionModeratorMode)}, }, }, }, }) require.NoError(t, err) require.NoError(t, s.server.Auth().UpsertRole(s.ctx, peerRole)) moderatorRole, err := types.NewRole("moderator", types.RoleSpecV6{ Allow: types.RoleConditions{ JoinSessions: []*types.SessionJoinPolicy{ { Name: "moderated", Roles: []string{peerRole.GetName()}, Kinds: []string{string(types.SSHSessionKind)}, Modes: []string{string(types.SessionModeratorMode), string(types.SessionObserverMode)}, }, }, }, }) require.NoError(t, err) require.NoError(t, s.server.Auth().UpsertRole(s.ctx, moderatorRole)) peer := s.authPack(t, "foo", peerRole.GetName()) peerWS, sess, err := s.makeTerminal(t, peer) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, peerWS.Close()) }) peerStream := NewTerminalStream(ctx, peerWS, utils.NewLoggerForTests()) require.NoError(t, waitForOutput(peerStream, "Teleport > User foo joined the session with participant mode: peer.")) moderator := s.authPack(t, "bar", moderatorRole.GetName()) moderatorWS, _, err := s.makeTerminal(t, moderator, withSessionID(sess.ID), withParticipantMode(types.SessionModeratorMode)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, moderatorWS.Close()) }) moderatorStream := NewTerminalStream(ctx, moderatorWS, utils.NewLoggerForTests()) require.NoError(t, waitForOutput(peerStream, "Teleport > Connecting to node over SSH")) // here we intentionally run a command where the output we're looking // for is not present in the command itself _, err = io.WriteString(peerStream, "echo llxmx | sed 's/x/a/g'\r\n") require.NoError(t, err) require.NoError(t, waitForOutput(peerStream, "llama")) require.NoError(t, waitForOutput(moderatorStream, "llama")) // the moderator terminates the session _, err = io.WriteString(moderatorStream, "t") require.NoError(t, err) require.NoError(t, waitForOutput(moderatorStream, "Stopping session...")) require.NoError(t, waitForOutput(peerStream, "Process exited with status 255")) } // TestModeratedSessionWithMFA validates the same behavior as TestModeratedSession while // also ensuring that MFA is performed prior to accessing the host and that periodic // presence checks are performed by the moderator. When presence checks are not performed // the session is aborted. func TestModeratedSessionWithMFA(t *testing.T) { modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise}) ctx := context.Background() const RPID = "localhost" s := newWebSuiteWithConfig(t, webSuiteConfig{ disableDiskBasedRecording: true, authPreferenceSpec: &types.AuthPreferenceSpecV2{ Type: constants.Local, ConnectorName: constants.PasswordlessConnector, SecondFactor: constants.SecondFactorOn, RequireMFAType: types.RequireMFAType_SESSION, Webauthn: &types.Webauthn{ RPID: RPID, }, }, }) peerRole, err := types.NewRole("moderated", types.RoleSpecV6{ Allow: types.RoleConditions{ RequireSessionJoin: []*types.SessionRequirePolicy{ { Name: "moderated", Filter: "contains(user.roles, \"moderator\")", Kinds: []string{string(types.SSHSessionKind)}, Count: 1, Modes: []string{string(types.SessionModeratorMode)}, }, }, }, }) require.NoError(t, err) moderatorRole, err := types.NewRole("moderator", types.RoleSpecV6{ Allow: types.RoleConditions{ JoinSessions: []*types.SessionJoinPolicy{ { Name: "moderated", Roles: []string{peerRole.GetName()}, Kinds: []string{string(types.SSHSessionKind)}, Modes: []string{string(types.SessionModeratorMode), string(types.SessionObserverMode)}, }, }, }, }) require.NoError(t, err) peer := s.authPackWithMFA(t, "foo", peerRole) moderator := s.authPackWithMFA(t, "bar", moderatorRole) peerWS, sess, err := s.makeTerminal(t, peer) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, peerWS.Close()) }) handleMFAWebauthnChallenge(t, peerWS, peer.device) peerStream := NewTerminalStream(ctx, peerWS, utils.NewLoggerForTests()) require.NoError(t, waitForOutput(peerStream, "Teleport > User foo joined the session with participant mode: peer.")) moderatorWS, _, err := s.makeTerminal(t, moderator, withSessionID(sess.ID), withParticipantMode(types.SessionModeratorMode)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, moderatorWS.Close()) }) handleMFAWebauthnChallenge(t, moderatorWS, moderator.device) moderatorStream := NewTerminalStream(ctx, moderatorWS, utils.NewLoggerForTests()) require.NoError(t, waitForOutput(peerStream, "Teleport > Connecting to node over SSH")) // here we intentionally run a command where the output we're looking // for is not present in the command itself _, err = io.WriteString(peerStream, "echo llxmx | sed 's/x/a/g'\r\n") require.NoError(t, err) require.NoError(t, waitForOutput(peerStream, "llama")) require.NoError(t, waitForOutput(moderatorStream, "llama")) // run the presence check a few times for i := 0; i < 3; i++ { s.clock.Advance(30 * time.Second) require.NoError(t, waitForOutput(moderatorStream, "Teleport > Please tap your MFA key")) challenge, err := moderatorStream.readChallenge(protobufMFACodec{}) require.NoError(t, err) res, err := moderator.device.SolveAuthn(challenge) require.NoError(t, err) webauthnResBytes, err := json.Marshal(wanlib.CredentialAssertionResponseFromProto(res.GetWebauthn())) require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketWebauthnChallenge, Payload: string(webauthnResBytes), } envelopeBytes, err := proto.Marshal(envelope) require.NoError(t, err) require.NoError(t, moderatorWS.WriteMessage(websocket.BinaryMessage, envelopeBytes)) } // advance the clock far enough in the future to make the moderator stale // which will terminate the session s.clock.Advance(180 * time.Second) require.NoError(t, waitForOutput(moderatorStream, "wait: remote command exited without exit status or exit signal")) require.NoError(t, waitForOutput(peerStream, "Process exited with status 255")) } func handleMFAWebauthnChallenge(t *testing.T, ws *websocket.Conn, dev *auth.TestDevice) { // Wait for websocket authn challenge event. ty, raw, err := ws.ReadMessage() require.NoError(t, err) require.Equal(t, websocket.BinaryMessage, ty) var env Envelope require.NoError(t, proto.Unmarshal(raw, &env)) var challenge client.MFAAuthenticateChallenge require.NoError(t, json.Unmarshal([]byte(env.Payload), &challenge)) res, err := dev.SolveAuthn(&authproto.MFAAuthenticateChallenge{ WebauthnChallenge: wanlib.CredentialAssertionToProto(challenge.WebauthnChallenge), }) require.NoError(t, err) webauthnResBytes, err := json.Marshal(wanlib.CredentialAssertionResponseFromProto(res.GetWebauthn())) require.NoError(t, err) envelope := &Envelope{ Version: defaults.WebsocketVersion, Type: defaults.WebsocketWebauthnChallenge, Payload: string(webauthnResBytes), } envelopeBytes, err := proto.Marshal(envelope) require.NoError(t, err) require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, envelopeBytes)) }