Teleport signal handling and live reload.

This commit introduces signal handling.
Parent teleport process is now capable of forking
the child process and passing listeners file descriptors
to the child.

Parent process then can gracefully shutdown
by tracking the amount of current connections and
closing listeners once the amount goes to 0.

Here are the signals handled:

* USR2 signal will cause the parent to fork
a child process and pass listener file descriptors to it.
Child process will close unused file descriptors
and will bind to the used ones.

At this moment two processes - the parent
and the forked child process will be serving requests.
After looking at the traffic and the log files,
administrator can either shut down the parent process
or the child process if the child process is not functioning
as expected.

* TERM, INT signals will trigger graceful process shutdown.
Auth, node and proxy processes will wait until the amount
of active connections goes down to 0 and will exit after that.

* KILL, QUIT signals will cause immediate non-graceful
shutdown.

* HUP signal combines USR2 and TERM signals in a convenient
way: parent process will fork a child process and
self-initate graceful shutdown. This is a more convenient
than USR2/TERM sequence, but less agile and robust
as if the connection to the parent process drops, but
the new process exits with error, administrators
can lock themselves out of the environment.

Additionally, boltdb backend has to be phased out,
as it does not support read/writes by two concurrent
processes. This had required refactoring of the dir
backend to use file locking to allow inter-process
collaboration on read/write operations.
This commit is contained in:
Sasha Klizhentas 2018-02-07 18:32:50 -08:00
parent 5a2f5e861c
commit 68b65f5b24
28 changed files with 1155 additions and 239 deletions

View file

