downloader: move internal subnet validation into http Transport

This commit is contained in:
Matthew Penner 2024-04-10 15:22:09 -06:00
parent 5b0422d756
commit c152e36101
No known key found for this signature in database

View File

@ -20,20 +20,58 @@ import (
"github.com/pterodactyl/wings/server" "github.com/pterodactyl/wings/server"
) )
var client = &http.Client{ var client *http.Client
Timeout: time.Hour * 12,
// Disallow any redirect on an HTTP call. This is a security requirement: do not modify func init() {
// this logic without first ensuring that the new target location IS NOT within the current dialer := &net.Dialer{
// instance's local network. LocalAddr: nil,
// }
// This specific error response just causes the client to not follow the redirect and
// returns the actual redirect response to the caller. Not perfect, but simple and most trnspt := http.DefaultTransport.(*http.Transport).Clone()
// people won't be using URLs that redirect anyways hopefully? trnspt.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
// c, err := dialer.DialContext(ctx, network, addr)
// We'll re-evaluate this down the road if needed. if err != nil {
CheckRedirect: func(req *http.Request, via []*http.Request) error { return nil, errors.WithStack(err)
return http.ErrUseLastResponse }
},
ipStr, _, err := net.SplitHostPort(c.RemoteAddr().String())
if err != nil {
return c, errors.WithStack(err)
}
ip := net.ParseIP(ipStr)
if ip == nil {
return c, errors.WithStack(ErrInvalidIPAddress)
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
return c, errors.WithStack(ErrInternalResolution)
}
for _, block := range internalRanges {
if !block.Contains(ip) {
continue
}
return c, errors.WithStack(ErrInternalResolution)
}
return c, nil
}
client = &http.Client{
Timeout: time.Hour * 12,
Transport: trnspt,
// Disallow any redirect on an HTTP call. This is a security requirement: do not modify
// this logic without first ensuring that the new target location IS NOT within the current
// instance's local network.
//
// This specific error response just causes the client to not follow the redirect and
// returns the actual redirect response to the caller. Not perfect, but simple and most
// people won't be using URLs that redirect anyways hopefully?
//
// We'll re-evaluate this down the road if needed.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
} }
var instance = &Downloader{ var instance = &Downloader{
@ -143,12 +181,6 @@ func (dl *Download) Execute() error {
dl.cancelFunc = &cancel dl.cancelFunc = &cancel
defer dl.Cancel() defer dl.Cancel()
// Always ensure that we're checking the destination for the download to avoid a malicious
// user from accessing internal network resources.
if err := dl.isExternalNetwork(ctx); err != nil {
return err
}
// At this point we have verified the destination is not within the local network, so we can // At this point we have verified the destination is not within the local network, so we can
// now make a request to that URL and pull down the file, saving it to the server's data // now make a request to that URL and pull down the file, saving it to the server's data
// directory. // directory.
@ -243,59 +275,6 @@ func (dl *Download) counter(contentLength int64) *Counter {
} }
} }
// Verifies that a given download resolves to a location not within the current local
// network for the machine. If the final destination of a resource is within the local
// network an ErrInternalResolution error is returned.
func (dl *Download) isExternalNetwork(ctx context.Context) error {
dialer := &net.Dialer{
LocalAddr: nil,
}
host := dl.req.URL.Host
// This cluster-fuck of math and integer shit converts an integer IP into a proper IPv4.
// For example: 16843009 would become 1.1.1.1
//if i, err := strconv.ParseInt(host, 10, 64); err == nil {
// host = strconv.FormatInt((i>>24)&0xFF, 10) + "." + strconv.FormatInt((i>>16)&0xFF, 10) + "." + strconv.FormatInt((i>>8)&0xFF, 10) + "." + strconv.FormatInt(i&0xFF, 10)
//}
if _, _, err := net.SplitHostPort(host); err != nil {
if !strings.Contains(err.Error(), "missing port in address") {
return errors.WithStack(err)
}
switch dl.req.URL.Scheme {
case "http":
host += ":80"
case "https":
host += ":443"
}
}
c, err := dialer.DialContext(ctx, "tcp", host)
if err != nil {
return errors.WithStack(err)
}
_ = c.Close()
ipStr, _, err := net.SplitHostPort(c.RemoteAddr().String())
if err != nil {
return errors.WithStack(err)
}
ip := net.ParseIP(ipStr)
if ip == nil {
return errors.WithStack(ErrInvalidIPAddress)
}
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || ip.IsInterfaceLocalMulticast() {
return errors.WithStack(ErrInternalResolution)
}
for _, block := range internalRanges {
if block.Contains(ip) {
return errors.WithStack(ErrInternalResolution)
}
}
return nil
}
// Downloader represents a global downloader that keeps track of all currently processing downloads // Downloader represents a global downloader that keeps track of all currently processing downloads
// for the machine. // for the machine.
type Downloader struct { type Downloader struct {