mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 01:34:01 +00:00
Fix an issue context canceled
spams Proxy log for database connections (#31441)
* poc tracking read conn context canceled * fix lint * remove old Cancel func * update comments
This commit is contained in:
parent
176237ddb7
commit
0233855193
|
@ -2162,13 +2162,13 @@ type clusterSession struct {
|
|||
// connCtx is the context used to monitor the connection.
|
||||
connCtx context.Context
|
||||
// connMonitorCancel is the conn monitor connMonitorCancel function.
|
||||
connMonitorCancel context.CancelFunc
|
||||
connMonitorCancel context.CancelCauseFunc
|
||||
}
|
||||
|
||||
// close cancels the connection monitor context if available.
|
||||
func (s *clusterSession) close() {
|
||||
if s.connMonitorCancel != nil {
|
||||
s.connMonitorCancel()
|
||||
s.connMonitorCancel(io.EOF)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2184,7 +2184,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error)
|
|||
Cancel: s.connMonitorCancel,
|
||||
})
|
||||
if err != nil {
|
||||
s.connMonitorCancel()
|
||||
s.connMonitorCancel(err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
|
@ -2203,8 +2203,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error)
|
|||
Emitter: s.parent.cfg.AuthClient,
|
||||
})
|
||||
if err != nil {
|
||||
tc.Close()
|
||||
s.connMonitorCancel()
|
||||
tc.CloseWithCause(err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return tc, nil
|
||||
|
@ -2264,7 +2263,7 @@ func (f *Forwarder) newClusterSession(ctx context.Context, authCtx authContext)
|
|||
|
||||
func (f *Forwarder) newClusterSessionRemoteCluster(ctx context.Context, authCtx authContext) (*clusterSession, error) {
|
||||
f.log.Debugf("Forwarding kubernetes session for %v to remote cluster.", authCtx)
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
connCtx, cancel := context.WithCancelCause(ctx)
|
||||
return &clusterSession{
|
||||
parent: f,
|
||||
authContext: authCtx,
|
||||
|
@ -2312,7 +2311,7 @@ func (f *Forwarder) newClusterSessionLocal(ctx context.Context, authCtx authCont
|
|||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
connCtx, cancel := context.WithCancelCause(ctx)
|
||||
f.log.Debugf("Handling kubernetes session for %v using local credentials.", authCtx)
|
||||
return &clusterSession{
|
||||
parent: f,
|
||||
|
@ -2328,7 +2327,7 @@ func (f *Forwarder) newClusterSessionLocal(ctx context.Context, authCtx authCont
|
|||
}
|
||||
|
||||
func (f *Forwarder) newClusterSessionDirect(ctx context.Context, authCtx authContext) (*clusterSession, error) {
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
connCtx, cancel := context.WithCancelCause(ctx)
|
||||
return &clusterSession{
|
||||
parent: f,
|
||||
authContext: authCtx,
|
||||
|
|
|
@ -702,7 +702,7 @@ func (s *Server) HandleConnection(conn net.Conn) {
|
|||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) (func(), error) {
|
||||
ctx, cancel := context.WithCancel(s.closeContext)
|
||||
ctx, cancel := context.WithCancelCause(s.closeContext)
|
||||
tc, err := srv.NewTrackingReadConn(srv.TrackingReadConnConfig{
|
||||
Conn: conn,
|
||||
Clock: s.c.Clock,
|
||||
|
|
|
@ -20,7 +20,9 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sort"
|
||||
|
@ -529,7 +531,16 @@ func (s *ProxyServer) Proxy(ctx context.Context, proxyCtx *common.ProxyContext,
|
|||
activeConnections.With(labels).Inc()
|
||||
defer activeConnections.With(labels).Dec()
|
||||
|
||||
return trace.Wrap(utils.ProxyConn(ctx, clientConn, serviceConn))
|
||||
err = utils.ProxyConn(ctx, clientConn, serviceConn)
|
||||
|
||||
// The clientConn is closed by utils.ProxyConn on successful io.Copy thus
|
||||
// possibly causing utils.ProxyConn to return io.EOF from
|
||||
// context.Cause(ctx), as monitor context is closed when
|
||||
// TrackingReadConn.Close() is called.
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// Authorize authorizes the provided client TLS connection.
|
||||
|
|
|
@ -157,7 +157,7 @@ func (c *ConnectionMonitor) MonitorConn(ctx context.Context, authzCtx *authz.Con
|
|||
|
||||
tconn, ok := getTrackingReadConn(conn)
|
||||
if !ok {
|
||||
tctx, cancel := context.WithCancel(ctx)
|
||||
tctx, cancel := context.WithCancelCause(ctx)
|
||||
tconn, err = NewTrackingReadConn(TrackingReadConnConfig{
|
||||
Conn: conn,
|
||||
Clock: c.cfg.Clock,
|
||||
|
@ -405,6 +405,10 @@ func (w *Monitor) disconnectClientOnExpiredCert() {
|
|||
w.disconnectClient(reason)
|
||||
}
|
||||
|
||||
type withCauseCloser interface {
|
||||
CloseWithCause(cause error) error
|
||||
}
|
||||
|
||||
func (w *Monitor) disconnectClient(reason string) {
|
||||
w.Entry.Debugf("Disconnecting client: %v", reason)
|
||||
// Emit Audit event first to make sure that that underlying context will not be canceled during
|
||||
|
@ -412,8 +416,15 @@ func (w *Monitor) disconnectClient(reason string) {
|
|||
if err := w.emitDisconnectEvent(reason); err != nil {
|
||||
w.Entry.WithError(err).Warn("Failed to emit audit event.")
|
||||
}
|
||||
if err := w.Conn.Close(); err != nil {
|
||||
w.Entry.WithError(err).Error("Failed to close connection.")
|
||||
|
||||
if connWithCauseCloser, ok := w.Conn.(withCauseCloser); ok {
|
||||
if err := connWithCauseCloser.CloseWithCause(trace.AccessDenied(reason)); err != nil {
|
||||
w.Entry.WithError(err).Error("Failed to close connection.")
|
||||
}
|
||||
} else {
|
||||
if err := w.Conn.Close(); err != nil {
|
||||
w.Entry.WithError(err).Error("Failed to close connection.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -482,7 +493,7 @@ type TrackingReadConnConfig struct {
|
|||
// Context is an external context to cancel the operation.
|
||||
Context context.Context
|
||||
// Cancel is called whenever client context is closed.
|
||||
Cancel context.CancelFunc
|
||||
Cancel context.CancelCauseFunc
|
||||
}
|
||||
|
||||
// CheckAndSetDefaults checks and sets defaults.
|
||||
|
@ -536,8 +547,16 @@ func (t *TrackingReadConn) Read(b []byte) (int, error) {
|
|||
return n, err
|
||||
}
|
||||
|
||||
// Close cancels the context with io.EOF and closes the underlying connection.
|
||||
func (t *TrackingReadConn) Close() error {
|
||||
t.cfg.Cancel()
|
||||
t.cfg.Cancel(io.EOF)
|
||||
return t.Conn.Close()
|
||||
}
|
||||
|
||||
// CloseWithCause cancels the context with provided cause and closes the
|
||||
// underlying connection.
|
||||
func (t *TrackingReadConn) CloseWithCause(cause error) error {
|
||||
t.cfg.Cancel(cause)
|
||||
return t.Conn.Close()
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -116,8 +117,13 @@ func TestConnectionMonitorLockInForce(t *testing.T) {
|
|||
t.Fatal("Timeout waiting for connection close.")
|
||||
}
|
||||
|
||||
// Assert that the context was canceled.
|
||||
// Assert that the context was canceled and verify the cause.
|
||||
require.Error(t, monitorCtx.Err())
|
||||
cause := context.Cause(monitorCtx)
|
||||
require.True(t, trace.IsAccessDenied(cause))
|
||||
for _, contains := range []string{"lock", "in force"} {
|
||||
require.Contains(t, cause.Error(), contains)
|
||||
}
|
||||
|
||||
// Validate that the disconnect event was logged.
|
||||
require.Equal(t, services.LockInForceAccessDenied(lock).Error(), (<-emitter.C()).(*apievents.ClientDisconnect).Reason)
|
||||
|
@ -282,7 +288,7 @@ func TestMonitorDisconnectExpiredCertBeforeTimeNow(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestTrackingReadConnEOF(t *testing.T) {
|
||||
func TestTrackingReadConn(t *testing.T) {
|
||||
server, client := net.Pipe()
|
||||
defer client.Close()
|
||||
|
||||
|
@ -290,7 +296,7 @@ func TestTrackingReadConnEOF(t *testing.T) {
|
|||
require.NoError(t, server.Close())
|
||||
|
||||
// Wrap the client in a TrackingReadConn.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
tc, err := NewTrackingReadConn(TrackingReadConnConfig{
|
||||
Conn: client,
|
||||
Clock: clockwork.NewFakeClock(),
|
||||
|
@ -299,10 +305,30 @@ func TestTrackingReadConnEOF(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure it returns an EOF and not a wrapped exception.
|
||||
buf := make([]byte, 64)
|
||||
_, err = tc.Read(buf)
|
||||
require.Equal(t, io.EOF, err)
|
||||
t.Run("Read EOF", func(t *testing.T) {
|
||||
// Make sure it returns an EOF and not a wrapped exception.
|
||||
buf := make([]byte, 64)
|
||||
_, err = tc.Read(buf)
|
||||
require.Equal(t, io.EOF, err)
|
||||
})
|
||||
|
||||
t.Run("CloseWithCause", func(t *testing.T) {
|
||||
require.NoError(t, tc.CloseWithCause(trace.AccessDenied("fake problem")))
|
||||
require.ErrorIs(t, context.Cause(ctx), trace.AccessDenied("fake problem"))
|
||||
})
|
||||
|
||||
t.Run("Close", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
tc, err := NewTrackingReadConn(TrackingReadConnConfig{
|
||||
Conn: client,
|
||||
Clock: clockwork.NewFakeClock(),
|
||||
Context: ctx,
|
||||
Cancel: cancel,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tc.Close())
|
||||
require.ErrorIs(t, context.Cause(ctx), io.EOF)
|
||||
})
|
||||
}
|
||||
|
||||
type mockChecker struct {
|
||||
|
|
|
@ -86,7 +86,8 @@ func ProxyConn(ctx context.Context, client, server io.ReadWriteCloser) error {
|
|||
errors = append(errors, err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
// Cause(ctx) returns ctx.Err() if no cause is provided.
|
||||
return trace.Wrap(context.Cause(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue