diff --git a/router/downloader/downloader.go b/router/downloader/downloader.go index 85034ed..a89cca1 100644 --- a/router/downloader/downloader.go +++ b/router/downloader/downloader.go @@ -4,17 +4,50 @@ import ( "context" "emperror.dev/errors" "encoding/json" + "fmt" "github.com/google/uuid" "github.com/pterodactyl/wings/server" "io" + "net" "net/http" "net/url" "path/filepath" + "regexp" "strings" "sync" "time" ) +var client = &http.Client{Timeout: time.Hour * 12} +var instance = &Downloader{ + // Tracks all of the active downloads. + downloadCache: make(map[string]*Download), + // Tracks all of the downloads active for a given server instance. This is + // primarily used to make things quicker and keep the code a little more + // legible throughout here. + serverCache: make(map[string][]string), +} + +// Regex to match the end of an IPv4/IPv6 address. This allows the port to be removed +// so that we are just working with the raw IP address in question. +var ipMatchRegex = regexp.MustCompile(`(:\d+)$`) + +// Internal IP ranges that should be blocked if the resource requested resolves within. +var internalRanges = []*net.IPNet{ + mustParseCIDR("127.0.0.1/8"), + mustParseCIDR("10.0.0.0/8"), + mustParseCIDR("172.16.0.0/12"), + mustParseCIDR("192.168.0.0/16"), + mustParseCIDR("169.254.0.0/16"), + mustParseCIDR("::1/128"), + mustParseCIDR("fe80::/10"), + mustParseCIDR("fc00::/7"), +} + +const ErrInternalResolution = errors.Sentinel("downloader: destination resolves to internal network location") +const ErrInvalidIPAddress = errors.Sentinel("downloader: invalid IP address") +const ErrDownloadFailed = errors.Sentinel("downloader: download request failed") + type Counter struct { total int onWrite func(total int) @@ -27,12 +60,6 @@ func (c *Counter) Write(p []byte) (int, error) { return n, nil } -type Downloader struct { - mu sync.RWMutex - downloadCache map[string]*Download - serverCache map[string][]string -} - type DownloadRequest struct { URL *url.URL Directory string @@ -47,16 +74,6 @@ type Download struct { cancelFunc *context.CancelFunc } -var client = &http.Client{Timeout: time.Hour * 12} -var instance = &Downloader{ - // Tracks all of the active downloads. - downloadCache: make(map[string]*Download), - // Tracks all of the downloads active for a given server instance. This is - // primarily used to make things quicker and keep the code a little more - // legible throughout here. - serverCache: make(map[string][]string), -} - // Starts a new tracked download which allows for cancelation later on by calling // the Downloader.Cancel function. func New(s *server.Server, r DownloadRequest) *Download { @@ -108,15 +125,24 @@ func (dl *Download) Execute() error { dl.cancelFunc = &cancel defer dl.Cancel() + // Always ensure that we're checking the destination for the download to avoid a malicious + // user from accessing internal network resources. + if err := dl.isExternalNetwork(ctx); err != nil { + return err + } + + // At this point we have verified the destination is not within the local network, so we can + // now make a request to that URL and pull down the file, saving it to the server's data + // directory. req, err := http.NewRequestWithContext(ctx, http.MethodGet, dl.req.URL.String(), nil) if err != nil { return errors.WrapIf(err, "downloader: failed to create request") } req.Header.Set("User-Agent", "Pterodactyl Panel (https://pterodactyl.io)") - res, err := client.Do(req) // lgtm [go/request-forgery] + res, err := client.Do(req) if err != nil { - return errors.New("downloader: failed opening request to download file") + return ErrDownloadFailed } defer res.Body.Close() if res.StatusCode != http.StatusOK { @@ -178,6 +204,52 @@ func (dl *Download) counter(contentLength int64) *Counter { } } +// Verifies that a given download resolves to a location not within the current local +// network for the machine. If the final destination of a resource is within the local +// network an ErrInternalResolution error is returned. +func (dl *Download) isExternalNetwork(ctx context.Context) error { + dialer := &net.Dialer{ + LocalAddr: nil, + } + + host := dl.req.URL.Host + if !ipMatchRegex.MatchString(host) { + if dl.req.URL.Scheme == "https" { + host = host + ":443" + } else { + host = host + ":80" + } + } + + c, err := dialer.DialContext(ctx, "tcp", host) + if err != nil { + return errors.WithStack(err) + } + c.Close() + + ip := net.ParseIP(ipMatchRegex.ReplaceAllString(c.RemoteAddr().String(), "")) + if ip == nil { + return errors.WithStack(ErrInvalidIPAddress) + } + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() { + return errors.WithStack(ErrInternalResolution) + } + for _, block := range internalRanges { + if block.Contains(ip) { + return errors.WithStack(ErrInternalResolution) + } + } + return nil +} + +// Defines a global downloader struct that keeps track of all currently processing downloads +// for the machine. +type Downloader struct { + mu sync.RWMutex + downloadCache map[string]*Download + serverCache map[string][]string +} + // Tracks a download in the internal cache for this instance. func (d *Downloader) track(dl *Download) { d.mu.Lock() @@ -222,3 +294,11 @@ func (d *Downloader) remove(dlid string) { d.serverCache[sid] = out } } + +func mustParseCIDR(ip string) *net.IPNet { + _, block, err := net.ParseCIDR(ip) + if err != nil { + panic(fmt.Errorf("downloader: failed to parse CIDR: %s", err)) + } + return block +}