downloader: fix internal range check

This commit is contained in:
Matthew Penner 2021-08-02 15:16:38 -06:00
parent 3c54c1f840
commit 2a370a8776
2 changed files with 39 additions and 35 deletions

View File

@ -9,8 +9,6 @@ import (
"net/http"
"net/url"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
@ -46,10 +44,6 @@ var instance = &Downloader{
serverCache: make(map[string][]string),
}
// Regex to match the end of an IPv4/IPv6 address. This allows the port to be removed
// so that we are just working with the raw IP address in question.
var ipMatchRegex = regexp.MustCompile(`(:\d+)$`)
// Internal IP ranges that should be blocked if the resource requested resolves within.
var internalRanges = []*net.IPNet{
mustParseCIDR("127.0.0.1/8"),
@ -62,9 +56,11 @@ var internalRanges = []*net.IPNet{
mustParseCIDR("fc00::/7"),
}
const ErrInternalResolution = errors.Sentinel("downloader: destination resolves to internal network location")
const ErrInvalidIPAddress = errors.Sentinel("downloader: invalid IP address")
const ErrDownloadFailed = errors.Sentinel("downloader: download request failed")
const (
ErrInternalResolution = errors.Sentinel("downloader: destination resolves to internal network location")
ErrInvalidIPAddress = errors.Sentinel("downloader: invalid IP address")
ErrDownloadFailed = errors.Sentinel("downloader: download request failed")
)
type Counter struct {
total int
@ -79,8 +75,8 @@ func (c *Counter) Write(p []byte) (int, error) {
}
type DownloadRequest struct {
URL *url.URL
Directory string
URL *url.URL
}
type Download struct {
@ -92,7 +88,7 @@ type Download struct {
cancelFunc *context.CancelFunc
}
// Starts a new tracked download which allows for cancellation later on by calling
// New starts a new tracked download which allows for cancellation later on by calling
// the Downloader.Cancel function.
func New(s *server.Server, r DownloadRequest) *Download {
dl := Download{
@ -104,14 +100,14 @@ func New(s *server.Server, r DownloadRequest) *Download {
return &dl
}
// Returns all of the tracked downloads for a given server instance.
// ByServer returns all the tracked downloads for a given server instance.
func ByServer(sid string) []*Download {
instance.mu.Lock()
defer instance.mu.Unlock()
var downloads []*Download
if v, ok := instance.serverCache[sid]; ok {
for _, id := range v {
if dl, dlok := instance.downloadCache[id]; dlok {
if dl, ok := instance.downloadCache[id]; ok {
downloads = append(downloads, dl)
}
}
@ -119,7 +115,7 @@ func ByServer(sid string) []*Download {
return downloads
}
// Returns a single Download matching a given identifier. If no download is found
// ByID returns a single Download matching a given identifier. If no download is found
// the second argument in the response will be false.
func ByID(dlid string) *Download {
return instance.find(dlid)
@ -136,7 +132,7 @@ func (dl Download) MarshalJSON() ([]byte, error) {
})
}
// Executes a given download for the server and begins writing the file to the disk. Once
// Execute executes a given download for the server and begins writing the file to the disk. Once
// completed the download will be removed from the cache.
func (dl *Download) Execute() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Hour*12)
@ -187,7 +183,7 @@ func (dl *Download) Execute() error {
return nil
}
// Cancels a running download and frees up the associated resources. If a file is being
// Cancel cancels a running download and frees up the associated resources. If a file is being
// written a partial file will remain present on the disk.
func (dl *Download) Cancel() {
if dl.cancelFunc != nil {
@ -196,12 +192,12 @@ func (dl *Download) Cancel() {
instance.remove(dl.Identifier)
}
// Checks if the given download belongs to the provided server.
// BelongsTo checks if the given download belongs to the provided server.
func (dl *Download) BelongsTo(s *server.Server) bool {
return dl.server.ID() == s.ID()
}
// Returns the current progress of the download as a float value between 0 and 1 where
// Progress returns the current progress of the download as a float value between 0 and 1 where
// 1 indicates that the download is completed.
func (dl *Download) Progress() float64 {
dl.mu.RLock()
@ -234,15 +230,19 @@ func (dl *Download) isExternalNetwork(ctx context.Context) error {
// This cluster-fuck of math and integer shit converts an integer IP into a proper IPv4.
// For example: 16843009 would become 1.1.1.1
if i, err := strconv.ParseInt(host, 10, 64); err == nil {
host = strconv.FormatInt((i>>24)&0xFF, 10) + "." + strconv.FormatInt((i>>16)&0xFF, 10) + "." + strconv.FormatInt((i>>8)&0xFF, 10) + "." + strconv.FormatInt(i&0xFF, 10)
}
//if i, err := strconv.ParseInt(host, 10, 64); err == nil {
// host = strconv.FormatInt((i>>24)&0xFF, 10) + "." + strconv.FormatInt((i>>16)&0xFF, 10) + "." + strconv.FormatInt((i>>8)&0xFF, 10) + "." + strconv.FormatInt(i&0xFF, 10)
//}
if !ipMatchRegex.MatchString(host) {
if dl.req.URL.Scheme == "https" {
host = host + ":443"
} else {
host = host + ":80"
if _, _, err := net.SplitHostPort(host); err != nil {
if !strings.Contains(err.Error(), "missing port in address") {
return errors.WithStack(err)
}
switch dl.req.URL.Scheme {
case "http":
host += ":80"
case "https":
host += ":443"
}
}
@ -252,7 +252,11 @@ func (dl *Download) isExternalNetwork(ctx context.Context) error {
}
_ = c.Close()
ip := net.ParseIP(ipMatchRegex.ReplaceAllString(c.RemoteAddr().String(), ""))
ipStr, _, err := net.SplitHostPort(c.RemoteAddr().String())
if err != nil {
return errors.WithStack(err)
}
ip := net.ParseIP(ipStr)
if ip == nil {
return errors.WithStack(ErrInvalidIPAddress)
}
@ -267,7 +271,7 @@ func (dl *Download) isExternalNetwork(ctx context.Context) error {
return nil
}
// Defines a global downloader struct that keeps track of all currently processing downloads
// Downloader represents a global downloader that keeps track of all currently processing downloads
// for the machine.
type Downloader struct {
mu sync.RWMutex
@ -275,7 +279,7 @@ type Downloader struct {
serverCache map[string][]string
}
// Tracks a download in the internal cache for this instance.
// track tracks a download in the internal cache for this instance.
func (d *Downloader) track(dl *Download) {
d.mu.Lock()
defer d.mu.Unlock()
@ -289,7 +293,7 @@ func (d *Downloader) track(dl *Download) {
}
}
// Finds a given download entry using the provided ID and returns it.
// find finds a given download entry using the provided ID and returns it.
func (d *Downloader) find(dlid string) *Download {
d.mu.RLock()
defer d.mu.RUnlock()
@ -299,7 +303,7 @@ func (d *Downloader) find(dlid string) *Download {
return nil
}
// Remove the given download reference from the cache storing them. This also updates
// remove removes the given download reference from the cache storing them. This also updates
// the slice of active downloads for a given server to not include this download.
func (d *Downloader) remove(dlID string) {
d.mu.Lock()

View File

@ -254,8 +254,8 @@ func getServerPullingFiles(c *gin.Context) {
func postServerPullRemoteFile(c *gin.Context) {
s := ExtractServer(c)
var data struct {
RootPath string `binding:"required,omitempty" json:"root"`
URL string `binding:"required" json:"url"`
Directory string `binding:"required,omitempty" json:"directory"`
}
if err := c.BindJSON(&data); err != nil {
return
@ -286,11 +286,11 @@ func postServerPullRemoteFile(c *gin.Context) {
}
dl := downloader.New(s, downloader.DownloadRequest{
Directory: data.RootPath,
URL: u,
Directory: data.Directory,
})
// Execute this pull in a seperate thread since it may take a long time to complete.
// Execute this pull in a separate thread since it may take a long time to complete.
go func() {
s.Log().WithField("download_id", dl.Identifier).WithField("url", u.String()).Info("starting pull of remote file to disk")
if err := dl.Execute(); err != nil {