tsh scp to use target directory correctly (#5501)

* Fixes the scp logic to take target directory into account in sink mode.
Also expose channel error in scp client so the error is more visible to
the user. Old behavior will only output the 'exit code n' if anything
breaks.

Fixes https://github.com/gravitational/teleport/issues/5497.

* Silence 'wait: remote command exited without exit status or exit signal' error when interrupting the scp session. Leave a TODO to fix properly in a future PR

* Address review comments
This commit is contained in:
a-palchikov 2021-02-11 19:35:40 +01:00 committed by GitHub
parent 68bc78f10e
commit e65eac59b0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 331 additions and 188 deletions

View file

@ -1426,7 +1426,7 @@ func (tc *TeleportClient) ExecuteSCP(ctx context.Context, cmd scp.Command) (err
return trace.Wrap(err)
}
err = nodeClient.ExecuteSCP(cmd)
err = nodeClient.ExecuteSCP(ctx, cmd)
if err != nil {
// converts SSH error code to tc.ExitStatus
exitError, _ := trace.Unwrap(err).(*ssh.ExitError)
@ -1463,21 +1463,24 @@ func (tc *TeleportClient) SCP(ctx context.Context, args []string, port int, flag
}
defer proxyClient.Close()
var progressWriter io.Writer
if !quiet {
progressWriter = tc.Stdout
}
// helper function connects to the src/target node:
connectToNode := func(addr string) (*NodeClient, error) {
connectToNode := func(addr, hostLogin string) (*NodeClient, error) {
// determine which cluster we're connecting to:
siteInfo, err := proxyClient.currentCluster()
if err != nil {
return nil, trace.Wrap(err)
}
if hostLogin == "" {
hostLogin = tc.Config.HostLogin
}
return proxyClient.ConnectToNode(ctx,
NodeAddr{Addr: addr, Namespace: tc.Namespace, Cluster: siteInfo.Name},
tc.HostLogin, false)
}
var progressWriter io.Writer
if !quiet {
progressWriter = tc.Stdout
hostLogin, false)
}
// gets called to convert SSH error code to tc.ExitStatus
@ -1488,88 +1491,102 @@ func (tc *TeleportClient) SCP(ctx context.Context, args []string, port int, flag
}
return err
}
tpl := scp.Config{
User: tc.Username,
ProgressWriter: progressWriter,
Flags: flags,
}
var config *scpConfig
// upload:
if isRemoteDest(last) {
filesToUpload := args[:len(args)-1]
// If more than a single file were provided, scp must be in directory mode
// and the target on the remote host needs to be a directory.
var directoryMode bool
if len(filesToUpload) > 1 {
directoryMode = true
}
dest, err := scp.ParseSCPDestination(last)
config, err = tc.uploadConfig(ctx, tpl, port, args)
if err != nil {
return trace.Wrap(err)
}
if dest.Login != "" {
tc.HostLogin = dest.Login
}
addr := net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port))
client, err := connectToNode(addr)
if err != nil {
return trace.Wrap(err)
}
// copy everything except the last arg (that's destination)
for _, src := range filesToUpload {
scpConfig := scp.Config{
User: tc.Username,
ProgressWriter: progressWriter,
RemoteLocation: dest.Path,
Flags: flags,
}
scpConfig.Flags.Target = []string{src}
scpConfig.Flags.DirectoryMode = directoryMode
cmd, err := scp.CreateUploadCommand(scpConfig)
if err != nil {
return trace.Wrap(err)
}
err = client.ExecuteSCP(cmd)
if err != nil {
return onError(err)
}
}
} else {
// download:
src, err := scp.ParseSCPDestination(first)
config, err = tc.downloadConfig(ctx, tpl, port, args)
if err != nil {
return trace.Wrap(err)
}
addr := net.JoinHostPort(src.Host.Host(), strconv.Itoa(port))
if src.Login != "" {
tc.HostLogin = src.Login
}
client, err := connectToNode(addr)
if err != nil {
return trace.Wrap(err)
}
// copy everything except the last arg (that's destination)
for _, dest := range args[1:] {
scpConfig := scp.Config{
User: tc.Username,
Flags: flags,
RemoteLocation: src.Path,
ProgressWriter: progressWriter,
}
scpConfig.Flags.Target = []string{dest}
cmd, err := scp.CreateDownloadCommand(scpConfig)
if err != nil {
return trace.Wrap(err)
}
err = client.ExecuteSCP(cmd)
if err != nil {
return onError(err)
}
}
}
return nil
client, err := connectToNode(config.addr, config.hostLogin)
if err != nil {
return trace.Wrap(err)
}
return onError(client.ExecuteSCP(ctx, config.cmd))
}
func (tc *TeleportClient) uploadConfig(ctx context.Context, tpl scp.Config, port int, args []string) (config *scpConfig, err error) {
filesToUpload := args[:len(args)-1]
// copy everything except the last arg (the destination)
destPath := args[len(args)-1]
// If more than a single file were provided, scp must be in directory mode
// and the target on the remote host needs to be a directory.
var directoryMode bool
if len(filesToUpload) > 1 {
directoryMode = true
}
dest, addr, err := getSCPDestination(destPath, port)
if err != nil {
return nil, trace.Wrap(err)
}
tpl.RemoteLocation = dest.Path
tpl.Flags.Target = filesToUpload
tpl.Flags.DirectoryMode = directoryMode
cmd, err := scp.CreateUploadCommand(tpl)
if err != nil {
return nil, trace.Wrap(err)
}
return &scpConfig{
cmd: cmd,
addr: addr,
hostLogin: dest.Login,
}, nil
}
func (tc *TeleportClient) downloadConfig(ctx context.Context, tpl scp.Config, port int, args []string) (config *scpConfig, err error) {
src, addr, err := getSCPDestination(args[0], port)
if err != nil {
return nil, trace.Wrap(err)
}
tpl.RemoteLocation = src.Path
tpl.Flags.Target = args[1:]
cmd, err := scp.CreateDownloadCommand(tpl)
if err != nil {
return nil, trace.Wrap(err)
}
return &scpConfig{
cmd: cmd,
addr: addr,
hostLogin: src.Login,
}, nil
}
type scpConfig struct {
cmd scp.Command
addr string
hostLogin string
}
func getSCPDestination(target string, port int) (dest *scp.Destination, addr string, err error) {
dest, err = scp.ParseSCPDestination(target)
if err != nil {
return nil, "", trace.Wrap(err)
}
addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port))
return dest, addr, nil
}
func isRemoteDest(name string) bool {

View file

@ -20,9 +20,11 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net"
"os"
"strconv"
"strings"
"time"
@ -854,7 +856,7 @@ func (proxy *ProxyClient) Close() error {
// ExecuteSCP runs remote scp command(shellCmd) on the remote server and
// runs local scp handler using SCP Command
func (c *NodeClient) ExecuteSCP(cmd scp.Command) error {
func (c *NodeClient) ExecuteSCP(ctx context.Context, cmd scp.Command) error {
shellCmd, err := cmd.GetRemoteShellCmd()
if err != nil {
return trace.Wrap(err)
@ -876,6 +878,14 @@ func (c *NodeClient) ExecuteSCP(cmd scp.Command) error {
return trace.Wrap(err)
}
// Stream scp's stderr so tsh gets the verbose remote error
// if the command fails
stderr, err := s.StderrPipe()
if err != nil {
return trace.Wrap(err)
}
go io.Copy(os.Stderr, stderr)
ch := utils.NewPipeNetConn(
stdout,
stdin,
@ -884,18 +894,40 @@ func (c *NodeClient) ExecuteSCP(cmd scp.Command) error {
&net.IPAddr{},
)
closeC := make(chan error, 1)
execC := make(chan error, 1)
go func() {
err := cmd.Execute(ch)
if err != nil {
log.Error(err)
if err != nil && !trace.IsEOF(err) {
log.WithError(err).Warn("Failed to execute SCP command.")
}
stdin.Close()
closeC <- err
execC <- err
}()
runErr := s.Run(shellCmd)
err = <-closeC
runC := make(chan error, 1)
go func() {
err := s.Run(shellCmd)
if err != nil && errors.Is(err, &ssh.ExitMissingError{}) {
// TODO(dmitri): currently, if the session is aborted with (*session).Close,
// the remote side cannot send exit-status and this error results.
// To abort the session properly, Teleport needs to support `signal` request
err = nil
}
runC <- err
}()
var runErr error
select {
case <-ctx.Done():
if err := s.Close(); err != nil {
log.WithError(err).Debug("Failed to close the SSH session.")
}
err, runErr = <-execC, <-runC
case err = <-execC:
runErr = <-runC
case runErr = <-runC:
err = <-execC
}
if runErr != nil && (err == nil || trace.IsEOF(err)) {
err = runErr

View file

@ -639,7 +639,7 @@ func (c *ServerContext) SendExecResult(r ExecResult) {
select {
case c.ExecResultCh <- r:
default:
log.Infof("blocked on sending exec result %v", r)
c.Infof("Blocked on sending exec result %v.", r)
}
}
@ -649,7 +649,7 @@ func (c *ServerContext) SendSubsystemResult(r SubsystemResult) {
select {
case c.SubsystemResultCh <- r:
default:
c.Infof("blocked on sending subsystem result")
c.Info("Blocked on sending subsystem result.")
}
}

View file

@ -149,20 +149,14 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) {
// Connect stdout and stderr to the channel so the user can interact with
// the command.
e.Cmd.Stderr = channel.Stderr()
e.Cmd.Stdout = channel
e.Cmd.Stderr = io.MultiWriter(os.Stderr, channel.Stderr())
e.Cmd.Stdout = io.MultiWriter(os.Stdout, channel)
// Copy from the channel (client input) into stdin of the process.
inputWriter, err := e.Cmd.StdinPipe()
if err != nil {
return nil, trace.Wrap(err)
}
go func() {
if _, err := io.Copy(inputWriter, channel); err != nil {
e.Ctx.Warningf("Failed to forward data from SSH channel to local command %q stdin: %v", e.GetCommand(), err)
}
inputWriter.Close()
}()
// Start the command.
err = e.Cmd.Start()
@ -178,6 +172,13 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) {
}, trace.ConvertSystemError(err)
}
go func() {
if _, err := io.Copy(inputWriter, channel); err != nil {
e.Ctx.Warnf("Failed to forward data from SSH channel to local command %q stdin: %v", e.GetCommand(), err)
}
inputWriter.Close()
}()
e.Ctx.Infof("Started local command execution: %q", e.Command)
return nil, nil
@ -186,7 +187,7 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) {
// Wait will block while the command executes.
func (e *localExec) Wait() *ExecResult {
if e.Cmd.Process == nil {
e.Ctx.Errorf("no process")
e.Ctx.Error("No process.")
}
// Block until the command is finished executing.

View file

@ -90,6 +90,8 @@ type Config struct {
// RunOnServer is low level API flag that indicates that
// this command will be run on the server
RunOnServer bool
// Log optionally specifies the logger
Log log.FieldLogger
}
// Command is an API that describes command operations
@ -167,10 +169,25 @@ func CreateUploadCommand(cfg Config) (Command, error) {
// CheckAndSetDefaults checks and sets default values
func (c *Config) CheckAndSetDefaults() error {
logger := c.Log
if logger == nil {
logger = log.StandardLogger()
}
c.Log = logger.WithFields(log.Fields{
trace.Component: "SCP",
trace.ComponentFields: log.Fields{
"LocalAddr": c.Flags.LocalAddr,
"RemoteAddr": c.Flags.RemoteAddr,
"Target": c.Flags.Target,
"PreserveAttrs": c.Flags.PreserveAttrs,
"User": c.User,
"RunOnServer": c.RunOnServer,
"RemoteLocation": c.RemoteLocation,
},
})
if c.FileSystem == nil {
c.FileSystem = &localFileSystem{}
}
if c.User == "" {
return trace.BadParameter("missing User parameter")
}
@ -186,31 +203,17 @@ func CreateCommand(cfg Config) (Command, error) {
return nil, trace.Wrap(err)
}
cmd := command{
return &command{
Config: cfg,
}
cmd.log = log.WithFields(log.Fields{
trace.Component: "SCP",
trace.ComponentFields: log.Fields{
"LocalAddr": cfg.Flags.LocalAddr,
"RemoteAddr": cfg.Flags.RemoteAddr,
"Target": cfg.Flags.Target,
"PreserveAttrs": cfg.Flags.PreserveAttrs,
"User": cfg.User,
"RunOnServer": cfg.RunOnServer,
"RemoteLocation": cfg.RemoteLocation,
},
})
return &cmd, nil
log: cfg.Log,
}, nil
}
// Command mimics behavior of SCP command line tool
// to teleport can pretend it launches real SCP behind the scenes
type command struct {
Config
log *log.Entry
log log.FieldLogger
}
// Execute implements SSH file copy (SCP). It is called on both tsh (client)
@ -333,7 +336,7 @@ func (cmd *command) sendDir(r *reader, ch io.ReadWriter, fileInfo FileInfo) erro
if _, err = fmt.Fprintf(ch, "E\n"); err != nil {
return trace.Wrap(err)
}
return r.read()
return trace.Wrap(r.read())
}
func (cmd *command) sendFile(r *reader, ch io.ReadWriter, fileInfo FileInfo) error {
@ -375,7 +378,7 @@ func (cmd *command) sendFile(r *reader, ch io.ReadWriter, fileInfo FileInfo) err
func (cmd *command) sendErr(ch io.Writer, err error) {
out := fmt.Sprintf("%c%s\n", byte(ErrByte), err)
if _, err := ch.Write([]byte(out)); err != nil {
cmd.log.Debugf("failed sending SCP error message to the remote side: %v", err)
cmd.log.Debugf("Failed sending SCP error message to the remote side: %v.", err)
}
}
@ -386,23 +389,25 @@ func (cmd *command) serveSink(ch io.ReadWriter) error {
// directory.
if cmd.Flags.DirectoryMode {
if len(cmd.Flags.Target) != 1 {
return trace.BadParameter("in directory mode, only single upload target is allowed but %v provided", len(cmd.Flags.Target))
return trace.BadParameter("in directory mode, only single upload target is allowed but %q provided",
cmd.Flags.Target)
}
fi, err := os.Stat(cmd.Flags.Target[0])
if err != nil {
return trace.Wrap(err)
}
if mode := fi.Mode(); !mode.IsDir() {
if !cmd.FileSystem.IsDir(cmd.Flags.Target[0]) {
return trace.BadParameter("target path must be a directory")
}
}
rootDir := localDir
if cmd.hasTargetDir() {
rootDir = newPathFromDir(cmd.Flags.Target[0])
}
if err := sendOK(ch); err != nil {
return trace.Wrap(err)
}
var st state
st.path = localDir
st.path = rootDir
var b [1]byte
scanner := bufio.NewScanner(ch)
for {
@ -478,14 +483,13 @@ func (cmd *command) processCommand(ch io.ReadWriter, st *state, b byte, line str
func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) error {
cmd.log.Debugf("scp.receiveFile(%v): %v", cmd.Flags.Target, fc.Name)
// if the destination path is a folder, we should save the file to that folder, but
// only if 'recursive' is set
path := cmd.Flags.Target[0]
if cmd.Flags.Recursive || cmd.FileSystem.IsDir(path) {
path = st.makePath(fc.Name)
// Unless target specifies a file, use the file name from the command
filename := fc.Name
if !cmd.Flags.Recursive && !cmd.FileSystem.IsDir(cmd.Flags.Target[0]) {
filename = cmd.Flags.Target[0]
}
path := st.makePath(filename)
writer, err := cmd.FileSystem.CreateFile(path, fc.Length)
if err != nil {
return trace.Wrap(err)
@ -504,7 +508,6 @@ func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) erro
n, err := io.CopyN(writer, ch, int64(fc.Length))
if err != nil {
cmd.log.Error(err)
return trace.Wrap(err)
}
@ -528,15 +531,9 @@ func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) erro
func (cmd *command) receiveDir(st *state, fc newFileCmd, ch io.ReadWriter) error {
cmd.log.Debugf("scp.receiveDir(%v): %v", cmd.Flags.Target, fc.Name)
targetDir := cmd.Flags.Target[0]
// copying into an existing directory? append to it:
if cmd.FileSystem.IsDir(targetDir) {
targetDir = st.makePath(fc.Name)
}
st.push(fc.Name, st.stat)
err := cmd.FileSystem.MkDir(targetDir, int(fc.Mode))
err := cmd.FileSystem.MkDir(st.path.join(), int(fc.Mode))
if err != nil {
return trace.ConvertSystemError(err)
}
@ -599,6 +596,14 @@ func (cmd *command) updateDirTimes(path pathSegments) error {
return nil
}
func (cmd *command) hasTargetDir() bool {
return len(cmd.Flags.Target) != 0 && cmd.FileSystem.IsDir(cmd.Flags.Target[0])
}
func (r newFileCmd) String() string {
return fmt.Sprintf("newFileCmd(mode=%o,len=%d,name=%v)", r.Mode, r.Length, r.Name)
}
type newFileCmd struct {
Mode int64
Length uint64
@ -675,15 +680,19 @@ type state struct {
stat *mtimeCmd
}
func (r pathSegments) join() string {
func (r pathSegments) join(elems ...string) string {
path := make([]string, 0, len(r))
for _, s := range r {
path = append(path, s.dir)
}
return filepath.Join(path...)
return filepath.Join(append(path, elems...)...)
}
var localDir = pathSegments{{dir: "."}}
var localDir = newPathFromDir(".")
func newPathFromDir(dir string) pathSegments {
return pathSegments{{dir: dir}}
}
type pathSegments []pathSegment
@ -710,7 +719,7 @@ func (st *state) pop() pathSegments {
}
func (st *state) makePath(filename string) string {
return filepath.Join(st.path.join(), filename)
return st.path.join(filename)
}
func newReader(r io.Reader) *reader {

View file

@ -30,6 +30,7 @@ import (
"time"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"
"github.com/google/go-cmp/cmp"
@ -38,7 +39,7 @@ import (
)
func TestHTTPSendFile(t *testing.T) {
outDir := tempDir(t)
outDir := t.TempDir()
expectedBytes := []byte("hello")
buf := bytes.NewReader(expectedBytes)
@ -65,7 +66,7 @@ func TestHTTPSendFile(t *testing.T) {
}
func TestHTTPReceiveFile(t *testing.T) {
source := filepath.Join(tempDir(t), "target")
source := filepath.Join(t.TempDir(), "target")
contents := []byte("hello, file contents!")
err := ioutil.WriteFile(source, contents, 0666)
@ -112,7 +113,7 @@ func TestSend(t *testing.T) {
desc: "regular file preserving the attributes",
config: newSourceConfig("file", Flags{PreserveAttrs: true}),
args: args("-v", "-t", "-p"),
fs: newTestFS(logger, newFile("file", modtime, atime, "file contents")),
fs: newTestFS(logger, newFileTimes("file", modtime, atime, "file contents")),
},
{
desc: "directory preserving the attributes",
@ -121,10 +122,10 @@ func TestSend(t *testing.T) {
fs: newTestFS(
logger,
// Use timestamps extending backwards to test time application
newDir("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second),
newFile("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"),
newDir("dir/dir2", dirModtime, dirAtime,
newFile("dir/dir2/file2", modtime, atime, "file2 contents")),
newDirTimes("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second),
newFileTimes("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"),
newDirTimes("dir/dir2", dirModtime, dirAtime,
newFileTimes("dir/dir2/file2", modtime, atime, "file2 contents")),
),
),
},
@ -136,7 +137,7 @@ func TestSend(t *testing.T) {
cmd, err := CreateCommand(tt.config)
require.NoError(t, err)
targetDir := tempDir(t)
targetDir := t.TempDir()
target := filepath.Join(targetDir, tt.config.Flags.Target[0])
args := append(tt.args, target)
@ -153,7 +154,7 @@ func TestSend(t *testing.T) {
fs := newEmptyTestFS(logger)
fromOS(t, targetDir, &fs)
validateSCP(t, fs, tt.fs)
validateSCPTimes(t, fs, tt.fs)
validateSCPContents(t, fs, tt.fs)
})
}
@ -177,7 +178,7 @@ func TestReceive(t *testing.T) {
desc: "regular file preserving the attributes",
config: newTargetConfig("file", Flags{PreserveAttrs: true}),
args: args("-v", "-f", "-p"),
fs: newTestFS(logger, newFile("file", modtime, atime, "file contents")),
fs: newTestFS(logger, newFileTimes("file", modtime, atime, "file contents")),
},
{
desc: "directory preserving the attributes",
@ -186,10 +187,10 @@ func TestReceive(t *testing.T) {
fs: newTestFS(
logger,
// Use timestamps extending backwards to test time application
newDir("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second),
newFile("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"),
newDir("dir/dir2", dirModtime, dirAtime,
newFile("dir/dir2/file2", modtime, atime, "file2 contents")),
newDirTimes("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second),
newFileTimes("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"),
newDirTimes("dir/dir2", dirModtime, dirAtime,
newFileTimes("dir/dir2/file2", modtime, atime, "file2 contents")),
),
),
},
@ -201,7 +202,7 @@ func TestReceive(t *testing.T) {
cmd, err := CreateCommand(tt.config)
require.NoError(t, err)
sourceDir := tempDir(t)
sourceDir := t.TempDir()
source := filepath.Join(sourceDir, tt.config.Flags.Target[0])
args := append(tt.args, source)
@ -209,8 +210,7 @@ func TestReceive(t *testing.T) {
err = runSCP(cmd, args...)
require.Regexp(t, ".*No such file or directory", err)
fs := newEmptyTestFS(logger)
tt.config.FileSystem = fs
tt.config.FileSystem = newEmptyTestFS(logger)
cmd, err = CreateCommand(tt.config)
require.NoError(t, err)
@ -221,12 +221,58 @@ func TestReceive(t *testing.T) {
err = runSCP(cmd, args...)
require.NoError(t, err)
validateSCP(t, tt.fs, fs)
validateSCPContents(t, tt.fs, fs)
validateSCPTimes(t, tt.fs, tt.config.FileSystem)
validateSCPContents(t, tt.fs, tt.config.FileSystem)
})
}
}
// TestReceiveIntoExistingDirectory validates that the target remote directory
// is respected during copy.
//
// See https://github.com/gravitational/teleport/issues/5497
func TestReceiveIntoExistingDirectory(t *testing.T) {
utils.InitLoggerForTests(testing.Verbose())
logger := logrus.WithField("test", t.Name())
config := newTargetConfigWithFS("dir",
Flags{PreserveAttrs: true, Recursive: true},
newTestFS(logger, newDir("dir")),
)
sourceFS := newTestFS(
logger,
newDir("dir",
newFile("dir/file", "file contents"),
newDir("dir/dir2",
newFile("dir/dir2/file2", "file2 contents")),
),
)
expectedFS := newTestFS(
logger,
// Source is copied into an existing directory
newDir("dir/dir",
newFile("dir/dir/file", "file contents"),
newDir("dir/dir/dir2",
newFile("dir/dir/dir2/file2", "file2 contents")),
),
)
sourceDir := t.TempDir()
source := filepath.Join(sourceDir, config.Flags.Target[0])
args := append(args("-v", "-f", "-r", "-p"), source)
cmd, err := CreateCommand(config)
require.NoError(t, err)
writeData(t, sourceDir, sourceFS)
writeFileTimes(t, sourceDir, sourceFS)
err = runSCP(cmd, args...)
require.NoError(t, err)
validateSCP(t, expectedFS, config.FileSystem)
validateSCPContents(t, expectedFS, config.FileSystem)
}
func TestInvalidDir(t *testing.T) {
t.Parallel()
@ -295,7 +341,7 @@ func TestInvalidDir(t *testing.T) {
// directory.
func TestVerifyDir(t *testing.T) {
// Create temporary directory with a file "target" in it.
dir := tempDir(t)
dir := t.TempDir()
target := filepath.Join(dir, "target")
err := ioutil.WriteFile(target, []byte{}, 0666)
require.NoError(t, err)
@ -492,11 +538,24 @@ func validateSCPContents(t *testing.T, expected testFS, actual FileSystem) {
}
// validateSCP verifies that the specified pair of FileSystems match.
// FileSystem match if their contents match incl. access/modification times
func validateSCP(t *testing.T, expected testFS, actual FileSystem) {
for path, fileinfo := range expected.fs {
targetFileinfo, err := actual.GetFileInfo(path)
require.NoError(t, err)
require.NoError(t, err, "expected %v", path)
if fileinfo.IsDir() {
require.True(t, targetFileinfo.IsDir())
} else {
require.True(t, targetFileinfo.GetModePerm().IsRegular())
}
}
}
// validateSCPTimes verifies that the specified pair of FileSystems match.
// FileSystem match if their contents match incl. access/modification times
func validateSCPTimes(t *testing.T, expected testFS, actual FileSystem) {
for path, fileinfo := range expected.fs {
targetFileinfo, err := actual.GetFileInfo(path)
require.NoError(t, err, "expected %v", path)
if fileinfo.IsDir() {
require.True(t, targetFileinfo.IsDir())
} else {
@ -691,13 +750,6 @@ type nopReadCloser struct {
var errMissingFile = fmt.Errorf("no such file or directory")
func tempDir(t *testing.T) (dir string) {
path, err := ioutil.TempDir("", "test")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(path) })
return path
}
func newSourceConfig(path string, flags Flags) Config {
flags.Source = true
flags.Target = []string{path}
@ -707,6 +759,12 @@ func newSourceConfig(path string, flags Flags) Config {
}
}
func newTargetConfigWithFS(path string, flags Flags, fs testFS) Config {
config := newTargetConfig(path, flags)
config.FileSystem = &fs
return config
}
func newTargetConfig(path string, flags Flags) Config {
flags.Sink = true
flags.Target = []string{path}
@ -716,7 +774,25 @@ func newTargetConfig(path string, flags Flags) Config {
}
}
func newDir(name string, modtime, atime time.Time, ents ...*testFileInfo) *testFileInfo {
func newDir(name string, ents ...*testFileInfo) *testFileInfo {
return &testFileInfo{
path: name,
ents: ents,
dir: true,
perms: 0755,
}
}
func newFile(name string, contents string) *testFileInfo {
return &testFileInfo{
path: name,
perms: 0666,
size: int64(len(contents)),
contents: bytes.NewBufferString(contents),
}
}
func newDirTimes(name string, modtime, atime time.Time, ents ...*testFileInfo) *testFileInfo {
return &testFileInfo{
path: name,
ents: ents,
@ -727,7 +803,7 @@ func newDir(name string, modtime, atime time.Time, ents ...*testFileInfo) *testF
}
}
func newFile(name string, modtime, atime time.Time, contents string) *testFileInfo {
func newFileTimes(name string, modtime, atime time.Time, contents string) *testFileInfo {
return &testFileInfo{
path: name,
modtime: modtime,

View file

@ -75,14 +75,23 @@ func (r *RemoveDirCloser) Close() error {
}
// IsDir is a helper function to quickly check if a given path is a valid directory
func IsDir(dirPath string) bool {
fi, err := os.Stat(dirPath)
func IsDir(path string) bool {
fi, err := os.Stat(path)
if err == nil {
return fi.IsDir()
}
return false
}
// IsFile is a convenience helper to check if the given path is a regular file
func IsFile(path string) bool {
fi, err := os.Stat(path)
if err == nil {
return fi.Mode().IsRegular()
}
return false
}
// NormalizePath normalises path, evaluating symlinks and converting local
// paths to absolute
func NormalizePath(path string) (string, error) {

View file

@ -1410,18 +1410,17 @@ func onSCP(cf *CLIConf) error {
PreserveAttrs: cf.PreserveAttrs,
}
err = client.RetryWithRelogin(cf.Context, tc, func() error {
return tc.SCP(context.TODO(), cf.CopySpec, int(cf.NodePort), flags, cf.Quiet)
return tc.SCP(cf.Context, cf.CopySpec, int(cf.NodePort), flags, cf.Quiet)
})
if err != nil {
// exit with the same exit status as the failed command:
if tc.ExitStatus != 0 {
fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err))
os.Exit(tc.ExitStatus)
} else {
return trace.Wrap(err)
}
if err == nil {
return nil
}
return nil
// exit with the same exit status as the failed command:
if tc.ExitStatus != 0 {
fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err))
os.Exit(tc.ExitStatus)
}
return trace.Wrap(err)
}
// makeClient takes the command-line configuration and constructs & returns