mirror of
https://github.com/gravitational/teleport
synced 2024-10-20 17:23:22 +00:00
fix data races and remove sleep from tests
* fix data race with advertise ip * remove global variable * simplify pings logic and fix ping bug * fix potential bug in dynamic labels
This commit is contained in:
parent
f71fbb90bd
commit
519f07611b
4
Makefile
4
Makefile
|
@ -29,7 +29,7 @@ tsh:
|
|||
install: remove-temp-files
|
||||
go install github.com/gravitational/teleport/tool/teleport
|
||||
go install github.com/gravitational/teleport/tool/tctl
|
||||
go install github.com/gravitational/teleport/tool/t
|
||||
go install github.com/gravitational/teleport/tool/tsh
|
||||
|
||||
clean:
|
||||
rm -rf $(OUT)
|
||||
|
@ -45,7 +45,7 @@ production: clean
|
|||
# tests everything: called by Jenkins
|
||||
#
|
||||
test:
|
||||
go test -v github.com/gravitational/teleport/tool/t/...
|
||||
go test -v github.com/gravitational/teleport/tool/tsh/...
|
||||
go test -v github.com/gravitational/teleport/lib/... -cover
|
||||
go test -v github.com/gravitational/teleport/tool/teleport... -cover
|
||||
|
||||
|
|
|
@ -13,6 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package defaults contains default constants set in various parts of
|
||||
// teleport codebase
|
||||
package defaults
|
||||
|
||||
import (
|
||||
|
@ -64,9 +67,11 @@ const (
|
|||
// By default all users use /bin/bash
|
||||
DefaultShell = "/bin/bash"
|
||||
|
||||
// Median sleep time between node pings. Note that a random deviation is
|
||||
// added to this time
|
||||
SleepBetweenNodePings = time.Second * 5
|
||||
// ServerHeartbeatTTL is a period between heartbeats
|
||||
// Median sleep time between node pings is this value / 2 + random
|
||||
// deviation added to this time to avoid lots of simultaneous
|
||||
// heartbeats coming to auth server
|
||||
ServerHeartbeatTTL = 6 * time.Second
|
||||
)
|
||||
|
||||
// Default connection limits, they can be applied separately on any of the Teleport
|
||||
|
@ -84,27 +89,31 @@ const (
|
|||
MinCertDuration = time.Minute
|
||||
// MaxCertDuration limits maximum duration of validity of issued cert
|
||||
MaxCertDuration = 30 * time.Hour
|
||||
// Default certificate duration in hours
|
||||
CertDurationHours = 12
|
||||
CertDuration = CertDurationHours * time.Hour
|
||||
// CertDuration is a default certificate duration
|
||||
// 12 is default as it' longer than average working day (I hope so)
|
||||
CertDuration = 12 * time.Hour
|
||||
)
|
||||
|
||||
// list of roles teleport service can run as:
|
||||
const (
|
||||
RoleNode = "node"
|
||||
RoleProxy = "proxy"
|
||||
// RoleNode is SSH stateless node
|
||||
RoleNode = "node"
|
||||
// RoleProxy is a stateless SSH access proxy (bastion)
|
||||
RoleProxy = "proxy"
|
||||
// RoleAuthService is authentication and authorization service,
|
||||
// the only stateful role in the system
|
||||
RoleAuthService = "auth"
|
||||
)
|
||||
|
||||
var (
|
||||
// Default path to teleport config file
|
||||
// ConfigFilePath is default path to teleport config file
|
||||
ConfigFilePath = "/etc/teleport.yaml"
|
||||
|
||||
// This is where all mutable data is stored (user keys, recorded sessions,
|
||||
// DataDir is where all mutable data is stored (user keys, recorded sessions,
|
||||
// registered SSH servers, etc):
|
||||
DataDir = "/var/lib/teleport"
|
||||
|
||||
// Default roles teleport assumes when started via 'start' command
|
||||
// StartRoles is default roles teleport assumes when started via 'start' command
|
||||
StartRoles = []string{RoleProxy, RoleNode, RoleAuthService}
|
||||
)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package services
|
||||
|
||||
import (
|
||||
|
|
113
lib/srv/srv.go
113
lib/srv/srv.go
|
@ -24,6 +24,7 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -48,9 +49,6 @@ import (
|
|||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
// TestSleepDuration is used to shortcut the 'ping' looop during tests
|
||||
var TestSleepDuration = time.Duration(time.Hour)
|
||||
|
||||
// Server implements SSH server that uses configuration backend and
|
||||
// certificate-based authentication
|
||||
type Server struct {
|
||||
|
@ -231,74 +229,99 @@ func (s *Server) ID() string {
|
|||
return s.uuid
|
||||
}
|
||||
|
||||
// AdvertiseAddr() returns an address this server should be publicly accessible
|
||||
func (s *Server) setAdvertiseIP(ip net.IP) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.advertiseIP = ip
|
||||
}
|
||||
|
||||
func (s *Server) getAdvertiseIP() net.IP {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
return s.advertiseIP
|
||||
}
|
||||
|
||||
// AdvertiseAddr returns an address this server should be publicly accessible
|
||||
// as, in "ip:host" form
|
||||
func (s *Server) AdvertiseAddr() string {
|
||||
// se if we have explicit --advertise-ip option
|
||||
if s.advertiseIP == nil {
|
||||
if s.getAdvertiseIP() == nil {
|
||||
return s.addr.Addr
|
||||
}
|
||||
_, port, _ := net.SplitHostPort(s.addr.Addr)
|
||||
return net.JoinHostPort(s.advertiseIP.String(), port)
|
||||
return net.JoinHostPort(s.getAdvertiseIP().String(), port)
|
||||
}
|
||||
|
||||
// registerServer attempts to register server in the cluster
|
||||
func (s *Server) registerServer() error {
|
||||
srv := services.Server{
|
||||
ID: s.ID(),
|
||||
Addr: s.AdvertiseAddr(),
|
||||
Hostname: s.hostname,
|
||||
Labels: s.labels,
|
||||
CmdLabels: s.getCommandLabels(),
|
||||
}
|
||||
return trace.Wrap(s.authService.UpsertServer(srv, defaults.ServerHeartbeatTTL))
|
||||
}
|
||||
|
||||
// heartbeatPresence periodically calls into the auth server to let everyone
|
||||
// know we're up & alive
|
||||
func (s *Server) heartbeatPresence() {
|
||||
var sleepDuration time.Duration
|
||||
advertiseAddr := s.AdvertiseAddr()
|
||||
for {
|
||||
sleepDuration = utils.RandomizedDuration(defaults.SleepBetweenNodePings, 0.3)
|
||||
// shorten sleep time during tests
|
||||
if sleepDuration > TestSleepDuration {
|
||||
sleepDuration = TestSleepDuration
|
||||
if err := s.registerServer(); err != nil {
|
||||
log.Warningf("failed to announce %#v presence: %v", s, err)
|
||||
}
|
||||
func() {
|
||||
s.labelsMutex.Lock()
|
||||
defer s.labelsMutex.Unlock()
|
||||
srv := services.Server{
|
||||
ID: s.ID(),
|
||||
Addr: advertiseAddr,
|
||||
Hostname: s.hostname,
|
||||
Labels: s.labels,
|
||||
CmdLabels: s.cmdLabels,
|
||||
}
|
||||
if err := s.authService.UpsertServer(srv, sleepDuration*2); err != nil {
|
||||
log.Warningf("failed to announce %#v presence: %v", srv, err)
|
||||
// sleep longer on failures
|
||||
sleepDuration = sleepDuration * 2
|
||||
}
|
||||
}()
|
||||
log.Infof("[SSH] will ping auth service in %v", sleepDuration)
|
||||
time.Sleep(sleepDuration)
|
||||
log.Infof("[SSH] will ping auth service in %v", defaults.ServerHeartbeatTTL)
|
||||
time.Sleep(defaults.ServerHeartbeatTTL + utils.RandomDuration(defaults.ServerHeartbeatTTL/10))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) updateLabels() {
|
||||
for name, label := range s.cmdLabels {
|
||||
go s.updateLabel(name, label)
|
||||
go s.periodicUpdateLabel(name, label)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) syncUpdateLabels() {
|
||||
for name, label := range s.cmdLabels {
|
||||
s.updateLabel(name, label)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) updateLabel(name string, label services.CommandLabel) {
|
||||
out, err := exec.Command(label.Command[0], label.Command[1:]...).Output()
|
||||
if err != nil {
|
||||
log.Errorf(err.Error())
|
||||
label.Result = err.Error() + " output: " + string(out)
|
||||
} else {
|
||||
label.Result = strings.TrimSpace(string(out))
|
||||
}
|
||||
s.setCommandLabel(name, label)
|
||||
}
|
||||
|
||||
func (s *Server) periodicUpdateLabel(name string, label services.CommandLabel) {
|
||||
for {
|
||||
out, err := exec.Command(label.Command[0], label.Command[1:]...).Output()
|
||||
if err != nil {
|
||||
log.Errorf(err.Error())
|
||||
label.Result = err.Error() + " Output: " + string(out)
|
||||
} else {
|
||||
if out[len(out)-1] == 10 {
|
||||
out = out[:len(out)-1] // remove new line
|
||||
}
|
||||
label.Result = string(out)
|
||||
}
|
||||
s.labelsMutex.Lock()
|
||||
s.cmdLabels[name] = label
|
||||
s.labelsMutex.Unlock()
|
||||
s.updateLabel(name, label)
|
||||
time.Sleep(label.Period)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setCommandLabel(name string, value services.CommandLabel) {
|
||||
s.labelsMutex.Lock()
|
||||
defer s.labelsMutex.Unlock()
|
||||
s.cmdLabels[name] = value
|
||||
}
|
||||
|
||||
func (s *Server) getCommandLabels() map[string]services.CommandLabel {
|
||||
s.labelsMutex.Lock()
|
||||
defer s.labelsMutex.Unlock()
|
||||
out := make(map[string]services.CommandLabel, len(s.cmdLabels))
|
||||
for key, val := range s.cmdLabels {
|
||||
out[key] = val
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Server) checkPermissionToLogin(cert ssh.PublicKey, teleportUser, osUser string) error {
|
||||
// find cert authority by it's key
|
||||
cas, err := s.authService.GetCertAuthorities(services.UserCA)
|
||||
|
@ -566,7 +589,7 @@ func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, ch ssh.Channel, in
|
|||
return
|
||||
}
|
||||
if err := s.dispatch(sconn, ch, req, ctx); err != nil {
|
||||
ctx.Infof("error dispatching request: %v, closing channel", err)
|
||||
ctx.Infof("error dispatching request: %#v", err)
|
||||
replyError(ch, req, err)
|
||||
closeCh()
|
||||
return
|
||||
|
|
|
@ -73,7 +73,6 @@ type SrvSuite struct {
|
|||
var _ = Suite(&SrvSuite{})
|
||||
|
||||
func (s *SrvSuite) SetUpSuite(c *C) {
|
||||
TestSleepDuration = time.Millisecond * 50
|
||||
utils.InitLoggerCLI()
|
||||
}
|
||||
|
||||
|
@ -145,6 +144,7 @@ func (s *SrvSuite) SetUpTest(c *C) {
|
|||
s.srv = srv
|
||||
|
||||
c.Assert(s.srv.Start(), IsNil)
|
||||
c.Assert(s.srv.registerServer(), IsNil)
|
||||
|
||||
// set up SSH client using the user private key for signing
|
||||
up, err := newUpack(s.user, s.a)
|
||||
|
@ -190,9 +190,9 @@ func (s *SrvSuite) TestExec(c *C) {
|
|||
|
||||
func (s *SrvSuite) TestAdvertiseAddr(c *C) {
|
||||
c.Assert(strings.Index(s.srv.AdvertiseAddr(), "127.0.0.1:"), Equals, 0)
|
||||
s.srv.advertiseIP = net.ParseIP("10.10.10.1")
|
||||
s.srv.setAdvertiseIP(net.ParseIP("10.10.10.1"))
|
||||
c.Assert(strings.Index(s.srv.AdvertiseAddr(), "10.10.10.1:"), Equals, 0)
|
||||
s.srv.advertiseIP = nil
|
||||
s.srv.setAdvertiseIP(nil)
|
||||
}
|
||||
|
||||
// TestShell launches interactive shell session and executes a command
|
||||
|
@ -404,9 +404,10 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
|
|||
srv2.uuid = bobAddr
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(srv2.Start(), IsNil)
|
||||
c.Assert(srv2.registerServer(), IsNil)
|
||||
defer srv2.Close()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
srv2.registerServer()
|
||||
|
||||
// test proxysites
|
||||
client, err := ssh.Dial("tcp", proxy.Addr(), sshConfig)
|
||||
|
@ -424,6 +425,12 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
|
|||
close(done)
|
||||
}()
|
||||
|
||||
// to make sure labels have the right output
|
||||
s.srv.syncUpdateLabels()
|
||||
srv2.syncUpdateLabels()
|
||||
s.srv.registerServer()
|
||||
srv2.registerServer()
|
||||
|
||||
c.Assert(se3.RequestSubsystem("proxysites"), IsNil)
|
||||
<-done
|
||||
var sites map[string][]services.Server
|
||||
|
|
|
@ -19,19 +19,11 @@ func CryptoRandomHex(len int) (string, error) {
|
|||
return hex.EncodeToString(randomBytes), nil
|
||||
}
|
||||
|
||||
// RandomizedDuration returns duration which is within given deviation from a given
|
||||
// median.
|
||||
//
|
||||
// For example RandomizedDuration(time.Second * 10, 0.5) will return
|
||||
// a random duration between 6 and 15 seconds.
|
||||
func RandomizedDuration(median time.Duration, deviation float64) time.Duration {
|
||||
min := int64(float64(median) * (1 - deviation))
|
||||
max := int64(float64(median) * (1 + deviation))
|
||||
|
||||
ceiling := big.NewInt(max - min)
|
||||
randomDeviation, err := rand.Int(rand.Reader, ceiling)
|
||||
// RandomDuration returns a duration in a range [0, max)
|
||||
func RandomDuration(max time.Duration) time.Duration {
|
||||
randomVal, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
|
||||
if err != nil {
|
||||
return median
|
||||
return max / 2
|
||||
}
|
||||
return time.Duration(min + randomDeviation.Int64())
|
||||
return time.Duration(randomVal.Int64())
|
||||
}
|
||||
|
|
|
@ -52,11 +52,11 @@ func (s *UtilsSuite) TestSelfSignedCert(c *check.C) {
|
|||
}
|
||||
|
||||
func (s *UtilsSuite) TestRandomDuration(c *check.C) {
|
||||
expectedMin := time.Second * 9
|
||||
expectedMax := time.Second * 11
|
||||
expectedMin := time.Duration(0)
|
||||
expectedMax := time.Second * 10
|
||||
for i := 0; i < 50; i++ {
|
||||
dur := RandomizedDuration(time.Second*10, 0.1)
|
||||
dur := RandomDuration(expectedMax)
|
||||
c.Assert(dur >= expectedMin, check.Equals, true)
|
||||
c.Assert(dur <= expectedMax, check.Equals, true)
|
||||
c.Assert(dur < expectedMax, check.Equals, true)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ func makeClient(cf *CLIConf) (tc *client.TeleportClient, err error) {
|
|||
cf.NodePort = defaults.SSHServerListenPort
|
||||
}
|
||||
if cf.MinsToLive == 0 {
|
||||
cf.MinsToLive = defaults.CertDurationHours * 60
|
||||
cf.MinsToLive = int32(defaults.CertDuration / time.Minute)
|
||||
}
|
||||
// split login & host
|
||||
parts := strings.Split(cf.UserHost, "@")
|
||||
|
|
Loading…
Reference in a new issue