teleport/lib/limiter/limiter_test.go

478 lines
11 KiB
Go
Raw Normal View History

2015-12-02 18:51:32 +00:00
/*
Copyright 2015 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
2015-12-02 18:51:32 +00:00
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
2015-12-03 09:26:34 +00:00
package limiter
2015-12-02 18:51:32 +00:00
import (
"context"
"errors"
"net"
"os"
2015-12-02 18:51:32 +00:00
"testing"
"time"
"github.com/gravitational/oxy/ratelimit"
"github.com/gravitational/trace"
2016-02-16 03:48:09 +00:00
"github.com/mailgun/timetools"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
2022-10-28 20:20:28 +00:00
"github.com/gravitational/teleport/lib/utils"
2015-12-02 18:51:32 +00:00
)
func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
}
func TestConnectionsLimiter(t *testing.T) {
2015-12-03 09:26:34 +00:00
limiter, err := NewLimiter(
Config{
2015-12-02 18:51:32 +00:00
MaxConnections: 0,
},
)
require.NoError(t, err)
2015-12-02 18:51:32 +00:00
for i := 0; i < 10; i++ {
require.NoError(t, limiter.AcquireConnection("token1"))
2015-12-02 18:51:32 +00:00
}
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
2015-12-02 18:51:32 +00:00
}
for i := 0; i < 10; i++ {
limiter.ReleaseConnection("token1")
}
for i := 0; i < 5; i++ {
limiter.ReleaseConnection("token2")
}
2015-12-03 09:26:34 +00:00
limiter, err = NewLimiter(
Config{
2015-12-02 18:51:32 +00:00
MaxConnections: 5,
},
)
require.NoError(t, err)
2015-12-02 18:51:32 +00:00
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token1"))
2015-12-02 18:51:32 +00:00
}
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
2015-12-02 18:51:32 +00:00
}
for i := 0; i < 5; i++ {
require.Error(t, limiter.AcquireConnection("token2"))
2015-12-02 18:51:32 +00:00
}
for i := 0; i < 10; i++ {
limiter.ReleaseConnection("token1")
require.NoError(t, limiter.AcquireConnection("token1"))
2015-12-02 18:51:32 +00:00
}
for i := 0; i < 5; i++ {
limiter.ReleaseConnection("token2")
}
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
2015-12-02 18:51:32 +00:00
}
}
func TestRateLimiter(t *testing.T) {
// TODO: this test fails
2016-02-16 03:48:09 +00:00
clock := &timetools.FreezedTime{
CurrentTime: time.Date(2016, 6, 5, 4, 3, 2, 1, time.UTC),
}
2015-12-03 09:26:34 +00:00
limiter, err := NewLimiter(
Config{
2016-02-16 03:48:09 +00:00
Clock: clock,
2015-12-02 18:51:32 +00:00
Rates: []Rate{
{
2015-12-02 18:51:32 +00:00
Period: 10 * time.Millisecond,
Average: 10,
Burst: 20,
},
{
2015-12-02 18:51:32 +00:00
Period: 40 * time.Millisecond,
Average: 10,
Burst: 40,
},
},
2016-02-16 03:48:09 +00:00
})
require.NoError(t, err)
2015-12-02 18:51:32 +00:00
2015-12-03 09:26:34 +00:00
for i := 0; i < 20; i++ {
require.NoError(t, limiter.RegisterRequest("token1"))
2015-12-02 18:51:32 +00:00
}
2015-12-03 09:26:34 +00:00
for i := 0; i < 20; i++ {
require.NoError(t, limiter.RegisterRequest("token2"))
2015-12-03 09:26:34 +00:00
}
require.Error(t, limiter.RegisterRequest("token1"))
2015-12-02 18:51:32 +00:00
2016-02-16 03:48:09 +00:00
clock.Sleep(10 * time.Millisecond)
2015-12-02 18:51:32 +00:00
for i := 0; i < 10; i++ {
require.NoError(t, limiter.RegisterRequest("token1"))
2015-12-02 18:51:32 +00:00
}
require.Error(t, limiter.RegisterRequest("token1"))
2015-12-02 18:51:32 +00:00
2016-02-16 03:48:09 +00:00
clock.Sleep(10 * time.Millisecond)
2015-12-02 18:51:32 +00:00
for i := 0; i < 10; i++ {
require.NoError(t, limiter.RegisterRequest("token1"))
2015-12-02 18:51:32 +00:00
}
require.Error(t, limiter.RegisterRequest("token1"))
2015-12-02 18:51:32 +00:00
2016-02-16 03:48:09 +00:00
clock.Sleep(10 * time.Millisecond)
2015-12-02 18:51:32 +00:00
// the second rate is full
2015-12-03 09:26:34 +00:00
err = nil
for i := 0; i < 10; i++ {
err = limiter.RegisterRequest("token1")
if err != nil {
break
}
}
require.Error(t, err)
2015-12-02 18:51:32 +00:00
2016-02-16 03:48:09 +00:00
clock.Sleep(10 * time.Millisecond)
// Now the second rate has free space
require.NoError(t, limiter.RegisterRequest("token1"))
2015-12-03 09:26:34 +00:00
err = nil
2015-12-07 20:05:54 +00:00
for i := 0; i < 15; i++ {
2015-12-03 09:26:34 +00:00
err = limiter.RegisterRequest("token1")
if err != nil {
break
}
}
require.Error(t, err)
2015-12-02 18:51:32 +00:00
}
func TestCustomRate(t *testing.T) {
clock := &timetools.FreezedTime{
CurrentTime: time.Date(2016, 6, 5, 4, 3, 2, 1, time.UTC),
}
limiter, err := NewLimiter(
Config{
Clock: clock,
Rates: []Rate{
// Default rate
{
Period: 10 * time.Millisecond,
Average: 10,
Burst: 20,
},
},
})
require.NoError(t, err)
customRate := ratelimit.NewRateSet()
err = customRate.Add(time.Minute, 1, 5)
require.NoError(t, err)
// Max out custom rate.
for i := 0; i < 5; i++ {
require.NoError(t, limiter.RegisterRequestWithCustomRate("token1", customRate))
}
// Test rate limit exceeded with custom rate.
require.Error(t, limiter.RegisterRequestWithCustomRate("token1", customRate))
// Test default rate still works.
for i := 0; i < 20; i++ {
require.NoError(t, limiter.RegisterRequest("token1"))
}
}
type mockAddr struct{}
func (a mockAddr) Network() string {
return "tcp"
}
func (a mockAddr) String() string {
return "127.0.0.1:1234"
}
func TestLimiter_UnaryServerInterceptor(t *testing.T) {
limiter, err := NewLimiter(Config{
MaxConnections: 1,
Rates: []Rate{
{
Period: time.Minute,
Average: 1,
Burst: 1,
},
},
})
require.NoError(t, err)
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: mockAddr{}})
req := "request"
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/method",
}
handler := func(context.Context, interface{}) (interface{}, error) { return nil, nil }
unaryInterceptor := limiter.UnaryServerInterceptor()
// pass at least once
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
require.NoError(t, err)
// should eventually fail, not testing the limiter behavior here
for i := 0; i < 10; i++ {
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
if err != nil {
break
}
}
require.Error(t, err)
getCustomRate := func(endpoint string) *ratelimit.RateSet {
rates := ratelimit.NewRateSet()
err := rates.Add(2*time.Minute, 1, 2)
require.NoError(t, err)
return rates
}
unaryInterceptor = limiter.UnaryServerInterceptorWithCustomRate(getCustomRate)
// should pass at least once
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
require.NoError(t, err)
// should eventually fail, not testing the limiter behavior here
for i := 0; i < 10; i++ {
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
if err != nil {
break
}
}
require.Error(t, err)
}
type mockServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (s mockServerStream) Context() context.Context {
return s.ctx
}
func TestLimiter_StreamServerInterceptor(t *testing.T) {
limiter, err := NewLimiter(Config{
MaxConnections: 1,
Rates: []Rate{
{
Period: time.Minute,
Average: 1,
Burst: 1,
},
},
})
require.NoError(t, err)
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: mockAddr{}})
ss := mockServerStream{
ctx: ctx,
}
info := &grpc.StreamServerInfo{}
handler := func(srv interface{}, stream grpc.ServerStream) error { return nil }
// pass at least once
err = limiter.StreamServerInterceptor(nil, ss, info, handler)
require.NoError(t, err)
// should eventually fail, not testing the limiter behavior here
for i := 0; i < 10; i++ {
err = limiter.StreamServerInterceptor(nil, ss, info, handler)
if err != nil {
break
}
}
require.Error(t, err)
}
// TestListener verifies that a [Listener] only accepts
// connections if the connection limit has not been exceeded.
func TestListener(t *testing.T) {
const connLimit = 5
failedAcceptErr := errors.New("failed accept")
tooManyConnectionsErr := trace.LimitExceeded("too many connections from 127.0.0.1: 2, max is 2")
tests := []struct {
name string
config Config
listener *fakeListener
acceptAssertion func(t *testing.T, iteration int, conn net.Conn, err error)
numConnAssertion func(t *testing.T, num int64)
}{
{
name: "all connections allowed",
config: Config{MaxConnections: 0},
listener: &fakeListener{
acceptConn: &fakeConn{
addr: mockAddr{},
},
},
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
require.NoError(t, err)
require.NotNil(t, conn)
},
numConnAssertion: func(t *testing.T, num int64) {
// MaxConnections == 0 prevents any connections from being accumulated
require.Zero(t, num)
},
},
{
name: "accept failure",
config: Config{MaxConnections: 0},
listener: &fakeListener{
acceptError: failedAcceptErr,
},
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
require.ErrorIs(t, err, failedAcceptErr)
require.Nil(t, conn)
},
numConnAssertion: func(t *testing.T, num int64) {
require.Zero(t, num)
},
},
{
name: "invalid remote address",
config: Config{MaxConnections: 0},
listener: &fakeListener{
acceptConn: &fakeConn{
addr: &utils.NetAddr{
Addr: "abcd",
AddrNetwork: "tcp",
},
},
},
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
require.Error(t, err)
require.Nil(t, conn)
},
numConnAssertion: func(t *testing.T, num int64) {
require.Zero(t, num)
},
},
{
name: "max connections exceeded",
config: Config{MaxConnections: 2},
listener: &fakeListener{
acceptConn: &fakeConn{
addr: mockAddr{},
},
},
acceptAssertion: func(t *testing.T, i int, conn net.Conn, err error) {
if i < 2 {
require.NoError(t, err)
require.NotNil(t, conn)
return
}
require.Error(t, err)
require.ErrorIs(t, err, tooManyConnectionsErr)
require.True(t, trace.IsLimitExceeded(err))
require.Nil(t, conn)
},
numConnAssertion: func(t *testing.T, num int64) {
require.Equal(t, int64(2), num)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
limiter, err := NewConnectionsLimiter(test.config)
require.NoError(t, err)
ln := NewListener(test.listener, limiter)
// open connections without closing to enforce limits
conns := make([]net.Conn, 0, connLimit)
for i := 0; i < connLimit; i++ {
conn, err := ln.Accept()
test.acceptAssertion(t, i, conn, err)
if conn != nil {
conns = append(conns, conn)
}
}
// validate limits were enforced
n, err := limiter.GetNumConnection("127.0.0.1")
require.NoError(t, err)
test.numConnAssertion(t, n)
// close connections to reset limits
for _, conn := range conns {
require.NoError(t, conn.Close())
}
// ensure closing connections resets count
n, err = limiter.GetNumConnection("127.0.0.1")
if test.config.MaxConnections == 0 {
require.NoError(t, err)
require.Zero(t, n)
} else {
require.True(t, trace.IsBadParameter(err))
require.Equal(t, int64(-1), n)
}
// open connections again after closing to
// ensure that closing reset limits
for i := 0; i < 5; i++ {
conn, err := ln.Accept()
test.acceptAssertion(t, i, conn, err)
if conn != nil {
t.Cleanup(func() {
require.NoError(t, err)
})
}
}
})
}
}
type fakeListener struct {
net.Listener
acceptConn net.Conn
acceptError error
}
func (f *fakeListener) Accept() (net.Conn, error) {
return f.acceptConn, f.acceptError
}
type fakeConn struct {
net.Conn
addr net.Addr
}
func (f *fakeConn) RemoteAddr() net.Addr {
return f.addr
}
func (f *fakeConn) Close() error {
return nil
}