diff --git a/database/attachment.go b/database/attachment.go deleted file mode 100644 index 2a2d1e2..0000000 --- a/database/attachment.go +++ /dev/null @@ -1,126 +0,0 @@ -package database - -import ( - "database/sql" - "errors" - - log "maunium.net/go/maulogger/v2" - - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/util/dbutil" -) - -type AttachmentQuery struct { - db *Database - log log.Logger -} - -const ( - attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid FROM attachment" -) - -func (aq *AttachmentQuery) New() *Attachment { - return &Attachment{ - db: aq.db, - log: aq.log, - } -} - -func (aq *AttachmentQuery) GetAllByDiscordMessageID(key PortalKey, discordMessageID string) []*Attachment { - query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3" - - return aq.getAll(query, key.ChannelID, key.Receiver, discordMessageID) -} - -func (aq *AttachmentQuery) getAll(query string, args ...interface{}) []*Attachment { - rows, err := aq.db.Query(query, args...) - if err != nil { - aq.log.Debugfln("getAll failed: %v", err) - - return nil - } - - if rows == nil { - return nil - } - - var attachments []*Attachment - for rows.Next() { - attachments = append(attachments, aq.New().Scan(rows)) - } - - return attachments -} - -func (aq *AttachmentQuery) GetByDiscordID(key PortalKey, discordMessageID, discordID string) *Attachment { - query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3 AND dcid=$4" - - return aq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID) -} - -func (aq *AttachmentQuery) GetByMatrixID(key PortalKey, matrixEventID id.EventID) *Attachment { - query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3" - - return aq.get(query, key.ChannelID, key.Receiver, matrixEventID) -} - -func (aq *AttachmentQuery) get(query string, args ...interface{}) *Attachment { - row := aq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return aq.New().Scan(row) -} - -type Attachment struct { - db *Database - log log.Logger - - Channel PortalKey - ThreadID string - MessageID string - ID string - MXID id.EventID -} - -func (a *Attachment) DiscordProtoChannelID() string { - if a.ThreadID != "" { - return a.ThreadID - } else { - return a.Channel.ChannelID - } -} - -func (a *Attachment) Scan(row dbutil.Scannable) *Attachment { - err := row.Scan(&a.ID, &a.MessageID, &a.Channel.ChannelID, &a.Channel.Receiver, &a.ThreadID, &a.MXID) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - a.log.Errorln("Database scan failed:", err) - panic(err) - } - return nil - } - return a -} - -func (a *Attachment) Insert() { - query := ` - INSERT INTO attachment (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid) - VALUES ($1, $2, $3, $4, $5, $6) - ` - _, err := a.db.Exec(query, a.ID, a.MessageID, a.Channel.ChannelID, a.Channel.Receiver, strPtr(a.ThreadID), a.MXID) - if err != nil { - a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.ID, a.Channel, err) - panic(err) - } -} - -func (a *Attachment) Delete() { - query := `DELETE FROM attachment WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3` - _, err := a.db.Exec(query, a.Channel.ChannelID, a.Channel.Receiver, a.ID) - if err != nil { - a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.ID, a.Channel, err) - panic(err) - } -} diff --git a/database/database.go b/database/database.go index 5a74633..cfd874d 100644 --- a/database/database.go +++ b/database/database.go @@ -14,15 +14,14 @@ import ( type Database struct { *dbutil.Database - User *UserQuery - Portal *PortalQuery - Puppet *PuppetQuery - Message *MessageQuery - Thread *ThreadQuery - Reaction *ReactionQuery - Attachment *AttachmentQuery - Emoji *EmojiQuery - Guild *GuildQuery + User *UserQuery + Portal *PortalQuery + Puppet *PuppetQuery + Message *MessageQuery + Thread *ThreadQuery + Reaction *ReactionQuery + Emoji *EmojiQuery + Guild *GuildQuery } func New(baseDB *dbutil.Database) *Database { @@ -52,10 +51,6 @@ func New(baseDB *dbutil.Database) *Database { db: db, log: db.Log.Sub("Reaction"), } - db.Attachment = &AttachmentQuery{ - db: db, - log: db.Log.Sub("Attachment"), - } db.Emoji = &EmojiQuery{ db: db, log: db.Log.Sub("Emoji"), diff --git a/database/message.go b/database/message.go index 04fe525..1dd692e 100644 --- a/database/message.go +++ b/database/message.go @@ -3,6 +3,8 @@ package database import ( "database/sql" "errors" + "fmt" + "strings" "time" log "maunium.net/go/maulogger/v2" @@ -17,7 +19,7 @@ type MessageQuery struct { } const ( - messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message" + messageSelect = "SELECT dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message" ) func (mq *MessageQuery) New() *Message { @@ -27,11 +29,12 @@ func (mq *MessageQuery) New() *Message { } } -func (mq *MessageQuery) GetAll(key PortalKey) []*Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2" - - rows, err := mq.db.Query(query, key.ChannelID, key.Receiver) - if err != nil || rows == nil { +func (mq *MessageQuery) scanAll(rows *sql.Rows, err error) []*Message { + if err != nil { + mq.log.Warnfln("Failed to query many messages: %v", err) + panic(err) + return nil + } else if rows == nil { return nil } @@ -43,13 +46,18 @@ func (mq *MessageQuery) GetAll(key PortalKey) []*Message { return messages } -func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3" +func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message { + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC" + return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID)) +} + +func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message { + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC LIMIT 1" return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) } func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC LIMIT 1" + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND dc_edit_index=0 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID)) } @@ -68,11 +76,13 @@ type Message struct { db *Database log log.Logger - DiscordID string - Channel PortalKey - SenderID string - Timestamp time.Time - ThreadID string + DiscordID string + AttachmentID string + EditIndex int + Channel PortalKey + SenderID string + Timestamp time.Time + ThreadID string MXID id.EventID } @@ -89,7 +99,7 @@ func (m *Message) Scan(row dbutil.Scannable) *Message { var ts int64 var threadID sql.NullString - err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID) + err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { m.log.Errorln("Database scan failed:", err) @@ -107,14 +117,49 @@ func (m *Message) Scan(row dbutil.Scannable) *Message { return m } -func (m *Message) Insert() { - query := ` - INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ` +const messageInsertQuery = ` + INSERT INTO message ( + dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) +` - _, err := m.db.Exec(query, - m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, +var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9)", "%s", 1) + +type MessagePart struct { + AttachmentID string + MXID id.EventID +} + +func (m *Message) MassInsert(msgs []MessagePart) { + valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d)" + if m.db.Dialect == dbutil.SQLite { + valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") + } + params := make([]interface{}, 7+len(msgs)*2) + placeholders := make([]string, len(msgs)) + params[0] = m.DiscordID + params[1] = m.EditIndex + params[2] = m.Channel.ChannelID + params[3] = m.Channel.Receiver + params[4] = m.SenderID + params[5] = m.Timestamp.UnixMilli() + params[6] = m.ThreadID + for i, msg := range msgs { + params[7+i*2] = msg.AttachmentID + params[7+i*2+1] = msg.MXID + placeholders[i] = fmt.Sprintf(valueStringFormat, 7+i*2+1, 7+i*2+2) + } + _, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...) + if err != nil { + m.log.Warnfln("Failed to insert %d parts of %s@%s: %v", len(msgs), m.DiscordID, m.Channel, err) + panic(err) + } +} + +func (m *Message) Insert() { + _, err := m.db.Exec(messageInsertQuery, + m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID) if err != nil { diff --git a/database/reaction.go b/database/reaction.go index 71584db..99dbca9 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -78,6 +78,8 @@ type Reaction struct { ThreadID string MXID id.EventID + + FirstAttachmentID string } func (r *Reaction) Scan(row dbutil.Scannable) *Reaction { @@ -105,10 +107,10 @@ func (r *Reaction) DiscordProtoChannelID() string { func (r *Reaction) Insert() { query := ` - INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid) - VALUES($1, $2, $3, $4, $5, $6, $7) + INSERT INTO reaction (dc_msg_id, dc_first_attachment_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid) + VALUES($1, $2, $3, $4, $5, $6, $7, $8) ` - _, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID) + _, err := r.db.Exec(query, r.MessageID, r.FirstAttachmentID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID) if err != nil { r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err) panic(err) diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index f272adc..3dc764a 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v3: Latest revision +-- v0 -> v4: Latest revision CREATE TABLE guild ( dcid TEXT PRIMARY KEY, @@ -90,6 +90,8 @@ CREATE TABLE user_portal ( CREATE TABLE message ( dcid TEXT, + dc_attachment_id TEXT, + dc_edit_index INTEGER, dc_chan_id TEXT, dc_chan_receiver TEXT, dc_sender TEXT NOT NULL, @@ -98,7 +100,7 @@ CREATE TABLE message ( mxid TEXT NOT NULL UNIQUE, - PRIMARY KEY (dcid, dc_chan_id, dc_chan_receiver), + PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver), CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE ); @@ -110,10 +112,13 @@ CREATE TABLE reaction ( dc_emoji_name TEXT, dc_thread_id TEXT, + dc_first_attachment_id TEXT NOT NULL, + _dc_first_edit_index INTEGER NOT NULL DEFAULT 0, + mxid TEXT NOT NULL UNIQUE, PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), - CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE + CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE ); CREATE TABLE attachment ( diff --git a/database/upgrades/04-attachment-fix.sql b/database/upgrades/04-attachment-fix.sql new file mode 100644 index 0000000..c476afd --- /dev/null +++ b/database/upgrades/04-attachment-fix.sql @@ -0,0 +1,20 @@ +-- v4: Fix storing attachments +ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; +ALTER TABLE attachment DROP CONSTRAINT attachment_message_fkey; +ALTER TABLE message DROP CONSTRAINT message_pkey; +ALTER TABLE message ADD COLUMN dc_attachment_id TEXT NOT NULL DEFAULT ''; +ALTER TABLE message ADD COLUMN dc_edit_index INTEGER NOT NULL DEFAULT 0; +ALTER TABLE message ALTER COLUMN dc_attachment_id DROP DEFAULT; +ALTER TABLE message ALTER COLUMN dc_edit_index DROP DEFAULT; +ALTER TABLE message ADD PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver); +INSERT INTO message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid) + SELECT message.dcid, attachment.dcid, 0, attachment.dc_chan_id, attachment.dc_chan_receiver, message.dc_sender, message.timestamp, attachment.dc_thread_id, attachment.mxid + FROM attachment LEFT JOIN message ON attachment.dc_msg_id = message.dcid; +DROP TABLE attachment; + +ALTER TABLE reaction ADD COLUMN dc_first_attachment_id TEXT NOT NULL DEFAULT ''; +ALTER TABLE reaction ALTER COLUMN dc_first_attachment_id DROP DEFAULT; +ALTER TABLE reaction ADD COLUMN _dc_first_edit_index INTEGER DEFAULT 0; +ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey + FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) + REFERENCES message(dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver); diff --git a/portal.go b/portal.go index 2b0706c..c970039 100644 --- a/portal.go +++ b/portal.go @@ -459,16 +459,15 @@ func (portal *Portal) ensureUserInvited(user *User) bool { return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat()) } -func (portal *Portal) markMessageHandled(discordID string, mxid id.EventID, authorID string, timestamp time.Time, threadID string) *database.Message { +func (portal *Portal) markMessageHandled(discordID string, editIndex int, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) { msg := portal.bridge.DB.Message.New() msg.Channel = portal.Key msg.DiscordID = discordID - msg.MXID = mxid + msg.EditIndex = editIndex msg.SenderID = authorID msg.Timestamp = timestamp msg.ThreadID = threadID - msg.Insert() - return msg + msg.MassInsert(parts) } func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) { @@ -483,7 +482,7 @@ func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridg } } -func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgID string, attachment *discordgo.MessageAttachment, ts time.Time, threadRelation *event.RelatesTo, threadID string) { +func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgID string, attachment *discordgo.MessageAttachment, ts time.Time, threadRelation *event.RelatesTo, threadID string) *database.MessagePart { // var captionContent *event.MessageEventContent // if attachment.Description != "" { @@ -521,34 +520,27 @@ func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgI data, err := portal.downloadDiscordAttachment(attachment.URL) if err != nil { portal.sendMediaFailedMessage(intent, err) - - return + return nil } err = portal.uploadMatrixAttachment(intent, data, content) if err != nil { portal.sendMediaFailedMessage(intent, err) - - return + return nil } resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, ts.UnixMilli()) if err != nil { portal.log.Warnfln("failed to send media message to matrix: %v", err) } - - dbAttachment := portal.bridge.DB.Attachment.New() - dbAttachment.Channel = portal.Key - dbAttachment.MessageID = msgID - dbAttachment.ID = attachment.ID - dbAttachment.MXID = resp.EventID - dbAttachment.ThreadID = threadID - dbAttachment.Insert() // Update the fallback reply event for the next attachment if threadRelation != nil { threadRelation.InReplyTo.EventID = resp.EventID } - go portal.sendDeliveryReceipt(resp.EventID) + return &database.MessagePart{ + AttachmentID: attachment.ID, + MXID: resp.EventID, + } } func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) { @@ -604,6 +596,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess threadRelation = (&event.RelatesTo{}).SetThread(thread.RootMXID, lastEventID) } + var parts []database.MessagePart ts, _ := discordgo.SnowflakeTimestamp(msg.ID) if msg.Content != "" { content := renderDiscordMarkdown(msg.Content) @@ -612,11 +605,11 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess if msg.MessageReference != nil { //key := database.PortalKey{msg.MessageReference.ChannelID, user.ID} replyTo := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID) - if replyTo != nil { + if len(replyTo) > 0 { if content.RelatesTo == nil { content.RelatesTo = &event.RelatesTo{} } - content.RelatesTo.SetReplyTo(replyTo.MXID) + content.RelatesTo.SetReplyTo(replyTo[0].MXID) } } @@ -626,7 +619,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess return } - portal.markMessageHandled(msg.ID, resp.EventID, msg.Author.ID, ts, threadID) + parts = append(parts, database.MessagePart{MXID: resp.EventID}) // Update the fallback reply event for attachments if threadRelation != nil { threadRelation.InReplyTo.EventID = resp.EventID @@ -635,8 +628,12 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } for _, attachment := range msg.Attachments { - portal.handleDiscordAttachment(intent, msg.ID, attachment, ts, threadRelation, threadID) + part := portal.handleDiscordAttachment(intent, msg.ID, attachment, ts, threadRelation, threadID) + if part != nil { + parts = append(parts, *part) + } } + portal.markMessageHandled(msg.ID, 0, msg.Author.ID, ts, threadID, parts) } func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) { @@ -653,7 +650,7 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess } if msg.Flags == discordgo.MessageFlagsHasThread { - portal.bridge.GetThreadByID(msg.ID, existing) + portal.bridge.GetThreadByID(msg.ID, existing[0]) portal.log.Debugfln("Marked %s as a thread root", msg.ID) // TODO make autojoining configurable //err := user.Session.ThreadJoinWithLocation(msg.ID, discordgo.ThreadJoinLocationContextMenu) @@ -682,46 +679,12 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess intent := portal.bridge.GetPuppetByID(msg.Author.ID).IntentFor(portal) - if existing == nil { - // Due to the differences in Discord and Matrix attachment handling, - // existing will return nil if the original message was empty as we - // don't store/save those messages so we can determine when we're - // working against an attachment and do the attachment lookup instead. - - // Find all the existing attachments and drop them in a map so we can - // figure out which, if any have been deleted and clean them up on the - // matrix side. - attachmentMap := map[string]*database.Attachment{} - attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID) - - for _, attachment := range attachments { - attachmentMap[attachment.ID] = attachment - } - - // Now run through the list of attachments on this message and remove - // them from the map. - for _, attachment := range msg.Attachments { - if _, found := attachmentMap[attachment.ID]; found { - delete(attachmentMap, attachment.ID) - } - } - - // Finally run through any attachments still in the map and delete them - // on the matrix side and our database. - for _, attachment := range attachmentMap { - _, err := intent.RedactEvent(portal.MXID, attachment.MXID) - if err != nil { - portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err) - } - - attachment.Delete() - } - + if msg.Content == "" || existing[0].AttachmentID != "" { + portal.log.Debugfln("Dropping non-text edit to %s", msg.ID) return } - content := renderDiscordMarkdown(msg.Content) - content.SetEdit(existing.MXID) + content.SetEdit(existing[0].MXID) var editTS int64 if msg.EditedTimestamp != nil { @@ -742,33 +705,20 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess } func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) { - // The discord delete message object is pretty empty and doesn't include - // the author so we have to use the DMUser from the portal that was added - // at creation time if we're a DM. We'll might have similar issues when we - // add guild message support, but we'll cross that bridge when we get - // there. - - // Find the message that we're working with. This could correctly return - // nil if the message was just one or more attachments. existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) intent := portal.MainIntent() - - if existing != nil { - attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID) - for _, attachment := range attachments { - _, err := intent.RedactEvent(portal.MXID, attachment.MXID) - if err != nil { - portal.log.Warnfln("Failed to redact attachment %s: %v", attachment.MXID, err) - } - attachment.Delete() - } - - resp, err := intent.RedactEvent(portal.MXID, existing.MXID) + var lastResp id.EventID + for _, dbMsg := range existing { + resp, err := intent.RedactEvent(portal.MXID, dbMsg.MXID) if err != nil { - portal.log.Warnfln("Failed to redact message %s: %v", existing.MXID, err) + portal.log.Warnfln("Failed to redact message %s: %v", dbMsg.MXID, err) + } else if resp != nil && resp.EventID != "" { + lastResp = resp.EventID } - existing.Delete() - portal.sendDeliveryReceipt(resp.EventID) + dbMsg.Delete() + } + if lastResp != "" { + portal.sendDeliveryReceipt(lastResp) } } @@ -1017,6 +967,9 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { dbMsg := portal.bridge.DB.Message.New() dbMsg.Channel = portal.Key dbMsg.DiscordID = msg.ID + if len(msg.Attachments) > 0 { + dbMsg.AttachmentID = msg.Attachments[0].ID + } dbMsg.MXID = evt.ID dbMsg.SenderID = sender.DiscordID dbMsg.Timestamp, _ = discordgo.SnowflakeTimestamp(msg.ID) @@ -1161,34 +1114,15 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { return } - var discordID, threadID string - channelID := portal.Key.ChannelID - msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID) - - // Due to the differences in attachments between Discord and Matrix, if a - // user reacts to a media message on discord our lookup above will fail - // because the relation of matrix media messages to attachments in handled - // in the attachments table instead of messages so we need to check that - // before continuing. - // - // This also leads to interesting problems when a Discord message comes in - // with multiple attachments. A user can react to each one individually on - // Matrix, which will cause us to send it twice. Discord tends to ignore - // this, but if the user removes one of them, discord removes it and now - // they're out of sync. Perhaps we should add a counter to the reactions - // table to keep them in sync and to avoid sending duplicates to Discord. if msg == nil { - attachment := portal.bridge.DB.Attachment.GetByMatrixID(portal.Key, reaction.RelatesTo.EventID) - if attachment == nil { - portal.bridge.SendMessageErrorCheckpoint(evt, bridge.MsgStepRemote, errors.New("unknown reaction target"), true, 0) - return - } - discordID = attachment.MessageID - threadID = attachment.ThreadID - } else { - discordID = msg.DiscordID - threadID = msg.ThreadID + portal.bridge.SendMessageErrorCheckpoint(evt, bridge.MsgStepRemote, errors.New("unknown reaction target"), true, 0) + } + + firstMsg := msg + if msg.AttachmentID != "" { + firstMsg = portal.bridge.DB.Message.GetFirstByDiscordID(portal.Key, msg.DiscordID) + // TODO should the emoji be rerouted to the first message if it's different? } // Figure out if this is a custom emoji or not. @@ -1207,22 +1141,28 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { emojiID = variationselector.Remove(emojiID) } - if threadID != "" { - channelID = threadID + existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, msg.DiscordID, sender.DiscordID, emojiID) + if existing != nil { + portal.log.Debugfln("Dropping duplicate Matrix reaction %s (already sent as %s)", evt.ID, existing.MXID) + portal.bridge.SendMessageSuccessCheckpoint(evt, bridge.MsgStepRemote, 0) + portal.sendDeliveryReceipt(evt.ID) + return } - err := sender.Session.MessageReactionAdd(channelID, discordID, emojiID) + + err := sender.Session.MessageReactionAdd(msg.DiscordProtoChannelID(), msg.DiscordID, emojiID) if err != nil { - portal.log.Debugf("Failed to send reaction to %s: %v", discordID, err) + portal.log.Debugf("Failed to send reaction to %s: %v", msg.DiscordID, err) portal.bridge.SendMessageErrorCheckpoint(evt, bridge.MsgStepRemote, err, true, 0) return } dbReaction := portal.bridge.DB.Reaction.New() dbReaction.Channel = portal.Key - dbReaction.MessageID = discordID + dbReaction.MessageID = msg.DiscordID + dbReaction.FirstAttachmentID = firstMsg.AttachmentID dbReaction.Sender = sender.DiscordID dbReaction.EmojiName = emojiID - dbReaction.ThreadID = threadID + dbReaction.ThreadID = msg.ThreadID dbReaction.MXID = evt.ID dbReaction.Insert() portal.log.Debugfln("Handled Matrix reaction %s", evt.ID) @@ -1272,12 +1212,11 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess message := portal.bridge.DB.Message.GetByDiscordID(portal.Key, reaction.MessageID) if message == nil { portal.log.Debugfln("failed to add reaction to message %s: message not found", reaction.MessageID) - return } // Lookup an existing reaction - existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message.DiscordID, reaction.UserID, discordID) + existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message[0].DiscordID, reaction.UserID, discordID) if !add { if existing == nil { portal.log.Debugln("Failed to remove reaction for unknown message", reaction.MessageID) @@ -1293,13 +1232,13 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess go portal.sendDeliveryReceipt(resp.EventID) return } else if existing != nil { - portal.log.Debugfln("Ignoring duplicate reaction %s from %s to %s", discordID, reaction.UserID, message.DiscordID) + portal.log.Debugfln("Ignoring duplicate reaction %s from %s to %s", discordID, reaction.UserID, message[0].DiscordID) return } content := event.Content{Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ - EventID: message.MXID, + EventID: message[0].MXID, Type: event.RelAnnotation, Key: matrixReaction, }, @@ -1320,7 +1259,8 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess if existing == nil { dbReaction := portal.bridge.DB.Reaction.New() dbReaction.Channel = portal.Key - dbReaction.MessageID = message.DiscordID + dbReaction.MessageID = message[0].DiscordID + dbReaction.FirstAttachmentID = message[0].AttachmentID dbReaction.Sender = reaction.UserID dbReaction.EmojiName = discordID dbReaction.MXID = resp.EventID