diff --git a/entitlements/entitlements.go b/entitlements/entitlements.go index e524b3037f0..63ca0f5b6e0 100644 --- a/entitlements/entitlements.go +++ b/entitlements/entitlements.go @@ -16,6 +16,8 @@ package entitlements +import "github.com/gravitational/teleport/api/client/proto" + type EntitlementKind string // The EntitlementKind list should be 1:1 with the Features & FeatureStrings in salescenter/product/product.go, @@ -57,3 +59,67 @@ var AllEntitlements = []EntitlementKind{ ExternalAuditStorage, FeatureHiding, HSM, Identity, JoinActiveSessions, K8s, MobileDeviceManagement, OIDC, OktaSCIM, OktaUserSync, Policy, SAML, SessionLocks, UpsellAlert, UsageReporting, LicenseAutoUpdate, } + +// BackfillFeatures ensures entitlements are backwards compatible. +// If Entitlements are present, there are no changes. +// If Entitlements are not present, it sets the entitlements based on legacy field values. +// TODO(michellescripts) DELETE IN 18.0.0 +func BackfillFeatures(features *proto.Features) { + if len(features.Entitlements) > 0 { + return + } + + features.Entitlements = getBaseEntitlements(features.GetEntitlements()) + + // Entitlements: All records are {enabled: false}; update to equal legacy feature value + features.Entitlements[string(ExternalAuditStorage)] = &proto.EntitlementInfo{Enabled: features.GetExternalAuditStorage()} + features.Entitlements[string(FeatureHiding)] = &proto.EntitlementInfo{Enabled: features.GetFeatureHiding()} + features.Entitlements[string(Identity)] = &proto.EntitlementInfo{Enabled: features.GetIdentityGovernance()} + features.Entitlements[string(JoinActiveSessions)] = &proto.EntitlementInfo{Enabled: features.GetJoinActiveSessions()} + features.Entitlements[string(MobileDeviceManagement)] = &proto.EntitlementInfo{Enabled: features.GetMobileDeviceManagement()} + features.Entitlements[string(OIDC)] = &proto.EntitlementInfo{Enabled: features.GetOIDC()} + features.Entitlements[string(Policy)] = &proto.EntitlementInfo{Enabled: features.GetPolicy().GetEnabled()} + features.Entitlements[string(SAML)] = &proto.EntitlementInfo{Enabled: features.GetSAML()} + features.Entitlements[string(K8s)] = &proto.EntitlementInfo{Enabled: features.GetKubernetes()} + features.Entitlements[string(App)] = &proto.EntitlementInfo{Enabled: features.GetApp()} + features.Entitlements[string(DB)] = &proto.EntitlementInfo{Enabled: features.GetDB()} + features.Entitlements[string(Desktop)] = &proto.EntitlementInfo{Enabled: features.GetDesktop()} + features.Entitlements[string(HSM)] = &proto.EntitlementInfo{Enabled: features.GetHSM()} + + // set default Identity fields to legacy feature value + features.Entitlements[string(AccessLists)] = &proto.EntitlementInfo{Enabled: true, Limit: features.GetAccessList().GetCreateLimit()} + features.Entitlements[string(AccessMonitoring)] = &proto.EntitlementInfo{Enabled: features.GetAccessMonitoring().GetEnabled(), Limit: features.GetAccessMonitoring().GetMaxReportRangeLimit()} + features.Entitlements[string(AccessRequests)] = &proto.EntitlementInfo{Enabled: features.GetAccessRequests().MonthlyRequestLimit > 0, Limit: features.GetAccessRequests().GetMonthlyRequestLimit()} + features.Entitlements[string(DeviceTrust)] = &proto.EntitlementInfo{Enabled: features.GetDeviceTrust().GetEnabled(), Limit: features.GetDeviceTrust().GetDevicesUsageLimit()} + // override Identity Package features if Identity is enabled: set true and clear limit + if features.GetIdentityGovernance() { + features.Entitlements[string(AccessLists)] = &proto.EntitlementInfo{Enabled: true} + features.Entitlements[string(AccessMonitoring)] = &proto.EntitlementInfo{Enabled: true} + features.Entitlements[string(AccessRequests)] = &proto.EntitlementInfo{Enabled: true} + features.Entitlements[string(DeviceTrust)] = &proto.EntitlementInfo{Enabled: true} + features.Entitlements[string(OktaSCIM)] = &proto.EntitlementInfo{Enabled: true} + features.Entitlements[string(OktaUserSync)] = &proto.EntitlementInfo{Enabled: true} + features.Entitlements[string(SessionLocks)] = &proto.EntitlementInfo{Enabled: true} + } +} + +// getBaseEntitlements takes a cloud entitlement set and returns a modules Entitlement set +func getBaseEntitlements(protoEntitlements map[string]*proto.EntitlementInfo) map[string]*proto.EntitlementInfo { + all := AllEntitlements + result := make(map[string]*proto.EntitlementInfo, len(all)) + + for _, e := range all { + al, ok := protoEntitlements[string(e)] + if !ok { + result[string(e)] = &proto.EntitlementInfo{} + continue + } + + result[string(e)] = &proto.EntitlementInfo{ + Enabled: al.Enabled, + Limit: al.Limit, + } + } + + return result +} diff --git a/entitlements/entitlements_test.go b/entitlements/entitlements_test.go new file mode 100644 index 00000000000..2015f6efaf7 --- /dev/null +++ b/entitlements/entitlements_test.go @@ -0,0 +1,286 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package entitlements + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client/proto" + apiutils "github.com/gravitational/teleport/api/utils" +) + +func TestBackfillFeatures(t *testing.T) { + tests := []struct { + name string + features *proto.Features + expected map[string]*proto.EntitlementInfo + }{ + { + name: "entitlements present; keeps entitlement values", + features: &proto.Features{ + DeviceTrust: nil, + AccessRequests: nil, + AccessList: nil, + AccessMonitoring: nil, + Policy: nil, + CustomTheme: "", + ProductType: 0, + SupportType: 0, + Kubernetes: false, + App: false, + DB: false, + OIDC: false, + SAML: false, + AccessControls: false, + AdvancedAccessWorkflows: false, + Cloud: false, + HSM: false, + Desktop: false, + RecoveryCodes: false, + Plugins: false, + AutomaticUpgrades: false, + IsUsageBased: false, + Assist: false, + FeatureHiding: false, + IdentityGovernance: false, + AccessGraph: false, + Questionnaire: false, + IsStripeManaged: false, + ExternalAuditStorage: false, + JoinActiveSessions: false, + MobileDeviceManagement: false, + AccessMonitoringConfigured: false, + Entitlements: map[string]*proto.EntitlementInfo{ + string(AccessLists): {Enabled: true, Limit: 111}, + string(AccessMonitoring): {Enabled: true, Limit: 2113}, + string(AccessRequests): {Enabled: true, Limit: 39}, + string(App): {Enabled: false}, + string(CloudAuditLogRetention): {Enabled: true}, + string(DB): {Enabled: true}, + string(Desktop): {Enabled: true}, + string(DeviceTrust): {Enabled: true, Limit: 103}, + string(ExternalAuditStorage): {Enabled: true}, + string(FeatureHiding): {Enabled: true}, + string(HSM): {Enabled: true}, + string(Identity): {Enabled: true}, + string(JoinActiveSessions): {Enabled: true}, + string(K8s): {Enabled: true}, + string(MobileDeviceManagement): {Enabled: true}, + string(OIDC): {Enabled: true}, + string(OktaSCIM): {Enabled: true}, + string(OktaUserSync): {Enabled: true}, + string(Policy): {Enabled: true}, + string(SAML): {Enabled: true}, + string(SessionLocks): {Enabled: true}, + string(UpsellAlert): {Enabled: true}, + string(UsageReporting): {Enabled: true}, + string(LicenseAutoUpdate): {Enabled: true}, + }, + }, + expected: map[string]*proto.EntitlementInfo{ + string(AccessLists): {Enabled: true, Limit: 111}, + string(AccessMonitoring): {Enabled: true, Limit: 2113}, + string(AccessRequests): {Enabled: true, Limit: 39}, + string(App): {Enabled: false}, + string(CloudAuditLogRetention): {Enabled: true}, + string(DB): {Enabled: true}, + string(Desktop): {Enabled: true}, + string(DeviceTrust): {Enabled: true, Limit: 103}, + string(ExternalAuditStorage): {Enabled: true}, + string(FeatureHiding): {Enabled: true}, + string(HSM): {Enabled: true}, + string(Identity): {Enabled: true}, + string(JoinActiveSessions): {Enabled: true}, + string(K8s): {Enabled: true}, + string(MobileDeviceManagement): {Enabled: true}, + string(OIDC): {Enabled: true}, + string(OktaSCIM): {Enabled: true}, + string(OktaUserSync): {Enabled: true}, + string(Policy): {Enabled: true}, + string(SAML): {Enabled: true}, + string(SessionLocks): {Enabled: true}, + string(UpsellAlert): {Enabled: true}, + string(UsageReporting): {Enabled: true}, + string(LicenseAutoUpdate): {Enabled: true}, + }, + }, + { + name: "entitlements not present; identity on - sets legacy fields & drops limits", + features: &proto.Features{ + DeviceTrust: &proto.DeviceTrustFeature{ + Enabled: true, + DevicesUsageLimit: 33, + }, + AccessRequests: &proto.AccessRequestsFeature{ + MonthlyRequestLimit: 22, + }, + AccessList: &proto.AccessListFeature{ + CreateLimit: 44, + }, + AccessMonitoring: &proto.AccessMonitoringFeature{ + Enabled: true, + MaxReportRangeLimit: 55, + }, + Policy: &proto.PolicyFeature{ + Enabled: true, + }, + CustomTheme: "", + ProductType: 0, + SupportType: 0, + Kubernetes: true, + App: true, + DB: true, + OIDC: true, + SAML: true, + AccessControls: true, + AdvancedAccessWorkflows: true, + Cloud: true, + HSM: true, + Desktop: true, + RecoveryCodes: true, + Plugins: true, + AutomaticUpgrades: true, + IsUsageBased: true, + Assist: true, + FeatureHiding: true, + IdentityGovernance: true, + AccessGraph: true, + Questionnaire: true, + IsStripeManaged: true, + ExternalAuditStorage: true, + JoinActiveSessions: true, + MobileDeviceManagement: true, + AccessMonitoringConfigured: true, + }, + expected: map[string]*proto.EntitlementInfo{ + string(AccessLists): {Enabled: true}, + string(AccessMonitoring): {Enabled: true}, + string(AccessRequests): {Enabled: true}, + string(App): {Enabled: true}, + string(DB): {Enabled: true}, + string(Desktop): {Enabled: true}, + string(DeviceTrust): {Enabled: true}, + string(ExternalAuditStorage): {Enabled: true}, + string(FeatureHiding): {Enabled: true}, + string(HSM): {Enabled: true}, + string(Identity): {Enabled: true}, + string(JoinActiveSessions): {Enabled: true}, + string(K8s): {Enabled: true}, + string(MobileDeviceManagement): {Enabled: true}, + string(OIDC): {Enabled: true}, + string(OktaSCIM): {Enabled: true}, + string(OktaUserSync): {Enabled: true}, + string(Policy): {Enabled: true}, + string(SAML): {Enabled: true}, + string(SessionLocks): {Enabled: true}, + // defaults, no legacy equivalent + string(UsageReporting): {Enabled: false}, + string(UpsellAlert): {Enabled: false}, + string(CloudAuditLogRetention): {Enabled: false}, + string(LicenseAutoUpdate): {Enabled: false}, + }, + }, + { + name: "entitlements not present; identity off - sets legacy fields", + features: &proto.Features{ + DeviceTrust: &proto.DeviceTrustFeature{ + Enabled: true, + DevicesUsageLimit: 33, + }, + AccessRequests: &proto.AccessRequestsFeature{ + MonthlyRequestLimit: 22, + }, + AccessList: &proto.AccessListFeature{ + CreateLimit: 44, + }, + AccessMonitoring: &proto.AccessMonitoringFeature{ + Enabled: true, + MaxReportRangeLimit: 55, + }, + Policy: &proto.PolicyFeature{ + Enabled: true, + }, + CustomTheme: "", + ProductType: 0, + SupportType: 0, + Kubernetes: true, + App: true, + DB: true, + OIDC: true, + SAML: true, + AccessControls: true, + AdvancedAccessWorkflows: true, + Cloud: true, + HSM: true, + Desktop: true, + RecoveryCodes: true, + Plugins: true, + AutomaticUpgrades: true, + IsUsageBased: true, + Assist: true, + FeatureHiding: true, + IdentityGovernance: false, + AccessGraph: true, + Questionnaire: true, + IsStripeManaged: true, + ExternalAuditStorage: true, + JoinActiveSessions: true, + MobileDeviceManagement: true, + AccessMonitoringConfigured: true, + }, + expected: map[string]*proto.EntitlementInfo{ + string(AccessLists): {Enabled: true, Limit: 44}, + string(AccessMonitoring): {Enabled: true, Limit: 55}, + string(AccessRequests): {Enabled: true, Limit: 22}, + string(DeviceTrust): {Enabled: true, Limit: 33}, + string(App): {Enabled: true}, + string(DB): {Enabled: true}, + string(Desktop): {Enabled: true}, + string(ExternalAuditStorage): {Enabled: true}, + string(FeatureHiding): {Enabled: true}, + string(HSM): {Enabled: true}, + string(JoinActiveSessions): {Enabled: true}, + string(K8s): {Enabled: true}, + string(MobileDeviceManagement): {Enabled: true}, + string(OIDC): {Enabled: true}, + string(Policy): {Enabled: true}, + string(SAML): {Enabled: true}, + + // defaults, no legacy equivalent + string(UsageReporting): {Enabled: false}, + string(UpsellAlert): {Enabled: false}, + string(CloudAuditLogRetention): {Enabled: false}, + string(LicenseAutoUpdate): {Enabled: false}, + // Identity off, fields false + string(Identity): {Enabled: false}, + string(SessionLocks): {Enabled: false}, + string(OktaSCIM): {Enabled: false}, + string(OktaUserSync): {Enabled: false}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cloned := apiutils.CloneProtoMsg(tt.features) + + BackfillFeatures(cloned) + require.Equal(t, tt.expected, cloned.Entitlements) + }) + } +} diff --git a/lib/service/connect.go b/lib/service/connect.go index 453e241b31a..40644a73763 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -1087,7 +1087,7 @@ func (process *TeleportProcess) getConnector(clientIdentity, serverIdentity *sta // Set cluster features and return successfully with a working connector. // TODO(michellescripts) remove clone & compatibility check in v18 cloned := apiutils.CloneProtoMsg(pingResponse.GetServerFeatures()) - supportEntitlementsCompatibility(cloned) + entitlements.BackfillFeatures(cloned) process.setClusterFeatures(cloned) process.setAuthSubjectiveAddr(pingResponse.RemoteAddr) process.logger.InfoContext(process.ExitContext(), "features loaded from auth server", "identity", clientIdentity.ID.Role, "features", pingResponse.GetServerFeatures()) @@ -1096,70 +1096,6 @@ func (process *TeleportProcess) getConnector(clientIdentity, serverIdentity *sta return newConn, nil } -// supportEntitlementsCompatibility ensures entitlements are backwards compatible -// If Entitlements are present, there are no changes -// If Entitlements are not present, sets the entitlements fields to legacy field values -// TODO(michellescripts) remove in v18 -func supportEntitlementsCompatibility(features *proto.Features) { - if len(features.Entitlements) > 0 { - return - } - - features.Entitlements = getBaseEntitlements(features.GetEntitlements()) - - // Entitlements: All records are {enabled: false}; update to equal legacy feature value - features.Entitlements[string(entitlements.ExternalAuditStorage)] = &proto.EntitlementInfo{Enabled: features.GetExternalAuditStorage()} - features.Entitlements[string(entitlements.FeatureHiding)] = &proto.EntitlementInfo{Enabled: features.GetFeatureHiding()} - features.Entitlements[string(entitlements.Identity)] = &proto.EntitlementInfo{Enabled: features.GetIdentityGovernance()} - features.Entitlements[string(entitlements.JoinActiveSessions)] = &proto.EntitlementInfo{Enabled: features.GetJoinActiveSessions()} - features.Entitlements[string(entitlements.MobileDeviceManagement)] = &proto.EntitlementInfo{Enabled: features.GetMobileDeviceManagement()} - features.Entitlements[string(entitlements.OIDC)] = &proto.EntitlementInfo{Enabled: features.GetOIDC()} - features.Entitlements[string(entitlements.Policy)] = &proto.EntitlementInfo{Enabled: features.GetPolicy().GetEnabled()} - features.Entitlements[string(entitlements.SAML)] = &proto.EntitlementInfo{Enabled: features.GetSAML()} - features.Entitlements[string(entitlements.K8s)] = &proto.EntitlementInfo{Enabled: features.GetKubernetes()} - features.Entitlements[string(entitlements.App)] = &proto.EntitlementInfo{Enabled: features.GetApp()} - features.Entitlements[string(entitlements.DB)] = &proto.EntitlementInfo{Enabled: features.GetDB()} - features.Entitlements[string(entitlements.Desktop)] = &proto.EntitlementInfo{Enabled: features.GetDesktop()} - features.Entitlements[string(entitlements.HSM)] = &proto.EntitlementInfo{Enabled: features.GetHSM()} - - // set default Identity fields to legacy feature value - features.Entitlements[string(entitlements.AccessLists)] = &proto.EntitlementInfo{Enabled: true, Limit: features.GetAccessList().GetCreateLimit()} - features.Entitlements[string(entitlements.AccessMonitoring)] = &proto.EntitlementInfo{Enabled: features.GetAccessMonitoring().GetEnabled(), Limit: features.GetAccessMonitoring().GetMaxReportRangeLimit()} - features.Entitlements[string(entitlements.AccessRequests)] = &proto.EntitlementInfo{Enabled: features.GetAccessRequests().MonthlyRequestLimit > 0, Limit: features.GetAccessRequests().GetMonthlyRequestLimit()} - features.Entitlements[string(entitlements.DeviceTrust)] = &proto.EntitlementInfo{Enabled: features.GetDeviceTrust().GetEnabled(), Limit: features.GetDeviceTrust().GetDevicesUsageLimit()} - // override Identity Package features if Identity is enabled: set true and clear limit - if features.GetIdentityGovernance() { - features.Entitlements[string(entitlements.AccessLists)] = &proto.EntitlementInfo{Enabled: true} - features.Entitlements[string(entitlements.AccessMonitoring)] = &proto.EntitlementInfo{Enabled: true} - features.Entitlements[string(entitlements.AccessRequests)] = &proto.EntitlementInfo{Enabled: true} - features.Entitlements[string(entitlements.DeviceTrust)] = &proto.EntitlementInfo{Enabled: true} - features.Entitlements[string(entitlements.OktaSCIM)] = &proto.EntitlementInfo{Enabled: true} - features.Entitlements[string(entitlements.OktaUserSync)] = &proto.EntitlementInfo{Enabled: true} - features.Entitlements[string(entitlements.SessionLocks)] = &proto.EntitlementInfo{Enabled: true} - } -} - -// getBaseEntitlements takes a cloud entitlement set and returns a modules Entitlement set -func getBaseEntitlements(protoEntitlements map[string]*proto.EntitlementInfo) map[string]*proto.EntitlementInfo { - all := entitlements.AllEntitlements - result := make(map[string]*proto.EntitlementInfo, len(all)) - - for _, e := range all { - al, ok := protoEntitlements[string(e)] - if !ok { - result[string(e)] = &proto.EntitlementInfo{} - continue - } - - result[string(e)] = &proto.EntitlementInfo{ - Enabled: al.Enabled, - Limit: al.Limit, - } - } - - return result -} - // newClient attempts to connect to either the proxy server or auth server // For config v3 and onwards, it will only connect to either the proxy (via tunnel) or the auth server (direct), // depending on what was specified in the config. diff --git a/lib/service/connect_test.go b/lib/service/connect_test.go deleted file mode 100644 index 72f74d40282..00000000000 --- a/lib/service/connect_test.go +++ /dev/null @@ -1,288 +0,0 @@ -// Teleport -// Copyright (C) 2024 Gravitational, Inc. -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package service - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/api/client/proto" - apiutils "github.com/gravitational/teleport/api/utils" - "github.com/gravitational/teleport/entitlements" -) - -func Test_supportEntitlementsCompatibility(t *testing.T) { - tests := []struct { - name string - features *proto.Features - expected map[string]*proto.EntitlementInfo - }{ - { - name: "entitlements present; keeps entitlement values", - features: &proto.Features{ - DeviceTrust: nil, - AccessRequests: nil, - AccessList: nil, - AccessMonitoring: nil, - Policy: nil, - CustomTheme: "", - ProductType: 0, - SupportType: 0, - Kubernetes: false, - App: false, - DB: false, - OIDC: false, - SAML: false, - AccessControls: false, - AdvancedAccessWorkflows: false, - Cloud: false, - HSM: false, - Desktop: false, - RecoveryCodes: false, - Plugins: false, - AutomaticUpgrades: false, - IsUsageBased: false, - Assist: false, - FeatureHiding: false, - IdentityGovernance: false, - AccessGraph: false, - Questionnaire: false, - IsStripeManaged: false, - ExternalAuditStorage: false, - JoinActiveSessions: false, - MobileDeviceManagement: false, - AccessMonitoringConfigured: false, - Entitlements: map[string]*proto.EntitlementInfo{ - string(entitlements.AccessLists): {Enabled: true, Limit: 111}, - string(entitlements.AccessMonitoring): {Enabled: true, Limit: 2113}, - string(entitlements.AccessRequests): {Enabled: true, Limit: 39}, - string(entitlements.App): {Enabled: false}, - string(entitlements.CloudAuditLogRetention): {Enabled: true}, - string(entitlements.DB): {Enabled: true}, - string(entitlements.Desktop): {Enabled: true}, - string(entitlements.DeviceTrust): {Enabled: true, Limit: 103}, - string(entitlements.ExternalAuditStorage): {Enabled: true}, - string(entitlements.FeatureHiding): {Enabled: true}, - string(entitlements.HSM): {Enabled: true}, - string(entitlements.Identity): {Enabled: true}, - string(entitlements.JoinActiveSessions): {Enabled: true}, - string(entitlements.K8s): {Enabled: true}, - string(entitlements.MobileDeviceManagement): {Enabled: true}, - string(entitlements.OIDC): {Enabled: true}, - string(entitlements.OktaSCIM): {Enabled: true}, - string(entitlements.OktaUserSync): {Enabled: true}, - string(entitlements.Policy): {Enabled: true}, - string(entitlements.SAML): {Enabled: true}, - string(entitlements.SessionLocks): {Enabled: true}, - string(entitlements.UpsellAlert): {Enabled: true}, - string(entitlements.UsageReporting): {Enabled: true}, - string(entitlements.LicenseAutoUpdate): {Enabled: true}, - }, - }, - expected: map[string]*proto.EntitlementInfo{ - string(entitlements.AccessLists): {Enabled: true, Limit: 111}, - string(entitlements.AccessMonitoring): {Enabled: true, Limit: 2113}, - string(entitlements.AccessRequests): {Enabled: true, Limit: 39}, - string(entitlements.App): {Enabled: false}, - string(entitlements.CloudAuditLogRetention): {Enabled: true}, - string(entitlements.DB): {Enabled: true}, - string(entitlements.Desktop): {Enabled: true}, - string(entitlements.DeviceTrust): {Enabled: true, Limit: 103}, - string(entitlements.ExternalAuditStorage): {Enabled: true}, - string(entitlements.FeatureHiding): {Enabled: true}, - string(entitlements.HSM): {Enabled: true}, - string(entitlements.Identity): {Enabled: true}, - string(entitlements.JoinActiveSessions): {Enabled: true}, - string(entitlements.K8s): {Enabled: true}, - string(entitlements.MobileDeviceManagement): {Enabled: true}, - string(entitlements.OIDC): {Enabled: true}, - string(entitlements.OktaSCIM): {Enabled: true}, - string(entitlements.OktaUserSync): {Enabled: true}, - string(entitlements.Policy): {Enabled: true}, - string(entitlements.SAML): {Enabled: true}, - string(entitlements.SessionLocks): {Enabled: true}, - string(entitlements.UpsellAlert): {Enabled: true}, - string(entitlements.UsageReporting): {Enabled: true}, - string(entitlements.LicenseAutoUpdate): {Enabled: true}, - }, - }, - { - name: "entitlements not present; identity on - sets legacy fields & drops limits", - features: &proto.Features{ - DeviceTrust: &proto.DeviceTrustFeature{ - Enabled: true, - DevicesUsageLimit: 33, - }, - AccessRequests: &proto.AccessRequestsFeature{ - MonthlyRequestLimit: 22, - }, - AccessList: &proto.AccessListFeature{ - CreateLimit: 44, - }, - AccessMonitoring: &proto.AccessMonitoringFeature{ - Enabled: true, - MaxReportRangeLimit: 55, - }, - Policy: &proto.PolicyFeature{ - Enabled: true, - }, - CustomTheme: "", - ProductType: 0, - SupportType: 0, - Kubernetes: true, - App: true, - DB: true, - OIDC: true, - SAML: true, - AccessControls: true, - AdvancedAccessWorkflows: true, - Cloud: true, - HSM: true, - Desktop: true, - RecoveryCodes: true, - Plugins: true, - AutomaticUpgrades: true, - IsUsageBased: true, - Assist: true, - FeatureHiding: true, - IdentityGovernance: true, - AccessGraph: true, - Questionnaire: true, - IsStripeManaged: true, - ExternalAuditStorage: true, - JoinActiveSessions: true, - MobileDeviceManagement: true, - AccessMonitoringConfigured: true, - }, - expected: map[string]*proto.EntitlementInfo{ - string(entitlements.AccessLists): {Enabled: true}, - string(entitlements.AccessMonitoring): {Enabled: true}, - string(entitlements.AccessRequests): {Enabled: true}, - string(entitlements.App): {Enabled: true}, - string(entitlements.DB): {Enabled: true}, - string(entitlements.Desktop): {Enabled: true}, - string(entitlements.DeviceTrust): {Enabled: true}, - string(entitlements.ExternalAuditStorage): {Enabled: true}, - string(entitlements.FeatureHiding): {Enabled: true}, - string(entitlements.HSM): {Enabled: true}, - string(entitlements.Identity): {Enabled: true}, - string(entitlements.JoinActiveSessions): {Enabled: true}, - string(entitlements.K8s): {Enabled: true}, - string(entitlements.MobileDeviceManagement): {Enabled: true}, - string(entitlements.OIDC): {Enabled: true}, - string(entitlements.OktaSCIM): {Enabled: true}, - string(entitlements.OktaUserSync): {Enabled: true}, - string(entitlements.Policy): {Enabled: true}, - string(entitlements.SAML): {Enabled: true}, - string(entitlements.SessionLocks): {Enabled: true}, - // defaults, no legacy equivalent - string(entitlements.UsageReporting): {Enabled: false}, - string(entitlements.UpsellAlert): {Enabled: false}, - string(entitlements.CloudAuditLogRetention): {Enabled: false}, - string(entitlements.LicenseAutoUpdate): {Enabled: false}, - }, - }, - { - name: "entitlements not present; identity off - sets legacy fields", - features: &proto.Features{ - DeviceTrust: &proto.DeviceTrustFeature{ - Enabled: true, - DevicesUsageLimit: 33, - }, - AccessRequests: &proto.AccessRequestsFeature{ - MonthlyRequestLimit: 22, - }, - AccessList: &proto.AccessListFeature{ - CreateLimit: 44, - }, - AccessMonitoring: &proto.AccessMonitoringFeature{ - Enabled: true, - MaxReportRangeLimit: 55, - }, - Policy: &proto.PolicyFeature{ - Enabled: true, - }, - CustomTheme: "", - ProductType: 0, - SupportType: 0, - Kubernetes: true, - App: true, - DB: true, - OIDC: true, - SAML: true, - AccessControls: true, - AdvancedAccessWorkflows: true, - Cloud: true, - HSM: true, - Desktop: true, - RecoveryCodes: true, - Plugins: true, - AutomaticUpgrades: true, - IsUsageBased: true, - Assist: true, - FeatureHiding: true, - IdentityGovernance: false, - AccessGraph: true, - Questionnaire: true, - IsStripeManaged: true, - ExternalAuditStorage: true, - JoinActiveSessions: true, - MobileDeviceManagement: true, - AccessMonitoringConfigured: true, - }, - expected: map[string]*proto.EntitlementInfo{ - string(entitlements.AccessLists): {Enabled: true, Limit: 44}, - string(entitlements.AccessMonitoring): {Enabled: true, Limit: 55}, - string(entitlements.AccessRequests): {Enabled: true, Limit: 22}, - string(entitlements.DeviceTrust): {Enabled: true, Limit: 33}, - string(entitlements.App): {Enabled: true}, - string(entitlements.DB): {Enabled: true}, - string(entitlements.Desktop): {Enabled: true}, - string(entitlements.ExternalAuditStorage): {Enabled: true}, - string(entitlements.FeatureHiding): {Enabled: true}, - string(entitlements.HSM): {Enabled: true}, - string(entitlements.JoinActiveSessions): {Enabled: true}, - string(entitlements.K8s): {Enabled: true}, - string(entitlements.MobileDeviceManagement): {Enabled: true}, - string(entitlements.OIDC): {Enabled: true}, - string(entitlements.Policy): {Enabled: true}, - string(entitlements.SAML): {Enabled: true}, - - // defaults, no legacy equivalent - string(entitlements.UsageReporting): {Enabled: false}, - string(entitlements.UpsellAlert): {Enabled: false}, - string(entitlements.CloudAuditLogRetention): {Enabled: false}, - string(entitlements.LicenseAutoUpdate): {Enabled: false}, - - // Identity off, fields false - string(entitlements.Identity): {Enabled: false}, - string(entitlements.SessionLocks): {Enabled: false}, - string(entitlements.OktaSCIM): {Enabled: false}, - string(entitlements.OktaUserSync): {Enabled: false}, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cloned := apiutils.CloneProtoMsg(tt.features) - - supportEntitlementsCompatibility(cloned) - require.Equal(t, tt.expected, cloned.Entitlements) - }) - } -} diff --git a/lib/service/service.go b/lib/service/service.go index c36e15e3342..54ec59037bf 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4595,6 +4595,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { TracerProvider: process.TracingProvider, AutomaticUpgradesChannels: cfg.Proxy.AutomaticUpgradesChannels, IntegrationAppHandler: connectionsHandler, + FeatureWatchInterval: utils.HalfJitter(web.DefaultFeatureWatchInterval * 2), } webHandler, err := web.NewHandler(webConfig) if err != nil { diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 0e05fb23c9c..921abf02321 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -21,6 +21,7 @@ package web import ( + "cmp" "context" "crypto/tls" "encoding/base64" @@ -124,6 +125,9 @@ const ( IncludedResourceModeRequestable = "requestable" // IncludedResourceModeAll describes that all resources, requestable and available, should be returned. IncludedResourceModeAll = "all" + // DefaultFeatureWatchInterval is the default time in which the feature watcher + // should ping the auth server to check for updated features + DefaultFeatureWatchInterval = time.Minute * 5 ) // healthCheckAppServerFunc defines a function used to perform a health check @@ -153,12 +157,8 @@ type Handler struct { // userConns tracks amount of current active connections with user certificates. userConns atomic.Int32 - // ClusterFeatures contain flags for supported and unsupported features. - // Note: This field can become stale since it's only set on initial proxy - // startup. To get the latest feature flags you'll need to ping from the - // auth server. - // https://github.com/gravitational/teleport/issues/39161 - ClusterFeatures proto.Features + // clusterFeatures contain flags for supported and unsupported features. + clusterFeatures proto.Features // nodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from // the proxy's cache and get nodes in real time. @@ -314,6 +314,10 @@ type Config struct { // IntegrationAppHandler handles App Access requests which use an Integration. IntegrationAppHandler app.ServerHandler + + // FeatureWatchInterval is the interval between pings to the auth server + // to fetch new cluster features + FeatureWatchInterval time.Duration } // SetDefaults ensures proper default values are set if @@ -328,6 +332,8 @@ func (c *Config) SetDefaults() { if c.PresenceChecker == nil { c.PresenceChecker = client.RunPresenceTask } + + c.FeatureWatchInterval = cmp.Or(c.FeatureWatchInterval, DefaultFeatureWatchInterval) } type APIHandler struct { @@ -451,7 +457,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { log: newPackageLogger(), logger: slog.Default().With(teleport.ComponentKey, teleport.ComponentWeb), clock: clockwork.NewRealClock(), - ClusterFeatures: cfg.ClusterFeatures, + clusterFeatures: cfg.ClusterFeatures, healthCheckAppServer: cfg.HealthCheckAppServer, tracer: cfg.TracerProvider.Tracer(teleport.ComponentWeb), wsIODeadline: wsIODeadline, @@ -682,6 +688,8 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { } } + go h.startFeatureWatcher() + return &APIHandler{ handler: h, appHandler: appHandler, @@ -1164,17 +1172,12 @@ func (h *Handler) getUserContext(w http.ResponseWriter, r *http.Request, p httpr } desktopRecordingEnabled := recConfig.GetMode() != types.RecordOff - pingResp, err := clt.Ping(r.Context()) - if err != nil { - return nil, trace.Wrap(err) - } - - features := pingResp.GetServerFeatures() - entitlement := modules.GetProtoEntitlement(features, entitlements.AccessMonitoring) + features := h.GetClusterFeatures() + entitlement := modules.GetProtoEntitlement(&features, entitlements.AccessMonitoring) // ensure entitlement is set & feature is configured accessMonitoringEnabled := entitlement.Enabled && features.GetAccessMonitoringConfigured() - userContext, err := ui.NewUserContext(user, accessChecker.Roles(), *pingResp.ServerFeatures, desktopRecordingEnabled, accessMonitoringEnabled) + userContext, err := ui.NewUserContext(user, accessChecker.Roles(), features, desktopRecordingEnabled, accessMonitoringEnabled) if err != nil { return nil, trace.Wrap(err) } @@ -1692,14 +1695,7 @@ func (h *Handler) getWebConfig(w http.ResponseWriter, r *http.Request, p httprou } } - clusterFeatures := h.ClusterFeatures - // ping server to get cluster features since h.ClusterFeatures may be stale - pingResponse, err := h.GetProxyClient().Ping(r.Context()) - if err != nil { - h.log.WithError(err).Warn("Cannot retrieve cluster features, client may receive stale features") - } else { - clusterFeatures = *pingResponse.ServerFeatures - } + clusterFeatures := h.GetClusterFeatures() // get tunnel address to display on cloud instances tunnelPublicAddr := "" @@ -1813,7 +1809,6 @@ func setEntitlementsWithLegacyLogic(webCfg *webclient.WebConfig, clusterFeatures webCfg.Entitlements[string(entitlements.OIDC)] = webclient.EntitlementInfo{Enabled: clusterFeatures.GetOIDC()} webCfg.Entitlements[string(entitlements.Policy)] = webclient.EntitlementInfo{Enabled: clusterFeatures.GetPolicy() != nil && clusterFeatures.GetPolicy().Enabled} webCfg.Entitlements[string(entitlements.SAML)] = webclient.EntitlementInfo{Enabled: clusterFeatures.GetSAML()} - // set default Identity fields to legacy feature value webCfg.Entitlements[string(entitlements.AccessLists)] = webclient.EntitlementInfo{Enabled: true, Limit: clusterFeatures.GetAccessList().GetCreateLimit()} webCfg.Entitlements[string(entitlements.AccessMonitoring)] = webclient.EntitlementInfo{Enabled: clusterFeatures.GetAccessMonitoring().GetEnabled(), Limit: clusterFeatures.GetAccessMonitoring().GetMaxReportRangeLimit()} diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 1719dd66b5a..640b7395ea7 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -4561,6 +4561,7 @@ func TestApplicationWebSessionsDeletedAfterLogout(t *testing.T) { func TestGetWebConfig_WithEntitlements(t *testing.T) { ctx := context.Background() env := newWebPack(t, 1) + handler := env.proxies[0].handler.handler // Set auth preference with passwordless. const MOTD = "Welcome to cluster, your activity will be recorded." @@ -4591,6 +4592,9 @@ func TestGetWebConfig_WithEntitlements(t *testing.T) { _, err = env.server.Auth().UpsertGithubConnector(ctx, github) require.NoError(t, err) + // start the feature watcher so the web config gets new features + env.clock.Advance(DefaultFeatureWatchInterval * 2) + expectedCfg := webclient.WebConfig{ Auth: webclient.WebConfigAuthSettings{ SecondFactor: constants.SecondFactorOptional, @@ -4680,6 +4684,7 @@ func TestGetWebConfig_WithEntitlements(t *testing.T) { }, }, }) + env.clock.Advance(DefaultFeatureWatchInterval * 2) require.NoError(t, err) // This version is too high and MUST NOT be used @@ -4690,7 +4695,7 @@ func TestGetWebConfig_WithEntitlements(t *testing.T) { }, } require.NoError(t, channels.CheckAndSetDefaults()) - env.proxies[0].handler.handler.cfg.AutomaticUpgradesChannels = channels + handler.cfg.AutomaticUpgradesChannels = channels expectedCfg.IsCloud = true expectedCfg.IsUsageBasedBilling = true @@ -4706,14 +4711,20 @@ func TestGetWebConfig_WithEntitlements(t *testing.T) { expectedCfg.Entitlements[string(entitlements.JoinActiveSessions)] = webclient.EntitlementInfo{Enabled: false} expectedCfg.Entitlements[string(entitlements.K8s)] = webclient.EntitlementInfo{Enabled: false} - // request and verify enabled features are enabled. - re, err = clt.Get(ctx, endpoint, nil) - require.NoError(t, err) - require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) - str = strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") - err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) - require.NoError(t, err) - require.Equal(t, expectedCfg, cfg) + // request and verify enabled features are eventually enabled. + require.EventuallyWithT(t, func(t *assert.CollectT) { + re, err := clt.Get(ctx, endpoint, nil) + if !assert.NoError(t, err) { + return + } + assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG"))) + res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{}) + err = json.Unmarshal(res[:len(res)-1], &cfg) + assert.NoError(t, err) + diff := cmp.Diff(expectedCfg, cfg) + assert.Empty(t, diff) + + }, time.Second*5, time.Millisecond*50) // use mock client to assert that if ping returns an error, we'll default to // cluster config @@ -4736,15 +4747,22 @@ func TestGetWebConfig_WithEntitlements(t *testing.T) { IsUsageBasedBilling: false, }, }) + env.clock.Advance(DefaultFeatureWatchInterval * 2) // request and verify again - re, err = clt.Get(ctx, endpoint, nil) - require.NoError(t, err) - require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) - str = strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") - err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) - require.NoError(t, err) - require.Equal(t, expectedCfg, cfg) + require.EventuallyWithT(t, func(t *assert.CollectT) { + re, err := clt.Get(ctx, endpoint, nil) + if !assert.NoError(t, err) { + return + } + assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG"))) + res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{}) + err = json.Unmarshal(res[:len(res)-1], &cfg) + assert.NoError(t, err) + diff := cmp.Diff(expectedCfg, cfg) + assert.Empty(t, diff) + + }, time.Second*5, time.Millisecond*50) } func TestGetWebConfig_LegacyFeatureLimits(t *testing.T) { @@ -4764,6 +4782,8 @@ func TestGetWebConfig_LegacyFeatureLimits(t *testing.T) { }, }, }) + // start the feature watcher so the web config gets new features + env.clock.Advance(DefaultFeatureWatchInterval * 2) expectedCfg := webclient.WebConfig{ Auth: webclient.WebConfigAuthSettings{ @@ -4812,20 +4832,25 @@ func TestGetWebConfig_LegacyFeatureLimits(t *testing.T) { PlayableDatabaseProtocols: player.SupportedDatabaseProtocols, } - // Make a request. clt := env.proxies[0].newClient(t) - endpoint := clt.Endpoint("web", "config.js") - re, err := clt.Get(ctx, endpoint, nil) - require.NoError(t, err) - require.True(t, strings.HasPrefix(string(re.Bytes()), "var GRV_CONFIG")) + require.EventuallyWithT(t, func(t *assert.CollectT) { + // Make a request. + endpoint := clt.Endpoint("web", "config.js") + re, err := clt.Get(ctx, endpoint, nil) + if !assert.NoError(t, err) { + return + } + assert.True(t, bytes.HasPrefix(re.Bytes(), []byte("var GRV_CONFIG"))) - // Response is type application/javascript, we need to strip off the variable name - // and the semicolon at the end, then we are left with json like object. - var cfg webclient.WebConfig - str := strings.ReplaceAll(string(re.Bytes()), "var GRV_CONFIG = ", "") - err = json.Unmarshal([]byte(str[:len(str)-1]), &cfg) - require.NoError(t, err) - require.Equal(t, expectedCfg, cfg) + // Response is type application/javascript, we need to strip off the variable name + // and the semicolon at the end, then we are left with json like object. + var cfg webclient.WebConfig + res := bytes.ReplaceAll(re.Bytes(), []byte("var GRV_CONFIG = "), []byte{}) + err = json.Unmarshal(res[:len(res)-1], &cfg) + assert.NoError(t, err) + diff := cmp.Diff(expectedCfg, cfg) + assert.Empty(t, diff) + }, time.Second*5, time.Millisecond*50) } func TestCreatePrivilegeToken(t *testing.T) { diff --git a/lib/web/features.go b/lib/web/features.go new file mode 100644 index 00000000000..29798851aa7 --- /dev/null +++ b/lib/web/features.go @@ -0,0 +1,73 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package web + +import ( + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/entitlements" +) + +// SetClusterFeatures sets the flags for supported and unsupported features. +// TODO(mcbattirola): make method unexported, fix tests using it to set +// test modules instead. +func (h *Handler) SetClusterFeatures(features proto.Features) { + h.Mutex.Lock() + defer h.Mutex.Unlock() + + entitlements.BackfillFeatures(&features) + h.clusterFeatures = features +} + +// GetClusterFeatures returns flags for supported and unsupported features. +func (h *Handler) GetClusterFeatures() proto.Features { + h.Mutex.Lock() + defer h.Mutex.Unlock() + + return h.clusterFeatures +} + +// startFeatureWatcher periodically pings the auth server and updates `clusterFeatures`. +// Must be called only once per `handler`, otherwise it may close an already closed channel +// which will cause a panic. +// The watcher doesn't ping the auth server immediately upon start because features are +// already set by the config object in `NewHandler`. +func (h *Handler) startFeatureWatcher() { + ticker := h.clock.NewTicker(h.cfg.FeatureWatchInterval) + h.log.WithField("interval", h.cfg.FeatureWatchInterval).Info("Proxy handler features watcher has started") + ctx := h.cfg.Context + + defer ticker.Stop() + for { + select { + case <-ticker.Chan(): + h.log.Info("Pinging auth server for features") + pingResponse, err := h.GetProxyClient().Ping(ctx) + if err != nil { + h.log.WithError(err).Error("Auth server ping failed") + continue + } + + h.SetClusterFeatures(*pingResponse.ServerFeatures) + h.log.WithField("features", pingResponse.ServerFeatures).Info("Done updating proxy features") + case <-ctx.Done(): + h.log.Info("Feature service has stopped") + return + } + } +} diff --git a/lib/web/features_test.go b/lib/web/features_test.go new file mode 100644 index 00000000000..3798e819b46 --- /dev/null +++ b/lib/web/features_test.go @@ -0,0 +1,176 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package web + +import ( + "context" + "log/slog" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/entitlements" + "github.com/gravitational/teleport/lib/auth/authclient" +) + +// mockedPingTestProxy is a test proxy with a mocked Ping method +// that returns the internal features +type mockedFeatureGetter struct { + authclient.ClientI + features proto.Features +} + +func (m *mockedFeatureGetter) Ping(ctx context.Context) (proto.PingResponse, error) { + return proto.PingResponse{ + ServerFeatures: utils.CloneProtoMsg(&m.features), + }, nil +} + +func (m *mockedFeatureGetter) setFeatures(f proto.Features) { + m.features = f +} + +func TestFeaturesWatcher(t *testing.T) { + clock := clockwork.NewFakeClock() + + mockClient := &mockedFeatureGetter{features: proto.Features{ + Kubernetes: true, + Entitlements: map[string]*proto.EntitlementInfo{}, + AccessRequests: &proto.AccessRequestsFeature{}, + }} + + ctx, cancel := context.WithCancel(context.Background()) + handler := &Handler{ + cfg: Config{ + FeatureWatchInterval: 100 * time.Millisecond, + ProxyClient: mockClient, + Context: ctx, + }, + clock: clock, + clusterFeatures: proto.Features{}, + log: newPackageLogger(), + logger: slog.Default().With(teleport.ComponentKey, teleport.ComponentWeb), + } + + // before running the watcher, features should match the value passed to the handler + requireFeatures(t, clock, proto.Features{}, handler.GetClusterFeatures) + + go handler.startFeatureWatcher() + clock.BlockUntil(1) + + // after starting the watcher, handler.GetClusterFeatures should return + // values matching the client's response + features := proto.Features{ + Kubernetes: true, + Entitlements: map[string]*proto.EntitlementInfo{}, + AccessRequests: &proto.AccessRequestsFeature{}, + } + entitlements.BackfillFeatures(&features) + expected := utils.CloneProtoMsg(&features) + requireFeatures(t, clock, *expected, handler.GetClusterFeatures) + + // update values once again and check if the features are properly updated + features = proto.Features{ + Kubernetes: false, + Entitlements: map[string]*proto.EntitlementInfo{}, + AccessRequests: &proto.AccessRequestsFeature{}, + } + entitlements.BackfillFeatures(&features) + mockClient.setFeatures(features) + expected = utils.CloneProtoMsg(&features) + requireFeatures(t, clock, *expected, handler.GetClusterFeatures) + + // test updating entitlements + features = proto.Features{ + Kubernetes: true, + Entitlements: map[string]*proto.EntitlementInfo{ + string(entitlements.ExternalAuditStorage): {Enabled: true}, + string(entitlements.AccessLists): {Enabled: true}, + string(entitlements.AccessMonitoring): {Enabled: true}, + string(entitlements.App): {Enabled: true}, + string(entitlements.CloudAuditLogRetention): {Enabled: true}, + }, + AccessRequests: &proto.AccessRequestsFeature{}, + } + entitlements.BackfillFeatures(&features) + mockClient.setFeatures(features) + + expected = &proto.Features{ + Kubernetes: true, + Entitlements: map[string]*proto.EntitlementInfo{ + string(entitlements.ExternalAuditStorage): {Enabled: true}, + string(entitlements.AccessLists): {Enabled: true}, + string(entitlements.AccessMonitoring): {Enabled: true}, + string(entitlements.App): {Enabled: true}, + string(entitlements.CloudAuditLogRetention): {Enabled: true}, + }, + AccessRequests: &proto.AccessRequestsFeature{}, + } + entitlements.BackfillFeatures(expected) + requireFeatures(t, clock, *expected, handler.GetClusterFeatures) + + // stop watcher and ensure it stops updating features + cancel() + features = proto.Features{ + Kubernetes: !features.Kubernetes, + App: !features.App, + DB: true, + Entitlements: map[string]*proto.EntitlementInfo{}, + AccessRequests: &proto.AccessRequestsFeature{}, + } + entitlements.BackfillFeatures(&features) + mockClient.setFeatures(features) + notExpected := utils.CloneProtoMsg(&features) + // assert the handler never get these last features as the watcher is stopped + neverFeatures(t, clock, *notExpected, handler.GetClusterFeatures) +} + +// requireFeatures is a helper function that advances the clock, then +// calls `getFeatures` every 100ms for up to 1 second, until it +// returns the expected result (`want`). +func requireFeatures(t *testing.T, fakeClock clockwork.FakeClock, want proto.Features, getFeatures func() proto.Features) { + t.Helper() + + // Advance the clock so the service fetch and stores features + fakeClock.Advance(1 * time.Second) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + diff := cmp.Diff(want, getFeatures()) + assert.Empty(t, diff) + }, 5*time.Second, time.Millisecond*100) +} + +// neverFeatures is a helper function that advances the clock, then +// calls `getFeatures` every 100ms for up to 1 second. If at some point `getFeatures` +// returns `doNotWant`, the test fails. +func neverFeatures(t *testing.T, fakeClock clockwork.FakeClock, doNotWant proto.Features, getFeatures func() proto.Features) { + t.Helper() + + fakeClock.Advance(1 * time.Second) + require.Never(t, func() bool { + return cmp.Diff(doNotWant, getFeatures()) == "" + }, 1*time.Second, time.Millisecond*100) +} diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 6252af38a3d..2ab0a00492d 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -148,7 +148,7 @@ func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p } teleportVersionTag := teleport.Version - if automaticUpgrades(h.ClusterFeatures) { + if automaticUpgrades(h.GetClusterFeatures()) { cloudStableVersion, err := h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx) if err != nil { return "", trace.Wrap(err) @@ -201,7 +201,7 @@ func (h *Handler) awsOIDCDeployDatabaseServices(w http.ResponseWriter, r *http.R } teleportVersionTag := teleport.Version - if automaticUpgrades(h.ClusterFeatures) { + if automaticUpgrades(h.GetClusterFeatures()) { cloudStableVersion, err := h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx) if err != nil { return "", trace.Wrap(err) @@ -527,7 +527,7 @@ func (h *Handler) awsOIDCEnrollEKSClusters(w http.ResponseWriter, r *http.Reques return nil, trace.BadParameter("an integration name is required") } - agentVersion, err := kubeutils.GetKubeAgentVersion(ctx, h.cfg.ProxyClient, h.ClusterFeatures, h.cfg.AutomaticUpgradesChannels) + agentVersion, err := kubeutils.GetKubeAgentVersion(ctx, h.cfg.ProxyClient, h.GetClusterFeatures(), h.cfg.AutomaticUpgradesChannels) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index c5939e30953..7dcd657857a 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -377,7 +377,7 @@ func (h *Handler) createTokenForDiscoveryHandle(w http.ResponseWriter, r *http.R func (h *Handler) getAutoUpgrades(ctx context.Context) (bool, string, error) { var autoUpgradesVersion string var err error - autoUpgrades := automaticUpgrades(h.ClusterFeatures) + autoUpgrades := automaticUpgrades(h.GetClusterFeatures()) if autoUpgrades { autoUpgradesVersion, err = h.cfg.AutomaticUpgradesChannels.DefaultVersion(ctx) if err != nil {