Cache files copied to Matrix

This commit is contained in:
Tulir Asokan
2023-01-28 01:57:44 +02:00
parent 9e6b573e83
commit 9ca27a8df6
6 changed files with 250 additions and 42 deletions

View File

@@ -7,18 +7,19 @@ import (
"io"
"net/http"
"strings"
"maunium.net/go/mautrix/crypto/attachment"
"time"
"github.com/bwmarrin/discordgo"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/mautrix-discord/database"
)
func (portal *Portal) downloadDiscordAttachment(url string) ([]byte, error) {
func downloadDiscordAttachment(url string) ([]byte, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
@@ -68,48 +69,67 @@ func (portal *Portal) downloadMatrixAttachment(content *event.MessageEventConten
return data, nil
}
func (portal *Portal) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, content *event.MessageEventContent) error {
content.Info.Size = len(data)
if content.Info.Width == 0 && content.Info.Height == 0 && strings.HasPrefix(content.Info.MimeType, "image/") {
func (br *DiscordBridge) uploadMatrixAttachment(intent *appservice.IntentAPI, data []byte, url string, encrypt bool, attachmentID, mime string) (*database.File, error) {
dbFile := br.DB.File.New()
dbFile.Timestamp = time.Now()
dbFile.URL = url
dbFile.ID = attachmentID
dbFile.Size = len(data)
if strings.HasPrefix(mime, "image/") {
cfg, _, _ := image.DecodeConfig(bytes.NewReader(data))
content.Info.Width = cfg.Width
content.Info.Height = cfg.Height
dbFile.Width = cfg.Width
dbFile.Height = cfg.Height
}
uploadMime := content.Info.MimeType
var file *attachment.EncryptedFile
if portal.Encrypted {
file = attachment.NewEncryptedFile()
file.EncryptInPlace(data)
uploadMime := mime
if encrypt {
dbFile.Encrypted = true
dbFile.DecryptionInfo = attachment.NewEncryptedFile()
dbFile.DecryptionInfo.EncryptInPlace(data)
uploadMime = "application/octet-stream"
}
req := mautrix.ReqUploadMedia{
ContentBytes: data,
ContentType: uploadMime,
}
var mxc id.ContentURI
if portal.bridge.Config.Homeserver.AsyncMedia {
uploaded, err := intent.UnstableUploadAsync(req)
if br.Config.Homeserver.AsyncMedia {
resp, err := intent.UnstableCreateMXC()
if err != nil {
return err
return nil, err
}
mxc = uploaded.ContentURI
dbFile.MXC = resp.ContentURI
req.UnstableMXC = resp.ContentURI
req.UploadURL = resp.UploadURL
go func() {
_, err = intent.UploadMedia(req)
if err != nil {
br.Log.Errorfln("Failed to upload %s: %v", req.UnstableMXC, err)
dbFile.Delete()
}
}()
} else {
uploaded, err := intent.UploadMedia(req)
if err != nil {
return err
return nil, err
}
mxc = uploaded.ContentURI
dbFile.MXC = uploaded.ContentURI
}
if file != nil {
content.File = &event.EncryptedFileInfo{
EncryptedFile: *file,
URL: mxc.CUString(),
}
} else {
content.URL = mxc.CUString()
}
return nil
dbFile.Insert(nil)
return dbFile, nil
}
func (br *DiscordBridge) copyAttachmentToMatrix(intent *appservice.IntentAPI, url string, encrypt bool, attachmentID, mime string) (*database.File, error) {
dbFile := br.DB.File.Get(url, encrypt)
if dbFile == nil {
data, err := downloadDiscordAttachment(url)
if err != nil {
return nil, err
}
dbFile, err = br.uploadMatrixAttachment(intent, data, url, encrypt, attachmentID, mime)
if err != nil {
return nil, err
}
}
return dbFile, nil
}

View File

@@ -24,6 +24,7 @@ type Database struct {
Emoji *EmojiQuery
Guild *GuildQuery
Role *RoleQuery
File *FileQuery
}
func New(baseDB *dbutil.Database, log maulogger.Logger) *Database {
@@ -65,6 +66,10 @@ func New(baseDB *dbutil.Database, log maulogger.Logger) *Database {
db: db,
log: log.Sub("Role"),
}
db.File = &FileQuery{
db: db,
log: log.Sub("File"),
}
return db
}

132
database/file.go Normal file
View File

@@ -0,0 +1,132 @@
package database
import (
"database/sql"
"encoding/json"
"errors"
"time"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type FileQuery struct {
db *Database
log log.Logger
}
// language=postgresql
const (
fileSelect = "SELECT url, encrypted, id, mxc, size, width, height, decryption_info, timestamp FROM discord_file"
fileInsert = `
INSERT INTO discord_file (url, encrypted, id, mxc, size, width, height, decryption_info, timestamp)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
)
func (fq *FileQuery) New() *File {
return &File{
db: fq.db,
log: fq.log,
}
}
func (fq *FileQuery) Get(url string, encrypted bool) *File {
query := fileSelect + " WHERE url=$1 AND encrypted=$2"
return fq.New().Scan(fq.db.QueryRow(query, url, encrypted))
}
type File struct {
db *Database
log log.Logger
URL string
Encrypted bool
ID string
MXC id.ContentURI
Size int
Width int
Height int
DecryptionInfo *attachment.EncryptedFile
Timestamp time.Time
}
func (f *File) Scan(row dbutil.Scannable) *File {
var fileID sql.NullString
var decryptionInfo []byte
var width, height sql.NullInt32
var timestamp int64
var mxc string
err := row.Scan(&f.URL, &f.Encrypted, &fileID, &mxc, &f.Size, &width, &height, &decryptionInfo, &timestamp)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
f.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
f.ID = fileID.String
f.Timestamp = time.UnixMilli(timestamp)
f.Width = int(width.Int32)
f.Height = int(height.Int32)
f.MXC, err = id.ParseContentURI(mxc)
if err != nil {
f.log.Errorfln("Failed to parse content URI %s: %v", mxc, err)
panic(err)
}
if decryptionInfo != nil {
err = json.Unmarshal(decryptionInfo, &f.DecryptionInfo)
if err != nil {
f.log.Errorfln("Failed to unmarshal decryption info of %v: %v", f.MXC, err)
panic(err)
}
}
return f
}
func positiveIntToNullInt32(val int) (ptr sql.NullInt32) {
if val > 0 {
ptr.Valid = true
ptr.Int32 = int32(val)
}
return
}
func (f *File) Insert(txn dbutil.Execable) {
if txn == nil {
txn = f.db
}
var err error
var decryptionInfo []byte
if f.DecryptionInfo != nil {
decryptionInfo, err = json.Marshal(f.DecryptionInfo)
if err != nil {
f.log.Warnfln("Failed to marshal decryption info of %v: %v", f.MXC, err)
panic(err)
}
}
_, err = txn.Exec(fileInsert,
f.URL, f.Encrypted, strPtr(f.ID), f.MXC.String(), f.Size,
positiveIntToNullInt32(f.Width), positiveIntToNullInt32(f.Height),
decryptionInfo, f.Timestamp.UnixMilli(),
)
if err != nil {
f.log.Warnfln("Failed to insert copied file %v: %v", f.MXC, err)
panic(err)
}
}
func (f *File) Delete() {
_, err := f.db.Exec("DELETE FROM discord_file WHERE url=$1 AND encrypted=$2", f.URL, f.Encrypted)
if err != nil {
f.log.Warnfln("Failed to delete copied file %v: %v", f.MXC, err)
panic(err)
}
}

View File

@@ -1,4 +1,4 @@
-- v0 -> v10: Latest revision
-- v0 -> v11: Latest revision
CREATE TABLE guild (
dcid TEXT PRIMARY KEY,
@@ -150,3 +150,21 @@ CREATE TABLE role (
PRIMARY KEY (dc_guild_id, dcid),
CONSTRAINT role_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild (dcid) ON DELETE CASCADE
);
CREATE TABLE discord_file (
url TEXT,
encrypted BOOLEAN,
id TEXT,
mxc TEXT NOT NULL,
size BIGINT NOT NULL,
width INTEGER,
height INTEGER,
decryption_info jsonb,
timestamp BIGINT NOT NULL,
PRIMARY KEY (url, encrypted)
);

View File

@@ -0,0 +1,18 @@
-- v11: Cache files copied from Discord to Matrix
CREATE TABLE discord_file (
url TEXT,
encrypted BOOLEAN,
id TEXT,
mxc TEXT NOT NULL,
size BIGINT NOT NULL,
width INTEGER,
height INTEGER,
decryption_info jsonb,
timestamp BIGINT NOT NULL,
PRIMARY KEY (url, encrypted)
);

View File

@@ -523,31 +523,46 @@ func (portal *Portal) markMessageHandled(discordID string, editIndex int, author
msg.MassInsert(parts)
}
func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) {
func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) id.EventID {
content := &event.MessageEventContent{
Body: fmt.Sprintf("Failed to bridge media: %v", bridgeErr),
MsgType: event.MsgNotice,
}
_, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0)
resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0)
if err != nil {
portal.log.Warnfln("Failed to send media error message to matrix: %v", err)
return ""
}
return resp.EventID
}
const DiscordStickerSize = 160
func (portal *Portal) handleDiscordFile(typeName string, intent *appservice.IntentAPI, id, url string, content *event.MessageEventContent, ts time.Time, threadRelation *event.RelatesTo) *database.MessagePart {
data, err := portal.downloadDiscordAttachment(url)
dbFile, err := portal.bridge.copyAttachmentToMatrix(intent, url, portal.Encrypted, id, content.Info.MimeType)
if err != nil {
portal.sendMediaFailedMessage(intent, err)
errorEventID := portal.sendMediaFailedMessage(intent, err)
if errorEventID != "" {
return &database.MessagePart{
AttachmentID: id,
MXID: errorEventID,
}
}
return nil
}
err = portal.uploadMatrixAttachment(intent, data, content)
if err != nil {
portal.sendMediaFailedMessage(intent, err)
return nil
content.Info.Size = dbFile.Size
if content.Info.Width == 0 && content.Info.Height == 0 {
content.Info.Width = dbFile.Width
content.Info.Height = dbFile.Height
}
if dbFile.DecryptionInfo != nil {
content.File = &event.EncryptedFileInfo{
EncryptedFile: *dbFile.DecryptionInfo,
URL: dbFile.MXC.CUString(),
}
} else {
content.URL = dbFile.MXC.CUString()
}
evtType := event.EventMessage