From 57e7eb714ccacb9789c5638244c3d8ad2af0274f Mon Sep 17 00:00:00 2001 From: Matthew Penner Date: Mon, 14 Nov 2022 18:25:01 -0700 Subject: [PATCH] Streaming Transfers (#153) --- internal/cron/cron.go | 22 +- internal/progress/progress.go | 90 +++ .../progress/progress_test.go | 26 +- router/router.go | 13 +- router/router_server.go | 16 +- router/router_server_transfer.go | 129 ++++ router/router_system.go | 6 +- router/router_transfer.go | 698 ++++++------------ server/backup/backup_local.go | 2 +- server/backup/backup_s3.go | 2 +- server/events.go | 4 +- server/filesystem/archive.go | 118 +-- server/filesystem/compress.go | 57 +- {installer => server/installer}/errors.go | 0 {installer => server/installer}/installer.go | 0 server/transfer/archive.go | 52 ++ server/transfer/doc.go | 4 + server/transfer/manager.go | 57 ++ server/transfer/source.go | 159 ++++ server/transfer/transfer.go | 128 ++++ system/utils.go | 44 +- 21 files changed, 1015 insertions(+), 612 deletions(-) create mode 100644 internal/progress/progress.go rename server/filesystem/archive_test.go => internal/progress/progress_test.go (63%) create mode 100644 router/router_server_transfer.go rename {installer => server/installer}/errors.go (100%) rename {installer => server/installer}/installer.go (100%) create mode 100644 server/transfer/archive.go create mode 100644 server/transfer/doc.go create mode 100644 server/transfer/manager.go create mode 100644 server/transfer/source.go create mode 100644 server/transfer/transfer.go diff --git a/internal/cron/cron.go b/internal/cron/cron.go index e697d81..fed4c04 100644 --- a/internal/cron/cron.go +++ b/internal/cron/cron.go @@ -5,7 +5,7 @@ import ( "time" "emperror.dev/errors" - log2 "github.com/apex/log" + "github.com/apex/log" "github.com/go-co-op/gocron" "github.com/pterodactyl/wings/config" @@ -24,7 +24,7 @@ func Scheduler(ctx context.Context, m *server.Manager) (*gocron.Scheduler, error if !o.SwapIf(true) { return nil, errors.New("cron: cannot call scheduler more than once in application lifecycle") } - l, err := time.LoadLocation(config.Get().System.Timezone) + location, err := time.LoadLocation(config.Get().System.Timezone) if err != nil { return nil, errors.Wrap(err, "cron: failed to parse configured system timezone") } @@ -41,30 +41,30 @@ func Scheduler(ctx context.Context, m *server.Manager) (*gocron.Scheduler, error max: config.Get().System.ActivitySendCount, } - s := gocron.NewScheduler(l) - log := log2.WithField("subsystem", "cron") + s := gocron.NewScheduler(location) + l := log.WithField("subsystem", "cron") interval := time.Duration(config.Get().System.ActivitySendInterval) * time.Second - log.WithField("interval", interval).Info("configuring system crons") + l.WithField("interval", interval).Info("configuring system crons") _, _ = s.Tag("activity").Every(interval).Do(func() { - log.WithField("cron", "activity").Debug("sending internal activity events to Panel") + l.WithField("cron", "activity").Debug("sending internal activity events to Panel") if err := activity.Run(ctx); err != nil { if errors.Is(err, ErrCronRunning) { - log.WithField("cron", "activity").Warn("activity process is already running, skipping...") + l.WithField("cron", "activity").Warn("activity process is already running, skipping...") } else { - log.WithField("cron", "activity").WithField("error", err).Error("activity process failed to execute") + l.WithField("cron", "activity").WithField("error", err).Error("activity process failed to execute") } } }) _, _ = s.Tag("sftp").Every(interval).Do(func() { - log.WithField("cron", "sftp").Debug("sending sftp events to Panel") + l.WithField("cron", "sftp").Debug("sending sftp events to Panel") if err := sftp.Run(ctx); err != nil { if errors.Is(err, ErrCronRunning) { - log.WithField("cron", "sftp").Warn("sftp events process already running, skipping...") + l.WithField("cron", "sftp").Warn("sftp events process already running, skipping...") } else { - log.WithField("cron", "sftp").WithField("error", err).Error("sftp events process failed to execute") + l.WithField("cron", "sftp").WithField("error", err).Error("sftp events process failed to execute") } } }) diff --git a/internal/progress/progress.go b/internal/progress/progress.go new file mode 100644 index 0000000..0e219af --- /dev/null +++ b/internal/progress/progress.go @@ -0,0 +1,90 @@ +package progress + +import ( + "io" + "strings" + "sync/atomic" + + "github.com/pterodactyl/wings/system" +) + +// Progress is used to track the progress of any I/O operation that are being +// performed. +type Progress struct { + // written is the total size of the files that have been written to the writer. + written uint64 + // Total is the total size of the archive in bytes. + total uint64 + + // Writer . + Writer io.Writer +} + +// NewProgress returns a new progress tracker for the given total size. +func NewProgress(total uint64) *Progress { + return &Progress{total: total} +} + +// Written returns the total number of bytes written. +// This function should be used when the progress is tracking data being written. +func (p *Progress) Written() uint64 { + return atomic.LoadUint64(&p.written) +} + +// Total returns the total size in bytes. +func (p *Progress) Total() uint64 { + return atomic.LoadUint64(&p.total) +} + +// SetTotal sets the total size of the archive in bytes. This function is safe +// to call concurrently and can be used to update the total size if it changes, +// such as when the total size is simultaneously being calculated as data is +// being written through the progress writer. +func (p *Progress) SetTotal(total uint64) { + atomic.StoreUint64(&p.total, total) +} + +// Write totals the number of bytes that have been written to the writer. +func (p *Progress) Write(v []byte) (int, error) { + n := len(v) + atomic.AddUint64(&p.written, uint64(n)) + if p.Writer != nil { + return p.Writer.Write(v) + } + return n, nil +} + +// Progress returns a formatted progress string for the current progress. +func (p *Progress) Progress(width int) string { + // current = 100 (Progress, dynamic) + // total = 1000 (Content-Length, dynamic) + // width = 25 (Number of ticks to display, static) + // widthPercentage = 100 / width (What percentage does each tick represent, static) + // + // percentageDecimal = current / total = 0.1 + // percentage = percentageDecimal * 100 = 10% + // ticks = percentage / widthPercentage = 2.5 + // + // ticks is a float64, so we cast it to an int which rounds it down to 2. + + // Values are cast to floats to prevent integer division. + current := p.Written() + total := p.Total() + // width := is passed as a parameter + widthPercentage := float64(100) / float64(width) + percentageDecimal := float64(current) / float64(total) + percentage := percentageDecimal * 100 + ticks := int(percentage / widthPercentage) + + // Ensure that we never get a negative number of ticks, this will prevent strings#Repeat + // from panicking. A negative number of ticks is likely to happen when the total size is + // inaccurate, such as when we are going off of rough disk usage calculation. + if ticks < 0 { + ticks = 0 + } else if ticks > width { + ticks = width + } + + bar := strings.Repeat("=", ticks) + strings.Repeat(" ", width-ticks) + return "[" + bar + "] " + system.FormatBytes(current) + " / " + system.FormatBytes(total) +} diff --git a/server/filesystem/archive_test.go b/internal/progress/progress_test.go similarity index 63% rename from server/filesystem/archive_test.go rename to internal/progress/progress_test.go index f6739f8..98037f5 100644 --- a/server/filesystem/archive_test.go +++ b/internal/progress/progress_test.go @@ -1,47 +1,49 @@ -package filesystem +package progress_test import ( "bytes" "testing" - . "github.com/franela/goblin" + "github.com/franela/goblin" + + "github.com/pterodactyl/wings/internal/progress" ) func TestProgress(t *testing.T) { - g := Goblin(t) + g := goblin.Goblin(t) g.Describe("Progress", func() { g.It("properly initializes", func() { - total := int64(1000) - p := NewProgress(total) + total := uint64(1000) + p := progress.NewProgress(total) g.Assert(p).IsNotNil() g.Assert(p.Total()).Equal(total) - g.Assert(p.Written()).Equal(int64(0)) + g.Assert(p.Written()).Equal(uint64(0)) }) g.It("increments written when Write is called", func() { v := []byte("hello") - p := NewProgress(1000) + p := progress.NewProgress(1000) _, err := p.Write(v) g.Assert(err).IsNil() - g.Assert(p.Written()).Equal(int64(len(v))) + g.Assert(p.Written()).Equal(uint64(len(v))) }) g.It("renders a progress bar", func() { v := bytes.Repeat([]byte{' '}, 100) - p := NewProgress(1000) + p := progress.NewProgress(1000) _, err := p.Write(v) g.Assert(err).IsNil() - g.Assert(p.Written()).Equal(int64(len(v))) + g.Assert(p.Written()).Equal(uint64(len(v))) g.Assert(p.Progress(25)).Equal("[== ] 100 B / 1000 B") }) g.It("renders a progress bar when written exceeds total", func() { v := bytes.Repeat([]byte{' '}, 1001) - p := NewProgress(1000) + p := progress.NewProgress(1000) _, err := p.Write(v) g.Assert(err).IsNil() - g.Assert(p.Written()).Equal(int64(len(v))) + g.Assert(p.Written()).Equal(uint64(len(v))) g.Assert(p.Progress(25)).Equal("[=========================] 1001 B / 1000 B") }) }) diff --git a/router/router.go b/router/router.go index 60e37cc..c978118 100644 --- a/router/router.go +++ b/router/router.go @@ -16,7 +16,7 @@ func Configure(m *wserver.Manager, client remote.Client) *gin.Engine { router := gin.New() router.Use(gin.Recovery()) - router.SetTrustedProxies(config.Get().Api.TrustedProxies) + _ = router.SetTrustedProxies(config.Get().Api.TrustedProxies) router.Use(middleware.AttachRequestID(), middleware.CaptureErrors(), middleware.SetAccessControlHeaders()) router.Use(middleware.AttachServerManager(m), middleware.AttachApiClient(client)) // @todo log this into a different file so you can setup IP blocking for abusive requests and such. @@ -40,7 +40,7 @@ func Configure(m *wserver.Manager, client remote.Client) *gin.Engine { router.GET("/download/file", getDownloadFile) router.POST("/upload/file", postServerUploadFiles) - // This route is special it sits above all of the other requests because we are + // This route is special it sits above all the other requests because we are // using a JWT to authorize access to it, therefore it needs to be publicly // accessible. router.GET("/api/servers/:server/ws", middleware.ServerExists(), getServerWebsocket) @@ -48,16 +48,16 @@ func Configure(m *wserver.Manager, client remote.Client) *gin.Engine { // This request is called by another daemon when a server is going to be transferred out. // This request does not need the AuthorizationMiddleware as the panel should never call it // and requests are authenticated through a JWT the panel issues to the other daemon. - router.GET("/api/servers/:server/archive", middleware.ServerExists(), getServerArchive) + router.POST("/api/transfers", postTransfers) - // All of the routes beyond this mount will use an authorization middleware + // All the routes beyond this mount will use an authorization middleware // and will not be accessible without the correct Authorization header provided. protected := router.Use(middleware.RequireAuthorization()) protected.POST("/api/update", postUpdateConfiguration) protected.GET("/api/system", getSystemInformation) protected.GET("/api/servers", getAllServers) protected.POST("/api/servers", postCreateServer) - protected.POST("/api/transfer", postTransfer) + protected.DELETE("/api/transfers/:server", deleteTransfer) // These are server specific routes, and require that the request be authorized, and // that the server exist on the Daemon. @@ -77,7 +77,8 @@ func Configure(m *wserver.Manager, client remote.Client) *gin.Engine { // This archive request causes the archive to start being created // this should only be triggered by the panel. - server.POST("/archive", postServerArchive) + server.POST("/transfer", postServerTransfer) + server.DELETE("/transfer", deleteServerTransfer) files := server.Group("/files") { diff --git a/router/router_server.go b/router/router_server.go index 0c53504..2c41272 100644 --- a/router/router_server.go +++ b/router/router_server.go @@ -14,6 +14,7 @@ import ( "github.com/pterodactyl/wings/router/middleware" "github.com/pterodactyl/wings/router/tokens" "github.com/pterodactyl/wings/server" + "github.com/pterodactyl/wings/server/transfer" ) // Returns a single server from the collection of servers. @@ -188,6 +189,17 @@ func deleteServer(c *gin.Context) { // Immediately suspend the server to prevent a user from attempting // to start it while this process is running. s.Config().SetSuspended(true) + + // Notify all websocket clients that the server is being deleted. + // This is useful for two reasons, one to tell clients not to bother + // retrying to connect to the websocket. And two, for transfers when + // the server has been successfully transferred to another node, and + // the client needs to switch to the new node. + if s.IsTransferring() { + s.Events().Publish(server.TransferStatusEvent, transfer.StatusCompleted) + } + s.Events().Publish(server.DeletedEvent, nil) + s.CleanupForDestroy() // Remove any pending remote file downloads for the server. @@ -199,7 +211,7 @@ func deleteServer(c *gin.Context) { // forcibly terminate it before removing the container, so we do not need to handle // that here. if err := s.Environment.Destroy(); err != nil { - WithError(c, err) + _ = WithError(c, err) return } @@ -207,7 +219,7 @@ func deleteServer(c *gin.Context) { // done in a separate process since failure is not the end of the world and can be // manually cleaned up after the fact. // - // In addition, servers with large amounts of files can take some time to finish deleting + // In addition, servers with large amounts of files can take some time to finish deleting, // so we don't want to block the HTTP call while waiting on this. go func(p string) { if err := os.RemoveAll(p); err != nil { diff --git a/router/router_server_transfer.go b/router/router_server_transfer.go new file mode 100644 index 0000000..c4171dd --- /dev/null +++ b/router/router_server_transfer.go @@ -0,0 +1,129 @@ +package router + +import ( + "context" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/pterodactyl/wings/environment" + "github.com/pterodactyl/wings/router/middleware" + "github.com/pterodactyl/wings/server" + "github.com/pterodactyl/wings/server/installer" + "github.com/pterodactyl/wings/server/transfer" +) + +// Data passed over to initiate a server transfer. +type serverTransferRequest struct { + URL string `binding:"required" json:"url"` + Token string `binding:"required" json:"token"` + Server installer.ServerDetails `json:"server"` +} + +// postServerTransfer handles the start of a transfer for a server. +func postServerTransfer(c *gin.Context) { + var data serverTransferRequest + if err := c.BindJSON(&data); err != nil { + return + } + + s := ExtractServer(c) + + // Check if the server is already being transferred. + // There will be another endpoint for resetting this value either by deleting the + // server, or by canceling the transfer. + if s.IsTransferring() { + c.AbortWithStatusJSON(http.StatusConflict, gin.H{ + "error": "A transfer is already in progress for this server.", + }) + return + } + + manager := middleware.ExtractManager(c) + + notifyPanelOfFailure := func() { + if err := manager.Client().SetTransferStatus(context.Background(), s.ID(), false); err != nil { + s.Log().WithField("subsystem", "transfer"). + WithField("status", false). + WithError(err). + Error("failed to set transfer status") + } + + s.Events().Publish(server.TransferStatusEvent, "failure") + s.SetTransferring(false) + } + + // Block the server from starting while we are transferring it. + s.SetTransferring(true) + + // Ensure the server is offline. Sometimes a "No such container" error gets through + // which means the server is already stopped. We can ignore that. + if s.Environment.State() != environment.ProcessOfflineState { + if err := s.Environment.WaitForStop( + s.Context(), + time.Minute, + false, + ); err != nil && !strings.Contains(strings.ToLower(err.Error()), "no such container") { + notifyPanelOfFailure() + s.Log().WithError(err).Error("failed to stop server for transfer") + return + } + } + + // Create a new transfer instance for this server. + trnsfr := transfer.New(context.Background(), s) + transfer.Outgoing().Add(trnsfr) + + go func() { + defer transfer.Outgoing().Remove(trnsfr) + + if _, err := trnsfr.PushArchiveToTarget(data.URL, data.Token); err != nil { + notifyPanelOfFailure() + + if err == context.Canceled { + trnsfr.Log().Debug("canceled") + trnsfr.SendMessage("Canceled.") + return + } + + trnsfr.Log().WithError(err).Error("failed to push archive to target") + return + } + + // DO NOT NOTIFY THE PANEL OF SUCCESS HERE. The only node that should send + // a success status is the destination node. When we send a failure status, + // the panel will automatically cancel the transfer and attempt to reset + // the server state on the destination node, we just need to make sure + // we clean up our statuses for failure. + + trnsfr.Log().Debug("transfer complete") + }() + + c.Status(http.StatusAccepted) +} + +// deleteServerTransfer cancels an outgoing transfer for a server. +func deleteServerTransfer(c *gin.Context) { + s := ExtractServer(c) + + if !s.IsTransferring() { + c.AbortWithStatusJSON(http.StatusConflict, gin.H{ + "error": "Server is not currently being transferred.", + }) + return + } + + trnsfr := transfer.Outgoing().Get(s.ID()) + if trnsfr == nil { + c.AbortWithStatusJSON(http.StatusConflict, gin.H{ + "error": "Server is not currently being transferred.", + }) + return + } + + trnsfr.Cancel() + + c.Status(http.StatusAccepted) +} diff --git a/router/router_system.go b/router/router_system.go index d1e1469..48bd20f 100644 --- a/router/router_system.go +++ b/router/router_system.go @@ -10,9 +10,9 @@ import ( "github.com/gin-gonic/gin" "github.com/pterodactyl/wings/config" - "github.com/pterodactyl/wings/installer" "github.com/pterodactyl/wings/router/middleware" "github.com/pterodactyl/wings/server" + "github.com/pterodactyl/wings/server/installer" "github.com/pterodactyl/wings/system" ) @@ -28,7 +28,7 @@ func getSystemInformation(c *gin.Context) { c.JSON(http.StatusOK, i) } -// Returns all of the servers that are registered and configured correctly on +// Returns all the servers that are registered and configured correctly on // this wings instance. func getAllServers(c *gin.Context) { servers := middleware.ExtractManager(c).All() @@ -117,7 +117,7 @@ func postUpdateConfiguration(c *gin.Context) { // Try to write this new configuration to the disk before updating our global // state with it. if err := config.WriteToDisk(cfg); err != nil { - WithError(c, err) + _ = WithError(c, err) return } // Since we wrote it to the disk successfully now update the global configuration diff --git a/router/router_transfer.go b/router/router_transfer.go index 9406e11..b57b304 100644 --- a/router/router_transfer.go +++ b/router/router_transfer.go @@ -1,53 +1,33 @@ package router import ( - "bufio" + "bytes" "context" "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" + "mime" + "mime/multipart" "net/http" "os" - "path/filepath" - "strconv" "strings" - "time" - "emperror.dev/errors" "github.com/apex/log" "github.com/gin-gonic/gin" "github.com/google/uuid" - "github.com/juju/ratelimit" - "github.com/mitchellh/colorstring" - "github.com/pterodactyl/wings/config" - "github.com/pterodactyl/wings/installer" - "github.com/pterodactyl/wings/remote" "github.com/pterodactyl/wings/router/middleware" "github.com/pterodactyl/wings/router/tokens" "github.com/pterodactyl/wings/server" - "github.com/pterodactyl/wings/server/filesystem" + "github.com/pterodactyl/wings/server/installer" + "github.com/pterodactyl/wings/server/transfer" ) -const progressWidth = 25 - -// Data passed over to initiate a server transfer. -type serverTransferRequest struct { - ServerID string `binding:"required" json:"server_id"` - URL string `binding:"required" json:"url"` - Token string `binding:"required" json:"token"` - Server installer.ServerDetails `json:"server"` -} - -func getArchivePath(sID string) string { - return filepath.Join(config.Get().System.ArchiveDirectory, sID+".tar.gz") -} - -// Returns the archive for a server so that it can be transferred to a new node. -func getServerArchive(c *gin.Context) { +// postTransfers . +func postTransfers(c *gin.Context) { auth := strings.SplitN(c.GetHeader("Authorization"), " ", 2) - if len(auth) != 2 || auth[0] != "Bearer" { c.Header("WWW-Authenticate", "Bearer") c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ @@ -62,458 +42,232 @@ func getServerArchive(c *gin.Context) { return } + manager := middleware.ExtractManager(c) + u, err := uuid.Parse(token.Subject) + if err != nil { + NewTrackedError(err).Abort(c) + return + } + + // Get or create a new transfer instance for this server. + var ( + ctx context.Context + cancel context.CancelFunc + ) + trnsfr := transfer.Incoming().Get(u.String()) + if trnsfr == nil { + // TODO: should this use the request context? + trnsfr = transfer.New(c, nil) + + ctx, cancel = context.WithCancel(trnsfr.Context()) + defer cancel() + + i, err := installer.New(ctx, manager, installer.ServerDetails{ + UUID: u.String(), + StartOnCompletion: false, + }) + if err != nil { + if err := manager.Client().SetTransferStatus(context.Background(), trnsfr.Server.ID(), false); err != nil { + trnsfr.Log().WithField("status", false).WithError(err).Error("failed to set transfer status") + } + NewTrackedError(err).Abort(c) + return + } + + i.Server().SetTransferring(true) + manager.Add(i.Server()) + + // We add the transfer to the list of transfers once we have a server instance to use. + trnsfr.Server = i.Server() + transfer.Incoming().Add(trnsfr) + } else { + ctx, cancel = context.WithCancel(trnsfr.Context()) + defer cancel() + } + + // Any errors past this point (until the transfer is complete) will abort + // the transfer. + + successful := false + defer func(ctx context.Context, trnsfr *transfer.Transfer) { + // Remove the transfer from the list of incoming transfers. + transfer.Incoming().Remove(trnsfr) + + if !successful { + trnsfr.Server.Events().Publish(server.TransferStatusEvent, "failure") + manager.Remove(func(match *server.Server) bool { + return match.ID() == trnsfr.Server.ID() + }) + } + + if err := manager.Client().SetTransferStatus(context.Background(), trnsfr.Server.ID(), successful); err != nil { + // Only delete the files if the transfer actually failed, otherwise we could have + // unrecoverable data-loss. + if !successful && err != nil { + // Delete all extracted files. + go func(trnsfr *transfer.Transfer) { + if err := os.RemoveAll(trnsfr.Server.Filesystem().Path()); err != nil && !os.IsNotExist(err) { + trnsfr.Log().WithError(err).Warn("failed to delete local server files") + } + }(trnsfr) + } + + trnsfr.Log().WithField("status", successful).WithError(err).Error("failed to set transfer status on panel") + return + } + + trnsfr.Server.SetTransferring(false) + trnsfr.Server.Events().Publish(server.TransferStatusEvent, "success") + }(ctx, trnsfr) + + mediaType, params, err := mime.ParseMediaType(c.GetHeader("Content-Type")) + if err != nil { + trnsfr.Log().Debug("failed to parse content type header") + NewTrackedError(err).Abort(c) + return + } + + if !strings.HasPrefix(mediaType, "multipart/") { + trnsfr.Log().Debug("invalid content type") + NewTrackedError(fmt.Errorf("invalid content type \"%s\", expected \"multipart/form-data\"", mediaType)).Abort(c) + return + } + + // Used to calculate the hash of the file as it is being uploaded. + h := sha256.New() + + // Used to read the file and checksum from the request body. + mr := multipart.NewReader(c.Request.Body, params["boundary"]) + + // Loop through the parts of the request body and process them. + var ( + hasArchive bool + hasChecksum bool + checksumVerified bool + ) +out: + for { + select { + case <-ctx.Done(): + break out + default: + p, err := mr.NextPart() + if err == io.EOF { + break out + } + if err != nil { + NewTrackedError(err).Abort(c) + return + } + + name := p.FormName() + switch name { + case "archive": + trnsfr.Log().Debug("received archive") + + if err := trnsfr.Server.EnsureDataDirectoryExists(); err != nil { + NewTrackedError(err).Abort(c) + return + } + + tee := io.TeeReader(p, h) + if err := trnsfr.Server.Filesystem().ExtractStreamUnsafe(ctx, "/", tee); err != nil { + NewTrackedError(err).Abort(c) + return + } + + hasArchive = true + case "checksum": + trnsfr.Log().Debug("received checksum") + + if !hasArchive { + NewTrackedError(errors.New("archive must be sent before the checksum")).Abort(c) + return + } + + hasChecksum = true + + v, err := io.ReadAll(p) + if err != nil { + NewTrackedError(err).Abort(c) + return + } + + expected := make([]byte, hex.DecodedLen(len(v))) + n, err := hex.Decode(expected, v) + if err != nil { + NewTrackedError(err).Abort(c) + return + } + actual := h.Sum(nil) + + trnsfr.Log().WithFields(log.Fields{ + "expected": hex.EncodeToString(expected), + "actual": hex.EncodeToString(actual), + }).Debug("checksums") + + if !bytes.Equal(expected[:n], actual) { + NewTrackedError(errors.New("checksums don't match")).Abort(c) + return + } + + trnsfr.Log().Debug("checksums match") + checksumVerified = true + default: + continue + } + } + } + + if !hasArchive || !hasChecksum { + NewTrackedError(errors.New("missing archive or checksum")).Abort(c) + return + } + + if !checksumVerified { + NewTrackedError(errors.New("checksums don't match")).Abort(c) + return + } + + // Transfer is almost complete, we just want to ensure the environment is + // configured correctly. We might want to not fail the transfer at this + // stage, but we will just to be safe. + + // Ensure the server environment gets configured. + if err := trnsfr.Server.CreateEnvironment(); err != nil { + NewTrackedError(err).Abort(c) + return + } + + // Changing this causes us to notify the panel about a successful transfer, + // rather than failing the transfer like we do by default. + successful = true + + // The rest of the logic for ensuring the server is unlocked and everything + // is handled in the deferred function above. + trnsfr.Log().Debug("done!") +} + +// deleteTransfer cancels an incoming transfer for a server. +func deleteTransfer(c *gin.Context) { s := ExtractServer(c) - if token.Subject != s.ID() { - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{ - "error": "Missing required token subject, or subject is not valid for the requested server.", + + if !s.IsTransferring() { + c.AbortWithStatusJSON(http.StatusConflict, gin.H{ + "error": "Server is not currently being transferred.", }) return } - archivePath := getArchivePath(s.ID()) - - // Stat the archive file. - st, err := os.Lstat(archivePath) - if err != nil { - if !errors.Is(err, os.ErrNotExist) { - _ = WithError(c, err) - return - } - c.AbortWithStatus(http.StatusNotFound) + trnsfr := transfer.Incoming().Get(s.ID()) + if trnsfr == nil { + c.AbortWithStatusJSON(http.StatusConflict, gin.H{ + "error": "Server is not currently being transferred.", + }) return } - // Compute sha256 checksum. - h := sha256.New() - f, err := os.Open(archivePath) - if err != nil { - return - } - if _, err := io.Copy(h, bufio.NewReader(f)); err != nil { - _ = f.Close() - _ = WithError(c, err) - return - } - if err := f.Close(); err != nil { - _ = WithError(c, err) - return - } - checksum := hex.EncodeToString(h.Sum(nil)) - - // Stream the file to the client. - f, err = os.Open(archivePath) - if err != nil { - _ = WithError(c, err) - return - } - defer f.Close() - - c.Header("X-Checksum", checksum) - c.Header("X-Mime-Type", "application/tar+gzip") - c.Header("Content-Length", strconv.Itoa(int(st.Size()))) - c.Header("Content-Disposition", "attachment; filename="+strconv.Quote(s.ID()+".tar.gz")) - c.Header("Content-Type", "application/octet-stream") - - _, _ = bufio.NewReader(f).WriteTo(c.Writer) -} - -func postServerArchive(c *gin.Context) { - s := middleware.ExtractServer(c) - manager := middleware.ExtractManager(c) - - go func(s *server.Server) { - l := log.WithField("server", s.ID()) - - // This function automatically adds the Source Node prefix and Timestamp to the log - // output before sending it over the websocket. - sendTransferLog := func(data string) { - output := colorstring.Color(fmt.Sprintf("[yellow][bold]%s [Pterodactyl Transfer System] [Source Node]:[default] %s", time.Now().Format(time.RFC1123), data)) - s.Events().Publish(server.TransferLogsEvent, output) - } - - s.Events().Publish(server.TransferStatusEvent, "starting") - sendTransferLog("Attempting to archive server...") - - hasError := true - defer func() { - if !hasError { - return - } - - // Mark the server as not being transferred so it can actually be used. - s.SetTransferring(false) - s.Events().Publish(server.TransferStatusEvent, "failure") - - sendTransferLog("Attempting to notify panel of archive failure..") - if err := manager.Client().SetArchiveStatus(s.Context(), s.ID(), false); err != nil { - if !remote.IsRequestError(err) { - sendTransferLog("Failed to notify panel of archive failure: " + err.Error()) - l.WithField("error", err).Error("failed to notify panel of failed archive status") - return - } - - sendTransferLog("Panel returned an error while notifying it of a failed archive: " + err.Error()) - l.WithField("error", err.Error()).Error("panel returned an error when notifying it of a failed archive status") - return - } - - sendTransferLog("Successfully notified panel of failed archive status") - l.Info("successfully notified panel of failed archive status") - }() - - // Mark the server as transferring to prevent problems. - s.SetTransferring(true) - - // Ensure the server is offline. Sometimes a "No such container" error gets through - // which means the server is already stopped. We can ignore that. - if err := s.Environment.WaitForStop(s.Context(), time.Minute, false); err != nil && !strings.Contains(strings.ToLower(err.Error()), "no such container") { - sendTransferLog("Failed to stop server, aborting transfer..") - l.WithField("error", err).Error("failed to stop server") - return - } - - // Get the disk usage of the server (used to calculate the progress of the archive process) - rawSize, err := s.Filesystem().DiskUsage(true) - if err != nil { - sendTransferLog("Failed to get disk usage for server, aborting transfer..") - l.WithField("error", err).Error("failed to get disk usage for server") - return - } - - // Create an archive of the entire server's data directory. - a := &filesystem.Archive{ - BasePath: s.Filesystem().Path(), - Progress: filesystem.NewProgress(rawSize), - } - - // Send the archive progress to the websocket every 3 seconds. - ctx2, cancel := context.WithCancel(s.Context()) - defer cancel() - go func(ctx context.Context, p *filesystem.Progress, t *time.Ticker) { - defer t.Stop() - for { - select { - case <-ctx.Done(): - return - case <-t.C: - sendTransferLog("Archiving " + p.Progress(progressWidth)) - } - } - }(ctx2, a.Progress, time.NewTicker(5*time.Second)) - - // Attempt to get an archive of the server. - if err := a.Create(getArchivePath(s.ID())); err != nil { - sendTransferLog("An error occurred while archiving the server: " + err.Error()) - l.WithField("error", err).Error("failed to get transfer archive for server") - return - } - - // Cancel the progress ticker. - cancel() - - // Show 100% completion. - sendTransferLog("Archiving " + a.Progress.Progress(progressWidth)) - - sendTransferLog("Successfully created archive, attempting to notify panel..") - l.Info("successfully created server transfer archive, notifying panel..") - - if err := manager.Client().SetArchiveStatus(s.Context(), s.ID(), true); err != nil { - if !remote.IsRequestError(err) { - sendTransferLog("Failed to notify panel of archive success: " + err.Error()) - l.WithField("error", err).Error("failed to notify panel of successful archive status") - return - } - - sendTransferLog("Panel returned an error while notifying it of a successful archive: " + err.Error()) - l.WithField("error", err.Error()).Error("panel returned an error when notifying it of a successful archive status") - return - } - - hasError = false - - // This log may not be displayed by the client due to the status event being sent before or at the same time. - sendTransferLog("Successfully notified panel of successful archive status") - - l.Info("successfully notified panel of successful transfer archive status") - s.Events().Publish(server.TransferStatusEvent, "archived") - }(s) - - c.Status(http.StatusAccepted) -} - -// Log helper function to attach all errors and info output to a consistently formatted -// log string for easier querying. -func (str serverTransferRequest) log() *log.Entry { - return log.WithField("subsystem", "transfers").WithField("server_id", str.ServerID) -} - -// Downloads an archive from the machine that the server currently lives on. -func (str serverTransferRequest) downloadArchive() (*http.Response, error) { - client := http.Client{Timeout: 0} - req, err := http.NewRequest(http.MethodGet, str.URL, nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", str.Token) - res, err := client.Do(req) // lgtm [go/request-forgery] - if err != nil { - return nil, err - } - return res, nil -} - -// Returns the path to the local archive on the system. -func (str serverTransferRequest) path() string { - return getArchivePath(str.ServerID) -} - -// Creates the archive location on this machine by first checking that the required file -// does not already exist. If it does exist, the file is deleted and then re-created as -// an empty file. -func (str serverTransferRequest) createArchiveFile() (*os.File, error) { - p := str.path() - if _, err := os.Stat(p); err != nil { - if !os.IsNotExist(err) { - return nil, err - } - } else if err := os.Remove(p); err != nil { - return nil, err - } - return os.Create(p) -} - -// Deletes the archive from the local filesystem. This is executed as a deferred function. -func (str serverTransferRequest) removeArchivePath() { - p := str.path() - str.log().Debug("deleting temporary transfer archive") - if err := os.Remove(p); err != nil && !os.IsNotExist(err) { - str.log().WithField("path", p).WithField("error", err).Error("failed to delete temporary transfer archive file") - return - } - str.log().Debug("deleted temporary transfer archive successfully") -} - -// Verifies that the SHA-256 checksum of the file on the local filesystem matches the -// expected value from the transfer request. The string value returned is the computed -// checksum on the system. -func (str serverTransferRequest) verifyChecksum(matches string) (bool, string, error) { - f, err := os.Open(str.path()) - if err != nil { - return false, "", err - } - defer f.Close() - h := sha256.New() - if _, err := io.Copy(h, bufio.NewReader(f)); err != nil { - return false, "", err - } - checksum := hex.EncodeToString(h.Sum(nil)) - return checksum == matches, checksum, nil -} - -// Sends a notification to the Panel letting it know what the status of this transfer is. -func (str serverTransferRequest) sendTransferStatus(client remote.Client, successful bool) error { - lg := str.log().WithField("transfer_successful", successful) - lg.Info("notifying Panel of server transfer state") - if err := client.SetTransferStatus(context.Background(), str.ServerID, successful); err != nil { - lg.WithField("error", err).Error("error notifying panel of transfer state") - return err - } - lg.Debug("notified panel of transfer state") - return nil -} - -// Initiates a transfer between two nodes for a server by downloading an archive from the -// remote node and then applying the server details to this machine. -func postTransfer(c *gin.Context) { - var data serverTransferRequest - if err := c.BindJSON(&data); err != nil { - return - } - - manager := middleware.ExtractManager(c) - u, err := uuid.Parse(data.ServerID) - if err != nil { - _ = WithError(c, err) - return - } - // Force the server ID to be a valid UUID string at this point. If it is not an error - // is returned to the caller. This limits injection vulnerabilities that would cause - // the str.path() function to return a location not within the server archive directory. - data.ServerID = u.String() - - data.log().Info("handling incoming server transfer request") - go func(data *serverTransferRequest) { - ctx := context.Background() - hasError := true - - // Create a new server installer. This will only configure the environment and not - // run the installer scripts. - i, err := installer.New(ctx, manager, data.Server) - if err != nil { - _ = data.sendTransferStatus(manager.Client(), false) - data.log().WithField("error", err).Error("failed to validate received server data") - return - } - - // This function automatically adds the Target Node prefix and Timestamp to the log output before sending it - // over the websocket. - sendTransferLog := func(data string) { - output := colorstring.Color(fmt.Sprintf("[yellow][bold]%s [Pterodactyl Transfer System] [Target Node]:[default] %s", time.Now().Format(time.RFC1123), data)) - i.Server().Events().Publish(server.TransferLogsEvent, output) - } - - // Mark the server as transferring to prevent problems later on during the process and - // then push the server into the global server collection for this instance. - i.Server().SetTransferring(true) - manager.Add(i.Server()) - defer func(s *server.Server) { - // In the event that this transfer call fails, remove the server from the global - // server tracking so that we don't have a dangling instance. - if err := data.sendTransferStatus(manager.Client(), !hasError); hasError || err != nil { - sendTransferLog("Server transfer failed, check Wings logs for additional information.") - s.Events().Publish(server.TransferStatusEvent, "failure") - manager.Remove(func(match *server.Server) bool { - return match.ID() == s.ID() - }) - - // If the transfer status was successful but the request failed, act like the transfer failed. - if !hasError && err != nil { - // Delete all extracted files. - if err := os.RemoveAll(s.Filesystem().Path()); err != nil && !os.IsNotExist(err) { - data.log().WithField("error", err).Warn("failed to delete local server files directory") - } - } - } else { - s.SetTransferring(false) - s.Events().Publish(server.TransferStatusEvent, "success") - sendTransferLog("Transfer completed.") - } - }(i.Server()) - - data.log().Info("downloading server archive from current server node") - sendTransferLog("Received incoming transfer from Panel, attempting to download archive from source node...") - res, err := data.downloadArchive() - if err != nil { - sendTransferLog("Failed to retrieve server archive from remote node: " + err.Error()) - data.log().WithField("error", err).Error("failed to download archive for server transfer") - return - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - data.log().WithField("error", err).WithField("status", res.StatusCode).Error("unexpected error response from transfer endpoint") - return - } - - size := res.ContentLength - if size == 0 { - data.log().WithField("error", err).Error("received an archive response with Content-Length of 0") - return - } - sendTransferLog("Got server archive response from remote node. (Content-Length: " + strconv.Itoa(int(size)) + ")") - sendTransferLog("Creating local archive file...") - file, err := data.createArchiveFile() - if err != nil { - data.log().WithField("error", err).Error("failed to create archive file on local filesystem") - return - } - - sendTransferLog("Writing archive to disk...") - data.log().Info("writing transfer archive to disk...") - - progress := filesystem.NewProgress(size) - progress.SetWriter(file) - - // Send the archive progress to the websocket every 3 seconds. - ctx2, cancel := context.WithCancel(ctx) - defer cancel() - go func(ctx context.Context, p *filesystem.Progress, t *time.Ticker) { - defer t.Stop() - for { - select { - case <-ctx.Done(): - return - case <-t.C: - sendTransferLog("Downloading " + p.Progress(progressWidth)) - } - } - }(ctx2, progress, time.NewTicker(5*time.Second)) - - var reader io.Reader - downloadLimit := float64(config.Get().System.Transfers.DownloadLimit) * 1024 * 1024 - if downloadLimit > 0 { - // Wrap the body with a reader that is limited to the defined download limit speed. - reader = ratelimit.Reader(res.Body, ratelimit.NewBucketWithRate(downloadLimit, int64(downloadLimit))) - } else { - reader = res.Body - } - - buf := make([]byte, 1024*4) - if _, err := io.CopyBuffer(progress, reader, buf); err != nil { - _ = file.Close() - - sendTransferLog("Failed while writing archive file to disk: " + err.Error()) - data.log().WithField("error", err).Error("failed to copy archive file to disk") - return - } - cancel() - - // Show 100% completion. - sendTransferLog("Downloading " + progress.Progress(progressWidth)) - - if err := file.Close(); err != nil { - data.log().WithField("error", err).Error("unable to close archive file on local filesystem") - return - } - data.log().Info("finished writing transfer archive to disk") - sendTransferLog("Successfully wrote archive to disk.") - - // Whenever the transfer fails or succeeds, delete the temporary transfer archive that - // was created on the disk. - defer data.removeArchivePath() - - sendTransferLog("Verifying checksum of downloaded archive...") - data.log().Info("computing checksum of downloaded archive file") - expected := res.Header.Get("X-Checksum") - if matches, computed, err := data.verifyChecksum(expected); err != nil { - data.log().WithField("error", err).Error("encountered an error while calculating local filesystem archive checksum") - return - } else if !matches { - sendTransferLog("@@@@@ CHECKSUM VERIFICATION FAILED @@@@@") - sendTransferLog(" - Source Checksum: " + expected) - sendTransferLog(" - Computed Checksum: " + computed) - data.log().WithField("expected_sum", expected).WithField("computed_checksum", computed).Error("checksum mismatch when verifying integrity of local archive") - return - } - - // Create the server's environment. - sendTransferLog("Creating server environment, this could take a while..") - data.log().Info("creating server environment") - if err := i.Server().CreateEnvironment(); err != nil { - data.log().WithField("error", err).Error("failed to create server environment") - return - } - - sendTransferLog("Server environment has been created, extracting transfer archive..") - data.log().Info("server environment configured, extracting transfer archive") - if err := i.Server().Filesystem().DecompressFileUnsafe(ctx, "/", data.path()); err != nil { - // Un-archiving failed, delete the server's data directory. - if err := os.RemoveAll(i.Server().Filesystem().Path()); err != nil && !os.IsNotExist(err) { - data.log().WithField("error", err).Warn("failed to delete local server files directory") - } - data.log().WithField("error", err).Error("failed to extract server archive") - return - } - - // We mark the process as being successful here as if we fail to send a transfer success, - // then a transfer failure won't probably be successful either. - // - // It may be useful to retry sending the transfer success every so often just in case of a small - // hiccup or the fix of whatever error causing the success request to fail. - hasError = false - - data.log().Info("archive extracted successfully, notifying Panel of status") - sendTransferLog("Archive extracted successfully.") - }(&data) + trnsfr.Cancel() c.Status(http.StatusAccepted) } diff --git a/server/backup/backup_local.go b/server/backup/backup_local.go index 4ca20cc..bdaeff6 100644 --- a/server/backup/backup_local.go +++ b/server/backup/backup_local.go @@ -66,7 +66,7 @@ func (b *LocalBackup) Generate(ctx context.Context, basePath, ignore string) (*A } b.log().WithField("path", b.Path()).Info("creating backup for server") - if err := a.Create(b.Path()); err != nil { + if err := a.Create(ctx, b.Path()); err != nil { return nil, err } b.log().Info("created backup successfully") diff --git a/server/backup/backup_s3.go b/server/backup/backup_s3.go index 392497d..ede235a 100644 --- a/server/backup/backup_s3.go +++ b/server/backup/backup_s3.go @@ -57,7 +57,7 @@ func (s *S3Backup) Generate(ctx context.Context, basePath, ignore string) (*Arch } s.log().WithField("path", s.Path()).Info("creating backup for server") - if err := a.Create(s.Path()); err != nil { + if err := a.Create(ctx, s.Path()); err != nil { return nil, err } s.log().Info("created backup successfully") diff --git a/server/events.go b/server/events.go index 6d9af1d..d08411c 100644 --- a/server/events.go +++ b/server/events.go @@ -5,8 +5,7 @@ import ( "github.com/pterodactyl/wings/system" ) -// Defines all of the possible output events for a server. -// noinspection GoNameStartsWithPackageName +// Defines all the possible output events for a server. const ( DaemonMessageEvent = "daemon message" InstallOutputEvent = "install output" @@ -19,6 +18,7 @@ const ( BackupCompletedEvent = "backup completed" TransferLogsEvent = "transfer logs" TransferStatusEvent = "transfer status" + DeletedEvent = "deleted" ) // Events returns the server's emitter instance. diff --git a/server/filesystem/archive.go b/server/filesystem/archive.go index edad0ed..968a280 100644 --- a/server/filesystem/archive.go +++ b/server/filesystem/archive.go @@ -2,13 +2,13 @@ package filesystem import ( "archive/tar" + "context" "io" "io/fs" "os" "path/filepath" "strings" "sync" - "sync/atomic" "emperror.dev/errors" "github.com/apex/log" @@ -18,7 +18,7 @@ import ( ignore "github.com/sabhiram/go-gitignore" "github.com/pterodactyl/wings/config" - "github.com/pterodactyl/wings/system" + "github.com/pterodactyl/wings/internal/progress" ) const memory = 4 * 1024 @@ -33,13 +33,13 @@ var pool = sync.Pool{ // TarProgress . type TarProgress struct { *tar.Writer - p *Progress + p *progress.Progress } // NewTarProgress . -func NewTarProgress(w *tar.Writer, p *Progress) *TarProgress { +func NewTarProgress(w *tar.Writer, p *progress.Progress) *TarProgress { if p != nil { - p.w = w + p.Writer = w } return &TarProgress{ Writer: w, @@ -47,6 +47,7 @@ func NewTarProgress(w *tar.Writer, p *Progress) *TarProgress { } } +// Write . func (p *TarProgress) Write(v []byte) (int, error) { if p.p == nil { return p.Writer.Write(v) @@ -54,84 +55,6 @@ func (p *TarProgress) Write(v []byte) (int, error) { return p.p.Write(v) } -// Progress is used to track the progress of any I/O operation that are being -// performed. -type Progress struct { - // written is the total size of the files that have been written to the writer. - written int64 - // Total is the total size of the archive in bytes. - total int64 - // w . - w io.Writer -} - -// NewProgress . -func NewProgress(total int64) *Progress { - return &Progress{total: total} -} - -// SetWriter sets the writer progress will forward writes to. -// NOTE: This function is not thread safe. -func (p *Progress) SetWriter(w io.Writer) { - p.w = w -} - -// Written returns the total number of bytes written. -// This function should be used when the progress is tracking data being written. -func (p *Progress) Written() int64 { - return atomic.LoadInt64(&p.written) -} - -// Total returns the total size in bytes. -func (p *Progress) Total() int64 { - return atomic.LoadInt64(&p.total) -} - -// Write totals the number of bytes that have been written to the writer. -func (p *Progress) Write(v []byte) (int, error) { - n := len(v) - atomic.AddInt64(&p.written, int64(n)) - if p.w != nil { - return p.w.Write(v) - } - return n, nil -} - -// Progress returns a formatted progress string for the current progress. -func (p *Progress) Progress(width int) string { - // current = 100 (Progress, dynamic) - // total = 1000 (Content-Length, dynamic) - // width = 25 (Number of ticks to display, static) - // widthPercentage = 100 / width (What percentage does each tick represent, static) - // - // percentageDecimal = current / total = 0.1 - // percentage = percentageDecimal * 100 = 10% - // ticks = percentage / widthPercentage = 2.5 - // - // ticks is a float64, so we cast it to an int which rounds it down to 2. - - // Values are cast to floats to prevent integer division. - current := p.Written() - total := p.Total() - // width := is passed as a parameter - widthPercentage := float64(100) / float64(width) - percentageDecimal := float64(current) / float64(total) - percentage := percentageDecimal * 100 - ticks := int(percentage / widthPercentage) - - // Ensure that we never get a negative number of ticks, this will prevent strings#Repeat - // from panicking. A negative number of ticks is likely to happen when the total size is - // inaccurate, such as when we are going off of rough disk usage calculation. - if ticks < 0 { - ticks = 0 - } else if ticks > width { - ticks = width - } - - bar := strings.Repeat("=", ticks) + strings.Repeat(" ", width-ticks) - return "[" + bar + "] " + system.FormatBytes(current) + " / " + system.FormatBytes(total) -} - type Archive struct { // BasePath is the absolute path to create the archive from where Files and Ignore are // relative to. @@ -146,12 +69,12 @@ type Archive struct { Files []string // Progress wraps the writer of the archive to pass through the progress tracker. - Progress *Progress + Progress *progress.Progress } // Create creates an archive at dst with all the files defined in the // included Files array. -func (a *Archive) Create(dst string) error { +func (a *Archive) Create(ctx context.Context, dst string) error { f, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o600) if err != nil { return err @@ -169,6 +92,11 @@ func (a *Archive) Create(dst string) error { writer = f } + return a.Stream(ctx, writer) +} + +// Stream . +func (a *Archive) Stream(ctx context.Context, w io.Writer) error { // Choose which compression level to use based on the compression_level configuration option var compressionLevel int switch config.Get().System.Backups.CompressionLevel { @@ -183,7 +111,7 @@ func (a *Archive) Create(dst string) error { } // Create a new gzip writer around the file. - gw, _ := pgzip.NewWriterLevel(writer, compressionLevel) + gw, _ := pgzip.NewWriterLevel(w, compressionLevel) _ = gw.SetConcurrency(1<<20, 1) defer gw.Close() @@ -197,16 +125,16 @@ func (a *Archive) Create(dst string) error { options := &godirwalk.Options{ FollowSymbolicLinks: false, Unsorted: true, - Callback: a.callback(pw), } // If we're specifically looking for only certain files, or have requested // that certain files be ignored we'll update the callback function to reflect // that request. + var callback godirwalk.WalkFunc if len(a.Files) == 0 && len(a.Ignore) > 0 { i := ignore.CompileIgnoreLines(strings.Split(a.Ignore, "\n")...) - options.Callback = a.callback(pw, func(_ string, rp string) error { + callback = a.callback(pw, func(_ string, rp string) error { if i.MatchesPath(rp) { return godirwalk.SkipThis } @@ -214,7 +142,19 @@ func (a *Archive) Create(dst string) error { return nil }) } else if len(a.Files) > 0 { - options.Callback = a.withFilesCallback(pw) + callback = a.withFilesCallback(pw) + } else { + callback = a.callback(pw) + } + + // Set the callback function, wrapped with support for context cancellation. + options.Callback = func(path string, de *godirwalk.Dirent) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return callback(path, de) + } } // Recursively walk the path we are archiving. diff --git a/server/filesystem/compress.go b/server/filesystem/compress.go index 61b0b27..f7c405a 100644 --- a/server/filesystem/compress.go +++ b/server/filesystem/compress.go @@ -6,6 +6,7 @@ import ( "compress/gzip" "context" "fmt" + "io" iofs "io/fs" "os" "path" @@ -21,7 +22,7 @@ import ( "github.com/mholt/archiver/v4" ) -// CompressFiles compresses all of the files matching the given paths in the +// CompressFiles compresses all the files matching the given paths in the // specified directory. This function also supports passing nested paths to only // compress certain files and folders when working in a larger directory. This // effectively creates a local backup, but rather than ignoring specific files @@ -36,7 +37,7 @@ func (fs *Filesystem) CompressFiles(dir string, paths []string) (os.FileInfo, er return nil, err } - // Take all of the paths passed in and merge them together with the root directory we've gotten. + // Take all the paths passed in and merge them together with the root directory we've gotten. for i, p := range paths { paths[i] = filepath.Join(cleanedRootDir, p) } @@ -52,7 +53,7 @@ func (fs *Filesystem) CompressFiles(dir string, paths []string) (os.FileInfo, er fmt.Sprintf("archive-%s.tar.gz", strings.ReplaceAll(time.Now().Format(time.RFC3339), ":", "")), ) - if err := a.Create(d); err != nil { + if err := a.Create(context.Background(), d); err != nil { return nil, err } @@ -147,6 +148,7 @@ func (fs *Filesystem) DecompressFileUnsafe(ctx context.Context, dir string, file if err != nil { return err } + // TODO: defer file close? // Identify the type of archive we are dealing with. format, input, err := archiver.Identify(filepath.Base(file), f) @@ -157,13 +159,49 @@ func (fs *Filesystem) DecompressFileUnsafe(ctx context.Context, dir string, file return err } + return fs.extractStream(ctx, extractStreamOptions{ + Directory: dir, + Format: format, + Reader: input, + }) +} + +// ExtractStreamUnsafe . +func (fs *Filesystem) ExtractStreamUnsafe(ctx context.Context, dir string, r io.Reader) error { + format, input, err := archiver.Identify("archive.tar.gz", r) + if err != nil { + if errors.Is(err, archiver.ErrNoMatch) { + return newFilesystemError(ErrCodeUnknownArchive, err) + } + return err + } + + return fs.extractStream(ctx, extractStreamOptions{ + Directory: dir, + Format: format, + Reader: input, + }) +} + +type extractStreamOptions struct { + // The directory to extract the archive to. + Directory string + // File name of the archive. + FileName string + // Format of the archive. + Format archiver.Format + // Reader for the archive. + Reader io.Reader +} + +func (fs *Filesystem) extractStream(ctx context.Context, opts extractStreamOptions) error { // Decompress and extract archive - if ex, ok := format.(archiver.Extractor); ok { - return ex.Extract(ctx, input, nil, func(ctx context.Context, f archiver.File) error { + if ex, ok := opts.Format.(archiver.Extractor); ok { + return ex.Extract(ctx, opts.Reader, nil, func(ctx context.Context, f archiver.File) error { if f.IsDir() { return nil } - p := filepath.Join(dir, ExtractNameFromArchive(f)) + p := filepath.Join(opts.Directory, ExtractNameFromArchive(f)) // If it is ignored, just don't do anything with the file and skip over it. if err := fs.IsIgnored(p); err != nil { return nil @@ -174,20 +212,19 @@ func (fs *Filesystem) DecompressFileUnsafe(ctx context.Context, dir string, file } defer r.Close() if err := fs.Writefile(p, r); err != nil { - return wrapError(err, file) + return wrapError(err, opts.FileName) } // Update the file permissions to the one set in the archive. if err := fs.Chmod(p, f.Mode()); err != nil { - return wrapError(err, file) + return wrapError(err, opts.FileName) } // Update the file modification time to the one set in the archive. if err := fs.Chtimes(p, f.ModTime(), f.ModTime()); err != nil { - return wrapError(err, file) + return wrapError(err, opts.FileName) } return nil }) } - return nil } diff --git a/installer/errors.go b/server/installer/errors.go similarity index 100% rename from installer/errors.go rename to server/installer/errors.go diff --git a/installer/installer.go b/server/installer/installer.go similarity index 100% rename from installer/installer.go rename to server/installer/installer.go diff --git a/server/transfer/archive.go b/server/transfer/archive.go new file mode 100644 index 0000000..e5457f1 --- /dev/null +++ b/server/transfer/archive.go @@ -0,0 +1,52 @@ +package transfer + +import ( + "context" + "fmt" + "io" + + "github.com/pterodactyl/wings/internal/progress" + "github.com/pterodactyl/wings/server/filesystem" +) + +// Archive returns an archive that can be used to stream the contents of the +// contents of a server. +func (t *Transfer) Archive() (*Archive, error) { + if t.archive == nil { + // Get the disk usage of the server (used to calculate the progress of the archive process) + rawSize, err := t.Server.Filesystem().DiskUsage(true) + if err != nil { + return nil, fmt.Errorf("transfer: failed to get server disk usage: %w", err) + } + + // Create a new archive instance and assign it to the transfer. + t.archive = NewArchive(t, uint64(rawSize)) + } + + return t.archive, nil +} + +// Archive represents an archive used to transfer the contents of a server. +type Archive struct { + archive *filesystem.Archive +} + +// NewArchive returns a new archive associated with the given transfer. +func NewArchive(t *Transfer, size uint64) *Archive { + return &Archive{ + archive: &filesystem.Archive{ + BasePath: t.Server.Filesystem().Path(), + Progress: progress.NewProgress(size), + }, + } +} + +// Stream returns a reader that can be used to stream the contents of the archive. +func (a *Archive) Stream(ctx context.Context, w io.Writer) error { + return a.archive.Stream(ctx, w) +} + +// Progress returns the current progress of the archive. +func (a *Archive) Progress() *progress.Progress { + return a.archive.Progress +} diff --git a/server/transfer/doc.go b/server/transfer/doc.go new file mode 100644 index 0000000..703dac2 --- /dev/null +++ b/server/transfer/doc.go @@ -0,0 +1,4 @@ +// Package transfer handles all logic related to transferring servers between +// two nodes. This includes the logic for archiving a server on the source node +// and logic for importing a server from the source node into the target node. +package transfer diff --git a/server/transfer/manager.go b/server/transfer/manager.go new file mode 100644 index 0000000..9883b5c --- /dev/null +++ b/server/transfer/manager.go @@ -0,0 +1,57 @@ +package transfer + +import ( + "sync" +) + +var ( + incomingTransfers = NewManager() + outgoingTransfers = NewManager() +) + +// Incoming returns a transfer manager for incoming transfers. +func Incoming() *Manager { + return incomingTransfers +} + +// Outgoing returns a transfer manager for outgoing transfers. +func Outgoing() *Manager { + return outgoingTransfers +} + +// Manager manages transfers. +type Manager struct { + mu sync.RWMutex + transfers map[string]*Transfer +} + +// NewManager returns a new transfer manager. +func NewManager() *Manager { + return &Manager{ + transfers: make(map[string]*Transfer), + } +} + +// Add adds a transfer to the manager. +func (m *Manager) Add(transfer *Transfer) { + m.mu.Lock() + defer m.mu.Unlock() + + m.transfers[transfer.Server.ID()] = transfer +} + +// Remove removes a transfer from the manager. +func (m *Manager) Remove(transfer *Transfer) { + m.mu.Lock() + defer m.mu.Unlock() + + delete(m.transfers, transfer.Server.ID()) +} + +// Get gets a transfer from the manager using a server ID. +func (m *Manager) Get(id string) *Transfer { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.transfers[id] +} diff --git a/server/transfer/source.go b/server/transfer/source.go new file mode 100644 index 0000000..2775363 --- /dev/null +++ b/server/transfer/source.go @@ -0,0 +1,159 @@ +package transfer + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "mime/multipart" + "net/http" + "time" + + "github.com/pterodactyl/wings/internal/progress" +) + +// PushArchiveToTarget POSTs the archive to the target node and returns the +// response body. +func (t *Transfer) PushArchiveToTarget(url, token string) ([]byte, error) { + ctx, cancel := context.WithCancel(t.ctx) + defer cancel() + + t.SendMessage("Preparing to stream server data to destination...") + t.SetStatus(StatusProcessing) + + a, err := t.Archive() + if err != nil { + t.Error(err, "Failed to get archive for transfer.") + return nil, errors.New("failed to get archive for transfer") + } + + t.SendMessage("Streaming archive to destination...") + + // Send the upload progress to the websocket every 5 seconds. + ctx2, cancel2 := context.WithCancel(ctx) + defer cancel2() + go func(ctx context.Context, p *progress.Progress, tc *time.Ticker) { + defer tc.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-tc.C: + t.SendMessage("Uploading " + p.Progress(25)) + } + } + }(ctx2, a.Progress(), time.NewTicker(5*time.Second)) + + // Create a new request using the pipe as the body. + body, writer := io.Pipe() + defer body.Close() + defer writer.Close() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", token) + + // Create a new multipart writer that writes the archive to the pipe. + mp := multipart.NewWriter(writer) + defer mp.Close() + req.Header.Set("Content-Type", mp.FormDataContentType()) + + // Create a new goroutine to write the archive to the pipe used by the + // multipart writer. + errChan := make(chan error) + go func() { + defer close(errChan) + defer writer.Close() + defer mp.Close() + + src, pw := io.Pipe() + defer src.Close() + defer pw.Close() + + h := sha256.New() + tee := io.TeeReader(src, h) + + dest, err := mp.CreateFormFile("archive", "archive.tar.gz") + if err != nil { + errChan <- errors.New("failed to create form file") + return + } + + ch := make(chan error) + go func() { + defer close(ch) + + if _, err := io.Copy(dest, tee); err != nil { + ch <- fmt.Errorf("failed to stream archive to destination: %w", err) + return + } + + t.Log().Debug("finished copying dest to tee") + }() + + if err := a.Stream(ctx, pw); err != nil { + errChan <- errors.New("failed to stream archive to pipe") + return + } + t.Log().Debug("finished streaming archive to pipe") + + // Close the pipe writer early to release resources and ensure that the data gets flushed. + _ = pw.Close() + + // Wait for the copy to finish before we continue. + t.Log().Debug("waiting on copy to finish") + if err := <-ch; err != nil { + errChan <- err + return + } + + if err := mp.WriteField("checksum", hex.EncodeToString(h.Sum(nil))); err != nil { + errChan <- errors.New("failed to stream checksum") + return + } + + cancel2() + t.SendMessage("Finished streaming archive to destination.") + + if err := mp.Close(); err != nil { + t.Log().WithError(err).Error("error while closing multipart writer") + } + t.Log().Debug("closed multipart writer") + }() + + t.Log().Debug("sending archive to destination") + client := http.Client{Timeout: 0} + res, err := client.Do(req) + t.Log().Debug("waiting for stream to complete") + select { + case <-ctx.Done(): + return nil, ctx.Err() + case err2 := <-errChan: + t.Log().Debug("stream completed") + if err != nil || err2 != nil { + if err == context.Canceled { + return nil, err + } + + t.Log().WithError(err).Debug("failed to send archive to destination") + return nil, fmt.Errorf("http error: %w, multipart error: %v", err, err2) + } + defer res.Body.Close() + t.Log().Debug("received response from destination") + + v, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if res.StatusCode != http.StatusOK { + return nil, errors.New(string(v)) + } + + return v, nil + } +} diff --git a/server/transfer/transfer.go b/server/transfer/transfer.go new file mode 100644 index 0000000..6511cfc --- /dev/null +++ b/server/transfer/transfer.go @@ -0,0 +1,128 @@ +package transfer + +import ( + "context" + "time" + + "github.com/apex/log" + "github.com/mitchellh/colorstring" + + "github.com/pterodactyl/wings/server" + "github.com/pterodactyl/wings/system" +) + +// Status represents the current status of a transfer. +type Status string + +// String satisfies the fmt.Stringer interface. +func (s Status) String() string { + return string(s) +} + +const ( + // StatusPending is the status of a transfer when it is first created. + StatusPending Status = "pending" + // StatusProcessing is the status of a transfer when it is currently in + // progress, such as when the archive is being streamed to the target node. + StatusProcessing Status = "processing" + + // StatusCancelling is the status of a transfer when it is in the process of + // being cancelled. + StatusCancelling Status = "cancelling" + + // StatusCancelled is the final status of a transfer when it has been + // cancelled. + StatusCancelled Status = "cancelled" + // StatusFailed is the final status of a transfer when it has failed. + StatusFailed Status = "failed" + // StatusCompleted is the final status of a transfer when it has completed. + StatusCompleted Status = "completed" +) + +// Transfer represents a transfer of a server from one node to another. +type Transfer struct { + // ctx is the context for the transfer. + ctx context.Context + // cancel is used to cancel all ongoing transfer operations for the server. + cancel *context.CancelFunc + + // Server associated with the transfer. + Server *server.Server + // status of the transfer. + status *system.Atomic[Status] + + // archive is the archive that is being created for the transfer. + archive *Archive +} + +// New returns a new transfer instance for the given server. +func New(ctx context.Context, s *server.Server) *Transfer { + ctx, cancel := context.WithCancel(ctx) + + return &Transfer{ + ctx: ctx, + cancel: &cancel, + + Server: s, + status: system.NewAtomic(StatusPending), + } +} + +// Context returns the context for the transfer. +func (t *Transfer) Context() context.Context { + return t.ctx +} + +// Cancel cancels the transfer. +func (t *Transfer) Cancel() { + status := t.Status() + if status == StatusCancelling || + status == StatusCancelled || + status == StatusCompleted || + status == StatusFailed { + return + } + + if t.cancel == nil { + return + } + + t.SetStatus(StatusCancelling) + (*t.cancel)() +} + +// Status returns the current status of the transfer. +func (t *Transfer) Status() Status { + return t.status.Load() +} + +// SetStatus sets the status of the transfer. +func (t *Transfer) SetStatus(s Status) { + // TODO: prevent certain status changes from happening. + // If we are cancelling, then we can't go back to processing. + t.status.Store(s) + + t.Server.Events().Publish(server.TransferStatusEvent, s) +} + +// SendMessage sends a message to the server's console. +func (t *Transfer) SendMessage(v string) { + t.Server.Events().Publish( + server.TransferLogsEvent, + colorstring.Color("[yellow][bold]"+time.Now().Format(time.RFC1123)+" [Transfer System] [Source Node]:[default] "+v), + ) +} + +// Error logs an error that occurred on the source node. +func (t *Transfer) Error(err error, v string) { + t.Log().WithError(err).Error(v) + t.SendMessage(v) +} + +// Log returns a logger for the transfer. +func (t *Transfer) Log() *log.Entry { + if t.Server == nil { + return log.WithField("subsystem", "transfer") + } + return t.Server.Log().WithField("subsystem", "transfer") +} diff --git a/system/utils.go b/system/utils.go index 31b63ca..a740ec5 100644 --- a/system/utils.go +++ b/system/utils.go @@ -120,7 +120,7 @@ func ScanReader(r io.Reader, callback func(line []byte)) error { return nil } -func FormatBytes(b int64) string { +func FormatBytes[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](b T) string { if b < 1024 { return fmt.Sprintf("%d B", b) } @@ -188,14 +188,14 @@ func NewAtomicString(v string) *AtomicString { return &AtomicString{v: v} } -// Stores the string value passed atomically. +// Store stores the string value passed atomically. func (as *AtomicString) Store(v string) { as.mu.Lock() as.v = v as.mu.Unlock() } -// Loads the string value and returns it. +// Load loads the string value and returns it. func (as *AtomicString) Load() string { as.mu.RLock() defer as.mu.RUnlock() @@ -211,3 +211,41 @@ func (as *AtomicString) UnmarshalJSON(b []byte) error { func (as *AtomicString) MarshalJSON() ([]byte, error) { return json.Marshal(as.Load()) } + +type Atomic[T any] struct { + v T + mu sync.RWMutex +} + +func NewAtomic[T any](v T) *Atomic[T] { + return &Atomic[T]{v: v} +} + +// Store stores the string value passed atomically. +func (a *Atomic[T]) Store(v T) { + a.mu.Lock() + defer a.mu.Unlock() + + a.v = v +} + +// Load loads the string value and returns it. +func (a *Atomic[T]) Load() T { + a.mu.RLock() + defer a.mu.RUnlock() + + return a.v +} + +// UnmarshalJSON unmarshals the JSON value into the Atomic[T] value. +func (a *Atomic[T]) UnmarshalJSON(b []byte) error { + a.mu.Lock() + defer a.mu.Unlock() + + return json.Unmarshal(b, &a.v) +} + +// MarshalJSON marshals the Atomic[T] value into JSON. +func (a *Atomic[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(a.Load()) +}