ALPN connection upgrade for MySQL behind ALB (#15669)

This commit is contained in:
STeve (Xin) Huang 2022-09-01 12:05:03 -04:00 committed by GitHub
parent 09cd4bfdd8
commit 8394f4fb48
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 1212 additions and 111 deletions

View file

@ -774,3 +774,15 @@ const UserSingleUseCertTTL = time.Minute
// StandardHTTPSPort is the default port used for the https URI scheme,
// cf. RFC 7230 § 2.7.2.
const StandardHTTPSPort = 443
const (
// WebAPIConnUpgrade is the HTTP web API to make the connection upgrade
// call.
WebAPIConnUpgrade = "/webapi/connectionupgrade"
// WebAPIConnUpgradeHeader is the header used to indicate the requested
// connection upgrade types in the connection upgrade API.
WebAPIConnUpgradeHeader = "Upgrade"
// WebAPIConnUpgradeTypeALPN is a connection upgrade type that specifies
// the upgraded connection should be handled by the ALPN handler.
WebAPIConnUpgradeTypeALPN = "alpn"
)

View file

@ -19,6 +19,8 @@ package integration
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509/pkix"
"encoding/json"
"net"
"net/http"
@ -33,11 +35,13 @@ import (
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/jackc/pgconn"
@ -470,20 +474,38 @@ func mustCreateKubeConfigFile(t *testing.T, config clientcmdapi.Config) string {
return configPath
}
func mustStartALPNLocalProxy(t *testing.T, addr string, protocol alpncommon.Protocol) *alpnproxy.LocalProxy {
func mustCreateListener(t *testing.T) net.Listener {
listener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
address, err := utils.ParseAddr(addr)
require.NoError(t, err)
lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{
t.Cleanup(func() {
listener.Close()
})
return listener
}
func mustStartALPNLocalProxy(t *testing.T, addr string, protocol alpncommon.Protocol) *alpnproxy.LocalProxy {
return mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{
RemoteProxyAddr: addr,
Protocols: []alpncommon.Protocol{protocol},
InsecureSkipVerify: true,
Listener: listener,
ParentContext: context.Background(),
SNI: address.Host(),
})
}
func mustStartALPNLocalProxyWithConfig(t *testing.T, config alpnproxy.LocalProxyConfig) *alpnproxy.LocalProxy {
if config.Listener == nil {
config.Listener = mustCreateListener(t)
}
if config.SNI == "" {
address, err := utils.ParseAddr(config.RemoteProxyAddr)
require.NoError(t, err)
config.SNI = address.Host()
}
if config.ParentContext == nil {
config.ParentContext = context.TODO()
}
lp, err := alpnproxy.NewLocalProxy(config)
require.NoError(t, err)
t.Cleanup(func() {
lp.Close()
@ -512,3 +534,71 @@ func makeNodeConfig(nodeName, authAddr string) *service.Config {
nodeConfig.CircuitBreakerConfig = breaker.NoopBreakerConfig()
return nodeConfig
}
func mustCreateSelfSignedCert(t *testing.T) tls.Certificate {
caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{
CommonName: "localhost",
}, []string{"localhost"}, defaults.CATTL)
require.NoError(t, err)
cert, err := tls.X509KeyPair(caCert, caKey)
require.NoError(t, err)
return cert
}
// mockAWSALBProxy is a mock proxy server that simulates an AWS application
// load balancer where ALPN is not supported. Note that this mock does not
// actually balance traffic.
type mockAWSALBProxy struct {
net.Listener
proxyAddr string
cert tls.Certificate
}
func (m *mockAWSALBProxy) serve(ctx context.Context, t *testing.T) {
for {
select {
case <-ctx.Done():
return
default:
}
conn, err := m.Accept()
if err != nil {
if utils.IsOKNetworkError(err) {
return
}
require.NoError(t, err)
return
}
go func() {
// Handshake with incoming client and drops ALPN.
downstreamConn := tls.Server(conn, &tls.Config{
Certificates: []tls.Certificate{m.cert},
})
require.NoError(t, downstreamConn.HandshakeContext(ctx))
// Make a connection to the proxy server with ALPN protos.
upstreamConn, err := tls.Dial("tcp", m.proxyAddr, &tls.Config{
InsecureSkipVerify: true,
})
require.NoError(t, err)
utils.ProxyConn(ctx, downstreamConn, upstreamConn)
}()
}
}
func mustStartMockALBProxy(t *testing.T, proxyAddr string) *mockAWSALBProxy {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
m := &mockAWSALBProxy{
proxyAddr: proxyAddr,
Listener: mustCreateListener(t),
cert: mustCreateSelfSignedCert(t),
}
go m.serve(ctx, t)
return m
}

View file

@ -41,6 +41,7 @@ import (
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mongodb"
@ -434,8 +435,8 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {
// Disconnect.
err = client.Close()
require.NoError(t, err)
})
t.Run("connect to leaf cluster via proxy", func(t *testing.T) {
client, err := mysql.MakeTestClient(common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
@ -620,6 +621,134 @@ func TestALPNSNIProxyDatabaseAccess(t *testing.T) {
require.NoError(t, err)
})
})
// Simulate situations where an AWS ALB is between client and the Teleport
// Proxy service, which drops ALPN along the way. The ALPN local proxy will
// need to make a connection upgrade first through a web API provided by
// the Proxy server and then tunnel the original ALPN/TLS routing traffic
// inside this tunnel.
t.Run("ALPN connection upgrade", func(t *testing.T) {
// Make a mock ALB which points to the Teleport Proxy Service. Then
// ALPN local proxies will point to this ALB instead.
albProxy := mustStartMockALBProxy(t, pack.root.cluster.Web)
// Test a protocol in the alpncommon.IsDBTLSProtocol list where
// the database client will perform a native TLS handshake.
//
// Packet layers:
// - HTTPS served by Teleport web server for connection upgrade
// - TLS routing with alpncommon.ProtocolMongoDB (no client cert)
// - TLS with client cert (provided by the database client)
// - MongoDB
t.Run("database client native TLS", func(t *testing.T) {
lp := mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{
RemoteProxyAddr: albProxy.Addr().String(),
Protocols: []alpncommon.Protocol{alpncommon.ProtocolMongoDB},
ALPNConnUpgradeRequired: true,
InsecureSkipVerify: true,
})
client, err := mongodb.MakeTestClient(context.Background(), common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
AuthServer: pack.root.cluster.Process.GetAuthServer(),
Address: lp.GetAddr(),
Cluster: pack.root.cluster.Secrets.SiteName,
Username: pack.root.user.GetName(),
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: pack.root.mongoService.Name,
Protocol: pack.root.mongoService.Protocol,
Username: "admin",
},
})
require.NoError(t, err)
// Execute a query.
_, err = client.Database("test").Collection("test").Find(context.Background(), bson.M{})
require.NoError(t, err)
// Disconnect.
require.NoError(t, client.Disconnect(context.Background()))
})
// Test the case where the database client cert is terminated within
// the database protocol.
//
// Packet layers:
// - HTTPS served by Teleport web server for connection upgrade
// - TLS routing with alpncommon.ProtocolMySQL (no client cert)
// - MySQL handshake then upgrade to TLS with Teleport issued client cert
// - MySQL protocol
t.Run("MySQL custom TLS", func(t *testing.T) {
lp := mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{
RemoteProxyAddr: albProxy.Addr().String(),
Protocols: []alpncommon.Protocol{alpncommon.ProtocolMySQL},
ALPNConnUpgradeRequired: true,
InsecureSkipVerify: true,
})
client, err := mysql.MakeTestClient(common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
AuthServer: pack.root.cluster.Process.GetAuthServer(),
Address: lp.GetAddr(),
Cluster: pack.root.cluster.Secrets.SiteName,
Username: pack.root.user.GetName(),
RouteToDatabase: tlsca.RouteToDatabase{
ServiceName: pack.root.mysqlService.Name,
Protocol: pack.root.mysqlService.Protocol,
Username: "root",
},
})
require.NoError(t, err)
// Execute a query.
result, err := client.Execute("select 1")
require.NoError(t, err)
require.Equal(t, mysql.TestQueryResponse, result)
// Disconnect.
require.NoError(t, client.Close())
})
// Test the case where the client cert is terminated by Teleport and
// the database client sends data in plain database protocol.
//
// Packet layers:
// - HTTPS served by Teleport web server for connection upgrade
// - TLS routing with alpncommon.ProtocolMySQL (client cert provided by ALPN local proxy)
// - MySQL protocol
t.Run("authenticated tunnel", func(t *testing.T) {
routeToDatabase := tlsca.RouteToDatabase{
ServiceName: pack.root.mysqlService.Name,
Protocol: pack.root.mysqlService.Protocol,
Username: "root",
}
clientTLSConfig, err := common.MakeTestClientTLSConfig(common.TestClientConfig{
AuthClient: pack.root.cluster.GetSiteAPI(pack.root.cluster.Secrets.SiteName),
AuthServer: pack.root.cluster.Process.GetAuthServer(),
Cluster: pack.root.cluster.Secrets.SiteName,
Username: pack.root.user.GetName(),
RouteToDatabase: routeToDatabase,
})
require.NoError(t, err)
lp := mustStartALPNLocalProxyWithConfig(t, alpnproxy.LocalProxyConfig{
RemoteProxyAddr: albProxy.Addr().String(),
Protocols: []alpncommon.Protocol{alpncommon.ProtocolMySQL},
ALPNConnUpgradeRequired: true,
InsecureSkipVerify: true,
Certs: clientTLSConfig.Certificates,
})
client, err := mysql.MakeTestClientWithoutTLS(lp.GetAddr(), routeToDatabase)
require.NoError(t, err)
// Execute a query.
result, err := client.Execute("select 1")
require.NoError(t, err)
require.Equal(t, mysql.TestQueryResponse, result)
// Disconnect.
require.NoError(t, client.Close())
})
})
}
// TestALPNSNIProxyAppAccess tests application access via ALPN SNI proxy service.

