athena audit logs - query rate limiter (#24918)

This commit is contained in:
Tobiasz Heller 2023-04-28 15:46:04 +02:00 committed by GitHub
parent 4d0f6c58ea
commit 29497f5d85
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 430 additions and 41 deletions

View file

@ -16,7 +16,6 @@ package athena
import (
"context"
"math"
"net/url"
"regexp"
"strconv"
@ -81,9 +80,14 @@ type Config struct {
// GetQueryResultsInterval is used to define how long query will wait before
// checking again for results status if previous status was not ready (optional).
GetQueryResultsInterval time.Duration
// LimiterRate defines rate at which search_event rate limiter is filled (optional).
LimiterRate float64
// LimiterBurst defines rate limit bucket capacity (optional).
// LimiterRefillTime determines the duration of time between the addition of tokens to the bucket (optional).
LimiterRefillTime time.Duration
// LimiterRefillAmount is the number of tokens that are added to the bucket during interval
// specified by LimiterRefillTime (optional).
LimiterRefillAmount int
// Burst defines number of available tokens. It's initially full and refilled
// based on LimiterRefillAmount and LimiterRefillTime (optional).
LimiterBurst int
// Batcher settings.
@ -198,19 +202,23 @@ func (cfg *Config) CheckAndSetDefaults(ctx context.Context) error {
return trace.BadParameter("BatchMaxInterval too short, must be greater than 5s")
}
if cfg.LimiterRate < 0 {
return trace.BadParameter("LimiterRate cannot be negative")
if cfg.LimiterRefillAmount < 0 {
return trace.BadParameter("LimiterRefillAmount cannot be nagative")
}
if cfg.LimiterBurst < 0 {
return trace.BadParameter("LimiterBurst cannot be negative")
}
if cfg.LimiterRate > 0 && cfg.LimiterBurst == 0 {
return trace.BadParameter("LimiterBurst must be greater than 0 if LimiterRate is used")
if cfg.LimiterRefillAmount > 0 && cfg.LimiterBurst == 0 {
return trace.BadParameter("LimiterBurst must be greater than 0 if LimiterRefillAmount is used")
}
if cfg.LimiterBurst > 0 && math.Abs(cfg.LimiterRate) < 1e-9 {
return trace.BadParameter("LimiterRate must be greater than 0 if LimiterBurst is used")
if cfg.LimiterBurst > 0 && cfg.LimiterRefillAmount == 0 {
return trace.BadParameter("LimiterRefillAmount must be greater than 0 if LimiterBurst is used")
}
if cfg.LimiterRefillAmount > 0 && cfg.LimiterRefillTime == 0 {
cfg.LimiterRefillTime = time.Second
}
if cfg.Clock == nil {
@ -283,13 +291,21 @@ func (cfg *Config) SetFromURL(url *url.URL) error {
}
cfg.GetQueryResultsInterval = dur
}
rateInString := url.Query().Get("limiterRate")
if rateInString != "" {
rate, err := strconv.ParseFloat(rateInString, 32)
refillAmountInString := url.Query().Get("limiterRefillAmount")
if refillAmountInString != "" {
refillAmount, err := strconv.Atoi(refillAmountInString)
if err != nil {
return trace.BadParameter("invalid limiterRate value (it must be float32): %v", err)
return trace.BadParameter("invalid limiterRefillAmount value (it must be int): %v", err)
}
cfg.LimiterRate = rate
cfg.LimiterRefillAmount = refillAmount
}
refillTimeInString := url.Query().Get("limiterRefillTime")
if refillTimeInString != "" {
dur, err := time.ParseDuration(refillTimeInString)
if err != nil {
return trace.BadParameter("invalid limiterRefillTime value: %v", err)
}
cfg.LimiterRefillTime = dur
}
burstInString := url.Query().Get("limiterBurst")
if burstInString != "" {

View file

@ -74,12 +74,13 @@ func TestConfig_SetFromURL(t *testing.T) {
},
{
name: "params to querier - part 2",
url: "athena://db.tbl/?getQueryResultsInterval=200ms&limiterRate=0.642&limiterBurst=3",
url: "athena://db.tbl/?getQueryResultsInterval=200ms&limiterRefillAmount=2&&limiterRefillTime=2s&limiterBurst=3",
want: Config{
TableName: "tbl",
Database: "db",
GetQueryResultsInterval: 200 * time.Millisecond,
LimiterRate: 0.642,
LimiterRefillAmount: 2,
LimiterRefillTime: 2 * time.Second,
LimiterBurst: 3,
},
},
@ -100,9 +101,9 @@ func TestConfig_SetFromURL(t *testing.T) {
wantErr: "invalid athena address, supported format is 'athena://database.table'",
},
{
name: "invalid limiterRate format",
url: "athena://db.tbl/?limiterRate=abc",
wantErr: "invalid limiterRate value (it must be float32)",
name: "invalid limiterRefillAmount format",
url: "athena://db.tbl/?limiterRefillAmount=abc",
wantErr: "invalid limiterRefillAmount value (it must be int)",
},
}
for _, tt := range tests {
@ -163,6 +164,33 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) {
Backend: mockBackend{},
},
},
{
name: "valid config with limiter, check defaults refillTime",
input: func() Config {
cfg := validConfig
cfg.LimiterBurst = 10
cfg.LimiterRefillAmount = 5
return cfg
},
want: Config{
Database: "db",
TableName: "tbl",
TopicARN: "arn:topic",
LargeEventsS3: "s3://large-payloads-bucket",
largeEventsBucket: "large-payloads-bucket",
LocationS3: "s3://events-bucket",
locationS3Bucket: "events-bucket",
QueueURL: "https://queue-url",
GetQueryResultsInterval: 100 * time.Millisecond,
BatchMaxItems: 20000,
BatchMaxInterval: 1 * time.Minute,
AWSConfig: &aws.Config{},
Backend: mockBackend{},
LimiterRefillTime: 1 * time.Second,
LimiterBurst: 10,
LimiterRefillAmount: 5,
},
},
{
name: "missing table name",
input: func() Config {
@ -227,24 +255,24 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) {
wantErr: "QueueURL must be valid url and start with https",
},
{
name: "invalid LimiterBurst and LimiterRate combination",
name: "invalid LimiterBurst and LimiterRefillAmount combination",
input: func() Config {
cfg := validConfig
cfg.LimiterBurst = 0
cfg.LimiterRate = 2.5
cfg.LimiterRefillAmount = 2
return cfg
},
wantErr: "LimiterBurst must be greater than 0 if LimiterRate is used",
wantErr: "LimiterBurst must be greater than 0 if LimiterRefillAmount is used",
},
{
name: "invalid LimiterRate and LimiterBurst combination",
name: "invalid LimiterRefillAmount and LimiterBurst combination",
input: func() Config {
cfg := validConfig
cfg.LimiterBurst = 3
cfg.LimiterRate = 0
cfg.LimiterRefillAmount = 0
return cfg
},
wantErr: "LimiterRate must be greater than 0 if LimiterBurst is used",
wantErr: "LimiterRefillAmount must be greater than 0 if LimiterBurst is used",
},
}
for _, tt := range tests {

View file

@ -117,7 +117,16 @@ func (q *querier) SearchEvents(fromUTC, toUTC time.Time, namespace string,
eventTypes []string, limit int, order types.EventOrder, startKey string,
) ([]apievents.AuditEvent, string, error) {
filter := searchEventsFilter{eventTypes: eventTypes}
return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, "")
events, keyset, err := q.searchEvents(context.TODO(), searchEventsRequest{
fromUTC: fromUTC,
toUTC: toUTC,
limit: limit,
order: order,
startKey: startKey,
filter: filter,
sessionID: "",
})
return events, keyset, trace.Wrap(err)
}
func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
@ -133,12 +142,29 @@ func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
}
filter.condition = condFn
}
return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, sessionID)
events, keyset, err := q.searchEvents(context.TODO(), searchEventsRequest{
fromUTC: fromUTC,
toUTC: toUTC,
limit: limit,
order: order,
startKey: startKey,
filter: filter,
sessionID: sessionID,
})
return events, keyset, trace.Wrap(err)
}
func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, limit int,
order types.EventOrder, startKey string, filter searchEventsFilter, sessionID string,
) ([]apievents.AuditEvent, string, error) {
type searchEventsRequest struct {
fromUTC, toUTC time.Time
limit int
order types.EventOrder
startKey string
filter searchEventsFilter
sessionID string
}
func (q *querier) searchEvents(ctx context.Context, req searchEventsRequest) ([]apievents.AuditEvent, string, error) {
limit := req.limit
if limit <= 0 {
limit = defaults.EventsIterationLimit
}
@ -147,28 +173,28 @@ func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, li
}
var startKeyset *keyset
if startKey != "" {
if req.startKey != "" {
var err error
startKeyset, err = fromKey(startKey)
startKeyset, err = fromKey(req.startKey)
if err != nil {
return nil, "", trace.Wrap(err)
}
}
query, params := prepareQuery(searchParams{
fromUTC: fromUTC,
toUTC: toUTC,
order: order,
fromUTC: req.fromUTC,
toUTC: req.toUTC,
order: req.order,
limit: limit,
startKeyset: startKeyset,
filter: filter,
sessionID: sessionID,
filter: req.filter,
sessionID: req.sessionID,
tablename: q.tablename,
})
q.logger.WithField("query", query).
WithField("params", params).
WithField("startKey", startKey).
WithField("startKey", req.startKey).
Debug("Executing events query on Athena")
queryId, err := q.startQueryExecution(ctx, query, params)
@ -180,7 +206,7 @@ func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, li
return nil, "", trace.Wrap(err)
}
output, nextKey, err := q.fetchResults(ctx, queryId, limit, filter.condition)
output, nextKey, err := q.fetchResults(ctx, queryId, limit, req.filter.condition)
return output, nextKey, trace.Wrap(err)
}

View file

@ -0,0 +1,95 @@
// Copyright 2023 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package events
import (
"time"
"github.com/gravitational/trace"
"golang.org/x/time/rate"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
)
// SearchEventsLimiter allows to wrap any AuditLogger with rate limit on
// search events endpoints.
// Note it share limiter for both SearchEvents and SearchSessionEvents.
type SearchEventsLimiter struct {
limiter *rate.Limiter
AuditLogger
}
// SearchEventsLimiterConfig is configuration for SearchEventsLimiter.
type SearchEventsLimiterConfig struct {
// RefillTime determines the duration of time between the addition of tokens to the bucket.
RefillTime time.Duration
// RefillAmount is the number of tokens that are added to the bucket during interval
// specified by RefillTime.
RefillAmount int
// Burst defines number of available tokens. It's initially full and refilled
// based on RefillAmount and RefillTime.
Burst int
// AuditLogger is auditLogger that will be wrapped with limiter on search endpoints.
AuditLogger AuditLogger
}
func (cfg *SearchEventsLimiterConfig) CheckAndSetDefaults() error {
if cfg.AuditLogger == nil {
return trace.BadParameter("empty auditLogger")
}
if cfg.Burst <= 0 {
return trace.BadParameter("Burst cannot be less or equal to 0")
}
if cfg.RefillAmount <= 0 {
return trace.BadParameter("RefillAmount cannot be less or equal to 0")
}
if cfg.RefillTime == 0 {
// Default to seconds so it can be just used as rate.
cfg.RefillTime = time.Second
}
return nil
}
// NewSearchEventLimiter returns instance of new SearchEventsLimiter.
func NewSearchEventLimiter(cfg SearchEventsLimiterConfig) (*SearchEventsLimiter, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return &SearchEventsLimiter{
limiter: rate.NewLimiter(rate.Every(cfg.RefillTime/time.Duration(cfg.RefillAmount)), cfg.Burst),
AuditLogger: cfg.AuditLogger,
}, nil
}
func (s *SearchEventsLimiter) SearchEvents(fromUTC, toUTC time.Time, namespace string,
eventTypes []string, limit int, order types.EventOrder, startKey string,
) ([]apievents.AuditEvent, string, error) {
if !s.limiter.Allow() {
return nil, "", trace.LimitExceeded("rate limit exceeded for searching events")
}
out, keyset, err := s.AuditLogger.SearchEvents(fromUTC, toUTC, namespace, eventTypes, limit, order, startKey)
return out, keyset, trace.Wrap(err)
}
func (s *SearchEventsLimiter) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string,
) ([]apievents.AuditEvent, string, error) {
if !s.limiter.Allow() {
return nil, "", trace.LimitExceeded("rate limit exceeded for searching events")
}
out, keyset, err := s.AuditLogger.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID)
return out, keyset, trace.Wrap(err)
}

