Fix an issue context canceled spams Proxy log for database connections (#31441)

* poc tracking read conn context canceled

* fix lint

* remove old Cancel func

* update comments
This commit is contained in:
STeve (Xin) Huang 2023-09-08 09:00:48 -04:00 committed by GitHub
parent 176237ddb7
commit 0233855193
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 79 additions and 23 deletions

View file

@ -2162,13 +2162,13 @@ type clusterSession struct {
// connCtx is the context used to monitor the connection.
connCtx context.Context
// connMonitorCancel is the conn monitor connMonitorCancel function.
connMonitorCancel context.CancelFunc
connMonitorCancel context.CancelCauseFunc
}
// close cancels the connection monitor context if available.
func (s *clusterSession) close() {
if s.connMonitorCancel != nil {
s.connMonitorCancel()
s.connMonitorCancel(io.EOF)
}
}
@ -2184,7 +2184,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error)
Cancel: s.connMonitorCancel,
})
if err != nil {
s.connMonitorCancel()
s.connMonitorCancel(err)
return nil, trace.Wrap(err)
}
@ -2203,8 +2203,7 @@ func (s *clusterSession) monitorConn(conn net.Conn, err error) (net.Conn, error)
Emitter: s.parent.cfg.AuthClient,
})
if err != nil {
tc.Close()
s.connMonitorCancel()
tc.CloseWithCause(err)
return nil, trace.Wrap(err)
}
return tc, nil
@ -2264,7 +2263,7 @@ func (f *Forwarder) newClusterSession(ctx context.Context, authCtx authContext)
func (f *Forwarder) newClusterSessionRemoteCluster(ctx context.Context, authCtx authContext) (*clusterSession, error) {
f.log.Debugf("Forwarding kubernetes session for %v to remote cluster.", authCtx)
connCtx, cancel := context.WithCancel(ctx)
connCtx, cancel := context.WithCancelCause(ctx)
return &clusterSession{
parent: f,
authContext: authCtx,
@ -2312,7 +2311,7 @@ func (f *Forwarder) newClusterSessionLocal(ctx context.Context, authCtx authCont
return nil, trace.Wrap(err)
}
connCtx, cancel := context.WithCancel(ctx)
connCtx, cancel := context.WithCancelCause(ctx)
f.log.Debugf("Handling kubernetes session for %v using local credentials.", authCtx)
return &clusterSession{
parent: f,
@ -2328,7 +2327,7 @@ func (f *Forwarder) newClusterSessionLocal(ctx context.Context, authCtx authCont
}
func (f *Forwarder) newClusterSessionDirect(ctx context.Context, authCtx authContext) (*clusterSession, error) {
connCtx, cancel := context.WithCancel(ctx)
connCtx, cancel := context.WithCancelCause(ctx)
return &clusterSession{
parent: f,
authContext: authCtx,

View file

@ -702,7 +702,7 @@ func (s *Server) HandleConnection(conn net.Conn) {
}
func (s *Server) handleConnection(conn net.Conn) (func(), error) {
ctx, cancel := context.WithCancel(s.closeContext)
ctx, cancel := context.WithCancelCause(s.closeContext)
tc, err := srv.NewTrackingReadConn(srv.TrackingReadConnConfig{
Conn: conn,
Clock: s.c.Clock,

View file

@ -20,7 +20,9 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"math/rand"
"net"
"sort"
@ -529,7 +531,16 @@ func (s *ProxyServer) Proxy(ctx context.Context, proxyCtx *common.ProxyContext,
activeConnections.With(labels).Inc()
defer activeConnections.With(labels).Dec()
return trace.Wrap(utils.ProxyConn(ctx, clientConn, serviceConn))
err = utils.ProxyConn(ctx, clientConn, serviceConn)
// The clientConn is closed by utils.ProxyConn on successful io.Copy thus
// possibly causing utils.ProxyConn to return io.EOF from
// context.Cause(ctx), as monitor context is closed when
// TrackingReadConn.Close() is called.
if errors.Is(err, io.EOF) {
return nil
}
return trace.Wrap(err)
}
// Authorize authorizes the provided client TLS connection.

View file

@ -157,7 +157,7 @@ func (c *ConnectionMonitor) MonitorConn(ctx context.Context, authzCtx *authz.Con
tconn, ok := getTrackingReadConn(conn)
if !ok {
tctx, cancel := context.WithCancel(ctx)
tctx, cancel := context.WithCancelCause(ctx)
tconn, err = NewTrackingReadConn(TrackingReadConnConfig{
Conn: conn,
Clock: c.cfg.Clock,
@ -405,6 +405,10 @@ func (w *Monitor) disconnectClientOnExpiredCert() {
w.disconnectClient(reason)
}
type withCauseCloser interface {
CloseWithCause(cause error) error
}
func (w *Monitor) disconnectClient(reason string) {
w.Entry.Debugf("Disconnecting client: %v", reason)
// Emit Audit event first to make sure that that underlying context will not be canceled during
@ -412,8 +416,15 @@ func (w *Monitor) disconnectClient(reason string) {
if err := w.emitDisconnectEvent(reason); err != nil {
w.Entry.WithError(err).Warn("Failed to emit audit event.")
}
if err := w.Conn.Close(); err != nil {
w.Entry.WithError(err).Error("Failed to close connection.")
if connWithCauseCloser, ok := w.Conn.(withCauseCloser); ok {
if err := connWithCauseCloser.CloseWithCause(trace.AccessDenied(reason)); err != nil {
w.Entry.WithError(err).Error("Failed to close connection.")
}
} else {
if err := w.Conn.Close(); err != nil {
w.Entry.WithError(err).Error("Failed to close connection.")
}
}
}
@ -482,7 +493,7 @@ type TrackingReadConnConfig struct {
// Context is an external context to cancel the operation.
Context context.Context
// Cancel is called whenever client context is closed.
Cancel context.CancelFunc
Cancel context.CancelCauseFunc
}
// CheckAndSetDefaults checks and sets defaults.
@ -536,8 +547,16 @@ func (t *TrackingReadConn) Read(b []byte) (int, error) {
return n, err
}
// Close cancels the context with io.EOF and closes the underlying connection.
func (t *TrackingReadConn) Close() error {
t.cfg.Cancel()
t.cfg.Cancel(io.EOF)
return t.Conn.Close()
}
// CloseWithCause cancels the context with provided cause and closes the
// underlying connection.
func (t *TrackingReadConn) CloseWithCause(cause error) error {
t.cfg.Cancel(cause)
return t.Conn.Close()
}

View file

@ -23,6 +23,7 @@ import (
"testing"
"time"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
@ -116,8 +117,13 @@ func TestConnectionMonitorLockInForce(t *testing.T) {
t.Fatal("Timeout waiting for connection close.")
}
// Assert that the context was canceled.
// Assert that the context was canceled and verify the cause.
require.Error(t, monitorCtx.Err())
cause := context.Cause(monitorCtx)
require.True(t, trace.IsAccessDenied(cause))
for _, contains := range []string{"lock", "in force"} {
require.Contains(t, cause.Error(), contains)
}
// Validate that the disconnect event was logged.
require.Equal(t, services.LockInForceAccessDenied(lock).Error(), (<-emitter.C()).(*apievents.ClientDisconnect).Reason)
@ -282,7 +288,7 @@ func TestMonitorDisconnectExpiredCertBeforeTimeNow(t *testing.T) {
}
}
func TestTrackingReadConnEOF(t *testing.T) {
func TestTrackingReadConn(t *testing.T) {
server, client := net.Pipe()
defer client.Close()
@ -290,7 +296,7 @@ func TestTrackingReadConnEOF(t *testing.T) {
require.NoError(t, server.Close())
// Wrap the client in a TrackingReadConn.
ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancelCause(context.Background())
tc, err := NewTrackingReadConn(TrackingReadConnConfig{
Conn: client,
Clock: clockwork.NewFakeClock(),
@ -299,10 +305,30 @@ func TestTrackingReadConnEOF(t *testing.T) {
})
require.NoError(t, err)
// Make sure it returns an EOF and not a wrapped exception.
buf := make([]byte, 64)
_, err = tc.Read(buf)
require.Equal(t, io.EOF, err)
t.Run("Read EOF", func(t *testing.T) {
// Make sure it returns an EOF and not a wrapped exception.
buf := make([]byte, 64)
_, err = tc.Read(buf)
require.Equal(t, io.EOF, err)
})
t.Run("CloseWithCause", func(t *testing.T) {
require.NoError(t, tc.CloseWithCause(trace.AccessDenied("fake problem")))
require.ErrorIs(t, context.Cause(ctx), trace.AccessDenied("fake problem"))
})
t.Run("Close", func(t *testing.T) {
ctx, cancel := context.WithCancelCause(context.Background())
tc, err := NewTrackingReadConn(TrackingReadConnConfig{
Conn: client,
Clock: clockwork.NewFakeClock(),
Context: ctx,
Cancel: cancel,
})
require.NoError(t, err)
require.NoError(t, tc.Close())
require.ErrorIs(t, context.Cause(ctx), io.EOF)
})
}
type mockChecker struct {

View file

@ -86,7 +86,8 @@ func ProxyConn(ctx context.Context, client, server io.ReadWriteCloser) error {
errors = append(errors, err)
}
case <-ctx.Done():
return ctx.Err()
// Cause(ctx) returns ctx.Err() if no cause is provided.
return trace.Wrap(context.Cause(ctx))
}
}