View file

@ -783,11 +783,7 @@ func applyProxyConfig(fc *FileConfig, cfg *service.Config) error {
// was passed. If the certificate is not self-signed, verify the certificate
// chain from leaf to root with the trust store on the computer so browsers
// don't complain.
certificateChainBytes, err := utils.ReadPath(p.Certificate)
if err != nil {
return trace.Wrap(err)
}
certificateChain, err := utils.ReadCertificateChain(certificateChainBytes)
certificateChain, err := utils.ReadCertificatesFromPath(p.Certificate)
if err != nil {
return trace.Wrap(err)
}
@ -1372,11 +1368,7 @@ func applyMetricsConfig(fc *FileConfig, cfg *service.Config) error {
return trace.NotFound("metrics service cert does not exist: %s", p.Certificate)
}
certificateChainBytes, err := utils.ReadPath(p.Certificate)
if err != nil {
return trace.Wrap(err)
}
certificateChain, err := utils.ReadCertificateChain(certificateChainBytes)
certificateChain, err := utils.ReadCertificatesFromPath(p.Certificate)
if err != nil {
return trace.Wrap(err)
}

View file

@ -3402,6 +3402,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
// Register web proxy server
alpnHandlerForWeb := &alpnproxy.ConnectionHandlerWrapper{}
var webServer *http.Server
var webHandler *web.APIHandler
var minimalWebServer *http.Server
@ -3441,6 +3442,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
ClusterFeatures: process.getClusterFeatures(),
ProxySettings: proxySettings,
PublicProxyAddr: process.proxyPublicAddr().Addr,
ALPNHandler: alpnHandlerForWeb.HandleConnection,
}
webHandler, err = web.NewHandler(webConfig)
if err != nil {
@ -3810,6 +3812,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
if err != nil {
return trace.Wrap(err)
}
alpnTLSConfigForWeb := process.setupALPNTLSConfigForWeb(serverTLSConfig, accessPoint, clusterName)
alpnHandlerForWeb.Set(alpnServer.MakeConnectionHandler(alpnTLSConfigForWeb))
process.RegisterCriticalFunc("proxy.tls.alpn.sni.proxy", func() error {
log.Infof("Starting TLS ALPN SNI proxy server on %v.", listeners.alpn.Addr())
if err := alpnServer.Serve(process.ExitContext()); err != nil {
@ -4022,10 +4028,6 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers
tlsConfig.NextProtos = apiutils.Deduplicate(append(tlsConfig.NextProtos, acme.ALPNProto))
}
// Go 1.17 introduced strict ALPN https://golang.org/doc/go1.17#ALPN If a client protocol is not recognized
// the TLS handshake will fail.
tlsConfig.NextProtos = apiutils.Deduplicate(append(tlsConfig.NextProtos, alpncommon.ProtocolsToString(alpncommon.SupportedProtocols)...))
for _, pair := range process.Config.Proxy.KeyPairs {
process.Config.Log.Infof("Loading TLS certificate %v and key %v.", pair.Certificate, pair.PrivateKey)
@ -4036,6 +4038,18 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers
tlsConfig.Certificates = append(tlsConfig.Certificates, certificate)
}
setupTLSConfigALPNProtocols(tlsConfig)
setupTLSConfigClientCAsForCluster(tlsConfig, accessPoint, clusterName)
return tlsConfig, nil
}
func setupTLSConfigALPNProtocols(tlsConfig *tls.Config) {
// Go 1.17 introduced strict ALPN https://golang.org/doc/go1.17#ALPN If a client protocol is not recognized
// the TLS handshake will fail.
tlsConfig.NextProtos = apiutils.Deduplicate(append(tlsConfig.NextProtos, alpncommon.ProtocolsToString(alpncommon.SupportedProtocols)...))
}
func setupTLSConfigClientCAsForCluster(tlsConfig *tls.Config, accessPoint auth.ReadProxyAccessPoint, clusterName string) {
tlsConfig.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
tlsClone := tlsConfig.Clone()
@ -4061,10 +4075,18 @@ func (process *TeleportProcess) setupProxyTLSConfig(conn *Connector, tsrv revers
return tlsClone, nil
}
return tlsConfig, nil
}
func setupALPNRouter(listeners *proxyListeners, serverTLSConf *tls.Config, cfg *Config) *alpnproxy.Router {
func (process *TeleportProcess) setupALPNTLSConfigForWeb(serverTLSConfig *tls.Config, accessPoint auth.ReadProxyAccessPoint, clusterName string) *tls.Config {
tlsConfig := utils.TLSConfig(process.Config.CipherSuites)
tlsConfig.Certificates = serverTLSConfig.Certificates
setupTLSConfigALPNProtocols(tlsConfig)
setupTLSConfigClientCAsForCluster(tlsConfig, accessPoint, clusterName)
return tlsConfig
}
func setupALPNRouter(listeners *proxyListeners, serverTLSConfig *tls.Config, cfg *Config) *alpnproxy.Router {
if listeners.web == nil || cfg.Proxy.DisableTLS || cfg.Proxy.DisableALPNSNIListener {
return nil
}
@ -4112,7 +4134,7 @@ func setupALPNRouter(listeners *proxyListeners, serverTLSConf *tls.Config, cfg *
router.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByProtocol(alpncommon.ProtocolProxySSH),
Handler: sshProxyListener.HandleConnection,
TLSConfig: serverTLSConf,
TLSConfig: serverTLSConfig,
})
listeners.ssh = sshProxyListener

View file

@ -0,0 +1,148 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package alpnproxy
import (
"bufio"
"context"
"crypto/tls"
"net"
"net/http"
"net/url"
"time"
"github.com/gravitational/teleport"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
)
// IsALPNConnUpgradeRequired returns true if a tunnel is required through a HTTP
// connection upgrade for ALPN connections.
//
// The function makes a test connection to the Proxy Service and checks if the
// ALPN is supported. If not, the Proxy Service is likely behind an AWS ALB or
// some custom proxy services that strip out ALPN and SNI information on the
// way to our Proxy Service.
//
// In those cases, the Teleport client should make a HTTP "upgrade" call to the
// Proxy Service to establish a tunnel for the originally planned traffic to
// preserve the ALPN and SNI information.
func IsALPNConnUpgradeRequired(addr string, insecure bool) bool {
netDialer := &net.Dialer{
Timeout: defaults.DefaultDialTimeout,
}
tlsConfig := &tls.Config{
NextProtos: []string{string(common.ProtocolReverseTunnel)},
InsecureSkipVerify: insecure,
}
testConn, err := tls.DialWithDialer(netDialer, "tcp", addr, tlsConfig)
if err != nil {
// If dialing TLS fails for any reason, we assume connection upgrade is
// not required so it will fallback to original connection method.
//
// This includes handshake failures where both peers support ALPN but
// no common protocol is getting negotiated. We may have to revisit
// this situation or make it configurable if we have to get through a
// middleman with this behavior. For now, we are only interested in the
// case where the middleman does not support ALPN.
logrus.Infof("ALPN connection upgrade test failed for %q: %v.", addr, err)
return false
}
defer testConn.Close()
// Upgrade required when ALPN is not supported on the remote side so
// NegotiatedProtocol comes back as empty.
result := testConn.ConnectionState().NegotiatedProtocol == ""
logrus.Debugf("ALPN connection upgrade required for %q: %v.", addr, result)
return result
}
// alpnConnUpgradeDialer makes an "HTTP" upgrade call to the Proxy Service then
// tunnels the connection with this connection upgrade.
type alpnConnUpgradeDialer struct {
netDialer *net.Dialer
insecure bool
}
// newALPNConnUpgradeDialer creates a new alpnConnUpgradeDialer.
func newALPNConnUpgradeDialer(keepAlivePeriod, dialTimeout time.Duration, insecure bool) ContextDialer {
return &alpnConnUpgradeDialer{
insecure: insecure,
netDialer: &net.Dialer{
KeepAlive: keepAlivePeriod,
Timeout: dialTimeout,
},
}
}
// DialContext implements ContextDialer
func (d alpnConnUpgradeDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
logrus.Debugf("ALPN connection upgrade for %v.", addr)
tlsConn, err := tls.DialWithDialer(d.netDialer, network, addr, &tls.Config{
InsecureSkipVerify: d.insecure,
})
if err != nil {
return nil, trace.Wrap(err)
}
err = upgradeConnThroughWebAPI(tlsConn, url.URL{
Host: addr,
Scheme: "https",
Path: teleport.WebAPIConnUpgrade,
})
if err != nil {
defer tlsConn.Close()
return nil, trace.Wrap(err)
}
return tlsConn, nil
}
func upgradeConnThroughWebAPI(conn net.Conn, api url.URL) error {
req, err := http.NewRequest(http.MethodGet, api.String(), nil)
if err != nil {
return trace.Wrap(err)
}
// For now, only "alpn" is supported.
req.Header.Add(teleport.WebAPIConnUpgradeHeader, teleport.WebAPIConnUpgradeTypeALPN)
// Send the request and check if upgrade is successful.
if err = req.Write(conn); err != nil {
return trace.Wrap(err)
}
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
if err != nil {
return trace.Wrap(err)
}
defer resp.Body.Close()
if http.StatusSwitchingProtocols != resp.StatusCode {
if http.StatusNotFound == resp.StatusCode {
return trace.NotImplemented(
"connection upgrade call to %q failed with status code %v. Please upgrade the server and try again.",
teleport.WebAPIConnUpgrade,
resp.StatusCode,
)
}
return trace.BadParameter("failed to switch Protocols %v", resp.StatusCode)
}
return nil
}

View file

@ -0,0 +1,186 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package alpnproxy
import (
"context"
"crypto/tls"
"crypto/x509/pkix"
"errors"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/tlsca"
)
func TestIsALPNConnUpgradeRequired(t *testing.T) {
t.Parallel()
tests := []struct {
name string
serverProtos []string
expectedResult bool
}{
{
name: "upgrade required",
serverProtos: nil, // Use nil for NextProtos to simulate no ALPN support.
expectedResult: true,
},
{
name: "upgrade not required (proto negotiated)",
serverProtos: []string{string(common.ProtocolReverseTunnel)},
expectedResult: false,
},
{
name: "upgrade not required (handshake error)",
serverProtos: []string{"unknown"},
expectedResult: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := mustStartMockALPNServer(t, test.serverProtos)
require.Equal(t, test.expectedResult, IsALPNConnUpgradeRequired(server.Addr().String(), true))
})
}
}
func TestALPNConUpgradeDialer(t *testing.T) {
t.Parallel()
t.Run("connection upgraded", func(t *testing.T) {
server := httptest.NewTLSServer(mockConnUpgradeHandler(t, "alpn", []byte("hello")))
addr, err := url.Parse(server.URL)
require.NoError(t, err)
dialer := newALPNConnUpgradeDialer(0, 5*time.Second, true)
conn, err := dialer.DialContext(context.TODO(), "tcp", addr.Host)
require.NoError(t, err)
data := make([]byte, 100)
n, err := conn.Read(data)
require.NoError(t, err)
require.Equal(t, string(data[:n]), "hello")
})
t.Run("connection upgrade API not found", func(t *testing.T) {
server := httptest.NewTLSServer(http.NotFoundHandler())
addr, err := url.Parse(server.URL)
require.NoError(t, err)
dialer := newALPNConnUpgradeDialer(0, 5*time.Second, true)
_, err = dialer.DialContext(context.TODO(), "tcp", addr.Host)
require.Error(t, err)
})
}
type mockALPNServer struct {
net.Listener
cert tls.Certificate
supportedProtos []string
}
func (m *mockALPNServer) serve(ctx context.Context, t *testing.T) {
config := &tls.Config{
NextProtos: m.supportedProtos,
Certificates: []tls.Certificate{m.cert},
}
for {
select {
case <-ctx.Done():
return
default:
}
conn, err := m.Accept()
if errors.Is(err, net.ErrClosed) {
return
}
go func() {
clientConn := tls.Server(conn, config)
clientConn.HandshakeContext(ctx)
clientConn.Close()
}()
}
}
func mustStartMockALPNServer(t *testing.T, supportedProtos []string) *mockALPNServer {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
t.Cleanup(func() {
listener.Close()
})
caKey, caCert, err := tlsca.GenerateSelfSignedCA(pkix.Name{
CommonName: "localhost",
}, []string{"localhost"}, defaults.CATTL)
require.NoError(t, err)
cert, err := tls.X509KeyPair(caCert, caKey)
require.NoError(t, err)
m := &mockALPNServer{
Listener: listener,
cert: cert,
supportedProtos: supportedProtos,
}
go m.serve(ctx, t)
return m
}
// mockConnUpgradeHandler mocks the server side implementation to handle an
// upgrade request and sends back some data inside the tunnel.
func mockConnUpgradeHandler(t *testing.T, upgradeType string, write []byte) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, teleport.WebAPIConnUpgrade, r.URL.Path)
require.Equal(t, upgradeType, r.Header.Get(teleport.WebAPIConnUpgradeHeader))
hj, ok := w.(http.Hijacker)
require.True(t, ok)
conn, _, err := hj.Hijack()
require.NoError(t, err)
defer conn.Close()
// Upgrade response.
response := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
ProtoMajor: 1,
ProtoMinor: 1,
}
require.NoError(t, response.Write(conn))
// Upgraded.
_, err = conn.Write(write)
require.NoError(t, err)
})
}

