diff --git a/lib/client/api.go b/lib/client/api.go index 64816a9fc9a..81c48a3afc3 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1426,7 +1426,7 @@ func (tc *TeleportClient) ExecuteSCP(ctx context.Context, cmd scp.Command) (err return trace.Wrap(err) } - err = nodeClient.ExecuteSCP(cmd) + err = nodeClient.ExecuteSCP(ctx, cmd) if err != nil { // converts SSH error code to tc.ExitStatus exitError, _ := trace.Unwrap(err).(*ssh.ExitError) @@ -1463,21 +1463,24 @@ func (tc *TeleportClient) SCP(ctx context.Context, args []string, port int, flag } defer proxyClient.Close() + var progressWriter io.Writer + if !quiet { + progressWriter = tc.Stdout + } + // helper function connects to the src/target node: - connectToNode := func(addr string) (*NodeClient, error) { + connectToNode := func(addr, hostLogin string) (*NodeClient, error) { // determine which cluster we're connecting to: siteInfo, err := proxyClient.currentCluster() if err != nil { return nil, trace.Wrap(err) } + if hostLogin == "" { + hostLogin = tc.Config.HostLogin + } return proxyClient.ConnectToNode(ctx, NodeAddr{Addr: addr, Namespace: tc.Namespace, Cluster: siteInfo.Name}, - tc.HostLogin, false) - } - - var progressWriter io.Writer - if !quiet { - progressWriter = tc.Stdout + hostLogin, false) } // gets called to convert SSH error code to tc.ExitStatus @@ -1488,88 +1491,102 @@ func (tc *TeleportClient) SCP(ctx context.Context, args []string, port int, flag } return err } + + tpl := scp.Config{ + User: tc.Username, + ProgressWriter: progressWriter, + Flags: flags, + } + + var config *scpConfig // upload: if isRemoteDest(last) { - filesToUpload := args[:len(args)-1] - - // If more than a single file were provided, scp must be in directory mode - // and the target on the remote host needs to be a directory. - var directoryMode bool - if len(filesToUpload) > 1 { - directoryMode = true - } - - dest, err := scp.ParseSCPDestination(last) + config, err = tc.uploadConfig(ctx, tpl, port, args) if err != nil { return trace.Wrap(err) } - if dest.Login != "" { - tc.HostLogin = dest.Login - } - addr := net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port)) - - client, err := connectToNode(addr) - if err != nil { - return trace.Wrap(err) - } - - // copy everything except the last arg (that's destination) - for _, src := range filesToUpload { - scpConfig := scp.Config{ - User: tc.Username, - ProgressWriter: progressWriter, - RemoteLocation: dest.Path, - Flags: flags, - } - scpConfig.Flags.Target = []string{src} - scpConfig.Flags.DirectoryMode = directoryMode - - cmd, err := scp.CreateUploadCommand(scpConfig) - if err != nil { - return trace.Wrap(err) - } - - err = client.ExecuteSCP(cmd) - if err != nil { - return onError(err) - } - } } else { - // download: - src, err := scp.ParseSCPDestination(first) + config, err = tc.downloadConfig(ctx, tpl, port, args) if err != nil { return trace.Wrap(err) } - addr := net.JoinHostPort(src.Host.Host(), strconv.Itoa(port)) - if src.Login != "" { - tc.HostLogin = src.Login - } - client, err := connectToNode(addr) - if err != nil { - return trace.Wrap(err) - } - // copy everything except the last arg (that's destination) - for _, dest := range args[1:] { - scpConfig := scp.Config{ - User: tc.Username, - Flags: flags, - RemoteLocation: src.Path, - ProgressWriter: progressWriter, - } - scpConfig.Flags.Target = []string{dest} - - cmd, err := scp.CreateDownloadCommand(scpConfig) - if err != nil { - return trace.Wrap(err) - } - - err = client.ExecuteSCP(cmd) - if err != nil { - return onError(err) - } - } } - return nil + + client, err := connectToNode(config.addr, config.hostLogin) + if err != nil { + return trace.Wrap(err) + } + + return onError(client.ExecuteSCP(ctx, config.cmd)) +} + +func (tc *TeleportClient) uploadConfig(ctx context.Context, tpl scp.Config, port int, args []string) (config *scpConfig, err error) { + filesToUpload := args[:len(args)-1] + // copy everything except the last arg (the destination) + destPath := args[len(args)-1] + + // If more than a single file were provided, scp must be in directory mode + // and the target on the remote host needs to be a directory. + var directoryMode bool + if len(filesToUpload) > 1 { + directoryMode = true + } + + dest, addr, err := getSCPDestination(destPath, port) + if err != nil { + return nil, trace.Wrap(err) + } + + tpl.RemoteLocation = dest.Path + tpl.Flags.Target = filesToUpload + tpl.Flags.DirectoryMode = directoryMode + + cmd, err := scp.CreateUploadCommand(tpl) + if err != nil { + return nil, trace.Wrap(err) + } + + return &scpConfig{ + cmd: cmd, + addr: addr, + hostLogin: dest.Login, + }, nil +} + +func (tc *TeleportClient) downloadConfig(ctx context.Context, tpl scp.Config, port int, args []string) (config *scpConfig, err error) { + src, addr, err := getSCPDestination(args[0], port) + if err != nil { + return nil, trace.Wrap(err) + } + + tpl.RemoteLocation = src.Path + tpl.Flags.Target = args[1:] + + cmd, err := scp.CreateDownloadCommand(tpl) + if err != nil { + return nil, trace.Wrap(err) + } + + return &scpConfig{ + cmd: cmd, + addr: addr, + hostLogin: src.Login, + }, nil +} + +type scpConfig struct { + cmd scp.Command + addr string + hostLogin string +} + +func getSCPDestination(target string, port int) (dest *scp.Destination, addr string, err error) { + dest, err = scp.ParseSCPDestination(target) + if err != nil { + return nil, "", trace.Wrap(err) + } + addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port)) + return dest, addr, nil } func isRemoteDest(name string) bool { diff --git a/lib/client/client.go b/lib/client/client.go index c20ff10e6de..0ed56d5edc7 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -20,9 +20,11 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "io/ioutil" "net" + "os" "strconv" "strings" "time" @@ -854,7 +856,7 @@ func (proxy *ProxyClient) Close() error { // ExecuteSCP runs remote scp command(shellCmd) on the remote server and // runs local scp handler using SCP Command -func (c *NodeClient) ExecuteSCP(cmd scp.Command) error { +func (c *NodeClient) ExecuteSCP(ctx context.Context, cmd scp.Command) error { shellCmd, err := cmd.GetRemoteShellCmd() if err != nil { return trace.Wrap(err) @@ -876,6 +878,14 @@ func (c *NodeClient) ExecuteSCP(cmd scp.Command) error { return trace.Wrap(err) } + // Stream scp's stderr so tsh gets the verbose remote error + // if the command fails + stderr, err := s.StderrPipe() + if err != nil { + return trace.Wrap(err) + } + go io.Copy(os.Stderr, stderr) + ch := utils.NewPipeNetConn( stdout, stdin, @@ -884,18 +894,40 @@ func (c *NodeClient) ExecuteSCP(cmd scp.Command) error { &net.IPAddr{}, ) - closeC := make(chan error, 1) + execC := make(chan error, 1) go func() { err := cmd.Execute(ch) - if err != nil { - log.Error(err) + if err != nil && !trace.IsEOF(err) { + log.WithError(err).Warn("Failed to execute SCP command.") } stdin.Close() - closeC <- err + execC <- err }() - runErr := s.Run(shellCmd) - err = <-closeC + runC := make(chan error, 1) + go func() { + err := s.Run(shellCmd) + if err != nil && errors.Is(err, &ssh.ExitMissingError{}) { + // TODO(dmitri): currently, if the session is aborted with (*session).Close, + // the remote side cannot send exit-status and this error results. + // To abort the session properly, Teleport needs to support `signal` request + err = nil + } + runC <- err + }() + + var runErr error + select { + case <-ctx.Done(): + if err := s.Close(); err != nil { + log.WithError(err).Debug("Failed to close the SSH session.") + } + err, runErr = <-execC, <-runC + case err = <-execC: + runErr = <-runC + case runErr = <-runC: + err = <-execC + } if runErr != nil && (err == nil || trace.IsEOF(err)) { err = runErr diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 9a8207262d7..bfb1ae18981 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -639,7 +639,7 @@ func (c *ServerContext) SendExecResult(r ExecResult) { select { case c.ExecResultCh <- r: default: - log.Infof("blocked on sending exec result %v", r) + c.Infof("Blocked on sending exec result %v.", r) } } @@ -649,7 +649,7 @@ func (c *ServerContext) SendSubsystemResult(r SubsystemResult) { select { case c.SubsystemResultCh <- r: default: - c.Infof("blocked on sending subsystem result") + c.Info("Blocked on sending subsystem result.") } } diff --git a/lib/srv/exec.go b/lib/srv/exec.go index 4c078f9dea3..e8cf6232d44 100644 --- a/lib/srv/exec.go +++ b/lib/srv/exec.go @@ -149,20 +149,14 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) { // Connect stdout and stderr to the channel so the user can interact with // the command. - e.Cmd.Stderr = channel.Stderr() - e.Cmd.Stdout = channel + e.Cmd.Stderr = io.MultiWriter(os.Stderr, channel.Stderr()) + e.Cmd.Stdout = io.MultiWriter(os.Stdout, channel) // Copy from the channel (client input) into stdin of the process. inputWriter, err := e.Cmd.StdinPipe() if err != nil { return nil, trace.Wrap(err) } - go func() { - if _, err := io.Copy(inputWriter, channel); err != nil { - e.Ctx.Warningf("Failed to forward data from SSH channel to local command %q stdin: %v", e.GetCommand(), err) - } - inputWriter.Close() - }() // Start the command. err = e.Cmd.Start() @@ -178,6 +172,13 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) { }, trace.ConvertSystemError(err) } + go func() { + if _, err := io.Copy(inputWriter, channel); err != nil { + e.Ctx.Warnf("Failed to forward data from SSH channel to local command %q stdin: %v", e.GetCommand(), err) + } + inputWriter.Close() + }() + e.Ctx.Infof("Started local command execution: %q", e.Command) return nil, nil @@ -186,7 +187,7 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) { // Wait will block while the command executes. func (e *localExec) Wait() *ExecResult { if e.Cmd.Process == nil { - e.Ctx.Errorf("no process") + e.Ctx.Error("No process.") } // Block until the command is finished executing. diff --git a/lib/sshutils/scp/scp.go b/lib/sshutils/scp/scp.go index 46ed5a914f8..3e91bff5c74 100644 --- a/lib/sshutils/scp/scp.go +++ b/lib/sshutils/scp/scp.go @@ -90,6 +90,8 @@ type Config struct { // RunOnServer is low level API flag that indicates that // this command will be run on the server RunOnServer bool + // Log optionally specifies the logger + Log log.FieldLogger } // Command is an API that describes command operations @@ -167,10 +169,25 @@ func CreateUploadCommand(cfg Config) (Command, error) { // CheckAndSetDefaults checks and sets default values func (c *Config) CheckAndSetDefaults() error { + logger := c.Log + if logger == nil { + logger = log.StandardLogger() + } + c.Log = logger.WithFields(log.Fields{ + trace.Component: "SCP", + trace.ComponentFields: log.Fields{ + "LocalAddr": c.Flags.LocalAddr, + "RemoteAddr": c.Flags.RemoteAddr, + "Target": c.Flags.Target, + "PreserveAttrs": c.Flags.PreserveAttrs, + "User": c.User, + "RunOnServer": c.RunOnServer, + "RemoteLocation": c.RemoteLocation, + }, + }) if c.FileSystem == nil { c.FileSystem = &localFileSystem{} } - if c.User == "" { return trace.BadParameter("missing User parameter") } @@ -186,31 +203,17 @@ func CreateCommand(cfg Config) (Command, error) { return nil, trace.Wrap(err) } - cmd := command{ + return &command{ Config: cfg, - } - - cmd.log = log.WithFields(log.Fields{ - trace.Component: "SCP", - trace.ComponentFields: log.Fields{ - "LocalAddr": cfg.Flags.LocalAddr, - "RemoteAddr": cfg.Flags.RemoteAddr, - "Target": cfg.Flags.Target, - "PreserveAttrs": cfg.Flags.PreserveAttrs, - "User": cfg.User, - "RunOnServer": cfg.RunOnServer, - "RemoteLocation": cfg.RemoteLocation, - }, - }) - - return &cmd, nil + log: cfg.Log, + }, nil } // Command mimics behavior of SCP command line tool // to teleport can pretend it launches real SCP behind the scenes type command struct { Config - log *log.Entry + log log.FieldLogger } // Execute implements SSH file copy (SCP). It is called on both tsh (client) @@ -333,7 +336,7 @@ func (cmd *command) sendDir(r *reader, ch io.ReadWriter, fileInfo FileInfo) erro if _, err = fmt.Fprintf(ch, "E\n"); err != nil { return trace.Wrap(err) } - return r.read() + return trace.Wrap(r.read()) } func (cmd *command) sendFile(r *reader, ch io.ReadWriter, fileInfo FileInfo) error { @@ -375,7 +378,7 @@ func (cmd *command) sendFile(r *reader, ch io.ReadWriter, fileInfo FileInfo) err func (cmd *command) sendErr(ch io.Writer, err error) { out := fmt.Sprintf("%c%s\n", byte(ErrByte), err) if _, err := ch.Write([]byte(out)); err != nil { - cmd.log.Debugf("failed sending SCP error message to the remote side: %v", err) + cmd.log.Debugf("Failed sending SCP error message to the remote side: %v.", err) } } @@ -386,23 +389,25 @@ func (cmd *command) serveSink(ch io.ReadWriter) error { // directory. if cmd.Flags.DirectoryMode { if len(cmd.Flags.Target) != 1 { - return trace.BadParameter("in directory mode, only single upload target is allowed but %v provided", len(cmd.Flags.Target)) + return trace.BadParameter("in directory mode, only single upload target is allowed but %q provided", + cmd.Flags.Target) } - - fi, err := os.Stat(cmd.Flags.Target[0]) - if err != nil { - return trace.Wrap(err) - } - if mode := fi.Mode(); !mode.IsDir() { + if !cmd.FileSystem.IsDir(cmd.Flags.Target[0]) { return trace.BadParameter("target path must be a directory") } } + rootDir := localDir + if cmd.hasTargetDir() { + rootDir = newPathFromDir(cmd.Flags.Target[0]) + } + if err := sendOK(ch); err != nil { return trace.Wrap(err) } + var st state - st.path = localDir + st.path = rootDir var b [1]byte scanner := bufio.NewScanner(ch) for { @@ -478,14 +483,13 @@ func (cmd *command) processCommand(ch io.ReadWriter, st *state, b byte, line str func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) error { cmd.log.Debugf("scp.receiveFile(%v): %v", cmd.Flags.Target, fc.Name) - // if the destination path is a folder, we should save the file to that folder, but - // only if 'recursive' is set - - path := cmd.Flags.Target[0] - if cmd.Flags.Recursive || cmd.FileSystem.IsDir(path) { - path = st.makePath(fc.Name) + // Unless target specifies a file, use the file name from the command + filename := fc.Name + if !cmd.Flags.Recursive && !cmd.FileSystem.IsDir(cmd.Flags.Target[0]) { + filename = cmd.Flags.Target[0] } + path := st.makePath(filename) writer, err := cmd.FileSystem.CreateFile(path, fc.Length) if err != nil { return trace.Wrap(err) @@ -504,7 +508,6 @@ func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) erro n, err := io.CopyN(writer, ch, int64(fc.Length)) if err != nil { - cmd.log.Error(err) return trace.Wrap(err) } @@ -528,15 +531,9 @@ func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) erro func (cmd *command) receiveDir(st *state, fc newFileCmd, ch io.ReadWriter) error { cmd.log.Debugf("scp.receiveDir(%v): %v", cmd.Flags.Target, fc.Name) - targetDir := cmd.Flags.Target[0] - // copying into an existing directory? append to it: - if cmd.FileSystem.IsDir(targetDir) { - targetDir = st.makePath(fc.Name) - } st.push(fc.Name, st.stat) - - err := cmd.FileSystem.MkDir(targetDir, int(fc.Mode)) + err := cmd.FileSystem.MkDir(st.path.join(), int(fc.Mode)) if err != nil { return trace.ConvertSystemError(err) } @@ -599,6 +596,14 @@ func (cmd *command) updateDirTimes(path pathSegments) error { return nil } +func (cmd *command) hasTargetDir() bool { + return len(cmd.Flags.Target) != 0 && cmd.FileSystem.IsDir(cmd.Flags.Target[0]) +} + +func (r newFileCmd) String() string { + return fmt.Sprintf("newFileCmd(mode=%o,len=%d,name=%v)", r.Mode, r.Length, r.Name) +} + type newFileCmd struct { Mode int64 Length uint64 @@ -675,15 +680,19 @@ type state struct { stat *mtimeCmd } -func (r pathSegments) join() string { +func (r pathSegments) join(elems ...string) string { path := make([]string, 0, len(r)) for _, s := range r { path = append(path, s.dir) } - return filepath.Join(path...) + return filepath.Join(append(path, elems...)...) } -var localDir = pathSegments{{dir: "."}} +var localDir = newPathFromDir(".") + +func newPathFromDir(dir string) pathSegments { + return pathSegments{{dir: dir}} +} type pathSegments []pathSegment @@ -710,7 +719,7 @@ func (st *state) pop() pathSegments { } func (st *state) makePath(filename string) string { - return filepath.Join(st.path.join(), filename) + return st.path.join(filename) } func newReader(r io.Reader) *reader { diff --git a/lib/sshutils/scp/scp_test.go b/lib/sshutils/scp/scp_test.go index ed9774b969b..884650ebf0d 100644 --- a/lib/sshutils/scp/scp_test.go +++ b/lib/sshutils/scp/scp_test.go @@ -30,6 +30,7 @@ import ( "time" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" "github.com/google/go-cmp/cmp" @@ -38,7 +39,7 @@ import ( ) func TestHTTPSendFile(t *testing.T) { - outDir := tempDir(t) + outDir := t.TempDir() expectedBytes := []byte("hello") buf := bytes.NewReader(expectedBytes) @@ -65,7 +66,7 @@ func TestHTTPSendFile(t *testing.T) { } func TestHTTPReceiveFile(t *testing.T) { - source := filepath.Join(tempDir(t), "target") + source := filepath.Join(t.TempDir(), "target") contents := []byte("hello, file contents!") err := ioutil.WriteFile(source, contents, 0666) @@ -112,7 +113,7 @@ func TestSend(t *testing.T) { desc: "regular file preserving the attributes", config: newSourceConfig("file", Flags{PreserveAttrs: true}), args: args("-v", "-t", "-p"), - fs: newTestFS(logger, newFile("file", modtime, atime, "file contents")), + fs: newTestFS(logger, newFileTimes("file", modtime, atime, "file contents")), }, { desc: "directory preserving the attributes", @@ -121,10 +122,10 @@ func TestSend(t *testing.T) { fs: newTestFS( logger, // Use timestamps extending backwards to test time application - newDir("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second), - newFile("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"), - newDir("dir/dir2", dirModtime, dirAtime, - newFile("dir/dir2/file2", modtime, atime, "file2 contents")), + newDirTimes("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second), + newFileTimes("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"), + newDirTimes("dir/dir2", dirModtime, dirAtime, + newFileTimes("dir/dir2/file2", modtime, atime, "file2 contents")), ), ), }, @@ -136,7 +137,7 @@ func TestSend(t *testing.T) { cmd, err := CreateCommand(tt.config) require.NoError(t, err) - targetDir := tempDir(t) + targetDir := t.TempDir() target := filepath.Join(targetDir, tt.config.Flags.Target[0]) args := append(tt.args, target) @@ -153,7 +154,7 @@ func TestSend(t *testing.T) { fs := newEmptyTestFS(logger) fromOS(t, targetDir, &fs) - validateSCP(t, fs, tt.fs) + validateSCPTimes(t, fs, tt.fs) validateSCPContents(t, fs, tt.fs) }) } @@ -177,7 +178,7 @@ func TestReceive(t *testing.T) { desc: "regular file preserving the attributes", config: newTargetConfig("file", Flags{PreserveAttrs: true}), args: args("-v", "-f", "-p"), - fs: newTestFS(logger, newFile("file", modtime, atime, "file contents")), + fs: newTestFS(logger, newFileTimes("file", modtime, atime, "file contents")), }, { desc: "directory preserving the attributes", @@ -186,10 +187,10 @@ func TestReceive(t *testing.T) { fs: newTestFS( logger, // Use timestamps extending backwards to test time application - newDir("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second), - newFile("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"), - newDir("dir/dir2", dirModtime, dirAtime, - newFile("dir/dir2/file2", modtime, atime, "file2 contents")), + newDirTimes("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second), + newFileTimes("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"), + newDirTimes("dir/dir2", dirModtime, dirAtime, + newFileTimes("dir/dir2/file2", modtime, atime, "file2 contents")), ), ), }, @@ -201,7 +202,7 @@ func TestReceive(t *testing.T) { cmd, err := CreateCommand(tt.config) require.NoError(t, err) - sourceDir := tempDir(t) + sourceDir := t.TempDir() source := filepath.Join(sourceDir, tt.config.Flags.Target[0]) args := append(tt.args, source) @@ -209,8 +210,7 @@ func TestReceive(t *testing.T) { err = runSCP(cmd, args...) require.Regexp(t, ".*No such file or directory", err) - fs := newEmptyTestFS(logger) - tt.config.FileSystem = fs + tt.config.FileSystem = newEmptyTestFS(logger) cmd, err = CreateCommand(tt.config) require.NoError(t, err) @@ -221,12 +221,58 @@ func TestReceive(t *testing.T) { err = runSCP(cmd, args...) require.NoError(t, err) - validateSCP(t, tt.fs, fs) - validateSCPContents(t, tt.fs, fs) + validateSCPTimes(t, tt.fs, tt.config.FileSystem) + validateSCPContents(t, tt.fs, tt.config.FileSystem) + }) } } +// TestReceiveIntoExistingDirectory validates that the target remote directory +// is respected during copy. +// +// See https://github.com/gravitational/teleport/issues/5497 +func TestReceiveIntoExistingDirectory(t *testing.T) { + utils.InitLoggerForTests(testing.Verbose()) + logger := logrus.WithField("test", t.Name()) + config := newTargetConfigWithFS("dir", + Flags{PreserveAttrs: true, Recursive: true}, + newTestFS(logger, newDir("dir")), + ) + sourceFS := newTestFS( + logger, + newDir("dir", + newFile("dir/file", "file contents"), + newDir("dir/dir2", + newFile("dir/dir2/file2", "file2 contents")), + ), + ) + expectedFS := newTestFS( + logger, + // Source is copied into an existing directory + newDir("dir/dir", + newFile("dir/dir/file", "file contents"), + newDir("dir/dir/dir2", + newFile("dir/dir/dir2/file2", "file2 contents")), + ), + ) + sourceDir := t.TempDir() + source := filepath.Join(sourceDir, config.Flags.Target[0]) + args := append(args("-v", "-f", "-r", "-p"), source) + + cmd, err := CreateCommand(config) + require.NoError(t, err) + + writeData(t, sourceDir, sourceFS) + writeFileTimes(t, sourceDir, sourceFS) + + err = runSCP(cmd, args...) + require.NoError(t, err) + + validateSCP(t, expectedFS, config.FileSystem) + validateSCPContents(t, expectedFS, config.FileSystem) +} + func TestInvalidDir(t *testing.T) { t.Parallel() @@ -295,7 +341,7 @@ func TestInvalidDir(t *testing.T) { // directory. func TestVerifyDir(t *testing.T) { // Create temporary directory with a file "target" in it. - dir := tempDir(t) + dir := t.TempDir() target := filepath.Join(dir, "target") err := ioutil.WriteFile(target, []byte{}, 0666) require.NoError(t, err) @@ -492,11 +538,24 @@ func validateSCPContents(t *testing.T, expected testFS, actual FileSystem) { } // validateSCP verifies that the specified pair of FileSystems match. -// FileSystem match if their contents match incl. access/modification times func validateSCP(t *testing.T, expected testFS, actual FileSystem) { for path, fileinfo := range expected.fs { targetFileinfo, err := actual.GetFileInfo(path) - require.NoError(t, err) + require.NoError(t, err, "expected %v", path) + if fileinfo.IsDir() { + require.True(t, targetFileinfo.IsDir()) + } else { + require.True(t, targetFileinfo.GetModePerm().IsRegular()) + } + } +} + +// validateSCPTimes verifies that the specified pair of FileSystems match. +// FileSystem match if their contents match incl. access/modification times +func validateSCPTimes(t *testing.T, expected testFS, actual FileSystem) { + for path, fileinfo := range expected.fs { + targetFileinfo, err := actual.GetFileInfo(path) + require.NoError(t, err, "expected %v", path) if fileinfo.IsDir() { require.True(t, targetFileinfo.IsDir()) } else { @@ -691,13 +750,6 @@ type nopReadCloser struct { var errMissingFile = fmt.Errorf("no such file or directory") -func tempDir(t *testing.T) (dir string) { - path, err := ioutil.TempDir("", "test") - require.NoError(t, err) - t.Cleanup(func() { os.RemoveAll(path) }) - return path -} - func newSourceConfig(path string, flags Flags) Config { flags.Source = true flags.Target = []string{path} @@ -707,6 +759,12 @@ func newSourceConfig(path string, flags Flags) Config { } } +func newTargetConfigWithFS(path string, flags Flags, fs testFS) Config { + config := newTargetConfig(path, flags) + config.FileSystem = &fs + return config +} + func newTargetConfig(path string, flags Flags) Config { flags.Sink = true flags.Target = []string{path} @@ -716,7 +774,25 @@ func newTargetConfig(path string, flags Flags) Config { } } -func newDir(name string, modtime, atime time.Time, ents ...*testFileInfo) *testFileInfo { +func newDir(name string, ents ...*testFileInfo) *testFileInfo { + return &testFileInfo{ + path: name, + ents: ents, + dir: true, + perms: 0755, + } +} + +func newFile(name string, contents string) *testFileInfo { + return &testFileInfo{ + path: name, + perms: 0666, + size: int64(len(contents)), + contents: bytes.NewBufferString(contents), + } +} + +func newDirTimes(name string, modtime, atime time.Time, ents ...*testFileInfo) *testFileInfo { return &testFileInfo{ path: name, ents: ents, @@ -727,7 +803,7 @@ func newDir(name string, modtime, atime time.Time, ents ...*testFileInfo) *testF } } -func newFile(name string, modtime, atime time.Time, contents string) *testFileInfo { +func newFileTimes(name string, modtime, atime time.Time, contents string) *testFileInfo { return &testFileInfo{ path: name, modtime: modtime, diff --git a/lib/utils/fs.go b/lib/utils/fs.go index 640f0af3b66..b143dcf9fd2 100644 --- a/lib/utils/fs.go +++ b/lib/utils/fs.go @@ -75,14 +75,23 @@ func (r *RemoveDirCloser) Close() error { } // IsDir is a helper function to quickly check if a given path is a valid directory -func IsDir(dirPath string) bool { - fi, err := os.Stat(dirPath) +func IsDir(path string) bool { + fi, err := os.Stat(path) if err == nil { return fi.IsDir() } return false } +// IsFile is a convenience helper to check if the given path is a regular file +func IsFile(path string) bool { + fi, err := os.Stat(path) + if err == nil { + return fi.Mode().IsRegular() + } + return false +} + // NormalizePath normalises path, evaluating symlinks and converting local // paths to absolute func NormalizePath(path string) (string, error) { diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 0fba7e1bd80..ba52a97018d 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -1410,18 +1410,17 @@ func onSCP(cf *CLIConf) error { PreserveAttrs: cf.PreserveAttrs, } err = client.RetryWithRelogin(cf.Context, tc, func() error { - return tc.SCP(context.TODO(), cf.CopySpec, int(cf.NodePort), flags, cf.Quiet) + return tc.SCP(cf.Context, cf.CopySpec, int(cf.NodePort), flags, cf.Quiet) }) - if err != nil { - // exit with the same exit status as the failed command: - if tc.ExitStatus != 0 { - fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) - os.Exit(tc.ExitStatus) - } else { - return trace.Wrap(err) - } + if err == nil { + return nil } - return nil + // exit with the same exit status as the failed command: + if tc.ExitStatus != 0 { + fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err)) + os.Exit(tc.ExitStatus) + } + return trace.Wrap(err) } // makeClient takes the command-line configuration and constructs & returns