diff --git a/constants.go b/constants.go index 1a6b65e8ccf..4f77de534a0 100644 --- a/constants.go +++ b/constants.go @@ -1,6 +1,7 @@ package teleport import ( + "strings" "time" ) @@ -79,6 +80,9 @@ const ( // ComponentProxy is SSH proxy (SSH server forwarding connections) ComponentProxy = "proxy" + // ComponentDiagnostic is a diagnostic service + ComponentDiagnostic = "diagnostic" + // ComponentTunClient is a tunnel client ComponentTunClient = "client:tunnel" @@ -201,6 +205,12 @@ const ( Off = "off" ) +// Component generates "component:subcomponent1:subcomponent2" strings used +// in debugging +func Component(components ...string) string { + return strings.Join(components, ":") +} + const ( // AuthorizedKeys are public keys that check against User CAs. AuthorizedKeys = "authorized_keys" diff --git a/e b/e index e18107a7413..66dea98c52b 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit e18107a74135099d9ecd6caad7d9baa70f26efde +Subproject commit 66dea98c52bf2d60bc9ac0d50df1849f2923c35f diff --git a/integration/helpers.go b/integration/helpers.go index 9ed51c48003..821f93fa95d 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -26,7 +26,7 @@ import ( "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/backend" - "github.com/gravitational/teleport/lib/backend/boltbk" + "github.com/gravitational/teleport/lib/backend/dir" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/reversetunnel" @@ -363,7 +363,7 @@ func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *servic tconf.Proxy.WebAddr.Addr = net.JoinHostPort(i.Hostname, i.GetPortWeb()) tconf.AuthServers = append(tconf.AuthServers, tconf.Auth.SSHAddr) tconf.Auth.StorageConfig = backend.Config{ - Type: boltbk.GetName(), + Type: dir.GetName(), Params: backend.Params{"path": dataDir}, } diff --git a/integration/integration_test.go b/integration/integration_test.go index d8184e33572..c89bb87e681 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -212,7 +212,7 @@ func (s *IntSuite) TestAuditOn(c *check.C) { select { case <-tickCh: nodesInSite, err := site.GetNodes(defaults.Namespace) - if err != nil { + if err != nil && !trace.IsNotFound(err) { return trace.Wrap(err) } if got, want := len(nodesInSite), count; got == want { @@ -497,6 +497,7 @@ func (s *IntSuite) TestInteractive(c *check.C) { err = cl.SSH(context.TODO(), []string{}, false) c.Assert(err, check.IsNil) sessionEndC <- true + } // PersonB: wait for a session to become available, then join: @@ -535,6 +536,80 @@ func (s *IntSuite) TestInteractive(c *check.C) { c.Assert(strings.Contains(outputOfA, outputOfB), check.Equals, true) } +// TestShutdown tests scenario with a graceful shutdown, +// that session will be working after +func (s *IntSuite) TestShutdown(c *check.C) { + t := s.newTeleport(c, nil, true) + + // get a reference to site obj: + site := t.GetSiteAPI(Site) + c.Assert(site, check.NotNil) + + person := NewTerminal(250) + + // commandsC receive commands + commandsC := make(chan string, 0) + + // PersonA: SSH into the server, wait one second, then type some commands on stdin: + openSession := func() { + cl, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()}) + c.Assert(err, check.IsNil) + cl.Stdout = &person + cl.Stdin = &person + + go func() { + for command := range commandsC { + person.Type(command) + } + }() + + err = cl.SSH(context.TODO(), []string{}, false) + c.Assert(err, check.IsNil) + } + + go openSession() + + retry := func(command, pattern string) { + person.Type(command) + // wait for both sites to see each other via their reverse tunnels (for up to 10 seconds) + abortTime := time.Now().Add(10 * time.Second) + var matched bool + var output string + for { + output = string(replaceNewlines(person.Output(1000))) + matched, _ = regexp.MatchString(pattern, output) + if matched { + break + } + time.Sleep(time.Millisecond * 200) + if time.Now().After(abortTime) { + c.Fatalf("failed to capture output: %v", pattern) + } + } + if !matched { + c.Fatalf("output %q does not match pattern %q", output, pattern) + } + } + + retry("echo start \r\n", ".*start.*") + + // initiate shutdown + ctx := context.TODO() + shutdownContext := t.Process.StartShutdown(ctx) + + // make sure that terminal still works + retry("echo howdy \r\n", ".*howdy.*") + + // now type exit and wait for shutdown to complete + person.Type("exit\n\r") + + select { + case <-shutdownContext.Done(): + case <-time.After(5 * time.Second): + c.Fatalf("failed to shut down the server") + } +} + // TestInvalidLogins validates that you can't login with invalid login or // with invalid 'site' parameter func (s *IntSuite) TestEnvironmentVariables(c *check.C) { diff --git a/lib/auth/auth.go b/lib/auth/auth.go index fba69bd7798..4a04cc93ad7 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -24,7 +24,6 @@ limitations under the License. package auth import ( - "context" "crypto/x509" "fmt" "net/url" @@ -79,7 +78,6 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro if cfg.AuditLog == nil { cfg.AuditLog = events.NewDiscardAuditLog() } - closeCtx, cancelFunc := context.WithCancel(context.TODO()) as := AuthServer{ clusterName: cfg.ClusterName, bk: cfg.Backend, @@ -95,8 +93,6 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro oidcClients: make(map[string]*oidcClient), samlProviders: make(map[string]*samlProvider), githubClients: make(map[string]*githubClient), - cancelFunc: cancelFunc, - closeCtx: closeCtx, } for _, o := range opts { o(&as) @@ -122,8 +118,6 @@ type AuthServer struct { githubClients map[string]*githubClient clock clockwork.Clock bk backend.Backend - closeCtx context.Context - cancelFunc context.CancelFunc sshca.Authority @@ -144,7 +138,6 @@ type AuthServer struct { } func (a *AuthServer) Close() error { - a.cancelFunc() if a.bk != nil { return trace.Wrap(a.bk.Close()) } diff --git a/lib/auth/tun.go b/lib/auth/tun.go index e77c92afef3..13a7ced75ad 100644 --- a/lib/auth/tun.go +++ b/lib/auth/tun.go @@ -181,6 +181,14 @@ func (s *AuthTunnel) Close() error { return nil } +// Shutdown gracefully shuts down auth server +func (s *AuthTunnel) Shutdown(ctx context.Context) error { + if s != nil && s.sshServer != nil { + return s.sshServer.Shutdown(ctx) + } + return nil +} + // HandleNewChan implements NewChanHandler interface: it gets called every time a new SSH // connection is established func (s *AuthTunnel) HandleNewChan(_ net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) { diff --git a/lib/backend/dir/impl.go b/lib/backend/dir/impl.go index 41b5f7702b4..bc06a365680 100644 --- a/lib/backend/dir/impl.go +++ b/lib/backend/dir/impl.go @@ -22,6 +22,7 @@ import ( "os" "path" "path/filepath" + "syscall" "time" "github.com/gravitational/teleport/lib/backend" @@ -160,6 +161,13 @@ func (bk *Backend) CreateVal(bucket []string, key string, val []byte, ttl time.D return trace.ConvertSystemError(err) } defer f.Close() + if err := writeLock(f); err != nil { + return trace.Wrap(err) + } + defer unlock(f) + if err := f.Truncate(0); err != nil { + return trace.ConvertSystemError(err) + } n, err := f.Write(val) if err == nil && n < len(val) { return trace.Wrap(io.ErrShortWrite) @@ -167,6 +175,27 @@ func (bk *Backend) CreateVal(bucket []string, key string, val []byte, ttl time.D return trace.Wrap(bk.applyTTL(dirPath, key, ttl)) } +func writeLock(f *os.File) error { + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil { + return trace.ConvertSystemError(err) + } + return nil +} + +func readLock(f *os.File) error { + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_SH); err != nil { + return trace.ConvertSystemError(err) + } + return nil +} + +func unlock(f *os.File) error { + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil { + return trace.ConvertSystemError(err) + } + return nil +} + // UpsertVal updates or inserts value with a given TTL into a bucket // ForeverTTL for no TTL func (bk *Backend) UpsertVal(bucket []string, key string, val []byte, ttl time.Duration) error { @@ -176,17 +205,33 @@ func (bk *Backend) UpsertVal(bucket []string, key string, val []byte, ttl time.D if err != nil { return trace.Wrap(err) } - // create the (or overwrite existing) file (AKA "key"): - err = ioutil.WriteFile(path.Join(dirPath, key), val, defaultFileMode) + filename := path.Join(dirPath, key) + f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE, defaultFileMode) if err != nil { + if os.IsExist(err) { + return trace.AlreadyExists("%s/%s already exists", dirPath, key) + } + return trace.ConvertSystemError(err) + } + defer f.Close() + if err := writeLock(f); err != nil { return trace.Wrap(err) } + defer unlock(f) + if err := f.Truncate(0); err != nil { + return trace.ConvertSystemError(err) + } + n, err := f.Write(val) + if err == nil && n < len(val) { + return trace.Wrap(io.ErrShortWrite) + } return trace.Wrap(bk.applyTTL(dirPath, key, ttl)) } // GetVal return a value for a given key in the bucket func (bk *Backend) GetVal(bucket []string, key string) ([]byte, error) { dirPath := path.Join(path.Join(bk.RootDir, path.Join(bucket...))) + filename := path.Join(dirPath, key) expired, err := bk.checkTTL(dirPath, key) if err != nil { return nil, trace.Wrap(err) @@ -195,17 +240,26 @@ func (bk *Backend) GetVal(bucket []string, key string) ([]byte, error) { bk.DeleteKey(bucket, key) return nil, trace.NotFound("key %q is not found", key) } - fp := path.Join(dirPath, key) - bytes, err := ioutil.ReadFile(fp) + f, err := os.OpenFile(filename, os.O_RDONLY, defaultFileMode) if err != nil { // GetVal() on a bucket must return 'BadParameter' error: - if fi, _ := os.Stat(fp); fi != nil && fi.IsDir() { + if fi, _ := os.Stat(filename); fi != nil && fi.IsDir() { return nil, trace.BadParameter("%q is not a valid key", key) } return nil, trace.ConvertSystemError(err) } - // this could happen if we delete the file concurrently - // with the read, apparently we can read empty file back + defer f.Close() + if err := readLock(f); err != nil { + return nil, trace.Wrap(err) + } + defer unlock(f) + bytes, err := ioutil.ReadAll(f) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + // this could happen when CreateKey or UpsertKey created a file + // but, GetVal managed to get readLock right after it, + // so there are no contents there if len(bytes) == 0 { return nil, trace.NotFound("key %q is not found", key) } @@ -215,13 +269,25 @@ func (bk *Backend) GetVal(bucket []string, key string) ([]byte, error) { // DeleteKey deletes a key in a bucket func (bk *Backend) DeleteKey(bucket []string, key string) error { dirPath := path.Join(bk.RootDir, path.Join(bucket...)) + filename := path.Join(dirPath, key) + f, err := os.OpenFile(filename, os.O_RDONLY, defaultFileMode) + if err != nil { + if fi, _ := os.Stat(filename); fi != nil && fi.IsDir() { + return trace.BadParameter("%q is not a valid key", key) + } + return trace.ConvertSystemError(err) + } + defer f.Close() + if err := writeLock(f); err != nil { + return trace.Wrap(err) + } + defer unlock(f) if err := os.Remove(bk.ttlFile(dirPath, key)); err != nil { if !os.IsNotExist(err) { log.Warn(err) } } - return trace.ConvertSystemError(os.Remove( - path.Join(dirPath, key))) + return trace.ConvertSystemError(os.Remove(filename)) } // DeleteBucket deletes the bucket by a given path @@ -260,18 +326,39 @@ func removeFiles(dir string) error { return err } } else if !fi.IsDir() { - err = os.Remove(path) + err = removeFile(path) if err != nil { - err = trace.ConvertSystemError(err) - if !trace.IsNotFound(err) { - return err - } + return err } } } return nil } +func removeFile(path string) error { + f, err := os.OpenFile(path, os.O_RDONLY, defaultFileMode) + err = trace.ConvertSystemError(err) + if err != nil { + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + return nil + } + defer f.Close() + if err := writeLock(f); err != nil { + return trace.Wrap(err) + } + defer unlock(f) + err = os.Remove(path) + if err != nil { + err = trace.ConvertSystemError(err) + if !trace.IsNotFound(err) { + return err + } + } + return nil +} + // AcquireLock grabs a lock that will be released automatically in TTL func (bk *Backend) AcquireLock(token string, ttl time.Duration) (err error) { bk.Debugf("AcquireLock(%s)", token) diff --git a/lib/backend/dir/impl_test.go b/lib/backend/dir/impl_test.go index a3e991e4ba9..5248a7cb81a 100644 --- a/lib/backend/dir/impl_test.go +++ b/lib/backend/dir/impl_test.go @@ -57,18 +57,60 @@ func (s *Suite) SetUpSuite(c *check.C) { s.suite.B = s.bk } -func (s *Suite) TestConcurrentDeleteBucket(c *check.C) { +func (s *Suite) BenchmarkOperations(c *check.C) { + bucket := []string{"bench", "bucket"} + keys := []string{"key1", "key2", "key3", "key4", "key5"} + value1 := "some backend value, not large enough, but not small enought" + for i := 0; i < c.N; i++ { + for _, key := range keys { + err := s.bk.UpsertVal(bucket, key, []byte(value1), time.Hour) + c.Assert(err, check.IsNil) + bytes, err := s.bk.GetVal(bucket, key) + c.Assert(err, check.IsNil) + c.Assert(string(bytes), check.Equals, value1) + err = s.bk.DeleteKey(bucket, key) + c.Assert(err, check.IsNil) + } + } +} + +func (s *Suite) TestConcurrentOperations(c *check.C) { bucket := []string{"concurrent", "bucket"} + value1 := "this first value should not be corrupted by concurrent ops" + value2 := "this second value should not be corrupted too" const attempts = 50 - resultsC := make(chan struct{}, attempts*2) + resultsC := make(chan struct{}, attempts*4) for i := 0; i < attempts; i++ { go func(cnt int) { - err := s.bk.UpsertVal(bucket, "key", []byte("new-value"), backend.Forever) + err := s.bk.UpsertVal(bucket, "key", []byte(value1), time.Hour) resultsC <- struct{}{} c.Assert(err, check.IsNil) }(i) + go func(cnt int) { + err := s.bk.CreateVal(bucket, "key", []byte(value2), time.Hour) + resultsC <- struct{}{} + if err != nil && !trace.IsAlreadyExists(err) { + c.Assert(err, check.IsNil) + } + }(i) + + go func(cnt int) { + bytes, err := s.bk.GetVal(bucket, "key") + resultsC <- struct{}{} + if err != nil && !trace.IsNotFound(err) { + c.Assert(err, check.IsNil) + } + // make sure data is not corrupted along the way + if err == nil { + val := string(bytes) + if val != value1 && val != value2 { + c.Fatalf("expected one of %q or %q and got %q", value1, value2, val) + } + } + }(i) + go func(cnt int) { err := s.bk.DeleteBucket([]string{"concurrent"}, "bucket") resultsC <- struct{}{} @@ -76,7 +118,7 @@ func (s *Suite) TestConcurrentDeleteBucket(c *check.C) { }(i) } timeoutC := time.After(3 * time.Second) - for i := 0; i < attempts*2; i++ { + for i := 0; i < attempts*4; i++ { select { case <-resultsC: case <-timeoutC: diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 640438a3154..75053e876fa 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -75,15 +75,10 @@ type CommandLineFlags struct { // --labels flag Labels string - // --httpprofile hidden flag - HTTPProfileEndpoint bool // --pid-file flag PIDFile string - // Gops starts gops agent on a specified address - // if not specified, gops won't start + // Gops starts gops agent on a first available address Gops bool - // GopsAddr specifies to gops addr to listen on - GopsAddr string // DiagnosticAddr is listen address for diagnostic endpoint DiagnosticAddr string // PermitUserEnvironment enables reading of ~/.tsh/environment @@ -666,6 +661,15 @@ func Configure(clf *CommandLineFlags, cfg *service.Config) error { return trace.Wrap(err) } + // apply diangostic address flag + if clf.DiagnosticAddr != "" { + addr, err := utils.ParseAddr(clf.DiagnosticAddr) + if err != nil { + return trace.Wrap(err, "failed to parse diag-addr") + } + cfg.DiagnosticAddr = *addr + } + // apply --insecure-no-tls flag: if clf.DisableTLS { cfg.Proxy.DisableTLS = clf.DisableTLS diff --git a/lib/defaults/defaults.go b/lib/defaults/defaults.go index 9ade2c5acbc..2e985bbfa39 100644 --- a/lib/defaults/defaults.go +++ b/lib/defaults/defaults.go @@ -92,6 +92,9 @@ const ( // the SSH connection open if there are no reads/writes happening over it. DefaultIdleConnectionDuration = 20 * time.Minute + // ShutdownPollPeriod is a polling period for graceful shutdowns of SSH servers + ShutdownPollPeriod = 500 * time.Millisecond + // ReadHeadersTimeout is a default TCP timeout when we wait // for the response headers to arrive ReadHeadersTimeout = time.Second diff --git a/lib/reversetunnel/api.go b/lib/reversetunnel/api.go index 4f07cd12f5b..420127b65f0 100644 --- a/lib/reversetunnel/api.go +++ b/lib/reversetunnel/api.go @@ -17,6 +17,7 @@ limitations under the License. package reversetunnel import ( + "context" "net" "time" @@ -58,8 +59,10 @@ type Server interface { RemoveSite(domainName string) error // Start starts server Start() error - // CLose closes server's socket + // Close closes server's operations immediately Close() error + // Shutdown performs graceful server shutdown + Shutdown(context.Context) error // Wait waits for server to close all outstanding operations Wait() } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index 8e041acf4b5..c534bf6a90b 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -402,7 +402,7 @@ func (s *server) diffConns(newConns, existingConns map[string]services.TunnelCon } func (s *server) Wait() { - s.srv.Wait() + s.srv.Wait(context.TODO()) } func (s *server) Start() error { @@ -415,6 +415,11 @@ func (s *server) Close() error { return s.srv.Close() } +func (s *server) Shutdown(ctx context.Context) error { + s.cancel() + return s.srv.Shutdown(ctx) +} + func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) { // apply read/write timeouts to the server connection conn = utils.ObeyIdleTimeout(conn, diff --git a/lib/service/cfg.go b/lib/service/cfg.go index bba5b6d9fa0..6ab21128ca6 100644 --- a/lib/service/cfg.go +++ b/lib/service/cfg.go @@ -127,6 +127,9 @@ type Config struct { // MACAlgorithms is a list of message authentication codes (MAC) that // the server supports. If omitted the defaults will be used. MACAlgorithms []string + + // DiagnosticAddr is an address for diagnostic and healthz endpoint service + DiagnosticAddr utils.NetAddr } // ApplyToken assigns a given token to all internal services but only if token @@ -202,12 +205,12 @@ func (c CachePolicy) String() string { recentCachePolicy = fmt.Sprintf("will cache frequently accessed items for %v", c.GetRecentTTL()) } if c.NeverExpires { - return fmt.Sprintf("cache will not expire in case if connection to database is lost, %v", recentCachePolicy) + return fmt.Sprintf("cache that will not expire in case if connection to database is lost, %v", recentCachePolicy) } if c.TTL == 0 { - return fmt.Sprintf("cache will expire after connection to database is lost after %v, %v", defaults.CacheTTL, recentCachePolicy) + return fmt.Sprintf("cache that will expire after connection to database is lost after %v, %v", defaults.CacheTTL, recentCachePolicy) } - return fmt.Sprintf("cache will expire after connection to database is lost after %v, %v", c.TTL, recentCachePolicy) + return fmt.Sprintf("cache that will expire after connection to database is lost after %v, %v", c.TTL, recentCachePolicy) } // ProxyConfig configures proy service diff --git a/lib/service/service.go b/lib/service/service.go index 3bf12d4c9a4..87c4a6aa2ed 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -19,6 +19,7 @@ limitations under the License. package service import ( + "context" "crypto/tls" "fmt" "io" @@ -52,12 +53,15 @@ import ( "github.com/gravitational/teleport/lib/srv/regular" "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/state" + "github.com/gravitational/teleport/lib/system" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web" "github.com/gravitational/trace" + "github.com/gravitational/roundtrip" "github.com/jonboulle/clockwork" "github.com/pborman/uuid" + "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" ) @@ -142,6 +146,12 @@ type TeleportProcess struct { // identities of this process (credentials to auth sever, basically) Identities map[teleport.Role]*auth.Identity + // registeredListeners keeps track of all listeners created by the process + // used to pass through the + registeredListeners []RegisteredListener + // importedDescriptors is a list of imported file descriptors + // passed by the parent process + importedDescriptors []FileDescriptor } // GetAuthServer returns the process' auth server @@ -286,7 +296,7 @@ func (process *TeleportProcess) connectToAuthService(role teleport.Role, additio // and starts them under a supervisor, returning the supervisor object func NewTeleport(cfg *Config) (*TeleportProcess, error) { // before we do anything reset the SIGINT handler back to the default - utils.ResetInterruptSignalHandler() + system.ResetInterruptSignalHandler() if err := validateConfig(cfg); err != nil { return nil, trace.Wrap(err, "configuration error") @@ -301,6 +311,11 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { } } + importedDescriptors, err := importFileDescriptors() + if err != nil { + return nil, trace.Wrap(err) + } + // if there's no host uuid initialized yet, try to read one from the // one of the identities cfg.HostUUID, err = utils.ReadHostUUID(cfg.DataDir) @@ -337,14 +352,23 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { } process := &TeleportProcess{ - Clock: clockwork.NewRealClock(), - Supervisor: NewSupervisor(), - Config: cfg, - Identities: make(map[teleport.Role]*auth.Identity), + Clock: clockwork.NewRealClock(), + Supervisor: NewSupervisor(), + Config: cfg, + Identities: make(map[teleport.Role]*auth.Identity), + importedDescriptors: importedDescriptors, } serviceStarted := false + if !cfg.DiagnosticAddr.IsEmpty() { + if err := process.initDiagnosticService(); err != nil { + return nil, trace.Wrap(err) + } + } else { + warnOnErr(process.closeImportedDescriptors(teleport.ComponentDiagnostic)) + } + if cfg.Auth.Enabled { if cfg.Keygen == nil { cfg.Keygen = native.New() @@ -353,6 +377,8 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { return nil, trace.Wrap(err) } serviceStarted = true + } else { + warnOnErr(process.closeImportedDescriptors(teleport.ComponentAuth)) } if cfg.SSH.Enabled { @@ -360,6 +386,8 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { return nil, err } serviceStarted = true + } else { + warnOnErr(process.closeImportedDescriptors(teleport.ComponentNode)) } if cfg.Proxy.Enabled { @@ -367,6 +395,8 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) { return nil, err } serviceStarted = true + } else { + warnOnErr(process.closeImportedDescriptors(teleport.ComponentProxy)) } if !serviceStarted { @@ -514,57 +544,8 @@ func (process *TeleportProcess) initAuthService(authority sshca.Authority) error return trace.Wrap(err) } - // auth server listens on SSH and TLS, reusing the same socket - listener, err := net.Listen("tcp", cfg.Auth.SSHAddr.Addr) - if err != nil { - utils.Consolef(cfg.Console, "[AUTH] failed to bind to address %v, exiting", cfg.Auth.SSHAddr.Addr, err) - return trace.Wrap(err) - } - process.onExit(func(payload interface{}) { - log.Debugf("Closing listener: %v.", listener.Addr()) - listener.Close() - }) - if cfg.Auth.EnableProxyProtocol { - log.Infof("Starting Auth service with PROXY protocol support.") - } - mux, err := multiplexer.New(multiplexer.Config{ - EnableProxyProtocol: cfg.Auth.EnableProxyProtocol, - Listener: listener, - }) - if err != nil { - return trace.Wrap(err) - } - go mux.Serve() - - // Register an SSH endpoint which is used to create an SSH tunnel to send HTTP - // requests to the Auth API - var authTunnel *auth.AuthTunnel - process.RegisterFunc("auth.ssh", func() error { - utils.Consolef(cfg.Console, "[AUTH] Auth service is starting on %v", cfg.Auth.SSHAddr.Addr) - authTunnel, err = auth.NewTunnel( - cfg.Auth.SSHAddr, - identity.KeySigner, - apiConf, - auth.SetLimiter(sshLimiter), - ) - if err != nil { - utils.Consolef(cfg.Console, "[AUTH] Error: %v", err) - return trace.Wrap(err) - } - - // since authTunnel.Serve is a blocking call, we emit this even right before - // the service has started - process.BroadcastEvent(Event{Name: AuthSSHReady, Payload: nil}) - - if err := authTunnel.Serve(mux.SSH()); err != nil { - if askedToExit { - log.Infof("Auth tunnel exited.") - return nil - } - utils.Consolef(cfg.Console, "[AUTH] Error: %v", err) - return trace.Wrap(err) - } - return nil + log := logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentAuth, }) // Register TLS endpoint of the auth service @@ -581,13 +562,66 @@ func (process *TeleportProcess) initAuthService(authority sshca.Authority) error if err != nil { return trace.Wrap(err) } + + // auth server listens on SSH and TLS, reusing the same socket + listener, err := process.importOrCreateListener(teleport.ComponentAuth, cfg.Auth.SSHAddr.Addr) + if err != nil { + log.Errorf("PID: %v Failed to bind to address %v: %v, exiting.", os.Getpid(), cfg.Auth.SSHAddr.Addr, err) + return trace.Wrap(err) + } + // clean up unused descriptors passed for proxy, but not used by it + warnOnErr(process.closeImportedDescriptors(teleport.ComponentAuth)) + if cfg.Auth.EnableProxyProtocol { + log.Infof("Starting Auth service with PROXY protocol support.") + } + mux, err := multiplexer.New(multiplexer.Config{ + EnableProxyProtocol: cfg.Auth.EnableProxyProtocol, + Listener: listener, + }) + if err != nil { + listener.Close() + return trace.Wrap(err) + } + go mux.Serve() + + // Register an SSH endpoint which is used to create an SSH tunnel to send HTTP + // requests to the Auth API + var authTunnel *auth.AuthTunnel + process.RegisterFunc("auth.ssh", func() error { + log.Infof("Auth SSH service is starting on %v", cfg.Auth.SSHAddr.Addr) + authTunnel, err = auth.NewTunnel( + cfg.Auth.SSHAddr, + identity.KeySigner, + apiConf, + auth.SetLimiter(sshLimiter), + ) + if err != nil { + log.Errorf("Error: %v", err) + return trace.Wrap(err) + } + + // since authTunnel.Serve is a blocking call, we emit this even right before + // the service has started + process.BroadcastEvent(Event{Name: AuthSSHReady, Payload: nil}) + + if err := authTunnel.Serve(mux.SSH()); err != nil { + if askedToExit { + log.Infof("Auth tunnel exited.") + return nil + } + log.Errorf("Error: %v", err) + return trace.Wrap(err) + } + return nil + }) + process.RegisterFunc("auth.tls", func() error { // since tlsServer.Serve is a blocking call, we emit this even right before // the service has started process.BroadcastEvent(Event{Name: AuthTLSReady, Payload: nil}) err := tlsServer.Serve(mux.TLS()) - if err != nil { + if err != nil && err != http.ErrServerClosed { log.Warningf("TLS server exited with error: %v.", err) } return nil @@ -602,12 +636,14 @@ func (process *TeleportProcess) initAuthService(authority sshca.Authority) error } // External integrations rely on this event: process.BroadcastEvent(Event{Name: AuthIdentityEvent, Payload: connector}) - process.onExit(func(payload interface{}) { + process.onExit("auth.broadcast", func(payload interface{}) { connector.Client.Close() }) return nil }) + closeContext, signalClose := context.WithCancel(context.TODO()) + process.RegisterFunc("auth.heartbeat", func() error { srv := services.ServerV2{ Kind: services.KindAuthServer, @@ -642,42 +678,72 @@ func (process *TeleportProcess) initAuthService(authority sshca.Authority) error log.Warnf("Parameter advertise_ip is not set for this auth server. Trying to guess the IP this server can be reached at: %v.", srv.GetAddr()) } // immediately register, and then keep repeating in a loop: - for !askedToExit { + ticker := time.NewTicker(defaults.ServerHeartbeatTTL / 2) + defer ticker.Stop() + announce: + for { srv.SetTTL(process, defaults.ServerHeartbeatTTL) err := authServer.UpsertAuthServer(&srv) if err != nil { log.Warningf("Failed to announce presence: %v.", err) } - sleepTime := defaults.ServerHeartbeatTTL/2 + utils.RandomDuration(defaults.ServerHeartbeatTTL/10) - time.Sleep(sleepTime) + select { + case <-closeContext.Done(): + break announce + case <-ticker.C: + } } log.Infof("Heartbeat to other auth servers exited.") return nil }) // execute this when process is asked to exit: - process.onExit(func(payload interface{}) { - askedToExit = true - mux.Close() - authTunnel.Close() - tlsServer.Close() - log.Infof("Auth service exited.") + process.onExit("auth.shutdown", func(payload interface{}) { + // as a last resort, at least close listeners (e.g. panic) + if listener != nil { + defer listener.Close() + } + if mux != nil { + defer mux.Close() + } + signalClose() + if payload == nil { + log.Info("Shutting down immediately.") + warnOnErr(tlsServer.Close()) + warnOnErr(authTunnel.Close()) + } else { + log.Info("Shutting down gracefully.") + ctx := payloadContext(payload) + warnOnErr(tlsServer.Shutdown(ctx)) + warnOnErr(authTunnel.Shutdown(ctx)) + } + log.Info("Exited.") }) return nil } +func payloadContext(payload interface{}) context.Context { + ctx, ok := payload.(context.Context) + if ok { + return ctx + } + log.Errorf("expected context, got %T", payload) + return context.TODO() +} + // onExit allows individual services to register a callback function which will be // called when Teleport Process is asked to exit. Usually services terminate themselves // when the callback is called -func (process *TeleportProcess) onExit(callback func(interface{})) { - go func() { +func (process *TeleportProcess) onExit(serviceName string, callback func(interface{})) { + process.RegisterFunc(serviceName, func() error { eventC := make(chan Event) process.WaitForEvent(TeleportExitEvent, eventC, make(chan struct{})) select { case event := <-eventC: callback(event.Payload) } - }() + return nil + }) } // newLocalCache returns new local cache access point @@ -690,7 +756,7 @@ func (process *TeleportProcess) newLocalCache(clt auth.ClientI, cacheName []stri if err := os.MkdirAll(path, teleport.SharedDirMode); err != nil { return nil, trace.ConvertSystemError(err) } - cacheBackend, err := boltbk.New(backend.Params{"path": path}) + cacheBackend, err := dir.New(backend.Params{"path": path}) if err != nil { return nil, trace.Wrap(err) } @@ -712,9 +778,13 @@ func (process *TeleportProcess) initSSH() error { var s *regular.Server + log := logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentNode, + }) + process.RegisterFunc("ssh.node", func() error { event := <-eventsC - log.Infof("SSH node received %v", &event) + log.Infof("Received event %q.", event.Name) conn, ok := (event.Payload).(*Connector) if !ok { return trace.BadParameter("unsupported connector type: %T", event.Payload) @@ -743,6 +813,13 @@ func (process *TeleportProcess) initSSH() error { return trace.Wrap(err) } + listener, err := process.importOrCreateListener(teleport.ComponentNode, cfg.SSH.Addr.Addr) + if err != nil { + return trace.Wrap(err) + } + // clean up unused descriptors passed for proxy, but not used by it + warnOnErr(process.closeImportedDescriptors(teleport.ComponentNode)) + s, err = regular.New(cfg.SSH.Addr, cfg.Hostname, []ssh.Signer{conn.Identity.KeySigner}, @@ -765,25 +842,31 @@ func (process *TeleportProcess) initSSH() error { return trace.Wrap(err) } - utils.Consolef(cfg.Console, "[SSH] Service is starting on %v using %v", cfg.SSH.Addr.Addr, process.Config.CachePolicy) - if err := s.Start(); err != nil { - utils.Consolef(cfg.Console, "[SSH] Error: %v", err) - return trace.Wrap(err) - } + log.Infof("Service is starting on %v %v.", cfg.SSH.Addr.Addr, process.Config.CachePolicy) + go s.Serve(listener) // broadcast that the node has started process.BroadcastEvent(Event{Name: NodeSSHReady, Payload: nil}) // block and wait while the node is running s.Wait() - log.Infof("[SSH] node service exited") + log.Infof("Exited.") return nil }) // execute this when process is asked to exit: - process.onExit(func(payload interface{}) { - if s != nil { - s.Close() + process.onExit("ssh.shutdown", func(payload interface{}) { + if payload == nil { + log.Infof("Shutting down immediately.") + if s != nil { + warnOnErr(s.Close()) + } + } else { + log.Infof("Shutting down gracefully.") + if s != nil { + warnOnErr(s.Shutdown(payloadContext(payload))) + } } + log.Infof("Exited.") }) return nil } @@ -809,7 +892,7 @@ func (process *TeleportProcess) RegisterWithAuthServer(token string, role telepo return nil } if trace.IsConnectionProblem(err) { - utils.Consolef(cfg.Console, "[%v] connecting to auth server: %v", role, err) + log.Infof("%v failed attempt connecting to auth server: %v", role, err) time.Sleep(retryTime) continue } @@ -834,19 +917,65 @@ func (process *TeleportProcess) RegisterWithAuthServer(token string, role telepo log.Errorf("Failed to join the cluster: %v.", err) time.Sleep(retryTime) } else { - utils.Consolef(cfg.Console, "[%v] Successfully registered with the cluster", role) + log.Infof("%v has successfully registered with the cluster.", role) continue } } }) - process.onExit(func(interface{}) { + process.onExit("auth.client", func(interface{}) { if authClient != nil { authClient.Close() } }) } +// initDiagnosticService starts diagnostic service currently serving healthz +// and prometheus endpoints +func (process *TeleportProcess) initDiagnosticService() error { + mux := http.NewServeMux() + mux.Handle("/metrics", prometheus.Handler()) + mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { + roundtrip.ReplyJSON(w, http.StatusOK, map[string]interface{}{"status": "ok"}) + }) + listener, err := process.importOrCreateListener(teleport.ComponentDiagnostic, process.Config.DiagnosticAddr.Addr) + if err != nil { + return trace.Wrap(err) + } + warnOnErr(process.closeImportedDescriptors(teleport.ComponentDiagnostic)) + + server := &http.Server{ + Handler: mux, + } + + log := logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentDiagnostic, + }) + + log.Infof("Starting diagnostic service on %v.", process.Config.DiagnosticAddr.Addr) + + process.RegisterFunc("diagnostic.service", func() error { + err := server.Serve(listener) + if err != nil && err != http.ErrServerClosed { + log.Warningf("Diagnostic server exited with error: %v.", err) + } + return nil + }) + + process.onExit("diagnostic.shutdown", func(payload interface{}) { + if payload == nil { + log.Infof("Shutting down immediately.") + warnOnErr(server.Close()) + } else { + log.Infof("Shutting down gracefully.") + ctx := payloadContext(payload) + warnOnErr(server.Shutdown(ctx)) + } + log.Infof("Exited.") + }) + return nil +} + // initProxy gets called if teleport runs with 'proxy' role enabled. // this means it will do two things: // 1. serve a web UI @@ -876,7 +1005,7 @@ func (process *TeleportProcess) initProxy() error { process.WaitForEvent(ProxyIdentityEvent, eventsC, make(chan struct{})) event := <-eventsC - log.Debugf("Received event %v.", &event) + log.Debugf("Received event %q.", event.Name) conn, ok := (event.Payload).(*Connector) if !ok { return trace.BadParameter("unsupported connector type: %T", event.Payload) @@ -916,7 +1045,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { return &listeners, nil case cfg.Proxy.ReverseTunnelListenAddr.Equals(cfg.Proxy.WebAddr): log.Debugf("Setup Proxy: Reverse tunnel proxy and web proxy listen on the same port, multiplexing is on.") - listener, err := net.Listen("tcp", cfg.Proxy.WebAddr.Addr) + listener, err := process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "tunnel", "web"), cfg.Proxy.WebAddr.Addr) if err != nil { return nil, trace.Wrap(err) } @@ -936,7 +1065,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { return &listeners, nil case cfg.Proxy.EnableProxyProtocol && !cfg.Proxy.DisableWebService: log.Debugf("Setup Proxy: Proxy protocol is enabled for web service, multiplexing is on.") - listener, err := net.Listen("tcp", cfg.Proxy.WebAddr.Addr) + listener, err := process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "web"), cfg.Proxy.WebAddr.Addr) if err != nil { return nil, trace.Wrap(err) } @@ -951,7 +1080,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { return nil, trace.Wrap(err) } listeners.web = listeners.mux.TLS() - listeners.reverseTunnel, err = net.Listen("tcp", cfg.Proxy.ReverseTunnelListenAddr.Addr) + listeners.reverseTunnel, err = process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "tunnel"), cfg.Proxy.ReverseTunnelListenAddr.Addr) if err != nil { listener.Close() listeners.Close() @@ -962,14 +1091,14 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { default: log.Debugf("Proxy reverse tunnel are listening on the separate ports") if !cfg.Proxy.DisableReverseTunnel { - listeners.reverseTunnel, err = net.Listen("tcp", cfg.Proxy.ReverseTunnelListenAddr.Addr) + listeners.reverseTunnel, err = process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "tunnel"), cfg.Proxy.ReverseTunnelListenAddr.Addr) if err != nil { listeners.Close() return nil, trace.Wrap(err) } } if !cfg.Proxy.DisableWebService { - listeners.web, err = net.Listen("tcp", cfg.Proxy.WebAddr.Addr) + listeners.web, err = process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "web"), cfg.Proxy.WebAddr.Addr) if err != nil { listeners.Close() return nil, trace.Wrap(err) @@ -980,10 +1109,9 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) { } func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { - var ( - askedToExit = true - err error - ) + // clean up unused descriptors passed for proxy, but not used by it + defer process.closeImportedDescriptors(teleport.ComponentProxy) + var err error cfg := process.Config proxyLimiter, err := limiter.NewLimiter(cfg.Proxy.Limiter) @@ -1024,6 +1152,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + log := logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentReverseTunnelServer, + }) + // register SSH reverse tunnel server that accepts connections // from remote teleport nodes var tsrv reversetunnel.Server @@ -1053,59 +1185,55 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } process.RegisterFunc("proxy.reveresetunnel.server", func() error { - utils.Consolef(cfg.Console, "Starting reverse tunnel service is starting on %v using %v", cfg.Proxy.ReverseTunnelListenAddr.Addr, process.Config.CachePolicy) + log.Infof("Starting on %v using %v", cfg.Proxy.ReverseTunnelListenAddr.Addr, process.Config.CachePolicy) if err := tsrv.Start(); err != nil { - utils.Consolef(cfg.Console, "Error: %v", err) + log.Error(err) return trace.Wrap(err) } // notify parties that we've started reverse tunnel server process.BroadcastEvent(Event{Name: ProxyReverseTunnelReady, Payload: tsrv}) - tsrv.Wait() - if askedToExit { - log.Infof("Reverse tunnel exited.") - } return nil }) } // Register web proxy server + var webServer *http.Server if !process.Config.Proxy.DisableWebService { - process.RegisterFunc("proxy.web", func() error { - utils.Consolef(cfg.Console, "Web proxy service is starting on %v.", cfg.Proxy.WebAddr.Addr) - webHandler, err := web.NewHandler( - web.Config{ - Proxy: tsrv, - AuthServers: cfg.AuthServers[0], - DomainName: cfg.Hostname, - ProxyClient: conn.Client, - DisableUI: process.Config.Proxy.DisableWebInterface, - ProxySSHAddr: cfg.Proxy.SSHAddr, - ProxyWebAddr: cfg.Proxy.WebAddr, - }) + webHandler, err := web.NewHandler( + web.Config{ + Proxy: tsrv, + AuthServers: cfg.AuthServers[0], + DomainName: cfg.Hostname, + ProxyClient: conn.Client, + DisableUI: process.Config.Proxy.DisableWebInterface, + ProxySSHAddr: cfg.Proxy.SSHAddr, + ProxyWebAddr: cfg.Proxy.WebAddr, + }) + if err != nil { + return trace.Wrap(err) + } + proxyLimiter.WrapHandle(webHandler) + if !process.Config.Proxy.DisableTLS { + log.Infof("Using TLS cert %v, key %v", cfg.Proxy.TLSCert, cfg.Proxy.TLSKey) + tlsConfig, err := utils.CreateTLSConfiguration(cfg.Proxy.TLSCert, cfg.Proxy.TLSKey) if err != nil { return trace.Wrap(err) } + listeners.web = tls.NewListener(listeners.web, tlsConfig) + } + webServer = &http.Server{ + Handler: proxyLimiter, + } + process.RegisterFunc("proxy.web", func() error { + log.Infof("Web proxy service is starting on %v.", cfg.Proxy.WebAddr.Addr) defer webHandler.Close() - - proxyLimiter.WrapHandle(webHandler) process.BroadcastEvent(Event{Name: ProxyWebServerReady, Payload: webHandler}) - - if !process.Config.Proxy.DisableTLS { - tlsConfig, err := utils.CreateTLSConfiguration(cfg.Proxy.TLSCert, cfg.Proxy.TLSKey) - if err != nil { - return trace.Wrap(err) - } - listeners.web = tls.NewListener(listeners.web, tlsConfig) - } - if err = http.Serve(listeners.web, proxyLimiter); err != nil { - if askedToExit { - log.Infof("Proxy web server exited.") - return nil - } - log.Error(err) + if err := webServer.Serve(listeners.web); err != nil && err != http.ErrServerClosed { + log.Warningf("Error while serving web requests: %v", err) } + log.Infof("Exited.") return nil }) } else { @@ -1113,6 +1241,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } // Register SSH proxy server - SSH jumphost proxy server + listener, err := process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "ssh"), cfg.Proxy.SSHAddr.Addr) + if err != nil { + return trace.Wrap(err) + } sshProxy, err := regular.New(cfg.Proxy.SSHAddr, cfg.Hostname, []ssh.Signer{conn.Identity.KeySigner}, @@ -1134,26 +1266,20 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { } process.RegisterFunc("proxy.ssh", func() error { - utils.Consolef(cfg.Console, "[PROXY] SSH proxy service is starting on %v", cfg.Proxy.SSHAddr.Addr) - if err := sshProxy.Start(); err != nil { - if askedToExit { - log.Infof("SSH proxy exited") - return nil - } - utils.Consolef(cfg.Console, "[PROXY] Error: %v", err) - return trace.Wrap(err) - } - + log.Infof("SSH proxy service is starting on %v", cfg.Proxy.SSHAddr.Addr) + go sshProxy.Serve(listener) // broadcast that the proxy ssh server has started process.BroadcastEvent(Event{Name: ProxySSHReady, Payload: nil}) - return nil }) process.RegisterFunc("proxy.reversetunnel.agent", func() error { + log := logrus.WithFields(logrus.Fields{ + trace.Component: teleport.ComponentReverseTunnelAgent, + }) log.Infof("Starting reverse tunnel agent pool.") if err := agentPool.Start(); err != nil { - log.Fatalf("Failed to start: %v.", err) + log.Errorf("Failed to start: %v.", err) return trace.Wrap(err) } agentPool.Wait() @@ -1161,18 +1287,40 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { }) // execute this when process is asked to exit: - process.onExit(func(payload interface{}) { - listeners.Close() - if tsrv != nil { - tsrv.Close() - } - sshProxy.Close() + process.onExit("proxy.shutdown", func(payload interface{}) { agentPool.Stop() - log.Infof("Proxy service exited.") + defer listeners.Close() + if payload == nil { + log.Infof("Shutting down immediately.") + if tsrv != nil { + warnOnErr(tsrv.Close()) + } + if webServer != nil { + warnOnErr(webServer.Close()) + } + warnOnErr(sshProxy.Close()) + } else { + log.Infof("Shutting down gracefully.") + ctx := payloadContext(payload) + if tsrv != nil { + warnOnErr(tsrv.Shutdown(ctx)) + } + warnOnErr(sshProxy.Shutdown(ctx)) + if webServer != nil { + warnOnErr(webServer.Shutdown(ctx)) + } + } + log.Infof("Exited.") }) return nil } +func warnOnErr(err error) { + if err != nil { + log.Errorf("Error while performing operation: %v", err) + } +} + // initAuthStorage initializes the storage backend for the auth service. func (process *TeleportProcess) initAuthStorage() (bk backend.Backend, err error) { bc := &process.Config.Auth.StorageConfig @@ -1184,7 +1332,7 @@ func (process *TeleportProcess) initAuthStorage() (bk backend.Backend, err error // filesystem backend: case dir.GetName(): bk, err = dir.New(bc.Params) - // DynamoDB bakcend: + // DynamoDB backend: case dynamo.GetName(): bk, err = dynamo.New(bc.Params) // etcd backend: @@ -1199,6 +1347,37 @@ func (process *TeleportProcess) initAuthStorage() (bk backend.Backend, err error return bk, nil } +// StartShutdown launches non-blocking graceful shutdown process that signals +// completion, returns context that will be closed once the shutdown is done +func (process *TeleportProcess) StartShutdown(ctx context.Context) context.Context { + process.BroadcastEvent(Event{Name: TeleportExitEvent, Payload: ctx}) + localCtx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + process.Supervisor.Wait() + log.Debugf("All supervisor functions are completed.") + localAuth := process.getLocalAuth() + if localAuth != nil { + if err := process.localAuth.Close(); err != nil { + log.Warningf("Failed closing auth server: %v", trace.DebugReport(err)) + } + } + }() + return localCtx +} + +// Shutdown launches graceful shutdown process and waits +// for it to complete +func (process *TeleportProcess) Shutdown(ctx context.Context) { + localCtx := process.StartShutdown(ctx) + // wait until parent context closes + select { + case <-localCtx.Done(): + log.Debugf("Process completed.") + } +} + +// Close broadcasts close signals and exits immediately func (process *TeleportProcess) Close() error { process.BroadcastEvent(Event{Name: TeleportExitEvent}) localAuth := process.getLocalAuth() @@ -1248,7 +1427,7 @@ func validateConfig(cfg *Config) error { // initSelfSignedHTTPSCert generates and self-signs a TLS key+cert pair for https connection // to the proxy server. func initSelfSignedHTTPSCert(cfg *Config) (err error) { - log.Warningf("[CONFIG] NO TLS Keys provided, using self signed certificate") + log.Warningf("No TLS Keys provided, using self signed certificate.") keyPath := filepath.Join(cfg.DataDir, defaults.SelfSignedKeyPath) certPath := filepath.Join(cfg.DataDir, defaults.SelfSignedCertPath) @@ -1264,7 +1443,7 @@ func initSelfSignedHTTPSCert(cfg *Config) (err error) { if !os.IsNotExist(err) { return trace.Wrap(err, "unrecognized error reading certs") } - log.Warningf("[CONFIG] Generating self signed key and cert to %v %v", keyPath, certPath) + log.Warningf("Generating self signed key and cert to %v %v.", keyPath, certPath) creds, err := utils.GenerateSelfSignedCert([]string{cfg.Hostname, "localhost"}) if err != nil { diff --git a/lib/service/signals.go b/lib/service/signals.go new file mode 100644 index 00000000000..290c04bbe6b --- /dev/null +++ b/lib/service/signals.go @@ -0,0 +1,348 @@ +/* +Copyright 2017 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package service + +import ( + "context" + "encoding/json" + "net" + "os" + "os/exec" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" +) + +// printShutdownStatus prints running services until shut down +func (process *TeleportProcess) printShutdownStatus(ctx context.Context) { + t := time.NewTicker(5 * time.Second) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + log.Infof("Waiting for services: %v to finish.", process.Supervisor.Services()) + } + } +} + +// WaitForSignals waits for system signals and processes them. +// Should not be called twice by the process. +func (process *TeleportProcess) WaitForSignals(ctx context.Context) error { + sigC := make(chan os.Signal, 1024) + signal.Notify(sigC, os.Interrupt, os.Kill, syscall.SIGTERM, syscall.SIGUSR2, syscall.SIGCHLD, syscall.SIGHUP) + + doneContext, cancel := context.WithCancel(ctx) + defer cancel() + + // Block until a signal is received or handler got an error. + // Notice how this handler is serialized - it will only receive + // signals in sequence and will not run in parallel. + for { + select { + case signal := <-sigC: + switch signal { + case syscall.SIGTERM, syscall.SIGINT: + go process.printShutdownStatus(doneContext) + process.Shutdown(ctx) + log.Infof("All services stopped, exiting.") + return nil + case syscall.SIGKILL, syscall.SIGQUIT: + log.Infof("Got signal %q, exiting immediately.", signal) + process.Close() + return nil + case syscall.SIGUSR2: + log.Infof("Got signal %q, forking a new process.", signal) + if err := process.forkChild(); err != nil { + log.Infof("Failed to fork: %s", trace.DebugReport(err)) + } else { + log.Infof("Successfully started new process.") + } + case syscall.SIGHUP: + log.Infof("Got signal %q, performing graceful restart.", signal) + if err := process.forkChild(); err != nil { + log.Infof("Failed to fork: %s", trace.DebugReport(err)) + } else { + log.Infof("Successfully started new process.") + } + log.Infof("Shutting down gracefully.") + go process.printShutdownStatus(doneContext) + process.Shutdown(ctx) + log.Infof("All services stopped, exiting.") + return nil + case syscall.SIGCHLD: + log.Debugf("Child exited, got %q, collecting status.", signal) + var wait syscall.WaitStatus + syscall.Wait4(-1, &wait, syscall.WNOHANG, nil) + default: + log.Infof("Ignoring %q.", signal) + } + case <-ctx.Done(): + process.Close() + process.Wait() + log.Info("Got request to shutdown, context is closing") + return nil + } + } +} + +// closeImportedDescriptors closes imported but unused file descriptors, +// what could happen if service has updated configuration +func (process *TeleportProcess) closeImportedDescriptors(prefix string) error { + process.Lock() + defer process.Unlock() + + var errors []error + for i := range process.importedDescriptors { + d := process.importedDescriptors[i] + if strings.HasPrefix(d.Type, prefix) { + log.Infof("Closing imported but unused descriptor %v %v.", d.Type, d.Address) + errors = append(errors, d.File.Close()) + } + } + return trace.NewAggregate(errors...) +} + +// importOrCreateListener imports listener passed by the parent process (happens during live reload) +// or creates a new listener if there was no listener registered +func (process *TeleportProcess) importOrCreateListener(listenerType, address string) (net.Listener, error) { + l, err := process.importListener(listenerType, address) + if err == nil { + log.Infof("Using file descriptor %v %v passed by the parent process.", listenerType, address) + return l, nil + } + if !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + log.Infof("Service %v is creating new listener on %v.", listenerType, address) + return process.createListener(listenerType, address) +} + +// importListener imports listener passed by the parent process, if no listener is found +// returns NotFound, otherwise removes the file from the list +func (process *TeleportProcess) importListener(listenerType, address string) (net.Listener, error) { + process.Lock() + defer process.Unlock() + + for i := range process.importedDescriptors { + d := process.importedDescriptors[i] + if d.Type == listenerType && d.Address == address { + l, err := d.ToListener() + if err != nil { + return nil, trace.Wrap(err) + } + process.importedDescriptors = append(process.importedDescriptors[:i], process.importedDescriptors[i+1:]...) + process.registeredListeners = append(process.registeredListeners, RegisteredListener{Type: listenerType, Address: address, Listener: l}) + return l, nil + } + } + + return nil, trace.NotFound("no file descriptor for type %v and address %v has been imported", listenerType, address) +} + +// createListener creates listener and adds to a list of tracked listeners +func (process *TeleportProcess) createListener(listenerType, address string) (net.Listener, error) { + listener, err := net.Listen("tcp", address) + if err != nil { + return nil, trace.Wrap(err) + } + process.Lock() + defer process.Unlock() + r := RegisteredListener{Type: listenerType, Address: address, Listener: listener} + process.registeredListeners = append(process.registeredListeners, r) + return listener, nil +} + +// exportFileDescriptors exports file descriptors to be passed to child process +func (process *TeleportProcess) exportFileDescriptors() ([]FileDescriptor, error) { + var out []FileDescriptor + process.Lock() + defer process.Unlock() + for _, r := range process.registeredListeners { + file, err := utils.GetListenerFile(r.Listener) + if err != nil { + return nil, trace.Wrap(err) + } + out = append(out, FileDescriptor{File: file, Type: r.Type, Address: r.Address}) + } + return out, nil +} + +// importFileDescriptors imports file descriptors from environment if there are any +func importFileDescriptors() ([]FileDescriptor, error) { + // These files may be passed in by the parent process + filesString := os.Getenv(teleportFilesEnvVar) + if filesString == "" { + return nil, nil + } + + files, err := filesFromString(filesString) + if err != nil { + return nil, trace.BadParameter("child process has failed to read files, error %q", err) + } + + if len(files) != 0 { + log.Infof("Child has been passed files: %v", files) + } + + return files, nil +} + +// RegisteredListener is a listener registered +// within teleport process, can be passed to child process +type RegisteredListener struct { + // Type is a listener type, e.g. auth:ssh + Type string + // Address is an address listener is serving on, e.g. 127.0.0.1:3025 + Address string + // Listener is a file listener object + Listener net.Listener +} + +// FileDescriptor is a file descriptor associated +// with a listener +type FileDescriptor struct { + // Type is a listener type, e.g. auth:ssh + Type string + // Address is an addresss of the listener, e.g. 127.0.0.1:3025 + Address string + // File is a file descriptor associated with the listener + File *os.File +} + +func (fd *FileDescriptor) ToListener() (net.Listener, error) { + listener, err := net.FileListener(fd.File) + if err != nil { + return nil, err + } + fd.File.Close() + return listener, nil +} + +type fileDescriptor struct { + Address string `json:"addr"` + Type string `json:"type"` + FileFD int `json:"fd"` + FileName string `json:"fileName"` +} + +// filesToString serializes file descriptors as well as accompanying information (like socket host and port) +func filesToString(files []FileDescriptor) (string, error) { + out := make([]fileDescriptor, len(files)) + for i, f := range files { + out[i] = fileDescriptor{ + // Once files will be passed to the child process and their FDs will change. + // The first three passed files are stdin, stdout and stderr, every next file will have the index + 3 + // That's why we rearrange the FDs for child processes to get the correct file descriptors. + FileFD: i + 3, + FileName: f.File.Name(), + Address: f.Address, + Type: f.Type, + } + } + bytes, err := json.Marshal(out) + if err != nil { + return "", err + } + return string(bytes), nil +} + +const teleportFilesEnvVar = "TELEPORT_OS_FILES" + +func execPath() (string, error) { + name, err := exec.LookPath(os.Args[0]) + if err != nil { + return "", err + } + if _, err = os.Stat(name); nil != err { + return "", err + } + return name, err +} + +// filesFromString de-serializes the file descriptors and turns them in the os.Files +func filesFromString(in string) ([]FileDescriptor, error) { + var out []fileDescriptor + if err := json.Unmarshal([]byte(in), &out); err != nil { + return nil, err + } + files := make([]FileDescriptor, len(out)) + for i, o := range out { + files[i] = FileDescriptor{ + File: os.NewFile(uintptr(o.FileFD), o.FileName), + Address: o.Address, + Type: o.Type, + } + } + return files, nil +} + +func (process *TeleportProcess) forkChild() error { + path, err := execPath() + if err != nil { + return trace.Wrap(err) + } + + workingDir, err := os.Getwd() + if nil != err { + return err + } + + log := log.WithFields(logrus.Fields{"path": path, "workingDir": workingDir}) + + log.Info("Forking child.") + + listenerFiles, err := process.exportFileDescriptors() + if err != nil { + return trace.Wrap(err) + } + + // These files will be passed to the child process + files := []*os.File{os.Stdin, os.Stdout, os.Stderr} + for _, f := range listenerFiles { + files = append(files, f.File) + } + + // Serialize files to JSON string representation + vals, err := filesToString(listenerFiles) + if err != nil { + return err + } + + log.Infof("Passing %s to child", vals) + os.Setenv(teleportFilesEnvVar, vals) + + p, err := os.StartProcess(path, os.Args, &os.ProcAttr{ + Dir: workingDir, + Env: os.Environ(), + Files: files, + Sys: &syscall.SysProcAttr{}, + }) + + if err != nil { + return trace.ConvertSystemError(err) + } + + log.WithFields(logrus.Fields{"pid": p.Pid}).Infof("Started new child process.") + return nil +} diff --git a/lib/service/supervisor.go b/lib/service/supervisor.go index 52ed3244cb4..300c503dfb3 100644 --- a/lib/service/supervisor.go +++ b/lib/service/supervisor.go @@ -17,7 +17,7 @@ limitations under the License. package service import ( - "fmt" + "context" "sync" "github.com/gravitational/teleport/lib/utils" @@ -52,6 +52,9 @@ type Supervisor interface { // it's a combinatioin Start() and Wait() Run() error + // Services returns list of running services + Services() []string + // BroadcastEvent generates event and broadcasts it to all // interested parties BroadcastEvent(Event) @@ -71,18 +74,21 @@ type LocalSupervisor struct { events map[string]Event eventsC chan Event eventWaiters map[string][]*waiter - closer *utils.CloseBroadcaster + closeContext context.Context + signalClose context.CancelFunc } // NewSupervisor returns new instance of initialized supervisor func NewSupervisor() Supervisor { + closeContext, cancel := context.WithCancel(context.TODO()) srv := &LocalSupervisor{ services: []Service{}, wg: &sync.WaitGroup{}, events: map[string]Event{}, - eventsC: make(chan Event, 100), + eventsC: make(chan Event, 1024), eventWaiters: make(map[string][]*waiter), - closer: utils.NewCloseBroadcaster(), + closeContext: closeContext, + signalClose: cancel, } go srv.fanOut() return srv @@ -96,11 +102,11 @@ type Event struct { } func (e *Event) String() string { - return fmt.Sprintf("event(%v)", e.Name) + return e.Name } func (s *LocalSupervisor) Register(srv Service) { - log.WithFields(logrus.Fields{"service": srv.Name()}).Debugf("Adding service to supervisor") + log.WithFields(logrus.Fields{"service": srv.Name()}).Debugf("Adding service to supervisor.") s.Lock() defer s.Unlock() s.services = append(s.services, srv) @@ -125,7 +131,7 @@ func (s *LocalSupervisor) RegisterFunc(name string, fn ServiceFunc) { // RemoveService removes service from supervisor tracking list func (s *LocalSupervisor) RemoveService(srv Service) error { - log = log.WithFields(logrus.Fields{"service": srv.Name()}) + log := log.WithFields(logrus.Fields{"service": srv.Name()}) s.Lock() defer s.Unlock() for i, el := range s.services { @@ -169,8 +175,20 @@ func (s *LocalSupervisor) Start() error { return nil } +func (s *LocalSupervisor) Services() []string { + s.Lock() + defer s.Unlock() + + out := make([]string, len(s.services)) + + for i, srv := range s.services { + out[i] = srv.Name() + } + return out +} + func (s *LocalSupervisor) Wait() error { - defer s.closer.Close() + defer s.signalClose() s.wg.Wait() return nil } @@ -233,7 +251,7 @@ func (s *LocalSupervisor) fanOut() { for _, waiter := range waiters { go s.notifyWaiter(waiter, event) } - case <-s.closer.C: + case <-s.closeContext.Done(): return } } diff --git a/lib/services/local/presence.go b/lib/services/local/presence.go index 77823fcb3ab..bc14487e2cf 100644 --- a/lib/services/local/presence.go +++ b/lib/services/local/presence.go @@ -195,7 +195,7 @@ func (s *PresenceService) GetNodes(namespace string) ([]services.Server, error) } servers = append(servers, server) } - s.Infof("GetServers(%v) in %v", len(servers), time.Now().Sub(start)) + s.Debugf("GetServers(%v) in %v", len(servers), time.Now().Sub(start)) // sorting helps with tests and makes it all deterministic sort.Sort(services.SortedServers(servers)) return servers, nil @@ -218,7 +218,7 @@ func (s *PresenceService) batchGetNodes(namespace string) ([]services.Server, er servers[i] = server } - s.Infof("GetServers(%v) in %v", len(servers), time.Now().Sub(start)) + s.Debugf("GetServers(%v) in %v", len(servers), time.Now().Sub(start)) // sorting helps with tests and makes it all deterministic sort.Sort(services.SortedServers(servers)) return servers, nil diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 51dae35b5cf..ae1f3c6f0a0 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -19,6 +19,7 @@ limitations under the License. package regular import ( + "context" "fmt" "io" "io/ioutil" @@ -169,6 +170,15 @@ func (s *Server) Close() error { return s.srv.Close() } +// Shutdown performs graceful shutdown +func (s *Server) Shutdown(ctx context.Context) error { + // wait until connections drain off + err := s.srv.Shutdown(ctx) + s.closer.Close() + s.reg.Close() + return err +} + // Start starts server func (s *Server) Start() error { if len(s.getCommandLabels()) > 0 { @@ -178,9 +188,18 @@ func (s *Server) Start() error { return s.srv.Start() } +// Serve servers service on started listener +func (s *Server) Serve(l net.Listener) error { + if len(s.getCommandLabels()) > 0 { + s.updateLabels() + } + go s.heartbeatPresence() + return s.srv.Serve(l) +} + // Wait waits until server stops func (s *Server) Wait() { - s.srv.Wait() + s.srv.Wait(context.TODO()) } // SetShell sets default shell that will be executed for interactive diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index 7555658cc49..7f3c162a779 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -29,6 +29,7 @@ import ( "io" "net" "sync" + "sync/atomic" "time" "github.com/gravitational/teleport" @@ -67,6 +68,11 @@ type Server struct { closeContext context.Context closeFunc context.CancelFunc + + // conns tracks amount of current active connections + conns int32 + // shutdownPollPeriod sets polling period for shutdown + shutdownPollPeriod time.Duration } const ( @@ -100,6 +106,14 @@ func SetLimiter(limiter *limiter.Limiter) ServerOption { } } +// SetShutdownPollPeriod sets a polling period for graceful shutdowns of SSH servers +func SetShutdownPollPeriod(period time.Duration) ServerOption { + return func(s *Server) error { + s.shutdownPollPeriod = period + return nil + } +} + func NewServer( component string, a utils.NetAddr, @@ -134,6 +148,10 @@ func NewServer( return nil, err } } + if s.shutdownPollPeriod == 0 { + s.shutdownPollPeriod = defaults.ShutdownPollPeriod + } + for _, signer := range hostSigners { (&s.cfg).AddHostKey(signer) } @@ -234,8 +252,43 @@ func (s *Server) setListener(l net.Listener) error { // Wait waits until server stops serving new connections // on the listener socket -func (s *Server) Wait() { - <-s.closeContext.Done() +func (s *Server) Wait(ctx context.Context) { + select { + case <-s.closeContext.Done(): + case <-ctx.Done(): + } +} + +// Shutdown initiates graceful shutdown - waiting until all active +// connections will get closed +func (s *Server) Shutdown(ctx context.Context) error { + // close listener to stop receiving new connections + err := s.Close() + s.Wait(ctx) + activeConnections := s.trackConnections(0) + if activeConnections == 0 { + return err + } + s.Infof("Shutdown: waiting for %v connections to finish.", activeConnections) + lastReport := time.Time{} + ticker := time.NewTicker(s.shutdownPollPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + activeConnections = s.trackConnections(0) + if activeConnections == 0 { + return err + } + if time.Now().Sub(lastReport) > 10*s.shutdownPollPeriod { + s.Infof("Shutdown: waiting for %v connections to finish.", activeConnections) + lastReport = time.Now() + } + case <-ctx.Done(): + s.Infof("Context cancelled wait, returning") + return trace.ConnectionProblem(err, "context cancelled") + } + } } // Close closes listening socket and stops accepting connections @@ -282,6 +335,10 @@ func (s *Server) acceptConnections() { } } +func (s *Server) trackConnections(delta int32) int32 { + return atomic.AddInt32(&s.conns, delta) +} + // handleConnection is called every time an SSH server accepts a new // connection from a client. // @@ -289,6 +346,8 @@ func (s *Server) acceptConnections() { // and proxies, proxies and servers, servers and auth, etc). // func (s *Server) handleConnection(conn net.Conn) { + s.trackConnections(1) + defer s.trackConnections(-1) // initiate an SSH connection, note that we don't need to close the conn here // in case of error as ssh server takes care of this remoteAddr, _, err := net.SplitHostPort(conn.RemoteAddr().String()) diff --git a/lib/sshutils/server_test.go b/lib/sshutils/server_test.go index 7d2c6ac1f6b..bb04c639845 100644 --- a/lib/sshutils/server_test.go +++ b/lib/sshutils/server_test.go @@ -16,6 +16,7 @@ limitations under the License. package sshutils import ( + "context" "fmt" "net" "testing" @@ -23,6 +24,7 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "golang.org/x/crypto/ssh" . "gopkg.in/check.v1" @@ -73,6 +75,54 @@ func (s *ServerSuite) TestStartStop(c *C) { c.Assert(called, Equals, true) } +// TestShutdown tests graceul shutdown feature +func (s *ServerSuite) TestShutdown(c *C) { + closeContext, cancel := context.WithCancel(context.TODO()) + fn := NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) { + ch, _, err := nch.Accept() + defer ch.Close() + c.Assert(err, IsNil) + select { + case <-closeContext.Done(): + conn.Close() + } + }) + + srv, err := NewServer( + "test", + utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, + fn, + s.signers, + AuthMethods{Password: pass("abc123")}, + SetShutdownPollPeriod(10*time.Millisecond), + ) + c.Assert(err, IsNil) + c.Assert(srv.Start(), IsNil) + + clt, err := ssh.Dial("tcp", srv.Addr(), &ssh.ClientConfig{Auth: []ssh.AuthMethod{ssh.Password("abc123")}}) + c.Assert(err, IsNil) + defer clt.Close() + + // call new session to initiate opening new channel + clt.NewSession() + + // context will timeout because there is a connection around + ctx, ctxc := context.WithTimeout(context.TODO(), 50*time.Millisecond) + defer ctxc() + c.Assert(trace.IsConnectionProblem(srv.Shutdown(ctx)), Equals, true) + + // now shutdown will return + cancel() + ctx2, ctxc2 := context.WithTimeout(context.TODO(), time.Second) + defer ctxc2() + c.Assert(srv.Shutdown(ctx2), IsNil) + + // shutdown is re-entrable + ctx3, ctxc3 := context.WithTimeout(context.TODO(), time.Second) + defer ctxc3() + c.Assert(srv.Shutdown(ctx3), IsNil) +} + func (s *ServerSuite) TestConfigureCiphers(c *C) { called := false fn := NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) { @@ -117,7 +167,7 @@ func (s *ServerSuite) TestConfigureCiphers(c *C) { func wait(c *C, srv *Server) { s := make(chan struct{}) go func() { - srv.Wait() + srv.Wait(context.TODO()) s <- struct{}{} }() select { diff --git a/lib/utils/signal.go b/lib/system/signal.go similarity index 97% rename from lib/utils/signal.go rename to lib/system/signal.go index 3f860096292..4eddb1205d1 100644 --- a/lib/utils/signal.go +++ b/lib/system/signal.go @@ -1,4 +1,4 @@ -package utils +package system /* #include diff --git a/lib/utils/addr.go b/lib/utils/addr.go index 25cda097b39..31a2c6c43dd 100644 --- a/lib/utils/addr.go +++ b/lib/utils/addr.go @@ -129,6 +129,8 @@ func ParseAddr(a string) (*NetAddr, error) { return &NetAddr{Addr: u.Host, AddrNetwork: u.Scheme, Path: u.Path}, nil case "unix": return &NetAddr{Addr: u.Path, AddrNetwork: u.Scheme}, nil + case "http", "https": + return &NetAddr{Addr: u.Host, AddrNetwork: u.Scheme, Path: u.Path}, nil default: return nil, trace.BadParameter("'%v': unsupported scheme: '%v'", a, u.Scheme) } diff --git a/lib/utils/addr_test.go b/lib/utils/addr_test.go index 0daa334665e..ffb79697286 100644 --- a/lib/utils/addr_test.go +++ b/lib/utils/addr_test.go @@ -66,6 +66,16 @@ func (s *AddrTestSuite) TestParse(c *C) { c.Assert(addr.IsEmpty(), Equals, false) } +func (s *AddrTestSuite) TestParseHTTP(c *C) { + addr, err := ParseAddr("http://one:25/path") + c.Assert(err, IsNil) + c.Assert(addr, NotNil) + c.Assert(addr.Addr, Equals, "one:25") + c.Assert(addr.Path, Equals, "/path") + c.Assert(addr.FullAddress(), Equals, "http://one:25") + c.Assert(addr.IsEmpty(), Equals, false) +} + func (s *AddrTestSuite) TestParseDefaults(c *C) { addr, err := ParseAddr("host:25") c.Assert(err, IsNil) diff --git a/lib/utils/listener.go b/lib/utils/listener.go new file mode 100644 index 00000000000..6015ca2fd5e --- /dev/null +++ b/lib/utils/listener.go @@ -0,0 +1,19 @@ +package utils + +import ( + "net" + "os" + + "github.com/gravitational/trace" +) + +// GetListenerFile returns file associated with listener +func GetListenerFile(listener net.Listener) (*os.File, error) { + switch t := listener.(type) { + case *net.TCPListener: + return t.File() + case *net.UnixListener: + return t.File() + } + return nil, trace.BadParameter("unsupported listener: %T", listener) +} diff --git a/lib/utils/tls.go b/lib/utils/tls.go index f6eaa5c0f22..50ecf1900de 100644 --- a/lib/utils/tls.go +++ b/lib/utils/tls.go @@ -79,7 +79,6 @@ func CreateTLSConfiguration(certFile, keyFile string) (*tls.Config, error) { return nil, trace.BadParameter("certificate is not accessible by '%v'", certFile) } - log.Infof("[PROXY] TLS cert=%v key=%v", certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/web/sessions.go b/lib/web/sessions.go index 0507771bb5c..c3b6f1ce006 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -681,7 +681,6 @@ func (s *sessionCache) ValidateSession(user, sid string) (*SessionContext, error // this means that someone has just inserted the context, so // close our extra context and return if trace.IsAlreadyExists(err) { - log.Infof("just created, returning the existing one") defer c.Close() return out, nil } diff --git a/lib/web/stream.go b/lib/web/stream.go index a17d86a6930..e8051f08352 100644 --- a/lib/web/stream.go +++ b/lib/web/stream.go @@ -95,7 +95,9 @@ func (w *sessionStreamHandler) stream(ws *websocket.Conn) error { // ask for any events than happened since the last call: re, err := clt.GetSessionEvents(w.namespace, w.sessionID, eventsCursor+1) if err != nil { - log.Error(err) + if !trace.IsNotFound(err) { + log.Error(err) + } return emptyEventList } batchLen := len(re) diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 0b6a1988efa..88c76fa2619 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -17,8 +17,8 @@ limitations under the License. package common import ( + "context" "fmt" - "net/http" "os" "os/user" "path/filepath" @@ -33,9 +33,9 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/google/gops/agent" - "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" - "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" ) @@ -109,14 +109,10 @@ func Run(options Options) (executedCommand string, conf *service.Config) { "Base64 encoded configuration string").Hidden().Envar(defaults.ConfigEnvar). StringVar(&ccf.ConfigString) start.Flag("labels", "List of labels for this node").StringVar(&ccf.Labels) - start.Flag("httpprofile", - "[DEPRECATED] Start profiling endpoint on localhost:6060").Hidden().BoolVar(&ccf.HTTPProfileEndpoint) start.Flag("gops", - "Start gops endpoint on a given address").Hidden().BoolVar(&ccf.Gops) - start.Flag("gops-addr", - "Specify gops addr to listen on").Hidden().StringVar(&ccf.GopsAddr) + "Start gops troubleshooting endpoint on a first available adress.").BoolVar(&ccf.Gops) start.Flag("diag-addr", - "Start diangonstic endpoint on this address").Hidden().StringVar(&ccf.DiagnosticAddr) + "Start diangonstic prometheus and healthz endpoint.").StringVar(&ccf.DiagnosticAddr) start.Flag("permit-user-env", "Enables reading of ~/.tsh/environment when creating a session").Hidden().BoolVar(&ccf.PermitUserEnvironment) start.Flag("insecure", @@ -155,30 +151,13 @@ func Run(options Options) (executedCommand string, conf *service.Config) { if !options.InitOnly { log.Debug(conf.DebugDumpToYAML()) } - if ccf.HTTPProfileEndpoint { - log.Warningf("http profile endpoint is deprecated, use gops instead") - } if ccf.Gops { - log.Debugf("starting gops agent") - err := agent.Listen(&agent.Options{Addr: ccf.GopsAddr}) + log.Debug("Starting gops agent.") + err := agent.Listen(&agent.Options{}) if err != nil { log.Warningf("failed to start gops agent %v", err) } } - // collect and expose diagnostic endpoint - if ccf.DiagnosticAddr != "" { - mux := http.NewServeMux() - mux.Handle("/metrics", prometheus.Handler()) - mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { - roundtrip.ReplyJSON(w, http.StatusOK, map[string]interface{}{"status": "ok"}) - }) - go func() { - err := http.ListenAndServe(ccf.DiagnosticAddr, mux) - if err != nil { - log.Warningf("diagnostic endpoint exited %v", err) - } - }() - } if !options.InitOnly { err = OnStart(conf) } @@ -194,7 +173,7 @@ func Run(options Options) (executedCommand string, conf *service.Config) { if err != nil { utils.FatalError(err) } - log.Info("teleport: clean exit") + log.Debug("Clean exit.") return command, conf } @@ -219,7 +198,7 @@ func OnStart(config *service.Config) error { defer f.Close() } - return trace.Wrap(srv.Wait()) + return srv.WaitForSignals(context.TODO()) } // onStatus is the handler for "status" CLI command