mirror of
https://github.com/gravitational/teleport
synced 2024-10-22 02:03:24 +00:00
tsh scp to use target directory correctly (#5501)
* Fixes the scp logic to take target directory into account in sink mode. Also expose channel error in scp client so the error is more visible to the user. Old behavior will only output the 'exit code n' if anything breaks. Fixes https://github.com/gravitational/teleport/issues/5497. * Silence 'wait: remote command exited without exit status or exit signal' error when interrupting the scp session. Leave a TODO to fix properly in a future PR * Address review comments
This commit is contained in:
parent
68bc78f10e
commit
e65eac59b0
|
@ -1426,7 +1426,7 @@ func (tc *TeleportClient) ExecuteSCP(ctx context.Context, cmd scp.Command) (err
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
err = nodeClient.ExecuteSCP(cmd)
|
||||
err = nodeClient.ExecuteSCP(ctx, cmd)
|
||||
if err != nil {
|
||||
// converts SSH error code to tc.ExitStatus
|
||||
exitError, _ := trace.Unwrap(err).(*ssh.ExitError)
|
||||
|
@ -1463,21 +1463,24 @@ func (tc *TeleportClient) SCP(ctx context.Context, args []string, port int, flag
|
|||
}
|
||||
defer proxyClient.Close()
|
||||
|
||||
var progressWriter io.Writer
|
||||
if !quiet {
|
||||
progressWriter = tc.Stdout
|
||||
}
|
||||
|
||||
// helper function connects to the src/target node:
|
||||
connectToNode := func(addr string) (*NodeClient, error) {
|
||||
connectToNode := func(addr, hostLogin string) (*NodeClient, error) {
|
||||
// determine which cluster we're connecting to:
|
||||
siteInfo, err := proxyClient.currentCluster()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
if hostLogin == "" {
|
||||
hostLogin = tc.Config.HostLogin
|
||||
}
|
||||
return proxyClient.ConnectToNode(ctx,
|
||||
NodeAddr{Addr: addr, Namespace: tc.Namespace, Cluster: siteInfo.Name},
|
||||
tc.HostLogin, false)
|
||||
}
|
||||
|
||||
var progressWriter io.Writer
|
||||
if !quiet {
|
||||
progressWriter = tc.Stdout
|
||||
hostLogin, false)
|
||||
}
|
||||
|
||||
// gets called to convert SSH error code to tc.ExitStatus
|
||||
|
@ -1488,9 +1491,39 @@ func (tc *TeleportClient) SCP(ctx context.Context, args []string, port int, flag
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
tpl := scp.Config{
|
||||
User: tc.Username,
|
||||
ProgressWriter: progressWriter,
|
||||
Flags: flags,
|
||||
}
|
||||
|
||||
var config *scpConfig
|
||||
// upload:
|
||||
if isRemoteDest(last) {
|
||||
config, err = tc.uploadConfig(ctx, tpl, port, args)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
} else {
|
||||
config, err = tc.downloadConfig(ctx, tpl, port, args)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
|
||||
client, err := connectToNode(config.addr, config.hostLogin)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
return onError(client.ExecuteSCP(ctx, config.cmd))
|
||||
}
|
||||
|
||||
func (tc *TeleportClient) uploadConfig(ctx context.Context, tpl scp.Config, port int, args []string) (config *scpConfig, err error) {
|
||||
filesToUpload := args[:len(args)-1]
|
||||
// copy everything except the last arg (the destination)
|
||||
destPath := args[len(args)-1]
|
||||
|
||||
// If more than a single file were provided, scp must be in directory mode
|
||||
// and the target on the remote host needs to be a directory.
|
||||
|
@ -1499,77 +1532,61 @@ func (tc *TeleportClient) SCP(ctx context.Context, args []string, port int, flag
|
|||
directoryMode = true
|
||||
}
|
||||
|
||||
dest, err := scp.ParseSCPDestination(last)
|
||||
dest, addr, err := getSCPDestination(destPath, port)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if dest.Login != "" {
|
||||
tc.HostLogin = dest.Login
|
||||
}
|
||||
addr := net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port))
|
||||
|
||||
client, err := connectToNode(addr)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
// copy everything except the last arg (that's destination)
|
||||
for _, src := range filesToUpload {
|
||||
scpConfig := scp.Config{
|
||||
User: tc.Username,
|
||||
ProgressWriter: progressWriter,
|
||||
RemoteLocation: dest.Path,
|
||||
Flags: flags,
|
||||
}
|
||||
scpConfig.Flags.Target = []string{src}
|
||||
scpConfig.Flags.DirectoryMode = directoryMode
|
||||
tpl.RemoteLocation = dest.Path
|
||||
tpl.Flags.Target = filesToUpload
|
||||
tpl.Flags.DirectoryMode = directoryMode
|
||||
|
||||
cmd, err := scp.CreateUploadCommand(scpConfig)
|
||||
cmd, err := scp.CreateUploadCommand(tpl)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
err = client.ExecuteSCP(cmd)
|
||||
if err != nil {
|
||||
return onError(err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// download:
|
||||
src, err := scp.ParseSCPDestination(first)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
addr := net.JoinHostPort(src.Host.Host(), strconv.Itoa(port))
|
||||
if src.Login != "" {
|
||||
tc.HostLogin = src.Login
|
||||
}
|
||||
client, err := connectToNode(addr)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
// copy everything except the last arg (that's destination)
|
||||
for _, dest := range args[1:] {
|
||||
scpConfig := scp.Config{
|
||||
User: tc.Username,
|
||||
Flags: flags,
|
||||
RemoteLocation: src.Path,
|
||||
ProgressWriter: progressWriter,
|
||||
}
|
||||
scpConfig.Flags.Target = []string{dest}
|
||||
return &scpConfig{
|
||||
cmd: cmd,
|
||||
addr: addr,
|
||||
hostLogin: dest.Login,
|
||||
}, nil
|
||||
}
|
||||
|
||||
cmd, err := scp.CreateDownloadCommand(scpConfig)
|
||||
func (tc *TeleportClient) downloadConfig(ctx context.Context, tpl scp.Config, port int, args []string) (config *scpConfig, err error) {
|
||||
src, addr, err := getSCPDestination(args[0], port)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
err = client.ExecuteSCP(cmd)
|
||||
tpl.RemoteLocation = src.Path
|
||||
tpl.Flags.Target = args[1:]
|
||||
|
||||
cmd, err := scp.CreateDownloadCommand(tpl)
|
||||
if err != nil {
|
||||
return onError(err)
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
return &scpConfig{
|
||||
cmd: cmd,
|
||||
addr: addr,
|
||||
hostLogin: src.Login,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type scpConfig struct {
|
||||
cmd scp.Command
|
||||
addr string
|
||||
hostLogin string
|
||||
}
|
||||
|
||||
func getSCPDestination(target string, port int) (dest *scp.Destination, addr string, err error) {
|
||||
dest, err = scp.ParseSCPDestination(target)
|
||||
if err != nil {
|
||||
return nil, "", trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
addr = net.JoinHostPort(dest.Host.Host(), strconv.Itoa(port))
|
||||
return dest, addr, nil
|
||||
}
|
||||
|
||||
func isRemoteDest(name string) bool {
|
||||
|
|
|
@ -20,9 +20,11 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -854,7 +856,7 @@ func (proxy *ProxyClient) Close() error {
|
|||
|
||||
// ExecuteSCP runs remote scp command(shellCmd) on the remote server and
|
||||
// runs local scp handler using SCP Command
|
||||
func (c *NodeClient) ExecuteSCP(cmd scp.Command) error {
|
||||
func (c *NodeClient) ExecuteSCP(ctx context.Context, cmd scp.Command) error {
|
||||
shellCmd, err := cmd.GetRemoteShellCmd()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
@ -876,6 +878,14 @@ func (c *NodeClient) ExecuteSCP(cmd scp.Command) error {
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// Stream scp's stderr so tsh gets the verbose remote error
|
||||
// if the command fails
|
||||
stderr, err := s.StderrPipe()
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
go io.Copy(os.Stderr, stderr)
|
||||
|
||||
ch := utils.NewPipeNetConn(
|
||||
stdout,
|
||||
stdin,
|
||||
|
@ -884,18 +894,40 @@ func (c *NodeClient) ExecuteSCP(cmd scp.Command) error {
|
|||
&net.IPAddr{},
|
||||
)
|
||||
|
||||
closeC := make(chan error, 1)
|
||||
execC := make(chan error, 1)
|
||||
go func() {
|
||||
err := cmd.Execute(ch)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if err != nil && !trace.IsEOF(err) {
|
||||
log.WithError(err).Warn("Failed to execute SCP command.")
|
||||
}
|
||||
stdin.Close()
|
||||
closeC <- err
|
||||
execC <- err
|
||||
}()
|
||||
|
||||
runErr := s.Run(shellCmd)
|
||||
err = <-closeC
|
||||
runC := make(chan error, 1)
|
||||
go func() {
|
||||
err := s.Run(shellCmd)
|
||||
if err != nil && errors.Is(err, &ssh.ExitMissingError{}) {
|
||||
// TODO(dmitri): currently, if the session is aborted with (*session).Close,
|
||||
// the remote side cannot send exit-status and this error results.
|
||||
// To abort the session properly, Teleport needs to support `signal` request
|
||||
err = nil
|
||||
}
|
||||
runC <- err
|
||||
}()
|
||||
|
||||
var runErr error
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := s.Close(); err != nil {
|
||||
log.WithError(err).Debug("Failed to close the SSH session.")
|
||||
}
|
||||
err, runErr = <-execC, <-runC
|
||||
case err = <-execC:
|
||||
runErr = <-runC
|
||||
case runErr = <-runC:
|
||||
err = <-execC
|
||||
}
|
||||
|
||||
if runErr != nil && (err == nil || trace.IsEOF(err)) {
|
||||
err = runErr
|
||||
|
|
|
@ -639,7 +639,7 @@ func (c *ServerContext) SendExecResult(r ExecResult) {
|
|||
select {
|
||||
case c.ExecResultCh <- r:
|
||||
default:
|
||||
log.Infof("blocked on sending exec result %v", r)
|
||||
c.Infof("Blocked on sending exec result %v.", r)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -649,7 +649,7 @@ func (c *ServerContext) SendSubsystemResult(r SubsystemResult) {
|
|||
select {
|
||||
case c.SubsystemResultCh <- r:
|
||||
default:
|
||||
c.Infof("blocked on sending subsystem result")
|
||||
c.Info("Blocked on sending subsystem result.")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -149,20 +149,14 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) {
|
|||
|
||||
// Connect stdout and stderr to the channel so the user can interact with
|
||||
// the command.
|
||||
e.Cmd.Stderr = channel.Stderr()
|
||||
e.Cmd.Stdout = channel
|
||||
e.Cmd.Stderr = io.MultiWriter(os.Stderr, channel.Stderr())
|
||||
e.Cmd.Stdout = io.MultiWriter(os.Stdout, channel)
|
||||
|
||||
// Copy from the channel (client input) into stdin of the process.
|
||||
inputWriter, err := e.Cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
go func() {
|
||||
if _, err := io.Copy(inputWriter, channel); err != nil {
|
||||
e.Ctx.Warningf("Failed to forward data from SSH channel to local command %q stdin: %v", e.GetCommand(), err)
|
||||
}
|
||||
inputWriter.Close()
|
||||
}()
|
||||
|
||||
// Start the command.
|
||||
err = e.Cmd.Start()
|
||||
|
@ -178,6 +172,13 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) {
|
|||
}, trace.ConvertSystemError(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(inputWriter, channel); err != nil {
|
||||
e.Ctx.Warnf("Failed to forward data from SSH channel to local command %q stdin: %v", e.GetCommand(), err)
|
||||
}
|
||||
inputWriter.Close()
|
||||
}()
|
||||
|
||||
e.Ctx.Infof("Started local command execution: %q", e.Command)
|
||||
|
||||
return nil, nil
|
||||
|
@ -186,7 +187,7 @@ func (e *localExec) Start(channel ssh.Channel) (*ExecResult, error) {
|
|||
// Wait will block while the command executes.
|
||||
func (e *localExec) Wait() *ExecResult {
|
||||
if e.Cmd.Process == nil {
|
||||
e.Ctx.Errorf("no process")
|
||||
e.Ctx.Error("No process.")
|
||||
}
|
||||
|
||||
// Block until the command is finished executing.
|
||||
|
|
|
@ -90,6 +90,8 @@ type Config struct {
|
|||
// RunOnServer is low level API flag that indicates that
|
||||
// this command will be run on the server
|
||||
RunOnServer bool
|
||||
// Log optionally specifies the logger
|
||||
Log log.FieldLogger
|
||||
}
|
||||
|
||||
// Command is an API that describes command operations
|
||||
|
@ -167,10 +169,25 @@ func CreateUploadCommand(cfg Config) (Command, error) {
|
|||
|
||||
// CheckAndSetDefaults checks and sets default values
|
||||
func (c *Config) CheckAndSetDefaults() error {
|
||||
logger := c.Log
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
c.Log = logger.WithFields(log.Fields{
|
||||
trace.Component: "SCP",
|
||||
trace.ComponentFields: log.Fields{
|
||||
"LocalAddr": c.Flags.LocalAddr,
|
||||
"RemoteAddr": c.Flags.RemoteAddr,
|
||||
"Target": c.Flags.Target,
|
||||
"PreserveAttrs": c.Flags.PreserveAttrs,
|
||||
"User": c.User,
|
||||
"RunOnServer": c.RunOnServer,
|
||||
"RemoteLocation": c.RemoteLocation,
|
||||
},
|
||||
})
|
||||
if c.FileSystem == nil {
|
||||
c.FileSystem = &localFileSystem{}
|
||||
}
|
||||
|
||||
if c.User == "" {
|
||||
return trace.BadParameter("missing User parameter")
|
||||
}
|
||||
|
@ -186,31 +203,17 @@ func CreateCommand(cfg Config) (Command, error) {
|
|||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
cmd := command{
|
||||
return &command{
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
cmd.log = log.WithFields(log.Fields{
|
||||
trace.Component: "SCP",
|
||||
trace.ComponentFields: log.Fields{
|
||||
"LocalAddr": cfg.Flags.LocalAddr,
|
||||
"RemoteAddr": cfg.Flags.RemoteAddr,
|
||||
"Target": cfg.Flags.Target,
|
||||
"PreserveAttrs": cfg.Flags.PreserveAttrs,
|
||||
"User": cfg.User,
|
||||
"RunOnServer": cfg.RunOnServer,
|
||||
"RemoteLocation": cfg.RemoteLocation,
|
||||
},
|
||||
})
|
||||
|
||||
return &cmd, nil
|
||||
log: cfg.Log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Command mimics behavior of SCP command line tool
|
||||
// to teleport can pretend it launches real SCP behind the scenes
|
||||
type command struct {
|
||||
Config
|
||||
log *log.Entry
|
||||
log log.FieldLogger
|
||||
}
|
||||
|
||||
// Execute implements SSH file copy (SCP). It is called on both tsh (client)
|
||||
|
@ -333,7 +336,7 @@ func (cmd *command) sendDir(r *reader, ch io.ReadWriter, fileInfo FileInfo) erro
|
|||
if _, err = fmt.Fprintf(ch, "E\n"); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
return r.read()
|
||||
return trace.Wrap(r.read())
|
||||
}
|
||||
|
||||
func (cmd *command) sendFile(r *reader, ch io.ReadWriter, fileInfo FileInfo) error {
|
||||
|
@ -375,7 +378,7 @@ func (cmd *command) sendFile(r *reader, ch io.ReadWriter, fileInfo FileInfo) err
|
|||
func (cmd *command) sendErr(ch io.Writer, err error) {
|
||||
out := fmt.Sprintf("%c%s\n", byte(ErrByte), err)
|
||||
if _, err := ch.Write([]byte(out)); err != nil {
|
||||
cmd.log.Debugf("failed sending SCP error message to the remote side: %v", err)
|
||||
cmd.log.Debugf("Failed sending SCP error message to the remote side: %v.", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -386,23 +389,25 @@ func (cmd *command) serveSink(ch io.ReadWriter) error {
|
|||
// directory.
|
||||
if cmd.Flags.DirectoryMode {
|
||||
if len(cmd.Flags.Target) != 1 {
|
||||
return trace.BadParameter("in directory mode, only single upload target is allowed but %v provided", len(cmd.Flags.Target))
|
||||
return trace.BadParameter("in directory mode, only single upload target is allowed but %q provided",
|
||||
cmd.Flags.Target)
|
||||
}
|
||||
|
||||
fi, err := os.Stat(cmd.Flags.Target[0])
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
if mode := fi.Mode(); !mode.IsDir() {
|
||||
if !cmd.FileSystem.IsDir(cmd.Flags.Target[0]) {
|
||||
return trace.BadParameter("target path must be a directory")
|
||||
}
|
||||
}
|
||||
|
||||
rootDir := localDir
|
||||
if cmd.hasTargetDir() {
|
||||
rootDir = newPathFromDir(cmd.Flags.Target[0])
|
||||
}
|
||||
|
||||
if err := sendOK(ch); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
var st state
|
||||
st.path = localDir
|
||||
st.path = rootDir
|
||||
var b [1]byte
|
||||
scanner := bufio.NewScanner(ch)
|
||||
for {
|
||||
|
@ -478,14 +483,13 @@ func (cmd *command) processCommand(ch io.ReadWriter, st *state, b byte, line str
|
|||
func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) error {
|
||||
cmd.log.Debugf("scp.receiveFile(%v): %v", cmd.Flags.Target, fc.Name)
|
||||
|
||||
// if the destination path is a folder, we should save the file to that folder, but
|
||||
// only if 'recursive' is set
|
||||
|
||||
path := cmd.Flags.Target[0]
|
||||
if cmd.Flags.Recursive || cmd.FileSystem.IsDir(path) {
|
||||
path = st.makePath(fc.Name)
|
||||
// Unless target specifies a file, use the file name from the command
|
||||
filename := fc.Name
|
||||
if !cmd.Flags.Recursive && !cmd.FileSystem.IsDir(cmd.Flags.Target[0]) {
|
||||
filename = cmd.Flags.Target[0]
|
||||
}
|
||||
|
||||
path := st.makePath(filename)
|
||||
writer, err := cmd.FileSystem.CreateFile(path, fc.Length)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
@ -504,7 +508,6 @@ func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) erro
|
|||
|
||||
n, err := io.CopyN(writer, ch, int64(fc.Length))
|
||||
if err != nil {
|
||||
cmd.log.Error(err)
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
|
@ -528,15 +531,9 @@ func (cmd *command) receiveFile(st *state, fc newFileCmd, ch io.ReadWriter) erro
|
|||
|
||||
func (cmd *command) receiveDir(st *state, fc newFileCmd, ch io.ReadWriter) error {
|
||||
cmd.log.Debugf("scp.receiveDir(%v): %v", cmd.Flags.Target, fc.Name)
|
||||
targetDir := cmd.Flags.Target[0]
|
||||
|
||||
// copying into an existing directory? append to it:
|
||||
if cmd.FileSystem.IsDir(targetDir) {
|
||||
targetDir = st.makePath(fc.Name)
|
||||
}
|
||||
st.push(fc.Name, st.stat)
|
||||
|
||||
err := cmd.FileSystem.MkDir(targetDir, int(fc.Mode))
|
||||
err := cmd.FileSystem.MkDir(st.path.join(), int(fc.Mode))
|
||||
if err != nil {
|
||||
return trace.ConvertSystemError(err)
|
||||
}
|
||||
|
@ -599,6 +596,14 @@ func (cmd *command) updateDirTimes(path pathSegments) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (cmd *command) hasTargetDir() bool {
|
||||
return len(cmd.Flags.Target) != 0 && cmd.FileSystem.IsDir(cmd.Flags.Target[0])
|
||||
}
|
||||
|
||||
func (r newFileCmd) String() string {
|
||||
return fmt.Sprintf("newFileCmd(mode=%o,len=%d,name=%v)", r.Mode, r.Length, r.Name)
|
||||
}
|
||||
|
||||
type newFileCmd struct {
|
||||
Mode int64
|
||||
Length uint64
|
||||
|
@ -675,15 +680,19 @@ type state struct {
|
|||
stat *mtimeCmd
|
||||
}
|
||||
|
||||
func (r pathSegments) join() string {
|
||||
func (r pathSegments) join(elems ...string) string {
|
||||
path := make([]string, 0, len(r))
|
||||
for _, s := range r {
|
||||
path = append(path, s.dir)
|
||||
}
|
||||
return filepath.Join(path...)
|
||||
return filepath.Join(append(path, elems...)...)
|
||||
}
|
||||
|
||||
var localDir = pathSegments{{dir: "."}}
|
||||
var localDir = newPathFromDir(".")
|
||||
|
||||
func newPathFromDir(dir string) pathSegments {
|
||||
return pathSegments{{dir: dir}}
|
||||
}
|
||||
|
||||
type pathSegments []pathSegment
|
||||
|
||||
|
@ -710,7 +719,7 @@ func (st *state) pop() pathSegments {
|
|||
}
|
||||
|
||||
func (st *state) makePath(filename string) string {
|
||||
return filepath.Join(st.path.join(), filename)
|
||||
return st.path.join(filename)
|
||||
}
|
||||
|
||||
func newReader(r io.Reader) *reader {
|
||||
|
|
|
@ -30,6 +30,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gravitational/teleport/lib/utils"
|
||||
|
||||
"github.com/gravitational/trace"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
@ -38,7 +39,7 @@ import (
|
|||
)
|
||||
|
||||
func TestHTTPSendFile(t *testing.T) {
|
||||
outDir := tempDir(t)
|
||||
outDir := t.TempDir()
|
||||
|
||||
expectedBytes := []byte("hello")
|
||||
buf := bytes.NewReader(expectedBytes)
|
||||
|
@ -65,7 +66,7 @@ func TestHTTPSendFile(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHTTPReceiveFile(t *testing.T) {
|
||||
source := filepath.Join(tempDir(t), "target")
|
||||
source := filepath.Join(t.TempDir(), "target")
|
||||
|
||||
contents := []byte("hello, file contents!")
|
||||
err := ioutil.WriteFile(source, contents, 0666)
|
||||
|
@ -112,7 +113,7 @@ func TestSend(t *testing.T) {
|
|||
desc: "regular file preserving the attributes",
|
||||
config: newSourceConfig("file", Flags{PreserveAttrs: true}),
|
||||
args: args("-v", "-t", "-p"),
|
||||
fs: newTestFS(logger, newFile("file", modtime, atime, "file contents")),
|
||||
fs: newTestFS(logger, newFileTimes("file", modtime, atime, "file contents")),
|
||||
},
|
||||
{
|
||||
desc: "directory preserving the attributes",
|
||||
|
@ -121,10 +122,10 @@ func TestSend(t *testing.T) {
|
|||
fs: newTestFS(
|
||||
logger,
|
||||
// Use timestamps extending backwards to test time application
|
||||
newDir("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second),
|
||||
newFile("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"),
|
||||
newDir("dir/dir2", dirModtime, dirAtime,
|
||||
newFile("dir/dir2/file2", modtime, atime, "file2 contents")),
|
||||
newDirTimes("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second),
|
||||
newFileTimes("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"),
|
||||
newDirTimes("dir/dir2", dirModtime, dirAtime,
|
||||
newFileTimes("dir/dir2/file2", modtime, atime, "file2 contents")),
|
||||
),
|
||||
),
|
||||
},
|
||||
|
@ -136,7 +137,7 @@ func TestSend(t *testing.T) {
|
|||
cmd, err := CreateCommand(tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
targetDir := tempDir(t)
|
||||
targetDir := t.TempDir()
|
||||
target := filepath.Join(targetDir, tt.config.Flags.Target[0])
|
||||
args := append(tt.args, target)
|
||||
|
||||
|
@ -153,7 +154,7 @@ func TestSend(t *testing.T) {
|
|||
|
||||
fs := newEmptyTestFS(logger)
|
||||
fromOS(t, targetDir, &fs)
|
||||
validateSCP(t, fs, tt.fs)
|
||||
validateSCPTimes(t, fs, tt.fs)
|
||||
validateSCPContents(t, fs, tt.fs)
|
||||
})
|
||||
}
|
||||
|
@ -177,7 +178,7 @@ func TestReceive(t *testing.T) {
|
|||
desc: "regular file preserving the attributes",
|
||||
config: newTargetConfig("file", Flags{PreserveAttrs: true}),
|
||||
args: args("-v", "-f", "-p"),
|
||||
fs: newTestFS(logger, newFile("file", modtime, atime, "file contents")),
|
||||
fs: newTestFS(logger, newFileTimes("file", modtime, atime, "file contents")),
|
||||
},
|
||||
{
|
||||
desc: "directory preserving the attributes",
|
||||
|
@ -186,10 +187,10 @@ func TestReceive(t *testing.T) {
|
|||
fs: newTestFS(
|
||||
logger,
|
||||
// Use timestamps extending backwards to test time application
|
||||
newDir("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second),
|
||||
newFile("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"),
|
||||
newDir("dir/dir2", dirModtime, dirAtime,
|
||||
newFile("dir/dir2/file2", modtime, atime, "file2 contents")),
|
||||
newDirTimes("dir", dirModtime.Add(1*time.Second), dirAtime.Add(2*time.Second),
|
||||
newFileTimes("dir/file", modtime.Add(1*time.Minute), atime.Add(2*time.Minute), "file contents"),
|
||||
newDirTimes("dir/dir2", dirModtime, dirAtime,
|
||||
newFileTimes("dir/dir2/file2", modtime, atime, "file2 contents")),
|
||||
),
|
||||
),
|
||||
},
|
||||
|
@ -201,7 +202,7 @@ func TestReceive(t *testing.T) {
|
|||
cmd, err := CreateCommand(tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
sourceDir := tempDir(t)
|
||||
sourceDir := t.TempDir()
|
||||
source := filepath.Join(sourceDir, tt.config.Flags.Target[0])
|
||||
args := append(tt.args, source)
|
||||
|
||||
|
@ -209,8 +210,7 @@ func TestReceive(t *testing.T) {
|
|||
err = runSCP(cmd, args...)
|
||||
require.Regexp(t, ".*No such file or directory", err)
|
||||
|
||||
fs := newEmptyTestFS(logger)
|
||||
tt.config.FileSystem = fs
|
||||
tt.config.FileSystem = newEmptyTestFS(logger)
|
||||
cmd, err = CreateCommand(tt.config)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -221,12 +221,58 @@ func TestReceive(t *testing.T) {
|
|||
err = runSCP(cmd, args...)
|
||||
require.NoError(t, err)
|
||||
|
||||
validateSCP(t, tt.fs, fs)
|
||||
validateSCPContents(t, tt.fs, fs)
|
||||
validateSCPTimes(t, tt.fs, tt.config.FileSystem)
|
||||
validateSCPContents(t, tt.fs, tt.config.FileSystem)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReceiveIntoExistingDirectory validates that the target remote directory
|
||||
// is respected during copy.
|
||||
//
|
||||
// See https://github.com/gravitational/teleport/issues/5497
|
||||
func TestReceiveIntoExistingDirectory(t *testing.T) {
|
||||
utils.InitLoggerForTests(testing.Verbose())
|
||||
logger := logrus.WithField("test", t.Name())
|
||||
config := newTargetConfigWithFS("dir",
|
||||
Flags{PreserveAttrs: true, Recursive: true},
|
||||
newTestFS(logger, newDir("dir")),
|
||||
)
|
||||
sourceFS := newTestFS(
|
||||
logger,
|
||||
newDir("dir",
|
||||
newFile("dir/file", "file contents"),
|
||||
newDir("dir/dir2",
|
||||
newFile("dir/dir2/file2", "file2 contents")),
|
||||
),
|
||||
)
|
||||
expectedFS := newTestFS(
|
||||
logger,
|
||||
// Source is copied into an existing directory
|
||||
newDir("dir/dir",
|
||||
newFile("dir/dir/file", "file contents"),
|
||||
newDir("dir/dir/dir2",
|
||||
newFile("dir/dir/dir2/file2", "file2 contents")),
|
||||
),
|
||||
)
|
||||
sourceDir := t.TempDir()
|
||||
source := filepath.Join(sourceDir, config.Flags.Target[0])
|
||||
args := append(args("-v", "-f", "-r", "-p"), source)
|
||||
|
||||
cmd, err := CreateCommand(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
writeData(t, sourceDir, sourceFS)
|
||||
writeFileTimes(t, sourceDir, sourceFS)
|
||||
|
||||
err = runSCP(cmd, args...)
|
||||
require.NoError(t, err)
|
||||
|
||||
validateSCP(t, expectedFS, config.FileSystem)
|
||||
validateSCPContents(t, expectedFS, config.FileSystem)
|
||||
}
|
||||
|
||||
func TestInvalidDir(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -295,7 +341,7 @@ func TestInvalidDir(t *testing.T) {
|
|||
// directory.
|
||||
func TestVerifyDir(t *testing.T) {
|
||||
// Create temporary directory with a file "target" in it.
|
||||
dir := tempDir(t)
|
||||
dir := t.TempDir()
|
||||
target := filepath.Join(dir, "target")
|
||||
err := ioutil.WriteFile(target, []byte{}, 0666)
|
||||
require.NoError(t, err)
|
||||
|
@ -492,11 +538,24 @@ func validateSCPContents(t *testing.T, expected testFS, actual FileSystem) {
|
|||
}
|
||||
|
||||
// validateSCP verifies that the specified pair of FileSystems match.
|
||||
// FileSystem match if their contents match incl. access/modification times
|
||||
func validateSCP(t *testing.T, expected testFS, actual FileSystem) {
|
||||
for path, fileinfo := range expected.fs {
|
||||
targetFileinfo, err := actual.GetFileInfo(path)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err, "expected %v", path)
|
||||
if fileinfo.IsDir() {
|
||||
require.True(t, targetFileinfo.IsDir())
|
||||
} else {
|
||||
require.True(t, targetFileinfo.GetModePerm().IsRegular())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validateSCPTimes verifies that the specified pair of FileSystems match.
|
||||
// FileSystem match if their contents match incl. access/modification times
|
||||
func validateSCPTimes(t *testing.T, expected testFS, actual FileSystem) {
|
||||
for path, fileinfo := range expected.fs {
|
||||
targetFileinfo, err := actual.GetFileInfo(path)
|
||||
require.NoError(t, err, "expected %v", path)
|
||||
if fileinfo.IsDir() {
|
||||
require.True(t, targetFileinfo.IsDir())
|
||||
} else {
|
||||
|
@ -691,13 +750,6 @@ type nopReadCloser struct {
|
|||
|
||||
var errMissingFile = fmt.Errorf("no such file or directory")
|
||||
|
||||
func tempDir(t *testing.T) (dir string) {
|
||||
path, err := ioutil.TempDir("", "test")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { os.RemoveAll(path) })
|
||||
return path
|
||||
}
|
||||
|
||||
func newSourceConfig(path string, flags Flags) Config {
|
||||
flags.Source = true
|
||||
flags.Target = []string{path}
|
||||
|
@ -707,6 +759,12 @@ func newSourceConfig(path string, flags Flags) Config {
|
|||
}
|
||||
}
|
||||
|
||||
func newTargetConfigWithFS(path string, flags Flags, fs testFS) Config {
|
||||
config := newTargetConfig(path, flags)
|
||||
config.FileSystem = &fs
|
||||
return config
|
||||
}
|
||||
|
||||
func newTargetConfig(path string, flags Flags) Config {
|
||||
flags.Sink = true
|
||||
flags.Target = []string{path}
|
||||
|
@ -716,7 +774,25 @@ func newTargetConfig(path string, flags Flags) Config {
|
|||
}
|
||||
}
|
||||
|
||||
func newDir(name string, modtime, atime time.Time, ents ...*testFileInfo) *testFileInfo {
|
||||
func newDir(name string, ents ...*testFileInfo) *testFileInfo {
|
||||
return &testFileInfo{
|
||||
path: name,
|
||||
ents: ents,
|
||||
dir: true,
|
||||
perms: 0755,
|
||||
}
|
||||
}
|
||||
|
||||
func newFile(name string, contents string) *testFileInfo {
|
||||
return &testFileInfo{
|
||||
path: name,
|
||||
perms: 0666,
|
||||
size: int64(len(contents)),
|
||||
contents: bytes.NewBufferString(contents),
|
||||
}
|
||||
}
|
||||
|
||||
func newDirTimes(name string, modtime, atime time.Time, ents ...*testFileInfo) *testFileInfo {
|
||||
return &testFileInfo{
|
||||
path: name,
|
||||
ents: ents,
|
||||
|
@ -727,7 +803,7 @@ func newDir(name string, modtime, atime time.Time, ents ...*testFileInfo) *testF
|
|||
}
|
||||
}
|
||||
|
||||
func newFile(name string, modtime, atime time.Time, contents string) *testFileInfo {
|
||||
func newFileTimes(name string, modtime, atime time.Time, contents string) *testFileInfo {
|
||||
return &testFileInfo{
|
||||
path: name,
|
||||
modtime: modtime,
|
||||
|
|
|
@ -75,14 +75,23 @@ func (r *RemoveDirCloser) Close() error {
|
|||
}
|
||||
|
||||
// IsDir is a helper function to quickly check if a given path is a valid directory
|
||||
func IsDir(dirPath string) bool {
|
||||
fi, err := os.Stat(dirPath)
|
||||
func IsDir(path string) bool {
|
||||
fi, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return fi.IsDir()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsFile is a convenience helper to check if the given path is a regular file
|
||||
func IsFile(path string) bool {
|
||||
fi, err := os.Stat(path)
|
||||
if err == nil {
|
||||
return fi.Mode().IsRegular()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NormalizePath normalises path, evaluating symlinks and converting local
|
||||
// paths to absolute
|
||||
func NormalizePath(path string) (string, error) {
|
||||
|
|
|
@ -1410,18 +1410,17 @@ func onSCP(cf *CLIConf) error {
|
|||
PreserveAttrs: cf.PreserveAttrs,
|
||||
}
|
||||
err = client.RetryWithRelogin(cf.Context, tc, func() error {
|
||||
return tc.SCP(context.TODO(), cf.CopySpec, int(cf.NodePort), flags, cf.Quiet)
|
||||
return tc.SCP(cf.Context, cf.CopySpec, int(cf.NodePort), flags, cf.Quiet)
|
||||
})
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// exit with the same exit status as the failed command:
|
||||
if tc.ExitStatus != 0 {
|
||||
fmt.Fprintln(os.Stderr, utils.UserMessageFromError(err))
|
||||
os.Exit(tc.ExitStatus)
|
||||
} else {
|
||||
}
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeClient takes the command-line configuration and constructs & returns
|
||||
|
|
Loading…
Reference in a new issue