From 9ca27a8df655bc1d8385ca754050d7856bcbc045 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 28 Jan 2023 01:57:44 +0200 Subject: [PATCH] Cache files copied to Matrix --- attachments.go | 84 ++++++----- database/database.go | 5 + database/file.go | 132 ++++++++++++++++++ database/upgrades/00-latest-revision.sql | 20 ++- .../upgrades/11-cache-reuploaded-files.sql | 18 +++ portal.go | 33 +++-- 6 files changed, 250 insertions(+), 42 deletions(-) create mode 100644 database/file.go create mode 100644 database/upgrades/11-cache-reuploaded-files.sql diff --git a/attachments.go b/attachments.go index e115a44..4d248bd 100644 --- a/attachments.go +++ b/attachments.go @@ -7,18 +7,19 @@ import ( "io" "net/http" "strings" - - "maunium.net/go/mautrix/crypto/attachment" + "time" "github.com/bwmarrin/discordgo" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-discord/database" ) -func (portal *Portal) downloadDiscordAttachment(url string) ([]byte, error) { +func downloadDiscordAttachment(url string) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err @@ -68,48 +69,67 @@ func (portal *Portal) downloadMatrixAttachment(content *event.MessageEventConten return data, nil } -func (portal *Portal) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, content *event.MessageEventContent) error { - content.Info.Size = len(data) - if content.Info.Width == 0 && content.Info.Height == 0 && strings.HasPrefix(content.Info.MimeType, "image/") { +func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, url string, encrypt bool, attachmentID, mime string) (*database.File, error) { + dbFile := br.DB.File.New() + dbFile.Timestamp = time.Now() + dbFile.URL = url + dbFile.ID = attachmentID + dbFile.Size = len(data) + if strings.HasPrefix(mime, "image/") { cfg, _, _ := image.DecodeConfig(bytes.NewReader(data)) - content.Info.Width = cfg.Width - content.Info.Height = cfg.Height + dbFile.Width = cfg.Width + dbFile.Height = cfg.Height } - uploadMime := content.Info.MimeType - var file *attachment.EncryptedFile - if portal.Encrypted { - file = attachment.NewEncryptedFile() - file.EncryptInPlace(data) + uploadMime := mime + if encrypt { + dbFile.Encrypted = true + dbFile.DecryptionInfo = attachment.NewEncryptedFile() + dbFile.DecryptionInfo.EncryptInPlace(data) uploadMime = "application/octet-stream" } req := mautrix.ReqUploadMedia{ ContentBytes: data, ContentType: uploadMime, } - var mxc id.ContentURI - if portal.bridge.Config.Homeserver.AsyncMedia { - uploaded, err := intent.UnstableUploadAsync(req) + if br.Config.Homeserver.AsyncMedia { + resp, err := intent.UnstableCreateMXC() if err != nil { - return err + return nil, err } - mxc = uploaded.ContentURI + dbFile.MXC = resp.ContentURI + req.UnstableMXC = resp.ContentURI + req.UploadURL = resp.UploadURL + go func() { + _, err = intent.UploadMedia(req) + if err != nil { + br.Log.Errorfln("Failed to upload %s: %v", req.UnstableMXC, err) + dbFile.Delete() + } + }() } else { uploaded, err := intent.UploadMedia(req) if err != nil { - return err + return nil, err } - mxc = uploaded.ContentURI + dbFile.MXC = uploaded.ContentURI } - - if file != nil { - content.File = &event.EncryptedFileInfo{ - EncryptedFile: *file, - URL: mxc.CUString(), - } - } else { - content.URL = mxc.CUString() - } - - return nil + dbFile.Insert(nil) + return dbFile, nil +} + +func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, url string, encrypt bool, attachmentID, mime string) (*database.File, error) { + dbFile := br.DB.File.Get(url, encrypt) + if dbFile == nil { + data, err := downloadDiscordAttachment(url) + if err != nil { + return nil, err + } + + dbFile, err = br.uploadMatrixAttachment(intent, data, url, encrypt, attachmentID, mime) + if err != nil { + return nil, err + } + } + return dbFile, nil } diff --git a/database/database.go b/database/database.go index 1b69d7e..080d5c6 100644 --- a/database/database.go +++ b/database/database.go @@ -24,6 +24,7 @@ type Database struct { Emoji *EmojiQuery Guild *GuildQuery Role *RoleQuery + File *FileQuery } func New(baseDB *dbutil.Database, log maulogger.Logger) *Database { @@ -65,6 +66,10 @@ func New(baseDB *dbutil.Database, log maulogger.Logger) *Database { db: db, log: log.Sub("Role"), } + db.File = &FileQuery{ + db: db, + log: log.Sub("File"), + } return db } diff --git a/database/file.go b/database/file.go new file mode 100644 index 0000000..8f7c403 --- /dev/null +++ b/database/file.go @@ -0,0 +1,132 @@ +package database + +import ( + "database/sql" + "encoding/json" + "errors" + "time" + + log "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" +) + +type FileQuery struct { + db *Database + log log.Logger +} + +// language=postgresql +const ( + fileSelect = "SELECT url, encrypted, id, mxc, size, width, height, decryption_info, timestamp FROM discord_file" + fileInsert = ` + INSERT INTO discord_file (url, encrypted, id, mxc, size, width, height, decryption_info, timestamp) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ` +) + +func (fq *FileQuery) New() *File { + return &File{ + db: fq.db, + log: fq.log, + } +} + +func (fq *FileQuery) Get(url string, encrypted bool) *File { + query := fileSelect + " WHERE url=$1 AND encrypted=$2" + return fq.New().Scan(fq.db.QueryRow(query, url, encrypted)) +} + +type File struct { + db *Database + log log.Logger + + URL string + Encrypted bool + + ID string + MXC id.ContentURI + + Size int + Width int + Height int + + DecryptionInfo *attachment.EncryptedFile + + Timestamp time.Time +} + +func (f *File) Scan(row dbutil.Scannable) *File { + var fileID sql.NullString + var decryptionInfo []byte + var width, height sql.NullInt32 + var timestamp int64 + var mxc string + err := row.Scan(&f.URL, &f.Encrypted, &fileID, &mxc, &f.Size, &width, &height, &decryptionInfo, ×tamp) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + f.log.Errorln("Database scan failed:", err) + panic(err) + } + return nil + } + f.ID = fileID.String + f.Timestamp = time.UnixMilli(timestamp) + f.Width = int(width.Int32) + f.Height = int(height.Int32) + f.MXC, err = id.ParseContentURI(mxc) + if err != nil { + f.log.Errorfln("Failed to parse content URI %s: %v", mxc, err) + panic(err) + } + if decryptionInfo != nil { + err = json.Unmarshal(decryptionInfo, &f.DecryptionInfo) + if err != nil { + f.log.Errorfln("Failed to unmarshal decryption info of %v: %v", f.MXC, err) + panic(err) + } + } + return f +} + +func positiveIntToNullInt32(val int) (ptr sql.NullInt32) { + if val > 0 { + ptr.Valid = true + ptr.Int32 = int32(val) + } + return +} + +func (f *File) Insert(txn dbutil.Execable) { + if txn == nil { + txn = f.db + } + var err error + var decryptionInfo []byte + if f.DecryptionInfo != nil { + decryptionInfo, err = json.Marshal(f.DecryptionInfo) + if err != nil { + f.log.Warnfln("Failed to marshal decryption info of %v: %v", f.MXC, err) + panic(err) + } + } + _, err = txn.Exec(fileInsert, + f.URL, f.Encrypted, strPtr(f.ID), f.MXC.String(), f.Size, + positiveIntToNullInt32(f.Width), positiveIntToNullInt32(f.Height), + decryptionInfo, f.Timestamp.UnixMilli(), + ) + if err != nil { + f.log.Warnfln("Failed to insert copied file %v: %v", f.MXC, err) + panic(err) + } +} + +func (f *File) Delete() { + _, err := f.db.Exec("DELETE FROM discord_file WHERE url=$1 AND encrypted=$2", f.URL, f.Encrypted) + if err != nil { + f.log.Warnfln("Failed to delete copied file %v: %v", f.MXC, err) + panic(err) + } +} diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index 6624be1..8425584 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v10: Latest revision +-- v0 -> v11: Latest revision CREATE TABLE guild ( dcid TEXT PRIMARY KEY, @@ -150,3 +150,21 @@ CREATE TABLE role ( PRIMARY KEY (dc_guild_id, dcid), CONSTRAINT role_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild (dcid) ON DELETE CASCADE ); + +CREATE TABLE discord_file ( + url TEXT, + encrypted BOOLEAN, + + id TEXT, + mxc TEXT NOT NULL, + + size BIGINT NOT NULL, + width INTEGER, + height INTEGER, + + decryption_info jsonb, + + timestamp BIGINT NOT NULL, + + PRIMARY KEY (url, encrypted) +); diff --git a/database/upgrades/11-cache-reuploaded-files.sql b/database/upgrades/11-cache-reuploaded-files.sql new file mode 100644 index 0000000..c32c2bc --- /dev/null +++ b/database/upgrades/11-cache-reuploaded-files.sql @@ -0,0 +1,18 @@ +-- v11: Cache files copied from Discord to Matrix +CREATE TABLE discord_file ( + url TEXT, + encrypted BOOLEAN, + + id TEXT, + mxc TEXT NOT NULL, + + size BIGINT NOT NULL, + width INTEGER, + height INTEGER, + + decryption_info jsonb, + + timestamp BIGINT NOT NULL, + + PRIMARY KEY (url, encrypted) +); diff --git a/portal.go b/portal.go index 7fbb544..2c9825c 100644 --- a/portal.go +++ b/portal.go @@ -523,31 +523,46 @@ func (portal *Portal) markMessageHandled(discordID string, editIndex int, author msg.MassInsert(parts) } -func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) { +func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) id.EventID { content := &event.MessageEventContent{ Body: fmt.Sprintf("Failed to bridge media: %v", bridgeErr), MsgType: event.MsgNotice, } - _, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0) + resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0) if err != nil { portal.log.Warnfln("Failed to send media error message to matrix: %v", err) + return "" } + return resp.EventID } const DiscordStickerSize = 160 func (portal *Portal) handleDiscordFile(typeName string, intent *appservice.IntentAPI, id, url string, content *event.MessageEventContent, ts time.Time, threadRelation *event.RelatesTo) *database.MessagePart { - data, err := portal.downloadDiscordAttachment(url) + dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, url, portal.Encrypted, id, content.Info.MimeType) if err != nil { - portal.sendMediaFailedMessage(intent, err) + errorEventID := portal.sendMediaFailedMessage(intent, err) + if errorEventID != "" { + return &database.MessagePart{ + AttachmentID: id, + MXID: errorEventID, + } + } return nil } - - err = portal.uploadMatrixAttachment(intent, data, content) - if err != nil { - portal.sendMediaFailedMessage(intent, err) - return nil + content.Info.Size = dbFile.Size + if content.Info.Width == 0 && content.Info.Height == 0 { + content.Info.Width = dbFile.Width + content.Info.Height = dbFile.Height + } + if dbFile.DecryptionInfo != nil { + content.File = &event.EncryptedFileInfo{ + EncryptedFile: *dbFile.DecryptionInfo, + URL: dbFile.MXC.CUString(), + } + } else { + content.URL = dbFile.MXC.CUString() } evtType := event.EventMessage