100
lib/srv/alpnproxy/dialer.go Normal file
View file

@ -0,0 +1,100 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package alpnproxy
import (
"context"
"crypto/tls"
"net"
"time"
"github.com/gravitational/trace"
)
// ContextDialer represents network dialer interface that uses context
type ContextDialer interface {
// DialContext is a function that dials the specified address
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
// ALPNDialerConfig is the config for ALPNDialer.
type ALPNDialerConfig struct {
// KeepAlivePeriod defines period between keep alives.
KeepAlivePeriod time.Duration
// DialTimeout defines how long to attempt dialing before timing out.
DialTimeout time.Duration
// TLSConfig is the TLS config used for the TLS connection.
TLSConfig *tls.Config
// ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required.
ALPNConnUpgradeRequired bool
}
// ALPNDialer is a ContextDialer that dials a connection to the Proxy Service
// with ALPN and SNI configured in the provided TLSConfig. An ALPN connection
// upgrade is also performed at the initial connection, if an upgrade is
// required.
type ALPNDialer struct {
cfg ALPNDialerConfig
}
// NewALPNDialer creates a new ALPNDialer.
func NewALPNDialer(cfg ALPNDialerConfig) ContextDialer {
return &ALPNDialer{
cfg: cfg,
}
}
// DialContext implements ContextDialer.
func (d ALPNDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if d.cfg.TLSConfig == nil {
return nil, trace.BadParameter("missing TLS config")
}
var dialer ContextDialer = &net.Dialer{
KeepAlive: d.cfg.KeepAlivePeriod,
Timeout: d.cfg.DialTimeout,
}
if d.cfg.ALPNConnUpgradeRequired {
dialer = newALPNConnUpgradeDialer(d.cfg.KeepAlivePeriod, d.cfg.DialTimeout, d.cfg.TLSConfig.InsecureSkipVerify)
}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, trace.Wrap(err)
}
tlsConn := tls.Client(conn, d.cfg.TLSConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
defer tlsConn.Close()
return nil, trace.Wrap(err)
}
return tlsConn, nil
}
// DialALPN a helper to dial using an ALPNDialer and returns a tls.Conn if
// successful.
func DialALPN(ctx context.Context, addr string, cfg ALPNDialerConfig) (*tls.Conn, error) {
conn, err := NewALPNDialer(cfg).DialContext(ctx, "tcp", addr)
if err != nil {
return nil, trace.Wrap(err)
}
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return nil, trace.BadParameter("failed to convert to tls.Conn")
}
return tlsConn, nil
}

