mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 17:53:28 +00:00
Adds custom timeout message to SSH sessions (#7120)
* Adds the idle_timeout_message to the auth_service config file block * Plumbs the value through to the session monitor * Writes the message to stderr when a session times out due to inactivity * Adds some machinery to the test helpers to configure appropriate tests See-Also: #6091
This commit is contained in:
parent
59d39dee5a
commit
ca1e47bef0
|
@ -56,6 +56,15 @@ type ClusterNetworkingConfig interface {
|
|||
|
||||
// SetSessionControlTimeout sets the session control timeout.
|
||||
SetSessionControlTimeout(t time.Duration)
|
||||
|
||||
// GetClientIdleTimeoutMessage fetches the message to be sent to the client in
|
||||
// the event of an idle timeout. An empty string implies no message should
|
||||
// be sent.
|
||||
GetClientIdleTimeoutMessage() string
|
||||
|
||||
// SetClientIdleTimeoutMessage sets the inactivity timeout disconnection message
|
||||
// to be sent to the user.
|
||||
SetClientIdleTimeoutMessage(string)
|
||||
}
|
||||
|
||||
// NewClusterNetworkingConfigFromConfigFile is a convenience method to create
|
||||
|
@ -196,6 +205,14 @@ func (c *ClusterNetworkingConfigV2) SetSessionControlTimeout(d time.Duration) {
|
|||
c.Spec.SessionControlTimeout = Duration(d)
|
||||
}
|
||||
|
||||
func (c *ClusterNetworkingConfigV2) GetClientIdleTimeoutMessage() string {
|
||||
return c.Spec.ClientIdleTimeoutMessage
|
||||
}
|
||||
|
||||
func (c *ClusterNetworkingConfigV2) SetClientIdleTimeoutMessage(msg string) {
|
||||
c.Spec.ClientIdleTimeoutMessage = msg
|
||||
}
|
||||
|
||||
// setStaticFields sets static resource header and metadata fields.
|
||||
func (c *ClusterNetworkingConfigV2) setStaticFields() {
|
||||
c.Kind = KindClusterNetworkingConfig
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -688,6 +688,9 @@ message ClusterNetworkingConfigSpecV2 {
|
|||
// server before it begins terminating controlled sessions.
|
||||
int64 SessionControlTimeout = 4
|
||||
[ (gogoproto.jsontag) = "session_control_timeout", (gogoproto.casttype) = "Duration" ];
|
||||
|
||||
// ClientIdleTimeoutMessage is the message sent to the user when a connection times out.
|
||||
string ClientIdleTimeoutMessage = 5 [ (gogoproto.jsontag) = "idle_timeout_message" ];
|
||||
}
|
||||
|
||||
// SessionRecordingConfigV2 contains session recording configuration.
|
||||
|
|
|
@ -270,6 +270,11 @@ auth_service:
|
|||
# Examples: "30m", "1h" or "1h30m"
|
||||
client_idle_timeout: never
|
||||
|
||||
# Send a custom message to the client when they are disconnected due to
|
||||
# inactivity. The empty string indicates that no message will be sent.
|
||||
# (Currently only supported for SSH connections)
|
||||
client_idle_timeout_message: ""
|
||||
|
||||
# Determines if the clients will be forcefully disconnected when their
|
||||
# certificates expire in the middle of an active SSH session. (default is 'no')
|
||||
disconnect_expired_cert: no
|
||||
|
@ -298,6 +303,7 @@ auth_service:
|
|||
# the configured `data_dir` .
|
||||
license_file: /var/lib/teleport/license.pem
|
||||
|
||||
|
||||
# This section configures the 'node service':
|
||||
ssh_service:
|
||||
# Turns 'ssh' role on. Default is 'yes'
|
||||
|
|
|
@ -58,6 +58,9 @@ type TestAuthServerConfig struct {
|
|||
CipherSuites []uint16
|
||||
// Clock is used to control time in tests.
|
||||
Clock clockwork.FakeClock
|
||||
// ClusterNetworkingConfig allows a test to change the default
|
||||
// networking configuration.
|
||||
ClusterNetworkingConfig types.ClusterNetworkingConfig
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults checks and sets defaults
|
||||
|
@ -240,7 +243,12 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
|
|||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
err = srv.AuthServer.SetClusterNetworkingConfig(ctx, types.DefaultClusterNetworkingConfig())
|
||||
clusterNetworkingCfg := cfg.ClusterNetworkingConfig
|
||||
if clusterNetworkingCfg == nil {
|
||||
clusterNetworkingCfg = types.DefaultClusterNetworkingConfig()
|
||||
}
|
||||
|
||||
err = srv.AuthServer.SetClusterNetworkingConfig(ctx, clusterNetworkingCfg)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -759,7 +767,7 @@ func (t *TestTLSServer) Addr() net.Addr {
|
|||
return t.Listener.Addr()
|
||||
}
|
||||
|
||||
// Start starts TLS server on loopback address on the first lisenting socket
|
||||
// Start starts TLS server on loopback address on the first listening socket
|
||||
func (t *TestTLSServer) Start() error {
|
||||
go t.TLSServer.Serve()
|
||||
return nil
|
||||
|
|
|
@ -502,10 +502,11 @@ func applyAuthConfig(fc *FileConfig, cfg *service.Config) error {
|
|||
|
||||
// Set cluster networking configuration from file configuration.
|
||||
cfg.Auth.NetworkingConfig, err = types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{
|
||||
ClientIdleTimeout: fc.Auth.ClientIdleTimeout,
|
||||
KeepAliveInterval: fc.Auth.KeepAliveInterval,
|
||||
KeepAliveCountMax: fc.Auth.KeepAliveCountMax,
|
||||
SessionControlTimeout: fc.Auth.SessionControlTimeout,
|
||||
ClientIdleTimeout: fc.Auth.ClientIdleTimeout,
|
||||
ClientIdleTimeoutMessage: fc.Auth.ClientIdleTimeoutMessage,
|
||||
KeepAliveInterval: fc.Auth.KeepAliveInterval,
|
||||
KeepAliveCountMax: fc.Auth.KeepAliveCountMax,
|
||||
SessionControlTimeout: fc.Auth.SessionControlTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
|
|
@ -463,6 +463,11 @@ type Auth struct {
|
|||
// KeepAliveCountMax set the number of keep-alive messages that can be
|
||||
// missed before the server disconnects the client.
|
||||
KeepAliveCountMax int64 `yaml:"keep_alive_count_max,omitempty"`
|
||||
|
||||
// ClientIdleTimeoutMessage is sent to the client when the inactivity timeout
|
||||
// expires. The empty string implies no message should be sent prior to
|
||||
// disconnection.
|
||||
ClientIdleTimeoutMessage string `yaml:"client_idle_timeout_message,omitempty"`
|
||||
}
|
||||
|
||||
// TrustedCluster struct holds configuration values under "trusted_clusters" key
|
||||
|
|
|
@ -115,6 +115,12 @@ fake certificate
|
|||
// mutated programatically by test cases and then re-serialised to test the
|
||||
// config file loader
|
||||
const minimalConfigFile string = `
|
||||
teleport:
|
||||
nodename: testing
|
||||
|
||||
auth_service:
|
||||
enabled: yes
|
||||
|
||||
ssh_service:
|
||||
enabled: yes
|
||||
`
|
||||
|
@ -138,6 +144,73 @@ func editConfig(t *testing.T, mutate func(cfg cfgMap)) []byte {
|
|||
return text
|
||||
}
|
||||
|
||||
// requireEqual creates an assertion function with a bound `expected` value
|
||||
// for use with table-driven tests
|
||||
func requireEqual(expected interface{}) require.ValueAssertionFunc {
|
||||
return func(t require.TestingT, actual interface{}, msgAndArgs ...interface{}) {
|
||||
require.Equal(t, expected, actual, msgAndArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthSection tests the config parser for the `auth_service` config block
|
||||
func TestAuthSection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
mutate func(cfgMap)
|
||||
expectError require.ErrorAssertionFunc
|
||||
expectEnabled require.BoolAssertionFunc
|
||||
expectIdleMsg require.ValueAssertionFunc
|
||||
}{
|
||||
{
|
||||
desc: "Default",
|
||||
mutate: func(cfg cfgMap) {},
|
||||
expectError: require.NoError,
|
||||
expectEnabled: require.True,
|
||||
expectIdleMsg: require.Empty,
|
||||
}, {
|
||||
desc: "Enabled",
|
||||
mutate: func(cfg cfgMap) {
|
||||
cfg["auth_service"].(cfgMap)["enabled"] = "yes"
|
||||
},
|
||||
expectError: require.NoError,
|
||||
expectEnabled: require.True,
|
||||
}, {
|
||||
desc: "Disabled",
|
||||
mutate: func(cfg cfgMap) {
|
||||
cfg["auth_service"].(cfgMap)["enabled"] = "no"
|
||||
},
|
||||
expectError: require.NoError,
|
||||
expectEnabled: require.False,
|
||||
}, {
|
||||
desc: "Idle timeout message",
|
||||
mutate: func(cfg cfgMap) {
|
||||
cfg["auth_service"].(cfgMap)["client_idle_timeout_message"] = "Are you pondering what I'm pondering?"
|
||||
},
|
||||
expectError: require.NoError,
|
||||
expectIdleMsg: requireEqual("Are you pondering what I'm pondering?"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
t.Run(tt.desc, func(t *testing.T) {
|
||||
text := bytes.NewBuffer(editConfig(t, tt.mutate))
|
||||
|
||||
cfg, err := ReadConfig(text)
|
||||
tt.expectError(t, err)
|
||||
|
||||
if tt.expectEnabled != nil {
|
||||
tt.expectEnabled(t, cfg.Auth.Enabled())
|
||||
}
|
||||
|
||||
if tt.expectIdleMsg != nil {
|
||||
tt.expectIdleMsg(t, cfg.Auth.ClientIdleTimeoutMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHSection tests the config parser for the SSH config block
|
||||
func TestSSHSection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -543,6 +543,11 @@ type SSHConfig struct {
|
|||
|
||||
// AllowTCPForwarding indicates that TCP port forwarding is allowed on this node
|
||||
AllowTCPForwarding bool
|
||||
|
||||
// IdleTimeoutMessage is sent to the client when a session expires due to
|
||||
// the inactivity timeout expiring. The empty string indicates that no
|
||||
// timeout message will be sent.
|
||||
IdleTimeoutMessage string
|
||||
}
|
||||
|
||||
// KubeConfig specifies configuration for kubernetes service
|
||||
|
|
|
@ -277,6 +277,11 @@ type ServerContext struct {
|
|||
// port to connect to in a "direct-tcpip" request. This value is only
|
||||
// populated for port forwarding requests.
|
||||
DstAddr string
|
||||
|
||||
// Monitor is a handle to the idle timeout monitor that is watching this
|
||||
// session context. May be nil if there is no set idle timeout or we are
|
||||
// not monitoring certificate expiry.
|
||||
Monitor *Monitor
|
||||
}
|
||||
|
||||
// NewServerContext creates a new *ServerContext which is used to pass and
|
||||
|
@ -340,7 +345,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
|
|||
})
|
||||
|
||||
if !child.disconnectExpiredCert.IsZero() || child.clientIdleTimeout != 0 {
|
||||
mon, err := NewMonitor(MonitorConfig{
|
||||
child.Monitor, err = NewMonitor(MonitorConfig{
|
||||
DisconnectExpiredCert: child.disconnectExpiredCert,
|
||||
ClientIdleTimeout: child.clientIdleTimeout,
|
||||
Clock: child.srv.GetClock(),
|
||||
|
@ -357,7 +362,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
|
|||
child.Close()
|
||||
return nil, nil, trace.Wrap(err)
|
||||
}
|
||||
go mon.Start()
|
||||
go child.Monitor.Start()
|
||||
}
|
||||
|
||||
// Create pipe used to send command to child process.
|
||||
|
|
|
@ -19,6 +19,7 @@ package srv
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -78,6 +79,8 @@ type MonitorConfig struct {
|
|||
Emitter apievents.Emitter
|
||||
// Entry is a logging entry
|
||||
Entry log.FieldLogger
|
||||
// A message sent to the client when the idle timeout expires
|
||||
IdleTimeoutMessage string
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults checks values and sets defaults
|
||||
|
@ -116,12 +119,16 @@ func NewMonitor(cfg MonitorConfig) (*Monitor, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Monitor monitors connection activity
|
||||
// and disconnects connections with expired certificates
|
||||
// or with periods of inactivity
|
||||
// Monitor monitors the activity on a single connection and disconnects
|
||||
// that connection if the certificate expires or after
|
||||
// periods of inactivity
|
||||
type Monitor struct {
|
||||
// MonitorConfig is a connection monitor configuration
|
||||
MonitorConfig
|
||||
|
||||
// MessageWriter wraps a channel to send text messages to the client. Use
|
||||
// for disconnection messages, etc.
|
||||
MessageWriter io.StringWriter
|
||||
}
|
||||
|
||||
// Start starts monitoring connection
|
||||
|
@ -194,6 +201,12 @@ func (w *Monitor) Start() {
|
|||
now.Sub(clientLastActive), w.ClientIdleTimeout)
|
||||
}
|
||||
w.Entry.Debugf("Disconnecting client: %v", event.Reason)
|
||||
|
||||
if w.MessageWriter != nil && w.IdleTimeoutMessage != "" {
|
||||
if _, err := w.MessageWriter.WriteString(w.IdleTimeoutMessage); err != nil {
|
||||
w.Entry.WithError(err).Warn("Failed to send idle timeout message.")
|
||||
}
|
||||
}
|
||||
w.Conn.Close()
|
||||
|
||||
if err := w.Emitter.EmitAuditEvent(w.Context, event); err != nil {
|
||||
|
|
|
@ -1097,6 +1097,17 @@ func (s *Server) canPortForward(scx *srv.ServerContext, channel ssh.Channel) err
|
|||
return nil
|
||||
}
|
||||
|
||||
// stderrWriter wraps an ssh.Channel in an implementation of io.StringWriter
|
||||
// that sends anything written back the client over its stderr stream
|
||||
type stderrWriter struct {
|
||||
channel ssh.Channel
|
||||
}
|
||||
|
||||
func (w *stderrWriter) WriteString(s string) (int, error) {
|
||||
writeStderr(w.channel, s)
|
||||
return len(s), nil
|
||||
}
|
||||
|
||||
// handleDirectTCPIPRequest handles port forwarding requests.
|
||||
func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ccx *sshutils.ConnectionContext, identityContext srv.IdentityContext, channel ssh.Channel, req *sshutils.DirectTCPIPReq) {
|
||||
// Create context for this channel. This context will be closed when
|
||||
|
@ -1250,6 +1261,11 @@ func (s *Server) handleSessionRequests(ctx context.Context, ccx *sshutils.Connec
|
|||
return
|
||||
}
|
||||
|
||||
if scx.Monitor != nil {
|
||||
scx.Monitor.IdleTimeoutMessage = netConfig.GetClientIdleTimeoutMessage()
|
||||
scx.Monitor.MessageWriter = &stderrWriter{channel: ch}
|
||||
}
|
||||
|
||||
// The keep-alive loop will keep pinging the remote server and after it has
|
||||
// missed a certain number of keep-alive requests it will cancel the
|
||||
// closeContext which signals the server to shutdown.
|
||||
|
|
|
@ -257,6 +257,80 @@ func newNodeClient(t *testing.T, testSvr *auth.TestServer) (*auth.Client, string
|
|||
|
||||
const hostID = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
func startReadAll(r io.Reader) <-chan []byte {
|
||||
ch := make(chan []byte)
|
||||
go func() {
|
||||
data, _ := ioutil.ReadAll(r)
|
||||
ch <- data
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func waitForBytes(ch <-chan []byte) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case data := <-ch:
|
||||
return data, nil
|
||||
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func TestInactivityTimeout(t *testing.T) {
|
||||
const timeoutMessage = "You snooze, you loose."
|
||||
|
||||
// Given
|
||||
// * a running auth server configured with a 5s inactivity timeout,
|
||||
// * a running SSH server configured with a given disconnection message
|
||||
// * a client connected to the SSH server,
|
||||
// * an SSH session running over the client connection
|
||||
mutateCfg := func(cfg *auth.TestServerConfig) {
|
||||
networkCfg := types.DefaultClusterNetworkingConfig()
|
||||
networkCfg.SetClientIdleTimeout(5 * time.Second)
|
||||
networkCfg.SetClientIdleTimeoutMessage(timeoutMessage)
|
||||
|
||||
cfg.Auth.ClusterNetworkingConfig = networkCfg
|
||||
}
|
||||
f := newCustomFixture(t, mutateCfg)
|
||||
|
||||
// If all goes well, the client will be closed by the time cleanup happens,
|
||||
// so change the assertion on closing the client to expect it to fail
|
||||
f.ssh.assertCltClose = require.Error
|
||||
|
||||
se, err := f.ssh.clt.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer se.Close()
|
||||
|
||||
stderr, err := se.StderrPipe()
|
||||
require.NoError(t, err)
|
||||
stdErrCh := startReadAll(stderr)
|
||||
|
||||
endCh := make(chan error)
|
||||
go func() { endCh <- f.ssh.clt.Wait() }()
|
||||
|
||||
// When I let the session idle (with the clock running at approx 10x speed)...
|
||||
sessionHasFinished := func() bool {
|
||||
f.clock.Advance(1 * time.Second)
|
||||
select {
|
||||
case <-endCh:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
require.Eventually(t, sessionHasFinished, 6*time.Second, 100*time.Millisecond,
|
||||
"Timed out waiting for session to finish")
|
||||
|
||||
// Expect that the idle timeout has been delivered via stderr
|
||||
text, err := waitForBytes(stdErrCh)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, timeoutMessage, string(text))
|
||||
}
|
||||
|
||||
// TestDirectTCPIP ensures that the server can create a "direct-tcpip"
|
||||
// channel to the target address. The "direct-tcpip" channel is what port
|
||||
// forwarding is built upon.
|
||||
|
|
Loading…
Reference in a new issue