pkg/copy: introduce a Copier

Introduce a `Copier` object to separate the copy-rule enforcement from
copying.  That allows for a better error reporting of the REST API.

Signed-off-by: Valentin Rothberg <rothberg@redhat.com>
This commit is contained in:
Valentin Rothberg 2020-12-09 12:18:14 +01:00
parent c2a5011c0d
commit a12323884f
3 changed files with 70 additions and 22 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/containers/podman/v2/pkg/copy"
"github.com/gorilla/schema"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
func Archive(w http.ResponseWriter, r *http.Request) {
@ -86,11 +87,16 @@ func handleHeadAndGet(w http.ResponseWriter, r *http.Request, decoder *schema.De
return
}
w.WriteHeader(http.StatusOK)
if err := copy.Copy(&source, &destination, false); err != nil {
copier, err := copy.GetCopier(&source, &destination, false)
if err != nil {
utils.Error(w, "Something went wrong", http.StatusInternalServerError, err)
return
}
w.WriteHeader(http.StatusOK)
if err := copier.Copy(); err != nil {
logrus.Errorf("Error during copy: %v", err)
return
}
}
func handlePut(w http.ResponseWriter, r *http.Request, decoder *schema.Decoder, runtime *libpod.Runtime) {
@ -129,9 +135,14 @@ func handlePut(w http.ResponseWriter, r *http.Request, decoder *schema.Decoder,
return
}
w.WriteHeader(http.StatusOK)
if err := copy.Copy(&source, &destination, false); err != nil {
copier, err := copy.GetCopier(&source, &destination, false)
if err != nil {
utils.Error(w, "Something went wrong", http.StatusInternalServerError, err)
return
}
w.WriteHeader(http.StatusOK)
if err := copier.Copy(); err != nil {
logrus.Errorf("Error during copy: %v", err)
return
}
}

View file

@ -25,31 +25,61 @@ import (
//
// ****************************************************************************
// Copy the source item to destination. Use extract to untar the source if
// it's a tar archive.
func Copy(source *CopyItem, destination *CopyItem, extract bool) error {
// Copier copies data from a source to a destination CopyItem.
type Copier struct {
copyFunc func() error
cleanUpFuncs []deferFunc
}
// cleanUp releases resources the Copier may hold open.
func (c *Copier) cleanUp() {
for _, f := range c.cleanUpFuncs {
f()
}
}
// Copy data from a source to a destination CopyItem.
func (c *Copier) Copy() error {
defer c.cleanUp()
return c.copyFunc()
}
// GetCopiers returns a Copier to copy the source item to destination. Use
// extract to untar the source if it's a tar archive.
func GetCopier(source *CopyItem, destination *CopyItem, extract bool) (*Copier, error) {
copier := &Copier{}
// First, do the man-page dance. See podman-cp(1) for details.
if err := enforceCopyRules(source, destination); err != nil {
return err
return nil, err
}
// Destination is a stream (e.g., stdout or an http body).
if destination.info.IsStream {
// Source is a stream (e.g., stdin or an http body).
if source.info.IsStream {
_, err := io.Copy(destination.writer, source.reader)
return err
copier.copyFunc = func() error {
_, err := io.Copy(destination.writer, source.reader)
return err
}
return copier, nil
}
root, glob, err := source.buildahGlobs()
if err != nil {
return err
return nil, err
}
return buildahCopiah.Get(root, "", source.getOptions(), []string{glob}, destination.writer)
copier.copyFunc = func() error {
return buildahCopiah.Get(root, "", source.getOptions(), []string{glob}, destination.writer)
}
return copier, nil
}
// Destination is either a file or a directory.
if source.info.IsStream {
return buildahCopiah.Put(destination.root, destination.resolved, source.putOptions(), source.reader)
copier.copyFunc = func() error {
return buildahCopiah.Put(destination.root, destination.resolved, source.putOptions(), source.reader)
}
return copier, nil
}
tarOptions := &archive.TarOptions{
@ -71,33 +101,36 @@ func Copy(source *CopyItem, destination *CopyItem, extract bool) error {
var tarReader io.ReadCloser
if extract && archive.IsArchivePath(source.resolved) {
if !destination.info.IsDir {
return errors.Errorf("cannot extract archive %q to file %q", source.original, destination.original)
return nil, errors.Errorf("cannot extract archive %q to file %q", source.original, destination.original)
}
reader, err := os.Open(source.resolved)
if err != nil {
return err
return nil, err
}
defer reader.Close()
copier.cleanUpFuncs = append(copier.cleanUpFuncs, func() { reader.Close() })
// The stream from stdin may be compressed (e.g., via gzip).
decompressedStream, err := archive.DecompressStream(reader)
if err != nil {
return err
return nil, err
}
defer decompressedStream.Close()
copier.cleanUpFuncs = append(copier.cleanUpFuncs, func() { decompressedStream.Close() })
tarReader = decompressedStream
} else {
reader, err := archive.TarWithOptions(source.resolved, tarOptions)
if err != nil {
return err
return nil, err
}
defer reader.Close()
copier.cleanUpFuncs = append(copier.cleanUpFuncs, func() { reader.Close() })
tarReader = reader
}
return buildahCopiah.Put(root, dir, source.putOptions(), tarReader)
copier.copyFunc = func() error {
return buildahCopiah.Put(root, dir, source.putOptions(), tarReader)
}
return copier, nil
}
// enforceCopyRules enforces the rules for copying from a source to a

View file

@ -62,5 +62,9 @@ func (ic *ContainerEngine) ContainerCp(ctx context.Context, source, dest string,
}
// Copy from the host to the container.
return copy.Copy(&sourceItem, &destinationItem, options.Extract)
copier, err := copy.GetCopier(&sourceItem, &destinationItem, options.Extract)
if err != nil {
return err
}
return copier.Copy()
}