diff --git a/lib/srv/alpnproxy/auth/auth_proxy.go b/lib/srv/alpnproxy/auth/auth_proxy.go index ec8740fae26..f7b5aae4115 100644 --- a/lib/srv/alpnproxy/auth/auth_proxy.go +++ b/lib/srv/alpnproxy/auth/auth_proxy.go @@ -18,6 +18,7 @@ package alpnproxyauth import ( "context" + "fmt" "io" "math/rand" "net" @@ -113,13 +114,23 @@ func (s *AuthProxyDialerService) dialLocalAuthServer(ctx context.Context) (net.C if len(authServers) == 0 { return nil, trace.NotFound("empty auth servers list") } - //TODO(smallinksy) Better support for HA. Add dial retry on auth network errors. - authServerIndex := rand.Intn(len(authServers)) - conn, err := net.Dial("tcp", authServers[authServerIndex].GetAddr()) - if err != nil { - return nil, trace.Wrap(err) + var errors []error + + // iterate over the addresses in random order + for len(authServers) > 0 { + l := len(authServers) + authServerIndex := rand.Intn(l) + addr := authServers[authServerIndex].GetAddr() + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", addr) + if err == nil { + return conn, nil + } + errors = append(errors, fmt.Errorf("%s: %w", addr, err)) + authServers[authServerIndex] = authServers[l-1] + authServers = authServers[:l-1] } - return conn, nil + return nil, trace.NewAggregate(errors...) } func (s *AuthProxyDialerService) dialRemoteAuthServer(ctx context.Context, clusterName string) (net.Conn, error) { diff --git a/lib/srv/alpnproxy/auth/auth_proxy_test.go b/lib/srv/alpnproxy/auth/auth_proxy_test.go new file mode 100644 index 00000000000..ff25f264121 --- /dev/null +++ b/lib/srv/alpnproxy/auth/auth_proxy_test.go @@ -0,0 +1,76 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package alpnproxyauth + +import ( + "context" + "net" + "testing" + "time" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/services" + "github.com/stretchr/testify/require" +) + +type mockAuthGetter struct { + servers []types.Server +} + +func (m mockAuthGetter) GetClusterName(...services.MarshalOption) (types.ClusterName, error) { + return nil, nil +} + +func (m mockAuthGetter) GetAuthServers() ([]types.Server, error) { + return m.servers, nil +} + +func TestDialLocalAuthServerNoServers(t *testing.T) { + s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: []types.Server{}}) + _, err := s.dialLocalAuthServer(context.Background()) + require.Error(t, err, "dialLocalAuthServer expected to fail") + require.Equal(t, "empty auth servers list", err.Error()) +} + +func TestDialLocalAuthServerNoAvailableServers(t *testing.T) { + server1, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: "invalid:8000"}) + require.NoError(t, err) + s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: []types.Server{server1}}) + _, err = s.dialLocalAuthServer(context.Background()) + require.Error(t, err, "dialLocalAuthServer expected to fail") + require.Contains(t, err.Error(), "invalid:8000:") +} + +func TestDialLocalAuthServerAvailableServers(t *testing.T) { + socket, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer socket.Close() + server, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: socket.Addr().String()}) + require.NoError(t, err) + servers := []types.Server{server} + // multiple invalid servers to minimize chance that we select good one first try + for i := 0; i < 20; i++ { + server, err := types.NewServer("s1", "auth", types.ServerSpecV2{Addr: "invalid2:8000"}) + require.NoError(t, err) + servers = append(servers, server) + } + s := NewAuthProxyDialerService(nil, mockAuthGetter{servers: servers}) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err = s.dialLocalAuthServer(ctx) + require.NoError(t, err) +}