mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 08:43:58 +00:00
ALPN connection upgrade for MySQL behind ALB (#15669)
This commit is contained in:
parent
09cd4bfdd8
commit
8394f4fb48
12
constants.go
12
constants.go
|
@ -774,3 +774,15 @@ const UserSingleUseCertTTL = time.Minute
|
|||
// StandardHTTPSPort is the default port used for the https URI scheme,
|
||||
// cf. RFC 7230 § 2.7.2.
|
||||
const StandardHTTPSPort = 443
|
||||
|
||||
const (
|
||||
// WebAPIConnUpgrade is the HTTP web API to make the connection upgrade
|
||||
// call.
|
||||
WebAPIConnUpgrade = "/webapi/connectionupgrade"
|
||||
// WebAPIConnUpgradeHeader is the header used to indicate the requested
|
||||
// connection upgrade types in the connection upgrade API.
|
||||
WebAPIConnUpgradeHeader = "Upgrade"
|
||||
// WebAPIConnUpgradeTypeALPN is a connection upgrade type that specifies
|
||||
// the upgraded connection should be handled by the ALPN handler.
|
||||
WebAPIConnUpgradeTypeALPN = "alpn"
|
||||
)
|
||||
|
|
|
@ -19,6 +19,8 @@ package integration
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -33,11 +35,13 @@ import (
|
|||
"github.com/gravitational/teleport/api/types"
|
||||
apiutils "github.com/gravitational/teleport/api/utils"
|
||||
"github.com/gravitational/teleport/integration/helpers"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/service"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy"
|
||||
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
"github.com/gravitational/teleport/lib/srv/db/postgres"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/jackc/pgconn"
|
||||
|
@ -470,20 +474,38 @@ func mustCreateKubeConfigFile(t *testing.T, config clientcmdapi.Config) string {
|
|||
return configPath
|
||||
}
|
||||
|
||||
func mustStartALPNLocalProxy(t *testing.T, addr string, protocol alpncommon.Protocol) *alpnproxy.LocalProxy {
|
||||
func mustCreateListener(t *testing.T) net.Listener {
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
address, err := utils.ParseAddr(addr)
|
||||
require.NoError(t, err)
|
||||
lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{
|
||||
t.Cleanup(func() {
|
||||
listener.Close()
|
||||
})
|
||||
return listener
|
||||
}
|
||||
|
||||
func mustStartALPNLocalProxy(t *testing.T, addr string, protocol alpncommon.Protocol) *alpnproxy.LocalProxy {
|
||||
return mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{
|
||||
RemoteProxyAddr: addr,
|
||||
Protocols: []alpncommon.Protocol{protocol},
|
||||
InsecureSkipVerify: true,
|
||||
Listener: listener,
|
||||
ParentContext: context.Background(),
|
||||
SNI: address.Host(),
|
||||
})
|
||||
}
|
||||
|
||||
func mustStartALPNLocalProxyWithConfig(t *testing.T, config alpnproxy.LocalProxyConfig) *alpnproxy.LocalProxy {
|
||||
if config.Listener == nil {
|
||||
config.Listener = mustCreateListener(t)
|
||||
}
|
||||
if config.SNI == "" {
|
||||
address, err := utils.ParseAddr(config.RemoteProxyAddr)
|
||||
require.NoError(t, err)
|
||||
config.SNI = address.Host()
|
||||
}
|
||||
if config.ParentContext == nil {
|
||||
config.ParentContext = context.TODO()
|
||||
}
|
||||
|
||||
lp, err := alpnproxy.NewLocalProxy(config)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
lp.Close()
|
||||
|
@ -512,3 +534,71 @@ func makeNodeConfig(nodeName, authAddr string) *service.Config {
|
|||
nodeConfig.CircuitBreakerConfig = breaker.NoopBreakerConfig()
|
||||
return nodeConfig
|
||||
}
|
||||
|
||||
func mustCreateSelfSignedCert(t *testing.T) tls.Certificate {
|
||||
caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{
|
||||
CommonName: "localhost",
|
||||
}, []string{"localhost"}, defaults.CATTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := tls.X509KeyPair(caCert, caKey)
|
||||
require.NoError(t, err)
|
||||
return cert
|
||||
}
|
||||
|
||||
// mockAWSALBProxy is a mock proxy server that simulates an AWS application
|
||||
// load balancer where ALPN is not supported. Note that this mock does not
|
||||
// actually balance traffic.
|
||||
type mockAWSALBProxy struct {
|
||||
net.Listener
|
||||
proxyAddr string
|
||||
cert tls.Certificate
|
||||
}
|
||||
|
||||
func (m *mockAWSALBProxy) serve(ctx context.Context, t *testing.T) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := m.Accept()
|
||||
if err != nil {
|
||||
if utils.IsOKNetworkError(err) {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Handshake with incoming client and drops ALPN.
|
||||
downstreamConn := tls.Server(conn, &tls.Config{
|
||||
Certificates: []tls.Certificate{m.cert},
|
||||
})
|
||||
require.NoError(t, downstreamConn.HandshakeContext(ctx))
|
||||
|
||||
// Make a connection to the proxy server with ALPN protos.
|
||||
upstreamConn, err := tls.Dial("tcp", m.proxyAddr, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
utils.ProxyConn(ctx, downstreamConn, upstreamConn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func mustStartMockALBProxy(t *testing.T, proxyAddr string) *mockAWSALBProxy {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
m := &mockAWSALBProxy{
|
||||
proxyAddr: proxyAddr,
|
||||
Listener: mustCreateListener(t),
|
||||
cert: mustCreateSelfSignedCert(t),
|
||||
}
|
||||
go m.serve(ctx, t)
|
||||
return m
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ import (
|
|||
"github.com/gravitational/teleport/lib/auth/testauthority"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/service"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy"
|
||||
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
"github.com/gravitational/teleport/lib/srv/db/common"
|
||||
"github.com/gravitational/teleport/lib/srv/db/mongodb"
|
||||
|
@ -434,8 +435,8 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {
|
|||
// Disconnect.
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
})
|
||||
|
||||
t.Run("connect to leaf cluster via proxy", func(t *testing.T) {
|
||||
client, err := mysql.MakeTestClient(common.TestClientConfig{
|
||||
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
|
||||
|
@ -620,6 +621,134 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
})
|
||||
|
||||
// Simulate situations where an AWS ALB is between client and the Teleport
|
||||
// Proxy service, which drops ALPN along the way. The ALPN local proxy will
|
||||
// need to make a connection upgrade first through a web API provided by
|
||||
// the Proxy server and then tunnel the original ALPN/TLS routing traffic
|
||||
// inside this tunnel.
|
||||
t.Run("ALPN connection upgrade", func(t *testing.T) {
|
||||
// Make a mock ALB which points to the Teleport Proxy Service. Then
|
||||
// ALPN local proxies will point to this ALB instead.
|
||||
albProxy := mustStartMockALBProxy(t, pack.root.cluster.Web)
|
||||
|
||||
// Test a protocol in the alpncommon.IsDBTLSProtocol list where
|
||||
// the database client will perform a native TLS handshake.
|
||||
//
|
||||
// Packet layers:
|
||||
// - HTTPS served by Teleport web server for connection upgrade
|
||||
// - TLS routing with alpncommon.ProtocolMongoDB (no client cert)
|
||||
// - TLS with client cert (provided by the database client)
|
||||
// - MongoDB
|
||||
t.Run("database client native TLS", func(t *testing.T) {
|
||||
lp := mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{
|
||||
RemoteProxyAddr: albProxy.Addr().String(),
|
||||
Protocols: []alpncommon.Protocol{alpncommon.ProtocolMongoDB},
|
||||
ALPNConnUpgradeRequired: true,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
client, err := mongodb.MakeTestClient(context.Background(), common.TestClientConfig{
|
||||
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
|
||||
AuthServer: pack.root.cluster.Process.GetAuthServer(),
|
||||
Address: lp.GetAddr(),
|
||||
Cluster: pack.root.cluster.Secrets.SiteName,
|
||||
Username: pack.root.user.GetName(),
|
||||
RouteToDatabase: tlsca.RouteToDatabase{
|
||||
ServiceName: pack.root.mongoService.Name,
|
||||
Protocol: pack.root.mongoService.Protocol,
|
||||
Username: "admin",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Execute a query.
|
||||
_, err = client.Database("test").Collection("test").Find(context.Background(), bson.M{})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Disconnect.
|
||||
require.NoError(t, client.Disconnect(context.Background()))
|
||||
})
|
||||
|
||||
// Test the case where the database client cert is terminated within
|
||||
// the database protocol.
|
||||
//
|
||||
// Packet layers:
|
||||
// - HTTPS served by Teleport web server for connection upgrade
|
||||
// - TLS routing with alpncommon.ProtocolMySQL (no client cert)
|
||||
// - MySQL handshake then upgrade to TLS with Teleport issued client cert
|
||||
// - MySQL protocol
|
||||
t.Run("MySQL custom TLS", func(t *testing.T) {
|
||||
lp := mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{
|
||||
RemoteProxyAddr: albProxy.Addr().String(),
|
||||
Protocols: []alpncommon.Protocol{alpncommon.ProtocolMySQL},
|
||||
ALPNConnUpgradeRequired: true,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
client, err := mysql.MakeTestClient(common.TestClientConfig{
|
||||
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
|
||||
AuthServer: pack.root.cluster.Process.GetAuthServer(),
|
||||
Address: lp.GetAddr(),
|
||||
Cluster: pack.root.cluster.Secrets.SiteName,
|
||||
Username: pack.root.user.GetName(),
|
||||
RouteToDatabase: tlsca.RouteToDatabase{
|
||||
ServiceName: pack.root.mysqlService.Name,
|
||||
Protocol: pack.root.mysqlService.Protocol,
|
||||
Username: "root",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Execute a query.
|
||||
result, err := client.Execute("select 1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, mysql.TestQueryResponse, result)
|
||||
|
||||
// Disconnect.
|
||||
require.NoError(t, client.Close())
|
||||
})
|
||||
|
||||
// Test the case where the client cert is terminated by Teleport and
|
||||
// the database client sends data in plain database protocol.
|
||||
//
|
||||
// Packet layers:
|
||||
// - HTTPS served by Teleport web server for connection upgrade
|
||||
// - TLS routing with alpncommon.ProtocolMySQL (client cert provided by ALPN local proxy)
|
||||
// - MySQL protocol
|
||||
t.Run("authenticated tunnel", func(t *testing.T) {
|
||||
routeToDatabase := tlsca.RouteToDatabase{
|
||||
ServiceName: pack.root.mysqlService.Name,
|
||||
Protocol: pack.root.mysqlService.Protocol,
|
||||
Username: "root",
|
||||
}
|
||||
clientTLSConfig, err := common.MakeTestClientTLSConfig(common.TestClientConfig{
|
||||
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
|
||||
AuthServer: pack.root.cluster.Process.GetAuthServer(),
|
||||
Cluster: pack.root.cluster.Secrets.SiteName,
|
||||
Username: pack.root.user.GetName(),
|
||||
RouteToDatabase: routeToDatabase,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
lp := mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{
|
||||
RemoteProxyAddr: albProxy.Addr().String(),
|
||||
Protocols: []alpncommon.Protocol{alpncommon.ProtocolMySQL},
|
||||
ALPNConnUpgradeRequired: true,
|
||||
InsecureSkipVerify: true,
|
||||
Certs: clientTLSConfig.Certificates,
|
||||
})
|
||||
|
||||
client, err := mysql.MakeTestClientWithoutTLS(lp.GetAddr(), routeToDatabase)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Execute a query.
|
||||
result, err := client.Execute("select 1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, mysql.TestQueryResponse, result)
|
||||
|
||||
// Disconnect.
|
||||
require.NoError(t, client.Close())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestALPNSNIProxyAppAccess tests application access via ALPN SNI proxy service.
|
||||
|
|
|
@ -783,11 +783,7 @@ func applyProxyConfig(fc *FileConfig, cfg *service.Config) error {
|
|||
// was passed. If the certificate is not self-signed, verify the certificate
|
||||
// chain from leaf to root with the trust store on the computer so browsers
|
||||
// don't complain.
|
||||
certificateChainBytes, err := utils.ReadPath(p.Certificate)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
certificateChain, err := utils.ReadCertificateChain(certificateChainBytes)
|
||||
certificateChain, err := utils.ReadCertificatesFromPath(p.Certificate)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
@ -1372,11 +1368,7 @@ func applyMetricsConfig(fc *FileConfig, cfg *service.Config) error {
|
|||
return trace.NotFound("metrics service cert does not exist: %s", p.Certificate)
|
||||
}
|
||||
|
||||
certificateChainBytes, err := utils.ReadPath(p.Certificate)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
certificateChain, err := utils.ReadCertificateChain(certificateChainBytes)
|
||||
certificateChain, err := utils.ReadCertificatesFromPath(p.Certificate)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
|
|
@ -3402,6 +3402,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
|
|||
}
|
||||
|
||||
// Register web proxy server
|
||||
alpnHandlerForWeb := &alpnproxy.ConnectionHandlerWrapper{}
|
||||
var webServer *http.Server
|
||||
var webHandler *web.APIHandler
|
||||
var minimalWebServer *http.Server
|
||||
|
@ -3441,6 +3442,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
|
|||
ClusterFeatures: process.getClusterFeatures(),
|
||||
ProxySettings: proxySettings,
|
||||
PublicProxyAddr: process.proxyPublicAddr().Addr,
|
||||
ALPNHandler: alpnHandlerForWeb.HandleConnection,
|
||||
}
|
||||
webHandler, err = web.NewHandler(webConfig)
|
||||
if err != nil {
|
||||
|
@ -3810,6 +3812,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
|
|||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
alpnTLSConfigForWeb := process.setupALPNTLSConfigForWeb(serverTLSConfig, accessPoint, clusterName)
|
||||
alpnHandlerForWeb.Set(alpnServer.MakeConnectionHandler(alpnTLSConfigForWeb))
|
||||
|
||||
process.RegisterCriticalFunc("proxy.tls.alpn.sni.proxy", func() error {
|
||||
log.Infof("Starting TLS ALPN SNI proxy server on %v.", listeners.alpn.Addr())
|
||||
if err := alpnServer.Serve(process.ExitContext()); err != nil {
|
||||
|
@ -4022,10 +4028,6 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers
|
|||
tlsConfig.NextProtos = apiutils.Deduplicate(append(tlsConfig.NextProtos, acme.ALPNProto))
|
||||
}
|
||||
|
||||
// Go 1.17 introduced strict ALPN https://golang.org/doc/go1.17#ALPN If a client protocol is not recognized
|
||||
// the TLS handshake will fail.
|
||||
tlsConfig.NextProtos = apiutils.Deduplicate(append(tlsConfig.NextProtos, alpncommon.ProtocolsToString(alpncommon.SupportedProtocols)...))
|
||||
|
||||
for _, pair := range process.Config.Proxy.KeyPairs {
|
||||
process.Config.Log.Infof("Loading TLS certificate %v and key %v.", pair.Certificate, pair.PrivateKey)
|
||||
|
||||
|
@ -4036,6 +4038,18 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers
|
|||
tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
|
||||
}
|
||||
|
||||
setupTLSConfigALPNProtocols(tlsConfig)
|
||||
setupTLSConfigClientCAsForCluster(tlsConfig, accessPoint, clusterName)
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func setupTLSConfigALPNProtocols(tlsConfig *tls.Config) {
|
||||
// Go 1.17 introduced strict ALPN https://golang.org/doc/go1.17#ALPN If a client protocol is not recognized
|
||||
// the TLS handshake will fail.
|
||||
tlsConfig.NextProtos = apiutils.Deduplicate(append(tlsConfig.NextProtos, alpncommon.ProtocolsToString(alpncommon.SupportedProtocols)...))
|
||||
}
|
||||
|
||||
func setupTLSConfigClientCAsForCluster(tlsConfig *tls.Config, accessPoint auth.ReadProxyAccessPoint, clusterName string) {
|
||||
tlsConfig.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
tlsClone := tlsConfig.Clone()
|
||||
|
||||
|
@ -4061,10 +4075,18 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers
|
|||
|
||||
return tlsClone, nil
|
||||
}
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
func setupALPNRouter(listeners *proxyListeners, serverTLSConf *tls.Config, cfg *Config) *alpnproxy.Router {
|
||||
func (process *TeleportProcess) setupALPNTLSConfigForWeb(serverTLSConfig *tls.Config, accessPoint auth.ReadProxyAccessPoint, clusterName string) *tls.Config {
|
||||
tlsConfig := utils.TLSConfig(process.Config.CipherSuites)
|
||||
tlsConfig.Certificates = serverTLSConfig.Certificates
|
||||
|
||||
setupTLSConfigALPNProtocols(tlsConfig)
|
||||
setupTLSConfigClientCAsForCluster(tlsConfig, accessPoint, clusterName)
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
func setupALPNRouter(listeners *proxyListeners, serverTLSConfig *tls.Config, cfg *Config) *alpnproxy.Router {
|
||||
if listeners.web == nil || cfg.Proxy.DisableTLS || cfg.Proxy.DisableALPNSNIListener {
|
||||
return nil
|
||||
}
|
||||
|
@ -4112,7 +4134,7 @@ func setupALPNRouter(listeners *proxyListeners, serverTLSConf *tls.Config, cfg *
|
|||
router.Add(alpnproxy.HandlerDecs{
|
||||
MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolProxySSH),
|
||||
Handler: sshProxyListener.HandleConnection,
|
||||
TLSConfig: serverTLSConf,
|
||||
TLSConfig: serverTLSConfig,
|
||||
})
|
||||
listeners.ssh = sshProxyListener
|
||||
|
||||
|
|
148
lib/srv/alpnproxy/conn_upgrade.go
Normal file
148
lib/srv/alpnproxy/conn_upgrade.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package alpnproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/gravitational/teleport/api/defaults"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
)
|
||||
|
||||
// IsALPNConnUpgradeRequired returns true if a tunnel is required through a HTTP
|
||||
// connection upgrade for ALPN connections.
|
||||
//
|
||||
// The function makes a test connection to the Proxy Service and checks if the
|
||||
// ALPN is supported. If not, the Proxy Service is likely behind an AWS ALB or
|
||||
// some custom proxy services that strip out ALPN and SNI information on the
|
||||
// way to our Proxy Service.
|
||||
//
|
||||
// In those cases, the Teleport client should make a HTTP "upgrade" call to the
|
||||
// Proxy Service to establish a tunnel for the originally planned traffic to
|
||||
// preserve the ALPN and SNI information.
|
||||
func IsALPNConnUpgradeRequired(addr string, insecure bool) bool {
|
||||
netDialer := &net.Dialer{
|
||||
Timeout: defaults.DefaultDialTimeout,
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: []string{string(common.ProtocolReverseTunnel)},
|
||||
InsecureSkipVerify: insecure,
|
||||
}
|
||||
testConn, err := tls.DialWithDialer(netDialer, "tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
// If dialing TLS fails for any reason, we assume connection upgrade is
|
||||
// not required so it will fallback to original connection method.
|
||||
//
|
||||
// This includes handshake failures where both peers support ALPN but
|
||||
// no common protocol is getting negotiated. We may have to revisit
|
||||
// this situation or make it configurable if we have to get through a
|
||||
// middleman with this behavior. For now, we are only interested in the
|
||||
// case where the middleman does not support ALPN.
|
||||
logrus.Infof("ALPN connection upgrade test failed for %q: %v.", addr, err)
|
||||
return false
|
||||
}
|
||||
defer testConn.Close()
|
||||
|
||||
// Upgrade required when ALPN is not supported on the remote side so
|
||||
// NegotiatedProtocol comes back as empty.
|
||||
result := testConn.ConnectionState().NegotiatedProtocol == ""
|
||||
logrus.Debugf("ALPN connection upgrade required for %q: %v.", addr, result)
|
||||
return result
|
||||
}
|
||||
|
||||
// alpnConnUpgradeDialer makes an "HTTP" upgrade call to the Proxy Service then
|
||||
// tunnels the connection with this connection upgrade.
|
||||
type alpnConnUpgradeDialer struct {
|
||||
netDialer *net.Dialer
|
||||
insecure bool
|
||||
}
|
||||
|
||||
// newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer.
|
||||
func newALPNConnUpgradeDialer(keepAlivePeriod, dialTimeout time.Duration, insecure bool) ContextDialer {
|
||||
return &alpnConnUpgradeDialer{
|
||||
insecure: insecure,
|
||||
netDialer: &net.Dialer{
|
||||
KeepAlive: keepAlivePeriod,
|
||||
Timeout: dialTimeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DialContext implements ContextDialer
|
||||
func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
logrus.Debugf("ALPN connection upgrade for %v.", addr)
|
||||
|
||||
tlsConn, err := tls.DialWithDialer(d.netDialer, network, addr, &tls.Config{
|
||||
InsecureSkipVerify: d.insecure,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
err = upgradeConnThroughWebAPI(tlsConn, url.URL{
|
||||
Host: addr,
|
||||
Scheme: "https",
|
||||
Path: teleport.WebAPIConnUpgrade,
|
||||
})
|
||||
if err != nil {
|
||||
defer tlsConn.Close()
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error {
|
||||
req, err := http.NewRequest(http.MethodGet, api.String(), nil)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// For now, only "alpn" is supported.
|
||||
req.Header.Add(teleport.WebAPIConnUpgradeHeader, teleport.WebAPIConnUpgradeTypeALPN)
|
||||
|
||||
// Send the request and check if upgrade is successful.
|
||||
if err = req.Write(conn); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if http.StatusSwitchingProtocols != resp.StatusCode {
|
||||
if http.StatusNotFound == resp.StatusCode {
|
||||
return trace.NotImplemented(
|
||||
"connection upgrade call to %q failed with status code %v. Please upgrade the server and try again.",
|
||||
teleport.WebAPIConnUpgrade,
|
||||
resp.StatusCode,
|
||||
)
|
||||
}
|
||||
return trace.BadParameter("failed to switch Protocols %v", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
186
lib/srv/alpnproxy/conn_upgrade_test.go
Normal file
186
lib/srv/alpnproxy/conn_upgrade_test.go
Normal file
|
@ -0,0 +1,186 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package alpnproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
)
|
||||
|
||||
func TestIsALPNConnUpgradeRequired(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
serverProtos []string
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
name: "upgrade required",
|
||||
serverProtos: nil, // Use nil for NextProtos to simulate no ALPN support.
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "upgrade not required (proto negotiated)",
|
||||
serverProtos: []string{string(common.ProtocolReverseTunnel)},
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "upgrade not required (handshake error)",
|
||||
serverProtos: []string{"unknown"},
|
||||
expectedResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
server := mustStartMockALPNServer(t, test.serverProtos)
|
||||
require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(server.Addr().String(), true))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestALPNConUpgradeDialer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("connection upgraded", func(t *testing.T) {
|
||||
server := httptest.NewTLSServer(mockConnUpgradeHandler(t, "alpn", []byte("hello")))
|
||||
addr, err := url.Parse(server.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := newALPNConnUpgradeDialer(0, 5*time.Second, true)
|
||||
conn, err := dialer.DialContext(context.TODO(), "tcp", addr.Host)
|
||||
require.NoError(t, err)
|
||||
|
||||
data := make([]byte, 100)
|
||||
n, err := conn.Read(data)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, string(data[:n]), "hello")
|
||||
})
|
||||
|
||||
t.Run("connection upgrade API not found", func(t *testing.T) {
|
||||
server := httptest.NewTLSServer(http.NotFoundHandler())
|
||||
addr, err := url.Parse(server.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := newALPNConnUpgradeDialer(0, 5*time.Second, true)
|
||||
_, err = dialer.DialContext(context.TODO(), "tcp", addr.Host)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
type mockALPNServer struct {
|
||||
net.Listener
|
||||
cert tls.Certificate
|
||||
supportedProtos []string
|
||||
}
|
||||
|
||||
func (m *mockALPNServer) serve(ctx context.Context, t *testing.T) {
|
||||
config := &tls.Config{
|
||||
NextProtos: m.supportedProtos,
|
||||
Certificates: []tls.Certificate{m.cert},
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
conn, err := m.Accept()
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
clientConn := tls.Server(conn, config)
|
||||
clientConn.HandshakeContext(ctx)
|
||||
clientConn.Close()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
listener.Close()
|
||||
})
|
||||
|
||||
caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{
|
||||
CommonName: "localhost",
|
||||
}, []string{"localhost"}, defaults.CATTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := tls.X509KeyPair(caCert, caKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
m := &mockALPNServer{
|
||||
Listener: listener,
|
||||
cert: cert,
|
||||
supportedProtos: supportedProtos,
|
||||
}
|
||||
go m.serve(ctx, t)
|
||||
return m
|
||||
}
|
||||
|
||||
// mockConnUpgradeHandler mocks the server side implementation to handle an
|
||||
// upgrade request and sends back some data inside the tunnel.
|
||||
func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, teleport.WebAPIConnUpgrade, r.URL.Path)
|
||||
require.Equal(t, upgradeType, r.Header.Get(teleport.WebAPIConnUpgradeHeader))
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
require.True(t, ok)
|
||||
|
||||
conn, _, err := hj.Hijack()
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Upgrade response.
|
||||
response := &http.Response{
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
}
|
||||
require.NoError(t, response.Write(conn))
|
||||
|
||||
// Upgraded.
|
||||
_, err = conn.Write(write)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
100
lib/srv/alpnproxy/dialer.go
Normal file
100
lib/srv/alpnproxy/dialer.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package alpnproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// ContextDialer represents network dialer interface that uses context
|
||||
type ContextDialer interface {
|
||||
// DialContext is a function that dials the specified address
|
||||
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// ALPNDialerConfig is the config for ALPNDialer.
|
||||
type ALPNDialerConfig struct {
|
||||
// KeepAlivePeriod defines period between keep alives.
|
||||
KeepAlivePeriod time.Duration
|
||||
// DialTimeout defines how long to attempt dialing before timing out.
|
||||
DialTimeout time.Duration
|
||||
// TLSConfig is the TLS config used for the TLS connection.
|
||||
TLSConfig *tls.Config
|
||||
// ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required.
|
||||
ALPNConnUpgradeRequired bool
|
||||
}
|
||||
|
||||
// ALPNDialer is a ContextDialer that dials a connection to the Proxy Service
|
||||
// with ALPN and SNI configured in the provided TLSConfig. An ALPN connection
|
||||
// upgrade is also performed at the initial connection, if an upgrade is
|
||||
// required.
|
||||
type ALPNDialer struct {
|
||||
cfg ALPNDialerConfig
|
||||
}
|
||||
|
||||
// NewALPNDialer creates a new ALPNDialer.
|
||||
func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer {
|
||||
return &ALPNDialer{
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// DialContext implements ContextDialer.
|
||||
func (d ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if d.cfg.TLSConfig == nil {
|
||||
return nil, trace.BadParameter("missing TLS config")
|
||||
}
|
||||
|
||||
var dialer ContextDialer = &net.Dialer{
|
||||
KeepAlive: d.cfg.KeepAlivePeriod,
|
||||
Timeout: d.cfg.DialTimeout,
|
||||
}
|
||||
if d.cfg.ALPNConnUpgradeRequired {
|
||||
dialer = newALPNConnUpgradeDialer(d.cfg.KeepAlivePeriod, d.cfg.DialTimeout, d.cfg.TLSConfig.InsecureSkipVerify)
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(conn, d.cfg.TLSConfig)
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
defer tlsConn.Close()
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// DialALPN a helper to dial using an ALPNDialer and returns a tls.Conn if
|
||||
// successful.
|
||||
func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (*tls.Conn, error) {
|
||||
conn, err := NewALPNDialer(cfg).DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
tlsConn, ok := conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return nil, trace.BadParameter("failed to convert to tls.Conn")
|
||||
}
|
||||
return tlsConn, nil
|
||||
}
|
|
@ -39,6 +39,7 @@ import (
|
|||
"github.com/gravitational/teleport/api/types"
|
||||
"github.com/gravitational/teleport/lib/auth"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
)
|
||||
|
||||
|
@ -106,9 +107,10 @@ func (s *Suite) GetCertPool() *x509.CertPool {
|
|||
return pool
|
||||
}
|
||||
|
||||
func (s *Suite) Start(t *testing.T) {
|
||||
func (s *Suite) CreateProxyServer(t *testing.T) *Proxy {
|
||||
serverCert := mustGenCertSignedWithCA(t, s.ca)
|
||||
tlsConfig := &tls.Config{
|
||||
NextProtos: common.ProtocolsToString(common.SupportedProtocols),
|
||||
ClientAuth: tls.VerifyClientCertIfGiven,
|
||||
ClientCAs: s.GetCertPool(),
|
||||
Certificates: []tls.Certificate{
|
||||
|
@ -130,6 +132,11 @@ func (s *Suite) Start(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
// Reset GetConfigForClient to simplify test setup.
|
||||
svr.cfg.IdentityTLSConfig.GetConfigForClient = nil
|
||||
return svr
|
||||
}
|
||||
|
||||
func (s *Suite) Start(t *testing.T) {
|
||||
svr := s.CreateProxyServer(t)
|
||||
|
||||
go func() {
|
||||
err := svr.Serve(context.Background())
|
||||
|
|
|
@ -19,7 +19,7 @@ package alpnproxy
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
@ -64,13 +64,14 @@ type LocalProxyConfig struct {
|
|||
SSHHostKeyCallback ssh.HostKeyCallback
|
||||
// SSHTrustedCluster allows selecting trusted cluster ssh subsystem request.
|
||||
SSHTrustedCluster string
|
||||
// ClientTLSConfig is a client TLS configuration used during establishing
|
||||
// connection to the RemoteProxyAddr.
|
||||
ClientTLSConfig *tls.Config
|
||||
// Certs are the client certificates used to connect to the remote Teleport Proxy.
|
||||
Certs []tls.Certificate
|
||||
// AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification.
|
||||
AWSCredentials *credentials.Credentials
|
||||
// RootCAs overwrites the root CAs used in tls.Config if specified.
|
||||
RootCAs *x509.CertPool
|
||||
// ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required.
|
||||
ALPNConnUpgradeRequired bool
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults verifies the constraints for LocalProxyConfig.
|
||||
|
@ -129,7 +130,7 @@ func (l *LocalProxy) Start(ctx context.Context) error {
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
go func() {
|
||||
if err := l.handleDownstreamConnection(ctx, conn, l.cfg.SNI); err != nil {
|
||||
if err := l.handleDownstreamConnection(ctx, conn); err != nil {
|
||||
if utils.IsOKNetworkError(err) {
|
||||
return
|
||||
}
|
||||
|
@ -146,14 +147,18 @@ func (l *LocalProxy) GetAddr() string {
|
|||
|
||||
// handleDownstreamConnection proxies the downstreamConn (connection established to the local proxy) and forward the
|
||||
// traffic to the upstreamConn (TLS connection to remote host).
|
||||
func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamConn net.Conn, serverName string) error {
|
||||
func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamConn net.Conn) error {
|
||||
defer downstreamConn.Close()
|
||||
|
||||
tlsConn, err := tls.Dial("tcp", l.cfg.RemoteProxyAddr, &tls.Config{
|
||||
NextProtos: l.cfg.GetProtocols(),
|
||||
InsecureSkipVerify: l.cfg.InsecureSkipVerify,
|
||||
ServerName: serverName,
|
||||
Certificates: l.cfg.Certs,
|
||||
tlsConn, err := DialALPN(ctx, l.cfg.RemoteProxyAddr, ALPNDialerConfig{
|
||||
ALPNConnUpgradeRequired: l.cfg.ALPNConnUpgradeRequired,
|
||||
TLSConfig: &tls.Config{
|
||||
NextProtos: l.cfg.GetProtocols(),
|
||||
InsecureSkipVerify: l.cfg.InsecureSkipVerify,
|
||||
ServerName: l.cfg.SNI,
|
||||
Certificates: l.cfg.Certs,
|
||||
RootCAs: l.cfg.RootCAs,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
@ -166,32 +171,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC
|
|||
upstreamConn = NewPingConn(tlsConn)
|
||||
}
|
||||
|
||||
errC := make(chan error, 2)
|
||||
go func() {
|
||||
defer downstreamConn.Close()
|
||||
defer upstreamConn.Close()
|
||||
_, err := io.Copy(downstreamConn, upstreamConn)
|
||||
errC <- err
|
||||
}()
|
||||
go func() {
|
||||
defer downstreamConn.Close()
|
||||
defer upstreamConn.Close()
|
||||
_, err := io.Copy(upstreamConn, downstreamConn)
|
||||
errC <- err
|
||||
}()
|
||||
|
||||
var errs []error
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return trace.NewAggregate(append(errs, ctx.Err())...)
|
||||
case err := <-errC:
|
||||
if err != nil && !utils.IsOKNetworkError(err) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return trace.NewAggregate(errs...)
|
||||
return trace.Wrap(utils.ProxyConn(ctx, downstreamConn, upstreamConn))
|
||||
}
|
||||
|
||||
func (l *LocalProxy) Close() error {
|
||||
|
|
|
@ -324,7 +324,7 @@ func (p *Proxy) Serve(ctx context.Context) error {
|
|||
// For example in ReverseTunnel handles connection asynchronously and closing conn after
|
||||
// service handler returned will break service logic.
|
||||
// https://github.com/gravitational/teleport/blob/master/lib/sshutils/server.go#L397
|
||||
if err := p.handleConn(ctx, clientConn); err != nil {
|
||||
if err := p.handleConn(ctx, clientConn, nil); err != nil {
|
||||
if cerr := clientConn.Close(); cerr != nil && !utils.IsOKNetworkError(cerr) {
|
||||
p.log.WithError(cerr).Warnf("Failed to close client connection.")
|
||||
}
|
||||
|
@ -361,7 +361,7 @@ type HandlerFuncWithInfo func(ctx context.Context, conn net.Conn, info Connectio
|
|||
// 5) For backward compatibility check RouteToDatabase identity field
|
||||
// was set if yes forward to the generic TLS DB handler.
|
||||
// 6) Forward connection to the handler obtained in step 2.
|
||||
func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn) error {
|
||||
func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn, defaultOverride *tls.Config) error {
|
||||
hello, conn, err := p.readHelloMessageWithoutTLSTermination(clientConn)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
@ -381,7 +381,7 @@ func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn) error {
|
|||
return trace.Wrap(handlerDesc.handle(ctx, conn, connInfo))
|
||||
}
|
||||
|
||||
tlsConn := tls.Server(conn, p.getTLSConfig(handlerDesc))
|
||||
tlsConn := tls.Server(conn, p.getTLSConfig(handlerDesc, defaultOverride))
|
||||
if err := tlsConn.SetReadDeadline(p.cfg.Clock.Now().Add(p.cfg.ReadDeadline)); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
@ -437,12 +437,17 @@ func (p *Proxy) handlePingConnection(ctx context.Context, conn *tls.Conn) net.Co
|
|||
return pingConn
|
||||
}
|
||||
|
||||
// getTLSConfig returns HandlerDesc.TLSConfig if custom TLS configuration was set for the handler
|
||||
// otherwise the ProxyConfig.WebTLSConfig is used.
|
||||
func (p *Proxy) getTLSConfig(desc *HandlerDecs) *tls.Config {
|
||||
// getTLSConfig picks the TLS config with the following priority:
|
||||
// - TLS config found in the provided handler.
|
||||
// - A default override.
|
||||
// - The default TLS config (cfg.WebTLSConfig).
|
||||
func (p *Proxy) getTLSConfig(desc *HandlerDecs, defaultOverride *tls.Config) *tls.Config {
|
||||
if desc.TLSConfig != nil {
|
||||
return desc.TLSConfig
|
||||
}
|
||||
if defaultOverride != nil {
|
||||
return defaultOverride
|
||||
}
|
||||
return p.cfg.WebTLSConfig
|
||||
}
|
||||
|
||||
|
@ -578,3 +583,33 @@ func (p *Proxy) Close() error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeConnectionHandler creates a ConnectionHandler which provides a callback
|
||||
// to handle incoming connections by this ALPN proxy server.
|
||||
func (p *Proxy) MakeConnectionHandler(defaultOverride *tls.Config) ConnectionHandler {
|
||||
return func(ctx context.Context, conn net.Conn) error {
|
||||
return p.handleConn(ctx, conn, defaultOverride)
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectionHandler defines a function for serving incoming connections.
|
||||
type ConnectionHandler func(ctx context.Context, conn net.Conn) error
|
||||
|
||||
// ConnectionHandlerWrapper is a wrapper of ConnectionHandler. This wrapper is
|
||||
// mainly used as a placeholder to resolve circular dependencies.
|
||||
type ConnectionHandlerWrapper struct {
|
||||
h ConnectionHandler
|
||||
}
|
||||
|
||||
// Set updates inner ConnectionHandler to use.
|
||||
func (w *ConnectionHandlerWrapper) Set(h ConnectionHandler) {
|
||||
w.h = h
|
||||
}
|
||||
|
||||
// HandleConnection implements ConnectionHandler.
|
||||
func (w *ConnectionHandlerWrapper) HandleConnection(ctx context.Context, conn net.Conn) error {
|
||||
if w.h == nil {
|
||||
return trace.NotFound("missing ConnectionHandler")
|
||||
}
|
||||
return w.h(ctx, conn)
|
||||
}
|
||||
|
|
|
@ -17,13 +17,16 @@ limitations under the License.
|
|||
package alpnproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/teleport/api/constants"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
|
@ -335,6 +338,73 @@ func TestProxyHTTPConnection(t *testing.T) {
|
|||
mustSuccessfullyCallHTTPSServer(t, suite.GetServerAddress(), client)
|
||||
}
|
||||
|
||||
// TestProxyMakeConnectionHandler creates a ConnectionHandler from the ALPN
|
||||
// proxy server, and verifies ALPN protocol is properly handled through the
|
||||
// ConnectionHandler.
|
||||
func TestProxyMakeConnectionHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
suite := NewSuite(t)
|
||||
|
||||
// Create a HTTP server and register the listener to ALPN server.
|
||||
lw := NewMuxListenerWrapper(nil, suite.serverListener)
|
||||
mustStartHTTPServer(t, lw)
|
||||
|
||||
suite.router = NewRouter()
|
||||
suite.router.Add(HandlerDecs{
|
||||
MatchFunc: MatchByProtocol(common.ProtocolHTTP),
|
||||
Handler: lw.HandleConnection,
|
||||
})
|
||||
|
||||
svr := suite.CreateProxyServer(t)
|
||||
customCA := mustGenSelfSignedCert(t)
|
||||
|
||||
// Create a ConnectionHandler from the proxy server.
|
||||
alpnConnHandler := svr.MakeConnectionHandler(&tls.Config{
|
||||
NextProtos: []string{string(common.ProtocolHTTP)},
|
||||
Certificates: []tls.Certificate{
|
||||
mustGenCertSignedWithCA(t, customCA),
|
||||
},
|
||||
})
|
||||
|
||||
// Prepare net.Conn to be used for the created alpnConnHandler.
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
// Let alpnConnHandler serve the connection in a separate go routine.
|
||||
go func() {
|
||||
alpnConnHandler(timeoutCtx, serverConn)
|
||||
require.NoError(t, timeoutCtx.Err())
|
||||
}()
|
||||
|
||||
// Send client request.
|
||||
req, err := http.NewRequest("GET", "https://localhost/test", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use the customCA to validate default TLS config override.
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(customCA.Cert)
|
||||
|
||||
clientTLSConn := tls.Client(clientConn, &tls.Config{
|
||||
NextProtos: []string{string(common.ProtocolHTTP)},
|
||||
RootCAs: pool,
|
||||
ServerName: "localhost",
|
||||
})
|
||||
defer clientTLSConn.Close()
|
||||
|
||||
require.NoError(t, clientTLSConn.Handshake())
|
||||
require.Equal(t, string(common.ProtocolHTTP), clientTLSConn.ConnectionState().NegotiatedProtocol)
|
||||
require.NoError(t, req.Write(clientTLSConn))
|
||||
|
||||
response, err := http.ReadResponse(bufio.NewReader(clientTLSConn), req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, response.StatusCode)
|
||||
}
|
||||
|
||||
// TestProxyALPNProtocolsRouting tests the routing based on client TLS NextProtos values.
|
||||
func TestProxyALPNProtocolsRouting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -24,6 +24,7 @@ import (
|
|||
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/srv/db/common"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
|
||||
"github.com/go-mysql-org/go-mysql/client"
|
||||
|
@ -54,6 +55,20 @@ func MakeTestClient(config common.TestClientConfig) (*client.Conn, error) {
|
|||
return conn, nil
|
||||
}
|
||||
|
||||
// MakeTestClientWithoutTLS returns a MySQL client connection without setting
|
||||
// TLS config to the MySQL client.
|
||||
func MakeTestClientWithoutTLS(addr string, routeToDatabase tlsca.RouteToDatabase) (*client.Conn, error) {
|
||||
conn, err := client.Connect(addr,
|
||||
routeToDatabase.Username,
|
||||
"",
|
||||
routeToDatabase.Database,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// TestServer is a test MySQL server used in functional database
|
||||
// access tests.
|
||||
type TestServer struct {
|
||||
|
|
|
@ -217,13 +217,12 @@ func IsSelfSigned(certificateChain []*x509.Certificate) bool {
|
|||
return bytes.Equal(certificateChain[0].SubjectKeyId, certificateChain[0].AuthorityKeyId)
|
||||
}
|
||||
|
||||
// ReadCertificateChain parses PEM encoded bytes that can contain one or
|
||||
// ReadCertificates parses PEM encoded bytes that can contain one or
|
||||
// multiple certificates and returns a slice of x509.Certificate.
|
||||
func ReadCertificateChain(certificateChainBytes []byte) ([]*x509.Certificate, error) {
|
||||
// build the certificate chain next
|
||||
func ReadCertificates(certificateChainBytes []byte) ([]*x509.Certificate, error) {
|
||||
var (
|
||||
certificateBlock *pem.Block
|
||||
certificateChain [][]byte
|
||||
certificates [][]byte
|
||||
)
|
||||
remainingBytes := bytes.TrimSpace(certificateChainBytes)
|
||||
|
||||
|
@ -232,29 +231,59 @@ func ReadCertificateChain(certificateChainBytes []byte) ([]*x509.Certificate, er
|
|||
if certificateBlock == nil || certificateBlock.Type != pemBlockCertificate {
|
||||
return nil, trace.NotFound("no PEM data found")
|
||||
}
|
||||
certificateChain = append(certificateChain, certificateBlock.Bytes)
|
||||
certificates = append(certificates, certificateBlock.Bytes)
|
||||
|
||||
if len(remainingBytes) == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// build a concatenated certificate chain
|
||||
// build concatenated certificates into a buffer
|
||||
var buf bytes.Buffer
|
||||
for _, cc := range certificateChain {
|
||||
for _, cc := range certificates {
|
||||
_, err := buf.Write(cc)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
// parse the chain and get a slice of x509.Certificates.
|
||||
x509Chain, err := x509.ParseCertificates(buf.Bytes())
|
||||
// parse the buffer and get a slice of x509.Certificates.
|
||||
x509Certs, err := x509.ParseCertificates(buf.Bytes())
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return x509Chain, nil
|
||||
return x509Certs, nil
|
||||
}
|
||||
|
||||
// ReadCertificatesFromPath parses PEM encoded certificates from provided path.
|
||||
func ReadCertificatesFromPath(path string) ([]*x509.Certificate, error) {
|
||||
bytes, err := ReadPath(path)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
certs, err := ReadCertificates(bytes)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
// NewCertPoolFromPath creates a new x509.CertPool from provided path.
|
||||
func NewCertPoolFromPath(path string) (*x509.CertPool, error) {
|
||||
// x509.CertPool.AppendCertsFromPEM skips parse errors. Using our own
|
||||
// implementation here to be more strict.
|
||||
cas, err := ReadCertificatesFromPath(path)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
for _, ca := range cas {
|
||||
pool.AddCert(ca)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
const pemBlockCertificate = "CERTIFICATE"
|
||||
|
|
|
@ -17,7 +17,6 @@ limitations under the License.
|
|||
package utils
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
|
@ -31,17 +30,14 @@ import (
|
|||
func TestRejectsInvalidPEMData(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := ReadCertificateChain([]byte("no data"))
|
||||
_, err := ReadCertificates([]byte("no data"))
|
||||
require.IsType(t, trace.Unwrap(err), &trace.NotFoundError{})
|
||||
}
|
||||
|
||||
func TestRejectsSelfSignedCertificate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
certificateChainBytes, err := os.ReadFile("../../fixtures/certs/ca.pem")
|
||||
require.NoError(t, err)
|
||||
|
||||
certificateChain, err := ReadCertificateChain(certificateChainBytes)
|
||||
certificateChain, err := ReadCertificatesFromPath("../../fixtures/certs/ca.pem")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = VerifyCertificateChain(certificateChain)
|
||||
|
@ -52,3 +48,11 @@ func TestRejectsSelfSignedCertificate(t *testing.T) {
|
|||
require.ErrorContains(t, err, "x509: certificate signed by unknown authority")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCertPoolFromPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pool, err := NewCertPoolFromPath("../../fixtures/certs/ca.pem")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pool.Subjects(), 1)
|
||||
}
|
||||
|
|
|
@ -200,6 +200,10 @@ type Config struct {
|
|||
// PublicProxyAddr is used to template the public proxy address
|
||||
// into the installer script responses
|
||||
PublicProxyAddr string
|
||||
|
||||
// ALPNHandler is the ALPN connection handler for handling upgraded ALPN
|
||||
// connection through a HTTP upgrade call.
|
||||
ALPNHandler ConnectionHandler
|
||||
}
|
||||
|
||||
type APIHandler struct {
|
||||
|
@ -209,6 +213,9 @@ type APIHandler struct {
|
|||
appHandler *app.Handler
|
||||
}
|
||||
|
||||
// ConnectionHandler defines a function for serving incoming connections.
|
||||
type ConnectionHandler func(ctx context.Context, conn net.Conn) error
|
||||
|
||||
// Check if this request should be forwarded to an application handler to
|
||||
// be handled by the UI and handle the request appropriately.
|
||||
func (h *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -609,6 +616,9 @@ func (h *Handler) bindDefaultEndpoints(challengeLimiter *limiter.RateLimiter) {
|
|||
h.GET("/webapi/sites/:site/diagnostics/connections/:connectionid", h.WithClusterAuth(h.getConnectionDiagnostic))
|
||||
// Diagnose a Connection
|
||||
h.POST("/webapi/sites/:site/diagnostics/connections", h.WithClusterAuth(h.diagnoseConnection))
|
||||
|
||||
// Connection upgrades.
|
||||
h.GET("/webapi/connectionupgrade", httplib.MakeHandler(h.connectionUpgrade))
|
||||
}
|
||||
|
||||
// GetProxyClient returns authenticated auth server client
|
||||
|
|
131
lib/web/conn_upgrade.go
Normal file
131
lib/web/conn_upgrade.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
)
|
||||
|
||||
// selectConnectionUpgrade selects the requested upgrade type and returns the
|
||||
// corresponding handler.
|
||||
func (h *Handler) selectConnectionUpgrade(r *http.Request) (string, ConnectionHandler, error) {
|
||||
upgrades := r.Header.Values(teleport.WebAPIConnUpgradeHeader)
|
||||
for _, upgradeType := range upgrades {
|
||||
switch upgradeType {
|
||||
case teleport.WebAPIConnUpgradeTypeALPN:
|
||||
return upgradeType, h.upgradeALPN, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil, trace.BadParameter("unsupported upgrade types: %v", upgrades)
|
||||
}
|
||||
|
||||
// connectionUpgrade handles connection upgrades.
|
||||
func (h *Handler) connectionUpgrade(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
|
||||
upgradeType, upgradeHandler, err := h.selectConnectionUpgrade(r)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, trace.BadParameter("failed to hijack connection")
|
||||
}
|
||||
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Since w is hijacked, there is no point returning an error for response
|
||||
// starting at this point.
|
||||
if err := writeUpgradeResponse(conn, upgradeType); err != nil {
|
||||
h.log.WithError(err).Error("Failed to write upgrade response.")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if err := upgradeHandler(r.Context(), conn); err != nil && !utils.IsOKNetworkError(err) {
|
||||
h.log.WithError(err).Errorf("Failed to handle %v upgrade request.", upgradeType)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (h *Handler) upgradeALPN(ctx context.Context, conn net.Conn) error {
|
||||
if h.cfg.ALPNHandler == nil {
|
||||
return trace.BadParameter("missing ALPNHandler")
|
||||
}
|
||||
|
||||
// ALPNHandler may handle some connections asynchronously. Here we want to
|
||||
// block until the handling is done by waiting until the connection is
|
||||
// closed.
|
||||
waitConn := newWaitConn(ctx, conn)
|
||||
defer waitConn.WaitForClose()
|
||||
|
||||
return h.cfg.ALPNHandler(ctx, waitConn)
|
||||
}
|
||||
|
||||
func writeUpgradeResponse(w io.Writer, upgradeType string) error {
|
||||
header := make(http.Header)
|
||||
header.Add(teleport.WebAPIConnUpgradeHeader, upgradeType)
|
||||
response := &http.Response{
|
||||
Status: http.StatusText(http.StatusSwitchingProtocols),
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
Header: header,
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
}
|
||||
return response.Write(w)
|
||||
}
|
||||
|
||||
// waitConn is a net.Conn that provides a "WaitForClose" function to wait until
|
||||
// the connection is closed.
|
||||
type waitConn struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// newWaitConn creates a new waitConn.
|
||||
func newWaitConn(ctx context.Context, conn net.Conn) *waitConn {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &waitConn{
|
||||
Conn: conn,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForClose blocks until the Close() function of this connection is called.
|
||||
func (conn *waitConn) WaitForClose() {
|
||||
<-conn.ctx.Done()
|
||||
}
|
||||
|
||||
// Close implements net.Conn.
|
||||
func (conn *waitConn) Close() error {
|
||||
err := conn.Conn.Close()
|
||||
conn.cancel()
|
||||
return trace.Wrap(err)
|
||||
}
|
126
lib/web/conn_upgrade_test.go
Normal file
126
lib/web/conn_upgrade_test.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
/*
|
||||
Copyright 2022 Gravitational, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteUpgradeResponse(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
require.NoError(t, writeUpgradeResponse(&buf, "custom"))
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(&buf), nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, resp.StatusCode, http.StatusSwitchingProtocols)
|
||||
require.Equal(t, "custom", resp.Header.Get("Upgrade"))
|
||||
}
|
||||
|
||||
func TestHandlerConnectionUpgrade(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
expectedPayload := "hello@"
|
||||
alpnHandler := func(_ context.Context, conn net.Conn) error {
|
||||
// Handles connection asynchronously to verify web handler waits until
|
||||
// connection is closed.
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
n, err := conn.Write([]byte(expectedPayload))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(expectedPayload), n)
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cherry picked some attributes to create a Handler to test only the
|
||||
// connection upgrade portion.
|
||||
h := &Handler{
|
||||
cfg: Config{
|
||||
ALPNHandler: alpnHandler,
|
||||
},
|
||||
log: newPackageLogger(),
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
|
||||
t.Run("unsupported type", func(t *testing.T) {
|
||||
r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil)
|
||||
require.NoError(t, err)
|
||||
r.Header.Add("Upgrade", "unsupported-protocol")
|
||||
|
||||
_, err = h.connectionUpgrade(httptest.NewRecorder(), r, nil)
|
||||
require.True(t, trace.IsBadParameter(err))
|
||||
})
|
||||
|
||||
t.Run("upgraded to ALPN", func(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil)
|
||||
require.NoError(t, err)
|
||||
r.Header.Add("Upgrade", "alpn")
|
||||
|
||||
go func() {
|
||||
// serverConn will be hijacked.
|
||||
w := newResponseWriterHijacker(nil, serverConn)
|
||||
_, err = h.connectionUpgrade(w, r, nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Verify clientConn receives http.StatusSwitchingProtocols.
|
||||
clientConnReader := bufio.NewReader(clientConn)
|
||||
response, err := http.ReadResponse(clientConnReader, r)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusSwitchingProtocols, response.StatusCode)
|
||||
|
||||
// Verify clientConn receives data sent by Config.ALPNHandler.
|
||||
receive, err := clientConnReader.ReadString(byte('@'))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expectedPayload, receive)
|
||||
})
|
||||
}
|
||||
|
||||
// responseWriterHijacker is a mock http.ResponseWriter that also serves a
|
||||
// net.Conn for http.Hijacker.
|
||||
type responseWriterHijacker struct {
|
||||
http.ResponseWriter
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func newResponseWriterHijacker(w http.ResponseWriter, conn net.Conn) *responseWriterHijacker {
|
||||
if w == nil {
|
||||
w = httptest.NewRecorder()
|
||||
}
|
||||
return &responseWriterHijacker{
|
||||
ResponseWriter: w,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (h responseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return h.conn, nil, nil
|
||||
}
|
|
@ -35,6 +35,7 @@ import (
|
|||
"github.com/gravitational/teleport/lib/client/db/dbcmd"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy"
|
||||
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
|
||||
"github.com/gravitational/teleport/lib/tlsca"
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
|
@ -555,7 +556,7 @@ func serializeDatabaseConfig(configInfo *dbConfigInfo, format string) (string, e
|
|||
// connection scenario and returns a list of options to use in the connect
|
||||
// command.
|
||||
func maybeStartLocalProxy(cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus, db *tlsca.RouteToDatabase,
|
||||
database types.Database, cluster string,
|
||||
database types.Database, rootClusterName string,
|
||||
) ([]dbcmd.ConnectCommandFunc, error) {
|
||||
if !isLocalProxyRequiredForDatabase(tc, db) {
|
||||
return []dbcmd.ConnectCommandFunc{}, nil
|
||||
|
@ -582,6 +583,7 @@ func maybeStartLocalProxy(cf *CLIConf, tc *client.TeleportClient, profile *clien
|
|||
database: database,
|
||||
listener: listener,
|
||||
localProxyTunnel: localProxyTunnel,
|
||||
rootClusterName: rootClusterName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
|
@ -609,7 +611,7 @@ func maybeStartLocalProxy(cf *CLIConf, tc *client.TeleportClient, profile *clien
|
|||
// validation, so connect to localhost.
|
||||
host := "localhost"
|
||||
return []dbcmd.ConnectCommandFunc{
|
||||
dbcmd.WithLocalProxy(host, addr.Port(0), profile.CACertPathForCluster(cluster)),
|
||||
dbcmd.WithLocalProxy(host, addr.Port(0), profile.CACertPathForCluster(rootClusterName)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -625,6 +627,7 @@ type localProxyConfig struct {
|
|||
// it's always true for Snowflake database. Value is copied here to not modify
|
||||
// cli arguments directly.
|
||||
localProxyTunnel bool
|
||||
rootClusterName string
|
||||
}
|
||||
|
||||
// prepareLocalProxyOptions created localProxyOpts needed to create local proxy from localProxyConfig.
|
||||
|
@ -639,12 +642,23 @@ func prepareLocalProxyOptions(arg *localProxyConfig) (localProxyOpts, error) {
|
|||
}
|
||||
|
||||
opts := localProxyOpts{
|
||||
proxyAddr: arg.teleportClient.WebProxyAddr,
|
||||
listener: arg.listener,
|
||||
protocols: []common.Protocol{common.Protocol(arg.routeToDatabase.Protocol)},
|
||||
insecure: arg.cliConf.InsecureSkipVerify,
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
proxyAddr: arg.teleportClient.WebProxyAddr,
|
||||
listener: arg.listener,
|
||||
protocols: []common.Protocol{common.Protocol(arg.routeToDatabase.Protocol)},
|
||||
insecure: arg.cliConf.InsecureSkipVerify,
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
alpnConnUpgradeRequired: alpnproxy.IsALPNConnUpgradeRequired(arg.teleportClient.WebProxyAddr, arg.cliConf.InsecureSkipVerify),
|
||||
}
|
||||
|
||||
// If ALPN connection upgrade is required, explicitly use the profile CAs
|
||||
// since the tunneled TLS routing connection serves the Host cert.
|
||||
if opts.alpnConnUpgradeRequired {
|
||||
profileCAs, err := utils.NewCertPoolFromPath(arg.profile.CACertPathForCluster(arg.rootClusterName))
|
||||
if err != nil {
|
||||
return localProxyOpts{}, trace.Wrap(err)
|
||||
}
|
||||
opts.rootCAs = profileCAs
|
||||
}
|
||||
|
||||
// For SQL Server connections, local proxy must be configured with the
|
||||
|
@ -705,11 +719,7 @@ func onDatabaseConnect(cf *CLIConf) error {
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
key, err := tc.LocalAgent().GetCoreKey()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
rootClusterName, err := key.RootClusterName()
|
||||
rootClusterName, err := tc.RootClusterName(cf.Context)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
|
|
@ -363,6 +363,7 @@ func onProxyCommandDB(cf *CLIConf) error {
|
|||
routeToDatabase: routeToDatabase,
|
||||
listener: listener,
|
||||
localProxyTunnel: cf.LocalProxyTunnel,
|
||||
rootClusterName: rootCluster,
|
||||
})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
@ -422,12 +423,14 @@ func onProxyCommandDB(cf *CLIConf) error {
|
|||
}
|
||||
|
||||
type localProxyOpts struct {
|
||||
proxyAddr string
|
||||
listener net.Listener
|
||||
protocols []alpncommon.Protocol
|
||||
insecure bool
|
||||
certFile string
|
||||
keyFile string
|
||||
proxyAddr string
|
||||
listener net.Listener
|
||||
protocols []alpncommon.Protocol
|
||||
insecure bool
|
||||
certFile string
|
||||
keyFile string
|
||||
rootCAs *x509.CertPool
|
||||
alpnConnUpgradeRequired bool
|
||||
}
|
||||
|
||||
// protocol returns the first protocol or string if configuration doesn't contain any protocols.
|
||||
|
@ -458,13 +461,15 @@ func mkLocalProxy(ctx context.Context, opts localProxyOpts) (*alpnproxy.LocalPro
|
|||
}
|
||||
|
||||
lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{
|
||||
InsecureSkipVerify: opts.insecure,
|
||||
RemoteProxyAddr: opts.proxyAddr,
|
||||
Protocols: protocols,
|
||||
Listener: opts.listener,
|
||||
ParentContext: ctx,
|
||||
SNI: address.Host(),
|
||||
Certs: certs,
|
||||
InsecureSkipVerify: opts.insecure,
|
||||
RemoteProxyAddr: opts.proxyAddr,
|
||||
Protocols: protocols,
|
||||
Listener: opts.listener,
|
||||
ParentContext: ctx,
|
||||
SNI: address.Host(),
|
||||
Certs: certs,
|
||||
RootCAs: opts.rootCAs,
|
||||
ALPNConnUpgradeRequired: opts.alpnConnUpgradeRequired,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
|
|
Loading…
Reference in a new issue