mirror of
https://github.com/gravitational/teleport
synced 2024-10-20 01:03:40 +00:00
Use first available auth server (#11229)
Currently we use random auth server from the list but if it's unavailable (for example it was restarted but there's still entry in cache, dynamodb backend etc) we return error. This change tries all servers (in random order) and uses first that is available. Closes #10019
This commit is contained in:
parent
22fe05db56
commit
35a9bbc887
|
@ -18,6 +18,7 @@ package alpnproxyauth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
@ -113,13 +114,23 @@ func (s *AuthProxyDialerService) dialLocalAuthServer(ctx context.Context) (net.C
|
|||
if len(authServers) == 0 {
|
||||
return nil, trace.NotFound("empty auth servers list")
|
||||
}
|
||||
//TODO(smallinksy) Better support for HA. Add dial retry on auth network errors.
|
||||
authServerIndex := rand.Intn(len(authServers))
|
||||
conn, err := net.Dial("tcp", authServers[authServerIndex].GetAddr())
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
var errors []error
|
||||
|
||||
// iterate over the addresses in random order
|
||||
for len(authServers) > 0 {
|
||||
l := len(authServers)
|
||||
authServerIndex := rand.Intn(l)
|
||||
addr := authServers[authServerIndex].GetAddr()
|
||||
var d net.Dialer
|
||||
conn, err := d.DialContext(ctx, "tcp", addr)
|
||||
if err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
errors = append(errors, fmt.Errorf("%s: %w", addr, err))
|
||||
authServers[authServerIndex] = authServers[l-1]
|
||||
authServers = authServers[:l-1]
|
||||
}
|
||||
return conn, nil
|
||||
return nil, trace.NewAggregate(errors...)
|
||||
}
|
||||
|
||||
func (s *AuthProxyDialerService) dialRemoteAuthServer(ctx context.Context, clusterName string) (net.Conn, error) {
|
||||
|
|
76
lib/srv/alpnproxy/auth/auth_proxy_test.go
Normal file
76
lib/srv/alpnproxy/auth/auth_proxy_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
Copyright 2021 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package alpnproxyauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/teleport/api/types"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockAuthGetter struct {
|
||||
servers []types.Server
|
||||
}
|
||||
|
||||
func (m mockAuthGetter) GetClusterName(...services.MarshalOption) (types.ClusterName, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m mockAuthGetter) GetAuthServers() ([]types.Server, error) {
|
||||
return m.servers, nil
|
||||
}
|
||||
|
||||
func TestDialLocalAuthServerNoServers(t *testing.T) {
|
||||
s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: []types.Server{}})
|
||||
_, err := s.dialLocalAuthServer(context.Background())
|
||||
require.Error(t, err, "dialLocalAuthServer expected to fail")
|
||||
require.Equal(t, "empty auth servers list", err.Error())
|
||||
}
|
||||
|
||||
func TestDialLocalAuthServerNoAvailableServers(t *testing.T) {
|
||||
server1, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: "invalid:8000"})
|
||||
require.NoError(t, err)
|
||||
s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: []types.Server{server1}})
|
||||
_, err = s.dialLocalAuthServer(context.Background())
|
||||
require.Error(t, err, "dialLocalAuthServer expected to fail")
|
||||
require.Contains(t, err.Error(), "invalid:8000:")
|
||||
}
|
||||
|
||||
func TestDialLocalAuthServerAvailableServers(t *testing.T) {
|
||||
socket, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer socket.Close()
|
||||
server, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: socket.Addr().String()})
|
||||
require.NoError(t, err)
|
||||
servers := []types.Server{server}
|
||||
// multiple invalid servers to minimize chance that we select good one first try
|
||||
for i := 0; i < 20; i++ {
|
||||
server, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: "invalid2:8000"})
|
||||
require.NoError(t, err)
|
||||
servers = append(servers, server)
|
||||
}
|
||||
s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: servers})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
_, err = s.dialLocalAuthServer(ctx)
|
||||
require.NoError(t, err)
|
||||
}
|
Loading…
Reference in a new issue