Merge pull request #254 from gravitational/alexander/roundrobin

Alexander/roundrobin
This commit is contained in:
Alexander Klizhentas 2016-03-15 14:34:56 -07:00
commit 85edd3e753
11 changed files with 377 additions and 110 deletions

View file

@ -48,7 +48,7 @@ func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr) err
}
client, err := NewTunClient(
servers[0],
servers,
id.HostUUID,
method)
if err != nil {
@ -74,7 +74,7 @@ func RegisterNewAuth(domainName, token string, servers []utils.NetAddr) error {
}
client, err := NewTunClient(
servers[0],
servers,
domainName,
method)
if err != nil {

View file

@ -27,6 +27,7 @@ import (
"time"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshutils"
@ -363,7 +364,7 @@ func (s *AuthTunnel) passwordAuth(
switch ab.Type {
case AuthWebPassword:
if err := s.authServer.CheckPassword(conn.User(), ab.Pass, ab.HotpToken); err != nil {
log.Warningf("password auth error: %v", err)
log.Warningf("password auth error: %#v", err)
return nil, trace.Wrap(err)
}
perms := &ssh.Permissions{
@ -483,21 +484,51 @@ func NewHostAuth(key, cert []byte) ([]ssh.AuthMethod, error) {
return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
}
type TunClient struct {
Client
dialer *TunDialer
tr *http.Transport
// TunClientOption is functional option for tunnel client
type TunClientOption func(t *TunClient)
// TunClientStorage allows tun client to set local presence service
// that it will use to sync up the latest information about auth servers
func TunClientStorage(storage utils.AddrStorage) TunClientOption {
return func(t *TunClient) {
t.addrStorage = storage
}
}
func NewTunClient(addr utils.NetAddr, user string, auth []ssh.AuthMethod) (*TunClient, error) {
// TunClient is HTTP client that works over SSH tunnel
// This is done in order to authenticate various teleport roles
// using existing SSH certificate infrastructure
type TunClient struct {
sync.Mutex
Client
user string
authServers []utils.NetAddr
authMethods []ssh.AuthMethod
refreshTicker *time.Ticker
closeC chan struct{}
closeOnce sync.Once
tr *http.Transport
addrStorage utils.AddrStorage
}
// NewTunClient returns an instance of new HTTP client to Auth server API
// exposed over SSH tunnel, so client uses SSH credentials to dial and authenticate
func NewTunClient(authServers []utils.NetAddr, user string, authMethods []ssh.AuthMethod, opts ...TunClientOption) (*TunClient, error) {
if user == "" {
return nil, trace.Wrap(teleport.BadParameter("user", "SSH connection requires a valid username"))
}
tc := &TunClient{
dialer: &TunDialer{auth: auth, addr: addr, user: user},
user: user,
authServers: authServers,
authMethods: authMethods,
refreshTicker: time.NewTicker(defaults.AuthServersRefreshPeriod),
closeC: make(chan struct{}),
}
for _, o := range opts {
o(tc)
}
tr := &http.Transport{
Dial: tc.dialer.Dial,
Dial: tc.Dial,
}
clt, err := NewClient(
"http://stub:0",
@ -509,24 +540,177 @@ func NewTunClient(addr utils.NetAddr, user string, auth []ssh.AuthMethod) (*TunC
}
tc.Client = *clt
tc.tr = tr
// use local information about auth servers if it's available
if tc.addrStorage != nil {
authServers, err := tc.addrStorage.GetAddresses()
if err != nil {
if !teleport.IsNotFound(err) {
return nil, trace.Wrap(err)
}
log.Infof("local storage is provided, not initialized")
} else {
log.Infof("using auth servers from local storage: %v", authServers)
tc.authServers = authServers
}
}
go tc.syncAuthServers()
return tc, nil
}
func (c *TunClient) GetAgent() (AgentCloser, error) {
return c.dialer.GetAgent()
}
// Close releases all the resources allocated for this client
func (c *TunClient) Close() error {
c.tr.CloseIdleConnections()
return c.dialer.Close()
c.refreshTicker.Stop()
c.closeOnce.Do(func() {
close(c.closeC)
})
return nil
}
// GetDialer returns dialer that will connect to auth server API
func (c *TunClient) GetDialer() AccessPointDialer {
return func() (net.Conn, error) {
return c.dialer.Dial(c.dialer.addr.AddrNetwork, "accesspoint:0")
return c.Dial(c.authServers[0].AddrNetwork, "accesspoint:0")
}
}
// GetAgent returns SSH agent that uses ReqWebSessionAgent Auth server extension
func (c *TunClient) GetAgent() (AgentCloser, error) {
client, err := c.getClient() // we need an established connection first
if err != nil {
return nil, trace.Wrap(err)
}
ch, _, err := client.OpenChannel(ReqWebSessionAgent, nil)
if err != nil {
return nil, trace.Wrap(
teleport.ConnectionProblem(
"failed to connect to remote API", err))
}
agentCloser := &tunAgent{client: client}
agentCloser.Agent = agent.NewClient(ch)
return agentCloser, nil
}
// Dial dials to Auth server's HTTP API over SSH tunnel
func (c *TunClient) Dial(network, address string) (net.Conn, error) {
log.Debugf("TunDialer.Dial(%v, %v)", network, address)
client, err := c.getClient()
if err != nil {
return nil, trace.Wrap(err)
}
conn, err := client.Dial(network, address)
if err != nil {
return nil, trace.Wrap(
teleport.ConnectionProblem("failed to connect to remote API", err))
}
tc := &tunConn{client: client}
tc.Conn = conn
return tc, nil
}
func (c *TunClient) fetchAndSync() error {
authServers, err := c.fetchAuthServers()
if err != nil {
log.Infof("failed to fetch auth servers")
return trace.Wrap(err)
}
if len(authServers) == 0 {
log.Warningf("no auth servers received")
return trace.Wrap(teleport.NotFound("no auth servers"))
}
// set runtime information about auth servers
c.setAuthServers(authServers)
// populate local storage if it is supplied
if c.addrStorage != nil {
if err := c.addrStorage.SetAddresses(authServers); err != nil {
return trace.Wrap(err, "failed to set local storage addresses")
}
}
return nil
}
func (c *TunClient) syncAuthServers() {
for {
select {
case <-c.refreshTicker.C:
err := c.fetchAndSync()
if err != nil {
log.Infof("failed to fetch and sync servers: %v", err)
continue
}
case <-c.closeC:
return
}
}
}
func (c *TunClient) fetchAuthServers() ([]utils.NetAddr, error) {
servers, err := c.GetAuthServers()
if err != nil {
return nil, trace.Wrap(err)
}
authServers := make([]utils.NetAddr, 0, len(servers))
for _, server := range servers {
serverAddr, err := utils.ParseAddr(server.Addr)
if err != nil {
return nil, trace.Wrap(err)
}
authServers = append(authServers, *serverAddr)
}
return authServers, nil
}
func (c *TunClient) getAuthServers() []utils.NetAddr {
c.Lock()
defer c.Unlock()
out := make([]utils.NetAddr, len(c.authServers))
for i := range c.authServers {
out[i] = c.authServers[i]
}
return out
}
func (c *TunClient) setAuthServers(servers []utils.NetAddr) {
c.Lock()
defer c.Unlock()
log.Infof("setAuthServers(%#v)", servers)
c.authServers = servers
}
func (c *TunClient) getClient() (*ssh.Client, error) {
var client *ssh.Client
var err error
for _, authServer := range c.getAuthServers() {
client, err = c.dialAuthServer(authServer)
if err == nil {
return client, nil
}
}
return nil, trace.Wrap(err)
}
func (c *TunClient) dialAuthServer(authServer utils.NetAddr) (*ssh.Client, error) {
config := &ssh.ClientConfig{
User: c.user,
Auth: c.authMethods,
}
client, err := ssh.Dial(authServer.AddrNetwork, authServer.Addr, config)
log.Debugf("TunDialer.getClient(%v)", authServer.String())
if err != nil {
log.Infof("TunDialer could not ssh.Dial: %v", err)
if utils.IsHandshakeFailedError(err) {
return nil, teleport.AccessDenied(
fmt.Sprintf("access denied to '%v': bad username or credentials", c.user))
}
return nil, trace.Wrap(teleport.ConvertSystemError(err))
}
return client, nil
}
type AgentCloser interface {
io.Closer
agent.Agent
@ -542,51 +726,6 @@ func (ta *tunAgent) Close() error {
return ta.client.Close()
}
type TunDialer struct {
sync.Mutex
auth []ssh.AuthMethod
user string
addr utils.NetAddr
}
func (t *TunDialer) Close() error {
return nil
}
func (t *TunDialer) GetAgent() (AgentCloser, error) {
client, err := t.getClient() // we need an established connection first
if err != nil {
return nil, trace.Wrap(err)
}
ch, _, err := client.OpenChannel(ReqWebSessionAgent, nil)
if err != nil {
return nil, trace.Wrap(
teleport.ConnectionProblem(
"failed to connect to remote API", err))
}
agentCloser := &tunAgent{client: client}
agentCloser.Agent = agent.NewClient(ch)
return agentCloser, nil
}
func (t *TunDialer) getClient() (*ssh.Client, error) {
config := &ssh.ClientConfig{
User: t.user,
Auth: t.auth,
}
client, err := ssh.Dial(t.addr.AddrNetwork, t.addr.Addr, config)
log.Debugf("TunDialer.getClient(%v)", t.addr.String())
if err != nil {
log.Infof("TunDialer could not ssh.Dial: %v", err)
if utils.IsHandshakeFailedError(err) {
return nil, teleport.AccessDenied(
fmt.Sprintf("access denied to '%v': bad username or credentials", t.user))
}
return nil, trace.Wrap(teleport.ConvertSystemError(err))
}
return client, nil
}
const (
// DialerRetryAttempts is the amount of attempts for dialer to try and
// connect to the remote destination
@ -607,22 +746,6 @@ func (c *tunConn) Close() error {
return trace.Wrap(err)
}
func (t *TunDialer) Dial(network, address string) (net.Conn, error) {
log.Debugf("TunDialer.Dial(%v, %v)", network, address)
client, err := t.getClient()
if err != nil {
return nil, trace.Wrap(err)
}
conn, err := client.Dial(network, address)
if err != nil {
return nil, trace.Wrap(
teleport.ConnectionProblem("failed to connect to remote API", err))
}
tc := &tunConn{client: client}
tc.Conn = conn
return tc, nil
}
const (
ReqWebSessionAgent = "web-session-agent@teleport"
ReqProvision = "provision@teleport"

View file

@ -17,6 +17,7 @@ limitations under the License.
package auth
import (
"fmt"
"path/filepath"
"time"
@ -148,7 +149,7 @@ func (s *TunSuite) TestUnixServerClient(c *C) {
c.Assert(err, IsNil)
clt, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()},
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}},
"test", authMethod)
c.Assert(err, IsNil)
@ -176,7 +177,7 @@ func (s *TunSuite) TestSessions(c *C) {
c.Assert(err, IsNil)
clt, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
c.Assert(err, IsNil)
defer clt.Close()
@ -189,7 +190,7 @@ func (s *TunSuite) TestSessions(c *C) {
c.Assert(err, IsNil)
cltw, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
c.Assert(err, IsNil)
defer cltw.Close()
@ -229,7 +230,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
c.Assert(err, IsNil)
clt0, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod0)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod0)
c.Assert(err, IsNil)
_, _, _, err = clt0.GetSignupTokenData(token2)
c.Assert(err, NotNil) // valid token, but invalid client
@ -239,7 +240,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
c.Assert(err, IsNil)
clt, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
c.Assert(err, IsNil)
defer clt.Close()
@ -279,7 +280,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
// Saving new password
clt2, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
c.Assert(err, IsNil)
defer clt2.Close()
@ -302,7 +303,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
// trying to connect to the auth server using used token
clt0, err = NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
c.Assert(err, IsNil) // shouldn't accept such connection twice
_, _, _, err = clt0.GetSignupTokenData(token2)
c.Assert(err, NotNil) // valid token, but invalid client
@ -313,7 +314,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
c.Assert(err, IsNil)
clt3, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod3)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod3)
c.Assert(err, IsNil)
defer clt3.Close()
@ -341,7 +342,7 @@ func (s *TunSuite) TestPermissions(c *C) {
c.Assert(err, IsNil)
clt, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
c.Assert(err, IsNil)
defer clt.Close()
@ -362,7 +363,7 @@ func (s *TunSuite) TestPermissions(c *C) {
c.Assert(err, IsNil)
cltw, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
c.Assert(err, IsNil)
defer cltw.Close()
@ -404,7 +405,7 @@ func (s *TunSuite) TestSessionsBadPassword(c *C) {
c.Assert(err, IsNil)
clt, err := NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
c.Assert(err, IsNil)
defer clt.Close()
@ -416,3 +417,57 @@ func (s *TunSuite) TestSessionsBadPassword(c *C) {
c.Assert(err, NotNil)
c.Assert(ws, IsNil)
}
func (s *TunSuite) TestFailover(c *C) {
node := services.Server{
ID: "node1",
Addr: "node.example.com:12345",
Hostname: "node.example.com",
}
c.Assert(s.a.UpsertNode(node, backend.Forever), IsNil)
ports, err := utils.GetFreeTCPPorts(1)
c.Assert(err, IsNil)
clt, err := NewTunClient(
[]utils.NetAddr{
{AddrNetwork: "tcp", Addr: fmt.Sprintf("127.0.0.1:%v", ports.Pop())},
{AddrNetwork: "tcp", Addr: s.tsrv.Addr()},
}, "localhost", []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
c.Assert(err, IsNil)
defer clt.Close()
nodes, err := clt.GetNodes()
c.Assert(err, IsNil)
c.Assert(nodes, DeepEquals, []services.Server{node})
}
func (s *TunSuite) TestSync(c *C) {
authServer := services.Server{
ID: "node1",
Addr: "node.example.com:12345",
Hostname: "node.example.com",
}
c.Assert(s.a.UpsertAuthServer(authServer, backend.Forever), IsNil)
storage := utils.NewFileAddrStorage(filepath.Join(c.MkDir(), "addr.json"))
clt, err := NewTunClient(
[]utils.NetAddr{
{AddrNetwork: "tcp", Addr: s.tsrv.Addr()},
}, "localhost", []ssh.AuthMethod{ssh.PublicKeys(s.signer)},
TunClientStorage(storage),
)
c.Assert(err, IsNil)
defer clt.Close()
err = clt.fetchAndSync()
c.Assert(err, IsNil)
expected := []utils.NetAddr{{Addr: "node.example.com:12345", AddrNetwork: "tcp"}}
c.Assert(clt.getAuthServers(), DeepEquals, expected)
syncedServers, err := storage.GetAddresses()
c.Assert(err, IsNil)
c.Assert(syncedServers, DeepEquals, expected)
}

View file

@ -18,6 +18,7 @@ limitations under the License.
package etcdbk
import (
"encoding/base64"
"sort"
"strings"
"time"
@ -107,7 +108,7 @@ func (b *bk) GetKeys(path []string) ([]string, error) {
func (b *bk) CreateVal(path []string, key string, val []byte, ttl time.Duration) error {
_, err := b.api.Set(
context.Background(),
b.key(append(path, key)...), string(val),
b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val),
&client.SetOptions{PrevExist: client.PrevNoExist, TTL: ttl})
return trace.Wrap(convertErr(err))
}
@ -125,7 +126,7 @@ func (b *bk) TouchVal(path []string, key string, ttl time.Duration) error {
}
_, err = b.api.Set(
context.Background(),
b.key(append(path, key)...), string(re.Node.Value),
b.key(append(path, key)...), re.Node.Value,
&client.SetOptions{TTL: ttl, PrevValue: re.Node.Value, PrevExist: client.PrevExist})
err = convertErr(err)
if err == nil {
@ -138,7 +139,7 @@ func (b *bk) TouchVal(path []string, key string, ttl time.Duration) error {
func (b *bk) UpsertVal(path []string, key string, val []byte, ttl time.Duration) error {
_, err := b.api.Set(
context.Background(),
b.key(append(path, key)...), string(val), &client.SetOptions{TTL: ttl})
b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val), &client.SetOptions{TTL: ttl})
return convertErr(err)
}
@ -148,20 +149,25 @@ func (b *bk) CompareAndSwap(path []string, key string, val []byte, ttl time.Dura
if len(prevVal) != 0 {
re, err = b.api.Set(
context.Background(),
b.key(append(path, key)...), string(val),
&client.SetOptions{TTL: ttl, PrevValue: string(prevVal), PrevExist: client.PrevExist})
b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val),
&client.SetOptions{TTL: ttl, PrevValue: base64.StdEncoding.EncodeToString(prevVal), PrevExist: client.PrevExist})
} else {
re, err = b.api.Set(
context.Background(),
b.key(append(path, key)...), string(val),
b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val),
&client.SetOptions{TTL: ttl, PrevExist: client.PrevNoExist})
}
err = convertErr(err)
if err != nil {
return nil, trace.Wrap(err)
}
if re.PrevNode != nil {
return []byte(re.PrevNode.Value), nil
value, err := base64.StdEncoding.DecodeString(re.PrevNode.Value)
if err != nil {
return nil, trace.Wrap(err)
}
return value, nil
}
return nil, nil
}
@ -174,7 +180,11 @@ func (b *bk) GetVal(path []string, key string) ([]byte, error) {
if re.Node.Dir {
return nil, trace.Wrap(teleport.BadParameter(key, "trying to get value of bucket"))
}
return []byte(re.Node.Value), nil
value, err := base64.StdEncoding.DecodeString(re.Node.Value)
if err != nil {
return nil, trace.Wrap(err)
}
return value, nil
}
func (b *bk) GetValAndTTL(path []string, key string) ([]byte, time.Duration, error) {
@ -186,7 +196,11 @@ func (b *bk) GetValAndTTL(path []string, key string) ([]byte, time.Duration, err
return nil, 0, trace.Wrap(
teleport.BadParameter(key, "trying to get value of bucket"))
}
return []byte(re.Node.Value), time.Duration(re.Node.TTL) * time.Second, nil
value, err := base64.StdEncoding.DecodeString(re.Node.Value)
if err != nil {
return nil, 0, trace.Wrap(err)
}
return value, time.Duration(re.Node.TTL) * time.Second, nil
}
func (b *bk) DeleteKey(path []string, key string) error {

View file

@ -72,6 +72,10 @@ const (
// deviation added to this time to avoid lots of simultaneous
// heartbeats coming to auth server
ServerHeartbeatTTL = 6 * time.Second
// AuthServersRefreshPeriod is a period for clients to refresh their
// their stored list of auth servers
AuthServersRefreshPeriod = 3 * time.Second
)
// Default connection limits, they can be applied separately on any of the Teleport

View file

@ -89,12 +89,24 @@ func (process *TeleportProcess) connectToAuthService(role teleport.Role) (*conne
if err != nil {
return nil, trace.Wrap(err)
}
storage := utils.NewFileAddrStorage(
filepath.Join(process.Config.DataDir, "authservers.json"))
var authServers []utils.NetAddr
authServers, err = storage.GetAddresses()
if err != nil && len(authServers) == 0 {
log.Infof("no auth servers are available from the local storage")
authServers = process.Config.AuthServers
}
log.Infof("connecting to auth servers: %v", authServers)
authUser := identity.Cert.ValidPrincipals[0]
authClient, err := auth.NewTunClient(
process.Config.AuthServers[0],
authServers,
authUser,
[]ssh.AuthMethod{ssh.PublicKeys(identity.KeySigner)})
[]ssh.AuthMethod{ssh.PublicKeys(identity.KeySigner)},
auth.TunClientStorage(storage),
)
// success?
if err != nil {
return nil, trace.Wrap(err)
@ -283,7 +295,7 @@ func (process *TeleportProcess) initAuthService() error {
// logic, consolidate it into auth package later
process.RegisterFunc(func() error {
authClient, err := auth.NewTunClient(
cfg.Auth.SSHAddr,
[]utils.NetAddr{cfg.Auth.SSHAddr},
identity.Cert.ValidPrincipals[0],
[]ssh.AuthMethod{ssh.PublicKeys(identity.KeySigner)})
// success?

View file

@ -458,7 +458,7 @@ func (s *Server) keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permiss
logger.Warningf("authenticate err: %v", err)
return nil, trace.Wrap(err)
}
if err := s.certChecker.CheckCert(teleportUser, cert); err != nil {
if err := s.certChecker.CheckCert(conn.User(), cert); err != nil {
logger.Warningf("failed to authenticate user, err: %v", err)
return nil, trace.Wrap(err)
}

View file

@ -357,7 +357,7 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
c.Assert(tsrv.Start(), IsNil)
tunClt, err := auth.NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
c.Assert(err, IsNil)
defer tunClt.Close()
@ -502,7 +502,7 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) {
c.Assert(tsrv.Start(), IsNil)
tunClt, err := auth.NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
c.Assert(err, IsNil)
defer tunClt.Close()
@ -603,7 +603,7 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) {
c.Assert(tsrv.Start(), IsNil)
tunClt, err := auth.NewTunClient(
utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
c.Assert(err, IsNil)
defer tunClt.Close()

59
lib/utils/storage.go Normal file
View file

@ -0,0 +1,59 @@
package utils
import (
"encoding/json"
"io/ioutil"
"github.com/gravitational/teleport"
"github.com/gravitational/trace"
)
// AddrStorage is used to store information locally for
// every client that connects in the cluster, so it can always have
// up-to-date info about auth servers
type AddrStorage interface {
// SetAddresses saves addresses
SetAddresses([]NetAddr) error
// GetAddresses
GetAddresses() ([]NetAddr, error)
}
// FileAddrStorage is a file based address storage
type FileAddrStorage struct {
filePath string
}
// SetAddresses updates storage with new address list
func (fs *FileAddrStorage) SetAddresses(addrs []NetAddr) error {
bytes, err := json.Marshal(addrs)
if err != nil {
return trace.Wrap(err)
}
err = ioutil.WriteFile(fs.filePath, bytes, 0666)
if err != nil {
return trace.Wrap(teleport.ConvertSystemError(err))
}
return nil
}
// GetAddresses returns saved address list
func (fs *FileAddrStorage) GetAddresses() ([]NetAddr, error) {
bytes, err := ioutil.ReadFile(fs.filePath)
if err != nil {
return nil, trace.Wrap(teleport.ConvertSystemError(err))
}
var addrs []NetAddr
err = json.Unmarshal(bytes, &addrs)
if err != nil {
return nil, trace.Wrap(err)
}
return addrs, nil
}
// NewFileAddrStorage returns new instance of file-based address storage
func NewFileAddrStorage(filePath string) *FileAddrStorage {
return &FileAddrStorage{
filePath: filePath,
}
}

View file

@ -174,7 +174,7 @@ func (s *sessionCache) Auth(user, pass string, hotpToken string) (*auth.Session,
if err != nil {
return nil, trace.Wrap(err)
}
clt, err := auth.NewTunClient(s.authServers[0], user, method)
clt, err := auth.NewTunClient(s.authServers, user, method)
if err != nil {
return nil, trace.Wrap(err)
}
@ -187,7 +187,7 @@ func (s *sessionCache) GetCertificate(c createSSHCertReq) (*SSHLoginResponse, er
if err != nil {
return nil, trace.Wrap(err)
}
clt, err := auth.NewTunClient(s.authServers[0], c.User, method)
clt, err := auth.NewTunClient(s.authServers, c.User, method)
if err != nil {
return nil, trace.Wrap(err)
}
@ -218,7 +218,7 @@ func (s *sessionCache) GetUserInviteInfo(token string) (user string,
if err != nil {
return "", nil, nil, trace.Wrap(err)
}
clt, err := auth.NewTunClient(s.authServers[0], "tokenAuth", method)
clt, err := auth.NewTunClient(s.authServers, "tokenAuth", method)
if err != nil {
return "", nil, nil, trace.Wrap(err)
}
@ -231,7 +231,7 @@ func (s *sessionCache) CreateNewUser(token, password, hotpToken string) (*auth.S
if err != nil {
return nil, trace.Wrap(err)
}
clt, err := auth.NewTunClient(s.authServers[0], "tokenAuth", method)
clt, err := auth.NewTunClient(s.authServers, "tokenAuth", method)
if err != nil {
return nil, trace.Wrap(err)
}
@ -292,7 +292,7 @@ func (s *sessionCache) ValidateSession(user, sid string) (*sessionContext, error
if err != nil {
return nil, trace.Wrap(err)
}
clt, err := auth.NewTunClient(s.authServers[0], user, method)
clt, err := auth.NewTunClient(s.authServers, user, method)
if err != nil {
return nil, trace.Wrap(err)
}

View file

@ -358,7 +358,7 @@ func connectToAuthService(cfg *service.Config) (client *auth.TunClient, err erro
return nil, trace.Wrap(err)
}
client, err = auth.NewTunClient(
cfg.AuthServers[0],
cfg.AuthServers,
cfg.HostUUID,
[]ssh.AuthMethod{ssh.PublicKeys(i.KeySigner)})
if err != nil {