AuditLog/grpc server data race (#6170)

* Avoid test flake by ensuring the gRPC server is shutdown gracefully before closing the audit log

* Fix lint warnings. Nove tunnel server's Close to earlier to close the proxy watcher and release grpc traffic

* Use graceful shutdown selectively until all tests have improved support for it

* Move session recorder clean up to session.Close

* Always use graceful shutdown for TLS.
This commit is contained in:
a-palchikov 2021-05-19 02:57:57 +02:00 committed by GitHub
parent fc713b7216
commit ee6e2c85d8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 149 additions and 75 deletions

View file

@ -52,8 +52,7 @@ var wildcardAllow = services.Labels{
type SrvCtx struct {
srv *regular.Server
signer ssh.Signer
server *auth.TestTLSServer
testServer *auth.TestAuthServer
server *auth.TestServer
clock clockwork.FakeClock
nodeClient *auth.Client
nodeID string
@ -152,26 +151,25 @@ func newSrvCtx(t *testing.T) *SrvCtx {
s := &SrvCtx{}
t.Cleanup(func() {
if s.server != nil {
require.NoError(t, s.server.Close())
}
if s.srv != nil {
require.NoError(t, s.srv.Close())
}
if s.server != nil {
require.NoError(t, s.server.Shutdown(context.Background()))
}
})
s.clock = clockwork.NewFakeClock()
tempdir := t.TempDir()
authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: tempdir,
Clock: s.clock,
})
var err error
s.server, err = auth.NewTestServer(auth.TestServerConfig{
Auth: auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: tempdir,
Clock: s.clock,
}})
require.NoError(t, err)
s.server, err = authServer.NewTestTLSServer()
require.NoError(t, err)
s.testServer = authServer
// set up host private key and certificate
certs, err := s.server.Auth().GenerateServerKeys(auth.GenerateServerKeysRequest{
@ -266,7 +264,7 @@ func newUpack(s *SrvCtx, username string, allowedLogins []string, allowedLabels
if err != nil {
return nil, trace.Wrap(err)
}
ucert, err := s.testServer.GenerateUserCert(upub, user.GetName(), 5*time.Minute, teleport.CertificateFormatStandard)
ucert, err := s.server.AuthServer.GenerateUserCert(upub, user.GetName(), 5*time.Minute, teleport.CertificateFormatStandard)
if err != nil {
return nil, trace.Wrap(err)
}

View file

@ -96,6 +96,66 @@ func CreateUploaderDir(dir string) error {
return nil
}
// TestServer defines the set of server components for a test
type TestServer struct {
TLS *TestTLSServer
AuthServer *TestAuthServer
}
// TestServerConfig defines the configuration for all server components
type TestServerConfig struct {
// Auth specifies the auth server configuration
Auth TestAuthServerConfig
// TLS optionally specifies the configuration for the TLS server.
// If unspecified, will be generated automatically
TLS *TestTLSServerConfig
}
// NewTestServer creates a new test server configuration
func NewTestServer(cfg TestServerConfig) (*TestServer, error) {
authServer, err := NewTestAuthServer(cfg.Auth)
if err != nil {
return nil, trace.Wrap(err)
}
var tlsServer *TestTLSServer
if cfg.TLS != nil {
tlsServer, err = NewTestTLSServer(*cfg.TLS)
if err != nil {
return nil, trace.Wrap(err)
}
} else {
tlsServer, err = authServer.NewTestTLSServer()
if err != nil {
return nil, trace.Wrap(err)
}
}
return &TestServer{
AuthServer: authServer,
TLS: tlsServer,
}, nil
}
// Auth returns the underlying auth server instance
func (a *TestServer) Auth() *Server {
return a.AuthServer.AuthServer
}
func (a *TestServer) NewClient(identity TestIdentity) (*Client, error) {
return a.TLS.NewClient(identity)
}
func (a *TestServer) ClusterName() string {
return a.TLS.ClusterName()
}
// Shutdown stops this server instance gracefully
func (a *TestServer) Shutdown(ctx context.Context) error {
return trace.NewAggregate(
a.TLS.Shutdown(ctx),
a.AuthServer.Close(),
)
}
// TestAuthServer is auth server using local filesystem backend
// and test certificate authority key generation that speeds up
// keygen by using the same private key
@ -252,9 +312,9 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
func (a *TestAuthServer) Close() error {
return trace.NewAggregate(
a.AuthServer.Close(),
a.Backend.Close(),
a.AuditLog.Close(),
a.AuthServer.Close(),
)
}
@ -702,6 +762,18 @@ func (t *TestTLSServer) Close() error {
return err
}
// Shutdown closes the listener and HTTP server gracefully
func (t *TestTLSServer) Shutdown(ctx context.Context) error {
err := t.TLSServer.Shutdown(ctx)
if t.Listener != nil {
t.Listener.Close()
}
if t.AuthServer.Backend != nil {
t.AuthServer.Backend.Close()
}
return err
}
// Stop stops listening server, but does not close the auth backend
func (t *TestTLSServer) Stop() error {
err := t.TLSServer.Close()

View file

@ -553,11 +553,13 @@ func (s *server) Start() error {
func (s *server) Close() error {
s.cancel()
s.proxyWatcher.Close()
return s.srv.Close()
}
func (s *server) Shutdown(ctx context.Context) error {
s.cancel()
s.proxyWatcher.Close()
return s.srv.Shutdown(ctx)
}

View file

@ -70,14 +70,13 @@ type SrvSuite struct {
up *upack
signer ssh.Signer
user string
server *auth.TestTLSServer
proxyClient *auth.Client
proxyID string
nodeClient *auth.Client
nodeID string
adminClient *auth.Client
testServer *auth.TestAuthServer
clock clockwork.FakeClock
server *auth.TestServer
}
// teleportTestUser is additional user used for tests
@ -113,15 +112,14 @@ func (s *SrvSuite) SetUpTest(c *C) {
s.clock = clockwork.NewFakeClock()
authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: c.MkDir(),
Clock: s.clock,
s.server, err = auth.NewTestServer(auth.TestServerConfig{
Auth: auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: c.MkDir(),
Clock: s.clock,
},
})
c.Assert(err, IsNil)
s.server, err = authServer.NewTestTLSServer()
c.Assert(err, IsNil)
s.testServer = authServer
// create proxy client used in some tests
s.proxyID = uuid.New()
@ -224,12 +222,12 @@ func (s *SrvSuite) TearDownTest(c *C) {
if s.clt != nil {
c.Assert(s.clt.Close(), IsNil)
}
if s.server != nil {
c.Assert(s.server.Close(), IsNil)
}
if s.srv != nil {
c.Assert(s.srv.Close(), IsNil)
}
if s.server != nil {
c.Assert(s.server.Shutdown(context.Background()), IsNil)
}
}
// TestDirectTCPIP ensures that the server can create a "direct-tcpip"
@ -722,6 +720,7 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
})
c.Assert(err, IsNil)
c.Assert(reverseTunnelServer.Start(), IsNil)
defer reverseTunnelServer.Close()
proxy, err := New(
utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"},
@ -891,6 +890,7 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) {
logger.WithField("tun-addr", reverseTunnelAddress.String()).Info("Created reverse tunnel server.")
c.Assert(reverseTunnelServer.Start(), IsNil)
defer reverseTunnelServer.Close()
proxy, err := New(
utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"},
@ -998,6 +998,9 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) {
})
c.Assert(err, IsNil)
c.Assert(reverseTunnelServer.Start(), IsNil)
defer reverseTunnelServer.Close()
proxy, err := New(
utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"},
s.server.ClusterName(),
@ -1016,6 +1019,7 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) {
)
c.Assert(err, IsNil)
c.Assert(proxy.Start(), IsNil)
defer proxy.Close()
// set up SSH client using the user private key for signing
up, err := s.newUpack(s.user, []string{s.user}, wildcardAllow)
@ -1506,7 +1510,7 @@ func (s *SrvSuite) newUpack(username string, allowedLogins []string, allowedLabe
if err != nil {
return nil, trace.Wrap(err)
}
ucert, err := s.testServer.GenerateUserCert(upub, user.GetName(), 5*time.Minute, teleport.CertificateFormatStandard)
ucert, err := s.server.AuthServer.GenerateUserCert(upub, user.GetName(), 5*time.Minute, teleport.CertificateFormatStandard)
if err != nil {
return nil, trace.Wrap(err)
}

View file

@ -17,6 +17,7 @@ limitations under the License.
package srv
import (
"context"
"encoding/json"
"fmt"
"io"
@ -527,6 +528,9 @@ type session struct {
// hasEnhancedRecording returns true if this session has enhanced session
// recording events associated.
hasEnhancedRecording bool
// serverCtx is used to control clean up of internal resources
serverCtx context.Context
}
// newSession creates a new session with a given ID within a given context.
@ -598,6 +602,7 @@ func newSession(id rsession.ID, r *SessionRegistry, ctx *ServerContext) (*sessio
closeC: make(chan bool),
lingerTTL: defaults.SessionIdlePeriod,
startTime: startTime,
serverCtx: ctx.srv.Context(),
}
return sess, nil
}
@ -636,10 +641,10 @@ func (s *session) Close() error {
if s.term != nil {
s.term.Close()
}
if s.recorder != nil {
s.recorder.Close(s.serverCtx)
}
close(s.closeC)
// close all writers in our multi-writer
s.writer.Close()
}()
})
return nil
@ -1331,18 +1336,6 @@ func (m *multiWriter) Write(p []byte) (n int, err error) {
return len(p), nil
}
func (m *multiWriter) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
for writerName, writer := range m.writers {
logrus.Debugf("Closing session writer: %v.", writerName)
if closer, ok := writer.WriteCloser.(io.Closer); ok {
closer.Close()
}
}
return nil
}
func (m *multiWriter) getRecentWrites() []byte {
m.mu.Lock()
defer m.mu.Unlock()

View file

@ -102,7 +102,7 @@ type WebSuite struct {
webServer *httptest.Server
mockU2F *mocku2f.Key
server *auth.TestTLSServer
server *auth.TestServer
proxyClient *auth.Client
clock clockwork.FakeClock
}
@ -141,17 +141,18 @@ func (s *WebSuite) SetUpTest(c *C) {
s.user = u.Username
s.clock = clockwork.NewFakeClock()
authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: c.MkDir(),
Clock: s.clock,
s.server, err = auth.NewTestServer(auth.TestServerConfig{
Auth: auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: c.MkDir(),
Clock: s.clock,
},
})
c.Assert(err, IsNil)
s.server, err = authServer.NewTestTLSServer()
c.Assert(err, IsNil)
// Register the auth server, since test auth server doesn't start its own
// heartbeat.
err = authServer.AuthServer.UpsertAuthServer(&services.ServerV2{
err = s.server.Auth().UpsertAuthServer(&services.ServerV2{
Kind: services.KindAuthServer,
Version: services.V2,
Metadata: services.Metadata{
@ -159,7 +160,7 @@ func (s *WebSuite) SetUpTest(c *C) {
Name: "auth",
},
Spec: services.ServerSpecV2{
Addr: s.server.Listener.Addr().String(),
Addr: s.server.TLS.Listener.Addr().String(),
Hostname: "localhost",
Version: teleport.Version,
},
@ -266,7 +267,7 @@ func (s *WebSuite) SetUpTest(c *C) {
c.Assert(err, IsNil)
handler, err := NewHandler(Config{
Proxy: revTunServer,
AuthServers: utils.FromAddr(s.server.Addr()),
AuthServers: utils.FromAddr(s.server.TLS.Addr()),
DomainName: s.server.ClusterName(),
ProxyClient: s.proxyClient,
CipherSuites: utils.DefaultCipherSuites(),
@ -307,11 +308,17 @@ func (s *WebSuite) SetUpTest(c *C) {
}
func (s *WebSuite) TearDownTest(c *C) {
c.Assert(s.node.Close(), IsNil)
c.Assert(s.server.Close(), IsNil)
var errors []error
s.proxyTunnel.Close()
if err := s.node.Close(); err != nil {
errors = append(errors, err)
}
if err := s.server.Shutdown(context.Background()); err != nil {
errors = append(errors, err)
}
s.webServer.Close()
s.proxy.Close()
s.proxyTunnel.Close()
c.Assert(errors, HasLen, 0)
}
func (r *authPack) renewSession(ctx context.Context, t *testing.T) *roundtrip.Response {
@ -452,7 +459,7 @@ func (s *WebSuite) TestSAMLSuccess(c *C) {
err = s.server.Auth().CreateSAMLConnector(connector)
c.Assert(err, IsNil)
s.server.AuthServer.AuthServer.SetClock(clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC)))
s.server.Auth().SetClock(clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC)))
clt := s.clientNoRedirects()
csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992"
@ -1435,8 +1442,8 @@ func testU2FLogin(t *testing.T, secondFactor constants.SecondFactorType) {
Type: teleport.Local,
SecondFactor: constants.SecondFactorU2F,
U2F: &services.U2F{
AppID: "https://" + env.server.ClusterName(),
Facets: []string{"https://" + env.server.ClusterName()},
AppID: "https://" + env.server.TLS.ClusterName(),
Facets: []string{"https://" + env.server.TLS.ClusterName()},
},
})
require.NoError(t, err)
@ -2454,21 +2461,19 @@ func (r CreateSessionResponse) response() (*CreateSessionResponse, error) {
func newWebPack(t *testing.T, numProxies int) *webPack {
clock := clockwork.NewFakeClock()
authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: t.TempDir(),
Clock: clock,
server, err := auth.NewTestServer(auth.TestServerConfig{
Auth: auth.TestAuthServerConfig{
ClusterName: "localhost",
Dir: t.TempDir(),
Clock: clock,
},
})
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, authServer.Close()) })
server, err := authServer.NewTestTLSServer()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, server.Close()) })
t.Cleanup(func() { require.NoError(t, server.Shutdown(context.Background())) })
// Register the auth server, since test auth server doesn't start its own
// heartbeat.
err = authServer.AuthServer.UpsertAuthServer(&services.ServerV2{
err = server.Auth().UpsertAuthServer(&services.ServerV2{
Kind: services.KindAuthServer,
Version: services.V2,
Metadata: services.Metadata{
@ -2476,7 +2481,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack {
Name: "auth",
},
Spec: services.ServerSpecV2{
Addr: server.Listener.Addr().String(),
Addr: server.TLS.Listener.Addr().String(),
Hostname: "localhost",
Version: teleport.Version,
},
@ -2486,7 +2491,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack {
// start auth server
certs, err := server.Auth().GenerateServerKeys(auth.GenerateServerKeysRequest{
HostID: hostID,
NodeName: server.ClusterName(),
NodeName: server.TLS.ClusterName(),
Roles: teleport.Roles{teleport.RoleNode},
})
require.NoError(t, err)
@ -2495,7 +2500,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack {
require.NoError(t, err)
const nodeID = "node"
nodeClient, err := server.NewClient(auth.TestIdentity{
nodeClient, err := server.TLS.NewClient(auth.TestIdentity{
I: auth.BuiltinRole{
Role: teleport.RoleNode,
Username: nodeID,
@ -2509,7 +2514,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack {
nodeDataDir := t.TempDir()
node, err := regular.New(
utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"},
server.ClusterName(),
server.TLS.ClusterName(),
hostSigners,
nodeClient,
nodeDataDir,
@ -2533,7 +2538,7 @@ func newWebPack(t *testing.T, numProxies int) *webPack {
var proxies []*proxy
for p := 0; p < numProxies; p++ {
proxyID := fmt.Sprintf("proxy%v", p)
proxies = append(proxies, createProxy(t, proxyID, node, server, hostSigners, clock))
proxies = append(proxies, createProxy(t, proxyID, node, server.TLS, hostSigners, clock))
}
// Wait for proxies to fully register before starting the test.
@ -2658,7 +2663,7 @@ func createProxy(t *testing.T, proxyID string, node *regular.Server, authServer
// directly.
type webPack struct {
proxies []*proxy
server *auth.TestTLSServer
server *auth.TestServer
node *regular.Server
clock clockwork.FakeClock
}