Include endpoints for getting active downloads for a server

This commit is contained in:
Dane Everitt 2020-12-20 13:14:07 -08:00
parent f8282c56cb
commit 184013b652
No known key found for this signature in database
GPG Key ID: EEA66103B3D71F53
3 changed files with 52 additions and 18 deletions

View File

@ -3,6 +3,7 @@ package downloader
import ( import (
"context" "context"
"emperror.dev/errors" "emperror.dev/errors"
"encoding/json"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/pterodactyl/wings/server" "github.com/pterodactyl/wings/server"
"io" "io"
@ -28,7 +29,7 @@ func (c *Counter) Write(p []byte) (int, error) {
type Downloader struct { type Downloader struct {
mu sync.RWMutex mu sync.RWMutex
downloadCache map[string]Download downloadCache map[string]*Download
serverCache map[string][]string serverCache map[string][]string
} }
@ -39,6 +40,7 @@ type DownloadRequest struct {
type Download struct { type Download struct {
Identifier string Identifier string
mu sync.RWMutex
req DownloadRequest req DownloadRequest
server *server.Server server *server.Server
progress float64 progress float64
@ -48,7 +50,7 @@ type Download struct {
var client = &http.Client{Timeout: time.Hour * 12} var client = &http.Client{Timeout: time.Hour * 12}
var instance = &Downloader{ var instance = &Downloader{
// Tracks all of the active downloads. // Tracks all of the active downloads.
downloadCache: make(map[string]Download), downloadCache: make(map[string]*Download),
// Tracks all of the downloads active for a given server instance. This is // Tracks all of the downloads active for a given server instance. This is
// primarily used to make things quicker and keep the code a little more // primarily used to make things quicker and keep the code a little more
// legible throughout here. // legible throughout here.
@ -63,13 +65,15 @@ func New(s *server.Server, r DownloadRequest) *Download {
req: r, req: r,
server: s, server: s,
} }
instance.track(dl) instance.track(&dl)
return &dl return &dl
} }
// Returns all of the tracked downloads for a given server instance. // Returns all of the tracked downloads for a given server instance.
func ByServer(sid string) []Download { func ByServer(sid string) []*Download {
var downloads []Download instance.mu.Lock()
defer instance.mu.Unlock()
var downloads []*Download
if v, ok := instance.serverCache[sid]; ok { if v, ok := instance.serverCache[sid]; ok {
for _, id := range v { for _, id := range v {
if dl, dlok := instance.downloadCache[id]; dlok { if dl, dlok := instance.downloadCache[id]; dlok {
@ -82,11 +86,19 @@ func ByServer(sid string) []Download {
// Returns a single Download matching a given identifier. If no download is found // Returns a single Download matching a given identifier. If no download is found
// the second argument in the response will be false. // the second argument in the response will be false.
func ByID(dlid string) (Download, bool) { func ByID(dlid string) *Download {
if v, ok := instance.downloadCache[dlid]; ok { return instance.find(dlid)
return v, true
} }
return Download{}, false
//goland:noinspection GoVetCopyLock
func (dl Download) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Identifier string
Progress float64
}{
Identifier: dl.Identifier,
Progress: dl.Progress(),
})
} }
// Executes a given download for the server and begins writing the file to the disk. Once // Executes a given download for the server and begins writing the file to the disk. Once
@ -140,20 +152,29 @@ func (dl *Download) BelongsTo(s *server.Server) bool {
return dl.server.Id() == s.Id() return dl.server.Id() == s.Id()
} }
// Returns the current progress of the download as a float value between 0 and 1 where
// 1 indicates that the download is completed.
func (dl *Download) Progress() float64 {
dl.mu.RLock()
defer dl.mu.RUnlock()
return dl.progress
}
// Handles a write event by updating the progress completed percentage and firing off // Handles a write event by updating the progress completed percentage and firing off
// events to the server websocket as needed. // events to the server websocket as needed.
func (dl *Download) counter(contentLength int64) *Counter { func (dl *Download) counter(contentLength int64) *Counter {
onWrite := func(t int) { onWrite := func(t int) {
dl.mu.Lock()
defer dl.mu.Unlock()
dl.progress = float64(t) / float64(contentLength) dl.progress = float64(t) / float64(contentLength)
} }
return &Counter{ return &Counter{
onWrite: onWrite, onWrite: onWrite,
} }
} }
// Tracks a download in the internal cache for this instance. // Tracks a download in the internal cache for this instance.
func (d *Downloader) track(dl Download) { func (d *Downloader) track(dl *Download) {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
sid := dl.server.Id() sid := dl.server.Id()
@ -167,13 +188,13 @@ func (d *Downloader) track(dl Download) {
} }
// Finds a given download entry using the provided ID and returns it. // Finds a given download entry using the provided ID and returns it.
func (d *Downloader) find(dlid string) (Download, bool) { func (d *Downloader) find(dlid string) *Download {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
if entry, ok := d.downloadCache[dlid]; ok { if entry, ok := d.downloadCache[dlid]; ok {
return entry, true return entry
} }
return Download{}, false return nil
} }
// Remove the given download reference from the cache storing them. This also updates // Remove the given download reference from the cache storing them. This also updates

View File

@ -82,13 +82,15 @@ func Configure() *gin.Engine {
files.PUT("/rename", putServerRenameFiles) files.PUT("/rename", putServerRenameFiles)
files.POST("/copy", postServerCopyFile) files.POST("/copy", postServerCopyFile)
files.POST("/write", postServerWriteFile) files.POST("/write", postServerWriteFile)
files.POST("/pull", postServerPullRemoteFile)
files.DELETE("/pull/:download", deleteServerPullRemoteFile)
files.POST("/create-directory", postServerCreateDirectory) files.POST("/create-directory", postServerCreateDirectory)
files.POST("/delete", postServerDeleteFiles) files.POST("/delete", postServerDeleteFiles)
files.POST("/compress", postServerCompressFiles) files.POST("/compress", postServerCompressFiles)
files.POST("/decompress", postServerDecompressFiles) files.POST("/decompress", postServerDecompressFiles)
files.POST("/chmod", postServerChmodFile) files.POST("/chmod", postServerChmodFile)
files.GET("/pull", getServerPullingFiles)
files.POST("/pull", postServerPullRemoteFile)
files.DELETE("/pull/:download", deleteServerPullRemoteFile)
} }
backup := server.Group("/backup") backup := server.Group("/backup")

View File

@ -233,6 +233,17 @@ func postServerWriteFile(c *gin.Context) {
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)
} }
// Returns all of the currently in-progress file downloads and their current download
// progress. The progress is also pushed out via a websocket event allowing you to just
// call this once to get current downloads, and then listen to targeted websocket events
// with the current progress for everything.
func getServerPullingFiles(c *gin.Context) {
s := ExtractServer(c)
c.JSON(http.StatusOK, gin.H{
"downloads": downloader.ByServer(s.Id()),
})
}
// Writes the contents of the remote URL to a file on a server. // Writes the contents of the remote URL to a file on a server.
func postServerPullRemoteFile(c *gin.Context) { func postServerPullRemoteFile(c *gin.Context) {
s := ExtractServer(c) s := ExtractServer(c)
@ -284,7 +295,7 @@ func postServerPullRemoteFile(c *gin.Context) {
// Stops a remote file download if it exists and belongs to this server. // Stops a remote file download if it exists and belongs to this server.
func deleteServerPullRemoteFile(c *gin.Context) { func deleteServerPullRemoteFile(c *gin.Context) {
s := ExtractServer(c) s := ExtractServer(c)
if dl, ok := downloader.ByID(c.Param("download")); ok && dl.BelongsTo(s) { if dl := downloader.ByID(c.Param("download")); dl != nil && dl.BelongsTo(s) {
dl.Cancel() dl.Cancel()
} }
c.Status(http.StatusNoContent) c.Status(http.StatusNoContent)