fix data races and remove sleep from tests

* fix data race with advertise ip
* remove global variable
* simplify pings logic and fix ping bug
* fix potential bug in dynamic labels
This commit is contained in:
klizhentas 2016-03-08 18:41:05 -08:00
parent f71fbb90bd
commit 519f07611b
8 changed files with 112 additions and 80 deletions

View file

@ -29,7 +29,7 @@ tsh:
install: remove-temp-files
go install github.com/gravitational/teleport/tool/teleport
go install github.com/gravitational/teleport/tool/tctl
go install github.com/gravitational/teleport/tool/t
go install github.com/gravitational/teleport/tool/tsh
clean:
rm -rf $(OUT)
@ -45,7 +45,7 @@ production: clean
# tests everything: called by Jenkins
#
test:
go test -v github.com/gravitational/teleport/tool/t/...
go test -v github.com/gravitational/teleport/tool/tsh/...
go test -v github.com/gravitational/teleport/lib/... -cover
go test -v github.com/gravitational/teleport/tool/teleport... -cover

View file

@ -13,6 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package defaults contains default constants set in various parts of
// teleport codebase
package defaults
import (
@ -64,9 +67,11 @@ const (
// By default all users use /bin/bash
DefaultShell = "/bin/bash"
// Median sleep time between node pings. Note that a random deviation is
// added to this time
SleepBetweenNodePings = time.Second * 5
// ServerHeartbeatTTL is a period between heartbeats
// Median sleep time between node pings is this value / 2 + random
// deviation added to this time to avoid lots of simultaneous
// heartbeats coming to auth server
ServerHeartbeatTTL = 6 * time.Second
)
// Default connection limits, they can be applied separately on any of the Teleport
@ -84,27 +89,31 @@ const (
MinCertDuration = time.Minute
// MaxCertDuration limits maximum duration of validity of issued cert
MaxCertDuration = 30 * time.Hour
// Default certificate duration in hours
CertDurationHours = 12
CertDuration = CertDurationHours * time.Hour
// CertDuration is a default certificate duration
// 12 is default as it' longer than average working day (I hope so)
CertDuration = 12 * time.Hour
)
// list of roles teleport service can run as:
const (
RoleNode = "node"
RoleProxy = "proxy"
// RoleNode is SSH stateless node
RoleNode = "node"
// RoleProxy is a stateless SSH access proxy (bastion)
RoleProxy = "proxy"
// RoleAuthService is authentication and authorization service,
// the only stateful role in the system
RoleAuthService = "auth"
)
var (
// Default path to teleport config file
// ConfigFilePath is default path to teleport config file
ConfigFilePath = "/etc/teleport.yaml"
// This is where all mutable data is stored (user keys, recorded sessions,
// DataDir is where all mutable data is stored (user keys, recorded sessions,
// registered SSH servers, etc):
DataDir = "/var/lib/teleport"
// Default roles teleport assumes when started via 'start' command
// StartRoles is default roles teleport assumes when started via 'start' command
StartRoles = []string{RoleProxy, RoleNode, RoleAuthService}
)

View file

@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package services
import (

View file

@ -24,6 +24,7 @@ import (
"io"
"net"
"os/exec"
"strings"
"sync"
"time"
@ -48,9 +49,6 @@ import (
"golang.org/x/crypto/ssh/agent"
)
// TestSleepDuration is used to shortcut the 'ping' looop during tests
var TestSleepDuration = time.Duration(time.Hour)
// Server implements SSH server that uses configuration backend and
// certificate-based authentication
type Server struct {
@ -231,74 +229,99 @@ func (s *Server) ID() string {
return s.uuid
}
// AdvertiseAddr() returns an address this server should be publicly accessible
func (s *Server) setAdvertiseIP(ip net.IP) {
s.Lock()
defer s.Unlock()
s.advertiseIP = ip
}
func (s *Server) getAdvertiseIP() net.IP {
s.Lock()
defer s.Unlock()
return s.advertiseIP
}
// AdvertiseAddr returns an address this server should be publicly accessible
// as, in "ip:host" form
func (s *Server) AdvertiseAddr() string {
// se if we have explicit --advertise-ip option
if s.advertiseIP == nil {
if s.getAdvertiseIP() == nil {
return s.addr.Addr
}
_, port, _ := net.SplitHostPort(s.addr.Addr)
return net.JoinHostPort(s.advertiseIP.String(), port)
return net.JoinHostPort(s.getAdvertiseIP().String(), port)
}
// registerServer attempts to register server in the cluster
func (s *Server) registerServer() error {
srv := services.Server{
ID: s.ID(),
Addr: s.AdvertiseAddr(),
Hostname: s.hostname,
Labels: s.labels,
CmdLabels: s.getCommandLabels(),
}
return trace.Wrap(s.authService.UpsertServer(srv, defaults.ServerHeartbeatTTL))
}
// heartbeatPresence periodically calls into the auth server to let everyone
// know we're up & alive
func (s *Server) heartbeatPresence() {
var sleepDuration time.Duration
advertiseAddr := s.AdvertiseAddr()
for {
sleepDuration = utils.RandomizedDuration(defaults.SleepBetweenNodePings, 0.3)
// shorten sleep time during tests
if sleepDuration > TestSleepDuration {
sleepDuration = TestSleepDuration
if err := s.registerServer(); err != nil {
log.Warningf("failed to announce %#v presence: %v", s, err)
}
func() {
s.labelsMutex.Lock()
defer s.labelsMutex.Unlock()
srv := services.Server{
ID: s.ID(),
Addr: advertiseAddr,
Hostname: s.hostname,
Labels: s.labels,
CmdLabels: s.cmdLabels,
}
if err := s.authService.UpsertServer(srv, sleepDuration*2); err != nil {
log.Warningf("failed to announce %#v presence: %v", srv, err)
// sleep longer on failures
sleepDuration = sleepDuration * 2
}
}()
log.Infof("[SSH] will ping auth service in %v", sleepDuration)
time.Sleep(sleepDuration)
log.Infof("[SSH] will ping auth service in %v", defaults.ServerHeartbeatTTL)
time.Sleep(defaults.ServerHeartbeatTTL + utils.RandomDuration(defaults.ServerHeartbeatTTL/10))
}
}
func (s *Server) updateLabels() {
for name, label := range s.cmdLabels {
go s.updateLabel(name, label)
go s.periodicUpdateLabel(name, label)
}
}
func (s *Server) syncUpdateLabels() {
for name, label := range s.cmdLabels {
s.updateLabel(name, label)
}
}
func (s *Server) updateLabel(name string, label services.CommandLabel) {
out, err := exec.Command(label.Command[0], label.Command[1:]...).Output()
if err != nil {
log.Errorf(err.Error())
label.Result = err.Error() + " output: " + string(out)
} else {
label.Result = strings.TrimSpace(string(out))
}
s.setCommandLabel(name, label)
}
func (s *Server) periodicUpdateLabel(name string, label services.CommandLabel) {
for {
out, err := exec.Command(label.Command[0], label.Command[1:]...).Output()
if err != nil {
log.Errorf(err.Error())
label.Result = err.Error() + " Output: " + string(out)
} else {
if out[len(out)-1] == 10 {
out = out[:len(out)-1] // remove new line
}
label.Result = string(out)
}
s.labelsMutex.Lock()
s.cmdLabels[name] = label
s.labelsMutex.Unlock()
s.updateLabel(name, label)
time.Sleep(label.Period)
}
}
func (s *Server) setCommandLabel(name string, value services.CommandLabel) {
s.labelsMutex.Lock()
defer s.labelsMutex.Unlock()
s.cmdLabels[name] = value
}
func (s *Server) getCommandLabels() map[string]services.CommandLabel {
s.labelsMutex.Lock()
defer s.labelsMutex.Unlock()
out := make(map[string]services.CommandLabel, len(s.cmdLabels))
for key, val := range s.cmdLabels {
out[key] = val
}
return out
}
func (s *Server) checkPermissionToLogin(cert ssh.PublicKey, teleportUser, osUser string) error {
// find cert authority by it's key
cas, err := s.authService.GetCertAuthorities(services.UserCA)
@ -566,7 +589,7 @@ func (s *Server) handleSessionRequests(sconn *ssh.ServerConn, ch ssh.Channel, in
return
}
if err := s.dispatch(sconn, ch, req, ctx); err != nil {
ctx.Infof("error dispatching request: %v, closing channel", err)
ctx.Infof("error dispatching request: %#v", err)
replyError(ch, req, err)
closeCh()
return

View file

@ -73,7 +73,6 @@ type SrvSuite struct {
var _ = Suite(&SrvSuite{})
func (s *SrvSuite) SetUpSuite(c *C) {
TestSleepDuration = time.Millisecond * 50
utils.InitLoggerCLI()
}
@ -145,6 +144,7 @@ func (s *SrvSuite) SetUpTest(c *C) {
s.srv = srv
c.Assert(s.srv.Start(), IsNil)
c.Assert(s.srv.registerServer(), IsNil)
// set up SSH client using the user private key for signing
up, err := newUpack(s.user, s.a)
@ -190,9 +190,9 @@ func (s *SrvSuite) TestExec(c *C) {
func (s *SrvSuite) TestAdvertiseAddr(c *C) {
c.Assert(strings.Index(s.srv.AdvertiseAddr(), "127.0.0.1:"), Equals, 0)
s.srv.advertiseIP = net.ParseIP("10.10.10.1")
s.srv.setAdvertiseIP(net.ParseIP("10.10.10.1"))
c.Assert(strings.Index(s.srv.AdvertiseAddr(), "10.10.10.1:"), Equals, 0)
s.srv.advertiseIP = nil
s.srv.setAdvertiseIP(nil)
}
// TestShell launches interactive shell session and executes a command
@ -404,9 +404,10 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
srv2.uuid = bobAddr
c.Assert(err, IsNil)
c.Assert(srv2.Start(), IsNil)
c.Assert(srv2.registerServer(), IsNil)
defer srv2.Close()
time.Sleep(200 * time.Millisecond)
srv2.registerServer()
// test proxysites
client, err := ssh.Dial("tcp", proxy.Addr(), sshConfig)
@ -424,6 +425,12 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
close(done)
}()
// to make sure labels have the right output
s.srv.syncUpdateLabels()
srv2.syncUpdateLabels()
s.srv.registerServer()
srv2.registerServer()
c.Assert(se3.RequestSubsystem("proxysites"), IsNil)
<-done
var sites map[string][]services.Server

View file

@ -19,19 +19,11 @@ func CryptoRandomHex(len int) (string, error) {
return hex.EncodeToString(randomBytes), nil
}
// RandomizedDuration returns duration which is within given deviation from a given
// median.
//
// For example RandomizedDuration(time.Second * 10, 0.5) will return
// a random duration between 6 and 15 seconds.
func RandomizedDuration(median time.Duration, deviation float64) time.Duration {
min := int64(float64(median) * (1 - deviation))
max := int64(float64(median) * (1 + deviation))
ceiling := big.NewInt(max - min)
randomDeviation, err := rand.Int(rand.Reader, ceiling)
// RandomDuration returns a duration in a range [0, max)
func RandomDuration(max time.Duration) time.Duration {
randomVal, err := rand.Int(rand.Reader, big.NewInt(int64(max)))
if err != nil {
return median
return max / 2
}
return time.Duration(min + randomDeviation.Int64())
return time.Duration(randomVal.Int64())
}

View file

@ -52,11 +52,11 @@ func (s *UtilsSuite) TestSelfSignedCert(c *check.C) {
}
func (s *UtilsSuite) TestRandomDuration(c *check.C) {
expectedMin := time.Second * 9
expectedMax := time.Second * 11
expectedMin := time.Duration(0)
expectedMax := time.Second * 10
for i := 0; i < 50; i++ {
dur := RandomizedDuration(time.Second*10, 0.1)
dur := RandomDuration(expectedMax)
c.Assert(dur >= expectedMin, check.Equals, true)
c.Assert(dur <= expectedMax, check.Equals, true)
c.Assert(dur < expectedMax, check.Equals, true)
}
}

View file

@ -148,7 +148,7 @@ func makeClient(cf *CLIConf) (tc *client.TeleportClient, err error) {
cf.NodePort = defaults.SSHServerListenPort
}
if cf.MinsToLive == 0 {
cf.MinsToLive = defaults.CertDurationHours * 60
cf.MinsToLive = int32(defaults.CertDuration / time.Minute)
}
// split login & host
parts := strings.Split(cf.UserHost, "@")