Keep a device data cache in-process (#34597)

This commit is contained in:
Alan Parra 2023-11-17 10:27:23 -03:00 committed by GitHub
parent 0cef1a31ca
commit ffce86db1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 9 deletions

View file

@ -16,6 +16,13 @@ package native
import (
"runtime"
"sync"
"time"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1"
)
@ -34,24 +41,77 @@ const (
// CollectedDataNeverEscalate will never escalate privileges, even in the
// absence of cached data.
CollectedDataNeverEscalate CollectDataMode = iota
// CollectedDataAlwaysEscalate avoids using cached DMI data and instead will
// always escalate privileges if necessary.
//
// Used by `tsh device enroll`, `tsh device collect` and
// `tsh device asset-tag`.
CollectedDataAlwaysEscalate
// CollectedDataMaybeEscalate will attempt to use cached DMI data before
// privilege escalation, but it may choose to escalate if no cached data is
// available.
//
// Used by `tsh login` and similar operations (ie, device authn).
CollectedDataMaybeEscalate
// CollectedDataAlwaysEscalate avoids using cached DMI data and instead will
// always escalate privileges if necessary.
//
// Used by `tsh device enroll`, `tsh device collect` and
// `tsh device asset-tag`.
CollectedDataAlwaysEscalate
// IMPORTANT: CollectDataMode declarations must go from least to most strict.
)
var cachedDeviceData = struct {
skipCache bool // Set to true for testing.
mu sync.Mutex
mode CollectDataMode
value *devicepb.DeviceCollectedData
}{}
func readCachedDeviceDataUnderLock(mode CollectDataMode) (cdd *devicepb.DeviceCollectedData, ok bool) {
// Use cached data if present and the cached mode is at least as strict as the
// one requested.
// This can save some time, but mainly it avoids needless escalation attempts
// on Linux (past the first).
if cachedDeviceData.skipCache || cachedDeviceData.mode < mode || cachedDeviceData.value == nil {
return nil, false
}
// Default sudo cache is around 5m, so this seems like a resonable interval.
const maxAgeSeconds = 60
cdd = cachedDeviceData.value
now := time.Now()
if now.Unix()-cdd.CollectTime.Seconds > maxAgeSeconds {
// "Evict" cache.
cachedDeviceData.mode = 0
cachedDeviceData.value = nil
return nil, false
}
log.Debug("Device Trust: Using in-process cached device data")
cdd = proto.Clone(cachedDeviceData.value).(*devicepb.DeviceCollectedData)
cdd.CollectTime = timestamppb.Now()
return cdd, true
}
func writeCachedDeviceDataUnderLock(mode CollectDataMode, cdd *devicepb.DeviceCollectedData) {
cachedDeviceData.mode = mode
cachedDeviceData.value = proto.Clone(cdd).(*devicepb.DeviceCollectedData)
}
// CollectDeviceData collects OS-specific device data for device enrollment or
// device authentication ceremonies.
func CollectDeviceData(mode CollectDataMode) (*devicepb.DeviceCollectedData, error) {
return collectDeviceData(mode)
cachedDeviceData.mu.Lock()
defer cachedDeviceData.mu.Unlock()
if cdd, ok := readCachedDeviceDataUnderLock(mode); ok {
return cdd, nil
}
cdd, err := collectDeviceData(mode)
if err != nil {
return nil, trace.Wrap(err)
}
writeCachedDeviceDataUnderLock(mode, cdd)
return cdd, nil
}
// SignChallenge signs a device challenge for device enrollment or device

View file

@ -49,7 +49,7 @@ func enrollDeviceInit() (*devicepb.EnrollDeviceInit, error) {
return nil, trace.Wrap(err)
}
cd, err := collectDeviceData(CollectedDataAlwaysEscalate)
cd, err := CollectDeviceData(CollectedDataAlwaysEscalate)
if err != nil {
return nil, trace.Wrap(err, "collecting device data")
}

View file

@ -38,6 +38,13 @@ func TestCollectDeviceData_linux(t *testing.T) {
// Silence logging for tests.
log.SetLevel(log.PanicLevel)
// Do not cache data during testing.
skipCacheBefore := cachedDeviceData.skipCache
cachedDeviceData.skipCache = true
t.Cleanup(func() {
cachedDeviceData.skipCache = skipCacheBefore
})
u, err := user.Current()
require.NoError(t, err, "reading current user")

View file

@ -151,7 +151,7 @@ func (d *tpmDevice) enrollDeviceInit() (*devicepb.EnrollDeviceInit, error) {
}
defer ak.Close(tpm)
deviceData, err := collectDeviceData(CollectedDataAlwaysEscalate)
deviceData, err := CollectDeviceData(CollectedDataAlwaysEscalate)
if err != nil {
return nil, trace.Wrap(err, "collecting device data")
}