mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 17:53:28 +00:00
athena audit logs - query rate limiter (#24918)
This commit is contained in:
parent
4d0f6c58ea
commit
29497f5d85
|
@ -16,7 +16,6 @@ package athena
|
|||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
@ -81,9 +80,14 @@ type Config struct {
|
|||
// GetQueryResultsInterval is used to define how long query will wait before
|
||||
// checking again for results status if previous status was not ready (optional).
|
||||
GetQueryResultsInterval time.Duration
|
||||
// LimiterRate defines rate at which search_event rate limiter is filled (optional).
|
||||
LimiterRate float64
|
||||
// LimiterBurst defines rate limit bucket capacity (optional).
|
||||
|
||||
// LimiterRefillTime determines the duration of time between the addition of tokens to the bucket (optional).
|
||||
LimiterRefillTime time.Duration
|
||||
// LimiterRefillAmount is the number of tokens that are added to the bucket during interval
|
||||
// specified by LimiterRefillTime (optional).
|
||||
LimiterRefillAmount int
|
||||
// Burst defines number of available tokens. It's initially full and refilled
|
||||
// based on LimiterRefillAmount and LimiterRefillTime (optional).
|
||||
LimiterBurst int
|
||||
|
||||
// Batcher settings.
|
||||
|
@ -198,19 +202,23 @@ func (cfg *Config) CheckAndSetDefaults(ctx context.Context) error {
|
|||
return trace.BadParameter("BatchMaxInterval too short, must be greater than 5s")
|
||||
}
|
||||
|
||||
if cfg.LimiterRate < 0 {
|
||||
return trace.BadParameter("LimiterRate cannot be negative")
|
||||
if cfg.LimiterRefillAmount < 0 {
|
||||
return trace.BadParameter("LimiterRefillAmount cannot be nagative")
|
||||
}
|
||||
if cfg.LimiterBurst < 0 {
|
||||
return trace.BadParameter("LimiterBurst cannot be negative")
|
||||
}
|
||||
|
||||
if cfg.LimiterRate > 0 && cfg.LimiterBurst == 0 {
|
||||
return trace.BadParameter("LimiterBurst must be greater than 0 if LimiterRate is used")
|
||||
if cfg.LimiterRefillAmount > 0 && cfg.LimiterBurst == 0 {
|
||||
return trace.BadParameter("LimiterBurst must be greater than 0 if LimiterRefillAmount is used")
|
||||
}
|
||||
|
||||
if cfg.LimiterBurst > 0 && math.Abs(cfg.LimiterRate) < 1e-9 {
|
||||
return trace.BadParameter("LimiterRate must be greater than 0 if LimiterBurst is used")
|
||||
if cfg.LimiterBurst > 0 && cfg.LimiterRefillAmount == 0 {
|
||||
return trace.BadParameter("LimiterRefillAmount must be greater than 0 if LimiterBurst is used")
|
||||
}
|
||||
|
||||
if cfg.LimiterRefillAmount > 0 && cfg.LimiterRefillTime == 0 {
|
||||
cfg.LimiterRefillTime = time.Second
|
||||
}
|
||||
|
||||
if cfg.Clock == nil {
|
||||
|
@ -283,13 +291,21 @@ func (cfg *Config) SetFromURL(url *url.URL) error {
|
|||
}
|
||||
cfg.GetQueryResultsInterval = dur
|
||||
}
|
||||
rateInString := url.Query().Get("limiterRate")
|
||||
if rateInString != "" {
|
||||
rate, err := strconv.ParseFloat(rateInString, 32)
|
||||
refillAmountInString := url.Query().Get("limiterRefillAmount")
|
||||
if refillAmountInString != "" {
|
||||
refillAmount, err := strconv.Atoi(refillAmountInString)
|
||||
if err != nil {
|
||||
return trace.BadParameter("invalid limiterRate value (it must be float32): %v", err)
|
||||
return trace.BadParameter("invalid limiterRefillAmount value (it must be int): %v", err)
|
||||
}
|
||||
cfg.LimiterRate = rate
|
||||
cfg.LimiterRefillAmount = refillAmount
|
||||
}
|
||||
refillTimeInString := url.Query().Get("limiterRefillTime")
|
||||
if refillTimeInString != "" {
|
||||
dur, err := time.ParseDuration(refillTimeInString)
|
||||
if err != nil {
|
||||
return trace.BadParameter("invalid limiterRefillTime value: %v", err)
|
||||
}
|
||||
cfg.LimiterRefillTime = dur
|
||||
}
|
||||
burstInString := url.Query().Get("limiterBurst")
|
||||
if burstInString != "" {
|
||||
|
|
|
@ -74,12 +74,13 @@ func TestConfig_SetFromURL(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "params to querier - part 2",
|
||||
url: "athena://db.tbl/?getQueryResultsInterval=200ms&limiterRate=0.642&limiterBurst=3",
|
||||
url: "athena://db.tbl/?getQueryResultsInterval=200ms&limiterRefillAmount=2&&limiterRefillTime=2s&limiterBurst=3",
|
||||
want: Config{
|
||||
TableName: "tbl",
|
||||
Database: "db",
|
||||
GetQueryResultsInterval: 200 * time.Millisecond,
|
||||
LimiterRate: 0.642,
|
||||
LimiterRefillAmount: 2,
|
||||
LimiterRefillTime: 2 * time.Second,
|
||||
LimiterBurst: 3,
|
||||
},
|
||||
},
|
||||
|
@ -100,9 +101,9 @@ func TestConfig_SetFromURL(t *testing.T) {
|
|||
wantErr: "invalid athena address, supported format is 'athena://database.table'",
|
||||
},
|
||||
{
|
||||
name: "invalid limiterRate format",
|
||||
url: "athena://db.tbl/?limiterRate=abc",
|
||||
wantErr: "invalid limiterRate value (it must be float32)",
|
||||
name: "invalid limiterRefillAmount format",
|
||||
url: "athena://db.tbl/?limiterRefillAmount=abc",
|
||||
wantErr: "invalid limiterRefillAmount value (it must be int)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -163,6 +164,33 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) {
|
|||
Backend: mockBackend{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid config with limiter, check defaults refillTime",
|
||||
input: func() Config {
|
||||
cfg := validConfig
|
||||
cfg.LimiterBurst = 10
|
||||
cfg.LimiterRefillAmount = 5
|
||||
return cfg
|
||||
},
|
||||
want: Config{
|
||||
Database: "db",
|
||||
TableName: "tbl",
|
||||
TopicARN: "arn:topic",
|
||||
LargeEventsS3: "s3://large-payloads-bucket",
|
||||
largeEventsBucket: "large-payloads-bucket",
|
||||
LocationS3: "s3://events-bucket",
|
||||
locationS3Bucket: "events-bucket",
|
||||
QueueURL: "https://queue-url",
|
||||
GetQueryResultsInterval: 100 * time.Millisecond,
|
||||
BatchMaxItems: 20000,
|
||||
BatchMaxInterval: 1 * time.Minute,
|
||||
AWSConfig: &aws.Config{},
|
||||
Backend: mockBackend{},
|
||||
LimiterRefillTime: 1 * time.Second,
|
||||
LimiterBurst: 10,
|
||||
LimiterRefillAmount: 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing table name",
|
||||
input: func() Config {
|
||||
|
@ -227,24 +255,24 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) {
|
|||
wantErr: "QueueURL must be valid url and start with https",
|
||||
},
|
||||
{
|
||||
name: "invalid LimiterBurst and LimiterRate combination",
|
||||
name: "invalid LimiterBurst and LimiterRefillAmount combination",
|
||||
input: func() Config {
|
||||
cfg := validConfig
|
||||
cfg.LimiterBurst = 0
|
||||
cfg.LimiterRate = 2.5
|
||||
cfg.LimiterRefillAmount = 2
|
||||
return cfg
|
||||
},
|
||||
wantErr: "LimiterBurst must be greater than 0 if LimiterRate is used",
|
||||
wantErr: "LimiterBurst must be greater than 0 if LimiterRefillAmount is used",
|
||||
},
|
||||
{
|
||||
name: "invalid LimiterRate and LimiterBurst combination",
|
||||
name: "invalid LimiterRefillAmount and LimiterBurst combination",
|
||||
input: func() Config {
|
||||
cfg := validConfig
|
||||
cfg.LimiterBurst = 3
|
||||
cfg.LimiterRate = 0
|
||||
cfg.LimiterRefillAmount = 0
|
||||
return cfg
|
||||
},
|
||||
wantErr: "LimiterRate must be greater than 0 if LimiterBurst is used",
|
||||
wantErr: "LimiterRefillAmount must be greater than 0 if LimiterBurst is used",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
|
|
@ -117,7 +117,16 @@ func (q *querier) SearchEvents(fromUTC, toUTC time.Time, namespace string,
|
|||
eventTypes []string, limit int, order types.EventOrder, startKey string,
|
||||
) ([]apievents.AuditEvent, string, error) {
|
||||
filter := searchEventsFilter{eventTypes: eventTypes}
|
||||
return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, "")
|
||||
events, keyset, err := q.searchEvents(context.TODO(), searchEventsRequest{
|
||||
fromUTC: fromUTC,
|
||||
toUTC: toUTC,
|
||||
limit: limit,
|
||||
order: order,
|
||||
startKey: startKey,
|
||||
filter: filter,
|
||||
sessionID: "",
|
||||
})
|
||||
return events, keyset, trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
|
||||
|
@ -133,12 +142,29 @@ func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
|
|||
}
|
||||
filter.condition = condFn
|
||||
}
|
||||
return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, sessionID)
|
||||
events, keyset, err := q.searchEvents(context.TODO(), searchEventsRequest{
|
||||
fromUTC: fromUTC,
|
||||
toUTC: toUTC,
|
||||
limit: limit,
|
||||
order: order,
|
||||
startKey: startKey,
|
||||
filter: filter,
|
||||
sessionID: sessionID,
|
||||
})
|
||||
return events, keyset, trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, limit int,
|
||||
order types.EventOrder, startKey string, filter searchEventsFilter, sessionID string,
|
||||
) ([]apievents.AuditEvent, string, error) {
|
||||
type searchEventsRequest struct {
|
||||
fromUTC, toUTC time.Time
|
||||
limit int
|
||||
order types.EventOrder
|
||||
startKey string
|
||||
filter searchEventsFilter
|
||||
sessionID string
|
||||
}
|
||||
|
||||
func (q *querier) searchEvents(ctx context.Context, req searchEventsRequest) ([]apievents.AuditEvent, string, error) {
|
||||
limit := req.limit
|
||||
if limit <= 0 {
|
||||
limit = defaults.EventsIterationLimit
|
||||
}
|
||||
|
@ -147,28 +173,28 @@ func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, li
|
|||
}
|
||||
|
||||
var startKeyset *keyset
|
||||
if startKey != "" {
|
||||
if req.startKey != "" {
|
||||
var err error
|
||||
startKeyset, err = fromKey(startKey)
|
||||
startKeyset, err = fromKey(req.startKey)
|
||||
if err != nil {
|
||||
return nil, "", trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
query, params := prepareQuery(searchParams{
|
||||
fromUTC: fromUTC,
|
||||
toUTC: toUTC,
|
||||
order: order,
|
||||
fromUTC: req.fromUTC,
|
||||
toUTC: req.toUTC,
|
||||
order: req.order,
|
||||
limit: limit,
|
||||
startKeyset: startKeyset,
|
||||
filter: filter,
|
||||
sessionID: sessionID,
|
||||
filter: req.filter,
|
||||
sessionID: req.sessionID,
|
||||
tablename: q.tablename,
|
||||
})
|
||||
|
||||
q.logger.WithField("query", query).
|
||||
WithField("params", params).
|
||||
WithField("startKey", startKey).
|
||||
WithField("startKey", req.startKey).
|
||||
Debug("Executing events query on Athena")
|
||||
|
||||
queryId, err := q.startQueryExecution(ctx, query, params)
|
||||
|
@ -180,7 +206,7 @@ func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, li
|
|||
return nil, "", trace.Wrap(err)
|
||||
}
|
||||
|
||||
output, nextKey, err := q.fetchResults(ctx, queryId, limit, filter.condition)
|
||||
output, nextKey, err := q.fetchResults(ctx, queryId, limit, req.filter.condition)
|
||||
return output, nextKey, trace.Wrap(err)
|
||||
}
|
||||
|
||||
|
|
95
lib/events/search_limiter.go
Normal file
95
lib/events/search_limiter.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
// Copyright 2023 Gravitational, Inc
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package events
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
apievents "github.com/gravitational/teleport/api/types/events"
|
||||
)
|
||||
|
||||
// SearchEventsLimiter allows to wrap any AuditLogger with rate limit on
|
||||
// search events endpoints.
|
||||
// Note it share limiter for both SearchEvents and SearchSessionEvents.
|
||||
type SearchEventsLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
AuditLogger
|
||||
}
|
||||
|
||||
// SearchEventsLimiterConfig is configuration for SearchEventsLimiter.
|
||||
type SearchEventsLimiterConfig struct {
|
||||
// RefillTime determines the duration of time between the addition of tokens to the bucket.
|
||||
RefillTime time.Duration
|
||||
// RefillAmount is the number of tokens that are added to the bucket during interval
|
||||
// specified by RefillTime.
|
||||
RefillAmount int
|
||||
// Burst defines number of available tokens. It's initially full and refilled
|
||||
// based on RefillAmount and RefillTime.
|
||||
Burst int
|
||||
// AuditLogger is auditLogger that will be wrapped with limiter on search endpoints.
|
||||
AuditLogger AuditLogger
|
||||
}
|
||||
|
||||
func (cfg *SearchEventsLimiterConfig) CheckAndSetDefaults() error {
|
||||
if cfg.AuditLogger == nil {
|
||||
return trace.BadParameter("empty auditLogger")
|
||||
}
|
||||
if cfg.Burst <= 0 {
|
||||
return trace.BadParameter("Burst cannot be less or equal to 0")
|
||||
}
|
||||
if cfg.RefillAmount <= 0 {
|
||||
return trace.BadParameter("RefillAmount cannot be less or equal to 0")
|
||||
}
|
||||
if cfg.RefillTime == 0 {
|
||||
// Default to seconds so it can be just used as rate.
|
||||
cfg.RefillTime = time.Second
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewSearchEventLimiter returns instance of new SearchEventsLimiter.
|
||||
func NewSearchEventLimiter(cfg SearchEventsLimiterConfig) (*SearchEventsLimiter, error) {
|
||||
if err := cfg.CheckAndSetDefaults(); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return &SearchEventsLimiter{
|
||||
limiter: rate.NewLimiter(rate.Every(cfg.RefillTime/time.Duration(cfg.RefillAmount)), cfg.Burst),
|
||||
AuditLogger: cfg.AuditLogger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SearchEventsLimiter) SearchEvents(fromUTC, toUTC time.Time, namespace string,
|
||||
eventTypes []string, limit int, order types.EventOrder, startKey string,
|
||||
) ([]apievents.AuditEvent, string, error) {
|
||||
if !s.limiter.Allow() {
|
||||
return nil, "", trace.LimitExceeded("rate limit exceeded for searching events")
|
||||
}
|
||||
out, keyset, err := s.AuditLogger.SearchEvents(fromUTC, toUTC, namespace, eventTypes, limit, order, startKey)
|
||||
return out, keyset, trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (s *SearchEventsLimiter) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
|
||||
order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string,
|
||||
) ([]apievents.AuditEvent, string, error) {
|
||||
if !s.limiter.Allow() {
|
||||
return nil, "", trace.LimitExceeded("rate limit exceeded for searching events")
|
||||
}
|
||||
out, keyset, err := s.AuditLogger.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID)
|
||||
return out, keyset, trace.Wrap(err)
|
||||
}
|
168
lib/events/search_limiter_test.go
Normal file
168
lib/events/search_limiter_test.go
Normal file
|
@ -0,0 +1,168 @@
|
|||
// Copyright 2023 Gravitational, Inc
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package events_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
apievents "github.com/gravitational/teleport/api/types/events"
|
||||
"github.com/gravitational/teleport/lib/events"
|
||||
)
|
||||
|
||||
func TestSearchEventsLimiter(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("emitting events happen without any limiting", func(t *testing.T) {
|
||||
s, err := events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{
|
||||
RefillAmount: 1,
|
||||
Burst: 1,
|
||||
AuditLogger: &mockAuditLogger{
|
||||
emitAuditEventRespFn: func() error { return nil },
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
for i := 0; i < 20; i++ {
|
||||
require.NoError(t, s.EmitAuditEvent(context.Background(), &apievents.AccessRequestCreate{}))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with limiter", func(t *testing.T) {
|
||||
burst := 20
|
||||
s, err := events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{
|
||||
RefillTime: 20 * time.Millisecond,
|
||||
RefillAmount: 1,
|
||||
Burst: burst,
|
||||
AuditLogger: &mockAuditLogger{
|
||||
searchEventsRespFn: func() ([]apievents.AuditEvent, string, error) { return nil, "", nil },
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
someDate := clockwork.NewFakeClock().Now().UTC()
|
||||
// searchEvents and searchSessionEvents are helper fn to avoid coping those methods with huge
|
||||
// number of attributes multiple times in that test case.
|
||||
searchEvents := func() ([]apievents.AuditEvent, string, error) {
|
||||
return s.SearchEvents(someDate, someDate, "default", nil /* eventTypes */, 100 /* limit */, types.EventOrderAscending, "" /* startKey */)
|
||||
}
|
||||
searchSessionEvents := func() ([]apievents.AuditEvent, string, error) {
|
||||
return s.SearchSessionEvents(someDate, someDate, 100 /* limit */, types.EventOrderAscending, "" /* startKey */, nil /* cond */, "" /* sessionID */)
|
||||
}
|
||||
|
||||
for i := 0; i < burst; i++ {
|
||||
var err error
|
||||
// rate limit is shared between both search endpoints.
|
||||
if i%2 == 0 {
|
||||
_, _, err = searchEvents()
|
||||
} else {
|
||||
_, _, err = searchSessionEvents()
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}
|
||||
// Now all tokens from rate limit should be used
|
||||
_, _, err = searchEvents()
|
||||
require.True(t, trace.IsLimitExceeded(err))
|
||||
// Also on SearchSessionEvents
|
||||
_, _, err = searchSessionEvents()
|
||||
require.True(t, trace.IsLimitExceeded(err))
|
||||
|
||||
// After 20ms 1 token should be added according to rate.
|
||||
require.Eventually(t, func() bool {
|
||||
_, _, err := searchEvents()
|
||||
return err == nil
|
||||
}, 40*time.Millisecond, 5*time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSearchEventsLimiterConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg events.SearchEventsLimiterConfig
|
||||
wantFn func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig)
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
cfg: events.SearchEventsLimiterConfig{
|
||||
AuditLogger: &mockAuditLogger{},
|
||||
RefillAmount: 1,
|
||||
Burst: 1,
|
||||
},
|
||||
wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, time.Second, cfg.RefillTime)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty rate in config",
|
||||
cfg: events.SearchEventsLimiterConfig{
|
||||
AuditLogger: &mockAuditLogger{},
|
||||
Burst: 1,
|
||||
},
|
||||
wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) {
|
||||
require.ErrorContains(t, err, "RefillAmount cannot be less or equal to 0")
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "empty burst in config",
|
||||
cfg: events.SearchEventsLimiterConfig{
|
||||
AuditLogger: &mockAuditLogger{},
|
||||
RefillAmount: 1,
|
||||
},
|
||||
wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) {
|
||||
require.ErrorContains(t, err, "Burst cannot be less or equal to 0")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty logger",
|
||||
cfg: events.SearchEventsLimiterConfig{
|
||||
RefillAmount: 1,
|
||||
Burst: 1,
|
||||
},
|
||||
wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) {
|
||||
require.ErrorContains(t, err, "empty auditLogger")
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.cfg.CheckAndSetDefaults()
|
||||
tt.wantFn(t, err, tt.cfg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockAuditLogger struct {
|
||||
searchEventsRespFn func() ([]apievents.AuditEvent, string, error)
|
||||
emitAuditEventRespFn func() error
|
||||
events.AuditLogger
|
||||
}
|
||||
|
||||
func (m *mockAuditLogger) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) {
|
||||
return m.searchEventsRespFn()
|
||||
}
|
||||
|
||||
func (m *mockAuditLogger) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) {
|
||||
return m.searchEventsRespFn()
|
||||
}
|
||||
|
||||
func (m *mockAuditLogger) EmitAuditEvent(context.Context, apievents.AuditEvent) error {
|
||||
return m.emitAuditEventRespFn()
|
||||
}
|
|
@ -1414,10 +1414,23 @@ func initAuthExternalAuditLog(ctx context.Context, auditConfig types.ClusterAudi
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
logger, err := athena.New(ctx, cfg)
|
||||
var logger events.AuditLogger
|
||||
logger, err = athena.New(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if cfg.LimiterBurst > 0 {
|
||||
// Wrap athena logger with rate limiter on search events.
|
||||
logger, err = events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{
|
||||
RefillTime: cfg.LimiterRefillTime,
|
||||
RefillAmount: cfg.LimiterRefillAmount,
|
||||
Burst: cfg.LimiterBurst,
|
||||
AuditLogger: logger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
loggers = append(loggers, logger)
|
||||
case teleport.SchemeFile:
|
||||
if uri.Path == "" {
|
||||
|
|
|
@ -52,6 +52,8 @@ import (
|
|||
"github.com/gravitational/teleport/lib/backend/memory"
|
||||
"github.com/gravitational/teleport/lib/cloud"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/events"
|
||||
"github.com/gravitational/teleport/lib/events/athena"
|
||||
"github.com/gravitational/teleport/lib/limiter"
|
||||
"github.com/gravitational/teleport/lib/modules"
|
||||
"github.com/gravitational/teleport/lib/reversetunnel"
|
||||
|
@ -297,6 +299,47 @@ func TestServiceInitExternalLog(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestAthenaAuditLogSetup(t *testing.T) {
|
||||
sampleValidConfig := "athena://db.table?topicArn=arn:aws:sns:eu-central-1:accnr:topicName&queryResultsS3=s3://testbucket/query-result/&workgroup=workgroup&locationS3=s3://testbucket/events-location&queueURL=https://sqs.eu-central-1.amazonaws.com/accnr/sqsname&largeEventsS3=s3://testbucket/largeevents"
|
||||
tests := []struct {
|
||||
name string
|
||||
uri string
|
||||
wantFn func(*testing.T, events.AuditLogger, error)
|
||||
}{
|
||||
{
|
||||
name: "valid athena config",
|
||||
uri: sampleValidConfig,
|
||||
wantFn: func(t *testing.T, alog events.AuditLogger, err error) {
|
||||
require.NoError(t, err)
|
||||
v, ok := alog.(*athena.Log)
|
||||
require.True(t, ok, "invalid logger type, got %T", v)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "config with rate limit - should use events.SearchEventsLimiter",
|
||||
uri: sampleValidConfig + "&limiterRefillAmount=3&limiterBurst=2",
|
||||
wantFn: func(t *testing.T, alog events.AuditLogger, err error) {
|
||||
require.NoError(t, err)
|
||||
_, ok := alog.(*events.SearchEventsLimiter)
|
||||
require.True(t, ok, "invalid logger type, got %T", alog)
|
||||
},
|
||||
},
|
||||
}
|
||||
backend, err := memory.New(memory.Config{})
|
||||
require.NoError(t, err)
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
auditConfig, err := types.NewClusterAuditConfig(types.ClusterAuditConfigSpecV2{
|
||||
AuditEventsURI: []string{tt.uri},
|
||||
AuditSessionsURI: "s3://testbucket/sessions-rec",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
log, err := initAuthExternalAuditLog(context.Background(), auditConfig, backend)
|
||||
tt.wantFn(t, log, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAdditionalPrincipals(t *testing.T) {
|
||||
p := &TeleportProcess{
|
||||
Config: &servicecfg.Config{
|
||||
|
|
Loading…
Reference in a new issue