Support canceling an in-progress download

This commit is contained in:
Dane Everitt
2020-12-20 12:53:40 -08:00
parent c718da20e3
commit c8d297a056
5 changed files with 91 additions and 11 deletions

View File

@@ -5,6 +5,7 @@ import (
"emperror.dev/errors"
"github.com/google/uuid"
"github.com/pterodactyl/wings/server"
"io"
"net/http"
"net/url"
"path/filepath"
@@ -13,6 +14,18 @@ import (
"time"
)
type Counter struct {
total int
onWrite func(total int)
}
func (c *Counter) Write(p []byte) (int, error) {
n := len(p)
c.total += n
c.onWrite(c.total)
return n, nil
}
type Downloader struct {
mu sync.RWMutex
downloadCache map[string]Download
@@ -28,6 +41,7 @@ type Download struct {
Identifier string
req DownloadRequest
server *server.Server
progress float64
cancelFunc *context.CancelFunc
}
@@ -53,6 +67,28 @@ func New(s *server.Server, r DownloadRequest) *Download {
return &dl
}
// Returns all of the tracked downloads for a given server instance.
func ByServer(sid string) []Download {
var downloads []Download
if v, ok := instance.serverCache[sid]; ok {
for _, id := range v {
if dl, dlok := instance.downloadCache[id]; dlok {
downloads = append(downloads, dl)
}
}
}
return downloads
}
// Returns a single Download matching a given identifier. If no download is found
// the second argument in the response will be false.
func ByID(dlid string) (Download, bool) {
if v, ok := instance.downloadCache[dlid]; ok {
return v, true
}
return Download{}, false
}
// Executes a given download for the server and begins writing the file to the disk. Once
// completed the download will be removed from the cache.
func (dl *Download) Execute() error {
@@ -82,7 +118,9 @@ func (dl *Download) Execute() error {
fnameparts := strings.Split(dl.req.URL.Path, "/")
p := filepath.Join(dl.req.Directory, fnameparts[len(fnameparts)-1])
dl.server.Log().WithField("path", p).Debug("writing remote file to disk")
if err := dl.server.Filesystem().Writefile(p, res.Body); err != nil {
r := io.TeeReader(res.Body, dl.counter(res.ContentLength))
if err := dl.server.Filesystem().Writefile(p, r); err != nil {
return errors.WrapIf(err, "downloader: failed to write file to server directory")
}
return nil
@@ -97,6 +135,23 @@ func (dl *Download) Cancel() {
instance.remove(dl.Identifier)
}
// Checks if the given download belongs to the provided server.
func (dl *Download) BelongsTo(s *server.Server) bool {
return dl.server.Id() == s.Id()
}
// Handles a write event by updating the progress completed percentage and firing off
// events to the server websocket as needed.
func (dl *Download) counter(contentLength int64) *Counter {
onWrite := func(t int) {
dl.progress = float64(t) / float64(contentLength)
}
return &Counter{
onWrite: onWrite,
}
}
// Tracks a download in the internal cache for this instance.
func (d *Downloader) track(dl Download) {
d.mu.Lock()