wings/router/downloader/downloader.go
Matthew Penner 14ad268af4
go: update dependencies
Signed-off-by: Matthew Penner <me@matthewp.io>
2024-07-24 15:13:04 -06:00

338 lines
9.1 KiB
Go

package downloader
import (
"context"
"encoding/json"
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
"path/filepath"
"strings"
"sync"
"time"
"emperror.dev/errors"
"github.com/google/uuid"
"github.com/pterodactyl/wings/server"
)
var client *http.Client
func init() {
dialer := &net.Dialer{
LocalAddr: nil,
}
trnspt := http.DefaultTransport.(*http.Transport).Clone()
trnspt.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
c, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, errors.WithStack(err)
}
ipStr, _, err := net.SplitHostPort(c.RemoteAddr().String())
if err != nil {
return c, errors.WithStack(err)
}
ip := net.ParseIP(ipStr)
if ip == nil {
return c, errors.WithStack(ErrInvalidIPAddress)
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
return c, errors.WithStack(ErrInternalResolution)
}
for _, block := range internalRanges {
if !block.Contains(ip) {
continue
}
return c, errors.WithStack(ErrInternalResolution)
}
return c, nil
}
client = &http.Client{
Timeout: time.Hour * 12,
Transport: trnspt,
// Disallow any redirect on an HTTP call. This is a security requirement: do not modify
// this logic without first ensuring that the new target location IS NOT within the current
// instance's local network.
//
// This specific error response just causes the client to not follow the redirect and
// returns the actual redirect response to the caller. Not perfect, but simple and most
// people won't be using URLs that redirect anyways hopefully?
//
// We'll re-evaluate this down the road if needed.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
var instance = &Downloader{
// Tracks all the active downloads.
downloadCache: make(map[string]*Download),
// Tracks all the downloads active for a given server instance. This is
// primarily used to make things quicker and keep the code a little more
// legible throughout here.
serverCache: make(map[string][]string),
}
// Internal IP ranges that should be blocked if the resource requested resolves within.
var internalRanges = []*net.IPNet{
mustParseCIDR("127.0.0.1/8"),
mustParseCIDR("10.0.0.0/8"),
mustParseCIDR("172.16.0.0/12"),
mustParseCIDR("192.168.0.0/16"),
mustParseCIDR("169.254.0.0/16"),
mustParseCIDR("::1/128"),
mustParseCIDR("fe80::/10"),
mustParseCIDR("fc00::/7"),
}
const (
ErrInternalResolution = errors.Sentinel("downloader: destination resolves to internal network location")
ErrInvalidIPAddress = errors.Sentinel("downloader: invalid IP address")
ErrDownloadFailed = errors.Sentinel("downloader: download request failed")
)
type Counter struct {
total int
onWrite func(total int)
}
func (c *Counter) Write(p []byte) (int, error) {
n := len(p)
c.total += n
c.onWrite(c.total)
return n, nil
}
type DownloadRequest struct {
Directory string
URL *url.URL
FileName string
UseHeader bool
}
type Download struct {
Identifier string
path string
mu sync.RWMutex
req DownloadRequest
server *server.Server
progress float64
cancelFunc *context.CancelFunc
}
// New starts a new tracked download which allows for cancellation later on by calling
// the Downloader.Cancel function.
func New(s *server.Server, r DownloadRequest) *Download {
dl := Download{
Identifier: uuid.Must(uuid.NewRandom()).String(),
req: r,
server: s,
}
instance.track(&dl)
return &dl
}
// ByServer returns all the tracked downloads for a given server instance.
func ByServer(sid string) []*Download {
instance.mu.Lock()
defer instance.mu.Unlock()
var downloads []*Download
if v, ok := instance.serverCache[sid]; ok {
for _, id := range v {
if dl, ok := instance.downloadCache[id]; ok {
downloads = append(downloads, dl)
}
}
}
return downloads
}
// ByID returns a single Download matching a given identifier. If no download is found
// the second argument in the response will be false.
func ByID(dlid string) *Download {
return instance.find(dlid)
}
//goland:noinspection GoVetCopyLock
func (dl Download) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Identifier string
Progress float64
}{
Identifier: dl.Identifier,
Progress: dl.Progress(),
})
}
// Execute executes a given download for the server and begins writing the file to the disk. Once
// completed the download will be removed from the cache.
func (dl *Download) Execute() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Hour*12)
dl.cancelFunc = &cancel
defer dl.Cancel()
// At this point we have verified the destination is not within the local network, so we can
// now make a request to that URL and pull down the file, saving it to the server's data
// directory.
req, err := http.NewRequestWithContext(ctx, http.MethodGet, dl.req.URL.String(), nil)
if err != nil {
return errors.WrapIf(err, "downloader: failed to create request")
}
req.Header.Set("User-Agent", "Pterodactyl Panel (https://pterodactyl.io)")
res, err := client.Do(req)
if err != nil {
return ErrDownloadFailed
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return errors.New("downloader: got bad response status from endpoint: " + res.Status)
}
if res.ContentLength < 1 {
return errors.New("downloader: request is missing ContentLength")
}
if dl.req.UseHeader {
if contentDisposition := res.Header.Get("Content-Disposition"); contentDisposition != "" {
_, params, err := mime.ParseMediaType(contentDisposition)
if err != nil {
return errors.WrapIf(err, "downloader: invalid \"Content-Disposition\" header")
}
if v, ok := params["filename"]; ok {
dl.path = v
}
}
}
if dl.path == "" {
if dl.req.FileName != "" {
dl.path = dl.req.FileName
} else {
parts := strings.Split(dl.req.URL.Path, "/")
dl.path = parts[len(parts)-1]
}
}
p := dl.Path()
dl.server.Log().WithField("path", p).Debug("writing remote file to disk")
// Write the file while tracking the progress, Write will check that the
// size of the file won't exceed the disk limit.
r := io.TeeReader(res.Body, dl.counter(res.ContentLength))
if err := dl.server.Filesystem().Write(p, r, res.ContentLength, 0o644); err != nil {
return errors.WrapIf(err, "downloader: failed to write file to server directory")
}
return nil
}
// Cancel cancels a running download and frees up the associated resources. If a file is being
// written a partial file will remain present on the disk.
func (dl *Download) Cancel() {
if dl.cancelFunc != nil {
(*dl.cancelFunc)()
}
instance.remove(dl.Identifier)
}
// BelongsTo checks if the given download belongs to the provided server.
func (dl *Download) BelongsTo(s *server.Server) bool {
return dl.server.ID() == s.ID()
}
// Progress returns the current progress of the download as a float value between 0 and 1 where
// 1 indicates that the download is completed.
func (dl *Download) Progress() float64 {
dl.mu.RLock()
defer dl.mu.RUnlock()
return dl.progress
}
func (dl *Download) Path() string {
return filepath.Join(dl.req.Directory, dl.path)
}
// Handles a write event by updating the progress completed percentage and firing off
// events to the server websocket as needed.
func (dl *Download) counter(contentLength int64) *Counter {
onWrite := func(t int) {
dl.mu.Lock()
defer dl.mu.Unlock()
dl.progress = float64(t) / float64(contentLength)
}
return &Counter{
onWrite: onWrite,
}
}
// Downloader represents a global downloader that keeps track of all currently processing downloads
// for the machine.
type Downloader struct {
mu sync.RWMutex
downloadCache map[string]*Download
serverCache map[string][]string
}
// track tracks a download in the internal cache for this instance.
func (d *Downloader) track(dl *Download) {
d.mu.Lock()
defer d.mu.Unlock()
sid := dl.server.ID()
if _, ok := d.downloadCache[dl.Identifier]; !ok {
d.downloadCache[dl.Identifier] = dl
if _, ok := d.serverCache[sid]; !ok {
d.serverCache[sid] = []string{}
}
d.serverCache[sid] = append(d.serverCache[sid], dl.Identifier)
}
}
// find finds a given download entry using the provided ID and returns it.
func (d *Downloader) find(dlid string) *Download {
d.mu.RLock()
defer d.mu.RUnlock()
if entry, ok := d.downloadCache[dlid]; ok {
return entry
}
return nil
}
// remove removes the given download reference from the cache storing them. This also updates
// the slice of active downloads for a given server to not include this download.
func (d *Downloader) remove(dlID string) {
d.mu.Lock()
defer d.mu.Unlock()
if _, ok := d.downloadCache[dlID]; !ok {
return
}
sID := d.downloadCache[dlID].server.ID()
delete(d.downloadCache, dlID)
if tracked, ok := d.serverCache[sID]; ok {
var out []string
for _, k := range tracked {
if k != dlID {
out = append(out, k)
}
}
d.serverCache[sID] = out
}
}
func mustParseCIDR(ip string) *net.IPNet {
_, block, err := net.ParseCIDR(ip)
if err != nil {
panic(fmt.Errorf("downloader: failed to parse CIDR: %s", err))
}
return block
}