View file

@ -39,6 +39,7 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/tlsca"
)
@ -106,9 +107,10 @@ func (s *Suite) GetCertPool() *x509.CertPool {
return pool
}
func (s *Suite) Start(t *testing.T) {
func (s *Suite) CreateProxyServer(t *testing.T) *Proxy {
serverCert := mustGenCertSignedWithCA(t, s.ca)
tlsConfig := &tls.Config{
NextProtos: common.ProtocolsToString(common.SupportedProtocols),
ClientAuth: tls.VerifyClientCertIfGiven,
ClientCAs: s.GetCertPool(),
Certificates: []tls.Certificate{
@ -130,6 +132,11 @@ func (s *Suite) Start(t *testing.T) {
require.NoError(t, err)
// Reset GetConfigForClient to simplify test setup.
svr.cfg.IdentityTLSConfig.GetConfigForClient = nil
return svr
}
func (s *Suite) Start(t *testing.T) {
svr := s.CreateProxyServer(t)
go func() {
err := svr.Serve(context.Background())

View file

@ -19,7 +19,7 @@ package alpnproxy
import (
"context"
"crypto/tls"
"io"
"crypto/x509"
"net"
"net/http"
"net/http/httputil"
@ -64,13 +64,14 @@ type LocalProxyConfig struct {
SSHHostKeyCallback ssh.HostKeyCallback
// SSHTrustedCluster allows selecting trusted cluster ssh subsystem request.
SSHTrustedCluster string
// ClientTLSConfig is a client TLS configuration used during establishing
// connection to the RemoteProxyAddr.
ClientTLSConfig *tls.Config
// Certs are the client certificates used to connect to the remote Teleport Proxy.
Certs []tls.Certificate
// AWSCredentials are AWS Credentials used by LocalProxy for request's signature verification.
AWSCredentials *credentials.Credentials
// RootCAs overwrites the root CAs used in tls.Config if specified.
RootCAs *x509.CertPool
// ALPNConnUpgradeRequired specifies if ALPN connection upgrade is required.
ALPNConnUpgradeRequired bool
}
// CheckAndSetDefaults verifies the constraints for LocalProxyConfig.
@ -129,7 +130,7 @@ func (l *LocalProxy) Start(ctx context.Context) error {
return trace.Wrap(err)
}
go func() {
if err := l.handleDownstreamConnection(ctx, conn, l.cfg.SNI); err != nil {
if err := l.handleDownstreamConnection(ctx, conn); err != nil {
if utils.IsOKNetworkError(err) {
return
}
@ -146,14 +147,18 @@ func (l *LocalProxy) GetAddr() string {
// handleDownstreamConnection proxies the downstreamConn (connection established to the local proxy) and forward the
// traffic to the upstreamConn (TLS connection to remote host).
func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamConn net.Conn, serverName string) error {
func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamConn net.Conn) error {
defer downstreamConn.Close()
tlsConn, err := tls.Dial("tcp", l.cfg.RemoteProxyAddr, &tls.Config{
NextProtos: l.cfg.GetProtocols(),
InsecureSkipVerify: l.cfg.InsecureSkipVerify,
ServerName: serverName,
Certificates: l.cfg.Certs,
tlsConn, err := DialALPN(ctx, l.cfg.RemoteProxyAddr, ALPNDialerConfig{
ALPNConnUpgradeRequired: l.cfg.ALPNConnUpgradeRequired,
TLSConfig: &tls.Config{
NextProtos: l.cfg.GetProtocols(),
InsecureSkipVerify: l.cfg.InsecureSkipVerify,
ServerName: l.cfg.SNI,
Certificates: l.cfg.Certs,
RootCAs: l.cfg.RootCAs,
},
})
if err != nil {
return trace.Wrap(err)
@ -166,32 +171,7 @@ func (l *LocalProxy) handleDownstreamConnection(ctx context.Context, downstreamC
upstreamConn = NewPingConn(tlsConn)
}
errC := make(chan error, 2)
go func() {
defer downstreamConn.Close()
defer upstreamConn.Close()
_, err := io.Copy(downstreamConn, upstreamConn)
errC <- err
}()
go func() {
defer downstreamConn.Close()
defer upstreamConn.Close()
_, err := io.Copy(upstreamConn, downstreamConn)
errC <- err
}()
var errs []error
for i := 0; i < 2; i++ {
select {
case <-ctx.Done():
return trace.NewAggregate(append(errs, ctx.Err())...)
case err := <-errC:
if err != nil && !utils.IsOKNetworkError(err) {
errs = append(errs, err)
}
}
}
return trace.NewAggregate(errs...)
return trace.Wrap(utils.ProxyConn(ctx, downstreamConn, upstreamConn))
}
func (l *LocalProxy) Close() error {

View file

@ -324,7 +324,7 @@ func (p *Proxy) Serve(ctx context.Context) error {
// For example in ReverseTunnel handles connection asynchronously and closing conn after
// service handler returned will break service logic.
// https://github.com/gravitational/teleport/blob/master/lib/sshutils/server.go#L397
if err := p.handleConn(ctx, clientConn); err != nil {
if err := p.handleConn(ctx, clientConn, nil); err != nil {
if cerr := clientConn.Close(); cerr != nil && !utils.IsOKNetworkError(cerr) {
p.log.WithError(cerr).Warnf("Failed to close client connection.")
}
@ -361,7 +361,7 @@ type HandlerFuncWithInfo func(ctx context.Context, conn net.Conn, info Connectio
// 5) For backward compatibility check RouteToDatabase identity field
// was set if yes forward to the generic TLS DB handler.
// 6) Forward connection to the handler obtained in step 2.
func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn) error {
func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn, defaultOverride *tls.Config) error {
hello, conn, err := p.readHelloMessageWithoutTLSTermination(clientConn)
if err != nil {
return trace.Wrap(err)
@ -381,7 +381,7 @@ func (p *Proxy) handleConn(ctx context.Context, clientConn net.Conn) error {
return trace.Wrap(handlerDesc.handle(ctx, conn, connInfo))
}
tlsConn := tls.Server(conn, p.getTLSConfig(handlerDesc))
tlsConn := tls.Server(conn, p.getTLSConfig(handlerDesc, defaultOverride))
if err := tlsConn.SetReadDeadline(p.cfg.Clock.Now().Add(p.cfg.ReadDeadline)); err != nil {
return trace.Wrap(err)
}
@ -437,12 +437,17 @@ func (p *Proxy) handlePingConnection(ctx context.Context, conn *tls.Conn) net.Co
return pingConn
}
// getTLSConfig returns HandlerDesc.TLSConfig if custom TLS configuration was set for the handler
// otherwise the ProxyConfig.WebTLSConfig is used.
func (p *Proxy) getTLSConfig(desc *HandlerDecs) *tls.Config {
// getTLSConfig picks the TLS config with the following priority:
// - TLS config found in the provided handler.
// - A default override.
// - The default TLS config (cfg.WebTLSConfig).
func (p *Proxy) getTLSConfig(desc *HandlerDecs, defaultOverride *tls.Config) *tls.Config {
if desc.TLSConfig != nil {
return desc.TLSConfig
}
if defaultOverride != nil {
return defaultOverride
}
return p.cfg.WebTLSConfig
}
@ -578,3 +583,33 @@ func (p *Proxy) Close() error {
}
return nil
}
// MakeConnectionHandler creates a ConnectionHandler which provides a callback
// to handle incoming connections by this ALPN proxy server.
func (p *Proxy) MakeConnectionHandler(defaultOverride *tls.Config) ConnectionHandler {
return func(ctx context.Context, conn net.Conn) error {
return p.handleConn(ctx, conn, defaultOverride)
}
}
// ConnectionHandler defines a function for serving incoming connections.
type ConnectionHandler func(ctx context.Context, conn net.Conn) error
// ConnectionHandlerWrapper is a wrapper of ConnectionHandler. This wrapper is
// mainly used as a placeholder to resolve circular dependencies.
type ConnectionHandlerWrapper struct {
h ConnectionHandler
}
// Set updates inner ConnectionHandler to use.
func (w *ConnectionHandlerWrapper) Set(h ConnectionHandler) {
w.h = h
}
// HandleConnection implements ConnectionHandler.
func (w *ConnectionHandlerWrapper) HandleConnection(ctx context.Context, conn net.Conn) error {
if w.h == nil {
return trace.NotFound("missing ConnectionHandler")
}
return w.h(ctx, conn)
}

View file

@ -17,13 +17,16 @@ limitations under the License.
package alpnproxy
import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"net"
"net/http"
"testing"
"time"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
@ -335,6 +338,73 @@ func TestProxyHTTPConnection(t *testing.T) {
mustSuccessfullyCallHTTPSServer(t, suite.GetServerAddress(), client)
}
// TestProxyMakeConnectionHandler creates a ConnectionHandler from the ALPN
// proxy server, and verifies ALPN protocol is properly handled through the
// ConnectionHandler.
func TestProxyMakeConnectionHandler(t *testing.T) {
t.Parallel()
suite := NewSuite(t)
// Create a HTTP server and register the listener to ALPN server.
lw := NewMuxListenerWrapper(nil, suite.serverListener)
mustStartHTTPServer(t, lw)
suite.router = NewRouter()
suite.router.Add(HandlerDecs{
MatchFunc: MatchByProtocol(common.ProtocolHTTP),
Handler: lw.HandleConnection,
})
svr := suite.CreateProxyServer(t)
customCA := mustGenSelfSignedCert(t)
// Create a ConnectionHandler from the proxy server.
alpnConnHandler := svr.MakeConnectionHandler(&tls.Config{
NextProtos: []string{string(common.ProtocolHTTP)},
Certificates: []tls.Certificate{
mustGenCertSignedWithCA(t, customCA),
},
})
// Prepare net.Conn to be used for the created alpnConnHandler.
serverConn, clientConn := net.Pipe()
defer clientConn.Close()
defer serverConn.Close()
timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
// Let alpnConnHandler serve the connection in a separate go routine.
go func() {
alpnConnHandler(timeoutCtx, serverConn)
require.NoError(t, timeoutCtx.Err())
}()
// Send client request.
req, err := http.NewRequest("GET", "https://localhost/test", nil)
require.NoError(t, err)
// Use the customCA to validate default TLS config override.
pool := x509.NewCertPool()
pool.AddCert(customCA.Cert)
clientTLSConn := tls.Client(clientConn, &tls.Config{
NextProtos: []string{string(common.ProtocolHTTP)},
RootCAs: pool,
ServerName: "localhost",
})
defer clientTLSConn.Close()
require.NoError(t, clientTLSConn.Handshake())
require.Equal(t, string(common.ProtocolHTTP), clientTLSConn.ConnectionState().NegotiatedProtocol)
require.NoError(t, req.Write(clientTLSConn))
response, err := http.ReadResponse(bufio.NewReader(clientTLSConn), req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, response.StatusCode)
}
// TestProxyALPNProtocolsRouting tests the routing based on client TLS NextProtos values.
func TestProxyALPNProtocolsRouting(t *testing.T) {
t.Parallel()

View file

@ -24,6 +24,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/go-mysql-org/go-mysql/client"
@ -54,6 +55,20 @@ func MakeTestClient(config common.TestClientConfig) (*client.Conn, error) {
return conn, nil
}
// MakeTestClientWithoutTLS returns a MySQL client connection without setting
// TLS config to the MySQL client.
func MakeTestClientWithoutTLS(addr string, routeToDatabase tlsca.RouteToDatabase) (*client.Conn, error) {
conn, err := client.Connect(addr,
routeToDatabase.Username,
"",
routeToDatabase.Database,
)
if err != nil {
return nil, trace.Wrap(err)
}
return conn, nil
}
// TestServer is a test MySQL server used in functional database
// access tests.
type TestServer struct {

View file

@ -217,13 +217,12 @@ func IsSelfSigned(certificateChain []*x509.Certificate) bool {
return bytes.Equal(certificateChain[0].SubjectKeyId, certificateChain[0].AuthorityKeyId)
}
// ReadCertificateChain parses PEM encoded bytes that can contain one or
// ReadCertificates parses PEM encoded bytes that can contain one or
// multiple certificates and returns a slice of x509.Certificate.
func ReadCertificateChain(certificateChainBytes []byte) ([]*x509.Certificate, error) {
// build the certificate chain next
func ReadCertificates(certificateChainBytes []byte) ([]*x509.Certificate, error) {
var (
certificateBlock *pem.Block
certificateChain [][]byte
certificates [][]byte
)
remainingBytes := bytes.TrimSpace(certificateChainBytes)
@ -232,29 +231,59 @@ func ReadCertificateChain(certificateChainBytes []byte) ([]*x509.Certificate, er
if certificateBlock == nil || certificateBlock.Type != pemBlockCertificate {
return nil, trace.NotFound("no PEM data found")
}
certificateChain = append(certificateChain, certificateBlock.Bytes)
certificates = append(certificates, certificateBlock.Bytes)
if len(remainingBytes) == 0 {
break
}
}
// build a concatenated certificate chain
// build concatenated certificates into a buffer
var buf bytes.Buffer
for _, cc := range certificateChain {
for _, cc := range certificates {
_, err := buf.Write(cc)
if err != nil {
return nil, trace.Wrap(err)
}
}
// parse the chain and get a slice of x509.Certificates.
x509Chain, err := x509.ParseCertificates(buf.Bytes())
// parse the buffer and get a slice of x509.Certificates.
x509Certs, err := x509.ParseCertificates(buf.Bytes())
if err != nil {
return nil, trace.Wrap(err)
}
return x509Chain, nil
return x509Certs, nil
}
// ReadCertificatesFromPath parses PEM encoded certificates from provided path.
func ReadCertificatesFromPath(path string) ([]*x509.Certificate, error) {
bytes, err := ReadPath(path)
if err != nil {
return nil, trace.Wrap(err)
}
certs, err := ReadCertificates(bytes)
if err != nil {
return nil, trace.Wrap(err)
}
return certs, nil
}
// NewCertPoolFromPath creates a new x509.CertPool from provided path.
func NewCertPoolFromPath(path string) (*x509.CertPool, error) {
// x509.CertPool.AppendCertsFromPEM skips parse errors. Using our own
// implementation here to be more strict.
cas, err := ReadCertificatesFromPath(path)
if err != nil {
return nil, trace.Wrap(err)
}
pool := x509.NewCertPool()
for _, ca := range cas {
pool.AddCert(ca)
}
return pool, nil
}
const pemBlockCertificate = "CERTIFICATE"

View file

@ -17,7 +17,6 @@ limitations under the License.
package utils
import (
"os"
"runtime"
"testing"
@ -31,17 +30,14 @@ import (
func TestRejectsInvalidPEMData(t *testing.T) {
t.Parallel()
_, err := ReadCertificateChain([]byte("no data"))
_, err := ReadCertificates([]byte("no data"))
require.IsType(t, trace.Unwrap(err), &trace.NotFoundError{})
}
func TestRejectsSelfSignedCertificate(t *testing.T) {
t.Parallel()
certificateChainBytes, err := os.ReadFile("../../fixtures/certs/ca.pem")
require.NoError(t, err)
certificateChain, err := ReadCertificateChain(certificateChainBytes)
certificateChain, err := ReadCertificatesFromPath("../../fixtures/certs/ca.pem")
require.NoError(t, err)
err = VerifyCertificateChain(certificateChain)
@ -52,3 +48,11 @@ func TestRejectsSelfSignedCertificate(t *testing.T) {
require.ErrorContains(t, err, "x509: certificate signed by unknown authority")
}
}
func TestNewCertPoolFromPath(t *testing.T) {
t.Parallel()
pool, err := NewCertPoolFromPath("../../fixtures/certs/ca.pem")
require.NoError(t, err)
require.Len(t, pool.Subjects(), 1)
}

View file

@ -200,6 +200,10 @@ type Config struct {
// PublicProxyAddr is used to template the public proxy address
// into the installer script responses
PublicProxyAddr string
// ALPNHandler is the ALPN connection handler for handling upgraded ALPN
// connection through a HTTP upgrade call.
ALPNHandler ConnectionHandler
}
type APIHandler struct {
@ -209,6 +213,9 @@ type APIHandler struct {
appHandler *app.Handler
}
// ConnectionHandler defines a function for serving incoming connections.
type ConnectionHandler func(ctx context.Context, conn net.Conn) error
// Check if this request should be forwarded to an application handler to
// be handled by the UI and handle the request appropriately.
func (h *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -609,6 +616,9 @@ func (h *Handler) bindDefaultEndpoints(challengeLimiter *limiter.RateLimiter) {
h.GET("/webapi/sites/:site/diagnostics/connections/:connectionid", h.WithClusterAuth(h.getConnectionDiagnostic))
// Diagnose a Connection
h.POST("/webapi/sites/:site/diagnostics/connections", h.WithClusterAuth(h.diagnoseConnection))
// Connection upgrades.
h.GET("/webapi/connectionupgrade", httplib.MakeHandler(h.connectionUpgrade))
}
// GetProxyClient returns authenticated auth server client

131
lib/web/conn_upgrade.go Normal file
View file

@ -0,0 +1,131 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package web
import (
"context"
"io"
"net"
"net/http"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
)
// selectConnectionUpgrade selects the requested upgrade type and returns the
// corresponding handler.
func (h *Handler) selectConnectionUpgrade(r *http.Request) (string, ConnectionHandler, error) {
upgrades := r.Header.Values(teleport.WebAPIConnUpgradeHeader)
for _, upgradeType := range upgrades {
switch upgradeType {
case teleport.WebAPIConnUpgradeTypeALPN:
return upgradeType, h.upgradeALPN, nil
}
}
return "", nil, trace.BadParameter("unsupported upgrade types: %v", upgrades)
}
// connectionUpgrade handles connection upgrades.
func (h *Handler) connectionUpgrade(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
upgradeType, upgradeHandler, err := h.selectConnectionUpgrade(r)
if err != nil {
return nil, trace.Wrap(err)
}
hj, ok := w.(http.Hijacker)
if !ok {
return nil, trace.BadParameter("failed to hijack connection")
}
conn, _, err := hj.Hijack()
if err != nil {
return nil, trace.Wrap(err)
}
defer conn.Close()
// Since w is hijacked, there is no point returning an error for response
// starting at this point.
if err := writeUpgradeResponse(conn, upgradeType); err != nil {
h.log.WithError(err).Error("Failed to write upgrade response.")
return nil, nil
}
if err := upgradeHandler(r.Context(), conn); err != nil && !utils.IsOKNetworkError(err) {
h.log.WithError(err).Errorf("Failed to handle %v upgrade request.", upgradeType)
}
return nil, nil
}
func (h *Handler) upgradeALPN(ctx context.Context, conn net.Conn) error {
if h.cfg.ALPNHandler == nil {
return trace.BadParameter("missing ALPNHandler")
}
// ALPNHandler may handle some connections asynchronously. Here we want to
// block until the handling is done by waiting until the connection is
// closed.
waitConn := newWaitConn(ctx, conn)
defer waitConn.WaitForClose()
return h.cfg.ALPNHandler(ctx, waitConn)
}
func writeUpgradeResponse(w io.Writer, upgradeType string) error {
header := make(http.Header)
header.Add(teleport.WebAPIConnUpgradeHeader, upgradeType)
response := &http.Response{
Status: http.StatusText(http.StatusSwitchingProtocols),
StatusCode: http.StatusSwitchingProtocols,
Header: header,
ProtoMajor: 1,
ProtoMinor: 1,
}
return response.Write(w)
}
// waitConn is a net.Conn that provides a "WaitForClose" function to wait until
// the connection is closed.
type waitConn struct {
net.Conn
ctx context.Context
cancel context.CancelFunc
}
// newWaitConn creates a new waitConn.
func newWaitConn(ctx context.Context, conn net.Conn) *waitConn {
ctx, cancel := context.WithCancel(ctx)
return &waitConn{
Conn: conn,
ctx: ctx,
cancel: cancel,
}
}
// WaitForClose blocks until the Close() function of this connection is called.
func (conn *waitConn) WaitForClose() {
<-conn.ctx.Done()
}
// Close implements net.Conn.
func (conn *waitConn) Close() error {
err := conn.Conn.Close()
conn.cancel()
return trace.Wrap(err)
}

View file

@ -0,0 +1,126 @@
/*
Copyright 2022 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package web
import (
"bufio"
"bytes"
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
)
func TestWriteUpgradeResponse(t *testing.T) {
var buf bytes.Buffer
require.NoError(t, writeUpgradeResponse(&buf, "custom"))
resp, err := http.ReadResponse(bufio.NewReader(&buf), nil)
require.NoError(t, err)
require.Equal(t, resp.StatusCode, http.StatusSwitchingProtocols)
require.Equal(t, "custom", resp.Header.Get("Upgrade"))
}
func TestHandlerConnectionUpgrade(t *testing.T) {
t.Parallel()
expectedPayload := "hello@"
alpnHandler := func(_ context.Context, conn net.Conn) error {
// Handles connection asynchronously to verify web handler waits until
// connection is closed.
go func() {
defer conn.Close()
n, err := conn.Write([]byte(expectedPayload))
require.NoError(t, err)
require.Equal(t, len(expectedPayload), n)
}()
return nil
}
// Cherry picked some attributes to create a Handler to test only the
// connection upgrade portion.
h := &Handler{
cfg: Config{
ALPNHandler: alpnHandler,
},
log: newPackageLogger(),
clock: clockwork.NewRealClock(),
}
t.Run("unsupported type", func(t *testing.T) {
r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil)
require.NoError(t, err)
r.Header.Add("Upgrade", "unsupported-protocol")
_, err = h.connectionUpgrade(httptest.NewRecorder(), r, nil)
require.True(t, trace.IsBadParameter(err))
})
t.Run("upgraded to ALPN", func(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer serverConn.Close()
defer clientConn.Close()
r, err := http.NewRequest("GET", "http://localhost/webapi/connectionupgrade", nil)
require.NoError(t, err)
r.Header.Add("Upgrade", "alpn")
go func() {
// serverConn will be hijacked.
w := newResponseWriterHijacker(nil, serverConn)
_, err = h.connectionUpgrade(w, r, nil)
require.NoError(t, err)
}()
// Verify clientConn receives http.StatusSwitchingProtocols.
clientConnReader := bufio.NewReader(clientConn)
response, err := http.ReadResponse(clientConnReader, r)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, response.StatusCode)
// Verify clientConn receives data sent by Config.ALPNHandler.
receive, err := clientConnReader.ReadString(byte('@'))
require.NoError(t, err)
require.Equal(t, expectedPayload, receive)
})
}
// responseWriterHijacker is a mock http.ResponseWriter that also serves a
// net.Conn for http.Hijacker.
type responseWriterHijacker struct {
http.ResponseWriter
conn net.Conn
}
func newResponseWriterHijacker(w http.ResponseWriter, conn net.Conn) *responseWriterHijacker {
if w == nil {
w = httptest.NewRecorder()
}
return &responseWriterHijacker{
ResponseWriter: w,
conn: conn,
}
}
func (h responseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return h.conn, nil, nil
}

View file

@ -35,6 +35,7 @@ import (
"github.com/gravitational/teleport/lib/client/db/dbcmd"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/alpnproxy"
"github.com/gravitational/teleport/lib/srv/alpnproxy/common"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
@ -555,7 +556,7 @@ func serializeDatabaseConfig(configInfo *dbConfigInfo, format string) (string, e
// connection scenario and returns a list of options to use in the connect
// command.
func maybeStartLocalProxy(cf *CLIConf, tc *client.TeleportClient, profile *client.ProfileStatus, db *tlsca.RouteToDatabase,
database types.Database, cluster string,
database types.Database, rootClusterName string,
) ([]dbcmd.ConnectCommandFunc, error) {
if !isLocalProxyRequiredForDatabase(tc, db) {
return []dbcmd.ConnectCommandFunc{}, nil
@ -582,6 +583,7 @@ func maybeStartLocalProxy(cf *CLIConf, tc *client.TeleportClient, profile *clien
database: database,
listener: listener,
localProxyTunnel: localProxyTunnel,
rootClusterName: rootClusterName,
})
if err != nil {
return nil, trace.Wrap(err)
@ -609,7 +611,7 @@ func maybeStartLocalProxy(cf *CLIConf, tc *client.TeleportClient, profile *clien
// validation, so connect to localhost.
host := "localhost"
return []dbcmd.ConnectCommandFunc{
dbcmd.WithLocalProxy(host, addr.Port(0), profile.CACertPathForCluster(cluster)),
dbcmd.WithLocalProxy(host, addr.Port(0), profile.CACertPathForCluster(rootClusterName)),
}, nil
}
@ -625,6 +627,7 @@ type localProxyConfig struct {
// it's always true for Snowflake database. Value is copied here to not modify
// cli arguments directly.
localProxyTunnel bool
rootClusterName string
}
// prepareLocalProxyOptions created localProxyOpts needed to create local proxy from localProxyConfig.
@ -639,12 +642,23 @@ func prepareLocalProxyOptions(arg *localProxyConfig) (localProxyOpts, error) {
}
opts := localProxyOpts{
proxyAddr: arg.teleportClient.WebProxyAddr,
listener: arg.listener,
protocols: []common.Protocol{common.Protocol(arg.routeToDatabase.Protocol)},
insecure: arg.cliConf.InsecureSkipVerify,
certFile: certFile,
keyFile: keyFile,
proxyAddr: arg.teleportClient.WebProxyAddr,
listener: arg.listener,
protocols: []common.Protocol{common.Protocol(arg.routeToDatabase.Protocol)},
insecure: arg.cliConf.InsecureSkipVerify,
certFile: certFile,
keyFile: keyFile,
alpnConnUpgradeRequired: alpnproxy.IsALPNConnUpgradeRequired(arg.teleportClient.WebProxyAddr, arg.cliConf.InsecureSkipVerify),
}
// If ALPN connection upgrade is required, explicitly use the profile CAs
// since the tunneled TLS routing connection serves the Host cert.
if opts.alpnConnUpgradeRequired {
profileCAs, err := utils.NewCertPoolFromPath(arg.profile.CACertPathForCluster(arg.rootClusterName))
if err != nil {
return localProxyOpts{}, trace.Wrap(err)
}
opts.rootCAs = profileCAs
}
// For SQL Server connections, local proxy must be configured with the
@ -705,11 +719,7 @@ func onDatabaseConnect(cf *CLIConf) error {
return trace.Wrap(err)
}
key, err := tc.LocalAgent().GetCoreKey()
if err != nil {
return trace.Wrap(err)
}
rootClusterName, err := key.RootClusterName()
rootClusterName, err := tc.RootClusterName(cf.Context)
if err != nil {
return trace.Wrap(err)
}

View file

@ -363,6 +363,7 @@ func onProxyCommandDB(cf *CLIConf) error {
routeToDatabase: routeToDatabase,
listener: listener,
localProxyTunnel: cf.LocalProxyTunnel,
rootClusterName: rootCluster,
})
if err != nil {
return trace.Wrap(err)
@ -422,12 +423,14 @@ func onProxyCommandDB(cf *CLIConf) error {
}
type localProxyOpts struct {
proxyAddr string
listener net.Listener
protocols []alpncommon.Protocol
insecure bool
certFile string
keyFile string
proxyAddr string
listener net.Listener
protocols []alpncommon.Protocol
insecure bool
certFile string
keyFile string
rootCAs *x509.CertPool
alpnConnUpgradeRequired bool
}
// protocol returns the first protocol or string if configuration doesn't contain any protocols.
@ -458,13 +461,15 @@ func mkLocalProxy(ctx context.Context, opts localProxyOpts) (*alpnproxy.LocalPro
}
lp, err := alpnproxy.NewLocalProxy(alpnproxy.LocalProxyConfig{
InsecureSkipVerify: opts.insecure,
RemoteProxyAddr: opts.proxyAddr,
Protocols: protocols,
Listener: opts.listener,
ParentContext: ctx,
SNI: address.Host(),
Certs: certs,
InsecureSkipVerify: opts.insecure,
RemoteProxyAddr: opts.proxyAddr,
Protocols: protocols,
Listener: opts.listener,
ParentContext: ctx,
SNI: address.Host(),
Certs: certs,
RootCAs: opts.rootCAs,
ALPNConnUpgradeRequired: opts.alpnConnUpgradeRequired,
})
if err != nil {
return nil, trace.Wrap(err)