@ -1,6 +1,7 @@
package teleport
import (
"strings"
"time"
)
@ -79,6 +80,9 @@ const (
// ComponentProxy is SSH proxy (SSH server forwarding connections)
ComponentProxy = "proxy"
// ComponentDiagnostic is a diagnostic service
ComponentDiagnostic = "diagnostic"
// ComponentTunClient is a tunnel client
ComponentTunClient = "client:tunnel"
@ -201,6 +205,12 @@ const (
Off = "off"
)
// Component generates "component:subcomponent1:subcomponent2" strings used
// in debugging
func Component(components ...string) string {
return strings.Join(components, ":")
}
const (
// AuthorizedKeys are public keys that check against User CAs.
AuthorizedKeys = "authorized_keys"

2
e

@ -1 +1 @@
Subproject commit e18107a74135099d9ecd6caad7d9baa70f26efde
Subproject commit 66dea98c52bf2d60bc9ac0d50df1849f2923c35f

View file

@ -26,7 +26,7 @@ import (
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/boltbk"
"github.com/gravitational/teleport/lib/backend/dir"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/reversetunnel"
@ -363,7 +363,7 @@ func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *servic
tconf.Proxy.WebAddr.Addr = net.JoinHostPort(i.Hostname, i.GetPortWeb())
tconf.AuthServers = append(tconf.AuthServers, tconf.Auth.SSHAddr)
tconf.Auth.StorageConfig = backend.Config{
Type: boltbk.GetName(),
Type: dir.GetName(),
Params: backend.Params{"path": dataDir},
}

View file

@ -212,7 +212,7 @@ func (s *IntSuite) TestAuditOn(c *check.C) {
select {
case <-tickCh:
nodesInSite, err := site.GetNodes(defaults.Namespace)
if err != nil {
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)
}
if got, want := len(nodesInSite), count; got == want {
@ -497,6 +497,7 @@ func (s *IntSuite) TestInteractive(c *check.C) {
err = cl.SSH(context.TODO(), []string{}, false)
c.Assert(err, check.IsNil)
sessionEndC <- true
}
// PersonB: wait for a session to become available, then join:
@ -535,6 +536,80 @@ func (s *IntSuite) TestInteractive(c *check.C) {
c.Assert(strings.Contains(outputOfA, outputOfB), check.Equals, true)
}
// TestShutdown tests scenario with a graceful shutdown,
// that session will be working after
func (s *IntSuite) TestShutdown(c *check.C) {
t := s.newTeleport(c, nil, true)
// get a reference to site obj:
site := t.GetSiteAPI(Site)
c.Assert(site, check.NotNil)
person := NewTerminal(250)
// commandsC receive commands
commandsC := make(chan string, 0)
// PersonA: SSH into the server, wait one second, then type some commands on stdin:
openSession := func() {
cl, err := t.NewClient(ClientConfig{Login: s.me.Username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()})
c.Assert(err, check.IsNil)
cl.Stdout = &person
cl.Stdin = &person
go func() {
for command := range commandsC {
person.Type(command)
}
}()
err = cl.SSH(context.TODO(), []string{}, false)
c.Assert(err, check.IsNil)
}
go openSession()
retry := func(command, pattern string) {
person.Type(command)
// wait for both sites to see each other via their reverse tunnels (for up to 10 seconds)
abortTime := time.Now().Add(10 * time.Second)
var matched bool
var output string
for {
output = string(replaceNewlines(person.Output(1000)))
matched, _ = regexp.MatchString(pattern, output)
if matched {
break
}
time.Sleep(time.Millisecond * 200)
if time.Now().After(abortTime) {
c.Fatalf("failed to capture output: %v", pattern)
}
}
if !matched {
c.Fatalf("output %q does not match pattern %q", output, pattern)
}
}
retry("echo start \r\n", ".*start.*")
// initiate shutdown
ctx := context.TODO()
shutdownContext := t.Process.StartShutdown(ctx)
// make sure that terminal still works
retry("echo howdy \r\n", ".*howdy.*")
// now type exit and wait for shutdown to complete
person.Type("exit\n\r")
select {
case <-shutdownContext.Done():
case <-time.After(5 * time.Second):
c.Fatalf("failed to shut down the server")
}
}
// TestInvalidLogins validates that you can't login with invalid login or
// with invalid 'site' parameter
func (s *IntSuite) TestEnvironmentVariables(c *check.C) {

View file

@ -24,7 +24,6 @@ limitations under the License.
package auth
import (
"context"
"crypto/x509"
"fmt"
"net/url"
@ -79,7 +78,6 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro
if cfg.AuditLog == nil {
cfg.AuditLog = events.NewDiscardAuditLog()
}
closeCtx, cancelFunc := context.WithCancel(context.TODO())
as := AuthServer{
clusterName: cfg.ClusterName,
bk: cfg.Backend,
@ -95,8 +93,6 @@ func NewAuthServer(cfg *InitConfig, opts ...AuthServerOption) (*AuthServer, erro
oidcClients: make(map[string]*oidcClient),
samlProviders: make(map[string]*samlProvider),
githubClients: make(map[string]*githubClient),
cancelFunc: cancelFunc,
closeCtx: closeCtx,
}
for _, o := range opts {
o(&as)
@ -122,8 +118,6 @@ type AuthServer struct {
githubClients map[string]*githubClient
clock clockwork.Clock
bk backend.Backend
closeCtx context.Context
cancelFunc context.CancelFunc
sshca.Authority
@ -144,7 +138,6 @@ type AuthServer struct {
}
func (a *AuthServer) Close() error {
a.cancelFunc()
if a.bk != nil {
return trace.Wrap(a.bk.Close())
}

View file

@ -181,6 +181,14 @@ func (s *AuthTunnel) Close() error {
return nil
}
// Shutdown gracefully shuts down auth server
func (s *AuthTunnel) Shutdown(ctx context.Context) error {
if s != nil && s.sshServer != nil {
return s.sshServer.Shutdown(ctx)
}
return nil
}
// HandleNewChan implements NewChanHandler interface: it gets called every time a new SSH
// connection is established
func (s *AuthTunnel) HandleNewChan(_ net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {

View file

@ -22,6 +22,7 @@ import (
"os"
"path"
"path/filepath"
"syscall"
"time"
"github.com/gravitational/teleport/lib/backend"
@ -160,6 +161,13 @@ func (bk *Backend) CreateVal(bucket []string, key string, val []byte, ttl time.D
return trace.ConvertSystemError(err)
}
defer f.Close()
if err := writeLock(f); err != nil {
return trace.Wrap(err)
}
defer unlock(f)
if err := f.Truncate(0); err != nil {
return trace.ConvertSystemError(err)
}
n, err := f.Write(val)
if err == nil && n < len(val) {
return trace.Wrap(io.ErrShortWrite)
@ -167,6 +175,27 @@ func (bk *Backend) CreateVal(bucket []string, key string, val []byte, ttl time.D
return trace.Wrap(bk.applyTTL(dirPath, key, ttl))
}
func writeLock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil {
return trace.ConvertSystemError(err)
}
return nil
}
func readLock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_SH); err != nil {
return trace.ConvertSystemError(err)
}
return nil
}
func unlock(f *os.File) error {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
return trace.ConvertSystemError(err)
}
return nil
}
// UpsertVal updates or inserts value with a given TTL into a bucket
// ForeverTTL for no TTL
func (bk *Backend) UpsertVal(bucket []string, key string, val []byte, ttl time.Duration) error {
@ -176,17 +205,33 @@ func (bk *Backend) UpsertVal(bucket []string, key string, val []byte, ttl time.D
if err != nil {
return trace.Wrap(err)
}
// create the (or overwrite existing) file (AKA "key"):
err = ioutil.WriteFile(path.Join(dirPath, key), val, defaultFileMode)
filename := path.Join(dirPath, key)
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE, defaultFileMode)
if err != nil {
if os.IsExist(err) {
return trace.AlreadyExists("%s/%s already exists", dirPath, key)
}
return trace.ConvertSystemError(err)
}
defer f.Close()
if err := writeLock(f); err != nil {
return trace.Wrap(err)
}
defer unlock(f)
if err := f.Truncate(0); err != nil {
return trace.ConvertSystemError(err)
}
n, err := f.Write(val)
if err == nil && n < len(val) {
return trace.Wrap(io.ErrShortWrite)
}
return trace.Wrap(bk.applyTTL(dirPath, key, ttl))
}
// GetVal return a value for a given key in the bucket
func (bk *Backend) GetVal(bucket []string, key string) ([]byte, error) {
dirPath := path.Join(path.Join(bk.RootDir, path.Join(bucket...)))
filename := path.Join(dirPath, key)
expired, err := bk.checkTTL(dirPath, key)
if err != nil {
return nil, trace.Wrap(err)
@ -195,17 +240,26 @@ func (bk *Backend) GetVal(bucket []string, key string) ([]byte, error) {
bk.DeleteKey(bucket, key)
return nil, trace.NotFound("key %q is not found", key)
}
fp := path.Join(dirPath, key)
bytes, err := ioutil.ReadFile(fp)
f, err := os.OpenFile(filename, os.O_RDONLY, defaultFileMode)
if err != nil {
// GetVal() on a bucket must return 'BadParameter' error:
if fi, _ := os.Stat(fp); fi != nil && fi.IsDir() {
if fi, _ := os.Stat(filename); fi != nil && fi.IsDir() {
return nil, trace.BadParameter("%q is not a valid key", key)
}
return nil, trace.ConvertSystemError(err)
}
// this could happen if we delete the file concurrently
// with the read, apparently we can read empty file back
defer f.Close()
if err := readLock(f); err != nil {
return nil, trace.Wrap(err)
}
defer unlock(f)
bytes, err := ioutil.ReadAll(f)
if err != nil {
return nil, trace.ConvertSystemError(err)
}
// this could happen when CreateKey or UpsertKey created a file
// but, GetVal managed to get readLock right after it,
// so there are no contents there
if len(bytes) == 0 {
return nil, trace.NotFound("key %q is not found", key)
}
@ -215,13 +269,25 @@ func (bk *Backend) GetVal(bucket []string, key string) ([]byte, error) {
// DeleteKey deletes a key in a bucket
func (bk *Backend) DeleteKey(bucket []string, key string) error {
dirPath := path.Join(bk.RootDir, path.Join(bucket...))
filename := path.Join(dirPath, key)
f, err := os.OpenFile(filename, os.O_RDONLY, defaultFileMode)
if err != nil {
if fi, _ := os.Stat(filename); fi != nil && fi.IsDir() {
return trace.BadParameter("%q is not a valid key", key)
}
return trace.ConvertSystemError(err)
}
defer f.Close()
if err := writeLock(f); err != nil {
return trace.Wrap(err)
}
defer unlock(f)
if err := os.Remove(bk.ttlFile(dirPath, key)); err != nil {
if !os.IsNotExist(err) {
log.Warn(err)
}
}
return trace.ConvertSystemError(os.Remove(
path.Join(dirPath, key)))
return trace.ConvertSystemError(os.Remove(filename))
}
// DeleteBucket deletes the bucket by a given path
@ -260,18 +326,39 @@ func removeFiles(dir string) error {
return err
}
} else if !fi.IsDir() {
err = os.Remove(path)
err = removeFile(path)
if err != nil {
err = trace.ConvertSystemError(err)
if !trace.IsNotFound(err) {
return err
}
return err
}
}
}
return nil
}
func removeFile(path string) error {
f, err := os.OpenFile(path, os.O_RDONLY, defaultFileMode)
err = trace.ConvertSystemError(err)
if err != nil {
if !trace.IsNotFound(err) {
return trace.Wrap(err)
}
return nil
}
defer f.Close()
if err := writeLock(f); err != nil {
return trace.Wrap(err)
}
defer unlock(f)
err = os.Remove(path)
if err != nil {
err = trace.ConvertSystemError(err)
if !trace.IsNotFound(err) {
return err
}
}
return nil
}
// AcquireLock grabs a lock that will be released automatically in TTL
func (bk *Backend) AcquireLock(token string, ttl time.Duration) (err error) {
bk.Debugf("AcquireLock(%s)", token)

View file

@ -57,18 +57,60 @@ func (s *Suite) SetUpSuite(c *check.C) {
s.suite.B = s.bk
}
func (s *Suite) TestConcurrentDeleteBucket(c *check.C) {
func (s *Suite) BenchmarkOperations(c *check.C) {
bucket := []string{"bench", "bucket"}
keys := []string{"key1", "key2", "key3", "key4", "key5"}
value1 := "some backend value, not large enough, but not small enought"
for i := 0; i < c.N; i++ {
for _, key := range keys {
err := s.bk.UpsertVal(bucket, key, []byte(value1), time.Hour)
c.Assert(err, check.IsNil)
bytes, err := s.bk.GetVal(bucket, key)
c.Assert(err, check.IsNil)
c.Assert(string(bytes), check.Equals, value1)
err = s.bk.DeleteKey(bucket, key)
c.Assert(err, check.IsNil)
}
}
}
func (s *Suite) TestConcurrentOperations(c *check.C) {
bucket := []string{"concurrent", "bucket"}
value1 := "this first value should not be corrupted by concurrent ops"
value2 := "this second value should not be corrupted too"
const attempts = 50
resultsC := make(chan struct{}, attempts*2)
resultsC := make(chan struct{}, attempts*4)
for i := 0; i < attempts; i++ {
go func(cnt int) {
err := s.bk.UpsertVal(bucket, "key", []byte("new-value"), backend.Forever)
err := s.bk.UpsertVal(bucket, "key", []byte(value1), time.Hour)
resultsC <- struct{}{}
c.Assert(err, check.IsNil)
}(i)
go func(cnt int) {
err := s.bk.CreateVal(bucket, "key", []byte(value2), time.Hour)
resultsC <- struct{}{}
if err != nil && !trace.IsAlreadyExists(err) {
c.Assert(err, check.IsNil)
}
}(i)
go func(cnt int) {
bytes, err := s.bk.GetVal(bucket, "key")
resultsC <- struct{}{}
if err != nil && !trace.IsNotFound(err) {
c.Assert(err, check.IsNil)
}
// make sure data is not corrupted along the way
if err == nil {
val := string(bytes)
if val != value1 && val != value2 {
c.Fatalf("expected one of %q or %q and got %q", value1, value2, val)
}
}
}(i)
go func(cnt int) {
err := s.bk.DeleteBucket([]string{"concurrent"}, "bucket")
resultsC <- struct{}{}
@ -76,7 +118,7 @@ func (s *Suite) TestConcurrentDeleteBucket(c *check.C) {
}(i)
}
timeoutC := time.After(3 * time.Second)
for i := 0; i < attempts*2; i++ {
for i := 0; i < attempts*4; i++ {
select {
case <-resultsC:
case <-timeoutC:

View file

@ -75,15 +75,10 @@ type CommandLineFlags struct {
// --labels flag
Labels string
// --httpprofile hidden flag
HTTPProfileEndpoint bool
// --pid-file flag
PIDFile string
// Gops starts gops agent on a specified address
// if not specified, gops won't start
// Gops starts gops agent on a first available address
Gops bool
// GopsAddr specifies to gops addr to listen on
GopsAddr string
// DiagnosticAddr is listen address for diagnostic endpoint
DiagnosticAddr string
// PermitUserEnvironment enables reading of ~/.tsh/environment
@ -666,6 +661,15 @@ func Configure(clf *CommandLineFlags, cfg *service.Config) error {
return trace.Wrap(err)
}
// apply diangostic address flag
if clf.DiagnosticAddr != "" {
addr, err := utils.ParseAddr(clf.DiagnosticAddr)
if err != nil {
return trace.Wrap(err, "failed to parse diag-addr")
}
cfg.DiagnosticAddr = *addr
}
// apply --insecure-no-tls flag:
if clf.DisableTLS {
cfg.Proxy.DisableTLS = clf.DisableTLS

View file

@ -92,6 +92,9 @@ const (
// the SSH connection open if there are no reads/writes happening over it.
DefaultIdleConnectionDuration = 20 * time.Minute
// ShutdownPollPeriod is a polling period for graceful shutdowns of SSH servers
ShutdownPollPeriod = 500 * time.Millisecond
// ReadHeadersTimeout is a default TCP timeout when we wait
// for the response headers to arrive
ReadHeadersTimeout = time.Second

View file

@ -17,6 +17,7 @@ limitations under the License.
package reversetunnel
import (
"context"
"net"
"time"
@ -58,8 +59,10 @@ type Server interface {
RemoveSite(domainName string) error
// Start starts server
Start() error
// CLose closes server's socket
// Close closes server's operations immediately
Close() error
// Shutdown performs graceful server shutdown
Shutdown(context.Context) error
// Wait waits for server to close all outstanding operations
Wait()
}

View file

@ -402,7 +402,7 @@ func (s *server) diffConns(newConns, existingConns map[string]services.TunnelCon
}
func (s *server) Wait() {
s.srv.Wait()
s.srv.Wait(context.TODO())
}
func (s *server) Start() error {
@ -415,6 +415,11 @@ func (s *server) Close() error {
return s.srv.Close()
}
func (s *server) Shutdown(ctx context.Context) error {
s.cancel()
return s.srv.Shutdown(ctx)
}
func (s *server) HandleNewChan(conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel) {
// apply read/write timeouts to the server connection
conn = utils.ObeyIdleTimeout(conn,

View file

@ -127,6 +127,9 @@ type Config struct {
// MACAlgorithms is a list of message authentication codes (MAC) that
// the server supports. If omitted the defaults will be used.
MACAlgorithms []string
// DiagnosticAddr is an address for diagnostic and healthz endpoint service
DiagnosticAddr utils.NetAddr
}
// ApplyToken assigns a given token to all internal services but only if token
@ -202,12 +205,12 @@ func (c CachePolicy) String() string {
recentCachePolicy = fmt.Sprintf("will cache frequently accessed items for %v", c.GetRecentTTL())
}
if c.NeverExpires {
return fmt.Sprintf("cache will not expire in case if connection to database is lost, %v", recentCachePolicy)
return fmt.Sprintf("cache that will not expire in case if connection to database is lost, %v", recentCachePolicy)
}
if c.TTL == 0 {
return fmt.Sprintf("cache will expire after connection to database is lost after %v, %v", defaults.CacheTTL, recentCachePolicy)
return fmt.Sprintf("cache that will expire after connection to database is lost after %v, %v", defaults.CacheTTL, recentCachePolicy)
}
return fmt.Sprintf("cache will expire after connection to database is lost after %v, %v", c.TTL, recentCachePolicy)
return fmt.Sprintf("cache that will expire after connection to database is lost after %v, %v", c.TTL, recentCachePolicy)
}
// ProxyConfig configures proy service

View file

@ -19,6 +19,7 @@ limitations under the License.
package service
import (
"context"
"crypto/tls"
"fmt"
"io"
@ -52,12 +53,15 @@ import (
"github.com/gravitational/teleport/lib/srv/regular"
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/state"
"github.com/gravitational/teleport/lib/system"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/web"
"github.com/gravitational/trace"
"github.com/gravitational/roundtrip"
"github.com/jonboulle/clockwork"
"github.com/pborman/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
)
@ -142,6 +146,12 @@ type TeleportProcess struct {
// identities of this process (credentials to auth sever, basically)
Identities map[teleport.Role]*auth.Identity
// registeredListeners keeps track of all listeners created by the process
// used to pass through the
registeredListeners []RegisteredListener
// importedDescriptors is a list of imported file descriptors
// passed by the parent process
importedDescriptors []FileDescriptor
}
// GetAuthServer returns the process' auth server
@ -286,7 +296,7 @@ func (process *TeleportProcess) connectToAuthService(role teleport.Role, additio
// and starts them under a supervisor, returning the supervisor object
func NewTeleport(cfg *Config) (*TeleportProcess, error) {
// before we do anything reset the SIGINT handler back to the default
utils.ResetInterruptSignalHandler()
system.ResetInterruptSignalHandler()
if err := validateConfig(cfg); err != nil {
return nil, trace.Wrap(err, "configuration error")
@ -301,6 +311,11 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) {
}
}
importedDescriptors, err := importFileDescriptors()
if err != nil {
return nil, trace.Wrap(err)
}
// if there's no host uuid initialized yet, try to read one from the
// one of the identities
cfg.HostUUID, err = utils.ReadHostUUID(cfg.DataDir)
@ -337,14 +352,23 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) {
}
process := &TeleportProcess{
Clock: clockwork.NewRealClock(),
Supervisor: NewSupervisor(),
Config: cfg,
Identities: make(map[teleport.Role]*auth.Identity),
Clock: clockwork.NewRealClock(),
Supervisor: NewSupervisor(),
Config: cfg,
Identities: make(map[teleport.Role]*auth.Identity),
importedDescriptors: importedDescriptors,
}
serviceStarted := false
if !cfg.DiagnosticAddr.IsEmpty() {
if err := process.initDiagnosticService(); err != nil {
return nil, trace.Wrap(err)
}
} else {
warnOnErr(process.closeImportedDescriptors(teleport.ComponentDiagnostic))
}
if cfg.Auth.Enabled {
if cfg.Keygen == nil {
cfg.Keygen = native.New()
@ -353,6 +377,8 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) {
return nil, trace.Wrap(err)
}
serviceStarted = true
} else {
warnOnErr(process.closeImportedDescriptors(teleport.ComponentAuth))
}
if cfg.SSH.Enabled {
@ -360,6 +386,8 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) {
return nil, err
}
serviceStarted = true
} else {
warnOnErr(process.closeImportedDescriptors(teleport.ComponentNode))
}
if cfg.Proxy.Enabled {
@ -367,6 +395,8 @@ func NewTeleport(cfg *Config) (*TeleportProcess, error) {
return nil, err
}
serviceStarted = true
} else {
warnOnErr(process.closeImportedDescriptors(teleport.ComponentProxy))
}
if !serviceStarted {
@ -514,57 +544,8 @@ func (process *TeleportProcess) initAuthService(authority sshca.Authority) error
return trace.Wrap(err)
}
// auth server listens on SSH and TLS, reusing the same socket
listener, err := net.Listen("tcp", cfg.Auth.SSHAddr.Addr)
if err != nil {
utils.Consolef(cfg.Console, "[AUTH] failed to bind to address %v, exiting", cfg.Auth.SSHAddr.Addr, err)
return trace.Wrap(err)
}
process.onExit(func(payload interface{}) {
log.Debugf("Closing listener: %v.", listener.Addr())
listener.Close()
})
if cfg.Auth.EnableProxyProtocol {
log.Infof("Starting Auth service with PROXY protocol support.")
}
mux, err := multiplexer.New(multiplexer.Config{
EnableProxyProtocol: cfg.Auth.EnableProxyProtocol,
Listener: listener,
})
if err != nil {
return trace.Wrap(err)
}
go mux.Serve()
// Register an SSH endpoint which is used to create an SSH tunnel to send HTTP
// requests to the Auth API
var authTunnel *auth.AuthTunnel
process.RegisterFunc("auth.ssh", func() error {
utils.Consolef(cfg.Console, "[AUTH] Auth service is starting on %v", cfg.Auth.SSHAddr.Addr)
authTunnel, err = auth.NewTunnel(
cfg.Auth.SSHAddr,
identity.KeySigner,
apiConf,
auth.SetLimiter(sshLimiter),
)
if err != nil {
utils.Consolef(cfg.Console, "[AUTH] Error: %v", err)
return trace.Wrap(err)
}
// since authTunnel.Serve is a blocking call, we emit this even right before
// the service has started
process.BroadcastEvent(Event{Name: AuthSSHReady, Payload: nil})
if err := authTunnel.Serve(mux.SSH()); err != nil {
if askedToExit {
log.Infof("Auth tunnel exited.")
return nil
}
utils.Consolef(cfg.Console, "[AUTH] Error: %v", err)
return trace.Wrap(err)
}
return nil
log := logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentAuth,
})
// Register TLS endpoint of the auth service
@ -581,13 +562,66 @@ func (process *TeleportProcess) initAuthService(authority sshca.Authority) error
if err != nil {
return trace.Wrap(err)
}
// auth server listens on SSH and TLS, reusing the same socket
listener, err := process.importOrCreateListener(teleport.ComponentAuth, cfg.Auth.SSHAddr.Addr)
if err != nil {
log.Errorf("PID: %v Failed to bind to address %v: %v, exiting.", os.Getpid(), cfg.Auth.SSHAddr.Addr, err)
return trace.Wrap(err)
}
// clean up unused descriptors passed for proxy, but not used by it
warnOnErr(process.closeImportedDescriptors(teleport.ComponentAuth))
if cfg.Auth.EnableProxyProtocol {
log.Infof("Starting Auth service with PROXY protocol support.")
}
mux, err := multiplexer.New(multiplexer.Config{
EnableProxyProtocol: cfg.Auth.EnableProxyProtocol,
Listener: listener,
})
if err != nil {
listener.Close()
return trace.Wrap(err)
}
go mux.Serve()
// Register an SSH endpoint which is used to create an SSH tunnel to send HTTP
// requests to the Auth API
var authTunnel *auth.AuthTunnel
process.RegisterFunc("auth.ssh", func() error {
log.Infof("Auth SSH service is starting on %v", cfg.Auth.SSHAddr.Addr)
authTunnel, err = auth.NewTunnel(
cfg.Auth.SSHAddr,
identity.KeySigner,
apiConf,
auth.SetLimiter(sshLimiter),
)
if err != nil {
log.Errorf("Error: %v", err)
return trace.Wrap(err)
}
// since authTunnel.Serve is a blocking call, we emit this even right before
// the service has started
process.BroadcastEvent(Event{Name: AuthSSHReady, Payload: nil})
if err := authTunnel.Serve(mux.SSH()); err != nil {
if askedToExit {
log.Infof("Auth tunnel exited.")
return nil
}
log.Errorf("Error: %v", err)
return trace.Wrap(err)
}
return nil
})
process.RegisterFunc("auth.tls", func() error {
// since tlsServer.Serve is a blocking call, we emit this even right before
// the service has started
process.BroadcastEvent(Event{Name: AuthTLSReady, Payload: nil})
err := tlsServer.Serve(mux.TLS())
if err != nil {
if err != nil && err != http.ErrServerClosed {
log.Warningf("TLS server exited with error: %v.", err)
}
return nil
@ -602,12 +636,14 @@ func (process *TeleportProcess) initAuthService(authority sshca.Authority) error
}
// External integrations rely on this event:
process.BroadcastEvent(Event{Name: AuthIdentityEvent, Payload: connector})
process.onExit(func(payload interface{}) {
process.onExit("auth.broadcast", func(payload interface{}) {
connector.Client.Close()
})
return nil
})
closeContext, signalClose := context.WithCancel(context.TODO())
process.RegisterFunc("auth.heartbeat", func() error {
srv := services.ServerV2{
Kind: services.KindAuthServer,
@ -642,42 +678,72 @@ func (process *TeleportProcess) initAuthService(authority sshca.Authority) error
log.Warnf("Parameter advertise_ip is not set for this auth server. Trying to guess the IP this server can be reached at: %v.", srv.GetAddr())
}
// immediately register, and then keep repeating in a loop:
for !askedToExit {
ticker := time.NewTicker(defaults.ServerHeartbeatTTL / 2)
defer ticker.Stop()
announce:
for {
srv.SetTTL(process, defaults.ServerHeartbeatTTL)
err := authServer.UpsertAuthServer(&srv)
if err != nil {
log.Warningf("Failed to announce presence: %v.", err)
}
sleepTime := defaults.ServerHeartbeatTTL/2 + utils.RandomDuration(defaults.ServerHeartbeatTTL/10)
time.Sleep(sleepTime)
select {
case <-closeContext.Done():
break announce
case <-ticker.C:
}
}
log.Infof("Heartbeat to other auth servers exited.")
return nil
})
// execute this when process is asked to exit:
process.onExit(func(payload interface{}) {
askedToExit = true
mux.Close()
authTunnel.Close()
tlsServer.Close()
log.Infof("Auth service exited.")
process.onExit("auth.shutdown", func(payload interface{}) {
// as a last resort, at least close listeners (e.g. panic)
if listener != nil {
defer listener.Close()
}
if mux != nil {
defer mux.Close()
}
signalClose()
if payload == nil {
log.Info("Shutting down immediately.")
warnOnErr(tlsServer.Close())
warnOnErr(authTunnel.Close())
} else {
log.Info("Shutting down gracefully.")
ctx := payloadContext(payload)
warnOnErr(tlsServer.Shutdown(ctx))
warnOnErr(authTunnel.Shutdown(ctx))
}
log.Info("Exited.")
})
return nil
}
func payloadContext(payload interface{}) context.Context {
ctx, ok := payload.(context.Context)
if ok {
return ctx
}
log.Errorf("expected context, got %T", payload)
return context.TODO()
}
// onExit allows individual services to register a callback function which will be
// called when Teleport Process is asked to exit. Usually services terminate themselves
// when the callback is called
func (process *TeleportProcess) onExit(callback func(interface{})) {
go func() {
func (process *TeleportProcess) onExit(serviceName string, callback func(interface{})) {
process.RegisterFunc(serviceName, func() error {
eventC := make(chan Event)
process.WaitForEvent(TeleportExitEvent, eventC, make(chan struct{}))
select {
case event := <-eventC:
callback(event.Payload)
}
}()
return nil
})
}
// newLocalCache returns new local cache access point
@ -690,7 +756,7 @@ func (process *TeleportProcess) newLocalCache(clt auth.ClientI, cacheName []stri
if err := os.MkdirAll(path, teleport.SharedDirMode); err != nil {
return nil, trace.ConvertSystemError(err)
}
cacheBackend, err := boltbk.New(backend.Params{"path": path})
cacheBackend, err := dir.New(backend.Params{"path": path})
if err != nil {
return nil, trace.Wrap(err)
}
@ -712,9 +778,13 @@ func (process *TeleportProcess) initSSH() error {
var s *regular.Server
log := logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentNode,
})
process.RegisterFunc("ssh.node", func() error {
event := <-eventsC
log.Infof("SSH node received %v", &event)
log.Infof("Received event %q.", event.Name)
conn, ok := (event.Payload).(*Connector)
if !ok {
return trace.BadParameter("unsupported connector type: %T", event.Payload)
@ -743,6 +813,13 @@ func (process *TeleportProcess) initSSH() error {
return trace.Wrap(err)
}
listener, err := process.importOrCreateListener(teleport.ComponentNode, cfg.SSH.Addr.Addr)
if err != nil {
return trace.Wrap(err)
}
// clean up unused descriptors passed for proxy, but not used by it
warnOnErr(process.closeImportedDescriptors(teleport.ComponentNode))
s, err = regular.New(cfg.SSH.Addr,
cfg.Hostname,
[]ssh.Signer{conn.Identity.KeySigner},
@ -765,25 +842,31 @@ func (process *TeleportProcess) initSSH() error {
return trace.Wrap(err)
}
utils.Consolef(cfg.Console, "[SSH] Service is starting on %v using %v", cfg.SSH.Addr.Addr, process.Config.CachePolicy)
if err := s.Start(); err != nil {
utils.Consolef(cfg.Console, "[SSH] Error: %v", err)
return trace.Wrap(err)
}
log.Infof("Service is starting on %v %v.", cfg.SSH.Addr.Addr, process.Config.CachePolicy)
go s.Serve(listener)
// broadcast that the node has started
process.BroadcastEvent(Event{Name: NodeSSHReady, Payload: nil})
// block and wait while the node is running
s.Wait()
log.Infof("[SSH] node service exited")
log.Infof("Exited.")
return nil
})
// execute this when process is asked to exit:
process.onExit(func(payload interface{}) {
if s != nil {
s.Close()
process.onExit("ssh.shutdown", func(payload interface{}) {
if payload == nil {
log.Infof("Shutting down immediately.")
if s != nil {
warnOnErr(s.Close())
}
} else {
log.Infof("Shutting down gracefully.")
if s != nil {
warnOnErr(s.Shutdown(payloadContext(payload)))
}
}
log.Infof("Exited.")
})
return nil
}
@ -809,7 +892,7 @@ func (process *TeleportProcess) RegisterWithAuthServer(token string, role telepo
return nil
}
if trace.IsConnectionProblem(err) {
utils.Consolef(cfg.Console, "[%v] connecting to auth server: %v", role, err)
log.Infof("%v failed attempt connecting to auth server: %v", role, err)
time.Sleep(retryTime)
continue
}
@ -834,19 +917,65 @@ func (process *TeleportProcess) RegisterWithAuthServer(token string, role telepo
log.Errorf("Failed to join the cluster: %v.", err)
time.Sleep(retryTime)
} else {
utils.Consolef(cfg.Console, "[%v] Successfully registered with the cluster", role)
log.Infof("%v has successfully registered with the cluster.", role)
continue
}
}
})
process.onExit(func(interface{}) {
process.onExit("auth.client", func(interface{}) {
if authClient != nil {
authClient.Close()
}
})
}
// initDiagnosticService starts diagnostic service currently serving healthz
// and prometheus endpoints
func (process *TeleportProcess) initDiagnosticService() error {
mux := http.NewServeMux()
mux.Handle("/metrics", prometheus.Handler())
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
roundtrip.ReplyJSON(w, http.StatusOK, map[string]interface{}{"status": "ok"})
})
listener, err := process.importOrCreateListener(teleport.ComponentDiagnostic, process.Config.DiagnosticAddr.Addr)
if err != nil {
return trace.Wrap(err)
}
warnOnErr(process.closeImportedDescriptors(teleport.ComponentDiagnostic))
server := &http.Server{
Handler: mux,
}
log := logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentDiagnostic,
})
log.Infof("Starting diagnostic service on %v.", process.Config.DiagnosticAddr.Addr)
process.RegisterFunc("diagnostic.service", func() error {
err := server.Serve(listener)
if err != nil && err != http.ErrServerClosed {
log.Warningf("Diagnostic server exited with error: %v.", err)
}
return nil
})
process.onExit("diagnostic.shutdown", func(payload interface{}) {
if payload == nil {
log.Infof("Shutting down immediately.")
warnOnErr(server.Close())
} else {
log.Infof("Shutting down gracefully.")
ctx := payloadContext(payload)
warnOnErr(server.Shutdown(ctx))
}
log.Infof("Exited.")
})
return nil
}
// initProxy gets called if teleport runs with 'proxy' role enabled.
// this means it will do two things:
// 1. serve a web UI
@ -876,7 +1005,7 @@ func (process *TeleportProcess) initProxy() error {
process.WaitForEvent(ProxyIdentityEvent, eventsC, make(chan struct{}))
event := <-eventsC
log.Debugf("Received event %v.", &event)
log.Debugf("Received event %q.", event.Name)
conn, ok := (event.Payload).(*Connector)
if !ok {
return trace.BadParameter("unsupported connector type: %T", event.Payload)
@ -916,7 +1045,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) {
return &listeners, nil
case cfg.Proxy.ReverseTunnelListenAddr.Equals(cfg.Proxy.WebAddr):
log.Debugf("Setup Proxy: Reverse tunnel proxy and web proxy listen on the same port, multiplexing is on.")
listener, err := net.Listen("tcp", cfg.Proxy.WebAddr.Addr)
listener, err := process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "tunnel", "web"), cfg.Proxy.WebAddr.Addr)
if err != nil {
return nil, trace.Wrap(err)
}
@ -936,7 +1065,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) {
return &listeners, nil
case cfg.Proxy.EnableProxyProtocol && !cfg.Proxy.DisableWebService:
log.Debugf("Setup Proxy: Proxy protocol is enabled for web service, multiplexing is on.")
listener, err := net.Listen("tcp", cfg.Proxy.WebAddr.Addr)
listener, err := process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "web"), cfg.Proxy.WebAddr.Addr)
if err != nil {
return nil, trace.Wrap(err)
}
@ -951,7 +1080,7 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) {
return nil, trace.Wrap(err)
}
listeners.web = listeners.mux.TLS()
listeners.reverseTunnel, err = net.Listen("tcp", cfg.Proxy.ReverseTunnelListenAddr.Addr)
listeners.reverseTunnel, err = process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "tunnel"), cfg.Proxy.ReverseTunnelListenAddr.Addr)
if err != nil {
listener.Close()
listeners.Close()
@ -962,14 +1091,14 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) {
default:
log.Debugf("Proxy reverse tunnel are listening on the separate ports")
if !cfg.Proxy.DisableReverseTunnel {
listeners.reverseTunnel, err = net.Listen("tcp", cfg.Proxy.ReverseTunnelListenAddr.Addr)
listeners.reverseTunnel, err = process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "tunnel"), cfg.Proxy.ReverseTunnelListenAddr.Addr)
if err != nil {
listeners.Close()
return nil, trace.Wrap(err)
}
}
if !cfg.Proxy.DisableWebService {
listeners.web, err = net.Listen("tcp", cfg.Proxy.WebAddr.Addr)
listeners.web, err = process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "web"), cfg.Proxy.WebAddr.Addr)
if err != nil {
listeners.Close()
return nil, trace.Wrap(err)
@ -980,10 +1109,9 @@ func (process *TeleportProcess) setupProxyListeners() (*proxyListeners, error) {
}
func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
var (
askedToExit = true
err error
)
// clean up unused descriptors passed for proxy, but not used by it
defer process.closeImportedDescriptors(teleport.ComponentProxy)
var err error
cfg := process.Config
proxyLimiter, err := limiter.NewLimiter(cfg.Proxy.Limiter)
@ -1024,6 +1152,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}
log := logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentReverseTunnelServer,
})
// register SSH reverse tunnel server that accepts connections
// from remote teleport nodes
var tsrv reversetunnel.Server
@ -1053,59 +1185,55 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
return trace.Wrap(err)
}
process.RegisterFunc("proxy.reveresetunnel.server", func() error {
utils.Consolef(cfg.Console, "Starting reverse tunnel service is starting on %v using %v", cfg.Proxy.ReverseTunnelListenAddr.Addr, process.Config.CachePolicy)
log.Infof("Starting on %v using %v", cfg.Proxy.ReverseTunnelListenAddr.Addr, process.Config.CachePolicy)
if err := tsrv.Start(); err != nil {
utils.Consolef(cfg.Console, "Error: %v", err)
log.Error(err)
return trace.Wrap(err)
}
// notify parties that we've started reverse tunnel server
process.BroadcastEvent(Event{Name: ProxyReverseTunnelReady, Payload: tsrv})
tsrv.Wait()
if askedToExit {
log.Infof("Reverse tunnel exited.")
}
return nil
})
}
// Register web proxy server
var webServer *http.Server
if !process.Config.Proxy.DisableWebService {
process.RegisterFunc("proxy.web", func() error {
utils.Consolef(cfg.Console, "Web proxy service is starting on %v.", cfg.Proxy.WebAddr.Addr)
webHandler, err := web.NewHandler(
web.Config{
Proxy: tsrv,
AuthServers: cfg.AuthServers[0],
DomainName: cfg.Hostname,
ProxyClient: conn.Client,
DisableUI: process.Config.Proxy.DisableWebInterface,
ProxySSHAddr: cfg.Proxy.SSHAddr,
ProxyWebAddr: cfg.Proxy.WebAddr,
})
webHandler, err := web.NewHandler(
web.Config{
Proxy: tsrv,
AuthServers: cfg.AuthServers[0],
DomainName: cfg.Hostname,
ProxyClient: conn.Client,
DisableUI: process.Config.Proxy.DisableWebInterface,
ProxySSHAddr: cfg.Proxy.SSHAddr,
ProxyWebAddr: cfg.Proxy.WebAddr,
})
if err != nil {
return trace.Wrap(err)
}
proxyLimiter.WrapHandle(webHandler)
if !process.Config.Proxy.DisableTLS {
log.Infof("Using TLS cert %v, key %v", cfg.Proxy.TLSCert, cfg.Proxy.TLSKey)
tlsConfig, err := utils.CreateTLSConfiguration(cfg.Proxy.TLSCert, cfg.Proxy.TLSKey)
if err != nil {
return trace.Wrap(err)
}
listeners.web = tls.NewListener(listeners.web, tlsConfig)
}
webServer = &http.Server{
Handler: proxyLimiter,
}
process.RegisterFunc("proxy.web", func() error {
log.Infof("Web proxy service is starting on %v.", cfg.Proxy.WebAddr.Addr)
defer webHandler.Close()
proxyLimiter.WrapHandle(webHandler)
process.BroadcastEvent(Event{Name: ProxyWebServerReady, Payload: webHandler})
if !process.Config.Proxy.DisableTLS {
tlsConfig, err := utils.CreateTLSConfiguration(cfg.Proxy.TLSCert, cfg.Proxy.TLSKey)
if err != nil {
return trace.Wrap(err)
}
listeners.web = tls.NewListener(listeners.web, tlsConfig)
}
if err = http.Serve(listeners.web, proxyLimiter); err != nil {
if askedToExit {
log.Infof("Proxy web server exited.")
return nil
}
log.Error(err)
if err := webServer.Serve(listeners.web); err != nil && err != http.ErrServerClosed {
log.Warningf("Error while serving web requests: %v", err)
}
log.Infof("Exited.")
return nil
})
} else {
@ -1113,6 +1241,10 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
// Register SSH proxy server - SSH jumphost proxy server
listener, err := process.importOrCreateListener(teleport.Component(teleport.ComponentProxy, "ssh"), cfg.Proxy.SSHAddr.Addr)
if err != nil {
return trace.Wrap(err)
}
sshProxy, err := regular.New(cfg.Proxy.SSHAddr,
cfg.Hostname,
[]ssh.Signer{conn.Identity.KeySigner},
@ -1134,26 +1266,20 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}
process.RegisterFunc("proxy.ssh", func() error {
utils.Consolef(cfg.Console, "[PROXY] SSH proxy service is starting on %v", cfg.Proxy.SSHAddr.Addr)
if err := sshProxy.Start(); err != nil {
if askedToExit {
log.Infof("SSH proxy exited")
return nil
}
utils.Consolef(cfg.Console, "[PROXY] Error: %v", err)
return trace.Wrap(err)
}
log.Infof("SSH proxy service is starting on %v", cfg.Proxy.SSHAddr.Addr)
go sshProxy.Serve(listener)
// broadcast that the proxy ssh server has started
process.BroadcastEvent(Event{Name: ProxySSHReady, Payload: nil})
return nil
})
process.RegisterFunc("proxy.reversetunnel.agent", func() error {
log := logrus.WithFields(logrus.Fields{
trace.Component: teleport.ComponentReverseTunnelAgent,
})
log.Infof("Starting reverse tunnel agent pool.")
if err := agentPool.Start(); err != nil {
log.Fatalf("Failed to start: %v.", err)
log.Errorf("Failed to start: %v.", err)
return trace.Wrap(err)
}
agentPool.Wait()
@ -1161,18 +1287,40 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
})
// execute this when process is asked to exit:
process.onExit(func(payload interface{}) {
listeners.Close()
if tsrv != nil {
tsrv.Close()
}
sshProxy.Close()
process.onExit("proxy.shutdown", func(payload interface{}) {
agentPool.Stop()
log.Infof("Proxy service exited.")
defer listeners.Close()
if payload == nil {
log.Infof("Shutting down immediately.")
if tsrv != nil {
warnOnErr(tsrv.Close())
}
if webServer != nil {
warnOnErr(webServer.Close())
}
warnOnErr(sshProxy.Close())
} else {
log.Infof("Shutting down gracefully.")
ctx := payloadContext(payload)
if tsrv != nil {
warnOnErr(tsrv.Shutdown(ctx))
}
warnOnErr(sshProxy.Shutdown(ctx))
if webServer != nil {
warnOnErr(webServer.Shutdown(ctx))
}
}
log.Infof("Exited.")
})
return nil
}
func warnOnErr(err error) {
if err != nil {
log.Errorf("Error while performing operation: %v", err)
}
}
// initAuthStorage initializes the storage backend for the auth service.
func (process *TeleportProcess) initAuthStorage() (bk backend.Backend, err error) {
bc := &process.Config.Auth.StorageConfig
@ -1184,7 +1332,7 @@ func (process *TeleportProcess) initAuthStorage() (bk backend.Backend, err error
// filesystem backend:
case dir.GetName():
bk, err = dir.New(bc.Params)
// DynamoDB bakcend:
// DynamoDB backend:
case dynamo.GetName():
bk, err = dynamo.New(bc.Params)
// etcd backend:
@ -1199,6 +1347,37 @@ func (process *TeleportProcess) initAuthStorage() (bk backend.Backend, err error
return bk, nil
}
// StartShutdown launches non-blocking graceful shutdown process that signals
// completion, returns context that will be closed once the shutdown is done
func (process *TeleportProcess) StartShutdown(ctx context.Context) context.Context {
process.BroadcastEvent(Event{Name: TeleportExitEvent, Payload: ctx})
localCtx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
process.Supervisor.Wait()
log.Debugf("All supervisor functions are completed.")
localAuth := process.getLocalAuth()
if localAuth != nil {
if err := process.localAuth.Close(); err != nil {
log.Warningf("Failed closing auth server: %v", trace.DebugReport(err))
}
}
}()
return localCtx
}
// Shutdown launches graceful shutdown process and waits
// for it to complete
func (process *TeleportProcess) Shutdown(ctx context.Context) {
localCtx := process.StartShutdown(ctx)
// wait until parent context closes
select {
case <-localCtx.Done():
log.Debugf("Process completed.")
}
}
// Close broadcasts close signals and exits immediately
func (process *TeleportProcess) Close() error {
process.BroadcastEvent(Event{Name: TeleportExitEvent})
localAuth := process.getLocalAuth()
@ -1248,7 +1427,7 @@ func validateConfig(cfg *Config) error {
// initSelfSignedHTTPSCert generates and self-signs a TLS key+cert pair for https connection
// to the proxy server.
func initSelfSignedHTTPSCert(cfg *Config) (err error) {
log.Warningf("[CONFIG] NO TLS Keys provided, using self signed certificate")
log.Warningf("No TLS Keys provided, using self signed certificate.")
keyPath := filepath.Join(cfg.DataDir, defaults.SelfSignedKeyPath)
certPath := filepath.Join(cfg.DataDir, defaults.SelfSignedCertPath)
@ -1264,7 +1443,7 @@ func initSelfSignedHTTPSCert(cfg *Config) (err error) {
if !os.IsNotExist(err) {
return trace.Wrap(err, "unrecognized error reading certs")
}
log.Warningf("[CONFIG] Generating self signed key and cert to %v %v", keyPath, certPath)
log.Warningf("Generating self signed key and cert to %v %v.", keyPath, certPath)
creds, err := utils.GenerateSelfSignedCert([]string{cfg.Hostname, "localhost"})
if err != nil {

348
lib/service/signals.go Normal file
View file

@ -0,0 +1,348 @@
/*
Copyright 2017 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package service
import (
"context"
"encoding/json"
"net"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
"time"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"
)
// printShutdownStatus prints running services until shut down
func (process *TeleportProcess) printShutdownStatus(ctx context.Context) {
t := time.NewTicker(5 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
log.Infof("Waiting for services: %v to finish.", process.Supervisor.Services())
}
}
}
// WaitForSignals waits for system signals and processes them.
// Should not be called twice by the process.
func (process *TeleportProcess) WaitForSignals(ctx context.Context) error {
sigC := make(chan os.Signal, 1024)
signal.Notify(sigC, os.Interrupt, os.Kill, syscall.SIGTERM, syscall.SIGUSR2, syscall.SIGCHLD, syscall.SIGHUP)
doneContext, cancel := context.WithCancel(ctx)
defer cancel()
// Block until a signal is received or handler got an error.
// Notice how this handler is serialized - it will only receive
// signals in sequence and will not run in parallel.
for {
select {
case signal := <-sigC:
switch signal {
case syscall.SIGTERM, syscall.SIGINT:
go process.printShutdownStatus(doneContext)
process.Shutdown(ctx)
log.Infof("All services stopped, exiting.")
return nil
case syscall.SIGKILL, syscall.SIGQUIT:
log.Infof("Got signal %q, exiting immediately.", signal)
process.Close()
return nil
case syscall.SIGUSR2:
log.Infof("Got signal %q, forking a new process.", signal)
if err := process.forkChild(); err != nil {
log.Infof("Failed to fork: %s", trace.DebugReport(err))
} else {
log.Infof("Successfully started new process.")
}
case syscall.SIGHUP:
log.Infof("Got signal %q, performing graceful restart.", signal)
if err := process.forkChild(); err != nil {
log.Infof("Failed to fork: %s", trace.DebugReport(err))
} else {
log.Infof("Successfully started new process.")
}
log.Infof("Shutting down gracefully.")
go process.printShutdownStatus(doneContext)
process.Shutdown(ctx)
log.Infof("All services stopped, exiting.")
return nil
case syscall.SIGCHLD:
log.Debugf("Child exited, got %q, collecting status.", signal)
var wait syscall.WaitStatus
syscall.Wait4(-1, &wait, syscall.WNOHANG, nil)
default:
log.Infof("Ignoring %q.", signal)
}
case <-ctx.Done():
process.Close()
process.Wait()
log.Info("Got request to shutdown, context is closing")
return nil
}
}
}
// closeImportedDescriptors closes imported but unused file descriptors,
// what could happen if service has updated configuration
func (process *TeleportProcess) closeImportedDescriptors(prefix string) error {
process.Lock()
defer process.Unlock()
var errors []error
for i := range process.importedDescriptors {
d := process.importedDescriptors[i]
if strings.HasPrefix(d.Type, prefix) {
log.Infof("Closing imported but unused descriptor %v %v.", d.Type, d.Address)
errors = append(errors, d.File.Close())
}
}
return trace.NewAggregate(errors...)
}
// importOrCreateListener imports listener passed by the parent process (happens during live reload)
// or creates a new listener if there was no listener registered
func (process *TeleportProcess) importOrCreateListener(listenerType, address string) (net.Listener, error) {
l, err := process.importListener(listenerType, address)
if err == nil {
log.Infof("Using file descriptor %v %v passed by the parent process.", listenerType, address)
return l, nil
}
if !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}
log.Infof("Service %v is creating new listener on %v.", listenerType, address)
return process.createListener(listenerType, address)
}
// importListener imports listener passed by the parent process, if no listener is found
// returns NotFound, otherwise removes the file from the list
func (process *TeleportProcess) importListener(listenerType, address string) (net.Listener, error) {
process.Lock()
defer process.Unlock()
for i := range process.importedDescriptors {
d := process.importedDescriptors[i]
if d.Type == listenerType && d.Address == address {
l, err := d.ToListener()
if err != nil {
return nil, trace.Wrap(err)
}
process.importedDescriptors = append(process.importedDescriptors[:i], process.importedDescriptors[i+1:]...)
process.registeredListeners = append(process.registeredListeners, RegisteredListener{Type: listenerType, Address: address, Listener: l})
return l, nil
}
}
return nil, trace.NotFound("no file descriptor for type %v and address %v has been imported", listenerType, address)
}
// createListener creates listener and adds to a list of tracked listeners
func (process *TeleportProcess) createListener(listenerType, address string) (net.Listener, error) {
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, trace.Wrap(err)
}
process.Lock()
defer process.Unlock()
r := RegisteredListener{Type: listenerType, Address: address, Listener: listener}
process.registeredListeners = append(process.registeredListeners, r)
return listener, nil
}
// exportFileDescriptors exports file descriptors to be passed to child process
func (process *TeleportProcess) exportFileDescriptors() ([]FileDescriptor, error) {
var out []FileDescriptor
process.Lock()
defer process.Unlock()
for _, r := range process.registeredListeners {
file, err := utils.GetListenerFile(r.Listener)
if err != nil {
return nil, trace.Wrap(err)
}
out = append(out, FileDescriptor{File: file, Type: r.Type, Address: r.Address})
}
return out, nil
}
// importFileDescriptors imports file descriptors from environment if there are any
func importFileDescriptors() ([]FileDescriptor, error) {
// These files may be passed in by the parent process
filesString := os.Getenv(teleportFilesEnvVar)
if filesString == "" {
return nil, nil
}
files, err := filesFromString(filesString)
if err != nil {
return nil, trace.BadParameter("child process has failed to read files, error %q", err)
}
if len(files) != 0 {
log.Infof("Child has been passed files: %v", files)
}
return files, nil
}
// RegisteredListener is a listener registered
// within teleport process, can be passed to child process
type RegisteredListener struct {
// Type is a listener type, e.g. auth:ssh
Type string
// Address is an address listener is serving on, e.g. 127.0.0.1:3025
Address string
// Listener is a file listener object
Listener net.Listener
}
// FileDescriptor is a file descriptor associated
// with a listener
type FileDescriptor struct {
// Type is a listener type, e.g. auth:ssh
Type string
// Address is an addresss of the listener, e.g. 127.0.0.1:3025
Address string
// File is a file descriptor associated with the listener
File *os.File
}
func (fd *FileDescriptor) ToListener() (net.Listener, error) {
listener, err := net.FileListener(fd.File)
if err != nil {
return nil, err
}
fd.File.Close()
return listener, nil
}
type fileDescriptor struct {
Address string `json:"addr"`
Type string `json:"type"`
FileFD int `json:"fd"`
FileName string `json:"fileName"`
}
// filesToString serializes file descriptors as well as accompanying information (like socket host and port)
func filesToString(files []FileDescriptor) (string, error) {
out := make([]fileDescriptor, len(files))
for i, f := range files {
out[i] = fileDescriptor{
// Once files will be passed to the child process and their FDs will change.
// The first three passed files are stdin, stdout and stderr, every next file will have the index + 3
// That's why we rearrange the FDs for child processes to get the correct file descriptors.
FileFD: i + 3,
FileName: f.File.Name(),
Address: f.Address,
Type: f.Type,
}
}
bytes, err := json.Marshal(out)
if err != nil {
return "", err
}
return string(bytes), nil
}
const teleportFilesEnvVar = "TELEPORT_OS_FILES"
func execPath() (string, error) {
name, err := exec.LookPath(os.Args[0])
if err != nil {
return "", err
}
if _, err = os.Stat(name); nil != err {
return "", err
}
return name, err
}
// filesFromString de-serializes the file descriptors and turns them in the os.Files
func filesFromString(in string) ([]FileDescriptor, error) {
var out []fileDescriptor
if err := json.Unmarshal([]byte(in), &out); err != nil {
return nil, err
}
files := make([]FileDescriptor, len(out))
for i, o := range out {
files[i] = FileDescriptor{
File: os.NewFile(uintptr(o.FileFD), o.FileName),
Address: o.Address,
Type: o.Type,
}
}
return files, nil
}
func (process *TeleportProcess) forkChild() error {
path, err := execPath()
if err != nil {
return trace.Wrap(err)
}
workingDir, err := os.Getwd()
if nil != err {
return err
}
log := log.WithFields(logrus.Fields{"path": path, "workingDir": workingDir})
log.Info("Forking child.")
listenerFiles, err := process.exportFileDescriptors()
if err != nil {
return trace.Wrap(err)
}
// These files will be passed to the child process
files := []*os.File{os.Stdin, os.Stdout, os.Stderr}
for _, f := range listenerFiles {
files = append(files, f.File)
}
// Serialize files to JSON string representation
vals, err := filesToString(listenerFiles)
if err != nil {
return err
}
log.Infof("Passing %s to child", vals)
os.Setenv(teleportFilesEnvVar, vals)
p, err := os.StartProcess(path, os.Args, &os.ProcAttr{
Dir: workingDir,
Env: os.Environ(),
Files: files,
Sys: &syscall.SysProcAttr{},
})
if err != nil {
return trace.ConvertSystemError(err)
}
log.WithFields(logrus.Fields{"pid": p.Pid}).Infof("Started new child process.")
return nil
}

View file

@ -17,7 +17,7 @@ limitations under the License.
package service
import (
"fmt"
"context"
"sync"
"github.com/gravitational/teleport/lib/utils"
@ -52,6 +52,9 @@ type Supervisor interface {
// it's a combinatioin Start() and Wait()
Run() error
// Services returns list of running services
Services() []string
// BroadcastEvent generates event and broadcasts it to all
// interested parties
BroadcastEvent(Event)
@ -71,18 +74,21 @@ type LocalSupervisor struct {
events map[string]Event
eventsC chan Event
eventWaiters map[string][]*waiter
closer *utils.CloseBroadcaster
closeContext context.Context
signalClose context.CancelFunc
}
// NewSupervisor returns new instance of initialized supervisor
func NewSupervisor() Supervisor {
closeContext, cancel := context.WithCancel(context.TODO())
srv := &LocalSupervisor{
services: []Service{},
wg: &sync.WaitGroup{},
events: map[string]Event{},
eventsC: make(chan Event, 100),
eventsC: make(chan Event, 1024),
eventWaiters: make(map[string][]*waiter),
closer: utils.NewCloseBroadcaster(),
closeContext: closeContext,
signalClose: cancel,
}
go srv.fanOut()
return srv
@ -96,11 +102,11 @@ type Event struct {
}
func (e *Event) String() string {
return fmt.Sprintf("event(%v)", e.Name)
return e.Name
}
func (s *LocalSupervisor) Register(srv Service) {
log.WithFields(logrus.Fields{"service": srv.Name()}).Debugf("Adding service to supervisor")
log.WithFields(logrus.Fields{"service": srv.Name()}).Debugf("Adding service to supervisor.")
s.Lock()
defer s.Unlock()
s.services = append(s.services, srv)
@ -125,7 +131,7 @@ func (s *LocalSupervisor) RegisterFunc(name string, fn ServiceFunc) {
// RemoveService removes service from supervisor tracking list
func (s *LocalSupervisor) RemoveService(srv Service) error {
log = log.WithFields(logrus.Fields{"service": srv.Name()})
log := log.WithFields(logrus.Fields{"service": srv.Name()})
s.Lock()
defer s.Unlock()
for i, el := range s.services {
@ -169,8 +175,20 @@ func (s *LocalSupervisor) Start() error {
return nil
}
func (s *LocalSupervisor) Services() []string {
s.Lock()
defer s.Unlock()
out := make([]string, len(s.services))
for i, srv := range s.services {
out[i] = srv.Name()
}
return out
}
func (s *LocalSupervisor) Wait() error {
defer s.closer.Close()
defer s.signalClose()
s.wg.Wait()
return nil
}
@ -233,7 +251,7 @@ func (s *LocalSupervisor) fanOut() {
for _, waiter := range waiters {
go s.notifyWaiter(waiter, event)
}
case <-s.closer.C:
case <-s.closeContext.Done():
return
}
}

View file

@ -195,7 +195,7 @@ func (s *PresenceService) GetNodes(namespace string) ([]services.Server, error)
}
servers = append(servers, server)
}
s.Infof("GetServers(%v) in %v", len(servers), time.Now().Sub(start))
s.Debugf("GetServers(%v) in %v", len(servers), time.Now().Sub(start))
// sorting helps with tests and makes it all deterministic
sort.Sort(services.SortedServers(servers))
return servers, nil
@ -218,7 +218,7 @@ func (s *PresenceService) batchGetNodes(namespace string) ([]services.Server, er
servers[i] = server
}
s.Infof("GetServers(%v) in %v", len(servers), time.Now().Sub(start))
s.Debugf("GetServers(%v) in %v", len(servers), time.Now().Sub(start))
// sorting helps with tests and makes it all deterministic
sort.Sort(services.SortedServers(servers))
return servers, nil

View file

@ -19,6 +19,7 @@ limitations under the License.
package regular
import (
"context"
"fmt"
"io"
"io/ioutil"
@ -169,6 +170,15 @@ func (s *Server) Close() error {
return s.srv.Close()
}
// Shutdown performs graceful shutdown
func (s *Server) Shutdown(ctx context.Context) error {
// wait until connections drain off
err := s.srv.Shutdown(ctx)
s.closer.Close()
s.reg.Close()
return err
}
// Start starts server
func (s *Server) Start() error {
if len(s.getCommandLabels()) > 0 {
@ -178,9 +188,18 @@ func (s *Server) Start() error {
return s.srv.Start()
}
// Serve servers service on started listener
func (s *Server) Serve(l net.Listener) error {
if len(s.getCommandLabels()) > 0 {
s.updateLabels()
}
go s.heartbeatPresence()
return s.srv.Serve(l)
}
// Wait waits until server stops
func (s *Server) Wait() {
s.srv.Wait()
s.srv.Wait(context.TODO())
}
// SetShell sets default shell that will be executed for interactive

View file

@ -29,6 +29,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/gravitational/teleport"
@ -67,6 +68,11 @@ type Server struct {
closeContext context.Context
closeFunc context.CancelFunc
// conns tracks amount of current active connections
conns int32
// shutdownPollPeriod sets polling period for shutdown
shutdownPollPeriod time.Duration
}
const (
@ -100,6 +106,14 @@ func SetLimiter(limiter *limiter.Limiter) ServerOption {
}
}
// SetShutdownPollPeriod sets a polling period for graceful shutdowns of SSH servers
func SetShutdownPollPeriod(period time.Duration) ServerOption {
return func(s *Server) error {
s.shutdownPollPeriod = period
return nil
}
}
func NewServer(
component string,
a utils.NetAddr,
@ -134,6 +148,10 @@ func NewServer(
return nil, err
}
}
if s.shutdownPollPeriod == 0 {
s.shutdownPollPeriod = defaults.ShutdownPollPeriod
}
for _, signer := range hostSigners {
(&s.cfg).AddHostKey(signer)
}
@ -234,8 +252,43 @@ func (s *Server) setListener(l net.Listener) error {
// Wait waits until server stops serving new connections
// on the listener socket
func (s *Server) Wait() {
<-s.closeContext.Done()
func (s *Server) Wait(ctx context.Context) {
select {
case <-s.closeContext.Done():
case <-ctx.Done():
}
}
// Shutdown initiates graceful shutdown - waiting until all active
// connections will get closed
func (s *Server) Shutdown(ctx context.Context) error {
// close listener to stop receiving new connections
err := s.Close()
s.Wait(ctx)
activeConnections := s.trackConnections(0)
if activeConnections == 0 {
return err
}
s.Infof("Shutdown: waiting for %v connections to finish.", activeConnections)
lastReport := time.Time{}
ticker := time.NewTicker(s.shutdownPollPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
activeConnections = s.trackConnections(0)
if activeConnections == 0 {
return err
}
if time.Now().Sub(lastReport) > 10*s.shutdownPollPeriod {
s.Infof("Shutdown: waiting for %v connections to finish.", activeConnections)
lastReport = time.Now()
}
case <-ctx.Done():
s.Infof("Context cancelled wait, returning")
return trace.ConnectionProblem(err, "context cancelled")
}
}
}
// Close closes listening socket and stops accepting connections
@ -282,6 +335,10 @@ func (s *Server) acceptConnections() {
}
}
func (s *Server) trackConnections(delta int32) int32 {
return atomic.AddInt32(&s.conns, delta)
}
// handleConnection is called every time an SSH server accepts a new
// connection from a client.
//
@ -289,6 +346,8 @@ func (s *Server) acceptConnections() {
// and proxies, proxies and servers, servers and auth, etc).
//
func (s *Server) handleConnection(conn net.Conn) {
s.trackConnections(1)
defer s.trackConnections(-1)
// initiate an SSH connection, note that we don't need to close the conn here
// in case of error as ssh server takes care of this
remoteAddr, _, err := net.SplitHostPort(conn.RemoteAddr().String())

View file

@ -16,6 +16,7 @@ limitations under the License.
package sshutils
import (
"context"
"fmt"
"net"
"testing"
@ -23,6 +24,7 @@ import (
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
. "gopkg.in/check.v1"
@ -73,6 +75,54 @@ func (s *ServerSuite) TestStartStop(c *C) {
c.Assert(called, Equals, true)
}
// TestShutdown tests graceul shutdown feature
func (s *ServerSuite) TestShutdown(c *C) {
closeContext, cancel := context.WithCancel(context.TODO())
fn := NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
ch, _, err := nch.Accept()
defer ch.Close()
c.Assert(err, IsNil)
select {
case <-closeContext.Done():
conn.Close()
}
})
srv, err := NewServer(
"test",
utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"},
fn,
s.signers,
AuthMethods{Password: pass("abc123")},
SetShutdownPollPeriod(10*time.Millisecond),
)
c.Assert(err, IsNil)
c.Assert(srv.Start(), IsNil)
clt, err := ssh.Dial("tcp", srv.Addr(), &ssh.ClientConfig{Auth: []ssh.AuthMethod{ssh.Password("abc123")}})
c.Assert(err, IsNil)
defer clt.Close()
// call new session to initiate opening new channel
clt.NewSession()
// context will timeout because there is a connection around
ctx, ctxc := context.WithTimeout(context.TODO(), 50*time.Millisecond)
defer ctxc()
c.Assert(trace.IsConnectionProblem(srv.Shutdown(ctx)), Equals, true)
// now shutdown will return
cancel()
ctx2, ctxc2 := context.WithTimeout(context.TODO(), time.Second)
defer ctxc2()
c.Assert(srv.Shutdown(ctx2), IsNil)
// shutdown is re-entrable
ctx3, ctxc3 := context.WithTimeout(context.TODO(), time.Second)
defer ctxc3()
c.Assert(srv.Shutdown(ctx3), IsNil)
}
func (s *ServerSuite) TestConfigureCiphers(c *C) {
called := false
fn := NewChanHandlerFunc(func(_ net.Conn, conn *ssh.ServerConn, nch ssh.NewChannel) {
@ -117,7 +167,7 @@ func (s *ServerSuite) TestConfigureCiphers(c *C) {
func wait(c *C, srv *Server) {
s := make(chan struct{})
go func() {
srv.Wait()
srv.Wait(context.TODO())
s <- struct{}{}
}()
select {

View file

@ -1,4 +1,4 @@
package utils
package system
/*
#include <signal.h>

View file

@ -129,6 +129,8 @@ func ParseAddr(a string) (*NetAddr, error) {
return &NetAddr{Addr: u.Host, AddrNetwork: u.Scheme, Path: u.Path}, nil
case "unix":
return &NetAddr{Addr: u.Path, AddrNetwork: u.Scheme}, nil
case "http", "https":
return &NetAddr{Addr: u.Host, AddrNetwork: u.Scheme, Path: u.Path}, nil
default:
return nil, trace.BadParameter("'%v': unsupported scheme: '%v'", a, u.Scheme)
}

View file

@ -66,6 +66,16 @@ func (s *AddrTestSuite) TestParse(c *C) {
c.Assert(addr.IsEmpty(), Equals, false)
}
func (s *AddrTestSuite) TestParseHTTP(c *C) {
addr, err := ParseAddr("http://one:25/path")
c.Assert(err, IsNil)
c.Assert(addr, NotNil)
c.Assert(addr.Addr, Equals, "one:25")
c.Assert(addr.Path, Equals, "/path")
c.Assert(addr.FullAddress(), Equals, "http://one:25")
c.Assert(addr.IsEmpty(), Equals, false)
}
func (s *AddrTestSuite) TestParseDefaults(c *C) {
addr, err := ParseAddr("host:25")
c.Assert(err, IsNil)

19
lib/utils/listener.go Normal file
View file

@ -0,0 +1,19 @@
package utils
import (
"net"
"os"
"github.com/gravitational/trace"
)
// GetListenerFile returns file associated with listener
func GetListenerFile(listener net.Listener) (*os.File, error) {
switch t := listener.(type) {
case *net.TCPListener:
return t.File()
case *net.UnixListener:
return t.File()
}
return nil, trace.BadParameter("unsupported listener: %T", listener)
}

View file

@ -79,7 +79,6 @@ func CreateTLSConfiguration(certFile, keyFile string) (*tls.Config, error) {
return nil, trace.BadParameter("certificate is not accessible by '%v'", certFile)
}
log.Infof("[PROXY] TLS cert=%v key=%v", certFile, keyFile)
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, trace.Wrap(err)

View file

@ -681,7 +681,6 @@ func (s *sessionCache) ValidateSession(user, sid string) (*SessionContext, error
// this means that someone has just inserted the context, so
// close our extra context and return
if trace.IsAlreadyExists(err) {
log.Infof("just created, returning the existing one")
defer c.Close()
return out, nil
}

View file

@ -95,7 +95,9 @@ func (w *sessionStreamHandler) stream(ws *websocket.Conn) error {
// ask for any events than happened since the last call:
re, err := clt.GetSessionEvents(w.namespace, w.sessionID, eventsCursor+1)
if err != nil {
log.Error(err)
if !trace.IsNotFound(err) {
log.Error(err)
}
return emptyEventList
}
batchLen := len(re)

View file

@ -17,8 +17,8 @@ limitations under the License.
package common
import (
"context"
"fmt"
"net/http"
"os"
"os/user"
"path/filepath"
@ -33,9 +33,9 @@ import (
"github.com/gravitational/teleport/lib/utils"
"github.com/google/gops/agent"
"github.com/gravitational/roundtrip"
"github.com/gravitational/trace"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
)
@ -109,14 +109,10 @@ func Run(options Options) (executedCommand string, conf *service.Config) {
"Base64 encoded configuration string").Hidden().Envar(defaults.ConfigEnvar).
StringVar(&ccf.ConfigString)
start.Flag("labels", "List of labels for this node").StringVar(&ccf.Labels)
start.Flag("httpprofile",
"[DEPRECATED] Start profiling endpoint on localhost:6060").Hidden().BoolVar(&ccf.HTTPProfileEndpoint)
start.Flag("gops",
"Start gops endpoint on a given address").Hidden().BoolVar(&ccf.Gops)
start.Flag("gops-addr",
"Specify gops addr to listen on").Hidden().StringVar(&ccf.GopsAddr)
"Start gops troubleshooting endpoint on a first available adress.").BoolVar(&ccf.Gops)
start.Flag("diag-addr",
"Start diangonstic endpoint on this address").Hidden().StringVar(&ccf.DiagnosticAddr)
"Start diangonstic prometheus and healthz endpoint.").StringVar(&ccf.DiagnosticAddr)
start.Flag("permit-user-env",
"Enables reading of ~/.tsh/environment when creating a session").Hidden().BoolVar(&ccf.PermitUserEnvironment)
start.Flag("insecure",
@ -155,30 +151,13 @@ func Run(options Options) (executedCommand string, conf *service.Config) {
if !options.InitOnly {
log.Debug(conf.DebugDumpToYAML())
}
if ccf.HTTPProfileEndpoint {
log.Warningf("http profile endpoint is deprecated, use gops instead")
}
if ccf.Gops {
log.Debugf("starting gops agent")
err := agent.Listen(&agent.Options{Addr: ccf.GopsAddr})
log.Debug("Starting gops agent.")
err := agent.Listen(&agent.Options{})
if err != nil {
log.Warningf("failed to start gops agent %v", err)
}
}
// collect and expose diagnostic endpoint
if ccf.DiagnosticAddr != "" {
mux := http.NewServeMux()
mux.Handle("/metrics", prometheus.Handler())
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
roundtrip.ReplyJSON(w, http.StatusOK, map[string]interface{}{"status": "ok"})
})
go func() {
err := http.ListenAndServe(ccf.DiagnosticAddr, mux)
if err != nil {
log.Warningf("diagnostic endpoint exited %v", err)
}
}()
}
if !options.InitOnly {
err = OnStart(conf)
}
@ -194,7 +173,7 @@ func Run(options Options) (executedCommand string, conf *service.Config) {
if err != nil {
utils.FatalError(err)
}
log.Info("teleport: clean exit")
log.Debug("Clean exit.")
return command, conf
}
@ -219,7 +198,7 @@ func OnStart(config *service.Config) error {
defer f.Close()
}
return trace.Wrap(srv.Wait())
return srv.WaitForSignals(context.TODO())
}
// onStatus is the handler for "status" CLI command