teleport/lib/limiter/limiter_test.go
rosstimothy 3d5557d947
Add a connection limiting listener (#20130)
Introduces `limiter.Listener` to provide a consistent and reusable
mechanism for limiting incoming connections per client. The new
listener is used by `sshutils/server.go` instead of manually applying
limits in `HandleConnection`.

This is particularly important now that the Proxy SSH port multiplexes
both SSH and gRPC. Each listener is now wrapped in a `limiter.Listener`
that uses the same `limiter.ConnectionsListener` to ensure that the
connection limits for the Proxy are enforced for all traffic on the
port.
2023-01-19 15:10:11 +00:00

478 lines
11 KiB
Go

/*
Copyright 2015 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package limiter
import (
"context"
"errors"
"net"
"os"
"testing"
"time"
"github.com/gravitational/oxy/ratelimit"
"github.com/gravitational/trace"
"github.com/mailgun/timetools"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/peer"
"github.com/gravitational/teleport/lib/utils"
)
func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
}
func TestConnectionsLimiter(t *testing.T) {
limiter, err := NewLimiter(
Config{
MaxConnections: 0,
},
)
require.NoError(t, err)
for i := 0; i < 10; i++ {
require.NoError(t, limiter.AcquireConnection("token1"))
}
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
}
for i := 0; i < 10; i++ {
limiter.ReleaseConnection("token1")
}
for i := 0; i < 5; i++ {
limiter.ReleaseConnection("token2")
}
limiter, err = NewLimiter(
Config{
MaxConnections: 5,
},
)
require.NoError(t, err)
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token1"))
}
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
}
for i := 0; i < 5; i++ {
require.Error(t, limiter.AcquireConnection("token2"))
}
for i := 0; i < 10; i++ {
limiter.ReleaseConnection("token1")
require.NoError(t, limiter.AcquireConnection("token1"))
}
for i := 0; i < 5; i++ {
limiter.ReleaseConnection("token2")
}
for i := 0; i < 5; i++ {
require.NoError(t, limiter.AcquireConnection("token2"))
}
}
func TestRateLimiter(t *testing.T) {
// TODO: this test fails
clock := &timetools.FreezedTime{
CurrentTime: time.Date(2016, 6, 5, 4, 3, 2, 1, time.UTC),
}
limiter, err := NewLimiter(
Config{
Clock: clock,
Rates: []Rate{
{
Period: 10 * time.Millisecond,
Average: 10,
Burst: 20,
},
{
Period: 40 * time.Millisecond,
Average: 10,
Burst: 40,
},
},
})
require.NoError(t, err)
for i := 0; i < 20; i++ {
require.NoError(t, limiter.RegisterRequest("token1"))
}
for i := 0; i < 20; i++ {
require.NoError(t, limiter.RegisterRequest("token2"))
}
require.Error(t, limiter.RegisterRequest("token1"))
clock.Sleep(10 * time.Millisecond)
for i := 0; i < 10; i++ {
require.NoError(t, limiter.RegisterRequest("token1"))
}
require.Error(t, limiter.RegisterRequest("token1"))
clock.Sleep(10 * time.Millisecond)
for i := 0; i < 10; i++ {
require.NoError(t, limiter.RegisterRequest("token1"))
}
require.Error(t, limiter.RegisterRequest("token1"))
clock.Sleep(10 * time.Millisecond)
// the second rate is full
err = nil
for i := 0; i < 10; i++ {
err = limiter.RegisterRequest("token1")
if err != nil {
break
}
}
require.Error(t, err)
clock.Sleep(10 * time.Millisecond)
// Now the second rate has free space
require.NoError(t, limiter.RegisterRequest("token1"))
err = nil
for i := 0; i < 15; i++ {
err = limiter.RegisterRequest("token1")
if err != nil {
break
}
}
require.Error(t, err)
}
func TestCustomRate(t *testing.T) {
clock := &timetools.FreezedTime{
CurrentTime: time.Date(2016, 6, 5, 4, 3, 2, 1, time.UTC),
}
limiter, err := NewLimiter(
Config{
Clock: clock,
Rates: []Rate{
// Default rate
{
Period: 10 * time.Millisecond,
Average: 10,
Burst: 20,
},
},
})
require.NoError(t, err)
customRate := ratelimit.NewRateSet()
err = customRate.Add(time.Minute, 1, 5)
require.NoError(t, err)
// Max out custom rate.
for i := 0; i < 5; i++ {
require.NoError(t, limiter.RegisterRequestWithCustomRate("token1", customRate))
}
// Test rate limit exceeded with custom rate.
require.Error(t, limiter.RegisterRequestWithCustomRate("token1", customRate))
// Test default rate still works.
for i := 0; i < 20; i++ {
require.NoError(t, limiter.RegisterRequest("token1"))
}
}
type mockAddr struct{}
func (a mockAddr) Network() string {
return "tcp"
}
func (a mockAddr) String() string {
return "127.0.0.1:1234"
}
func TestLimiter_UnaryServerInterceptor(t *testing.T) {
limiter, err := NewLimiter(Config{
MaxConnections: 1,
Rates: []Rate{
{
Period: time.Minute,
Average: 1,
Burst: 1,
},
},
})
require.NoError(t, err)
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: mockAddr{}})
req := "request"
serverInfo := &grpc.UnaryServerInfo{
FullMethod: "/method",
}
handler := func(context.Context, interface{}) (interface{}, error) { return nil, nil }
unaryInterceptor := limiter.UnaryServerInterceptor()
// pass at least once
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
require.NoError(t, err)
// should eventually fail, not testing the limiter behavior here
for i := 0; i < 10; i++ {
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
if err != nil {
break
}
}
require.Error(t, err)
getCustomRate := func(endpoint string) *ratelimit.RateSet {
rates := ratelimit.NewRateSet()
err := rates.Add(2*time.Minute, 1, 2)
require.NoError(t, err)
return rates
}
unaryInterceptor = limiter.UnaryServerInterceptorWithCustomRate(getCustomRate)
// should pass at least once
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
require.NoError(t, err)
// should eventually fail, not testing the limiter behavior here
for i := 0; i < 10; i++ {
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
if err != nil {
break
}
}
require.Error(t, err)
}
type mockServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (s mockServerStream) Context() context.Context {
return s.ctx
}
func TestLimiter_StreamServerInterceptor(t *testing.T) {
limiter, err := NewLimiter(Config{
MaxConnections: 1,
Rates: []Rate{
{
Period: time.Minute,
Average: 1,
Burst: 1,
},
},
})
require.NoError(t, err)
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: mockAddr{}})
ss := mockServerStream{
ctx: ctx,
}
info := &grpc.StreamServerInfo{}
handler := func(srv interface{}, stream grpc.ServerStream) error { return nil }
// pass at least once
err = limiter.StreamServerInterceptor(nil, ss, info, handler)
require.NoError(t, err)
// should eventually fail, not testing the limiter behavior here
for i := 0; i < 10; i++ {
err = limiter.StreamServerInterceptor(nil, ss, info, handler)
if err != nil {
break
}
}
require.Error(t, err)
}
// TestListener verifies that a [Listener] only accepts
// connections if the connection limit has not been exceeded.
func TestListener(t *testing.T) {
const connLimit = 5
failedAcceptErr := errors.New("failed accept")
tooManyConnectionsErr := trace.LimitExceeded("too many connections from 127.0.0.1: 2, max is 2")
tests := []struct {
name string
config Config
listener *fakeListener
acceptAssertion func(t *testing.T, iteration int, conn net.Conn, err error)
numConnAssertion func(t *testing.T, num int64)
}{
{
name: "all connections allowed",
config: Config{MaxConnections: 0},
listener: &fakeListener{
acceptConn: &fakeConn{
addr: mockAddr{},
},
},
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
require.NoError(t, err)
require.NotNil(t, conn)
},
numConnAssertion: func(t *testing.T, num int64) {
// MaxConnections == 0 prevents any connections from being accumulated
require.Zero(t, num)
},
},
{
name: "accept failure",
config: Config{MaxConnections: 0},
listener: &fakeListener{
acceptError: failedAcceptErr,
},
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
require.ErrorIs(t, err, failedAcceptErr)
require.Nil(t, conn)
},
numConnAssertion: func(t *testing.T, num int64) {
require.Zero(t, num)
},
},
{
name: "invalid remote address",
config: Config{MaxConnections: 0},
listener: &fakeListener{
acceptConn: &fakeConn{
addr: &utils.NetAddr{
Addr: "abcd",
AddrNetwork: "tcp",
},
},
},
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
require.Error(t, err)
require.Nil(t, conn)
},
numConnAssertion: func(t *testing.T, num int64) {
require.Zero(t, num)
},
},
{
name: "max connections exceeded",
config: Config{MaxConnections: 2},
listener: &fakeListener{
acceptConn: &fakeConn{
addr: mockAddr{},
},
},
acceptAssertion: func(t *testing.T, i int, conn net.Conn, err error) {
if i < 2 {
require.NoError(t, err)
require.NotNil(t, conn)
return
}
require.Error(t, err)
require.ErrorIs(t, err, tooManyConnectionsErr)
require.True(t, trace.IsLimitExceeded(err))
require.Nil(t, conn)
},
numConnAssertion: func(t *testing.T, num int64) {
require.Equal(t, int64(2), num)
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
limiter, err := NewConnectionsLimiter(test.config)
require.NoError(t, err)
ln := NewListener(test.listener, limiter)
// open connections without closing to enforce limits
conns := make([]net.Conn, 0, connLimit)
for i := 0; i < connLimit; i++ {
conn, err := ln.Accept()
test.acceptAssertion(t, i, conn, err)
if conn != nil {
conns = append(conns, conn)
}
}
// validate limits were enforced
n, err := limiter.GetNumConnection("127.0.0.1")
require.NoError(t, err)
test.numConnAssertion(t, n)
// close connections to reset limits
for _, conn := range conns {
require.NoError(t, conn.Close())
}
// ensure closing connections resets count
n, err = limiter.GetNumConnection("127.0.0.1")
if test.config.MaxConnections == 0 {
require.NoError(t, err)
require.Zero(t, n)
} else {
require.True(t, trace.IsBadParameter(err))
require.Equal(t, int64(-1), n)
}
// open connections again after closing to
// ensure that closing reset limits
for i := 0; i < 5; i++ {
conn, err := ln.Accept()
test.acceptAssertion(t, i, conn, err)
if conn != nil {
t.Cleanup(func() {
require.NoError(t, err)
})
}
}
})
}
}
type fakeListener struct {
net.Listener
acceptConn net.Conn
acceptError error
}
func (f *fakeListener) Accept() (net.Conn, error) {
return f.acceptConn, f.acceptError
}
type fakeConn struct {
net.Conn
addr net.Addr
}
func (f *fakeConn) RemoteAddr() net.Addr {
return f.addr
}
func (f *fakeConn) Close() error {
return nil
}