From 4324b60a2c601304e3f28dcf38efdf717a9d3a2a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 May 2023 22:10:00 +0300 Subject: [PATCH] Store edit timestamp in database to deduplicate edits. Fixes #86 --- database/message.go | 80 ++++++++++++------- database/upgrades/00-latest-revision.sql | 25 +++--- .../upgrades/19-message-edit-ts.postgres.sql | 15 ++++ .../upgrades/19-message-edit-ts.sqlite.sql | 48 +++++++++++ portal.go | 25 ++++-- 5 files changed, 145 insertions(+), 48 deletions(-) create mode 100644 database/upgrades/19-message-edit-ts.postgres.sql create mode 100644 database/upgrades/19-message-edit-ts.sqlite.sql diff --git a/database/message.go b/database/message.go index 5fba7e6..35d6552 100644 --- a/database/message.go +++ b/database/message.go @@ -19,7 +19,7 @@ type MessageQuery struct { } const ( - messageSelect = "SELECT dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message" + messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid FROM message" ) func (mq *MessageQuery) New() *Message { @@ -46,17 +46,17 @@ func (mq *MessageQuery) scanAll(rows dbutil.Rows, err error) []*Message { } func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC" + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC" return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID)) } func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC LIMIT 1" + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC LIMIT 1" return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) } func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id DESC LIMIT 1" + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id DESC LIMIT 1" return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) } @@ -66,12 +66,12 @@ func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time } func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND dc_edit_index=0 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID)) } func (mq *MessageQuery) GetLast(key PortalKey) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_edit_index=0 ORDER BY timestamp DESC LIMIT 1" + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 ORDER BY timestamp DESC LIMIT 1" return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver)) } @@ -99,7 +99,7 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) { if len(msgs) == 0 { return } - valueStringFormat := "($%d, $%d, $%d, $1, $2, $%d, $%d, $%d, $%d)" + valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)" if mq.db.Dialect == dbutil.SQLite { valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") } @@ -111,9 +111,9 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) { baseIndex := 2 + i*7 params[baseIndex] = msg.DiscordID params[baseIndex+1] = msg.AttachmentID - params[baseIndex+2] = msg.EditIndex - params[baseIndex+3] = msg.SenderID - params[baseIndex+4] = msg.Timestamp.UnixMilli() + params[baseIndex+2] = msg.SenderID + params[baseIndex+3] = msg.Timestamp.UnixMilli() + params[baseIndex+4] = msg.editTimestampVal() params[baseIndex+5] = msg.ThreadID params[baseIndex+6] = msg.MXID placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7) @@ -129,13 +129,13 @@ type Message struct { db *Database log log.Logger - DiscordID string - AttachmentID string - EditIndex int - Channel PortalKey - SenderID string - Timestamp time.Time - ThreadID string + DiscordID string + AttachmentID string + Channel PortalKey + SenderID string + Timestamp time.Time + EditTimestamp time.Time + ThreadID string MXID id.EventID } @@ -149,9 +149,9 @@ func (m *Message) DiscordProtoChannelID() string { } func (m *Message) Scan(row dbutil.Scannable) *Message { - var ts int64 + var ts, editTS int64 - err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.ThreadID, &m.MXID) + err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { m.log.Errorln("Database scan failed:", err) @@ -162,7 +162,10 @@ func (m *Message) Scan(row dbutil.Scannable) *Message { } if ts != 0 { - m.Timestamp = time.UnixMilli(ts) + m.Timestamp = time.UnixMilli(ts).UTC() + } + if editTS != 0 { + m.EditTimestamp = time.Unix(0, editTS).UTC() } return m @@ -170,7 +173,7 @@ func (m *Message) Scan(row dbutil.Scannable) *Message { const messageInsertQuery = ` INSERT INTO message ( - dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid + dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ` @@ -182,6 +185,13 @@ type MessagePart struct { MXID id.EventID } +func (m *Message) editTimestampVal() int64 { + if m.EditTimestamp.IsZero() { + return 0 + } + return m.EditTimestamp.UnixNano() +} + func (m *Message) MassInsertParts(msgs []MessagePart) { if len(msgs) == 0 { return @@ -193,11 +203,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) { params := make([]interface{}, 7+len(msgs)*2) placeholders := make([]string, len(msgs)) params[0] = m.DiscordID - params[1] = m.EditIndex - params[2] = m.Channel.ChannelID - params[3] = m.Channel.Receiver - params[4] = m.SenderID - params[5] = m.Timestamp.UnixMilli() + params[1] = m.Channel.ChannelID + params[2] = m.Channel.Receiver + params[3] = m.SenderID + params[4] = m.Timestamp.UnixMilli() + params[5] = m.editTimestampVal() params[6] = m.ThreadID for i, msg := range msgs { params[7+i*2] = msg.AttachmentID @@ -213,8 +223,8 @@ func (m *Message) MassInsertParts(msgs []MessagePart) { func (m *Message) Insert() { _, err := m.db.Exec(messageInsertQuery, - m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, - m.Timestamp.UnixMilli(), m.ThreadID, m.MXID) + m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, + m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID) if err != nil { m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err) @@ -222,6 +232,20 @@ func (m *Message) Insert() { } } +const editUpdateQuery = ` + UPDATE message + SET dc_edit_timestamp=$1 + WHERE dcid=$2 AND dc_attachment_id=$3 AND dc_chan_id=$4 AND dc_chan_receiver=$5 AND dc_edit_timestamp<$1 +` + +func (m *Message) UpdateEditTimestamp(ts time.Time) { + _, err := m.db.Exec(editUpdateQuery, ts.UnixNano(), m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver) + if err != nil { + m.log.Warnfln("Failed to update edit timestamp of %s@%s: %v", m.DiscordID, m.Channel, err) + panic(err) + } +} + func (m *Message) Delete() { query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4" _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID) diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index 08e7b88..777e220 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v18 (compatible with v15+): Latest revision +-- v0 -> v19: Latest revision CREATE TABLE guild ( dcid TEXT PRIMARY KEY, @@ -104,18 +104,18 @@ CREATE TABLE user_portal ( ); CREATE TABLE message ( - dcid TEXT, - dc_attachment_id TEXT, - dc_edit_index INTEGER, - dc_chan_id TEXT, - dc_chan_receiver TEXT, - dc_sender TEXT NOT NULL, - timestamp BIGINT NOT NULL, - dc_thread_id TEXT NOT NULL, + dcid TEXT, + dc_attachment_id TEXT, + dc_chan_id TEXT, + dc_chan_receiver TEXT, + dc_sender TEXT NOT NULL, + timestamp BIGINT NOT NULL, + dc_edit_timestamp BIGINT NOT NULL, + dc_thread_id TEXT NOT NULL, mxid TEXT NOT NULL UNIQUE, - PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver), + PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver), CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE ); @@ -127,13 +127,12 @@ CREATE TABLE reaction ( dc_emoji_name TEXT, dc_thread_id TEXT NOT NULL, - dc_first_attachment_id TEXT NOT NULL, - _dc_first_edit_index INTEGER NOT NULL DEFAULT 0, + dc_first_attachment_id TEXT NOT NULL, mxid TEXT NOT NULL UNIQUE, PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), - CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE + CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE ); CREATE TABLE role ( diff --git a/database/upgrades/19-message-edit-ts.postgres.sql b/database/upgrades/19-message-edit-ts.postgres.sql new file mode 100644 index 0000000..231afa1 --- /dev/null +++ b/database/upgrades/19-message-edit-ts.postgres.sql @@ -0,0 +1,15 @@ +-- v19: Replace dc_edit_index with dc_edit_timestamp +-- transaction: off +BEGIN; + +ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; +ALTER TABLE message DROP CONSTRAINT message_pkey; +ALTER TABLE message DROP COLUMN dc_edit_index; +ALTER TABLE reaction DROP COLUMN _dc_first_edit_index; +ALTER TABLE message ADD PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver); +ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE; + +ALTER TABLE message ADD COLUMN dc_edit_timestamp BIGINT NOT NULL DEFAULT 0; +ALTER TABLE message ALTER COLUMN dc_edit_timestamp DROP DEFAULT; + +COMMIT; diff --git a/database/upgrades/19-message-edit-ts.sqlite.sql b/database/upgrades/19-message-edit-ts.sqlite.sql new file mode 100644 index 0000000..98c3a23 --- /dev/null +++ b/database/upgrades/19-message-edit-ts.sqlite.sql @@ -0,0 +1,48 @@ +-- v19: Replace dc_edit_index with dc_edit_timestamp +-- transaction: off +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE message_new ( + dcid TEXT, + dc_attachment_id TEXT, + dc_chan_id TEXT, + dc_chan_receiver TEXT, + dc_sender TEXT NOT NULL, + timestamp BIGINT NOT NULL, + dc_edit_timestamp BIGINT NOT NULL, + dc_thread_id TEXT NOT NULL, + + mxid TEXT NOT NULL UNIQUE, + + PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver), + CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE +); +INSERT INTO message_new (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid) + SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, 0, dc_thread_id, mxid FROM message; +DROP TABLE message; +ALTER TABLE message_new RENAME TO message; + +CREATE TABLE reaction_new ( + dc_chan_id TEXT, + dc_chan_receiver TEXT, + dc_msg_id TEXT, + dc_sender TEXT, + dc_emoji_name TEXT, + dc_thread_id TEXT NOT NULL, + + dc_first_attachment_id TEXT NOT NULL, + + mxid TEXT NOT NULL UNIQUE, + + PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), + CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE +); +INSERT INTO reaction_new (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, dc_first_attachment_id, mxid) + SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, dc_first_attachment_id, mxid FROM reaction; +DROP TABLE reaction; +ALTER TABLE reaction_new RENAME TO reaction; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/portal.go b/portal.go index 15c7fea..01eb4f5 100644 --- a/portal.go +++ b/portal.go @@ -583,11 +583,10 @@ func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool { return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache) } -func (portal *Portal) markMessageHandled(discordID string, editIndex int, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) { +func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) { msg := portal.bridge.DB.Message.New() msg.Channel = portal.Key msg.DiscordID = discordID - msg.EditIndex = editIndex msg.SenderID = authorID msg.Timestamp = timestamp msg.ThreadID = threadID @@ -674,7 +673,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } else if len(dbParts) == 0 { log.Warn().Msg("All parts of message failed to send to Matrix") } else { - portal.markMessageHandled(msg.ID, 0, msg.Author.ID, ts, discordThreadID, dbParts) + portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, dbParts) } } @@ -778,6 +777,13 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess log.Warn().Msg("Dropping update of unknown message") return } + if msg.EditedTimestamp != nil && !msg.EditedTimestamp.After(existing[0].EditTimestamp) { + log.Debug(). + Time("received_edit_ts", *msg.EditedTimestamp). + Time("db_edit_ts", existing[0].EditTimestamp). + Msg("Dropping update of message with older or equal edit timestamp") + return + } if msg.Flags == discordgo.MessageFlagsHasThread { thread := portal.bridge.GetThreadByID(msg.ID, existing[0]) @@ -885,8 +891,9 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess portal.sendDeliveryReceipt(resp.EventID) - //ts, _ := msg.Timestamp.Parse() - //portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts) + if msg.EditedTimestamp != nil { + existing[0].UpdateEditTimestamp(*msg.EditedTimestamp) + } } func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) { @@ -1386,16 +1393,20 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { if edits != nil { discordContent, allowedMentions := portal.parseMatrixHTML(content.NewContent) var err error + var msg *discordgo.Message if !isWebhookSend { // TODO save edit in message table - _, err = sess.ChannelMessageEdit(edits.DiscordProtoChannelID(), edits.DiscordID, discordContent) + msg, err = sess.ChannelMessageEdit(edits.DiscordProtoChannelID(), edits.DiscordID, discordContent) } else { - _, err = relayClient.WebhookMessageEdit(portal.RelayWebhookID, portal.RelayWebhookSecret, edits.DiscordID, &discordgo.WebhookEdit{ + msg, err = relayClient.WebhookMessageEdit(portal.RelayWebhookID, portal.RelayWebhookSecret, edits.DiscordID, &discordgo.WebhookEdit{ Content: &discordContent, AllowedMentions: allowedMentions, }) } go portal.sendMessageMetrics(evt, err, "Failed to edit") + if msg.EditedTimestamp != nil { + edits.UpdateEditTimestamp(*msg.EditedTimestamp) + } } else { go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownEditTarget, editMXID), "Ignoring") }