Refactor tests under auth package.

Refactored all tests under "lib/auth" to use testify instead
of gocheck.

Switched backend for most tests to "memory".
This commit is contained in:
Russell Jones 2022-07-07 21:21:33 -07:00 committed by Russell Jones
parent 1cca898ffd
commit 2552b4ab25
6 changed files with 191 additions and 126 deletions

View file

@ -25,7 +25,7 @@ import (
"github.com/gravitational/teleport/api/types"
authority "github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/lite"
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/services"
@ -52,10 +52,9 @@ func setupGithubContext(ctx context.Context, t *testing.T) *githubContext {
tt.c = clockwork.NewFakeClockAt(time.Now())
var err error
tt.b, err = lite.NewWithConfig(context.Background(), lite.Config{
Path: t.TempDir(),
PollStreamPeriod: 200 * time.Millisecond,
Clock: tt.c,
tt.b, err = memory.New(memory.Config{
Context: context.Background(),
Clock: tt.c,
})
require.NoError(t, err)

View file

@ -18,10 +18,12 @@ package native
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
@ -29,15 +31,40 @@ import (
"github.com/gravitational/teleport/lib/auth/test"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
"github.com/stretchr/testify/require"
"github.com/jonboulle/clockwork"
"golang.org/x/crypto/ssh"
"gopkg.in/check.v1"
)
func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
}
type nativeContext struct {
suite *test.AuthSuite
}
func setupNativeContext(ctx context.Context, t *testing.T) *nativeContext {
var tt nativeContext
clock := clockwork.NewFakeClockAt(time.Date(2016, 9, 8, 7, 6, 5, 0, time.UTC))
tt.suite = &test.AuthSuite{
A: New(context.Background(), SetClock(clock)),
Keygen: GenerateKeyPair,
Clock: clock,
}
return &tt
}
// TestPrecomputeMode verifies that package enters precompute mode when
// PrecomputeKeys is called.
func TestPrecomputeMode(t *testing.T) {
t.Parallel()
PrecomputeKeys()
select {
@ -47,44 +74,25 @@ func TestPrecomputeMode(t *testing.T) {
}
}
func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
func TestGenerateKeypairEmptyPass(t *testing.T) {
t.Parallel()
tt := setupNativeContext(context.Background(), t)
tt.suite.GenerateKeypairEmptyPass(t)
}
func TestNative(t *testing.T) { check.TestingT(t) }
func TestGenerateHostCert(t *testing.T) {
t.Parallel()
type NativeSuite struct {
suite *test.AuthSuite
tt := setupNativeContext(context.Background(), t)
tt.suite.GenerateHostCert(t)
}
var _ = check.Suite(&NativeSuite{})
func TestGenerateUserCert(t *testing.T) {
t.Parallel()
func (s *NativeSuite) SetUpSuite(c *check.C) {
fakeClock := clockwork.NewFakeClockAt(time.Date(2016, 9, 8, 7, 6, 5, 0, time.UTC))
a := New(
context.TODO(),
SetClock(fakeClock),
)
s.suite = &test.AuthSuite{
A: a,
Keygen: GenerateKeyPair,
Clock: fakeClock,
}
}
func (s *NativeSuite) TestGenerateKeypairEmptyPass(c *check.C) {
s.suite.GenerateKeypairEmptyPass(c)
}
func (s *NativeSuite) TestGenerateHostCert(c *check.C) {
s.suite.GenerateHostCert(c)
}
func (s *NativeSuite) TestGenerateUserCert(c *check.C) {
s.suite.GenerateUserCert(c)
tt := setupNativeContext(context.Background(), t)
tt.suite.GenerateUserCert(t)
}
// TestBuildPrincipals makes sure that the list of principals for a host
@ -96,15 +104,19 @@ func (s *NativeSuite) TestGenerateUserCert(c *check.C) {
// * If both host ID and node name are given, then both should be included
// on the certificate.
// * If the host ID and node name are the same, only list one.
func (s *NativeSuite) TestBuildPrincipals(c *check.C) {
func TestBuildPrincipals(t *testing.T) {
t.Parallel()
tt := setupNativeContext(context.Background(), t)
caPrivateKey, _, err := GenerateKeyPair()
c.Assert(err, check.IsNil)
require.NoError(t, err)
caSigner, err := ssh.ParsePrivateKey(caPrivateKey)
c.Assert(err, check.IsNil)
require.NoError(t, err)
_, hostPublicKey, err := GenerateKeyPair()
c.Assert(err, check.IsNil)
require.NoError(t, err)
tests := []struct {
desc string
@ -169,35 +181,39 @@ func (s *NativeSuite) TestBuildPrincipals(c *check.C) {
}
// run tests
for _, tt := range tests {
c.Logf("Running test case: %q", tt.desc)
hostCertificateBytes, err := s.suite.A.GenerateHostCert(
for _, tc := range tests {
t.Logf("Running test case: %q", tc.desc)
hostCertificateBytes, err := tt.suite.A.GenerateHostCert(
services.HostCertParams{
CASigner: caSigner,
PublicHostKey: hostPublicKey,
HostID: tt.inHostID,
NodeName: tt.inNodeName,
ClusterName: tt.inClusterName,
Role: tt.inRole,
HostID: tc.inHostID,
NodeName: tc.inNodeName,
ClusterName: tc.inClusterName,
Role: tc.inRole,
TTL: time.Hour,
})
c.Assert(err, check.IsNil)
require.NoError(t, err)
hostCertificate, err := sshutils.ParseCertificate(hostCertificateBytes)
c.Assert(err, check.IsNil)
require.NoError(t, err)
c.Assert(hostCertificate.ValidPrincipals, check.DeepEquals, tt.outValidPrincipals)
require.Empty(t, cmp.Diff(hostCertificate.ValidPrincipals, tc.outValidPrincipals))
}
}
// TestUserCertCompatibility makes sure the compatibility flag can be used to
// add to remove roles from certificate extensions.
func (s *NativeSuite) TestUserCertCompatibility(c *check.C) {
func TestUserCertCompatibility(t *testing.T) {
t.Parallel()
tt := setupNativeContext(context.Background(), t)
priv, pub, err := GenerateKeyPair()
c.Assert(err, check.IsNil)
require.NoError(t, err)
caSigner, err := ssh.ParsePrivateKey(priv)
c.Assert(err, check.IsNil)
require.NoError(t, err)
tests := []struct {
inCompatibility string
@ -216,10 +232,10 @@ func (s *NativeSuite) TestUserCertCompatibility(c *check.C) {
}
// run tests
for i, tt := range tests {
comment := check.Commentf("Test %v", i)
for i, tc := range tests {
comment := fmt.Sprintf("Test %v", i)
userCertificateBytes, err := s.suite.A.GenerateUserCert(services.UserCertParams{
userCertificateBytes, err := tt.suite.A.GenerateUserCert(services.UserCertParams{
CASigner: caSigner,
PublicUserKey: pub,
Username: "user",
@ -233,19 +249,21 @@ func (s *NativeSuite) TestUserCertCompatibility(c *check.C) {
Value: "hello",
},
},
CertificateFormat: tt.inCompatibility,
CertificateFormat: tc.inCompatibility,
PermitAgentForwarding: true,
PermitPortForwarding: true,
})
c.Assert(err, check.IsNil, comment)
require.NoError(t, err, comment)
userCertificate, err := sshutils.ParseCertificate(userCertificateBytes)
c.Assert(err, check.IsNil, comment)
// check if we added the roles extension
require.NoError(t, err, comment)
// Check if we added the roles extension.
_, ok := userCertificate.Extensions[teleport.CertExtensionTeleportRoles]
c.Assert(ok, check.Equals, tt.outHasRoles, comment)
// check if users custom extension was added
require.Equal(t, ok, tc.outHasRoles, comment)
// Check if users custom extension was added.
extVal := userCertificate.Extensions["login@github.com"]
c.Assert(extVal, check.Equals, "hello")
require.Equal(t, extVal, "hello")
}
}

View file

@ -33,15 +33,16 @@ import (
"github.com/gravitational/teleport/api/types"
authority "github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/lite"
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/trace"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
directory "google.golang.org/api/admin/directory/v1"
@ -57,13 +58,14 @@ type OIDCSuite struct {
func setUpSuite(t *testing.T) *OIDCSuite {
s := OIDCSuite{}
ctx := context.Background()
s.c = clockwork.NewFakeClockAt(time.Now())
var err error
s.b, err = lite.NewWithConfig(context.Background(), lite.Config{
Path: t.TempDir(),
PollStreamPeriod: 200 * time.Millisecond,
Clock: s.c,
s.b, err = memory.New(memory.Config{
Context: ctx,
Clock: s.c,
})
require.NoError(t, err)
@ -100,6 +102,8 @@ func createInsecureOIDCClient(t *testing.T, connector types.OIDCConnector) *oidc
}
func TestCreateOIDCUser(t *testing.T) {
t.Parallel()
s := setUpSuite(t)
// Dry-run creation of OIDC user.
@ -140,6 +144,8 @@ func TestCreateOIDCUser(t *testing.T) {
// all claim information is already within the token and additional claim
// information does not need to be fetched.
func TestUserInfoBlockHTTP(t *testing.T) {
t.Parallel()
ctx := context.Background()
s := setUpSuite(t)
// Create configurable IdP to use in tests.
@ -166,6 +172,8 @@ func TestUserInfoBlockHTTP(t *testing.T) {
// TestUserInfoBadStatus asserts that a 4xx response from userinfo results
// in AccessDenied.
func TestUserInfoBadStatus(t *testing.T) {
t.Parallel()
// Create configurable IdP to use in tests.
idp := newFakeIDP(t, true /* tls */)
@ -186,6 +194,8 @@ func TestUserInfoBadStatus(t *testing.T) {
}
func TestSSODiagnostic(t *testing.T) {
t.Parallel()
ctx := context.Background()
s := setUpSuite(t)
// Create configurable IdP to use in tests.
@ -328,6 +338,8 @@ func TestSSODiagnostic(t *testing.T) {
// TestPingProvider confirms that the client_secret_post auth
// method was set for a oauthclient.
func TestPingProvider(t *testing.T) {
t.Parallel()
ctx := context.Background()
s := setUpSuite(t)
// Create configurable IdP to use in tests.
@ -376,6 +388,8 @@ func TestPingProvider(t *testing.T) {
}
func TestOIDCClientProviderSync(t *testing.T) {
t.Parallel()
ctx := context.Background()
// Create configurable IdP to use in tests.
idp := newFakeIDP(t, false /* tls */)
@ -434,6 +448,8 @@ func TestOIDCClientProviderSync(t *testing.T) {
}
func TestOIDCClientCache(t *testing.T) {
t.Parallel()
ctx := context.Background()
s := setUpSuite(t)
// Create configurable IdP to use in tests.
@ -565,6 +581,8 @@ func (s *fakeIDP) configurationHandler(w http.ResponseWriter, r *http.Request) {
}
func TestOIDCGoogle(t *testing.T) {
t.Parallel()
directGroups := map[string][]string{
"alice@foo.example": {"group1@foo.example", "group2@sub.foo.example", "group3@bar.example"},
"bob@foo.example": {"group1@foo.example"},

View file

@ -34,18 +34,18 @@ import (
authority "github.com/gravitational/teleport/lib/auth/testauthority"
wanlib "github.com/gravitational/teleport/lib/auth/webauthn"
"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/backend/lite"
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/suite"
"github.com/stretchr/testify/require"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/pquerna/otp/totp"
"github.com/stretchr/testify/require"
)
type passwordSuite struct {
@ -56,9 +56,18 @@ type passwordSuite struct {
func setupPasswordSuite(t *testing.T) *passwordSuite {
s := passwordSuite{}
ctx := context.Background()
clock := clockwork.NewFakeClockAt(time.Now())
var err error
s.bk, err = lite.New(context.TODO(), backend.Params{"path": t.TempDir()})
s.bk, err = memory.New(memory.Config{
Context: ctx,
Clock: clock,
})
require.NoError(t, err)
// set cluster name
clusterName, err := services.NewClusterNameWithRandomID(types.ClusterNameSpecV2{
ClusterName: "me.localhost",
@ -90,6 +99,8 @@ func setupPasswordSuite(t *testing.T) *passwordSuite {
}
func TestPasswordTimingAttack(t *testing.T) {
t.Parallel()
s := setupPasswordSuite(t)
username := "foo"
password := "barbaz"
@ -168,6 +179,7 @@ func TestPasswordTimingAttack(t *testing.T) {
func TestUserNotFound(t *testing.T) {
t.Parallel()
s := setupPasswordSuite(t)
username := "unknown-user"
password := "barbaz"
@ -180,6 +192,7 @@ func TestUserNotFound(t *testing.T) {
func TestChangePassword(t *testing.T) {
t.Parallel()
s := setupPasswordSuite(t)
req, err := s.prepareForPasswordChange("user1", []byte("abc123"), constants.SecondFactorOff)
require.NoError(t, err)
@ -204,6 +217,7 @@ func TestChangePassword(t *testing.T) {
func TestChangePasswordWithOTP(t *testing.T) {
t.Parallel()
s := setupPasswordSuite(t)
req, err := s.prepareForPasswordChange("user2", []byte("abc123"), constants.SecondFactorOTP)
require.NoError(t, err)
@ -242,6 +256,7 @@ func TestChangePasswordWithOTP(t *testing.T) {
func TestServer_ChangePassword(t *testing.T) {
t.Parallel()
srv := newTestTLSServer(t)
mfa := configureForMFA(t, srv)
@ -310,6 +325,7 @@ func TestServer_ChangePassword(t *testing.T) {
func TestChangeUserAuthentication(t *testing.T) {
t.Parallel()
srv := newTestTLSServer(t)
ctx := context.Background()
@ -551,6 +567,7 @@ func TestChangeUserAuthentication(t *testing.T) {
func TestChangeUserAuthenticationWithErrors(t *testing.T) {
t.Parallel()
s := setupPasswordSuite(t)
ctx := context.Background()
authPreference, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{

File diff suppressed because one or more lines are too long

View file

@ -18,9 +18,10 @@ limitations under the License.
package test
import (
"testing"
"time"
"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
@ -30,10 +31,11 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshca"
"golang.org/x/crypto/ssh"
"github.com/gravitational/trace"
"github.com/google/go-cmp/cmp"
"github.com/jonboulle/clockwork"
"gopkg.in/check.v1"
"github.com/stretchr/testify/require"
)
type AuthSuite struct {
@ -42,24 +44,24 @@ type AuthSuite struct {
Clock clockwork.Clock
}
func (s *AuthSuite) GenerateKeypairEmptyPass(c *check.C) {
func (s *AuthSuite) GenerateKeypairEmptyPass(t *testing.T) {
priv, pub, err := s.Keygen()
c.Assert(err, check.IsNil)
require.NoError(t, err)
// make sure we can parse the private and public key
_, err = ssh.ParsePrivateKey(priv)
c.Assert(err, check.IsNil)
require.NoError(t, err)
_, _, _, _, err = ssh.ParseAuthorizedKey(pub)
c.Assert(err, check.IsNil)
require.NoError(t, err)
}
func (s *AuthSuite) GenerateHostCert(c *check.C) {
func (s *AuthSuite) GenerateHostCert(t *testing.T) {
priv, pub, err := s.Keygen()
c.Assert(err, check.IsNil)
require.NoError(t, err)
caSigner, err := ssh.ParsePrivateKey(priv)
c.Assert(err, check.IsNil)
require.NoError(t, err)
cert, err := s.A.GenerateHostCert(
services.HostCertParams{
@ -71,26 +73,26 @@ func (s *AuthSuite) GenerateHostCert(c *check.C) {
Role: types.RoleAdmin,
TTL: time.Hour,
})
c.Assert(err, check.IsNil)
require.NoError(t, err)
certificate, err := sshutils.ParseCertificate(cert)
c.Assert(err, check.IsNil)
require.NoError(t, err)
// Check the valid time is not more than 1 minute before the current time.
validAfter := time.Unix(int64(certificate.ValidAfter), 0)
c.Assert(validAfter.Unix(), check.Equals, s.Clock.Now().UTC().Add(-1*time.Minute).Unix())
require.Equal(t, validAfter.Unix(), s.Clock.Now().UTC().Add(-1*time.Minute).Unix())
// Check the valid time is not more than 1 hour after the current time.
validBefore := time.Unix(int64(certificate.ValidBefore), 0)
c.Assert(validBefore.Unix(), check.Equals, s.Clock.Now().UTC().Add(1*time.Hour).Unix())
require.Equal(t, validBefore.Unix(), s.Clock.Now().UTC().Add(1*time.Hour).Unix())
}
func (s *AuthSuite) GenerateUserCert(c *check.C) {
func (s *AuthSuite) GenerateUserCert(t *testing.T) {
priv, pub, err := s.Keygen()
c.Assert(err, check.IsNil)
require.NoError(t, err)
caSigner, err := ssh.ParsePrivateKey(priv)
c.Assert(err, check.IsNil)
require.NoError(t, err)
cert, err := s.A.GenerateUserCert(services.UserCertParams{
CASigner: caSigner,
@ -102,12 +104,12 @@ func (s *AuthSuite) GenerateUserCert(c *check.C) {
PermitPortForwarding: true,
CertificateFormat: constants.CertificateFormatStandard,
})
c.Assert(err, check.IsNil)
require.NoError(t, err)
// Check the valid time is not more than 1 minute before and 1 hour after
// the current time.
err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(1*time.Hour))
c.Assert(err, check.IsNil)
require.NoError(t, err)
cert, err = s.A.GenerateUserCert(services.UserCertParams{
CASigner: caSigner,
@ -119,9 +121,9 @@ func (s *AuthSuite) GenerateUserCert(c *check.C) {
PermitPortForwarding: true,
CertificateFormat: constants.CertificateFormatStandard,
})
c.Assert(err, check.IsNil)
require.NoError(t, err)
err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(apidefaults.MinCertDuration))
c.Assert(err, check.IsNil)
require.NoError(t, err)
_, err = s.A.GenerateUserCert(services.UserCertParams{
CASigner: caSigner,
@ -133,9 +135,9 @@ func (s *AuthSuite) GenerateUserCert(c *check.C) {
PermitPortForwarding: true,
CertificateFormat: constants.CertificateFormatStandard,
})
c.Assert(err, check.IsNil)
require.NoError(t, err)
err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(apidefaults.MinCertDuration))
c.Assert(err, check.IsNil)
require.NoError(t, err)
_, err = s.A.GenerateUserCert(services.UserCertParams{
CASigner: caSigner,
@ -147,7 +149,7 @@ func (s *AuthSuite) GenerateUserCert(c *check.C) {
PermitPortForwarding: true,
CertificateFormat: constants.CertificateFormatStandard,
})
c.Assert(err, check.IsNil)
require.NoError(t, err)
inRoles := []string{"role-1", "role-2"}
impersonator := "alice"
@ -163,15 +165,15 @@ func (s *AuthSuite) GenerateUserCert(c *check.C) {
CertificateFormat: constants.CertificateFormatStandard,
Roles: inRoles,
})
c.Assert(err, check.IsNil)
require.NoError(t, err)
parsedCert, err := sshutils.ParseCertificate(cert)
c.Assert(err, check.IsNil)
require.NoError(t, err)
outRoles, err := services.UnmarshalCertRoles(parsedCert.Extensions[teleport.CertExtensionTeleportRoles])
c.Assert(err, check.IsNil)
c.Assert(outRoles, check.DeepEquals, inRoles)
require.NoError(t, err)
require.Empty(t, cmp.Diff(outRoles, inRoles))
outImpersonator := parsedCert.Extensions[teleport.CertExtensionImpersonator]
c.Assert(outImpersonator, check.DeepEquals, impersonator)
require.Empty(t, cmp.Diff(outImpersonator, impersonator))
}
func checkCertExpiry(cert []byte, after, before time.Time) error {