Store edit timestamp in database to deduplicate edits. Fixes #86

This commit is contained in:
Tulir Asokan
2023-05-06 22:10:00 +03:00
parent c26de9c7df
commit 4324b60a2c
5 changed files with 145 additions and 48 deletions

View File

@@ -19,7 +19,7 @@ type MessageQuery struct {
} }
const ( const (
messageSelect = "SELECT dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message" messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid FROM message"
) )
func (mq *MessageQuery) New() *Message { func (mq *MessageQuery) New() *Message {
@@ -46,17 +46,17 @@ func (mq *MessageQuery) scanAll(rows dbutil.Rows, err error) []*Message {
} }
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message { func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC" query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC"
return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID)) return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID))
} }
func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message { func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC LIMIT 1" query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id ASC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
} }
func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message { func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id DESC LIMIT 1" query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 ORDER BY dc_attachment_id DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
} }
@@ -66,12 +66,12 @@ func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time
} }
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message { func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND dc_edit_index=0 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID)) return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
} }
func (mq *MessageQuery) GetLast(key PortalKey) *Message { func (mq *MessageQuery) GetLast(key PortalKey) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_edit_index=0 ORDER BY timestamp DESC LIMIT 1" query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 ORDER BY timestamp DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver)) return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver))
} }
@@ -99,7 +99,7 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
if len(msgs) == 0 { if len(msgs) == 0 {
return return
} }
valueStringFormat := "($%d, $%d, $%d, $1, $2, $%d, $%d, $%d, $%d)" valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)"
if mq.db.Dialect == dbutil.SQLite { if mq.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
} }
@@ -111,9 +111,9 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
baseIndex := 2 + i*7 baseIndex := 2 + i*7
params[baseIndex] = msg.DiscordID params[baseIndex] = msg.DiscordID
params[baseIndex+1] = msg.AttachmentID params[baseIndex+1] = msg.AttachmentID
params[baseIndex+2] = msg.EditIndex params[baseIndex+2] = msg.SenderID
params[baseIndex+3] = msg.SenderID params[baseIndex+3] = msg.Timestamp.UnixMilli()
params[baseIndex+4] = msg.Timestamp.UnixMilli() params[baseIndex+4] = msg.editTimestampVal()
params[baseIndex+5] = msg.ThreadID params[baseIndex+5] = msg.ThreadID
params[baseIndex+6] = msg.MXID params[baseIndex+6] = msg.MXID
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7) placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7)
@@ -129,13 +129,13 @@ type Message struct {
db *Database db *Database
log log.Logger log log.Logger
DiscordID string DiscordID string
AttachmentID string AttachmentID string
EditIndex int Channel PortalKey
Channel PortalKey SenderID string
SenderID string Timestamp time.Time
Timestamp time.Time EditTimestamp time.Time
ThreadID string ThreadID string
MXID id.EventID MXID id.EventID
} }
@@ -149,9 +149,9 @@ func (m *Message) DiscordProtoChannelID() string {
} }
func (m *Message) Scan(row dbutil.Scannable) *Message { func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts int64 var ts, editTS int64
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.ThreadID, &m.MXID) err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err) m.log.Errorln("Database scan failed:", err)
@@ -162,7 +162,10 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
} }
if ts != 0 { if ts != 0 {
m.Timestamp = time.UnixMilli(ts) m.Timestamp = time.UnixMilli(ts).UTC()
}
if editTS != 0 {
m.EditTimestamp = time.Unix(0, editTS).UTC()
} }
return m return m
@@ -170,7 +173,7 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
const messageInsertQuery = ` const messageInsertQuery = `
INSERT INTO message ( INSERT INTO message (
dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid
) )
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
` `
@@ -182,6 +185,13 @@ type MessagePart struct {
MXID id.EventID MXID id.EventID
} }
func (m *Message) editTimestampVal() int64 {
if m.EditTimestamp.IsZero() {
return 0
}
return m.EditTimestamp.UnixNano()
}
func (m *Message) MassInsertParts(msgs []MessagePart) { func (m *Message) MassInsertParts(msgs []MessagePart) {
if len(msgs) == 0 { if len(msgs) == 0 {
return return
@@ -193,11 +203,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
params := make([]interface{}, 7+len(msgs)*2) params := make([]interface{}, 7+len(msgs)*2)
placeholders := make([]string, len(msgs)) placeholders := make([]string, len(msgs))
params[0] = m.DiscordID params[0] = m.DiscordID
params[1] = m.EditIndex params[1] = m.Channel.ChannelID
params[2] = m.Channel.ChannelID params[2] = m.Channel.Receiver
params[3] = m.Channel.Receiver params[3] = m.SenderID
params[4] = m.SenderID params[4] = m.Timestamp.UnixMilli()
params[5] = m.Timestamp.UnixMilli() params[5] = m.editTimestampVal()
params[6] = m.ThreadID params[6] = m.ThreadID
for i, msg := range msgs { for i, msg := range msgs {
params[7+i*2] = msg.AttachmentID params[7+i*2] = msg.AttachmentID
@@ -213,8 +223,8 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
func (m *Message) Insert() { func (m *Message) Insert() {
_, err := m.db.Exec(messageInsertQuery, _, err := m.db.Exec(messageInsertQuery,
m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
m.Timestamp.UnixMilli(), m.ThreadID, m.MXID) m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID)
if err != nil { if err != nil {
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err) m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
@@ -222,6 +232,20 @@ func (m *Message) Insert() {
} }
} }
const editUpdateQuery = `
UPDATE message
SET dc_edit_timestamp=$1
WHERE dcid=$2 AND dc_attachment_id=$3 AND dc_chan_id=$4 AND dc_chan_receiver=$5 AND dc_edit_timestamp<$1
`
func (m *Message) UpdateEditTimestamp(ts time.Time) {
_, err := m.db.Exec(editUpdateQuery, ts.UnixNano(), m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver)
if err != nil {
m.log.Warnfln("Failed to update edit timestamp of %s@%s: %v", m.DiscordID, m.Channel, err)
panic(err)
}
}
func (m *Message) Delete() { func (m *Message) Delete() {
query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4" query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3 AND dc_attachment_id=$4"
_, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID) _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.AttachmentID)

View File

@@ -1,4 +1,4 @@
-- v0 -> v18 (compatible with v15+): Latest revision -- v0 -> v19: Latest revision
CREATE TABLE guild ( CREATE TABLE guild (
dcid TEXT PRIMARY KEY, dcid TEXT PRIMARY KEY,
@@ -104,18 +104,18 @@ CREATE TABLE user_portal (
); );
CREATE TABLE message ( CREATE TABLE message (
dcid TEXT, dcid TEXT,
dc_attachment_id TEXT, dc_attachment_id TEXT,
dc_edit_index INTEGER, dc_chan_id TEXT,
dc_chan_id TEXT, dc_chan_receiver TEXT,
dc_chan_receiver TEXT, dc_sender TEXT NOT NULL,
dc_sender TEXT NOT NULL, timestamp BIGINT NOT NULL,
timestamp BIGINT NOT NULL, dc_edit_timestamp BIGINT NOT NULL,
dc_thread_id TEXT NOT NULL, dc_thread_id TEXT NOT NULL,
mxid TEXT NOT NULL UNIQUE, mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver), PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver),
CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE
); );
@@ -127,13 +127,12 @@ CREATE TABLE reaction (
dc_emoji_name TEXT, dc_emoji_name TEXT,
dc_thread_id TEXT NOT NULL, dc_thread_id TEXT NOT NULL,
dc_first_attachment_id TEXT NOT NULL, dc_first_attachment_id TEXT NOT NULL,
_dc_first_edit_index INTEGER NOT NULL DEFAULT 0,
mxid TEXT NOT NULL UNIQUE, mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name),
CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
); );
CREATE TABLE role ( CREATE TABLE role (

View File

@@ -0,0 +1,15 @@
-- v19: Replace dc_edit_index with dc_edit_timestamp
-- transaction: off
BEGIN;
ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey;
ALTER TABLE message DROP CONSTRAINT message_pkey;
ALTER TABLE message DROP COLUMN dc_edit_index;
ALTER TABLE reaction DROP COLUMN _dc_first_edit_index;
ALTER TABLE message ADD PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver);
ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE;
ALTER TABLE message ADD COLUMN dc_edit_timestamp BIGINT NOT NULL DEFAULT 0;
ALTER TABLE message ALTER COLUMN dc_edit_timestamp DROP DEFAULT;
COMMIT;

View File

@@ -0,0 +1,48 @@
-- v19: Replace dc_edit_index with dc_edit_timestamp
-- transaction: off
PRAGMA foreign_keys = OFF;
BEGIN;
CREATE TABLE message_new (
dcid TEXT,
dc_attachment_id TEXT,
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
dc_edit_timestamp BIGINT NOT NULL,
dc_thread_id TEXT NOT NULL,
mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver),
CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE
);
INSERT INTO message_new (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid)
SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, 0, dc_thread_id, mxid FROM message;
DROP TABLE message;
ALTER TABLE message_new RENAME TO message;
CREATE TABLE reaction_new (
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_msg_id TEXT,
dc_sender TEXT,
dc_emoji_name TEXT,
dc_thread_id TEXT NOT NULL,
dc_first_attachment_id TEXT NOT NULL,
mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name),
CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
);
INSERT INTO reaction_new (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, dc_first_attachment_id, mxid)
SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, dc_first_attachment_id, mxid FROM reaction;
DROP TABLE reaction;
ALTER TABLE reaction_new RENAME TO reaction;
PRAGMA foreign_key_check;
COMMIT;
PRAGMA foreign_keys = ON;

View File

@@ -583,11 +583,10 @@ func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool {
return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache) return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache)
} }
func (portal *Portal) markMessageHandled(discordID string, editIndex int, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) { func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) {
msg := portal.bridge.DB.Message.New() msg := portal.bridge.DB.Message.New()
msg.Channel = portal.Key msg.Channel = portal.Key
msg.DiscordID = discordID msg.DiscordID = discordID
msg.EditIndex = editIndex
msg.SenderID = authorID msg.SenderID = authorID
msg.Timestamp = timestamp msg.Timestamp = timestamp
msg.ThreadID = threadID msg.ThreadID = threadID
@@ -674,7 +673,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
} else if len(dbParts) == 0 { } else if len(dbParts) == 0 {
log.Warn().Msg("All parts of message failed to send to Matrix") log.Warn().Msg("All parts of message failed to send to Matrix")
} else { } else {
portal.markMessageHandled(msg.ID, 0, msg.Author.ID, ts, discordThreadID, dbParts) portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, dbParts)
} }
} }
@@ -778,6 +777,13 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
log.Warn().Msg("Dropping update of unknown message") log.Warn().Msg("Dropping update of unknown message")
return return
} }
if msg.EditedTimestamp != nil && !msg.EditedTimestamp.After(existing[0].EditTimestamp) {
log.Debug().
Time("received_edit_ts", *msg.EditedTimestamp).
Time("db_edit_ts", existing[0].EditTimestamp).
Msg("Dropping update of message with older or equal edit timestamp")
return
}
if msg.Flags == discordgo.MessageFlagsHasThread { if msg.Flags == discordgo.MessageFlagsHasThread {
thread := portal.bridge.GetThreadByID(msg.ID, existing[0]) thread := portal.bridge.GetThreadByID(msg.ID, existing[0])
@@ -885,8 +891,9 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
portal.sendDeliveryReceipt(resp.EventID) portal.sendDeliveryReceipt(resp.EventID)
//ts, _ := msg.Timestamp.Parse() if msg.EditedTimestamp != nil {
//portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts) existing[0].UpdateEditTimestamp(*msg.EditedTimestamp)
}
} }
func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) { func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) {
@@ -1386,16 +1393,20 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
if edits != nil { if edits != nil {
discordContent, allowedMentions := portal.parseMatrixHTML(content.NewContent) discordContent, allowedMentions := portal.parseMatrixHTML(content.NewContent)
var err error var err error
var msg *discordgo.Message
if !isWebhookSend { if !isWebhookSend {
// TODO save edit in message table // TODO save edit in message table
_, err = sess.ChannelMessageEdit(edits.DiscordProtoChannelID(), edits.DiscordID, discordContent) msg, err = sess.ChannelMessageEdit(edits.DiscordProtoChannelID(), edits.DiscordID, discordContent)
} else { } else {
_, err = relayClient.WebhookMessageEdit(portal.RelayWebhookID, portal.RelayWebhookSecret, edits.DiscordID, &discordgo.WebhookEdit{ msg, err = relayClient.WebhookMessageEdit(portal.RelayWebhookID, portal.RelayWebhookSecret, edits.DiscordID, &discordgo.WebhookEdit{
Content: &discordContent, Content: &discordContent,
AllowedMentions: allowedMentions, AllowedMentions: allowedMentions,
}) })
} }
go portal.sendMessageMetrics(evt, err, "Failed to edit") go portal.sendMessageMetrics(evt, err, "Failed to edit")
if msg.EditedTimestamp != nil {
edits.UpdateEditTimestamp(*msg.EditedTimestamp)
}
} else { } else {
go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownEditTarget, editMXID), "Ignoring") go portal.sendMessageMetrics(evt, fmt.Errorf("%w %s", errUnknownEditTarget, editMXID), "Ignoring")
} }