Update profile credential loader to work with tsh v6.0. (#7142)

This commit is contained in:
Brian Joerger 2021-06-07 18:01:12 -07:00 committed by GitHub
parent 23c77d894c
commit a6ed4d3d22
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 78 additions and 30 deletions

View file

@ -128,34 +128,54 @@ func TestLoadKeyPair(t *testing.T) {
func TestLoadProfile(t *testing.T) {
t.Parallel()
profileName := "proxy.example.com"
t.Run("normal profile", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
writeProfile(t, &profile.Profile{
WebProxyAddr: profileName + ":3080",
SiteName: "example.com",
Username: "testUser",
Dir: dir,
}, false)
testProfileContents(t, dir, profileName)
})
// DELETE IN 8.0.0
t.Run("old profile", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
writeProfile(t, &profile.Profile{
WebProxyAddr: profileName + ":3080",
SiteName: "example.com",
Username: "testUser",
Dir: dir,
}, true)
testProfileContents(t, dir, profileName)
})
t.Run("non existent profile", func(t *testing.T) {
t.Parallel()
// Load non existent profile.
creds := LoadProfile("invalid_dir", "invalid_name")
_, err := creds.TLSConfig()
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
_, err = creds.Dialer(Config{})
require.Error(t, err)
})
}
func testProfileContents(t *testing.T, dir, name string) {
// Load expected tls.Config and ssh.ClientConfig.
expectedTLSConfig := getExpectedTLSConfig(t)
expectedSSHConfig := getExpectedSSHConfig(t)
// Write identity file to disk.
dir := t.TempDir()
name := "proxy.example.com"
p := &profile.Profile{
WebProxyAddr: "proxy.example.com:3080",
SiteName: "example.com",
Username: "testUser",
Dir: dir,
}
// Save profile and keys to disk.
require.NoError(t, p.SaveToDir(dir, true))
require.NoError(t, os.MkdirAll(p.KeyDir(), 0700))
require.NoError(t, os.MkdirAll(p.ProxyKeyDir(), 0700))
require.NoError(t, os.MkdirAll(p.SSHDir(), 0700))
require.NoError(t, ioutil.WriteFile(p.UserKeyPath(), keyPEM, 0600))
require.NoError(t, ioutil.WriteFile(p.TLSCertPath(), tlsCert, 0600))
require.NoError(t, ioutil.WriteFile(p.TLSCAsPath(), tlsCACert, 0600))
require.NoError(t, ioutil.WriteFile(p.SSHCertPath(), sshCert, 0600))
require.NoError(t, ioutil.WriteFile(p.KnownHostsPath(), sshCACert, 0600))
// Load profile from disk.
creds := LoadProfile(dir, name)
// Build tls.Config and compare to expected tls.Config.
tlsConfig, err := creds.TLSConfig()
require.NoError(t, err)
@ -167,15 +187,25 @@ func TestLoadProfile(t *testing.T) {
// Build Dialer
_, err = creds.Dialer(Config{})
require.NoError(t, err)
}
// Load invalid profile.
creds = LoadProfile("invalid_dir", "invalid_name")
_, err = creds.TLSConfig()
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
_, err = creds.Dialer(Config{})
require.Error(t, err)
func writeProfile(t *testing.T, p *profile.Profile, oldSSHPath bool) {
// Save profile and keys to disk.
require.NoError(t, p.SaveToDir(p.Dir, true))
require.NoError(t, os.MkdirAll(p.KeyDir(), 0700))
require.NoError(t, os.MkdirAll(p.ProxyKeyDir(), 0700))
require.NoError(t, ioutil.WriteFile(p.UserKeyPath(), keyPEM, 0600))
require.NoError(t, ioutil.WriteFile(p.TLSCertPath(), tlsCert, 0600))
require.NoError(t, ioutil.WriteFile(p.TLSCAsPath(), tlsCACert, 0600))
require.NoError(t, ioutil.WriteFile(p.KnownHostsPath(), sshCACert, 0600))
// If oldSSHPath is specified, write the sshCert to the old ssh cert path.
// DELETE IN 8.0.0
if oldSSHPath {
require.NoError(t, ioutil.WriteFile(p.OldSSHCertPath(), sshCert, 0600))
return
}
require.NoError(t, os.MkdirAll(p.SSHDir(), 0700))
require.NoError(t, ioutil.WriteFile(p.SSHCertPath(), sshCert, 0600))
}
func getExpectedTLSConfig(t *testing.T) *tls.Config {

View file

@ -122,7 +122,13 @@ func (p *Profile) TLSConfig() (*tls.Config, error) {
func (p *Profile) SSHClientConfig() (*ssh.ClientConfig, error) {
cert, err := ioutil.ReadFile(p.SSHCertPath())
if err != nil {
return nil, trace.Wrap(err)
// Try reading SSHCert from old cert path, return original error otherwise
// DELETE IN 8.0.0
var err2 error
cert, err2 = ioutil.ReadFile(p.OldSSHCertPath())
if err2 != nil {
return nil, trace.Wrap(err)
}
}
key, err := ioutil.ReadFile(p.UserKeyPath())
@ -322,6 +328,12 @@ func (p *Profile) SSHCertPath() string {
return keypaths.SSHCertPath(p.Dir, p.Name(), p.Username, p.SiteName)
}
// OldSSHCertPath returns the old (before v6.1) path to the profile's ssh certificate.
// DELETE IN 8.0.0
func (p *Profile) OldSSHCertPath() string {
return keypaths.OldSSHCertPath(p.Dir, p.Name(), p.Username)
}
// KnownHostsPath returns the path to the profile's ssh certificate authorities.
func (p *Profile) KnownHostsPath() string {
return keypaths.KnownHostsPath(p.Dir)

View file

@ -148,6 +148,12 @@ func SSHCertPath(baseDir, proxy, username, cluster string) string {
return filepath.Join(SSHDir(baseDir, proxy, username), cluster+fileExtSSHCert)
}
// OldSSHCertPath returns the old (before v6.1) path to the profile's ssh certificate.
// DELETE IN 8.0.0
func OldSSHCertPath(baseDir, proxy, username string) string {
return filepath.Join(ProxyKeyDir(baseDir, proxy), username+fileExtSSHCert)
}
// AppDir returns the path to the user's app directory
// for the given proxy.
//