API client tunnel address discovery fix (#7533)

This commit is contained in:
Brian Joerger 2021-08-11 14:34:50 -07:00 committed by GitHub
parent aa68b80301
commit 25c9c982db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 302 additions and 139 deletions

View file

@ -56,17 +56,12 @@ func NewDirectDialer(keepAlivePeriod, dialTimeout time.Duration) ContextDialer {
func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer {
dialer := newTunnelDialer(ssh, keepAlivePeriod, dialTimeout)
return ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
// Ping web proxy to retrieve tunnel proxy address.
pr, err := webclient.Find(ctx, discoveryAddr, insecure, nil)
tunnelAddr, err := webclient.GetTunnelAddr(ctx, discoveryAddr, insecure, nil)
if err != nil {
return nil, trace.Wrap(err)
}
if pr.Proxy.SSH.TunnelPublicAddr == "" {
return nil, trace.BadParameter("reverse tunnel address not discoverable, 'tunnel_public_addr' is not set")
}
conn, err = dialer.DialContext(ctx, network, pr.Proxy.SSH.TunnelPublicAddr)
conn, err = dialer.DialContext(ctx, network, tunnelAddr)
if err != nil {
return nil, trace.Wrap(err)
}

View file

@ -23,9 +23,15 @@ import (
"crypto/x509"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/trace"
)
@ -90,6 +96,21 @@ func Ping(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertP
return pr, nil
}
// GetTunnelAddr returns the tunnel address either set in an environment variable or retrieved from the web proxy.
func GetTunnelAddr(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (string, error) {
// If TELEPORT_TUNNEL_PUBLIC_ADDR is set, nothing else has to be done, return it.
if tunnelAddr := os.Getenv(defaults.TunnelPublicAddrEnvar); tunnelAddr != "" {
return extractHostPort(tunnelAddr)
}
// Ping web proxy to retrieve tunnel proxy address.
pr, err := Find(ctx, proxyAddr, insecure, nil)
if err != nil {
return "", trace.Wrap(err)
}
return tunnelAddr(proxyAddr, pr.Proxy.SSH)
}
func GetMOTD(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (*MotD, error) {
clt := newWebClient(insecure, pool)
defer clt.CloseIdleConnections()
@ -234,3 +255,93 @@ type GithubSettings struct {
// Display is the connector display name
Display string `json:"display"`
}
// The tunnel addr is retrieved in the following preference order:
// 1. Reverse Tunnel Public Address.
// 2. SSH Proxy Public Address Host + Tunnel Port.
// 3. HTTP Proxy Public Address Host + Tunnel Port.
// 4. Proxy Address Host + Tunnel Port.
func tunnelAddr(proxyAddr string, settings SSHProxySettings) (string, error) {
// If a tunnel public address is set, nothing else has to be done, return it.
if settings.TunnelPublicAddr != "" {
return extractHostPort(settings.TunnelPublicAddr)
}
// Extract the port the tunnel server is listening on.
tunnelPort := strconv.Itoa(defaults.SSHProxyTunnelListenPort)
if settings.TunnelListenAddr != "" {
if port, err := extractPort(settings.TunnelListenAddr); err == nil {
tunnelPort = port
}
}
// If a tunnel public address has not been set, but a related HTTP or SSH
// public address has been set, extract the hostname but use the port from
// the tunnel listen address.
if settings.SSHPublicAddr != "" {
if host, err := extractHost(settings.SSHPublicAddr); err == nil {
return net.JoinHostPort(host, tunnelPort), nil
}
}
if settings.PublicAddr != "" {
if host, err := extractHost(settings.PublicAddr); err == nil {
return net.JoinHostPort(host, tunnelPort), nil
}
}
// If nothing is set, fallback to the address dialed with tunnel port.
host, err := extractHost(proxyAddr)
if err != nil {
return "", trace.Wrap(err, "failed to parse the given proxy address")
}
return net.JoinHostPort(host, tunnelPort), nil
}
// extractHostPort takes addresses like "tcp://host:port/path" and returns "host:port".
func extractHostPort(addr string) (string, error) {
if addr == "" {
return "", trace.BadParameter("missing parameter address")
}
if !strings.Contains(addr, "://") {
addr = "tcp://" + addr
}
u, err := url.Parse(addr)
if err != nil {
return "", trace.BadParameter("failed to parse %q: %v", addr, err)
}
switch u.Scheme {
case "tcp", "http", "https":
return u.Host, nil
default:
return "", trace.BadParameter("'%v': unsupported scheme: '%v'", addr, u.Scheme)
}
}
// extractHost takes addresses like "tcp://host:port/path" and returns "host".
func extractHost(addr string) (ra string, err error) {
parsed, err := extractHostPort(addr)
if err != nil {
return "", trace.Wrap(err)
}
host, _, err := net.SplitHostPort(parsed)
if err != nil {
if strings.Contains(err.Error(), "missing port in address") {
return addr, nil
}
return "", trace.Wrap(err)
}
return host, nil
}
// extractPort takes addresses like "tcp://host:port/path" and returns "port".
func extractPort(addr string) (string, error) {
parsed, err := extractHostPort(addr)
if err != nil {
return "", trace.Wrap(err)
}
_, port, err := net.SplitHostPort(parsed)
if err != nil {
return "", trace.Wrap(err)
}
return port, nil
}

View file

@ -0,0 +1,164 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package webclient
import (
"context"
"os"
"testing"
"github.com/gravitational/teleport/api/defaults"
"github.com/stretchr/testify/require"
)
func TestGetTunnelAddr(t *testing.T) {
ctx := context.Background()
t.Run("should use TELEPORT_TUNNEL_PUBLIC_ADDR", func(t *testing.T) {
os.Setenv(defaults.TunnelPublicAddrEnvar, "tunnel.example.com:4024")
t.Cleanup(func() { os.Unsetenv(defaults.TunnelPublicAddrEnvar) })
tunnelAddr, err := GetTunnelAddr(ctx, "", true, nil)
require.NoError(t, err)
require.Equal(t, "tunnel.example.com:4024", tunnelAddr)
})
}
func TestTunnelAddr(t *testing.T) {
type testCase struct {
proxyAddr string
settings SSHProxySettings
expectedTunnelAddr string
}
testTunnelAddr := func(tc testCase) func(*testing.T) {
return func(t *testing.T) {
t.Parallel()
tunnelAddr, err := tunnelAddr(tc.proxyAddr, tc.settings)
require.NoError(t, err)
require.Equal(t, tc.expectedTunnelAddr, tunnelAddr)
}
}
t.Run("should use TunnelPublicAddr", testTunnelAddr(testCase{
proxyAddr: "proxy.example.com",
settings: SSHProxySettings{
TunnelPublicAddr: "tunnel.example.com:4024",
PublicAddr: "public.example.com",
SSHPublicAddr: "ssh.example.com",
TunnelListenAddr: "[::]:5024",
},
expectedTunnelAddr: "tunnel.example.com:4024",
}))
t.Run("should use SSHPublicAddr and TunnelListenAddr", testTunnelAddr(testCase{
proxyAddr: "proxy.example.com",
settings: SSHProxySettings{
SSHPublicAddr: "ssh.example.com",
PublicAddr: "public.example.com",
TunnelListenAddr: "[::]:5024",
},
expectedTunnelAddr: "ssh.example.com:5024",
}))
t.Run("should use PublicAddr and TunnelListenAddr", testTunnelAddr(testCase{
proxyAddr: "proxy.example.com",
settings: SSHProxySettings{
PublicAddr: "public.example.com",
TunnelListenAddr: "[::]:5024",
},
expectedTunnelAddr: "public.example.com:5024",
}))
t.Run("should use PublicAddr and SSHProxyTunnelListenPort", testTunnelAddr(testCase{
proxyAddr: "proxy.example.com",
settings: SSHProxySettings{
PublicAddr: "public.example.com",
},
expectedTunnelAddr: "public.example.com:3024",
}))
t.Run("should use proxyAddr and SSHProxyTunnelListenPort", testTunnelAddr(testCase{
proxyAddr: "proxy.example.com",
settings: SSHProxySettings{},
expectedTunnelAddr: "proxy.example.com:3024",
}))
}
func TestExtract(t *testing.T) {
testCases := []struct {
addr string
hostPort string
host string
port string
}{
{
addr: "example.com",
hostPort: "example.com",
host: "example.com",
port: "",
}, {
addr: "example.com:443",
hostPort: "example.com:443",
host: "example.com",
port: "443",
}, {
addr: "http://example.com:443",
hostPort: "example.com:443",
host: "example.com",
port: "443",
}, {
addr: "https://example.com:443",
hostPort: "example.com:443",
host: "example.com",
port: "443",
}, {
addr: "tcp://example.com:443",
hostPort: "example.com:443",
host: "example.com",
port: "443",
}, {
addr: "file://host/path",
hostPort: "",
host: "",
port: "",
}, {
addr: "[::]:443",
hostPort: "[::]:443",
host: "::",
port: "443",
}, {
addr: "https://example.com:443/path?query=query#fragment",
hostPort: "example.com:443",
host: "example.com",
port: "443",
},
}
for _, tc := range testCases {
t.Run(tc.addr, func(t *testing.T) {
hostPort, err := extractHostPort(tc.addr)
// Expect err if expected value is empty
require.True(t, (tc.hostPort == "") == (err != nil))
require.Equal(t, tc.hostPort, hostPort)
host, err := extractHost(tc.addr)
// Expect err if expected value is empty
require.True(t, (tc.host == "") == (err != nil))
require.Equal(t, tc.host, host)
port, err := extractPort(tc.addr)
// Expect err if expected value is empty
require.True(t, (tc.port == "") == (err != nil))
require.Equal(t, tc.port, port)
})
}
}

View file

@ -73,3 +73,15 @@ const (
// DefaultChunkSize is the default chunk size for paginated endpoints.
DefaultChunkSize = 1000
)
const (
// When running in "SSH Proxy" role this port will be used for incoming
// connections from SSH nodes who wish to use "reverse tunnell" (when they
// run behind an environment/firewall which only allows outgoing connections)
SSHProxyTunnelListenPort = 3024
)
const (
// TunnelPublicAddrEnvar optionally specifies the alternative reverse tunnel address.
TunnelPublicAddrEnvar = "TELEPORT_TUNNEL_PUBLIC_ADDR"
)

View file

@ -869,11 +869,8 @@ func testCustomReverseTunnel(t *testing.T, suite *integrationTestSuite) {
nodeConf.Auth.Enabled = false
nodeConf.Proxy.Enabled = false
nodeConf.SSH.Enabled = true
nodeConf.SSH.ProxyReverseTunnelFallbackAddr = &utils.NetAddr{
// Configure the original proxy address as a fallback so the node is able to connect
Addr: main.Secrets.WebProxyAddr,
AddrNetwork: "tcp",
}
os.Setenv(apidefaults.TunnelPublicAddrEnvar, main.Secrets.WebProxyAddr)
t.Cleanup(func() { os.Unsetenv(apidefaults.TunnelPublicAddrEnvar) })
// verify the node is able to join the cluster
_, err = main.StartReverseTunnelNode(nodeConf)

View file

@ -839,13 +839,6 @@ func applySSHConfig(fc *FileConfig, cfg *service.Config) (err error) {
cfg.SSH.RestrictedSession = rs
}
if proxyAddr := os.Getenv(defaults.TunnelPublicAddrEnvar); proxyAddr != "" {
cfg.SSH.ProxyReverseTunnelFallbackAddr, err = utils.ParseHostPortAddr(proxyAddr, defaults.SSHProxyTunnelListenPort)
if err != nil {
return trace.Wrap(err, "invalid reverse tunnel address format %q", proxyAddr)
}
}
cfg.SSH.AllowTCPForwarding = fc.SSH.AllowTCPForwarding()
return nil

View file

@ -48,10 +48,7 @@ const (
// one of many SSH nodes
SSHProxyListenPort = 3023
// When running in "SSH Proxy" role this port will be used for incoming
// connections from SSH nodes who wish to use "reverse tunnell" (when they
// run behind an environment/firewall which only allows outgoing connections)
SSHProxyTunnelListenPort = 3024
SSHProxyTunnelListenPort = defaults.SSHProxyTunnelListenPort
// KubeListenPort is a default port for kubernetes proxies
KubeListenPort = 3026
@ -498,9 +495,6 @@ var (
// the Teleport configuration file that tctl reads on use
ConfigFileEnvar = "TELEPORT_CONFIG_FILE"
// TunnelPublicAddrEnvar optionally specifies the alternative reverse tunnel address.
TunnelPublicAddrEnvar = "TELEPORT_TUNNEL_PUBLIC_ADDR"
// LicenseFile is the default name of the license file
LicenseFile = "license.pem"

View file

@ -541,14 +541,6 @@ type SSHConfig struct {
// RestrictedSession holds kernel objects restrictions for Teleport.
RestrictedSession *restricted.Config
// ProxyReverseTunnelFallbackAddr optionall specifies the address of the proxy if reverse tunnel
// discovered proxy fails.
// This configuration is not exposed directly but can be set from environment via
// defaults.ProxyFallbackAddrEnvar.
//
// See github.com/gravitational/teleport/issues/4141 for details.
ProxyReverseTunnelFallbackAddr *utils.NetAddr
// AllowTCPForwarding indicates that TCP port forwarding is allowed on this node
AllowTCPForwarding bool

View file

@ -18,9 +18,7 @@ package service
import (
"crypto/tls"
"net"
"path/filepath"
"strconv"
"time"
"golang.org/x/crypto/ssh"
@ -799,7 +797,7 @@ func (process *TeleportProcess) rotate(conn *Connector, localState auth.StateV2,
// newClient attempts to connect directly to the Auth Server. If it fails, it
// falls back to trying to connect to the Auth Server through the proxy.
// The proxy address might be configured in process environment as defaults.TunnelPublicAddrEnvar
// The proxy address might be configured in process environment as apidefaults.TunnelPublicAddrEnvar
// in which case, no attempt at discovering the reverse tunnel address is made.
func (process *TeleportProcess) newClient(authServers []utils.NetAddr, identity *auth.Identity) (*auth.Client, error) {
tlsConfig, err := identity.TLSConfig(process.Config.CipherSuites)
@ -807,8 +805,6 @@ func (process *TeleportProcess) newClient(authServers []utils.NetAddr, identity
return nil, trace.Wrap(err)
}
// Try and connect to the Auth Server. If the request fails, try and
// connect through a tunnel.
logger := process.log.WithField("auth-addrs", utils.NetAddrsToStrings(authServers))
logger.Debug("Attempting to connect to Auth Server directly.")
directClient, err := process.newClientDirect(authServers, tlsConfig)
@ -826,17 +822,11 @@ func (process *TeleportProcess) newClient(authServers []utils.NetAddr, identity
}
logger.Debug("Attempting to discover reverse tunnel address.")
var proxyAddr string
if process.Config.SSH.ProxyReverseTunnelFallbackAddr != nil {
proxyAddr = process.Config.SSH.ProxyReverseTunnelFallbackAddr.String()
} else {
// Discover address of SSH reverse tunnel server.
proxyAddr, err = process.findReverseTunnel(authServers)
if err != nil {
directErrLogger.Debug("Failed to connect to Auth Server directly.")
logger.WithError(err).Debug("Failed to discover reverse tunnel address.")
return nil, trace.Errorf("Failed to connect to Auth Server directly or over tunnel, no methods remaining.")
}
proxyAddr, err := process.findReverseTunnel(authServers)
if err != nil {
directErrLogger.Debug("Failed to connect to Auth Server directly.")
logger.WithError(err).Debug("Failed to discover reverse tunnel address.")
return nil, trace.Errorf("Failed to connect to Auth Server directly or over tunnel, no methods remaining.")
}
logger = process.log.WithField("proxy-addr", proxyAddr)
@ -859,58 +849,15 @@ func (process *TeleportProcess) findReverseTunnel(addrs []utils.NetAddr) (string
for _, addr := range addrs {
// In insecure mode, any certificate is accepted. In secure mode the hosts
// CAs are used to validate the certificate on the proxy.
resp, err := webclient.Find(process.ExitContext(),
addr.String(),
lib.IsInsecureDevMode(),
nil)
tunnelAddr, err := webclient.GetTunnelAddr(process.ExitContext(), addr.String(), lib.IsInsecureDevMode(), nil)
if err == nil {
return tunnelAddr(resp.Proxy)
return tunnelAddr, nil
}
errs = append(errs, err)
}
return "", trace.NewAggregate(errs...)
}
// tunnelAddr returns the tunnel address in the following preference order:
// 1. Reverse Tunnel Public Address.
// 2. SSH Proxy Public Address.
// 3. HTTP Proxy Public Address.
// 4. Tunnel Listen Address.
func tunnelAddr(settings webclient.ProxySettings) (string, error) {
// Extract the port the tunnel server is listening on.
netAddr, err := utils.ParseHostPortAddr(settings.SSH.TunnelListenAddr, defaults.SSHProxyTunnelListenPort)
if err != nil {
return "", trace.Wrap(err)
}
tunnelPort := netAddr.Port(defaults.SSHProxyTunnelListenPort)
// If a tunnel public address is set, nothing else has to be done, return it.
if settings.SSH.TunnelPublicAddr != "" {
return settings.SSH.TunnelPublicAddr, nil
}
// If a tunnel public address has not been set, but a related HTTP or SSH
// public address has been set, extract the hostname but use the port from
// the tunnel listen address.
if settings.SSH.SSHPublicAddr != "" {
addr, err := utils.ParseHostPortAddr(settings.SSH.SSHPublicAddr, tunnelPort)
if err != nil {
return "", trace.Wrap(err)
}
return net.JoinHostPort(addr.Host(), strconv.Itoa(tunnelPort)), nil
}
if settings.SSH.PublicAddr != "" {
addr, err := utils.ParseHostPortAddr(settings.SSH.PublicAddr, tunnelPort)
if err != nil {
return "", trace.Wrap(err)
}
return net.JoinHostPort(addr.Host(), strconv.Itoa(tunnelPort)), nil
}
// If nothing is set, fallback to the tunnel listen address.
return settings.SSH.TunnelListenAddr, nil
}
func (process *TeleportProcess) newClientThroughTunnel(proxyAddr string, tlsConfig *tls.Config, sshConfig *ssh.ClientConfig) (*auth.Client, error) {
clt, err := auth.NewClient(apiclient.Config{
Dialer: &reversetunnel.TunnelAuthDialer{

View file

@ -20,10 +20,8 @@ import (
"context"
"crypto/tls"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"github.com/gravitational/teleport"
apiclient "github.com/gravitational/teleport/api/client"
@ -262,55 +260,15 @@ func findReverseTunnel(ctx context.Context, addrs []utils.NetAddr, insecureTLS b
for _, addr := range addrs {
// In insecure mode, any certificate is accepted. In secure mode the hosts
// CAs are used to validate the certificate on the proxy.
resp, err := webclient.Find(ctx, addr.String(), insecureTLS, nil)
tunnelAddr, err := webclient.GetTunnelAddr(ctx, addr.String(), insecureTLS, nil)
if err == nil {
return tunnelAddr(addr, resp.Proxy)
return tunnelAddr, nil
}
errs = append(errs, err)
}
return "", trace.NewAggregate(errs...)
}
// tunnelAddr returns the tunnel address in the following preference order:
// 1. Reverse Tunnel Public Address.
// 2. SSH Proxy Public Address.
// 3. HTTP Proxy Public Address.
// 4. Tunnel Listen Address.
func tunnelAddr(webAddr utils.NetAddr, settings webclient.ProxySettings) (string, error) {
// Extract the port the tunnel server is listening on.
netAddr, err := utils.ParseHostPortAddr(settings.SSH.TunnelListenAddr, defaults.SSHProxyTunnelListenPort)
if err != nil {
return "", trace.Wrap(err)
}
tunnelPort := netAddr.Port(defaults.SSHProxyTunnelListenPort)
// If a tunnel public address is set, nothing else has to be done, return it.
if settings.SSH.TunnelPublicAddr != "" {
return settings.SSH.TunnelPublicAddr, nil
}
// If a tunnel public address has not been set, but a related HTTP or SSH
// public address has been set, extract the hostname but use the port from
// the tunnel listen address.
if settings.SSH.SSHPublicAddr != "" {
addr, err := utils.ParseHostPortAddr(settings.SSH.SSHPublicAddr, tunnelPort)
if err != nil {
return "", trace.Wrap(err)
}
return net.JoinHostPort(addr.Host(), strconv.Itoa(tunnelPort)), nil
}
if settings.SSH.PublicAddr != "" {
addr, err := utils.ParseHostPortAddr(settings.SSH.PublicAddr, tunnelPort)
if err != nil {
return "", trace.Wrap(err)
}
return net.JoinHostPort(addr.Host(), strconv.Itoa(tunnelPort)), nil
}
// If nothing is set, fallback to the address we dialed.
return net.JoinHostPort(webAddr.Host(), strconv.Itoa(tunnelPort)), nil
}
// applyConfig takes configuration values from the config file and applies
// them to 'service.Config' object.
//