Support canceling an in-progress download
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"emperror.dev/errors"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pterodactyl/wings/server"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
@@ -13,6 +14,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Counter struct {
|
||||
total int
|
||||
onWrite func(total int)
|
||||
}
|
||||
|
||||
func (c *Counter) Write(p []byte) (int, error) {
|
||||
n := len(p)
|
||||
c.total += n
|
||||
c.onWrite(c.total)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type Downloader struct {
|
||||
mu sync.RWMutex
|
||||
downloadCache map[string]Download
|
||||
@@ -28,6 +41,7 @@ type Download struct {
|
||||
Identifier string
|
||||
req DownloadRequest
|
||||
server *server.Server
|
||||
progress float64
|
||||
cancelFunc *context.CancelFunc
|
||||
}
|
||||
|
||||
@@ -53,6 +67,28 @@ func New(s *server.Server, r DownloadRequest) *Download {
|
||||
return &dl
|
||||
}
|
||||
|
||||
// Returns all of the tracked downloads for a given server instance.
|
||||
func ByServer(sid string) []Download {
|
||||
var downloads []Download
|
||||
if v, ok := instance.serverCache[sid]; ok {
|
||||
for _, id := range v {
|
||||
if dl, dlok := instance.downloadCache[id]; dlok {
|
||||
downloads = append(downloads, dl)
|
||||
}
|
||||
}
|
||||
}
|
||||
return downloads
|
||||
}
|
||||
|
||||
// Returns a single Download matching a given identifier. If no download is found
|
||||
// the second argument in the response will be false.
|
||||
func ByID(dlid string) (Download, bool) {
|
||||
if v, ok := instance.downloadCache[dlid]; ok {
|
||||
return v, true
|
||||
}
|
||||
return Download{}, false
|
||||
}
|
||||
|
||||
// Executes a given download for the server and begins writing the file to the disk. Once
|
||||
// completed the download will be removed from the cache.
|
||||
func (dl *Download) Execute() error {
|
||||
@@ -82,7 +118,9 @@ func (dl *Download) Execute() error {
|
||||
fnameparts := strings.Split(dl.req.URL.Path, "/")
|
||||
p := filepath.Join(dl.req.Directory, fnameparts[len(fnameparts)-1])
|
||||
dl.server.Log().WithField("path", p).Debug("writing remote file to disk")
|
||||
if err := dl.server.Filesystem().Writefile(p, res.Body); err != nil {
|
||||
|
||||
r := io.TeeReader(res.Body, dl.counter(res.ContentLength))
|
||||
if err := dl.server.Filesystem().Writefile(p, r); err != nil {
|
||||
return errors.WrapIf(err, "downloader: failed to write file to server directory")
|
||||
}
|
||||
return nil
|
||||
@@ -97,6 +135,23 @@ func (dl *Download) Cancel() {
|
||||
instance.remove(dl.Identifier)
|
||||
}
|
||||
|
||||
// Checks if the given download belongs to the provided server.
|
||||
func (dl *Download) BelongsTo(s *server.Server) bool {
|
||||
return dl.server.Id() == s.Id()
|
||||
}
|
||||
|
||||
// Handles a write event by updating the progress completed percentage and firing off
|
||||
// events to the server websocket as needed.
|
||||
func (dl *Download) counter(contentLength int64) *Counter {
|
||||
onWrite := func(t int) {
|
||||
dl.progress = float64(t) / float64(contentLength)
|
||||
}
|
||||
|
||||
return &Counter{
|
||||
onWrite: onWrite,
|
||||
}
|
||||
}
|
||||
|
||||
// Tracks a download in the internal cache for this instance.
|
||||
func (d *Downloader) track(dl Download) {
|
||||
d.mu.Lock()
|
||||
|
||||
Reference in New Issue
Block a user