mirror of
https://github.com/gravitational/teleport
synced 2024-10-22 10:13:21 +00:00
Merge pull request #254 from gravitational/alexander/roundrobin
Alexander/roundrobin
This commit is contained in:
commit
85edd3e753
|
@ -48,7 +48,7 @@ func Register(dataDir, token string, id IdentityID, servers []utils.NetAddr) err
|
|||
}
|
||||
|
||||
client, err := NewTunClient(
|
||||
servers[0],
|
||||
servers,
|
||||
id.HostUUID,
|
||||
method)
|
||||
if err != nil {
|
||||
|
@ -74,7 +74,7 @@ func RegisterNewAuth(domainName, token string, servers []utils.NetAddr) error {
|
|||
}
|
||||
|
||||
client, err := NewTunClient(
|
||||
servers[0],
|
||||
servers,
|
||||
domainName,
|
||||
method)
|
||||
if err != nil {
|
||||
|
|
273
lib/auth/tun.go
273
lib/auth/tun.go
|
@ -27,6 +27,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
"github.com/gravitational/teleport/lib/defaults"
|
||||
"github.com/gravitational/teleport/lib/limiter"
|
||||
"github.com/gravitational/teleport/lib/services"
|
||||
"github.com/gravitational/teleport/lib/sshutils"
|
||||
|
@ -363,7 +364,7 @@ func (s *AuthTunnel) passwordAuth(
|
|||
switch ab.Type {
|
||||
case AuthWebPassword:
|
||||
if err := s.authServer.CheckPassword(conn.User(), ab.Pass, ab.HotpToken); err != nil {
|
||||
log.Warningf("password auth error: %v", err)
|
||||
log.Warningf("password auth error: %#v", err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
perms := &ssh.Permissions{
|
||||
|
@ -483,21 +484,51 @@ func NewHostAuth(key, cert []byte) ([]ssh.AuthMethod, error) {
|
|||
return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
|
||||
}
|
||||
|
||||
type TunClient struct {
|
||||
Client
|
||||
dialer *TunDialer
|
||||
tr *http.Transport
|
||||
// TunClientOption is functional option for tunnel client
|
||||
type TunClientOption func(t *TunClient)
|
||||
|
||||
// TunClientStorage allows tun client to set local presence service
|
||||
// that it will use to sync up the latest information about auth servers
|
||||
func TunClientStorage(storage utils.AddrStorage) TunClientOption {
|
||||
return func(t *TunClient) {
|
||||
t.addrStorage = storage
|
||||
}
|
||||
}
|
||||
|
||||
func NewTunClient(addr utils.NetAddr, user string, auth []ssh.AuthMethod) (*TunClient, error) {
|
||||
// TunClient is HTTP client that works over SSH tunnel
|
||||
// This is done in order to authenticate various teleport roles
|
||||
// using existing SSH certificate infrastructure
|
||||
type TunClient struct {
|
||||
sync.Mutex
|
||||
Client
|
||||
user string
|
||||
authServers []utils.NetAddr
|
||||
authMethods []ssh.AuthMethod
|
||||
refreshTicker *time.Ticker
|
||||
closeC chan struct{}
|
||||
closeOnce sync.Once
|
||||
tr *http.Transport
|
||||
addrStorage utils.AddrStorage
|
||||
}
|
||||
|
||||
// NewTunClient returns an instance of new HTTP client to Auth server API
|
||||
// exposed over SSH tunnel, so client uses SSH credentials to dial and authenticate
|
||||
func NewTunClient(authServers []utils.NetAddr, user string, authMethods []ssh.AuthMethod, opts ...TunClientOption) (*TunClient, error) {
|
||||
if user == "" {
|
||||
return nil, trace.Wrap(teleport.BadParameter("user", "SSH connection requires a valid username"))
|
||||
}
|
||||
tc := &TunClient{
|
||||
dialer: &TunDialer{auth: auth, addr: addr, user: user},
|
||||
user: user,
|
||||
authServers: authServers,
|
||||
authMethods: authMethods,
|
||||
refreshTicker: time.NewTicker(defaults.AuthServersRefreshPeriod),
|
||||
closeC: make(chan struct{}),
|
||||
}
|
||||
for _, o := range opts {
|
||||
o(tc)
|
||||
}
|
||||
tr := &http.Transport{
|
||||
Dial: tc.dialer.Dial,
|
||||
Dial: tc.Dial,
|
||||
}
|
||||
clt, err := NewClient(
|
||||
"http://stub:0",
|
||||
|
@ -509,24 +540,177 @@ func NewTunClient(addr utils.NetAddr, user string, auth []ssh.AuthMethod) (*TunC
|
|||
}
|
||||
tc.Client = *clt
|
||||
tc.tr = tr
|
||||
|
||||
// use local information about auth servers if it's available
|
||||
if tc.addrStorage != nil {
|
||||
|
||||
authServers, err := tc.addrStorage.GetAddresses()
|
||||
if err != nil {
|
||||
if !teleport.IsNotFound(err) {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
log.Infof("local storage is provided, not initialized")
|
||||
} else {
|
||||
log.Infof("using auth servers from local storage: %v", authServers)
|
||||
tc.authServers = authServers
|
||||
}
|
||||
}
|
||||
go tc.syncAuthServers()
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func (c *TunClient) GetAgent() (AgentCloser, error) {
|
||||
return c.dialer.GetAgent()
|
||||
}
|
||||
|
||||
// Close releases all the resources allocated for this client
|
||||
func (c *TunClient) Close() error {
|
||||
c.tr.CloseIdleConnections()
|
||||
return c.dialer.Close()
|
||||
c.refreshTicker.Stop()
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.closeC)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDialer returns dialer that will connect to auth server API
|
||||
func (c *TunClient) GetDialer() AccessPointDialer {
|
||||
return func() (net.Conn, error) {
|
||||
return c.dialer.Dial(c.dialer.addr.AddrNetwork, "accesspoint:0")
|
||||
return c.Dial(c.authServers[0].AddrNetwork, "accesspoint:0")
|
||||
}
|
||||
}
|
||||
|
||||
// GetAgent returns SSH agent that uses ReqWebSessionAgent Auth server extension
|
||||
func (c *TunClient) GetAgent() (AgentCloser, error) {
|
||||
client, err := c.getClient() // we need an established connection first
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
ch, _, err := client.OpenChannel(ReqWebSessionAgent, nil)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(
|
||||
teleport.ConnectionProblem(
|
||||
"failed to connect to remote API", err))
|
||||
}
|
||||
agentCloser := &tunAgent{client: client}
|
||||
agentCloser.Agent = agent.NewClient(ch)
|
||||
return agentCloser, nil
|
||||
}
|
||||
|
||||
// Dial dials to Auth server's HTTP API over SSH tunnel
|
||||
func (c *TunClient) Dial(network, address string) (net.Conn, error) {
|
||||
log.Debugf("TunDialer.Dial(%v, %v)", network, address)
|
||||
client, err := c.getClient()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
conn, err := client.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(
|
||||
teleport.ConnectionProblem("failed to connect to remote API", err))
|
||||
}
|
||||
tc := &tunConn{client: client}
|
||||
tc.Conn = conn
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func (c *TunClient) fetchAndSync() error {
|
||||
authServers, err := c.fetchAuthServers()
|
||||
if err != nil {
|
||||
log.Infof("failed to fetch auth servers")
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if len(authServers) == 0 {
|
||||
log.Warningf("no auth servers received")
|
||||
return trace.Wrap(teleport.NotFound("no auth servers"))
|
||||
}
|
||||
// set runtime information about auth servers
|
||||
c.setAuthServers(authServers)
|
||||
// populate local storage if it is supplied
|
||||
if c.addrStorage != nil {
|
||||
if err := c.addrStorage.SetAddresses(authServers); err != nil {
|
||||
return trace.Wrap(err, "failed to set local storage addresses")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TunClient) syncAuthServers() {
|
||||
for {
|
||||
select {
|
||||
case <-c.refreshTicker.C:
|
||||
err := c.fetchAndSync()
|
||||
if err != nil {
|
||||
log.Infof("failed to fetch and sync servers: %v", err)
|
||||
continue
|
||||
}
|
||||
case <-c.closeC:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TunClient) fetchAuthServers() ([]utils.NetAddr, error) {
|
||||
servers, err := c.GetAuthServers()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
authServers := make([]utils.NetAddr, 0, len(servers))
|
||||
for _, server := range servers {
|
||||
serverAddr, err := utils.ParseAddr(server.Addr)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
authServers = append(authServers, *serverAddr)
|
||||
}
|
||||
return authServers, nil
|
||||
}
|
||||
|
||||
func (c *TunClient) getAuthServers() []utils.NetAddr {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
out := make([]utils.NetAddr, len(c.authServers))
|
||||
for i := range c.authServers {
|
||||
out[i] = c.authServers[i]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *TunClient) setAuthServers(servers []utils.NetAddr) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
log.Infof("setAuthServers(%#v)", servers)
|
||||
c.authServers = servers
|
||||
}
|
||||
|
||||
func (c *TunClient) getClient() (*ssh.Client, error) {
|
||||
var client *ssh.Client
|
||||
var err error
|
||||
for _, authServer := range c.getAuthServers() {
|
||||
client, err = c.dialAuthServer(authServer)
|
||||
if err == nil {
|
||||
return client, nil
|
||||
}
|
||||
}
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (c *TunClient) dialAuthServer(authServer utils.NetAddr) (*ssh.Client, error) {
|
||||
config := &ssh.ClientConfig{
|
||||
User: c.user,
|
||||
Auth: c.authMethods,
|
||||
}
|
||||
client, err := ssh.Dial(authServer.AddrNetwork, authServer.Addr, config)
|
||||
log.Debugf("TunDialer.getClient(%v)", authServer.String())
|
||||
if err != nil {
|
||||
log.Infof("TunDialer could not ssh.Dial: %v", err)
|
||||
if utils.IsHandshakeFailedError(err) {
|
||||
return nil, teleport.AccessDenied(
|
||||
fmt.Sprintf("access denied to '%v': bad username or credentials", c.user))
|
||||
}
|
||||
return nil, trace.Wrap(teleport.ConvertSystemError(err))
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type AgentCloser interface {
|
||||
io.Closer
|
||||
agent.Agent
|
||||
|
@ -542,51 +726,6 @@ func (ta *tunAgent) Close() error {
|
|||
return ta.client.Close()
|
||||
}
|
||||
|
||||
type TunDialer struct {
|
||||
sync.Mutex
|
||||
auth []ssh.AuthMethod
|
||||
user string
|
||||
addr utils.NetAddr
|
||||
}
|
||||
|
||||
func (t *TunDialer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TunDialer) GetAgent() (AgentCloser, error) {
|
||||
client, err := t.getClient() // we need an established connection first
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
ch, _, err := client.OpenChannel(ReqWebSessionAgent, nil)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(
|
||||
teleport.ConnectionProblem(
|
||||
"failed to connect to remote API", err))
|
||||
}
|
||||
agentCloser := &tunAgent{client: client}
|
||||
agentCloser.Agent = agent.NewClient(ch)
|
||||
return agentCloser, nil
|
||||
}
|
||||
|
||||
func (t *TunDialer) getClient() (*ssh.Client, error) {
|
||||
config := &ssh.ClientConfig{
|
||||
User: t.user,
|
||||
Auth: t.auth,
|
||||
}
|
||||
client, err := ssh.Dial(t.addr.AddrNetwork, t.addr.Addr, config)
|
||||
log.Debugf("TunDialer.getClient(%v)", t.addr.String())
|
||||
if err != nil {
|
||||
log.Infof("TunDialer could not ssh.Dial: %v", err)
|
||||
if utils.IsHandshakeFailedError(err) {
|
||||
return nil, teleport.AccessDenied(
|
||||
fmt.Sprintf("access denied to '%v': bad username or credentials", t.user))
|
||||
}
|
||||
return nil, trace.Wrap(teleport.ConvertSystemError(err))
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// DialerRetryAttempts is the amount of attempts for dialer to try and
|
||||
// connect to the remote destination
|
||||
|
@ -607,22 +746,6 @@ func (c *tunConn) Close() error {
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
func (t *TunDialer) Dial(network, address string) (net.Conn, error) {
|
||||
log.Debugf("TunDialer.Dial(%v, %v)", network, address)
|
||||
client, err := t.getClient()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
conn, err := client.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(
|
||||
teleport.ConnectionProblem("failed to connect to remote API", err))
|
||||
}
|
||||
tc := &tunConn{client: client}
|
||||
tc.Conn = conn
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
const (
|
||||
ReqWebSessionAgent = "web-session-agent@teleport"
|
||||
ReqProvision = "provision@teleport"
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
|
@ -148,7 +149,7 @@ func (s *TunSuite) TestUnixServerClient(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
clt, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()},
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}},
|
||||
"test", authMethod)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
|
@ -176,7 +177,7 @@ func (s *TunSuite) TestSessions(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
clt, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
|
||||
c.Assert(err, IsNil)
|
||||
defer clt.Close()
|
||||
|
||||
|
@ -189,7 +190,7 @@ func (s *TunSuite) TestSessions(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
cltw, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
|
||||
c.Assert(err, IsNil)
|
||||
defer cltw.Close()
|
||||
|
||||
|
@ -229,7 +230,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
clt0, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod0)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod0)
|
||||
c.Assert(err, IsNil)
|
||||
_, _, _, err = clt0.GetSignupTokenData(token2)
|
||||
c.Assert(err, NotNil) // valid token, but invalid client
|
||||
|
@ -239,7 +240,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
clt, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
|
||||
c.Assert(err, IsNil)
|
||||
defer clt.Close()
|
||||
|
||||
|
@ -279,7 +280,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
|
|||
|
||||
// Saving new password
|
||||
clt2, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
|
||||
c.Assert(err, IsNil)
|
||||
defer clt2.Close()
|
||||
|
||||
|
@ -302,7 +303,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
|
|||
|
||||
// trying to connect to the auth server using used token
|
||||
clt0, err = NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
|
||||
c.Assert(err, IsNil) // shouldn't accept such connection twice
|
||||
_, _, _, err = clt0.GetSignupTokenData(token2)
|
||||
c.Assert(err, NotNil) // valid token, but invalid client
|
||||
|
@ -313,7 +314,7 @@ func (s *TunSuite) TestWebCreatingNewUser(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
clt3, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod3)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod3)
|
||||
c.Assert(err, IsNil)
|
||||
defer clt3.Close()
|
||||
|
||||
|
@ -341,7 +342,7 @@ func (s *TunSuite) TestPermissions(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
clt, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
|
||||
c.Assert(err, IsNil)
|
||||
defer clt.Close()
|
||||
|
||||
|
@ -362,7 +363,7 @@ func (s *TunSuite) TestPermissions(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
cltw, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
|
||||
c.Assert(err, IsNil)
|
||||
defer cltw.Close()
|
||||
|
||||
|
@ -404,7 +405,7 @@ func (s *TunSuite) TestSessionsBadPassword(c *C) {
|
|||
c.Assert(err, IsNil)
|
||||
|
||||
clt, err := NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}, user, authMethod)
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: s.tsrv.Addr()}}, user, authMethod)
|
||||
c.Assert(err, IsNil)
|
||||
defer clt.Close()
|
||||
|
||||
|
@ -416,3 +417,57 @@ func (s *TunSuite) TestSessionsBadPassword(c *C) {
|
|||
c.Assert(err, NotNil)
|
||||
c.Assert(ws, IsNil)
|
||||
}
|
||||
|
||||
func (s *TunSuite) TestFailover(c *C) {
|
||||
node := services.Server{
|
||||
ID: "node1",
|
||||
Addr: "node.example.com:12345",
|
||||
Hostname: "node.example.com",
|
||||
}
|
||||
c.Assert(s.a.UpsertNode(node, backend.Forever), IsNil)
|
||||
|
||||
ports, err := utils.GetFreeTCPPorts(1)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
clt, err := NewTunClient(
|
||||
[]utils.NetAddr{
|
||||
{AddrNetwork: "tcp", Addr: fmt.Sprintf("127.0.0.1:%v", ports.Pop())},
|
||||
{AddrNetwork: "tcp", Addr: s.tsrv.Addr()},
|
||||
}, "localhost", []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
c.Assert(err, IsNil)
|
||||
defer clt.Close()
|
||||
|
||||
nodes, err := clt.GetNodes()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(nodes, DeepEquals, []services.Server{node})
|
||||
}
|
||||
|
||||
func (s *TunSuite) TestSync(c *C) {
|
||||
authServer := services.Server{
|
||||
ID: "node1",
|
||||
Addr: "node.example.com:12345",
|
||||
Hostname: "node.example.com",
|
||||
}
|
||||
c.Assert(s.a.UpsertAuthServer(authServer, backend.Forever), IsNil)
|
||||
|
||||
storage := utils.NewFileAddrStorage(filepath.Join(c.MkDir(), "addr.json"))
|
||||
|
||||
clt, err := NewTunClient(
|
||||
[]utils.NetAddr{
|
||||
{AddrNetwork: "tcp", Addr: s.tsrv.Addr()},
|
||||
}, "localhost", []ssh.AuthMethod{ssh.PublicKeys(s.signer)},
|
||||
TunClientStorage(storage),
|
||||
)
|
||||
c.Assert(err, IsNil)
|
||||
defer clt.Close()
|
||||
|
||||
err = clt.fetchAndSync()
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
expected := []utils.NetAddr{{Addr: "node.example.com:12345", AddrNetwork: "tcp"}}
|
||||
c.Assert(clt.getAuthServers(), DeepEquals, expected)
|
||||
|
||||
syncedServers, err := storage.GetAddresses()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(syncedServers, DeepEquals, expected)
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
package etcdbk
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -107,7 +108,7 @@ func (b *bk) GetKeys(path []string) ([]string, error) {
|
|||
func (b *bk) CreateVal(path []string, key string, val []byte, ttl time.Duration) error {
|
||||
_, err := b.api.Set(
|
||||
context.Background(),
|
||||
b.key(append(path, key)...), string(val),
|
||||
b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val),
|
||||
&client.SetOptions{PrevExist: client.PrevNoExist, TTL: ttl})
|
||||
return trace.Wrap(convertErr(err))
|
||||
}
|
||||
|
@ -125,7 +126,7 @@ func (b *bk) TouchVal(path []string, key string, ttl time.Duration) error {
|
|||
}
|
||||
_, err = b.api.Set(
|
||||
context.Background(),
|
||||
b.key(append(path, key)...), string(re.Node.Value),
|
||||
b.key(append(path, key)...), re.Node.Value,
|
||||
&client.SetOptions{TTL: ttl, PrevValue: re.Node.Value, PrevExist: client.PrevExist})
|
||||
err = convertErr(err)
|
||||
if err == nil {
|
||||
|
@ -138,7 +139,7 @@ func (b *bk) TouchVal(path []string, key string, ttl time.Duration) error {
|
|||
func (b *bk) UpsertVal(path []string, key string, val []byte, ttl time.Duration) error {
|
||||
_, err := b.api.Set(
|
||||
context.Background(),
|
||||
b.key(append(path, key)...), string(val), &client.SetOptions{TTL: ttl})
|
||||
b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val), &client.SetOptions{TTL: ttl})
|
||||
return convertErr(err)
|
||||
}
|
||||
|
||||
|
@ -148,20 +149,25 @@ func (b *bk) CompareAndSwap(path []string, key string, val []byte, ttl time.Dura
|
|||
if len(prevVal) != 0 {
|
||||
re, err = b.api.Set(
|
||||
context.Background(),
|
||||
b.key(append(path, key)...), string(val),
|
||||
&client.SetOptions{TTL: ttl, PrevValue: string(prevVal), PrevExist: client.PrevExist})
|
||||
b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val),
|
||||
&client.SetOptions{TTL: ttl, PrevValue: base64.StdEncoding.EncodeToString(prevVal), PrevExist: client.PrevExist})
|
||||
} else {
|
||||
re, err = b.api.Set(
|
||||
context.Background(),
|
||||
b.key(append(path, key)...), string(val),
|
||||
b.key(append(path, key)...), base64.StdEncoding.EncodeToString(val),
|
||||
&client.SetOptions{TTL: ttl, PrevExist: client.PrevNoExist})
|
||||
}
|
||||
err = convertErr(err)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
if re.PrevNode != nil {
|
||||
return []byte(re.PrevNode.Value), nil
|
||||
value, err := base64.StdEncoding.DecodeString(re.PrevNode.Value)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -174,7 +180,11 @@ func (b *bk) GetVal(path []string, key string) ([]byte, error) {
|
|||
if re.Node.Dir {
|
||||
return nil, trace.Wrap(teleport.BadParameter(key, "trying to get value of bucket"))
|
||||
}
|
||||
return []byte(re.Node.Value), nil
|
||||
value, err := base64.StdEncoding.DecodeString(re.Node.Value)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func (b *bk) GetValAndTTL(path []string, key string) ([]byte, time.Duration, error) {
|
||||
|
@ -186,7 +196,11 @@ func (b *bk) GetValAndTTL(path []string, key string) ([]byte, time.Duration, err
|
|||
return nil, 0, trace.Wrap(
|
||||
teleport.BadParameter(key, "trying to get value of bucket"))
|
||||
}
|
||||
return []byte(re.Node.Value), time.Duration(re.Node.TTL) * time.Second, nil
|
||||
value, err := base64.StdEncoding.DecodeString(re.Node.Value)
|
||||
if err != nil {
|
||||
return nil, 0, trace.Wrap(err)
|
||||
}
|
||||
return value, time.Duration(re.Node.TTL) * time.Second, nil
|
||||
}
|
||||
|
||||
func (b *bk) DeleteKey(path []string, key string) error {
|
||||
|
|
|
@ -72,6 +72,10 @@ const (
|
|||
// deviation added to this time to avoid lots of simultaneous
|
||||
// heartbeats coming to auth server
|
||||
ServerHeartbeatTTL = 6 * time.Second
|
||||
|
||||
// AuthServersRefreshPeriod is a period for clients to refresh their
|
||||
// their stored list of auth servers
|
||||
AuthServersRefreshPeriod = 3 * time.Second
|
||||
)
|
||||
|
||||
// Default connection limits, they can be applied separately on any of the Teleport
|
||||
|
|
|
@ -89,12 +89,24 @@ func (process *TeleportProcess) connectToAuthService(role teleport.Role) (*conne
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
storage := utils.NewFileAddrStorage(
|
||||
filepath.Join(process.Config.DataDir, "authservers.json"))
|
||||
|
||||
var authServers []utils.NetAddr
|
||||
authServers, err = storage.GetAddresses()
|
||||
if err != nil && len(authServers) == 0 {
|
||||
log.Infof("no auth servers are available from the local storage")
|
||||
authServers = process.Config.AuthServers
|
||||
}
|
||||
|
||||
log.Infof("connecting to auth servers: %v", authServers)
|
||||
authUser := identity.Cert.ValidPrincipals[0]
|
||||
authClient, err := auth.NewTunClient(
|
||||
process.Config.AuthServers[0],
|
||||
authServers,
|
||||
authUser,
|
||||
[]ssh.AuthMethod{ssh.PublicKeys(identity.KeySigner)})
|
||||
[]ssh.AuthMethod{ssh.PublicKeys(identity.KeySigner)},
|
||||
auth.TunClientStorage(storage),
|
||||
)
|
||||
// success?
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
|
@ -283,7 +295,7 @@ func (process *TeleportProcess) initAuthService() error {
|
|||
// logic, consolidate it into auth package later
|
||||
process.RegisterFunc(func() error {
|
||||
authClient, err := auth.NewTunClient(
|
||||
cfg.Auth.SSHAddr,
|
||||
[]utils.NetAddr{cfg.Auth.SSHAddr},
|
||||
identity.Cert.ValidPrincipals[0],
|
||||
[]ssh.AuthMethod{ssh.PublicKeys(identity.KeySigner)})
|
||||
// success?
|
||||
|
|
|
@ -458,7 +458,7 @@ func (s *Server) keyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permiss
|
|||
logger.Warningf("authenticate err: %v", err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if err := s.certChecker.CheckCert(teleportUser, cert); err != nil {
|
||||
if err := s.certChecker.CheckCert(conn.User(), cert); err != nil {
|
||||
logger.Warningf("failed to authenticate user, err: %v", err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
|
|
@ -357,7 +357,7 @@ func (s *SrvSuite) TestProxyReverseTunnel(c *C) {
|
|||
c.Assert(tsrv.Start(), IsNil)
|
||||
|
||||
tunClt, err := auth.NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
c.Assert(err, IsNil)
|
||||
defer tunClt.Close()
|
||||
|
||||
|
@ -502,7 +502,7 @@ func (s *SrvSuite) TestProxyRoundRobin(c *C) {
|
|||
c.Assert(tsrv.Start(), IsNil)
|
||||
|
||||
tunClt, err := auth.NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
c.Assert(err, IsNil)
|
||||
defer tunClt.Close()
|
||||
|
||||
|
@ -603,7 +603,7 @@ func (s *SrvSuite) TestProxyDirectAccess(c *C) {
|
|||
c.Assert(tsrv.Start(), IsNil)
|
||||
|
||||
tunClt, err := auth.NewTunClient(
|
||||
utils.NetAddr{AddrNetwork: "tcp", Addr: tsrv.Addr()}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
[]utils.NetAddr{{AddrNetwork: "tcp", Addr: tsrv.Addr()}}, s.domainName, []ssh.AuthMethod{ssh.PublicKeys(s.signer)})
|
||||
c.Assert(err, IsNil)
|
||||
defer tunClt.Close()
|
||||
|
||||
|
|
59
lib/utils/storage.go
Normal file
59
lib/utils/storage.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
)
|
||||
|
||||
// AddrStorage is used to store information locally for
|
||||
// every client that connects in the cluster, so it can always have
|
||||
// up-to-date info about auth servers
|
||||
type AddrStorage interface {
|
||||
// SetAddresses saves addresses
|
||||
SetAddresses([]NetAddr) error
|
||||
// GetAddresses
|
||||
GetAddresses() ([]NetAddr, error)
|
||||
}
|
||||
|
||||
// FileAddrStorage is a file based address storage
|
||||
type FileAddrStorage struct {
|
||||
filePath string
|
||||
}
|
||||
|
||||
// SetAddresses updates storage with new address list
|
||||
func (fs *FileAddrStorage) SetAddresses(addrs []NetAddr) error {
|
||||
bytes, err := json.Marshal(addrs)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
err = ioutil.WriteFile(fs.filePath, bytes, 0666)
|
||||
if err != nil {
|
||||
return trace.Wrap(teleport.ConvertSystemError(err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAddresses returns saved address list
|
||||
func (fs *FileAddrStorage) GetAddresses() ([]NetAddr, error) {
|
||||
bytes, err := ioutil.ReadFile(fs.filePath)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(teleport.ConvertSystemError(err))
|
||||
}
|
||||
var addrs []NetAddr
|
||||
err = json.Unmarshal(bytes, &addrs)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// NewFileAddrStorage returns new instance of file-based address storage
|
||||
func NewFileAddrStorage(filePath string) *FileAddrStorage {
|
||||
return &FileAddrStorage{
|
||||
filePath: filePath,
|
||||
}
|
||||
}
|
|
@ -174,7 +174,7 @@ func (s *sessionCache) Auth(user, pass string, hotpToken string) (*auth.Session,
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
clt, err := auth.NewTunClient(s.authServers[0], user, method)
|
||||
clt, err := auth.NewTunClient(s.authServers, user, method)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -187,7 +187,7 @@ func (s *sessionCache) GetCertificate(c createSSHCertReq) (*SSHLoginResponse, er
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
clt, err := auth.NewTunClient(s.authServers[0], c.User, method)
|
||||
clt, err := auth.NewTunClient(s.authServers, c.User, method)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -218,7 +218,7 @@ func (s *sessionCache) GetUserInviteInfo(token string) (user string,
|
|||
if err != nil {
|
||||
return "", nil, nil, trace.Wrap(err)
|
||||
}
|
||||
clt, err := auth.NewTunClient(s.authServers[0], "tokenAuth", method)
|
||||
clt, err := auth.NewTunClient(s.authServers, "tokenAuth", method)
|
||||
if err != nil {
|
||||
return "", nil, nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -231,7 +231,7 @@ func (s *sessionCache) CreateNewUser(token, password, hotpToken string) (*auth.S
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
clt, err := auth.NewTunClient(s.authServers[0], "tokenAuth", method)
|
||||
clt, err := auth.NewTunClient(s.authServers, "tokenAuth", method)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
@ -292,7 +292,7 @@ func (s *sessionCache) ValidateSession(user, sid string) (*sessionContext, error
|
|||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
clt, err := auth.NewTunClient(s.authServers[0], user, method)
|
||||
clt, err := auth.NewTunClient(s.authServers, user, method)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
|
|
@ -358,7 +358,7 @@ func connectToAuthService(cfg *service.Config) (client *auth.TunClient, err erro
|
|||
return nil, trace.Wrap(err)
|
||||
}
|
||||
client, err = auth.NewTunClient(
|
||||
cfg.AuthServers[0],
|
||||
cfg.AuthServers,
|
||||
cfg.HostUUID,
|
||||
[]ssh.AuthMethod{ssh.PublicKeys(i.KeySigner)})
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue