teleport/lib/sshutils/server.go
rosstimothy 22f28130ab
Prevent zombie sessions being left behind for web sessions (#32141)
The ssh session was not being closed for web sessions which resulted
in zombie sessions being left around until the ssh service was
restarted. TestTerminal was updated to assert that the session
tracker eventually transitions to the terminated state when the
client terminates the web socket.

Fixes #32120
2023-09-20 13:45:55 +00:00

755 lines
21 KiB
Go

/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package sshutils contains the implementations of the base SSH
// server used throughout Teleport.
package sshutils
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/ssh"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/observability/tracing"
tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/srv/ingress"
"github.com/gravitational/teleport/lib/utils"
)
var proxyConnectionLimitHitCount = prometheus.NewCounter(
prometheus.CounterOpts{
Name: teleport.MetricProxyConnectionLimitHit,
Help: "Number of times the proxy connection limit was exceeded",
},
)
// Server is a generic implementation of an SSH server. All Teleport
// services (auth, proxy, ssh) use this as a base to accept SSH connections.
type Server struct {
sync.RWMutex
log logrus.FieldLogger
// component is a name of the facility which uses this server,
// used for logging/debugging. typically it's "proxy" or "auth api", etc
component string
// addr is the address this server binds to and listens on
addr utils.NetAddr
// listener is usually the listening TCP/IP socket
listener net.Listener
newChanHandler NewChanHandler
reqHandler RequestHandler
newConnHandler NewConnHandler
cfg ssh.ServerConfig
limiter *limiter.Limiter
closeContext context.Context
closeFunc context.CancelFunc
// userConns tracks amount of current active connections with user certificates.
userConns int32
// shutdownPollPeriod sets polling period for shutdown
shutdownPollPeriod time.Duration
// insecureSkipHostValidation does not validate the host signers to make sure
// they are a valid certificate. Used in tests.
insecureSkipHostValidation bool
// fips means Teleport started in a FedRAMP/FIPS 140-2 compliant
// configuration.
fips bool
// tracerProvider is used to create tracers capable
// of starting spans.
tracerProvider oteltrace.TracerProvider
// clock is used to control time.
clock clockwork.Clock
clusterName string
// ingressReporter reports new and active connections.
ingressReporter *ingress.Reporter
// ingressService the service name passed to the ingress reporter.
ingressService string
}
const (
// SSHVersionPrefix is the prefix of "server version" string which begins
// every SSH handshake. It MUST start with "SSH-2.0" according to
// https://tools.ietf.org/html/rfc4253#page-4
SSHVersionPrefix = "SSH-2.0-Teleport"
// MaxVersionStringBytes is the maximum number of bytes allowed for a
// SSH version string
// https://tools.ietf.org/html/rfc4253
MaxVersionStringBytes = 255
)
// ServerOption is a functional argument for server
type ServerOption func(cfg *Server) error
// SetIngressReporter sets the reporter for reporting new and active connections.
func SetIngressReporter(service string, r *ingress.Reporter) ServerOption {
return func(s *Server) error {
s.ingressReporter = r
s.ingressService = service
return nil
}
}
// SetLogger sets the logger for the server
func SetLogger(logger logrus.FieldLogger) ServerOption {
return func(s *Server) error {
s.log = logger.WithField(trace.Component, "ssh:"+s.component)
return nil
}
}
func SetLimiter(limiter *limiter.Limiter) ServerOption {
return func(s *Server) error {
s.limiter = limiter
return nil
}
}
// SetShutdownPollPeriod sets a polling period for graceful shutdowns of SSH servers
func SetShutdownPollPeriod(period time.Duration) ServerOption {
return func(s *Server) error {
s.shutdownPollPeriod = period
return nil
}
}
// SetInsecureSkipHostValidation does not validate the host signers to make sure
// they are a valid certificate. Used in tests.
func SetInsecureSkipHostValidation() ServerOption {
return func(s *Server) error {
s.insecureSkipHostValidation = true
return nil
}
}
// SetTracerProvider sets the tracer provider for the server.
func SetTracerProvider(provider oteltrace.TracerProvider) ServerOption {
return func(s *Server) error {
s.tracerProvider = provider
return nil
}
}
// SetClock sets the server's clock.
func SetClock(clock clockwork.Clock) ServerOption {
return func(s *Server) error {
s.clock = clock
return nil
}
}
func SetClusterName(clusterName string) ServerOption {
return func(s *Server) error {
s.clusterName = clusterName
return nil
}
}
func NewServer(
component string,
a utils.NetAddr,
h NewChanHandler,
hostSigners []ssh.Signer,
ah AuthMethods,
opts ...ServerOption,
) (*Server, error) {
err := metrics.RegisterPrometheusCollectors(proxyConnectionLimitHitCount)
if err != nil {
return nil, trace.Wrap(err)
}
closeContext, cancel := context.WithCancel(context.TODO())
s := &Server{
log: logrus.WithFields(logrus.Fields{
trace.Component: "ssh:" + component,
}),
addr: a,
newChanHandler: h,
component: component,
closeContext: closeContext,
closeFunc: cancel,
}
s.limiter, err = limiter.NewLimiter(limiter.Config{})
if err != nil {
return nil, trace.Wrap(err)
}
for _, o := range opts {
if err := o(s); err != nil {
return nil, err
}
}
if s.shutdownPollPeriod == 0 {
s.shutdownPollPeriod = defaults.ShutdownPollPeriod
}
if s.tracerProvider == nil {
s.tracerProvider = tracing.DefaultProvider()
}
err = s.checkArguments(a, h, hostSigners, ah)
if err != nil {
return nil, err
}
for _, signer := range hostSigners {
(&s.cfg).AddHostKey(signer)
}
s.cfg.PublicKeyCallback = ah.PublicKey
s.cfg.PasswordCallback = ah.Password
s.cfg.NoClientAuth = ah.NoClient
// Teleport servers need to identify as such to allow passing of the client
// IP from the client to the proxy to the destination node.
s.cfg.ServerVersion = SSHVersionPrefix
return s, nil
}
func SetSSHConfig(cfg ssh.ServerConfig) ServerOption {
return func(s *Server) error {
s.cfg = cfg
return nil
}
}
func SetRequestHandler(req RequestHandler) ServerOption {
return func(s *Server) error {
s.reqHandler = req
return nil
}
}
func SetNewConnHandler(handler NewConnHandler) ServerOption {
return func(s *Server) error {
s.newConnHandler = handler
return nil
}
}
func SetCiphers(ciphers []string) ServerOption {
return func(s *Server) error {
s.log.Debugf("Supported ciphers: %q.", ciphers)
if ciphers != nil {
s.cfg.Ciphers = ciphers
}
return nil
}
}
func SetKEXAlgorithms(kexAlgorithms []string) ServerOption {
return func(s *Server) error {
s.log.Debugf("Supported KEX algorithms: %q.", kexAlgorithms)
if kexAlgorithms != nil {
s.cfg.KeyExchanges = kexAlgorithms
}
return nil
}
}
func SetMACAlgorithms(macAlgorithms []string) ServerOption {
return func(s *Server) error {
s.log.Debugf("Supported MAC algorithms: %q.", macAlgorithms)
if macAlgorithms != nil {
s.cfg.MACs = macAlgorithms
}
return nil
}
}
func SetFIPS(fips bool) ServerOption {
return func(s *Server) error {
s.fips = fips
return nil
}
}
func (s *Server) Addr() string {
s.RLock()
defer s.RUnlock()
if s.listener == nil {
return ""
}
return s.listener.Addr().String()
}
func (s *Server) Serve(listener net.Listener) error {
if err := s.SetListener(listener); err != nil {
return trace.Wrap(err)
}
s.acceptConnections()
return nil
}
func (s *Server) Start() error {
if s.listener == nil {
listener, err := net.Listen(s.addr.AddrNetwork, s.addr.Addr)
if err != nil {
return trace.ConvertSystemError(err)
}
if err := s.SetListener(s.limiter.WrapListener(listener)); err != nil {
return trace.Wrap(err)
}
}
s.log.WithField("addr", s.listener.Addr().String()).Debug("Server start.")
go s.acceptConnections()
return nil
}
func (s *Server) SetListener(l net.Listener) error {
s.Lock()
defer s.Unlock()
if s.listener != nil {
return trace.BadParameter("listener is already set to %v", s.listener.Addr())
}
s.listener = l
return nil
}
// Wait waits until server stops serving new connections
// on the listener socket
func (s *Server) Wait(ctx context.Context) {
select {
case <-s.closeContext.Done():
case <-ctx.Done():
}
}
// Shutdown initiates graceful shutdown - waiting until all active
// connections will get closed
func (s *Server) Shutdown(ctx context.Context) error {
// close listener to stop receiving new connections
err := s.Close()
s.Wait(ctx)
activeConnections := s.trackUserConnections(0)
if activeConnections == 0 {
return err
}
s.log.Infof("Shutdown: waiting for %v connections to finish.", activeConnections)
lastReport := time.Time{}
ticker := time.NewTicker(s.shutdownPollPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
activeConnections = s.trackUserConnections(0)
if activeConnections == 0 {
return err
}
if time.Since(lastReport) > 10*s.shutdownPollPeriod {
s.log.Infof("Shutdown: waiting for %v connections to finish.", activeConnections)
lastReport = time.Now()
}
case <-ctx.Done():
s.log.Infof("Context canceled wait, returning.")
return trace.ConnectionProblem(err, "context canceled")
}
}
}
// Close closes listening socket and stops accepting connections
func (s *Server) Close() error {
s.Lock()
defer s.Unlock()
defer s.closeFunc()
if s.listener != nil {
err := s.listener.Close()
if utils.IsUseOfClosedNetworkError(err) {
return nil
}
return trace.Wrap(err)
}
return nil
}
func (s *Server) acceptConnections() {
defer s.closeFunc()
addr := s.Addr()
s.log.Debugf("Listening on %v.", addr)
for {
conn, err := s.listener.Accept()
if err != nil {
if trace.IsLimitExceeded(err) {
proxyConnectionLimitHitCount.Inc()
s.log.Error(err.Error())
continue
}
if utils.IsUseOfClosedNetworkError(err) {
s.log.Debugf("Server %v has closed.", addr)
return
}
select {
case <-s.closeContext.Done():
s.log.Debugf("Server %v has closed.", addr)
return
case <-time.After(5 * time.Second):
s.log.Debugf("Backoff on network error: %v.", err)
}
} else {
go s.HandleConnection(conn)
}
}
}
func (s *Server) trackUserConnections(delta int32) int32 {
return atomic.AddInt32(&s.userConns, delta)
}
// ActiveConnections returns the number of connections that are
// being served.
func (s *Server) ActiveConnections() int32 {
return atomic.LoadInt32(&s.userConns)
}
// HandleConnection is called every time an SSH server accepts a new
// connection from a client.
//
// this is the foundation of all SSH connections in Teleport (between clients
// and proxies, proxies and servers, servers and auth, etc), except for forwarding
// SSH proxy that used when "recording on proxy" is enabled.
func (s *Server) HandleConnection(conn net.Conn) {
if s.ingressReporter != nil {
s.ingressReporter.ConnectionAccepted(s.ingressService, conn)
defer s.ingressReporter.ConnectionClosed(s.ingressService, conn)
}
// apply idle read/write timeout to this connection.
conn = utils.ObeyIdleTimeout(conn,
defaults.DefaultIdleConnectionDuration,
s.component)
// Wrap connection with a tracker used to monitor how much data was
// transmitted and received over the connection.
wconn := utils.NewTrackingConn(conn)
sconn, chans, reqs, err := ssh.NewServerConn(wconn, &s.cfg)
if err != nil {
// Ignore EOF as these are triggered by loadbalancer health checks
if !errors.Is(err, io.EOF) {
s.log.
WithError(err).
WithField("remote_addr", conn.RemoteAddr()).
Warn("Error occurred in handshake for new SSH conn")
}
conn.SetDeadline(time.Time{})
return
}
if s.ingressReporter != nil {
s.ingressReporter.ConnectionAuthenticated(s.ingressService, conn)
defer s.ingressReporter.AuthenticatedConnectionClosed(s.ingressService, conn)
}
certType := "unknown"
if sconn.Permissions != nil {
certType = sconn.Permissions.Extensions[utils.ExtIntCertType]
}
if certType == utils.ExtIntCertTypeUser {
s.trackUserConnections(1)
defer s.trackUserConnections(-1)
}
user := sconn.User()
if err := s.limiter.RegisterRequest(user); err != nil {
s.log.Errorf(err.Error())
sconn.Close()
conn.Close()
return
}
// Connection successfully initiated
s.log.Debugf("Incoming connection %v -> %v version: %v, certtype: %q",
sconn.RemoteAddr(), sconn.LocalAddr(), string(sconn.ClientVersion()), certType)
// will be called when the connection is closed
connClosed := func() {
s.log.Debugf("Closed connection %v.", sconn.RemoteAddr())
}
// The keepalive ticket will ensure that SSH keepalive requests are being sent
// to the client at an interval much shorter than idle connection kill switch
keepAliveTick := time.NewTicker(defaults.DefaultIdleConnectionDuration / 3)
defer keepAliveTick.Stop()
keepAlivePayload := [8]byte{0}
// NOTE: we deliberately don't use s.closeContext here because the server's
// closeContext field is used to trigger starvation on cancellation by halting
// the acceptance of new connections; it is not intended to halt in-progress
// connection handling, and is therefore orthogonal to the role of ConnectionContext.
ctx, ccx := NewConnectionContext(context.Background(), wconn, sconn, SetConnectionContextClock(s.clock))
defer ccx.Close()
if s.newConnHandler != nil {
// if newConnHandler was set, then we have additional setup work
// to do before we can begin serving normally. Errors returned
// from a NewConnHandler are rejections.
ctx, err = s.newConnHandler.HandleNewConn(ctx, ccx)
if err != nil {
s.log.Warnf("Dropping inbound ssh connection due to error: %v", err)
// Immediately dropping the ssh connection results in an
// EOF error for the client. We therefore wait briefly
// to see if the client opens a channel or sends any global
// requests, which will give us the opportunity to respond
// with a human-readable error.
waitCtx, waitCancel := context.WithTimeout(s.closeContext, time.Second)
defer waitCancel()
for {
select {
case req := <-reqs:
if req == nil {
connClosed()
break
}
// wait for a request that wants a reply to send the error
if !req.WantReply {
continue
}
if err := req.Reply(false, []byte(err.Error())); err != nil {
s.log.WithError(err).Warnf("failed to reply to request %s", req.Type)
}
case firstChan := <-chans:
// channel was closed, terminate the connection
if firstChan == nil {
break
}
if err := firstChan.Reject(ssh.Prohibited, err.Error()); err != nil {
s.log.WithError(err).Warnf("failed to reject channel %s", firstChan.ChannelType())
}
case <-waitCtx.Done():
}
break
}
if err := sconn.Close(); err != nil && !utils.IsOKNetworkError(err) {
s.log.WithError(err).Warn("failed to close ssh server connection")
}
if err := conn.Close(); err != nil && !utils.IsOKNetworkError(err) {
s.log.WithError(err).Warn("failed to close ssh client connection")
}
return
}
}
for {
select {
// handle out of band ssh requests
case req := <-reqs:
if req == nil {
connClosed()
return
}
s.log.Debugf("Received out-of-band request: %+v.", req)
reqCtx := tracessh.ContextFromRequest(req)
ctx, span := s.tracerProvider.Tracer("ssh").Start(
oteltrace.ContextWithRemoteSpanContext(ctx, oteltrace.SpanContextFromContext(reqCtx)),
fmt.Sprintf("ssh.GlobalRequest/%s", req.Type),
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
oteltrace.WithAttributes(
semconv.RPCServiceKey.String("ssh.Server"),
semconv.RPCMethodKey.String("GlobalRequest"),
semconv.RPCSystemKey.String("ssh"),
),
)
if s.reqHandler != nil {
go func(span oteltrace.Span) {
defer span.End()
s.reqHandler.HandleRequest(ctx, req)
}(span)
} else {
span.End()
}
// handle channels:
case nch := <-chans:
if nch == nil {
connClosed()
return
}
// This is a request from clients to determine if tracing is enabled.
// Handle here so that we always alert clients that we can handle tracing envelopes.
if nch.ChannelType() == tracessh.TracingChannel {
ch, _, err := nch.Accept()
if err != nil {
s.log.Warnf("Unable to accept channel: %v", err)
if err := nch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)); err != nil {
s.log.Warnf("Failed to reject channel: %v", err)
}
continue
}
if err := ch.Close(); err != nil {
s.log.Warnf("Unable to close %q channel: %v", nch.ChannelType(), err)
}
continue
}
chanCtx, nch := tracessh.ContextFromNewChannel(nch)
ctx, span := s.tracerProvider.Tracer("ssh").Start(
oteltrace.ContextWithRemoteSpanContext(ctx, oteltrace.SpanContextFromContext(chanCtx)),
fmt.Sprintf("ssh.OpenChannel/%s", nch.ChannelType()),
oteltrace.WithSpanKind(oteltrace.SpanKindServer),
oteltrace.WithAttributes(
semconv.RPCServiceKey.String("ssh.Server"),
semconv.RPCMethodKey.String("OpenChannel"),
semconv.RPCSystemKey.String("ssh"),
),
)
go func(span oteltrace.Span) {
defer span.End()
s.newChanHandler.HandleNewChan(ctx, ccx, nch)
}(span)
// send keepalive pings to the clients
case <-keepAliveTick.C:
const wantReply = true
_, _, err = sconn.SendRequest(teleport.KeepAliveReqType, wantReply, keepAlivePayload[:])
if err != nil {
s.log.Errorf("Failed sending keepalive request: %v", err)
}
case <-ctx.Done():
s.log.Debugf("Connection context canceled: %v -> %v", conn.RemoteAddr(), conn.LocalAddr())
return
}
}
}
type RequestHandler interface {
HandleRequest(ctx context.Context, r *ssh.Request)
}
type NewChanHandler interface {
HandleNewChan(context.Context, *ConnectionContext, ssh.NewChannel)
}
type NewChanHandlerFunc func(context.Context, *ConnectionContext, ssh.NewChannel)
func (f NewChanHandlerFunc) HandleNewChan(ctx context.Context, ccx *ConnectionContext, ch ssh.NewChannel) {
f(ctx, ccx, ch)
}
// NewConnHandler is called once per incoming connection.
// Errors terminate the incoming connection. The returned context
// must be the same as, or a child of, the passed in context.
type NewConnHandler interface {
HandleNewConn(ctx context.Context, ccx *ConnectionContext) (context.Context, error)
}
type AuthMethods struct {
PublicKey PublicKeyFunc
Password PasswordFunc
NoClient bool
}
func (s *Server) checkArguments(a utils.NetAddr, h NewChanHandler, hostSigners []ssh.Signer, ah AuthMethods) error {
// If the server is not in tunnel mode, an address must be specified.
if s.listener != nil {
if a.Addr == "" || a.AddrNetwork == "" {
return trace.BadParameter("addr: specify network and the address for listening socket")
}
}
if h == nil {
return trace.BadParameter("missing NewChanHandler")
}
if len(hostSigners) == 0 {
return trace.BadParameter("need at least one signer")
}
for _, signer := range hostSigners {
if signer == nil {
return trace.BadParameter("host signer can not be nil")
}
if !s.insecureSkipHostValidation {
err := validateHostSigner(s.fips, signer)
if err != nil {
return trace.Wrap(err)
}
}
}
if ah.PublicKey == nil && ah.Password == nil && !ah.NoClient {
return trace.BadParameter("need at least one auth method")
}
return nil
}
// validateHostSigner make sure the signer is a valid certificate.
func validateHostSigner(fips bool, signer ssh.Signer) error {
cert, ok := signer.PublicKey().(*ssh.Certificate)
if !ok {
return trace.BadParameter("only host certificates supported")
}
if len(cert.ValidPrincipals) == 0 {
return trace.BadParameter("at least one valid principal is required in host certificate")
}
certChecker := sshutils.CertChecker{
FIPS: fips,
}
err := certChecker.CheckCert(cert.ValidPrincipals[0], cert)
if err != nil {
return trace.Wrap(err)
}
return nil
}
type (
PublicKeyFunc func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error)
PasswordFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error)
)
// ClusterDetails specifies information about a cluster
type ClusterDetails struct {
RecordingProxy bool
FIPSEnabled bool
}