mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 09:44:51 +00:00
3d5557d947
Introduces `limiter.Listener` to provide a consistent and reusable mechanism for limiting incoming connections per client. The new listener is used by `sshutils/server.go` instead of manually applying limits in `HandleConnection`. This is particularly important now that the Proxy SSH port multiplexes both SSH and gRPC. Each listener is now wrapped in a `limiter.Listener` that uses the same `limiter.ConnectionsListener` to ensure that the connection limits for the Proxy are enforced for all traffic on the port.
478 lines
11 KiB
Go
478 lines
11 KiB
Go
/*
|
|
Copyright 2015 Gravitational, Inc.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
package limiter
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gravitational/oxy/ratelimit"
|
|
"github.com/gravitational/trace"
|
|
"github.com/mailgun/timetools"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/peer"
|
|
|
|
"github.com/gravitational/teleport/lib/utils"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
utils.InitLoggerForTests()
|
|
os.Exit(m.Run())
|
|
}
|
|
|
|
func TestConnectionsLimiter(t *testing.T) {
|
|
limiter, err := NewLimiter(
|
|
Config{
|
|
MaxConnections: 0,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
require.NoError(t, limiter.AcquireConnection("token1"))
|
|
}
|
|
for i := 0; i < 5; i++ {
|
|
require.NoError(t, limiter.AcquireConnection("token2"))
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
limiter.ReleaseConnection("token1")
|
|
}
|
|
for i := 0; i < 5; i++ {
|
|
limiter.ReleaseConnection("token2")
|
|
}
|
|
|
|
limiter, err = NewLimiter(
|
|
Config{
|
|
MaxConnections: 5,
|
|
},
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
for i := 0; i < 5; i++ {
|
|
require.NoError(t, limiter.AcquireConnection("token1"))
|
|
}
|
|
|
|
for i := 0; i < 5; i++ {
|
|
require.NoError(t, limiter.AcquireConnection("token2"))
|
|
}
|
|
for i := 0; i < 5; i++ {
|
|
require.Error(t, limiter.AcquireConnection("token2"))
|
|
}
|
|
|
|
for i := 0; i < 10; i++ {
|
|
limiter.ReleaseConnection("token1")
|
|
require.NoError(t, limiter.AcquireConnection("token1"))
|
|
}
|
|
|
|
for i := 0; i < 5; i++ {
|
|
limiter.ReleaseConnection("token2")
|
|
}
|
|
for i := 0; i < 5; i++ {
|
|
require.NoError(t, limiter.AcquireConnection("token2"))
|
|
}
|
|
}
|
|
|
|
func TestRateLimiter(t *testing.T) {
|
|
// TODO: this test fails
|
|
clock := &timetools.FreezedTime{
|
|
CurrentTime: time.Date(2016, 6, 5, 4, 3, 2, 1, time.UTC),
|
|
}
|
|
|
|
limiter, err := NewLimiter(
|
|
Config{
|
|
Clock: clock,
|
|
Rates: []Rate{
|
|
{
|
|
Period: 10 * time.Millisecond,
|
|
Average: 10,
|
|
Burst: 20,
|
|
},
|
|
{
|
|
Period: 40 * time.Millisecond,
|
|
Average: 10,
|
|
Burst: 40,
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
for i := 0; i < 20; i++ {
|
|
require.NoError(t, limiter.RegisterRequest("token1"))
|
|
}
|
|
for i := 0; i < 20; i++ {
|
|
require.NoError(t, limiter.RegisterRequest("token2"))
|
|
}
|
|
|
|
require.Error(t, limiter.RegisterRequest("token1"))
|
|
|
|
clock.Sleep(10 * time.Millisecond)
|
|
for i := 0; i < 10; i++ {
|
|
require.NoError(t, limiter.RegisterRequest("token1"))
|
|
}
|
|
require.Error(t, limiter.RegisterRequest("token1"))
|
|
|
|
clock.Sleep(10 * time.Millisecond)
|
|
for i := 0; i < 10; i++ {
|
|
require.NoError(t, limiter.RegisterRequest("token1"))
|
|
}
|
|
require.Error(t, limiter.RegisterRequest("token1"))
|
|
|
|
clock.Sleep(10 * time.Millisecond)
|
|
// the second rate is full
|
|
err = nil
|
|
for i := 0; i < 10; i++ {
|
|
err = limiter.RegisterRequest("token1")
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
require.Error(t, err)
|
|
|
|
clock.Sleep(10 * time.Millisecond)
|
|
// Now the second rate has free space
|
|
require.NoError(t, limiter.RegisterRequest("token1"))
|
|
err = nil
|
|
for i := 0; i < 15; i++ {
|
|
err = limiter.RegisterRequest("token1")
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestCustomRate(t *testing.T) {
|
|
clock := &timetools.FreezedTime{
|
|
CurrentTime: time.Date(2016, 6, 5, 4, 3, 2, 1, time.UTC),
|
|
}
|
|
|
|
limiter, err := NewLimiter(
|
|
Config{
|
|
Clock: clock,
|
|
Rates: []Rate{
|
|
// Default rate
|
|
{
|
|
Period: 10 * time.Millisecond,
|
|
Average: 10,
|
|
Burst: 20,
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
customRate := ratelimit.NewRateSet()
|
|
err = customRate.Add(time.Minute, 1, 5)
|
|
require.NoError(t, err)
|
|
|
|
// Max out custom rate.
|
|
for i := 0; i < 5; i++ {
|
|
require.NoError(t, limiter.RegisterRequestWithCustomRate("token1", customRate))
|
|
}
|
|
|
|
// Test rate limit exceeded with custom rate.
|
|
require.Error(t, limiter.RegisterRequestWithCustomRate("token1", customRate))
|
|
|
|
// Test default rate still works.
|
|
for i := 0; i < 20; i++ {
|
|
require.NoError(t, limiter.RegisterRequest("token1"))
|
|
}
|
|
}
|
|
|
|
type mockAddr struct{}
|
|
|
|
func (a mockAddr) Network() string {
|
|
return "tcp"
|
|
}
|
|
|
|
func (a mockAddr) String() string {
|
|
return "127.0.0.1:1234"
|
|
}
|
|
|
|
func TestLimiter_UnaryServerInterceptor(t *testing.T) {
|
|
limiter, err := NewLimiter(Config{
|
|
MaxConnections: 1,
|
|
Rates: []Rate{
|
|
{
|
|
Period: time.Minute,
|
|
Average: 1,
|
|
Burst: 1,
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: mockAddr{}})
|
|
req := "request"
|
|
serverInfo := &grpc.UnaryServerInfo{
|
|
FullMethod: "/method",
|
|
}
|
|
handler := func(context.Context, interface{}) (interface{}, error) { return nil, nil }
|
|
|
|
unaryInterceptor := limiter.UnaryServerInterceptor()
|
|
|
|
// pass at least once
|
|
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
|
|
require.NoError(t, err)
|
|
|
|
// should eventually fail, not testing the limiter behavior here
|
|
for i := 0; i < 10; i++ {
|
|
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
require.Error(t, err)
|
|
|
|
getCustomRate := func(endpoint string) *ratelimit.RateSet {
|
|
rates := ratelimit.NewRateSet()
|
|
err := rates.Add(2*time.Minute, 1, 2)
|
|
require.NoError(t, err)
|
|
return rates
|
|
}
|
|
|
|
unaryInterceptor = limiter.UnaryServerInterceptorWithCustomRate(getCustomRate)
|
|
|
|
// should pass at least once
|
|
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
|
|
require.NoError(t, err)
|
|
|
|
// should eventually fail, not testing the limiter behavior here
|
|
for i := 0; i < 10; i++ {
|
|
_, err = unaryInterceptor(ctx, req, serverInfo, handler)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
require.Error(t, err)
|
|
}
|
|
|
|
type mockServerStream struct {
|
|
grpc.ServerStream
|
|
ctx context.Context
|
|
}
|
|
|
|
func (s mockServerStream) Context() context.Context {
|
|
return s.ctx
|
|
}
|
|
|
|
func TestLimiter_StreamServerInterceptor(t *testing.T) {
|
|
limiter, err := NewLimiter(Config{
|
|
MaxConnections: 1,
|
|
Rates: []Rate{
|
|
{
|
|
Period: time.Minute,
|
|
Average: 1,
|
|
Burst: 1,
|
|
},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
ctx := peer.NewContext(context.Background(), &peer.Peer{Addr: mockAddr{}})
|
|
ss := mockServerStream{
|
|
ctx: ctx,
|
|
}
|
|
info := &grpc.StreamServerInfo{}
|
|
handler := func(srv interface{}, stream grpc.ServerStream) error { return nil }
|
|
|
|
// pass at least once
|
|
err = limiter.StreamServerInterceptor(nil, ss, info, handler)
|
|
require.NoError(t, err)
|
|
|
|
// should eventually fail, not testing the limiter behavior here
|
|
for i := 0; i < 10; i++ {
|
|
err = limiter.StreamServerInterceptor(nil, ss, info, handler)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// TestListener verifies that a [Listener] only accepts
|
|
// connections if the connection limit has not been exceeded.
|
|
func TestListener(t *testing.T) {
|
|
const connLimit = 5
|
|
failedAcceptErr := errors.New("failed accept")
|
|
tooManyConnectionsErr := trace.LimitExceeded("too many connections from 127.0.0.1: 2, max is 2")
|
|
|
|
tests := []struct {
|
|
name string
|
|
config Config
|
|
listener *fakeListener
|
|
acceptAssertion func(t *testing.T, iteration int, conn net.Conn, err error)
|
|
numConnAssertion func(t *testing.T, num int64)
|
|
}{
|
|
{
|
|
name: "all connections allowed",
|
|
config: Config{MaxConnections: 0},
|
|
listener: &fakeListener{
|
|
acceptConn: &fakeConn{
|
|
addr: mockAddr{},
|
|
},
|
|
},
|
|
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
|
|
require.NoError(t, err)
|
|
require.NotNil(t, conn)
|
|
},
|
|
numConnAssertion: func(t *testing.T, num int64) {
|
|
// MaxConnections == 0 prevents any connections from being accumulated
|
|
require.Zero(t, num)
|
|
},
|
|
},
|
|
{
|
|
name: "accept failure",
|
|
config: Config{MaxConnections: 0},
|
|
listener: &fakeListener{
|
|
acceptError: failedAcceptErr,
|
|
},
|
|
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
|
|
require.ErrorIs(t, err, failedAcceptErr)
|
|
require.Nil(t, conn)
|
|
},
|
|
numConnAssertion: func(t *testing.T, num int64) {
|
|
require.Zero(t, num)
|
|
},
|
|
},
|
|
{
|
|
name: "invalid remote address",
|
|
config: Config{MaxConnections: 0},
|
|
listener: &fakeListener{
|
|
acceptConn: &fakeConn{
|
|
addr: &utils.NetAddr{
|
|
Addr: "abcd",
|
|
AddrNetwork: "tcp",
|
|
},
|
|
},
|
|
},
|
|
acceptAssertion: func(t *testing.T, _ int, conn net.Conn, err error) {
|
|
require.Error(t, err)
|
|
require.Nil(t, conn)
|
|
},
|
|
numConnAssertion: func(t *testing.T, num int64) {
|
|
require.Zero(t, num)
|
|
},
|
|
},
|
|
{
|
|
name: "max connections exceeded",
|
|
config: Config{MaxConnections: 2},
|
|
listener: &fakeListener{
|
|
acceptConn: &fakeConn{
|
|
addr: mockAddr{},
|
|
},
|
|
},
|
|
acceptAssertion: func(t *testing.T, i int, conn net.Conn, err error) {
|
|
if i < 2 {
|
|
require.NoError(t, err)
|
|
require.NotNil(t, conn)
|
|
return
|
|
}
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, tooManyConnectionsErr)
|
|
require.True(t, trace.IsLimitExceeded(err))
|
|
require.Nil(t, conn)
|
|
},
|
|
numConnAssertion: func(t *testing.T, num int64) {
|
|
require.Equal(t, int64(2), num)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
limiter, err := NewConnectionsLimiter(test.config)
|
|
require.NoError(t, err)
|
|
|
|
ln := NewListener(test.listener, limiter)
|
|
|
|
// open connections without closing to enforce limits
|
|
conns := make([]net.Conn, 0, connLimit)
|
|
for i := 0; i < connLimit; i++ {
|
|
conn, err := ln.Accept()
|
|
test.acceptAssertion(t, i, conn, err)
|
|
|
|
if conn != nil {
|
|
conns = append(conns, conn)
|
|
}
|
|
}
|
|
|
|
// validate limits were enforced
|
|
n, err := limiter.GetNumConnection("127.0.0.1")
|
|
require.NoError(t, err)
|
|
test.numConnAssertion(t, n)
|
|
|
|
// close connections to reset limits
|
|
for _, conn := range conns {
|
|
require.NoError(t, conn.Close())
|
|
}
|
|
|
|
// ensure closing connections resets count
|
|
n, err = limiter.GetNumConnection("127.0.0.1")
|
|
if test.config.MaxConnections == 0 {
|
|
require.NoError(t, err)
|
|
require.Zero(t, n)
|
|
} else {
|
|
require.True(t, trace.IsBadParameter(err))
|
|
require.Equal(t, int64(-1), n)
|
|
}
|
|
|
|
// open connections again after closing to
|
|
// ensure that closing reset limits
|
|
for i := 0; i < 5; i++ {
|
|
conn, err := ln.Accept()
|
|
test.acceptAssertion(t, i, conn, err)
|
|
|
|
if conn != nil {
|
|
t.Cleanup(func() {
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type fakeListener struct {
|
|
net.Listener
|
|
|
|
acceptConn net.Conn
|
|
acceptError error
|
|
}
|
|
|
|
func (f *fakeListener) Accept() (net.Conn, error) {
|
|
return f.acceptConn, f.acceptError
|
|
}
|
|
|
|
type fakeConn struct {
|
|
net.Conn
|
|
|
|
addr net.Addr
|
|
}
|
|
|
|
func (f *fakeConn) RemoteAddr() net.Addr {
|
|
return f.addr
|
|
}
|
|
|
|
func (f *fakeConn) Close() error {
|
|
return nil
|
|
}
|