Adds custom timeout message to SSH sessions (#7120)

* Adds the idle_timeout_message to the auth_service config file block
* Plumbs the value through to the session monitor
* Writes the message to stderr when a session times out due to inactivity
* Adds some machinery to the test helpers to configure appropriate tests

See-Also: #6091
This commit is contained in:
Trent Clarke 2021-06-25 14:12:50 +10:00 committed by GitHub
parent 59d39dee5a
commit ca1e47bef0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 830 additions and 557 deletions

View file

@ -56,6 +56,15 @@ type ClusterNetworkingConfig interface {
// SetSessionControlTimeout sets the session control timeout.
SetSessionControlTimeout(t time.Duration)
// GetClientIdleTimeoutMessage fetches the message to be sent to the client in
// the event of an idle timeout. An empty string implies no message should
// be sent.
GetClientIdleTimeoutMessage() string
// SetClientIdleTimeoutMessage sets the inactivity timeout disconnection message
// to be sent to the user.
SetClientIdleTimeoutMessage(string)
}
// NewClusterNetworkingConfigFromConfigFile is a convenience method to create
@ -196,6 +205,14 @@ func (c *ClusterNetworkingConfigV2) SetSessionControlTimeout(d time.Duration) {
c.Spec.SessionControlTimeout = Duration(d)
}
func (c *ClusterNetworkingConfigV2) GetClientIdleTimeoutMessage() string {
return c.Spec.ClientIdleTimeoutMessage
}
func (c *ClusterNetworkingConfigV2) SetClientIdleTimeoutMessage(msg string) {
c.Spec.ClientIdleTimeoutMessage = msg
}
// setStaticFields sets static resource header and metadata fields.
func (c *ClusterNetworkingConfigV2) setStaticFields() {
c.Kind = KindClusterNetworkingConfig

File diff suppressed because it is too large Load diff

View file

@ -688,6 +688,9 @@ message ClusterNetworkingConfigSpecV2 {
// server before it begins terminating controlled sessions.
int64 SessionControlTimeout = 4
[ (gogoproto.jsontag) = "session_control_timeout", (gogoproto.casttype) = "Duration" ];
// ClientIdleTimeoutMessage is the message sent to the user when a connection times out.
string ClientIdleTimeoutMessage = 5 [ (gogoproto.jsontag) = "idle_timeout_message" ];
}
// SessionRecordingConfigV2 contains session recording configuration.

View file

@ -270,6 +270,11 @@ auth_service:
# Examples: "30m", "1h" or "1h30m"
client_idle_timeout: never
# Send a custom message to the client when they are disconnected due to
# inactivity. The empty string indicates that no message will be sent.
# (Currently only supported for SSH connections)
client_idle_timeout_message: ""
# Determines if the clients will be forcefully disconnected when their
# certificates expire in the middle of an active SSH session. (default is 'no')
disconnect_expired_cert: no
@ -298,6 +303,7 @@ auth_service:
# the configured `data_dir` .
license_file: /var/lib/teleport/license.pem
# This section configures the 'node service':
ssh_service:
# Turns 'ssh' role on. Default is 'yes'

View file

@ -58,6 +58,9 @@ type TestAuthServerConfig struct {
CipherSuites []uint16
// Clock is used to control time in tests.
Clock clockwork.FakeClock
// ClusterNetworkingConfig allows a test to change the default
// networking configuration.
ClusterNetworkingConfig types.ClusterNetworkingConfig
}
// CheckAndSetDefaults checks and sets defaults
@ -240,7 +243,12 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
return nil, trace.Wrap(err)
}
err = srv.AuthServer.SetClusterNetworkingConfig(ctx, types.DefaultClusterNetworkingConfig())
clusterNetworkingCfg := cfg.ClusterNetworkingConfig
if clusterNetworkingCfg == nil {
clusterNetworkingCfg = types.DefaultClusterNetworkingConfig()
}
err = srv.AuthServer.SetClusterNetworkingConfig(ctx, clusterNetworkingCfg)
if err != nil {
return nil, trace.Wrap(err)
}
@ -759,7 +767,7 @@ func (t *TestTLSServer) Addr() net.Addr {
return t.Listener.Addr()
}
// Start starts TLS server on loopback address on the first lisenting socket
// Start starts TLS server on loopback address on the first listening socket
func (t *TestTLSServer) Start() error {
go t.TLSServer.Serve()
return nil

View file

@ -502,10 +502,11 @@ func applyAuthConfig(fc *FileConfig, cfg *service.Config) error {
// Set cluster networking configuration from file configuration.
cfg.Auth.NetworkingConfig, err = types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{
ClientIdleTimeout: fc.Auth.ClientIdleTimeout,
KeepAliveInterval: fc.Auth.KeepAliveInterval,
KeepAliveCountMax: fc.Auth.KeepAliveCountMax,
SessionControlTimeout: fc.Auth.SessionControlTimeout,
ClientIdleTimeout: fc.Auth.ClientIdleTimeout,
ClientIdleTimeoutMessage: fc.Auth.ClientIdleTimeoutMessage,
KeepAliveInterval: fc.Auth.KeepAliveInterval,
KeepAliveCountMax: fc.Auth.KeepAliveCountMax,
SessionControlTimeout: fc.Auth.SessionControlTimeout,
})
if err != nil {
return trace.Wrap(err)

View file

@ -463,6 +463,11 @@ type Auth struct {
// KeepAliveCountMax set the number of keep-alive messages that can be
// missed before the server disconnects the client.
KeepAliveCountMax int64 `yaml:"keep_alive_count_max,omitempty"`
// ClientIdleTimeoutMessage is sent to the client when the inactivity timeout
// expires. The empty string implies no message should be sent prior to
// disconnection.
ClientIdleTimeoutMessage string `yaml:"client_idle_timeout_message,omitempty"`
}
// TrustedCluster struct holds configuration values under "trusted_clusters" key

View file

@ -115,6 +115,12 @@ fake certificate
// mutated programatically by test cases and then re-serialised to test the
// config file loader
const minimalConfigFile string = `
teleport:
nodename: testing
auth_service:
enabled: yes
ssh_service:
enabled: yes
`
@ -138,6 +144,73 @@ func editConfig(t *testing.T, mutate func(cfg cfgMap)) []byte {
return text
}
// requireEqual creates an assertion function with a bound `expected` value
// for use with table-driven tests
func requireEqual(expected interface{}) require.ValueAssertionFunc {
return func(t require.TestingT, actual interface{}, msgAndArgs ...interface{}) {
require.Equal(t, expected, actual, msgAndArgs...)
}
}
// TestAuthSection tests the config parser for the `auth_service` config block
func TestAuthSection(t *testing.T) {
t.Parallel()
testCases := []struct {
desc string
mutate func(cfgMap)
expectError require.ErrorAssertionFunc
expectEnabled require.BoolAssertionFunc
expectIdleMsg require.ValueAssertionFunc
}{
{
desc: "Default",
mutate: func(cfg cfgMap) {},
expectError: require.NoError,
expectEnabled: require.True,
expectIdleMsg: require.Empty,
}, {
desc: "Enabled",
mutate: func(cfg cfgMap) {
cfg["auth_service"].(cfgMap)["enabled"] = "yes"
},
expectError: require.NoError,
expectEnabled: require.True,
}, {
desc: "Disabled",
mutate: func(cfg cfgMap) {
cfg["auth_service"].(cfgMap)["enabled"] = "no"
},
expectError: require.NoError,
expectEnabled: require.False,
}, {
desc: "Idle timeout message",
mutate: func(cfg cfgMap) {
cfg["auth_service"].(cfgMap)["client_idle_timeout_message"] = "Are you pondering what I'm pondering?"
},
expectError: require.NoError,
expectIdleMsg: requireEqual("Are you pondering what I'm pondering?"),
},
}
for _, tt := range testCases {
t.Run(tt.desc, func(t *testing.T) {
text := bytes.NewBuffer(editConfig(t, tt.mutate))
cfg, err := ReadConfig(text)
tt.expectError(t, err)
if tt.expectEnabled != nil {
tt.expectEnabled(t, cfg.Auth.Enabled())
}
if tt.expectIdleMsg != nil {
tt.expectIdleMsg(t, cfg.Auth.ClientIdleTimeoutMessage)
}
})
}
}
// TestSSHSection tests the config parser for the SSH config block
func TestSSHSection(t *testing.T) {
t.Parallel()

View file

@ -543,6 +543,11 @@ type SSHConfig struct {
// AllowTCPForwarding indicates that TCP port forwarding is allowed on this node
AllowTCPForwarding bool
// IdleTimeoutMessage is sent to the client when a session expires due to
// the inactivity timeout expiring. The empty string indicates that no
// timeout message will be sent.
IdleTimeoutMessage string
}
// KubeConfig specifies configuration for kubernetes service

View file

@ -277,6 +277,11 @@ type ServerContext struct {
// port to connect to in a "direct-tcpip" request. This value is only
// populated for port forwarding requests.
DstAddr string
// Monitor is a handle to the idle timeout monitor that is watching this
// session context. May be nil if there is no set idle timeout or we are
// not monitoring certificate expiry.
Monitor *Monitor
}
// NewServerContext creates a new *ServerContext which is used to pass and
@ -340,7 +345,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
})
if !child.disconnectExpiredCert.IsZero() || child.clientIdleTimeout != 0 {
mon, err := NewMonitor(MonitorConfig{
child.Monitor, err = NewMonitor(MonitorConfig{
DisconnectExpiredCert: child.disconnectExpiredCert,
ClientIdleTimeout: child.clientIdleTimeout,
Clock: child.srv.GetClock(),
@ -357,7 +362,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
child.Close()
return nil, nil, trace.Wrap(err)
}
go mon.Start()
go child.Monitor.Start()
}
// Create pipe used to send command to child process.

View file

@ -19,6 +19,7 @@ package srv
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
@ -78,6 +79,8 @@ type MonitorConfig struct {
Emitter apievents.Emitter
// Entry is a logging entry
Entry log.FieldLogger
// A message sent to the client when the idle timeout expires
IdleTimeoutMessage string
}
// CheckAndSetDefaults checks values and sets defaults
@ -116,12 +119,16 @@ func NewMonitor(cfg MonitorConfig) (*Monitor, error) {
}, nil
}
// Monitor monitors connection activity
// and disconnects connections with expired certificates
// or with periods of inactivity
// Monitor monitors the activity on a single connection and disconnects
// that connection if the certificate expires or after
// periods of inactivity
type Monitor struct {
// MonitorConfig is a connection monitor configuration
MonitorConfig
// MessageWriter wraps a channel to send text messages to the client. Use
// for disconnection messages, etc.
MessageWriter io.StringWriter
}
// Start starts monitoring connection
@ -194,6 +201,12 @@ func (w *Monitor) Start() {
now.Sub(clientLastActive), w.ClientIdleTimeout)
}
w.Entry.Debugf("Disconnecting client: %v", event.Reason)
if w.MessageWriter != nil && w.IdleTimeoutMessage != "" {
if _, err := w.MessageWriter.WriteString(w.IdleTimeoutMessage); err != nil {
w.Entry.WithError(err).Warn("Failed to send idle timeout message.")
}
}
w.Conn.Close()
if err := w.Emitter.EmitAuditEvent(w.Context, event); err != nil {

View file

@ -1097,6 +1097,17 @@ func (s *Server) canPortForward(scx *srv.ServerContext, channel ssh.Channel) err
return nil
}
// stderrWriter wraps an ssh.Channel in an implementation of io.StringWriter
// that sends anything written back the client over its stderr stream
type stderrWriter struct {
channel ssh.Channel
}
func (w *stderrWriter) WriteString(s string) (int, error) {
writeStderr(w.channel, s)
return len(s), nil
}
// handleDirectTCPIPRequest handles port forwarding requests.
func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, channel ssh.Channel, req *sshutils.DirectTCPIPReq) {
// Create context for this channel. This context will be closed when
@ -1250,6 +1261,11 @@ func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.Connec
return
}
if scx.Monitor != nil {
scx.Monitor.IdleTimeoutMessage = netConfig.GetClientIdleTimeoutMessage()
scx.Monitor.MessageWriter = &stderrWriter{channel: ch}
}
// The keep-alive loop will keep pinging the remote server and after it has
// missed a certain number of keep-alive requests it will cancel the
// closeContext which signals the server to shutdown.

View file

@ -257,6 +257,80 @@ func newNodeClient(t *testing.T, testSvr *auth.TestServer) (*auth.Client, string
const hostID = "00000000-0000-0000-0000-000000000000"
func startReadAll(r io.Reader) <-chan []byte {
ch := make(chan []byte)
go func() {
data, _ := ioutil.ReadAll(r)
ch <- data
}()
return ch
}
func waitForBytes(ch <-chan []byte) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case data := <-ch:
return data, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func TestInactivityTimeout(t *testing.T) {
const timeoutMessage = "You snooze, you loose."
// Given
// * a running auth server configured with a 5s inactivity timeout,
// * a running SSH server configured with a given disconnection message
// * a client connected to the SSH server,
// * an SSH session running over the client connection
mutateCfg := func(cfg *auth.TestServerConfig) {
networkCfg := types.DefaultClusterNetworkingConfig()
networkCfg.SetClientIdleTimeout(5 * time.Second)
networkCfg.SetClientIdleTimeoutMessage(timeoutMessage)
cfg.Auth.ClusterNetworkingConfig = networkCfg
}
f := newCustomFixture(t, mutateCfg)
// If all goes well, the client will be closed by the time cleanup happens,
// so change the assertion on closing the client to expect it to fail
f.ssh.assertCltClose = require.Error
se, err := f.ssh.clt.NewSession()
require.NoError(t, err)
defer se.Close()
stderr, err := se.StderrPipe()
require.NoError(t, err)
stdErrCh := startReadAll(stderr)
endCh := make(chan error)
go func() { endCh <- f.ssh.clt.Wait() }()
// When I let the session idle (with the clock running at approx 10x speed)...
sessionHasFinished := func() bool {
f.clock.Advance(1 * time.Second)
select {
case <-endCh:
return true
default:
return false
}
}
require.Eventually(t, sessionHasFinished, 6*time.Second, 100*time.Millisecond,
"Timed out waiting for session to finish")
// Expect that the idle timeout has been delivered via stderr
text, err := waitForBytes(stdErrCh)
require.NoError(t, err)
require.Equal(t, timeoutMessage, string(text))
}
// TestDirectTCPIP ensures that the server can create a "direct-tcpip"
// channel to the target address. The "direct-tcpip" channel is what port
// forwarding is built upon.