diff --git a/attachments.go b/attachments.go index c8a7318..34e6b91 100644 --- a/attachments.go +++ b/attachments.go @@ -3,6 +3,7 @@ package main import ( "bytes" "context" + "errors" "fmt" "image" "io" @@ -12,6 +13,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "github.com/bwmarrin/discordgo" @@ -28,7 +30,7 @@ import ( "go.mau.fi/mautrix-discord/database" ) -func downloadDiscordAttachment(url string) ([]byte, error) { +func downloadDiscordAttachment(url string, maxSize int64) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err @@ -46,7 +48,22 @@ func downloadDiscordAttachment(url string) ([]byte, error) { data, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("unexpected status %d downloading %s: %s", resp.StatusCode, url, data) } - return io.ReadAll(resp.Body) + if resp.Header.Get("Content-Length") != "" { + length, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse content length: %w", err) + } else if length > maxSize { + return nil, fmt.Errorf("attachment too large (%d > %d)", length, maxSize) + } + return io.ReadAll(resp.Body) + } else { + var mbe *http.MaxBytesError + data, err := io.ReadAll(http.MaxBytesReader(nil, resp.Body, maxSize)) + if err != nil && errors.As(err, &mbe) { + return nil, fmt.Errorf("attachment too large (over %d)", maxSize) + } + return data, err + } } func uploadDiscordAttachment(url string, data []byte) error { @@ -99,7 +116,7 @@ func downloadMatrixAttachment(intent *appservice.IntentAPI, content *event.Messa return data, nil } -func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, url string, encrypt bool, meta AttachmentMeta) (*database.File, error) { +func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, url string, encrypt bool, meta AttachmentMeta, semaWg *sync.WaitGroup) (*database.File, error) { dbFile := br.DB.File.New() dbFile.Timestamp = time.Now() dbFile.URL = url @@ -135,7 +152,9 @@ func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, da dbFile.MXC = resp.ContentURI req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL + semaWg.Add(1) go func() { + defer semaWg.Done() _, err = intent.UploadMedia(req) if err != nil { br.Log.Errorfln("Failed to upload %s: %v", req.MXC, err) @@ -259,8 +278,21 @@ func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, ur } } + const attachmentSizeVal = 1 + onceErr = br.parallelAttachmentSemaphore.Acquire(context.Background(), attachmentSizeVal) + if onceErr != nil { + onceErr = fmt.Errorf("failed to acquire semaphore: %w", onceErr) + return + } + var semaWg sync.WaitGroup + semaWg.Add(1) + go func() { + semaWg.Wait() + br.parallelAttachmentSemaphore.Release(attachmentSizeVal) + }() + var data []byte - data, onceErr = downloadDiscordAttachment(url) + data, onceErr = downloadDiscordAttachment(url, br.MediaConfig.UploadSize) if onceErr != nil { return } @@ -273,7 +305,7 @@ func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, ur } } - onceDBFile, onceErr = br.uploadMatrixAttachment(intent, data, url, encrypt, meta) + onceDBFile, onceErr = br.uploadMatrixAttachment(intent, data, url, encrypt, meta, &semaWg) if onceErr != nil { return } @@ -281,6 +313,7 @@ func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, ur onceDBFile.Insert(nil) } br.attachmentTransfers.Delete(transferKey) + semaWg.Done() return }) } diff --git a/go.mod b/go.mod index 7904de2..9e38037 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/stretchr/testify v1.8.4 github.com/yuin/goldmark v1.5.4 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 + golang.org/x/sync v0.3.0 maunium.net/go/maulogger/v2 v2.4.1 maunium.net/go/mautrix v0.15.4-0.20230623121006-d8b15c18dc3f ) diff --git a/go.sum b/go.sum index b1e04af..55b3993 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERs golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= diff --git a/main.go b/main.go index b2d66d0..f40ec33 100644 --- a/main.go +++ b/main.go @@ -20,6 +20,7 @@ import ( _ "embed" "sync" + "golang.org/x/sync/semaphore" "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/id" @@ -73,7 +74,8 @@ type DiscordBridge struct { puppetsByCustomMXID map[id.UserID]*Puppet puppetsLock sync.Mutex - attachmentTransfers *util.SyncMap[attachmentKey, *util.ReturnableOnce[*database.File]] + attachmentTransfers *util.SyncMap[attachmentKey, *util.ReturnableOnce[*database.File]] + parallelAttachmentSemaphore *semaphore.Weighted } func (br *DiscordBridge) GetExampleConfig() string { @@ -170,7 +172,8 @@ func main() { puppets: make(map[string]*Puppet), puppetsByCustomMXID: make(map[id.UserID]*Puppet), - attachmentTransfers: util.NewSyncMap[attachmentKey, *util.ReturnableOnce[*database.File]](), + attachmentTransfers: util.NewSyncMap[attachmentKey, *util.ReturnableOnce[*database.File]](), + parallelAttachmentSemaphore: semaphore.NewWeighted(3), } br.Bridge = bridge.Bridge{ Name: "mautrix-discord",