View file

@ -0,0 +1,168 @@
// Copyright 2023 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package events_test
import (
"context"
"testing"
"time"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events"
)
func TestSearchEventsLimiter(t *testing.T) {
t.Parallel()
t.Run("emitting events happen without any limiting", func(t *testing.T) {
s, err := events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{
RefillAmount: 1,
Burst: 1,
AuditLogger: &mockAuditLogger{
emitAuditEventRespFn: func() error { return nil },
},
})
require.NoError(t, err)
for i := 0; i < 20; i++ {
require.NoError(t, s.EmitAuditEvent(context.Background(), &apievents.AccessRequestCreate{}))
}
})
t.Run("with limiter", func(t *testing.T) {
burst := 20
s, err := events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{
RefillTime: 20 * time.Millisecond,
RefillAmount: 1,
Burst: burst,
AuditLogger: &mockAuditLogger{
searchEventsRespFn: func() ([]apievents.AuditEvent, string, error) { return nil, "", nil },
},
})
require.NoError(t, err)
someDate := clockwork.NewFakeClock().Now().UTC()
// searchEvents and searchSessionEvents are helper fn to avoid coping those methods with huge
// number of attributes multiple times in that test case.
searchEvents := func() ([]apievents.AuditEvent, string, error) {
return s.SearchEvents(someDate, someDate, "default", nil /* eventTypes */, 100 /* limit */, types.EventOrderAscending, "" /* startKey */)
}
searchSessionEvents := func() ([]apievents.AuditEvent, string, error) {
return s.SearchSessionEvents(someDate, someDate, 100 /* limit */, types.EventOrderAscending, "" /* startKey */, nil /* cond */, "" /* sessionID */)
}
for i := 0; i < burst; i++ {
var err error
// rate limit is shared between both search endpoints.
if i%2 == 0 {
_, _, err = searchEvents()
} else {
_, _, err = searchSessionEvents()
}
require.NoError(t, err)
}
// Now all tokens from rate limit should be used
_, _, err = searchEvents()
require.True(t, trace.IsLimitExceeded(err))
// Also on SearchSessionEvents
_, _, err = searchSessionEvents()
require.True(t, trace.IsLimitExceeded(err))
// After 20ms 1 token should be added according to rate.
require.Eventually(t, func() bool {
_, _, err := searchEvents()
return err == nil
}, 40*time.Millisecond, 5*time.Millisecond)
})
}
func TestSearchEventsLimiterConfig(t *testing.T) {
tests := []struct {
name string
cfg events.SearchEventsLimiterConfig
wantFn func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig)
}{
{
name: "valid config",
cfg: events.SearchEventsLimiterConfig{
AuditLogger: &mockAuditLogger{},
RefillAmount: 1,
Burst: 1,
},
wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) {
require.NoError(t, err)
require.Equal(t, time.Second, cfg.RefillTime)
},
},
{
name: "empty rate in config",
cfg: events.SearchEventsLimiterConfig{
AuditLogger: &mockAuditLogger{},
Burst: 1,
},
wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) {
require.ErrorContains(t, err, "RefillAmount cannot be less or equal to 0")
},
},
{
name: "empty burst in config",
cfg: events.SearchEventsLimiterConfig{
AuditLogger: &mockAuditLogger{},
RefillAmount: 1,
},
wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) {
require.ErrorContains(t, err, "Burst cannot be less or equal to 0")
},
},
{
name: "empty logger",
cfg: events.SearchEventsLimiterConfig{
RefillAmount: 1,
Burst: 1,
},
wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) {
require.ErrorContains(t, err, "empty auditLogger")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.cfg.CheckAndSetDefaults()
tt.wantFn(t, err, tt.cfg)
})
}
}
type mockAuditLogger struct {
searchEventsRespFn func() ([]apievents.AuditEvent, string, error)
emitAuditEventRespFn func() error
events.AuditLogger
}
func (m *mockAuditLogger) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) {
return m.searchEventsRespFn()
}
func (m *mockAuditLogger) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) {
return m.searchEventsRespFn()
}
func (m *mockAuditLogger) EmitAuditEvent(context.Context, apievents.AuditEvent) error {
return m.emitAuditEventRespFn()
}

