Remove exported Webauthn functions (#30420)

* Add WebauthnLogin field to teleportClient and tsh for tests.

* Use custom WebauthnLogin func instead of test export.

* Remove HasPlatformSupport exported function.

* Add todo to remove lib/client/export.go.

* Parallelize affected tests.

* Apply suggestions from CR.
This commit is contained in:
Brian Joerger 2023-08-16 19:18:23 -07:00 committed by GitHub
parent 4e66a6fc67
commit fc6bcf3cfb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 76 additions and 91 deletions

View file

@ -8603,18 +8603,16 @@ func testModeratedSessions(t *testing.T, suite *integrationTestSuite) {
panic("this should not be called")
})
oldStdin, oldWebauthn := prompt.Stdin(), *client.PromptWebauthn
oldStdin := prompt.Stdin()
prompt.SetStdin(inputReader)
t.Cleanup(func() {
prompt.SetStdin(oldStdin)
*client.PromptWebauthn = oldWebauthn
})
device, err := mocku2f.Create()
require.NoError(t, err)
device.SetPasswordless()
prompt.SetStdin(inputReader)
*client.PromptWebauthn = func(ctx context.Context, realOrigin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) {
customWebauthnLogin := func(ctx context.Context, realOrigin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) {
car, err := device.SignAssertion("https://127.0.0.1", assertion) // use the fake origin to prevent a mismatch
if err != nil {
return nil, "", err
@ -8731,6 +8729,7 @@ func testModeratedSessions(t *testing.T, suite *integrationTestSuite) {
return
}
cl.WebauthnLogin = customWebauthnLogin
cl.Stdout = peerTerminal
cl.Stdin = peerTerminal
if err := cl.SSH(ctx, []string{}, false); err != nil {
@ -8761,6 +8760,7 @@ func testModeratedSessions(t *testing.T, suite *integrationTestSuite) {
return
}
cl.WebauthnLogin = customWebauthnLogin
cl.Stdout = moderatorTerminal
cl.Stdin = moderatorTerminal
if err := cl.Join(ctx, types.SessionModeratorMode, defaults.Namespace, session.ID(sessionID), moderatorTerminal); err != nil {

View file

@ -463,6 +463,10 @@ type Config struct {
// PromptMFAFunc allows tests to override the default MFA prompt function.
// Defaults to [mfa.NewPrompt().Run].
PromptMFAFunc PromptMFAFunc
// WebauthnLogin allows tests to override the Webauthn Login func.
// Defaults to [wancli.Login].
WebauthnLogin WebauthnLoginFunc
}
// CachePolicy defines cache policy for local clients
@ -3625,6 +3629,7 @@ func (tc *TeleportClient) pwdlessLoginWeb(ctx context.Context, priv *keys.Privat
User: user,
AuthenticatorAttachment: tc.AuthenticatorAttachment,
StderrOverride: tc.Stderr,
WebauthnLogin: tc.WebauthnLogin,
})
return clt, session, trace.Wrap(err)
}
@ -3903,6 +3908,7 @@ func (tc *TeleportClient) pwdlessLogin(ctx context.Context, priv *keys.PrivateKe
User: user,
AuthenticatorAttachment: tc.AuthenticatorAttachment,
StderrOverride: tc.Stderr,
WebauthnLogin: tc.WebauthnLogin,
})
return response, trace.Wrap(err)

View file

@ -55,6 +55,8 @@ import (
)
func TestTeleportClient_Login_local(t *testing.T) {
t.Parallel()
silenceLogger(t)
clock := clockwork.NewFakeClockAt(time.Now())
@ -80,17 +82,11 @@ func TestTeleportClient_Login_local(t *testing.T) {
cfg.InsecureSkipVerify = true
// Reset functions after tests.
oldStdin, oldWebauthn := prompt.Stdin(), *client.PromptWebauthn
oldHasPlatformSupport := *client.HasPlatformSupport
*client.HasPlatformSupport = func() bool {
return true
}
oldStdin := prompt.Stdin()
oldHasCredentials := *client.HasTouchIDCredentials
t.Cleanup(func() {
prompt.SetStdin(oldStdin)
*client.PromptWebauthn = oldWebauthn
*client.HasPlatformSupport = oldHasPlatformSupport
*client.HasTouchIDCredentials = oldHasCredentials
})
@ -262,14 +258,6 @@ func TestTeleportClient_Login_local(t *testing.T) {
defer cancel()
prompt.SetStdin(test.inputReader)
*client.PromptWebauthn = func(
ctx context.Context,
origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts,
) (*proto.MFAAuthenticateResponse, string, error) {
resp, err := test.solveWebauthn(ctx, origin, assertion, prompt)
return resp, "", err
}
*client.HasTouchIDCredentials = func(rpid, user string) bool {
return test.hasTouchIDCredentials
}
@ -288,6 +276,14 @@ func TestTeleportClient_Login_local(t *testing.T) {
tc.PreferOTP = test.preferOTP
tc.AuthenticatorAttachment = test.authenticatorAttachment
tc.WebauthnLogin = func(
ctx context.Context,
origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts,
) (*proto.MFAAuthenticateResponse, string, error) {
resp, err := test.solveWebauthn(ctx, origin, assertion, prompt)
return resp, "", err
}
clock.Advance(30 * time.Second)
_, err = tc.Login(ctx)
require.NoError(t, err)

View file

@ -14,5 +14,5 @@
package client
// TODO(Joerger): Remove this export once /e no longer depends on it.
var PromptWebauthn = &promptWebauthn
var HasPlatformSupport = &hasPlatformSupport

View file

@ -27,14 +27,14 @@ import (
// TODO(Joerger): remove this once the exported PromptWebauthn function is no longer used in tests.
// promptWebauthn provides indirection for tests.
var promptWebauthn func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error)
// hasPlatformSupport is used to mock wancli.HasPlatformSupport for tests.
var hasPlatformSupport = wancli.HasPlatformSupport
var promptWebauthn WebauthnLoginFunc
// PromptMFAFunc matches the signature of [mfa.Prompt.Run].
type PromptMFAFunc func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
// WebauthnLoginFunc matches the signature of [wancli.Login].
type WebauthnLoginFunc func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error)
// NewMFAPrompt creates a new MFA prompt from client settings.
func (tc *TeleportClient) NewMFAPrompt(opts ...mfa.PromptOpt) PromptMFAFunc {
if tc.PromptMFAFunc != nil {
@ -52,6 +52,11 @@ func (tc *TeleportClient) NewMFAPrompt(opts ...mfa.PromptOpt) PromptMFAFunc {
prompt.WebauthnSupported = true
}
if tc.WebauthnLogin != nil {
prompt.WebauthnLogin = tc.WebauthnLogin
prompt.WebauthnSupported = true
}
for _, opt := range opts {
opt(prompt)
}

View file

@ -26,7 +26,6 @@ import (
wanpb "github.com/gravitational/teleport/api/types/webauthn"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/mfa"
"github.com/gravitational/teleport/lib/utils/prompt"
)
@ -36,12 +35,8 @@ import (
// See api_login_test.go and/or TeleportClient tests for more general
// authentication tests.
func TestPromptMFAChallenge_usingNonRegisteredDevice(t *testing.T) {
oldPromptWebauthn := *client.PromptWebauthn
oldHasPlatformSupport := *client.HasPlatformSupport
oldStdin := prompt.Stdin()
t.Cleanup(func() {
*client.PromptWebauthn = oldPromptWebauthn
*client.HasPlatformSupport = oldHasPlatformSupport
prompt.SetStdin(oldStdin)
})
@ -89,6 +84,9 @@ func TestPromptMFAChallenge_usingNonRegisteredDevice(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
test := test
t.Parallel()
// Set a timeout so the test won't block forever.
// We don't expect to hit the timeout for any of the test cases.
ctx, cancel := context.WithTimeout(ctx, 1*time.Second)

View file

@ -279,6 +279,10 @@ type SSHLoginMFA struct {
type SSHLoginPasswordless struct {
SSHLogin
// WebauthnLogin is a customizable webauthn login function.
// Defaults to [wancli.Login]
WebauthnLogin WebauthnLoginFunc
// StderrOverride will override the default os.Stderr if provided.
StderrOverride io.Writer
@ -543,13 +547,12 @@ func SSHAgentPasswordlessLogin(ctx context.Context, login SSHLoginPasswordless)
prompt = wancli.NewDefaultPrompt(ctx, stderr)
}
// TODO(Joerger): remove this once the exported PromptWebauthn function is no longer used in tests.
webauthnLogin := wancli.Login
if promptWebauthn != nil {
webauthnLogin = promptWebauthn
promptWebauthn := login.WebauthnLogin
if promptWebauthn == nil {
promptWebauthn = wancli.Login
}
mfaResp, _, err := webauthnLogin(ctx, webURL.String(), challenge.WebauthnChallenge, prompt, &wancli.LoginOpts{
mfaResp, _, err := promptWebauthn(ctx, webURL.String(), challenge.WebauthnChallenge, prompt, &wancli.LoginOpts{
User: login.User,
AuthenticatorAttachment: login.AuthenticatorAttachment,
})
@ -882,13 +885,12 @@ func SSHAgentPasswordlessLoginWeb(ctx context.Context, login SSHLoginPasswordles
prompt = wancli.NewDefaultPrompt(ctx, stderr)
}
// TODO(Joerger): remove this once the exported PromptWebauthn function is no longer used in tests.
webauthnLogin := wancli.Login
if promptWebauthn != nil {
webauthnLogin = promptWebauthn
promptWebauthn := login.WebauthnLogin
if promptWebauthn == nil {
promptWebauthn = wancli.Login
}
mfaResp, _, err := webauthnLogin(ctx, webURL.String(), challenge.WebauthnChallenge, prompt, &wancli.LoginOpts{
mfaResp, _, err := promptWebauthn(ctx, webURL.String(), challenge.WebauthnChallenge, prompt, &wancli.LoginOpts{
User: login.User,
AuthenticatorAttachment: login.AuthenticatorAttachment,
})

View file

@ -115,6 +115,7 @@ func newServer(handler http.HandlerFunc, loopback bool) (*httptest.Server, error
}
func TestSSHAgentPasswordlessLogin(t *testing.T) {
t.Parallel()
silenceLogger(t)
clock := clockwork.NewFakeClockAt(time.Now())
@ -133,12 +134,6 @@ func TestSSHAgentPasswordlessLogin(t *testing.T) {
cfg.KeysDir = t.TempDir()
cfg.InsecureSkipVerify = true
// Reset functions after tests.
oldWebauthn := *client.PromptWebauthn
t.Cleanup(func() {
*client.PromptWebauthn = oldWebauthn
})
solvePwdless := func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt) (*proto.MFAAuthenticateResponse, error) {
car, err := device.SignAssertion(origin, assertion)
if err != nil {
@ -225,9 +220,9 @@ func TestSSHAgentPasswordlessLogin(t *testing.T) {
},
AuthenticatorAttachment: tc.AuthenticatorAttachment,
CustomPrompt: test.customPromptLogin,
WebauthnLogin: test.customPromptWebauthn,
}
*client.PromptWebauthn = test.customPromptWebauthn
_, err = client.SSHAgentPasswordlessLogin(ctx, req)
require.NoError(t, err)
require.True(t, customPromptCalled, "Custom prompt present but not called")

View file

@ -294,6 +294,7 @@ func (c *Cluster) passwordlessLogin(stream api.TerminalService_LoginPasswordless
},
AuthenticatorAttachment: c.clusterClient.AuthenticatorAttachment,
CustomPrompt: newPwdlessLoginPrompt(ctx, c.Log, stream),
WebauthnLogin: c.clusterClient.WebauthnLogin,
})
if err != nil {
return nil, trace.Wrap(err)

View file

@ -480,6 +480,10 @@ type CLIConf struct {
// Defaults to [dtauthn.NewCeremony().Run].
DTAuthnRunCeremony client.DTAuthnRunCeremonyFunc
// WebauthnLogin allows tests to override the Webauthn Login func.
// Defaults to [wancli.Login].
WebauthnLogin client.WebauthnLoginFunc
// LeafClusterName is the optional name of a leaf cluster to connect to instead
LeafClusterName string
}
@ -3738,6 +3742,7 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err
c.MockSSOLogin = cf.MockSSOLogin
c.MockHeadlessLogin = cf.MockHeadlessLogin
c.DTAuthnRunCeremony = cf.DTAuthnRunCeremony
c.WebauthnLogin = cf.WebauthnLogin
// pass along MySQL/Postgres path overrides (only used in tests).
c.OverrideMySQLOptionFilePath = cf.overrideMySQLOptionFilePath

View file

@ -1003,10 +1003,9 @@ func approveAllAccessRequests(ctx context.Context, approver accessApprover) erro
// sessions when set either via role or cluster auth preference.
// Sessions created via hostname and by matched labels are
// verified.
//
// NOTE: This test must NOT be run in parallel because it updates
// the global [client.PromptWebauthn] in multiple test cases.
func TestSSHOnMultipleNodes(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -1236,32 +1235,12 @@ func TestSSHOnMultipleNodes(t *testing.T) {
}
}
type mfaPrompt = func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, _ *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error)
setupChallengeSolver := func(mfaPrompt mfaPrompt) func(t *testing.T) {
return func(t *testing.T) {
inputReader := prompt.NewFakeReader().
AddString(password).
AddReply(func(ctx context.Context) (string, error) {
panic("this should not be called")
})
oldStdin, oldWebauthn := prompt.Stdin(), *client.PromptWebauthn
t.Cleanup(func() {
prompt.SetStdin(oldStdin)
*client.PromptWebauthn = oldWebauthn
})
prompt.SetStdin(inputReader)
*client.PromptWebauthn = mfaPrompt
}
}
cases := []struct {
name string
target string
authPreference types.AuthPreference
roles []string
setup func(t *testing.T)
webauthnLogin client.WebauthnLoginFunc
errAssertion require.ErrorAssertionFunc
stdoutAssertion require.ValueAssertionFunc
stderrAssertion require.ValueAssertionFunc
@ -1329,7 +1308,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
},
proxyAddr: rootProxyAddr.String(),
auth: rootAuth.GetAuthServer(),
setup: setupChallengeSolver(successfulChallenge("localhost")),
webauthnLogin: successfulChallenge("localhost"),
target: "env=stage",
stderrAssertion: require.Empty,
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
@ -1352,7 +1331,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
},
proxyAddr: rootProxyAddr.String(),
auth: rootAuth.GetAuthServer(),
setup: setupChallengeSolver(successfulChallenge("localhost")),
webauthnLogin: successfulChallenge("localhost"),
target: "env=prod",
stderrAssertion: require.Empty,
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
@ -1375,7 +1354,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
},
proxyAddr: rootProxyAddr.String(),
auth: rootAuth.GetAuthServer(),
setup: setupChallengeSolver(successfulChallenge("localhost")),
webauthnLogin: successfulChallenge("localhost"),
target: "env=dev",
errAssertion: require.Error,
stderrAssertion: require.Empty,
@ -1395,7 +1374,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
proxyAddr: rootProxyAddr.String(),
auth: rootAuth.GetAuthServer(),
roles: []string{"access", sshLoginRole.GetName(), perSessionMFARole.GetName()},
setup: setupChallengeSolver(successfulChallenge("localhost")),
webauthnLogin: successfulChallenge("localhost"),
target: "env=stage",
stderrAssertion: require.Empty,
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
@ -1431,12 +1410,12 @@ func TestSSHOnMultipleNodes(t *testing.T) {
errAssertion: require.Error,
},
{
name: "command runs on a hostname with mfa set via role",
target: sshHostID,
proxyAddr: rootProxyAddr.String(),
auth: rootAuth.GetAuthServer(),
roles: []string{perSessionMFARole.GetName()},
setup: setupChallengeSolver(successfulChallenge("localhost")),
name: "command runs on a hostname with mfa set via role",
target: sshHostID,
proxyAddr: rootProxyAddr.String(),
auth: rootAuth.GetAuthServer(),
roles: []string{perSessionMFARole.GetName()},
webauthnLogin: successfulChallenge("localhost"),
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.Equal(t, "test\n", i, i2...)
},
@ -1459,7 +1438,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
auth: rootAuth.GetAuthServer(),
target: sshHostID,
roles: []string{perSessionMFARole.GetName()},
setup: setupChallengeSolver(failedChallenge("localhost")),
webauthnLogin: failedChallenge("localhost"),
stdoutAssertion: require.Empty,
stderrAssertion: func(t require.TestingT, v any, i ...any) {
out, ok := v.(string)
@ -1484,7 +1463,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
auth: rootAuth.GetAuthServer(),
target: sshHostID,
roles: []string{perSessionMFARole.GetName()},
setup: setupChallengeSolver(failedChallenge("localhost")),
webauthnLogin: failedChallenge("localhost"),
stderrAssertion: require.Empty,
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.Equal(t, "test\n", i, i2...)
@ -1498,7 +1477,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
proxyAddr: leafProxyAddr,
auth: leafAuth.GetAuthServer(),
roles: []string{perSessionMFARole.GetName()},
setup: setupChallengeSolver(successfulChallenge("leafcluster")),
webauthnLogin: successfulChallenge("leafcluster"),
stderrAssertion: require.Empty,
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.Equal(t, "test\n", i, i2...)
@ -1513,7 +1492,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
auth: rootAuth.GetAuthServer(),
cluster: "leafcluster",
roles: []string{sshLoginRole.GetName()},
setup: setupChallengeSolver(successfulChallenge("localhost")),
webauthnLogin: successfulChallenge("localhost"),
stderrAssertion: require.Empty,
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.Equal(t, "test\n", i, i2...)
@ -1527,7 +1506,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
auth: leafAuth.GetAuthServer(),
roles: []string{sshLoginRole.GetName()},
stderrAssertion: require.Empty,
setup: setupChallengeSolver(successfulChallenge("leafcluster")),
webauthnLogin: successfulChallenge("leafcluster"),
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.Equal(t, "test\n", i, i2...)
},
@ -1540,7 +1519,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
auth: rootAuth.GetAuthServer(),
cluster: "leafcluster",
roles: []string{perSessionMFARole.GetName()},
setup: setupChallengeSolver(successfulChallenge("localhost")),
webauthnLogin: successfulChallenge("localhost"),
stderrAssertion: require.Empty,
stdoutAssertion: func(t require.TestingT, i interface{}, i2 ...interface{}) {
require.Equal(t, "test\n", i, i2...)
@ -1600,10 +1579,6 @@ func TestSSHOnMultipleNodes(t *testing.T) {
})
}
if tt.setup != nil {
tt.setup(t)
}
if tt.roles != nil {
roles := user.GetRoles()
t.Cleanup(func() {
@ -1625,6 +1600,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
}, setHomePath(tmpHomePath),
func(cf *CLIConf) error {
cf.MockSSOLogin = mockSSOLogin(t, tt.auth, user)
cf.WebauthnLogin = tt.webauthnLogin
return nil
},
)
@ -1650,6 +1626,7 @@ func TestSSHOnMultipleNodes(t *testing.T) {
conf.OverrideStdout = stdout
conf.overrideStderr = stderr
conf.MockHeadlessLogin = mockHeadlessLogin(t, tt.auth, user)
conf.WebauthnLogin = tt.webauthnLogin
return nil
},
)