mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 17:53:28 +00:00
Kube Proxy Forwarder handles kube services with same name (#8362)
Update Proxy kube forwarder to attempt to dial through all available endpoints in a random order.
This commit is contained in:
parent
700f9f71e5
commit
c6f0a8a2fe
|
@ -293,11 +293,12 @@ func (f *Forwarder) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
|||
// contains information about user, target cluster and authenticated groups
|
||||
type authContext struct {
|
||||
auth.Context
|
||||
kubeGroups map[string]struct{}
|
||||
kubeUsers map[string]struct{}
|
||||
kubeCluster string
|
||||
teleportCluster teleportClusterClient
|
||||
recordingConfig types.SessionRecordingConfig
|
||||
kubeGroups map[string]struct{}
|
||||
kubeUsers map[string]struct{}
|
||||
kubeCluster string
|
||||
teleportCluster teleportClusterClient
|
||||
teleportClusterEndpoints []endpoint
|
||||
recordingConfig types.SessionRecordingConfig
|
||||
// clientIdleTimeout sets information on client idle timeout
|
||||
clientIdleTimeout time.Duration
|
||||
// disconnectExpiredCert if set, controls the time when the connection
|
||||
|
@ -307,6 +308,14 @@ type authContext struct {
|
|||
sessionTTL time.Duration
|
||||
}
|
||||
|
||||
type endpoint struct {
|
||||
// addr is a direct network address.
|
||||
addr string
|
||||
// serverID is the server:cluster ID of the endpoint,
|
||||
// which is used to find its corresponding reverse tunnel.
|
||||
serverID string
|
||||
}
|
||||
|
||||
func (c authContext) String() string {
|
||||
return fmt.Sprintf("user: %v, users: %v, groups: %v, teleport cluster: %v, kube cluster: %v", c.User.GetName(), c.kubeUsers, c.kubeGroups, c.teleportCluster.name, c.kubeCluster)
|
||||
}
|
||||
|
@ -335,17 +344,14 @@ type teleportClusterClient struct {
|
|||
dial dialFunc
|
||||
// targetAddr is a direct network address.
|
||||
targetAddr string
|
||||
//serverID is an address reachable over a reverse tunnel.
|
||||
// serverID is the server:cluster ID of the endpoint,
|
||||
// which is used to find its corresponding reverse tunnel.
|
||||
serverID string
|
||||
isRemote bool
|
||||
isRemoteClosed func() bool
|
||||
}
|
||||
|
||||
func (c *teleportClusterClient) Dial(network, addr string) (net.Conn, error) {
|
||||
return c.DialWithContext(context.Background(), network, addr)
|
||||
}
|
||||
|
||||
func (c *teleportClusterClient) DialWithContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
func (c *teleportClusterClient) DialWithContext(ctx context.Context, network, _ string) (net.Conn, error) {
|
||||
return c.dial(ctx, network, c.targetAddr, c.serverID)
|
||||
}
|
||||
|
||||
|
@ -1113,8 +1119,11 @@ func (f *Forwarder) setupForwardingHeaders(sess *clusterSession, req *http.Reque
|
|||
|
||||
// Setup scheme, override target URL to the destination address
|
||||
req.URL.Scheme = "https"
|
||||
req.URL.Host = sess.teleportCluster.targetAddr
|
||||
req.RequestURI = req.URL.Path + "?" + req.URL.RawQuery
|
||||
req.URL.Host = sess.teleportCluster.targetAddr
|
||||
if sess.teleportCluster.targetAddr == "" {
|
||||
req.URL.Host = reversetunnel.LocalKubernetes
|
||||
}
|
||||
|
||||
// add origin headers so the service consuming the request on the other site
|
||||
// is aware of where it came from
|
||||
|
@ -1367,13 +1376,44 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error)
|
|||
}
|
||||
|
||||
func (s *clusterSession) Dial(network, addr string) (net.Conn, error) {
|
||||
return s.monitorConn(s.teleportCluster.Dial(network, addr))
|
||||
return s.monitorConn(s.teleportCluster.DialWithContext(context.Background(), network, addr))
|
||||
}
|
||||
|
||||
func (s *clusterSession) DialWithContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return s.monitorConn(s.teleportCluster.DialWithContext(ctx, network, addr))
|
||||
}
|
||||
|
||||
func (s *clusterSession) DialWithEndpoints(network, addr string) (net.Conn, error) {
|
||||
return s.monitorConn(s.dialWithEndpoints(context.Background(), network, addr))
|
||||
}
|
||||
|
||||
// This is separated from DialWithEndpoints for testing without monitorConn.
|
||||
func (s *clusterSession) dialWithEndpoints(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if len(s.teleportClusterEndpoints) == 0 {
|
||||
return nil, trace.BadParameter("no endpoints to dial")
|
||||
}
|
||||
|
||||
// Shuffle endpoints to balance load
|
||||
shuffledEndpoints := make([]endpoint, len(s.teleportClusterEndpoints))
|
||||
copy(shuffledEndpoints, s.teleportClusterEndpoints)
|
||||
mathrand.Shuffle(len(shuffledEndpoints), func(i, j int) {
|
||||
shuffledEndpoints[i], shuffledEndpoints[j] = shuffledEndpoints[j], shuffledEndpoints[i]
|
||||
})
|
||||
|
||||
errs := []error{}
|
||||
for _, endpoint := range shuffledEndpoints {
|
||||
s.teleportCluster.targetAddr = endpoint.addr
|
||||
s.teleportCluster.serverID = endpoint.serverID
|
||||
conn, err := s.teleportCluster.DialWithContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
return nil, trace.NewAggregate(errs...)
|
||||
}
|
||||
|
||||
// TODO(awly): unit test this
|
||||
func (f *Forwarder) newClusterSession(ctx authContext) (*clusterSession, error) {
|
||||
if ctx.teleportCluster.isRemote {
|
||||
|
@ -1395,7 +1435,7 @@ func (f *Forwarder) newClusterSessionRemoteCluster(ctx authContext) (*clusterSes
|
|||
}
|
||||
// remote clusters use special hardcoded URL,
|
||||
// and use a special dialer
|
||||
sess.authContext.teleportCluster.targetAddr = reversetunnel.LocalKubernetes
|
||||
sess.teleportCluster.targetAddr = reversetunnel.LocalKubernetes
|
||||
transport := f.newTransport(sess.Dial, sess.tlsConfig)
|
||||
|
||||
sess.forwarder, err = forward.New(
|
||||
|
@ -1416,11 +1456,13 @@ func (f *Forwarder) newClusterSessionSameCluster(ctx authContext) (*clusterSessi
|
|||
if err != nil && !trace.IsNotFound(err) {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if len(kubeServices) == 0 && ctx.kubeCluster == ctx.teleportCluster.name {
|
||||
return f.newClusterSessionLocal(ctx)
|
||||
}
|
||||
|
||||
// Validate that the requested kube cluster is registered.
|
||||
var endpoints []types.Server
|
||||
var endpoints []endpoint
|
||||
outer:
|
||||
for _, s := range kubeServices {
|
||||
for _, k := range s.GetKubernetesClusters() {
|
||||
|
@ -1428,7 +1470,10 @@ outer:
|
|||
continue
|
||||
}
|
||||
// TODO(awly): check RBAC
|
||||
endpoints = append(endpoints, s)
|
||||
endpoints = append(endpoints, endpoint{
|
||||
serverID: fmt.Sprintf("%s.%s", s.GetName(), ctx.teleportCluster.name),
|
||||
addr: s.GetAddr(),
|
||||
})
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
|
@ -1439,12 +1484,7 @@ outer:
|
|||
if _, ok := f.creds[ctx.kubeCluster]; ok {
|
||||
return f.newClusterSessionLocal(ctx)
|
||||
}
|
||||
// Pick a random kubernetes_service to serve this request.
|
||||
//
|
||||
// Ideally, we should try a few of the endpoints at random until one
|
||||
// succeeds. But this is simpler for now.
|
||||
endpoint := endpoints[mathrand.Intn(len(endpoints))]
|
||||
return f.newClusterSessionDirect(ctx, endpoint)
|
||||
return f.newClusterSessionDirect(ctx, endpoints)
|
||||
}
|
||||
|
||||
func (f *Forwarder) newClusterSessionLocal(ctx authContext) (*clusterSession, error) {
|
||||
|
@ -1489,11 +1529,13 @@ func (f *Forwarder) newClusterSessionLocal(ctx authContext) (*clusterSession, er
|
|||
return sess, nil
|
||||
}
|
||||
|
||||
func (f *Forwarder) newClusterSessionDirect(ctx authContext, kubeService types.Server) (*clusterSession, error) {
|
||||
f.log.WithFields(log.Fields{
|
||||
"kubernetes_service.name": kubeService.GetName(),
|
||||
"kubernetes_service.addr": kubeService.GetAddr(),
|
||||
}).Debugf("Kubernetes session for %v forwarded to remote kubernetes_service instance.", ctx)
|
||||
func (f *Forwarder) newClusterSessionDirect(ctx authContext, endpoints []endpoint) (*clusterSession, error) {
|
||||
if len(endpoints) == 0 {
|
||||
return nil, trace.BadParameter("no kube cluster endpoints provided")
|
||||
}
|
||||
|
||||
f.log.WithField("kube_service.endpoints", endpoints).Debugf("Kubernetes session for %v forwarded to remote kubernetes_service instance.", ctx)
|
||||
|
||||
sess := &clusterSession{
|
||||
parent: f,
|
||||
authContext: ctx,
|
||||
|
@ -1501,10 +1543,7 @@ func (f *Forwarder) newClusterSessionDirect(ctx authContext, kubeService types.S
|
|||
// audit logging. Avoid duplicate logging.
|
||||
noAuditEvents: true,
|
||||
}
|
||||
// Set both addr and serverID, in case this is a kubernetes_service
|
||||
// connected over a tunnel.
|
||||
sess.authContext.teleportCluster.targetAddr = kubeService.GetAddr()
|
||||
sess.authContext.teleportCluster.serverID = fmt.Sprintf("%s.%s", kubeService.GetName(), ctx.teleportCluster.name)
|
||||
sess.authContext.teleportClusterEndpoints = endpoints
|
||||
|
||||
var err error
|
||||
sess.tlsConfig, err = f.getOrRequestClientCreds(ctx)
|
||||
|
@ -1513,12 +1552,11 @@ func (f *Forwarder) newClusterSessionDirect(ctx authContext, kubeService types.S
|
|||
return nil, trace.AccessDenied("access denied: failed to authenticate with auth server")
|
||||
}
|
||||
|
||||
transport := f.newTransport(sess.Dial, sess.tlsConfig)
|
||||
|
||||
transport := f.newTransport(sess.DialWithEndpoints, sess.tlsConfig)
|
||||
sess.forwarder, err = forward.New(
|
||||
forward.FlushInterval(100*time.Millisecond),
|
||||
forward.RoundTripper(transport),
|
||||
forward.WebsocketDial(sess.Dial),
|
||||
forward.WebsocketDial(sess.DialWithEndpoints),
|
||||
forward.Logger(f.log),
|
||||
forward.ErrorHandler(fwdutils.ErrorHandlerFunc(f.formatForwardResponseError)),
|
||||
)
|
||||
|
|
|
@ -20,6 +20,8 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"testing"
|
||||
|
@ -35,6 +37,8 @@ import (
|
|||
"github.com/gravitational/teleport/lib/services"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
"k8s.io/client-go/transport"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
@ -44,7 +48,6 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/check.v1"
|
||||
"k8s.io/client-go/transport"
|
||||
)
|
||||
|
||||
type ForwarderSuite struct{}
|
||||
|
@ -55,6 +58,25 @@ func Test(t *testing.T) {
|
|||
check.TestingT(t)
|
||||
}
|
||||
|
||||
var (
|
||||
identity = auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "remote-bob",
|
||||
Groups: []string{"remote group a", "remote group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
})
|
||||
unmappedIdentity = auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "bob",
|
||||
Groups: []string{"group a", "group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"k8s group a", "k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
})
|
||||
)
|
||||
|
||||
func (s ForwarderSuite) TestRequestCertificate(c *check.C) {
|
||||
cl, err := newMockCSRClient()
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -72,23 +94,9 @@ func (s ForwarderSuite) TestRequestCertificate(c *check.C) {
|
|||
name: "site a",
|
||||
},
|
||||
Context: auth.Context{
|
||||
User: user,
|
||||
Identity: auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "remote-bob",
|
||||
Groups: []string{"remote group a", "remote group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
}),
|
||||
UnmappedIdentity: auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "bob",
|
||||
Groups: []string{"group a", "group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"k8s group a", "k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
}),
|
||||
User: user,
|
||||
Identity: identity,
|
||||
UnmappedIdentity: unmappedIdentity,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -583,138 +591,275 @@ func (s ForwarderSuite) TestSetupImpersonationHeaders(c *check.C) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s ForwarderSuite) TestNewClusterSession(c *check.C) {
|
||||
clientCreds, err := ttlmap.New(defaults.ClientCacheSize)
|
||||
c.Assert(err, check.IsNil)
|
||||
csrClient, err := newMockCSRClient()
|
||||
c.Assert(err, check.IsNil)
|
||||
f := &Forwarder{
|
||||
log: logrus.New(),
|
||||
cfg: ForwarderConfig{
|
||||
Keygen: testauthority.New(),
|
||||
AuthClient: csrClient,
|
||||
CachingAuthClient: mockAccessPoint{},
|
||||
},
|
||||
clientCredentials: clientCreds,
|
||||
ctx: context.TODO(),
|
||||
activeRequests: make(map[string]context.Context),
|
||||
}
|
||||
user, err := types.NewUser("bob")
|
||||
c.Assert(err, check.IsNil)
|
||||
func TestNewClusterSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
f := newMockForwader(ctx, t)
|
||||
|
||||
user, err := types.NewUser("bob")
|
||||
require.NoError(t, err)
|
||||
|
||||
c.Log("newClusterSession for a local cluster without kubeconfig")
|
||||
authCtx := authContext{
|
||||
Context: auth.Context{
|
||||
User: user,
|
||||
Identity: auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "remote-bob",
|
||||
Groups: []string{"remote group a", "remote group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
}),
|
||||
UnmappedIdentity: auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "bob",
|
||||
Groups: []string{"group a", "group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"k8s group a", "k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
}),
|
||||
},
|
||||
teleportCluster: teleportClusterClient{
|
||||
name: "local",
|
||||
},
|
||||
sessionTTL: time.Minute,
|
||||
}
|
||||
_, err = f.newClusterSession(authCtx)
|
||||
c.Assert(err, check.NotNil)
|
||||
c.Assert(trace.IsNotFound(err), check.Equals, true)
|
||||
c.Assert(f.clientCredentials.Len(), check.Equals, 0)
|
||||
|
||||
f.creds = map[string]*kubeCreds{
|
||||
"local": {
|
||||
targetAddr: "k8s.example.com",
|
||||
tlsConfig: &tls.Config{},
|
||||
transportConfig: &transport.Config{},
|
||||
},
|
||||
}
|
||||
|
||||
c.Log("newClusterSession for a local cluster")
|
||||
authCtx = authContext{
|
||||
Context: auth.Context{
|
||||
User: user,
|
||||
Identity: auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "remote-bob",
|
||||
Groups: []string{"remote group a", "remote group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
}),
|
||||
UnmappedIdentity: auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "bob",
|
||||
Groups: []string{"group a", "group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"k8s group a", "k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
}),
|
||||
User: user,
|
||||
Identity: identity,
|
||||
UnmappedIdentity: unmappedIdentity,
|
||||
},
|
||||
teleportCluster: teleportClusterClient{
|
||||
name: "local",
|
||||
},
|
||||
sessionTTL: time.Minute,
|
||||
kubeCluster: "local",
|
||||
kubeCluster: "public",
|
||||
}
|
||||
sess, err := f.newClusterSession(authCtx)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sess.authContext.teleportCluster.targetAddr, check.Equals, f.creds["local"].targetAddr)
|
||||
c.Assert(sess.forwarder, check.NotNil)
|
||||
// Make sure newClusterSession used f.creds instead of requesting a
|
||||
// Teleport client cert.
|
||||
c.Assert(sess.tlsConfig, check.Equals, f.creds["local"].tlsConfig)
|
||||
c.Assert(csrClient.lastCert, check.IsNil)
|
||||
c.Assert(f.clientCredentials.Len(), check.Equals, 0)
|
||||
|
||||
c.Log("newClusterSession for a remote cluster")
|
||||
authCtx = authContext{
|
||||
Context: auth.Context{
|
||||
User: user,
|
||||
Identity: auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "remote-bob",
|
||||
Groups: []string{"remote group a", "remote group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
}),
|
||||
UnmappedIdentity: auth.WrapIdentity(tlsca.Identity{
|
||||
Username: "bob",
|
||||
Groups: []string{"group a", "group b"},
|
||||
Usage: []string{"usage a", "usage b"},
|
||||
Principals: []string{"principal a", "principal b"},
|
||||
KubernetesGroups: []string{"k8s group a", "k8s group b"},
|
||||
Traits: map[string][]string{"trait a": []string{"b", "c"}},
|
||||
}),
|
||||
},
|
||||
teleportCluster: teleportClusterClient{
|
||||
t.Run("newClusterSession for a local cluster without kubeconfig", func(t *testing.T) {
|
||||
authCtx := authCtx
|
||||
authCtx.kubeCluster = ""
|
||||
|
||||
_, err = f.newClusterSession(authCtx)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, trace.IsNotFound(err), true)
|
||||
require.Equal(t, f.clientCredentials.Len(), 0)
|
||||
})
|
||||
|
||||
t.Run("newClusterSession for a local cluster", func(t *testing.T) {
|
||||
authCtx := authCtx
|
||||
authCtx.kubeCluster = "local"
|
||||
|
||||
// Set local creds for the following tests
|
||||
f.creds = map[string]*kubeCreds{
|
||||
"local": {
|
||||
targetAddr: "k8s.example.com",
|
||||
tlsConfig: &tls.Config{},
|
||||
transportConfig: &transport.Config{},
|
||||
},
|
||||
}
|
||||
|
||||
sess, err := f.newClusterSession(authCtx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, f.creds["local"].targetAddr, sess.authContext.teleportCluster.targetAddr)
|
||||
require.NotNil(t, sess.forwarder)
|
||||
// Make sure newClusterSession used f.creds instead of requesting a
|
||||
// Teleport client cert.
|
||||
require.Equal(t, f.creds["local"].tlsConfig, sess.tlsConfig)
|
||||
require.Nil(t, f.cfg.AuthClient.(*mockCSRClient).lastCert)
|
||||
require.Equal(t, 0, f.clientCredentials.Len())
|
||||
})
|
||||
|
||||
t.Run("newClusterSession for a remote cluster", func(t *testing.T) {
|
||||
authCtx := authCtx
|
||||
authCtx.kubeCluster = ""
|
||||
authCtx.teleportCluster = teleportClusterClient{
|
||||
name: "remote",
|
||||
isRemote: true,
|
||||
}
|
||||
|
||||
sess, err := f.newClusterSession(authCtx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, reversetunnel.LocalKubernetes, sess.authContext.teleportCluster.targetAddr)
|
||||
require.NotNil(t, sess.forwarder)
|
||||
// Make sure newClusterSession obtained a new client cert instead of using
|
||||
// f.creds.
|
||||
require.NotEqual(t, f.creds["local"].tlsConfig, sess.tlsConfig)
|
||||
require.Equal(t, f.cfg.AuthClient.(*mockCSRClient).lastCert.Raw, sess.tlsConfig.Certificates[0].Certificate[0])
|
||||
require.Equal(t, [][]byte{f.cfg.AuthClient.(*mockCSRClient).ca.Cert.RawSubject}, sess.tlsConfig.RootCAs.Subjects())
|
||||
require.Equal(t, 1, f.clientCredentials.Len())
|
||||
})
|
||||
|
||||
t.Run("newClusterSession with public kube_service endpoints", func(t *testing.T) {
|
||||
publicKubeServer := &types.ServerV2{
|
||||
Kind: types.KindKubeService,
|
||||
Version: types.V2,
|
||||
Metadata: types.Metadata{
|
||||
Name: "public-server",
|
||||
},
|
||||
Spec: types.ServerSpecV2{
|
||||
Addr: "k8s.example.com:3026",
|
||||
Hostname: "",
|
||||
KubernetesClusters: []*types.KubernetesCluster{{
|
||||
Name: "public",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
reverseTunnelKubeServer := &types.ServerV2{
|
||||
Kind: types.KindKubeService,
|
||||
Version: types.V2,
|
||||
Metadata: types.Metadata{
|
||||
Name: "reverse-tunnel-server",
|
||||
},
|
||||
Spec: types.ServerSpecV2{
|
||||
Addr: reversetunnel.LocalKubernetes,
|
||||
Hostname: "",
|
||||
KubernetesClusters: []*types.KubernetesCluster{{
|
||||
Name: "public",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
f.cfg.CachingAuthClient = mockAccessPoint{
|
||||
kubeServices: []types.Server{
|
||||
publicKubeServer,
|
||||
reverseTunnelKubeServer,
|
||||
},
|
||||
}
|
||||
|
||||
sess, err := f.newClusterSession(authCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedEndpoints := []endpoint{
|
||||
{
|
||||
addr: publicKubeServer.GetAddr(),
|
||||
serverID: fmt.Sprintf("%v.local", publicKubeServer.GetName()),
|
||||
},
|
||||
{
|
||||
addr: reverseTunnelKubeServer.GetAddr(),
|
||||
serverID: fmt.Sprintf("%v.local", reverseTunnelKubeServer.GetName()),
|
||||
},
|
||||
}
|
||||
require.Equal(t, expectedEndpoints, sess.authContext.teleportClusterEndpoints)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDialWithEndpoints(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
f := newMockForwader(ctx, t)
|
||||
|
||||
user, err := types.NewUser("bob")
|
||||
require.NoError(t, err)
|
||||
|
||||
authCtx := authContext{
|
||||
Context: auth.Context{
|
||||
User: user,
|
||||
Identity: identity,
|
||||
UnmappedIdentity: unmappedIdentity,
|
||||
},
|
||||
sessionTTL: time.Minute,
|
||||
teleportCluster: teleportClusterClient{
|
||||
name: "local",
|
||||
dial: func(ctx context.Context, network, addr, serverID string) (net.Conn, error) {
|
||||
return &net.TCPConn{}, nil
|
||||
},
|
||||
},
|
||||
sessionTTL: time.Minute,
|
||||
kubeCluster: "public",
|
||||
}
|
||||
|
||||
publicKubeServer := &types.ServerV2{
|
||||
Kind: types.KindKubeService,
|
||||
Version: types.V2,
|
||||
Metadata: types.Metadata{
|
||||
Name: "public-server",
|
||||
},
|
||||
Spec: types.ServerSpecV2{
|
||||
Addr: "k8s.example.com:3026",
|
||||
Hostname: "",
|
||||
KubernetesClusters: []*types.KubernetesCluster{{
|
||||
Name: "public",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Dial public endpoint", func(t *testing.T) {
|
||||
f.cfg.CachingAuthClient = mockAccessPoint{
|
||||
kubeServices: []types.Server{
|
||||
publicKubeServer,
|
||||
},
|
||||
}
|
||||
|
||||
sess, err := f.newClusterSession(authCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sess.dialWithEndpoints(ctx, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, publicKubeServer.GetAddr(), sess.authContext.teleportCluster.targetAddr)
|
||||
expectServerID := fmt.Sprintf("%v.%v", publicKubeServer.GetName(), authCtx.teleportCluster.name)
|
||||
require.Equal(t, expectServerID, sess.authContext.teleportCluster.serverID)
|
||||
})
|
||||
|
||||
reverseTunnelKubeServer := &types.ServerV2{
|
||||
Kind: types.KindKubeService,
|
||||
Version: types.V2,
|
||||
Metadata: types.Metadata{
|
||||
Name: "reverse-tunnel-server",
|
||||
},
|
||||
Spec: types.ServerSpecV2{
|
||||
Addr: reversetunnel.LocalKubernetes,
|
||||
Hostname: "",
|
||||
KubernetesClusters: []*types.KubernetesCluster{{
|
||||
Name: "public",
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Dial reverse tunnel endpoint", func(t *testing.T) {
|
||||
f.cfg.CachingAuthClient = mockAccessPoint{
|
||||
kubeServices: []types.Server{
|
||||
reverseTunnelKubeServer,
|
||||
},
|
||||
}
|
||||
|
||||
sess, err := f.newClusterSession(authCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sess.dialWithEndpoints(ctx, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, reverseTunnelKubeServer.GetAddr(), sess.authContext.teleportCluster.targetAddr)
|
||||
expectServerID := fmt.Sprintf("%v.%v", reverseTunnelKubeServer.GetName(), authCtx.teleportCluster.name)
|
||||
require.Equal(t, expectServerID, sess.authContext.teleportCluster.serverID)
|
||||
})
|
||||
|
||||
t.Run("newClusterSession multiple kube clusters", func(t *testing.T) {
|
||||
f.cfg.CachingAuthClient = mockAccessPoint{
|
||||
kubeServices: []types.Server{
|
||||
publicKubeServer,
|
||||
reverseTunnelKubeServer,
|
||||
},
|
||||
}
|
||||
|
||||
sess, err := f.newClusterSession(authCtx)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sess.dialWithEndpoints(ctx, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// The endpoint used to dial will be chosen at random. Make sure we hit one of them.
|
||||
switch sess.teleportCluster.targetAddr {
|
||||
case publicKubeServer.GetAddr():
|
||||
expectServerID := fmt.Sprintf("%v.%v", publicKubeServer.GetName(), authCtx.teleportCluster.name)
|
||||
require.Equal(t, expectServerID, sess.authContext.teleportCluster.serverID)
|
||||
case reverseTunnelKubeServer.GetAddr():
|
||||
expectServerID := fmt.Sprintf("%v.%v", reverseTunnelKubeServer.GetName(), authCtx.teleportCluster.name)
|
||||
require.Equal(t, expectServerID, sess.authContext.teleportCluster.serverID)
|
||||
default:
|
||||
t.Fatalf("Unexpected targetAddr: %v", sess.authContext.teleportCluster.targetAddr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func newMockForwader(ctx context.Context, t *testing.T) *Forwarder {
|
||||
clientCreds, err := ttlmap.New(defaults.ClientCacheSize)
|
||||
require.NoError(t, err)
|
||||
|
||||
csrClient, err := newMockCSRClient()
|
||||
require.NoError(t, err)
|
||||
|
||||
return &Forwarder{
|
||||
log: logrus.New(),
|
||||
router: *httprouter.New(),
|
||||
cfg: ForwarderConfig{
|
||||
Keygen: testauthority.New(),
|
||||
AuthClient: csrClient,
|
||||
CachingAuthClient: mockAccessPoint{},
|
||||
Clock: clockwork.NewFakeClock(),
|
||||
Context: ctx,
|
||||
},
|
||||
clientCredentials: clientCreds,
|
||||
activeRequests: make(map[string]context.Context),
|
||||
ctx: ctx,
|
||||
}
|
||||
sess, err = f.newClusterSession(authCtx)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(sess.authContext.teleportCluster.targetAddr, check.Equals, reversetunnel.LocalKubernetes)
|
||||
c.Assert(sess.forwarder, check.NotNil)
|
||||
// Make sure newClusterSession obtained a new client cert instead of using
|
||||
// f.creds.
|
||||
c.Assert(sess.tlsConfig, check.Not(check.Equals), f.creds["local"].tlsConfig)
|
||||
c.Assert(sess.tlsConfig.Certificates[0].Certificate[0], check.DeepEquals, csrClient.lastCert.Raw)
|
||||
c.Assert(sess.tlsConfig.RootCAs.Subjects(), check.DeepEquals, [][]byte{csrClient.ca.Cert.RawSubject})
|
||||
c.Assert(f.clientCredentials.Len(), check.Equals, 1)
|
||||
}
|
||||
|
||||
// mockCSRClient to intercept ProcessKubeCSR requests, record them and return a
|
||||
|
|
|
@ -17,11 +17,17 @@ limitations under the License.
|
|||
package main
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/teleport/tool/teleport/common"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func main() {
|
||||
common.Run(common.Options{
|
||||
Args: os.Args[1:],
|
||||
|
|
Loading…
Reference in a new issue