View file

@ -1414,10 +1414,23 @@ func initAuthExternalAuditLog(ctx context.Context, auditConfig types.ClusterAudi
if err != nil {
return nil, trace.Wrap(err)
}
logger, err := athena.New(ctx, cfg)
var logger events.AuditLogger
logger, err = athena.New(ctx, cfg)
if err != nil {
return nil, trace.Wrap(err)
}
if cfg.LimiterBurst > 0 {
// Wrap athena logger with rate limiter on search events.
logger, err = events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{
RefillTime: cfg.LimiterRefillTime,
RefillAmount: cfg.LimiterRefillAmount,
Burst: cfg.LimiterBurst,
AuditLogger: logger,
})
if err != nil {
return nil, trace.Wrap(err)
}
}
loggers = append(loggers, logger)
case teleport.SchemeFile:
if uri.Path == "" {

View file

@ -52,6 +52,8 @@ import (
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/athena"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/reversetunnel"
@ -297,6 +299,47 @@ func TestServiceInitExternalLog(t *testing.T) {
}
}
func TestAthenaAuditLogSetup(t *testing.T) {
sampleValidConfig := "athena://db.table?topicArn=arn:aws:sns:eu-central-1:accnr:topicName&queryResultsS3=s3://testbucket/query-result/&workgroup=workgroup&locationS3=s3://testbucket/events-location&queueURL=https://sqs.eu-central-1.amazonaws.com/accnr/sqsname&largeEventsS3=s3://testbucket/largeevents"
tests := []struct {
name string
uri string
wantFn func(*testing.T, events.AuditLogger, error)
}{
{
name: "valid athena config",
uri: sampleValidConfig,
wantFn: func(t *testing.T, alog events.AuditLogger, err error) {
require.NoError(t, err)
v, ok := alog.(*athena.Log)
require.True(t, ok, "invalid logger type, got %T", v)
},
},
{
name: "config with rate limit - should use events.SearchEventsLimiter",
uri: sampleValidConfig + "&limiterRefillAmount=3&limiterBurst=2",
wantFn: func(t *testing.T, alog events.AuditLogger, err error) {
require.NoError(t, err)
_, ok := alog.(*events.SearchEventsLimiter)
require.True(t, ok, "invalid logger type, got %T", alog)
},
},
}
backend, err := memory.New(memory.Config{})
require.NoError(t, err)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
auditConfig, err := types.NewClusterAuditConfig(types.ClusterAuditConfigSpecV2{
AuditEventsURI: []string{tt.uri},
AuditSessionsURI: "s3://testbucket/sessions-rec",
})
require.NoError(t, err)
log, err := initAuthExternalAuditLog(context.Background(), auditConfig, backend)
tt.wantFn(t, log, err)
})
}
}
func TestGetAdditionalPrincipals(t *testing.T) {
p := &TeleportProcess{
Config: &servicecfg.Config{