From c152e36101aba45d8868a9a0eeb890995e8934b8 Mon Sep 17 00:00:00 2001 From: Matthew Penner Date: Wed, 10 Apr 2024 15:22:09 -0600 Subject: [PATCH] downloader: move internal subnet validation into http Transport --- router/downloader/downloader.go | 125 +++++++++++++------------------- 1 file changed, 52 insertions(+), 73 deletions(-) diff --git a/router/downloader/downloader.go b/router/downloader/downloader.go index d5436d6..e5c6523 100644 --- a/router/downloader/downloader.go +++ b/router/downloader/downloader.go @@ -20,20 +20,58 @@ import ( "github.com/pterodactyl/wings/server" ) -var client = &http.Client{ - Timeout: time.Hour * 12, - // Disallow any redirect on an HTTP call. This is a security requirement: do not modify - // this logic without first ensuring that the new target location IS NOT within the current - // instance's local network. - // - // This specific error response just causes the client to not follow the redirect and - // returns the actual redirect response to the caller. Not perfect, but simple and most - // people won't be using URLs that redirect anyways hopefully? - // - // We'll re-evaluate this down the road if needed. - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, +var client *http.Client + +func init() { + dialer := &net.Dialer{ + LocalAddr: nil, + } + + trnspt := http.DefaultTransport.(*http.Transport).Clone() + trnspt.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, errors.WithStack(err) + } + + ipStr, _, err := net.SplitHostPort(c.RemoteAddr().String()) + if err != nil { + return c, errors.WithStack(err) + } + ip := net.ParseIP(ipStr) + if ip == nil { + return c, errors.WithStack(ErrInvalidIPAddress) + } + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() { + return c, errors.WithStack(ErrInternalResolution) + } + for _, block := range internalRanges { + if !block.Contains(ip) { + continue + } + return c, errors.WithStack(ErrInternalResolution) + } + return c, nil + } + + client = &http.Client{ + Timeout: time.Hour * 12, + + Transport: trnspt, + + // Disallow any redirect on an HTTP call. This is a security requirement: do not modify + // this logic without first ensuring that the new target location IS NOT within the current + // instance's local network. + // + // This specific error response just causes the client to not follow the redirect and + // returns the actual redirect response to the caller. Not perfect, but simple and most + // people won't be using URLs that redirect anyways hopefully? + // + // We'll re-evaluate this down the road if needed. + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } } var instance = &Downloader{ @@ -143,12 +181,6 @@ func (dl *Download) Execute() error { dl.cancelFunc = &cancel defer dl.Cancel() - // Always ensure that we're checking the destination for the download to avoid a malicious - // user from accessing internal network resources. - if err := dl.isExternalNetwork(ctx); err != nil { - return err - } - // At this point we have verified the destination is not within the local network, so we can // now make a request to that URL and pull down the file, saving it to the server's data // directory. @@ -243,59 +275,6 @@ func (dl *Download) counter(contentLength int64) *Counter { } } -// Verifies that a given download resolves to a location not within the current local -// network for the machine. If the final destination of a resource is within the local -// network an ErrInternalResolution error is returned. -func (dl *Download) isExternalNetwork(ctx context.Context) error { - dialer := &net.Dialer{ - LocalAddr: nil, - } - - host := dl.req.URL.Host - - // This cluster-fuck of math and integer shit converts an integer IP into a proper IPv4. - // For example: 16843009 would become 1.1.1.1 - //if i, err := strconv.ParseInt(host, 10, 64); err == nil { - // host = strconv.FormatInt((i>>24)&0xFF, 10) + "." + strconv.FormatInt((i>>16)&0xFF, 10) + "." + strconv.FormatInt((i>>8)&0xFF, 10) + "." + strconv.FormatInt(i&0xFF, 10) - //} - - if _, _, err := net.SplitHostPort(host); err != nil { - if !strings.Contains(err.Error(), "missing port in address") { - return errors.WithStack(err) - } - switch dl.req.URL.Scheme { - case "http": - host += ":80" - case "https": - host += ":443" - } - } - - c, err := dialer.DialContext(ctx, "tcp", host) - if err != nil { - return errors.WithStack(err) - } - _ = c.Close() - - ipStr, _, err := net.SplitHostPort(c.RemoteAddr().String()) - if err != nil { - return errors.WithStack(err) - } - ip := net.ParseIP(ipStr) - if ip == nil { - return errors.WithStack(ErrInvalidIPAddress) - } - if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() { - return errors.WithStack(ErrInternalResolution) - } - for _, block := range internalRanges { - if block.Contains(ip) { - return errors.WithStack(ErrInternalResolution) - } - } - return nil -} - // Downloader represents a global downloader that keeps track of all currently processing downloads // for the machine. type Downloader struct {