server(filesystem): fix inaccurate archive progress (#145)

This commit is contained in:
Matthew Penner
2022-11-06 13:38:30 -07:00
committed by GitHub
parent 3337362955
commit eb4df39d14
12 changed files with 191 additions and 144 deletions

View File

@@ -30,6 +30,30 @@ var pool = sync.Pool{
},
}
// TarProgress .
type TarProgress struct {
*tar.Writer
p *Progress
}
// NewTarProgress .
func NewTarProgress(w *tar.Writer, p *Progress) *TarProgress {
if p != nil {
p.w = w
}
return &TarProgress{
Writer: w,
p: p,
}
}
func (p *TarProgress) Write(v []byte) (int, error) {
if p.p == nil {
return p.Writer.Write(v)
}
return p.p.Write(v)
}
// Progress is used to track the progress of any I/O operation that are being
// performed.
type Progress struct {
@@ -46,6 +70,12 @@ func NewProgress(total int64) *Progress {
return &Progress{total: total}
}
// SetWriter sets the writer progress will forward writes to.
// NOTE: This function is not thread safe.
func (p *Progress) SetWriter(w io.Writer) {
p.w = w
}
// Written returns the total number of bytes written.
// This function should be used when the progress is tracking data being written.
func (p *Progress) Written() int64 {
@@ -157,23 +187,17 @@ func (a *Archive) Create(dst string) error {
_ = gw.SetConcurrency(1<<20, 1)
defer gw.Close()
var pw io.Writer
if a.Progress != nil {
a.Progress.w = gw
pw = a.Progress
} else {
pw = gw
}
// Create a new tar writer around the gzip writer.
tw := tar.NewWriter(pw)
tw := tar.NewWriter(gw)
defer tw.Close()
pw := NewTarProgress(tw, a.Progress)
// Configure godirwalk.
options := &godirwalk.Options{
FollowSymbolicLinks: false,
Unsorted: true,
Callback: a.callback(tw),
Callback: a.callback(pw),
}
// If we're specifically looking for only certain files, or have requested
@@ -182,7 +206,7 @@ func (a *Archive) Create(dst string) error {
if len(a.Files) == 0 && len(a.Ignore) > 0 {
i := ignore.CompileIgnoreLines(strings.Split(a.Ignore, "\n")...)
options.Callback = a.callback(tw, func(_ string, rp string) error {
options.Callback = a.callback(pw, func(_ string, rp string) error {
if i.MatchesPath(rp) {
return godirwalk.SkipThis
}
@@ -190,7 +214,7 @@ func (a *Archive) Create(dst string) error {
return nil
})
} else if len(a.Files) > 0 {
options.Callback = a.withFilesCallback(tw)
options.Callback = a.withFilesCallback(pw)
}
// Recursively walk the path we are archiving.
@@ -199,7 +223,7 @@ func (a *Archive) Create(dst string) error {
// Callback function used to determine if a given file should be included in the archive
// being generated.
func (a *Archive) callback(tw *tar.Writer, opts ...func(path string, relative string) error) func(path string, de *godirwalk.Dirent) error {
func (a *Archive) callback(tw *TarProgress, opts ...func(path string, relative string) error) func(path string, de *godirwalk.Dirent) error {
return func(path string, de *godirwalk.Dirent) error {
// Skip directories because we are walking them recursively.
if de.IsDir() {
@@ -223,7 +247,7 @@ func (a *Archive) callback(tw *tar.Writer, opts ...func(path string, relative st
}
// Pushes only files defined in the Files key to the final archive.
func (a *Archive) withFilesCallback(tw *tar.Writer) func(path string, de *godirwalk.Dirent) error {
func (a *Archive) withFilesCallback(tw *TarProgress) func(path string, de *godirwalk.Dirent) error {
return a.callback(tw, func(p string, rp string) error {
for _, f := range a.Files {
// If the given doesn't match, or doesn't have the same prefix continue
@@ -244,7 +268,7 @@ func (a *Archive) withFilesCallback(tw *tar.Writer) func(path string, de *godirw
}
// Adds a given file path to the final archive being created.
func (a *Archive) addToArchive(p string, rp string, w *tar.Writer) error {
func (a *Archive) addToArchive(p string, rp string, w *TarProgress) error {
// Lstat the file, this will give us the same information as Stat except that it will not
// follow a symlink to its target automatically. This is important to avoid including
// files that exist outside the server root unintentionally in the backup.

View File

@@ -4,7 +4,9 @@ import (
"archive/tar"
"archive/zip"
"compress/gzip"
"context"
"fmt"
iofs "io/fs"
"os"
"path"
"path/filepath"
@@ -13,11 +15,10 @@ import (
"sync/atomic"
"time"
"emperror.dev/errors"
gzip2 "github.com/klauspost/compress/gzip"
zip2 "github.com/klauspost/compress/zip"
"emperror.dev/errors"
"github.com/mholt/archiver/v3"
"github.com/mholt/archiver/v4"
)
// CompressFiles compresses all of the files matching the given paths in the
@@ -73,7 +74,7 @@ func (fs *Filesystem) CompressFiles(dir string, paths []string) (os.FileInfo, er
// SpaceAvailableForDecompression looks through a given archive and determines
// if decompressing it would put the server over its allocated disk space limit.
func (fs *Filesystem) SpaceAvailableForDecompression(dir string, file string) error {
func (fs *Filesystem) SpaceAvailableForDecompression(ctx context.Context, dir string, file string) error {
// Don't waste time trying to determine this if we know the server will have the space for
// it since there is no limit.
if fs.MaxDisk() <= 0 {
@@ -89,69 +90,104 @@ func (fs *Filesystem) SpaceAvailableForDecompression(dir string, file string) er
// waiting an unnecessary amount of time on this call.
dirSize, err := fs.DiskUsage(false)
var size int64
// Walk over the archive and figure out just how large the final output would be from unarchiving it.
err = archiver.Walk(source, func(f archiver.File) error {
if atomic.AddInt64(&size, f.Size())+dirSize > fs.MaxDisk() {
return newFilesystemError(ErrCodeDiskSpace, nil)
}
return nil
})
fsys, err := archiver.FileSystem(source)
if err != nil {
if IsUnknownArchiveFormatError(err) {
if errors.Is(err, archiver.ErrNoMatch) {
return newFilesystemError(ErrCodeUnknownArchive, err)
}
return err
}
return err
var size int64
return iofs.WalkDir(fsys, ".", func(path string, d iofs.DirEntry, err error) error {
if err != nil {
return err
}
select {
case <-ctx.Done():
// Stop walking if the context is canceled.
return ctx.Err()
default:
info, err := d.Info()
if err != nil {
return err
}
if atomic.AddInt64(&size, info.Size())+dirSize > fs.MaxDisk() {
return newFilesystemError(ErrCodeDiskSpace, nil)
}
return nil
}
})
}
// DecompressFile will decompress a file in a given directory by using the
// archiver tool to infer the file type and go from there. This will walk over
// all of the files within the given archive and ensure that there is not a
// all the files within the given archive and ensure that there is not a
// zip-slip attack being attempted by validating that the final path is within
// the server data directory.
func (fs *Filesystem) DecompressFile(dir string, file string) error {
func (fs *Filesystem) DecompressFile(ctx context.Context, dir string, file string) error {
source, err := fs.SafePath(filepath.Join(dir, file))
if err != nil {
return err
}
// Ensure that the source archive actually exists on the system.
if _, err := os.Stat(source); err != nil {
return fs.DecompressFileUnsafe(ctx, dir, source)
}
// DecompressFileUnsafe will decompress any file on the local disk without checking
// if it is owned by the server. The file will be SAFELY decompressed and extracted
// into the server's directory.
func (fs *Filesystem) DecompressFileUnsafe(ctx context.Context, dir string, file string) error {
// Ensure that the archive actually exists on the system.
if _, err := os.Stat(file); err != nil {
return errors.WithStack(err)
}
// Walk all of the files in the archiver file and write them to the disk. If any
// directory is encountered it will be skipped since we handle creating any missing
// directories automatically when writing files.
err = archiver.Walk(source, func(f archiver.File) error {
if f.IsDir() {
return nil
}
p := filepath.Join(dir, ExtractNameFromArchive(f))
// If it is ignored, just don't do anything with the file and skip over it.
if err := fs.IsIgnored(p); err != nil {
return nil
}
if err := fs.Writefile(p, f); err != nil {
return wrapError(err, source)
}
// Update the file permissions to the one set in the archive.
if err := fs.Chmod(p, f.Mode()); err != nil {
return wrapError(err, source)
}
// Update the file modification time to the one set in the archive.
if err := fs.Chtimes(p, f.ModTime(), f.ModTime()); err != nil {
return wrapError(err, source)
}
return nil
})
f, err := os.Open(file)
if err != nil {
if IsUnknownArchiveFormatError(err) {
return err
}
// Identify the type of archive we are dealing with.
format, input, err := archiver.Identify(filepath.Base(file), f)
if err != nil {
if errors.Is(err, archiver.ErrNoMatch) {
return newFilesystemError(ErrCodeUnknownArchive, err)
}
return err
}
// Decompress and extract archive
if ex, ok := format.(archiver.Extractor); ok {
return ex.Extract(ctx, input, nil, func(ctx context.Context, f archiver.File) error {
if f.IsDir() {
return nil
}
p := filepath.Join(dir, ExtractNameFromArchive(f))
// If it is ignored, just don't do anything with the file and skip over it.
if err := fs.IsIgnored(p); err != nil {
return nil
}
r, err := f.Open()
if err != nil {
return err
}
defer r.Close()
if err := fs.Writefile(p, r); err != nil {
return wrapError(err, file)
}
// Update the file permissions to the one set in the archive.
if err := fs.Chmod(p, f.Mode()); err != nil {
return wrapError(err, file)
}
// Update the file modification time to the one set in the archive.
if err := fs.Chtimes(p, f.ModTime(), f.ModTime()); err != nil {
return wrapError(err, file)
}
return nil
})
}
return nil
}

View File

@@ -1,6 +1,7 @@
package filesystem
import (
"context"
"os"
"sync/atomic"
"testing"
@@ -28,7 +29,7 @@ func TestFilesystem_DecompressFile(t *testing.T) {
g.Assert(err).IsNil()
// decompress
err = fs.DecompressFile("/", "test."+ext)
err = fs.DecompressFile(context.Background(), "/", "test."+ext)
g.Assert(err).IsNil()
// make sure everything is where it is supposed to be

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"emperror.dev/errors"
"github.com/apex/log"
@@ -122,15 +121,6 @@ func IsErrorCode(err error, code ErrorCode) bool {
return false
}
// IsUnknownArchiveFormatError checks if the error is due to the archive being
// in an unexpected file format.
func IsUnknownArchiveFormatError(err error) bool {
if err != nil && strings.HasPrefix(err.Error(), "format ") {
return true
}
return false
}
// NewBadPathResolution returns a new BadPathResolution error.
func NewBadPathResolution(path string, resolved string) error {
return errors.WithStackDepth(&Error{code: ErrCodePathResolution, path: path, resolved: resolved}, 1)