Kube Proxy Forwarder handles kube services with same name (#8362)

Update Proxy kube forwarder to attempt to dial through all available
endpoints in a random order.
This commit is contained in:
Brian Joerger 2021-10-06 16:01:08 -07:00 committed by GitHub
parent 700f9f71e5
commit c6f0a8a2fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 359 additions and 170 deletions

View file

@ -293,11 +293,12 @@ func (f *Forwarder) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// contains information about user, target cluster and authenticated groups
type authContext struct {
auth.Context
kubeGroups map[string]struct{}
kubeUsers map[string]struct{}
kubeCluster string
teleportCluster teleportClusterClient
recordingConfig types.SessionRecordingConfig
kubeGroups map[string]struct{}
kubeUsers map[string]struct{}
kubeCluster string
teleportCluster teleportClusterClient
teleportClusterEndpoints []endpoint
recordingConfig types.SessionRecordingConfig
// clientIdleTimeout sets information on client idle timeout
clientIdleTimeout time.Duration
// disconnectExpiredCert if set, controls the time when the connection
@ -307,6 +308,14 @@ type authContext struct {
sessionTTL time.Duration
}
type endpoint struct {
// addr is a direct network address.
addr string
// serverID is the server:cluster ID of the endpoint,
// which is used to find its corresponding reverse tunnel.
serverID string
}
func (c authContext) String() string {
return fmt.Sprintf("user: %v, users: %v, groups: %v, teleport cluster: %v, kube cluster: %v", c.User.GetName(), c.kubeUsers, c.kubeGroups, c.teleportCluster.name, c.kubeCluster)
}
@ -335,17 +344,14 @@ type teleportClusterClient struct {
dial dialFunc
// targetAddr is a direct network address.
targetAddr string
//serverID is an address reachable over a reverse tunnel.
// serverID is the server:cluster ID of the endpoint,
// which is used to find its corresponding reverse tunnel.
serverID string
isRemote bool
isRemoteClosed func() bool
}
func (c *teleportClusterClient) Dial(network, addr string) (net.Conn, error) {
return c.DialWithContext(context.Background(), network, addr)
}
func (c *teleportClusterClient) DialWithContext(ctx context.Context, network, addr string) (net.Conn, error) {
func (c *teleportClusterClient) DialWithContext(ctx context.Context, network, _ string) (net.Conn, error) {
return c.dial(ctx, network, c.targetAddr, c.serverID)
}
@ -1113,8 +1119,11 @@ func (f *Forwarder) setupForwardingHeaders(sess *clusterSession, req *http.Reque
// Setup scheme, override target URL to the destination address
req.URL.Scheme = "https"
req.URL.Host = sess.teleportCluster.targetAddr
req.RequestURI = req.URL.Path + "?" + req.URL.RawQuery
req.URL.Host = sess.teleportCluster.targetAddr
if sess.teleportCluster.targetAddr == "" {
req.URL.Host = reversetunnel.LocalKubernetes
}
// add origin headers so the service consuming the request on the other site
// is aware of where it came from
@ -1367,13 +1376,44 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error)
}
func (s *clusterSession) Dial(network, addr string) (net.Conn, error) {
return s.monitorConn(s.teleportCluster.Dial(network, addr))
return s.monitorConn(s.teleportCluster.DialWithContext(context.Background(), network, addr))
}
func (s *clusterSession) DialWithContext(ctx context.Context, network, addr string) (net.Conn, error) {
return s.monitorConn(s.teleportCluster.DialWithContext(ctx, network, addr))
}
func (s *clusterSession) DialWithEndpoints(network, addr string) (net.Conn, error) {
return s.monitorConn(s.dialWithEndpoints(context.Background(), network, addr))
}
// This is separated from DialWithEndpoints for testing without monitorConn.
func (s *clusterSession) dialWithEndpoints(ctx context.Context, network, addr string) (net.Conn, error) {
if len(s.teleportClusterEndpoints) == 0 {
return nil, trace.BadParameter("no endpoints to dial")
}
// Shuffle endpoints to balance load
shuffledEndpoints := make([]endpoint, len(s.teleportClusterEndpoints))
copy(shuffledEndpoints, s.teleportClusterEndpoints)
mathrand.Shuffle(len(shuffledEndpoints), func(i, j int) {
shuffledEndpoints[i], shuffledEndpoints[j] = shuffledEndpoints[j], shuffledEndpoints[i]
})
errs := []error{}
for _, endpoint := range shuffledEndpoints {
s.teleportCluster.targetAddr = endpoint.addr
s.teleportCluster.serverID = endpoint.serverID
conn, err := s.teleportCluster.DialWithContext(ctx, network, addr)
if err != nil {
errs = append(errs, err)
continue
}
return conn, nil
}
return nil, trace.NewAggregate(errs...)
}
// TODO(awly): unit test this
func (f *Forwarder) newClusterSession(ctx authContext) (*clusterSession, error) {
if ctx.teleportCluster.isRemote {
@ -1395,7 +1435,7 @@ func (f *Forwarder) newClusterSessionRemoteCluster(ctx authContext) (*clusterSes
}
// remote clusters use special hardcoded URL,
// and use a special dialer
sess.authContext.teleportCluster.targetAddr = reversetunnel.LocalKubernetes
sess.teleportCluster.targetAddr = reversetunnel.LocalKubernetes
transport := f.newTransport(sess.Dial, sess.tlsConfig)
sess.forwarder, err = forward.New(
@ -1416,11 +1456,13 @@ func (f *Forwarder) newClusterSessionSameCluster(ctx authContext) (*clusterSessi
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}
if len(kubeServices) == 0 && ctx.kubeCluster == ctx.teleportCluster.name {
return f.newClusterSessionLocal(ctx)
}
// Validate that the requested kube cluster is registered.
var endpoints []types.Server
var endpoints []endpoint
outer:
for _, s := range kubeServices {
for _, k := range s.GetKubernetesClusters() {
@ -1428,7 +1470,10 @@ outer:
continue
}
// TODO(awly): check RBAC
endpoints = append(endpoints, s)
endpoints = append(endpoints, endpoint{
serverID: fmt.Sprintf("%s.%s", s.GetName(), ctx.teleportCluster.name),
addr: s.GetAddr(),
})
continue outer
}
}
@ -1439,12 +1484,7 @@ outer:
if _, ok := f.creds[ctx.kubeCluster]; ok {
return f.newClusterSessionLocal(ctx)
}
// Pick a random kubernetes_service to serve this request.
//
// Ideally, we should try a few of the endpoints at random until one
// succeeds. But this is simpler for now.
endpoint := endpoints[mathrand.Intn(len(endpoints))]
return f.newClusterSessionDirect(ctx, endpoint)
return f.newClusterSessionDirect(ctx, endpoints)
}
func (f *Forwarder) newClusterSessionLocal(ctx authContext) (*clusterSession, error) {
@ -1489,11 +1529,13 @@ func (f *Forwarder) newClusterSessionLocal(ctx authContext) (*clusterSession, er
return sess, nil
}
func (f *Forwarder) newClusterSessionDirect(ctx authContext, kubeService types.Server) (*clusterSession, error) {
f.log.WithFields(log.Fields{
"kubernetes_service.name": kubeService.GetName(),
"kubernetes_service.addr": kubeService.GetAddr(),
}).Debugf("Kubernetes session for %v forwarded to remote kubernetes_service instance.", ctx)
func (f *Forwarder) newClusterSessionDirect(ctx authContext, endpoints []endpoint) (*clusterSession, error) {
if len(endpoints) == 0 {
return nil, trace.BadParameter("no kube cluster endpoints provided")
}
f.log.WithField("kube_service.endpoints", endpoints).Debugf("Kubernetes session for %v forwarded to remote kubernetes_service instance.", ctx)
sess := &clusterSession{
parent: f,
authContext: ctx,
@ -1501,10 +1543,7 @@ func (f *Forwarder) newClusterSessionDirect(ctx authContext, kubeService types.S
// audit logging. Avoid duplicate logging.
noAuditEvents: true,
}
// Set both addr and serverID, in case this is a kubernetes_service
// connected over a tunnel.
sess.authContext.teleportCluster.targetAddr = kubeService.GetAddr()
sess.authContext.teleportCluster.serverID = fmt.Sprintf("%s.%s", kubeService.GetName(), ctx.teleportCluster.name)
sess.authContext.teleportClusterEndpoints = endpoints
var err error
sess.tlsConfig, err = f.getOrRequestClientCreds(ctx)
@ -1513,12 +1552,11 @@ func (f *Forwarder) newClusterSessionDirect(ctx authContext, kubeService types.S
return nil, trace.AccessDenied("access denied: failed to authenticate with auth server")
}
transport := f.newTransport(sess.Dial, sess.tlsConfig)
transport := f.newTransport(sess.DialWithEndpoints, sess.tlsConfig)
sess.forwarder, err = forward.New(
forward.FlushInterval(100*time.Millisecond),
forward.RoundTripper(transport),
forward.WebsocketDial(sess.Dial),
forward.WebsocketDial(sess.DialWithEndpoints),
forward.Logger(f.log),
forward.ErrorHandler(fwdutils.ErrorHandlerFunc(f.formatForwardResponseError)),
)

View file

@ -20,6 +20,8 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"net"
"net/http"
"sort"
"testing"
@ -35,6 +37,8 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/julienschmidt/httprouter"
"k8s.io/client-go/transport"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
@ -44,7 +48,6 @@ import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"gopkg.in/check.v1"
"k8s.io/client-go/transport"
)
type ForwarderSuite struct{}
@ -55,6 +58,25 @@ func Test(t *testing.T) {
check.TestingT(t)
}
var (
identity = auth.WrapIdentity(tlsca.Identity{
Username: "remote-bob",
Groups: []string{"remote group a", "remote group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
})
unmappedIdentity = auth.WrapIdentity(tlsca.Identity{
Username: "bob",
Groups: []string{"group a", "group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"k8s group a", "k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
})
)
func (s ForwarderSuite) TestRequestCertificate(c *check.C) {
cl, err := newMockCSRClient()
c.Assert(err, check.IsNil)
@ -72,23 +94,9 @@ func (s ForwarderSuite) TestRequestCertificate(c *check.C) {
name: "site a",
},
Context: auth.Context{
User: user,
Identity: auth.WrapIdentity(tlsca.Identity{
Username: "remote-bob",
Groups: []string{"remote group a", "remote group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
}),
UnmappedIdentity: auth.WrapIdentity(tlsca.Identity{
Username: "bob",
Groups: []string{"group a", "group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"k8s group a", "k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
}),
User: user,
Identity: identity,
UnmappedIdentity: unmappedIdentity,
},
}
@ -583,138 +591,275 @@ func (s ForwarderSuite) TestSetupImpersonationHeaders(c *check.C) {
}
}
func (s ForwarderSuite) TestNewClusterSession(c *check.C) {
clientCreds, err := ttlmap.New(defaults.ClientCacheSize)
c.Assert(err, check.IsNil)
csrClient, err := newMockCSRClient()
c.Assert(err, check.IsNil)
f := &Forwarder{
log: logrus.New(),
cfg: ForwarderConfig{
Keygen: testauthority.New(),
AuthClient: csrClient,
CachingAuthClient: mockAccessPoint{},
},
clientCredentials: clientCreds,
ctx: context.TODO(),
activeRequests: make(map[string]context.Context),
}
user, err := types.NewUser("bob")
c.Assert(err, check.IsNil)
func TestNewClusterSession(t *testing.T) {
ctx := context.Background()
f := newMockForwader(ctx, t)
user, err := types.NewUser("bob")
require.NoError(t, err)
c.Log("newClusterSession for a local cluster without kubeconfig")
authCtx := authContext{
Context: auth.Context{
User: user,
Identity: auth.WrapIdentity(tlsca.Identity{
Username: "remote-bob",
Groups: []string{"remote group a", "remote group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
}),
UnmappedIdentity: auth.WrapIdentity(tlsca.Identity{
Username: "bob",
Groups: []string{"group a", "group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"k8s group a", "k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
}),
},
teleportCluster: teleportClusterClient{
name: "local",
},
sessionTTL: time.Minute,
}
_, err = f.newClusterSession(authCtx)
c.Assert(err, check.NotNil)
c.Assert(trace.IsNotFound(err), check.Equals, true)
c.Assert(f.clientCredentials.Len(), check.Equals, 0)
f.creds = map[string]*kubeCreds{
"local": {
targetAddr: "k8s.example.com",
tlsConfig: &tls.Config{},
transportConfig: &transport.Config{},
},
}
c.Log("newClusterSession for a local cluster")
authCtx = authContext{
Context: auth.Context{
User: user,
Identity: auth.WrapIdentity(tlsca.Identity{
Username: "remote-bob",
Groups: []string{"remote group a", "remote group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
}),
UnmappedIdentity: auth.WrapIdentity(tlsca.Identity{
Username: "bob",
Groups: []string{"group a", "group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"k8s group a", "k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
}),
User: user,
Identity: identity,
UnmappedIdentity: unmappedIdentity,
},
teleportCluster: teleportClusterClient{
name: "local",
},
sessionTTL: time.Minute,
kubeCluster: "local",
kubeCluster: "public",
}
sess, err := f.newClusterSession(authCtx)
c.Assert(err, check.IsNil)
c.Assert(sess.authContext.teleportCluster.targetAddr, check.Equals, f.creds["local"].targetAddr)
c.Assert(sess.forwarder, check.NotNil)
// Make sure newClusterSession used f.creds instead of requesting a
// Teleport client cert.
c.Assert(sess.tlsConfig, check.Equals, f.creds["local"].tlsConfig)
c.Assert(csrClient.lastCert, check.IsNil)
c.Assert(f.clientCredentials.Len(), check.Equals, 0)
c.Log("newClusterSession for a remote cluster")
authCtx = authContext{
Context: auth.Context{
User: user,
Identity: auth.WrapIdentity(tlsca.Identity{
Username: "remote-bob",
Groups: []string{"remote group a", "remote group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"remote k8s group a", "remote k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
}),
UnmappedIdentity: auth.WrapIdentity(tlsca.Identity{
Username: "bob",
Groups: []string{"group a", "group b"},
Usage: []string{"usage a", "usage b"},
Principals: []string{"principal a", "principal b"},
KubernetesGroups: []string{"k8s group a", "k8s group b"},
Traits: map[string][]string{"trait a": []string{"b", "c"}},
}),
},
teleportCluster: teleportClusterClient{
t.Run("newClusterSession for a local cluster without kubeconfig", func(t *testing.T) {
authCtx := authCtx
authCtx.kubeCluster = ""
_, err = f.newClusterSession(authCtx)
require.Error(t, err)
require.Equal(t, trace.IsNotFound(err), true)
require.Equal(t, f.clientCredentials.Len(), 0)
})
t.Run("newClusterSession for a local cluster", func(t *testing.T) {
authCtx := authCtx
authCtx.kubeCluster = "local"
// Set local creds for the following tests
f.creds = map[string]*kubeCreds{
"local": {
targetAddr: "k8s.example.com",
tlsConfig: &tls.Config{},
transportConfig: &transport.Config{},
},
}
sess, err := f.newClusterSession(authCtx)
require.NoError(t, err)
require.Equal(t, f.creds["local"].targetAddr, sess.authContext.teleportCluster.targetAddr)
require.NotNil(t, sess.forwarder)
// Make sure newClusterSession used f.creds instead of requesting a
// Teleport client cert.
require.Equal(t, f.creds["local"].tlsConfig, sess.tlsConfig)
require.Nil(t, f.cfg.AuthClient.(*mockCSRClient).lastCert)
require.Equal(t, 0, f.clientCredentials.Len())
})
t.Run("newClusterSession for a remote cluster", func(t *testing.T) {
authCtx := authCtx
authCtx.kubeCluster = ""
authCtx.teleportCluster = teleportClusterClient{
name: "remote",
isRemote: true,
}
sess, err := f.newClusterSession(authCtx)
require.NoError(t, err)
require.Equal(t, reversetunnel.LocalKubernetes, sess.authContext.teleportCluster.targetAddr)
require.NotNil(t, sess.forwarder)
// Make sure newClusterSession obtained a new client cert instead of using
// f.creds.
require.NotEqual(t, f.creds["local"].tlsConfig, sess.tlsConfig)
require.Equal(t, f.cfg.AuthClient.(*mockCSRClient).lastCert.Raw, sess.tlsConfig.Certificates[0].Certificate[0])
require.Equal(t, [][]byte{f.cfg.AuthClient.(*mockCSRClient).ca.Cert.RawSubject}, sess.tlsConfig.RootCAs.Subjects())
require.Equal(t, 1, f.clientCredentials.Len())
})
t.Run("newClusterSession with public kube_service endpoints", func(t *testing.T) {
publicKubeServer := &types.ServerV2{
Kind: types.KindKubeService,
Version: types.V2,
Metadata: types.Metadata{
Name: "public-server",
},
Spec: types.ServerSpecV2{
Addr: "k8s.example.com:3026",
Hostname: "",
KubernetesClusters: []*types.KubernetesCluster{{
Name: "public",
}},
},
}
reverseTunnelKubeServer := &types.ServerV2{
Kind: types.KindKubeService,
Version: types.V2,
Metadata: types.Metadata{
Name: "reverse-tunnel-server",
},
Spec: types.ServerSpecV2{
Addr: reversetunnel.LocalKubernetes,
Hostname: "",
KubernetesClusters: []*types.KubernetesCluster{{
Name: "public",
}},
},
}
f.cfg.CachingAuthClient = mockAccessPoint{
kubeServices: []types.Server{
publicKubeServer,
reverseTunnelKubeServer,
},
}
sess, err := f.newClusterSession(authCtx)
require.NoError(t, err)
expectedEndpoints := []endpoint{
{
addr: publicKubeServer.GetAddr(),
serverID: fmt.Sprintf("%v.local", publicKubeServer.GetName()),
},
{
addr: reverseTunnelKubeServer.GetAddr(),
serverID: fmt.Sprintf("%v.local", reverseTunnelKubeServer.GetName()),
},
}
require.Equal(t, expectedEndpoints, sess.authContext.teleportClusterEndpoints)
})
}
func TestDialWithEndpoints(t *testing.T) {
ctx := context.Background()
f := newMockForwader(ctx, t)
user, err := types.NewUser("bob")
require.NoError(t, err)
authCtx := authContext{
Context: auth.Context{
User: user,
Identity: identity,
UnmappedIdentity: unmappedIdentity,
},
sessionTTL: time.Minute,
teleportCluster: teleportClusterClient{
name: "local",
dial: func(ctx context.Context, network, addr, serverID string) (net.Conn, error) {
return &net.TCPConn{}, nil
},
},
sessionTTL: time.Minute,
kubeCluster: "public",
}
publicKubeServer := &types.ServerV2{
Kind: types.KindKubeService,
Version: types.V2,
Metadata: types.Metadata{
Name: "public-server",
},
Spec: types.ServerSpecV2{
Addr: "k8s.example.com:3026",
Hostname: "",
KubernetesClusters: []*types.KubernetesCluster{{
Name: "public",
}},
},
}
t.Run("Dial public endpoint", func(t *testing.T) {
f.cfg.CachingAuthClient = mockAccessPoint{
kubeServices: []types.Server{
publicKubeServer,
},
}
sess, err := f.newClusterSession(authCtx)
require.NoError(t, err)
_, err = sess.dialWithEndpoints(ctx, "", "")
require.NoError(t, err)
require.Equal(t, publicKubeServer.GetAddr(), sess.authContext.teleportCluster.targetAddr)
expectServerID := fmt.Sprintf("%v.%v", publicKubeServer.GetName(), authCtx.teleportCluster.name)
require.Equal(t, expectServerID, sess.authContext.teleportCluster.serverID)
})
reverseTunnelKubeServer := &types.ServerV2{
Kind: types.KindKubeService,
Version: types.V2,
Metadata: types.Metadata{
Name: "reverse-tunnel-server",
},
Spec: types.ServerSpecV2{
Addr: reversetunnel.LocalKubernetes,
Hostname: "",
KubernetesClusters: []*types.KubernetesCluster{{
Name: "public",
}},
},
}
t.Run("Dial reverse tunnel endpoint", func(t *testing.T) {
f.cfg.CachingAuthClient = mockAccessPoint{
kubeServices: []types.Server{
reverseTunnelKubeServer,
},
}
sess, err := f.newClusterSession(authCtx)
require.NoError(t, err)
_, err = sess.dialWithEndpoints(ctx, "", "")
require.NoError(t, err)
require.Equal(t, reverseTunnelKubeServer.GetAddr(), sess.authContext.teleportCluster.targetAddr)
expectServerID := fmt.Sprintf("%v.%v", reverseTunnelKubeServer.GetName(), authCtx.teleportCluster.name)
require.Equal(t, expectServerID, sess.authContext.teleportCluster.serverID)
})
t.Run("newClusterSession multiple kube clusters", func(t *testing.T) {
f.cfg.CachingAuthClient = mockAccessPoint{
kubeServices: []types.Server{
publicKubeServer,
reverseTunnelKubeServer,
},
}
sess, err := f.newClusterSession(authCtx)
require.NoError(t, err)
_, err = sess.dialWithEndpoints(ctx, "", "")
require.NoError(t, err)
// The endpoint used to dial will be chosen at random. Make sure we hit one of them.
switch sess.teleportCluster.targetAddr {
case publicKubeServer.GetAddr():
expectServerID := fmt.Sprintf("%v.%v", publicKubeServer.GetName(), authCtx.teleportCluster.name)
require.Equal(t, expectServerID, sess.authContext.teleportCluster.serverID)
case reverseTunnelKubeServer.GetAddr():
expectServerID := fmt.Sprintf("%v.%v", reverseTunnelKubeServer.GetName(), authCtx.teleportCluster.name)
require.Equal(t, expectServerID, sess.authContext.teleportCluster.serverID)
default:
t.Fatalf("Unexpected targetAddr: %v", sess.authContext.teleportCluster.targetAddr)
}
})
}
func newMockForwader(ctx context.Context, t *testing.T) *Forwarder {
clientCreds, err := ttlmap.New(defaults.ClientCacheSize)
require.NoError(t, err)
csrClient, err := newMockCSRClient()
require.NoError(t, err)
return &Forwarder{
log: logrus.New(),
router: *httprouter.New(),
cfg: ForwarderConfig{
Keygen: testauthority.New(),
AuthClient: csrClient,
CachingAuthClient: mockAccessPoint{},
Clock: clockwork.NewFakeClock(),
Context: ctx,
},
clientCredentials: clientCreds,
activeRequests: make(map[string]context.Context),
ctx: ctx,
}
sess, err = f.newClusterSession(authCtx)
c.Assert(err, check.IsNil)
c.Assert(sess.authContext.teleportCluster.targetAddr, check.Equals, reversetunnel.LocalKubernetes)
c.Assert(sess.forwarder, check.NotNil)
// Make sure newClusterSession obtained a new client cert instead of using
// f.creds.
c.Assert(sess.tlsConfig, check.Not(check.Equals), f.creds["local"].tlsConfig)
c.Assert(sess.tlsConfig.Certificates[0].Certificate[0], check.DeepEquals, csrClient.lastCert.Raw)
c.Assert(sess.tlsConfig.RootCAs.Subjects(), check.DeepEquals, [][]byte{csrClient.ca.Cert.RawSubject})
c.Assert(f.clientCredentials.Len(), check.Equals, 1)
}
// mockCSRClient to intercept ProcessKubeCSR requests, record them and return a

View file

@ -17,11 +17,17 @@ limitations under the License.
package main
import (
"math/rand"
"os"
"time"
"github.com/gravitational/teleport/tool/teleport/common"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
func main() {
common.Run(common.Options{
Args: os.Args[1:],