From 434f27c8b4a8aa4e411051a950bf39f39216e5c3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 24 May 2023 13:18:23 +0300 Subject: [PATCH] Add support for intentional mentions --- backfill.go | 9 ++++- config/upgrade.go | 1 + database/message.go | 33 ++++++++++-------- database/upgrades/00-latest-revision.sql | 5 +-- database/upgrades/20-message-sender-mxid.sql | 2 ++ example-config.yaml | 2 ++ formatter.go | 9 +++++ portal.go | 36 ++++++++++++-------- portal_convert.go | 27 +++++++++++++++ 9 files changed, 91 insertions(+), 33 deletions(-) create mode 100644 database/upgrades/20-message-sender-mxid.sql diff --git a/backfill.go b/backfill.go index 9a5dcad..d6e15f1 100644 --- a/backfill.go +++ b/backfill.go @@ -217,7 +217,8 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess puppet := portal.bridge.GetPuppetByID(msg.Author.ID) puppet.UpdateInfo(source, msg.Author) intent := puppet.IntentFor(portal) - replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true) + replyTo, replySenderMXID := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true) + mentions := portal.convertDiscordMentions(msg, replySenderMXID, false) ts, _ := discordgo.SnowflakeTimestamp(msg.ID) log := log.With(). @@ -232,6 +233,11 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess // Only set reply for first event replyTo = nil } + + part.Content.Mentions = mentions + // Only set mentions for first event, but keep empty object for rest + mentions = &event.Mentions{} + partName := part.AttachmentID // Always use blank part name for first part so that replies and other things // can reference it without knowing about attachments. @@ -262,6 +268,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess SenderID: msg.Author.ID, Timestamp: ts, AttachmentID: part.AttachmentID, + SenderMXID: intent.UserID, }) } } diff --git a/config/upgrade.go b/config/upgrade.go index 4f1e99d..e207c41 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -85,6 +85,7 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "bridge", "encryption", "require") helper.Copy(up.Bool, "bridge", "encryption", "appservice") helper.Copy(up.Bool, "bridge", "encryption", "allow_key_sharing") + helper.Copy(up.Bool, "bridge", "encryption", "plaintext_mentions") helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "delete_outbound_on_ack") helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "dont_store_outbound") helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "ratchet_on_decrypt") diff --git a/database/message.go b/database/message.go index 35d6552..d47a1ef 100644 --- a/database/message.go +++ b/database/message.go @@ -19,7 +19,7 @@ type MessageQuery struct { } const ( - messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid FROM message" + messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid FROM message" ) func (mq *MessageQuery) New() *Message { @@ -99,11 +99,11 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) { if len(msgs) == 0 { return } - valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)" + valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d, $%d)" if mq.db.Dialect == dbutil.SQLite { valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") } - params := make([]interface{}, 2+len(msgs)*7) + params := make([]interface{}, 2+len(msgs)*8) placeholders := make([]string, len(msgs)) params[0] = key.ChannelID params[1] = key.Receiver @@ -116,7 +116,8 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) { params[baseIndex+4] = msg.editTimestampVal() params[baseIndex+5] = msg.ThreadID params[baseIndex+6] = msg.MXID - placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7) + params[baseIndex+7] = msg.SenderMXID.String() + placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8) } _, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...) if err != nil { @@ -137,7 +138,8 @@ type Message struct { EditTimestamp time.Time ThreadID string - MXID id.EventID + MXID id.EventID + SenderMXID id.UserID } func (m *Message) DiscordProtoChannelID() string { @@ -151,7 +153,7 @@ func (m *Message) DiscordProtoChannelID() string { func (m *Message) Scan(row dbutil.Scannable) *Message { var ts, editTS int64 - err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID) + err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID, &m.SenderMXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { m.log.Errorln("Database scan failed:", err) @@ -173,12 +175,12 @@ func (m *Message) Scan(row dbutil.Scannable) *Message { const messageInsertQuery = ` INSERT INTO message ( - dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid + dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ` -var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9)", "%s", 1) +var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", "%s", 1) type MessagePart struct { AttachmentID string @@ -196,11 +198,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) { if len(msgs) == 0 { return } - valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d)" + valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d, $8)" if m.db.Dialect == dbutil.SQLite { valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") } - params := make([]interface{}, 7+len(msgs)*2) + params := make([]interface{}, 8+len(msgs)*2) placeholders := make([]string, len(msgs)) params[0] = m.DiscordID params[1] = m.Channel.ChannelID @@ -209,10 +211,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) { params[4] = m.Timestamp.UnixMilli() params[5] = m.editTimestampVal() params[6] = m.ThreadID + params[7] = m.SenderMXID.String() for i, msg := range msgs { - params[7+i*2] = msg.AttachmentID - params[7+i*2+1] = msg.MXID - placeholders[i] = fmt.Sprintf(valueStringFormat, 7+i*2+1, 7+i*2+2) + params[8+i*2] = msg.AttachmentID + params[8+i*2+1] = msg.MXID + placeholders[i] = fmt.Sprintf(valueStringFormat, 8+i*2+1, 8+i*2+2) } _, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...) if err != nil { @@ -224,7 +227,7 @@ func (m *Message) MassInsertParts(msgs []MessagePart) { func (m *Message) Insert() { _, err := m.db.Exec(messageInsertQuery, m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, - m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID) + m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID, m.SenderMXID.String()) if err != nil { m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err) diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index 777e220..3bef8ab 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v19: Latest revision +-- v0 -> v20 (compatible with v19+): Latest revision CREATE TABLE guild ( dcid TEXT PRIMARY KEY, @@ -113,7 +113,8 @@ CREATE TABLE message ( dc_edit_timestamp BIGINT NOT NULL, dc_thread_id TEXT NOT NULL, - mxid TEXT NOT NULL UNIQUE, + mxid TEXT NOT NULL UNIQUE, + sender_mxid TEXT NOT NULL DEFAULT '', PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver), CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE diff --git a/database/upgrades/20-message-sender-mxid.sql b/database/upgrades/20-message-sender-mxid.sql new file mode 100644 index 0000000..aa2bd65 --- /dev/null +++ b/database/upgrades/20-message-sender-mxid.sql @@ -0,0 +1,2 @@ +-- v20 (compatible with v19+): Store message sender Matrix user ID +ALTER TABLE message ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT ''; diff --git a/example-config.yaml b/example-config.yaml index 0b63ec9..f45aabd 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -247,6 +247,8 @@ bridge: # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. # You must use a client that supports requesting keys from other users to use this feature. allow_key_sharing: false + # Should users mentions be in the event wire content to enable the server to send push notifications? + plaintext_mentions: false # Options for deleting megolm sessions from the bridge. delete_keys: # Beeper-specific: delete outbound sessions when hungryserv confirms diff --git a/formatter.go b/formatter.go index 4255a57..42ed101 100644 --- a/formatter.go +++ b/formatter.go @@ -26,6 +26,7 @@ import ( "github.com/yuin/goldmark/extension" "github.com/yuin/goldmark/parser" "github.com/yuin/goldmark/util" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" @@ -93,6 +94,7 @@ func (portal *Portal) renderDiscordMarkdownOnlyHTML(text string, allowInlineLink const formatterContextPortalKey = "fi.mau.discord.portal" const formatterContextAllowedMentionsKey = "fi.mau.discord.allowed_mentions" +const formatterContextInputAllowedMentionsKey = "fi.mau.discord.input_allowed_mentions" func appendIfNotContains(arr []string, newItem string) []string { for _, item := range arr { @@ -135,6 +137,10 @@ func (br *DiscordBridge) pillConverter(displayname, mxid, eventID string, ctx fo } } } else if mxid[0] == '@' { + allowedMentions, _ := ctx.ReturnData[formatterContextInputAllowedMentionsKey].([]id.UserID) + if allowedMentions != nil && !slices.Contains(allowedMentions, id.UserID(mxid)) { + return displayname + } mentions := ctx.ReturnData[formatterContextAllowedMentionsKey].(*discordgo.MessageAllowedMentions) parsedID, ok := br.ParsePuppetMXID(id.UserID(mxid)) if ok { @@ -219,6 +225,9 @@ func (portal *Portal) parseMatrixHTML(content *event.MessageEventContent) (strin ctx := format.NewContext() ctx.ReturnData[formatterContextPortalKey] = portal ctx.ReturnData[formatterContextAllowedMentionsKey] = allowedMentions + if content.Mentions != nil { + ctx.ReturnData[formatterContextInputAllowedMentionsKey] = content.Mentions.UserIDs + } return variationselector.FullyQualify(matrixHTMLParser.Parse(content.FormattedBody, ctx)), allowedMentions } else { return variationselector.FullyQualify(escapeDiscordMarkdown(content.Body)), allowedMentions diff --git a/portal.go b/portal.go index 4c0ae4d..324974a 100644 --- a/portal.go +++ b/portal.go @@ -584,13 +584,14 @@ func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool { return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache) } -func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) { +func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, senderMXID id.UserID, parts []database.MessagePart) { msg := portal.bridge.DB.Message.New() msg.Channel = portal.Key msg.DiscordID = discordID msg.SenderID = authorID msg.Timestamp = timestamp msg.ThreadID = threadID + msg.SenderMXID = senderMXID msg.MassInsertParts(parts) } @@ -618,11 +619,6 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } log.Debug().Msg("Starting handling of Discord message") - for _, mention := range msg.Mentions { - puppet := portal.bridge.GetPuppetByID(mention.ID) - puppet.UpdateInfo(nil, mention) - } - puppet := portal.bridge.GetPuppetByID(msg.Author.ID) puppet.UpdateInfo(user, msg.Author) intent := puppet.IntentFor(portal) @@ -638,7 +634,8 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess lastThreadEvent = lastInThread.MXID } } - replyTo := portal.getReplyTarget(user, discordThreadID, msg.MessageReference, msg.Embeds, false) + replyTo, replySenderMXID := portal.getReplyTarget(user, discordThreadID, msg.MessageReference, msg.Embeds, false) + mentions := portal.convertDiscordMentions(msg, replySenderMXID, true) ts, _ := discordgo.SnowflakeTimestamp(msg.ID) parts := portal.convertDiscordMessage(ctx, intent, msg) @@ -658,6 +655,11 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess // Only set reply for first event replyTo = nil } + + part.Content.Mentions = mentions + // Only set mentions for first event, but keep empty object for rest + mentions = &event.Mentions{} + resp, err := portal.sendMatrixMessage(intent, part.Type, part.Content, part.Extra, ts.UnixMilli()) if err != nil { log.Err(err). @@ -674,7 +676,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } else if len(dbParts) == 0 { log.Warn().Msg("All parts of message failed to send to Matrix") } else { - portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, dbParts) + portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts) } } @@ -684,7 +686,7 @@ func isReplyEmbed(embed *discordgo.MessageEmbed) bool { return hackyReplyPattern.MatchString(embed.Description) } -func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discordgo.MessageReference, embeds []*discordgo.MessageEmbed, allowNonExistent bool) *event.InReplyTo { +func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discordgo.MessageReference, embeds []*discordgo.MessageEmbed, allowNonExistent bool) (*event.InReplyTo, id.UserID) { if ref == nil && len(embeds) > 0 { match := hackyReplyPattern.FindStringSubmatch(embeds[0].Description) if match != nil && match[1] == portal.GuildID && (match[2] == portal.Key.ChannelID || match[2] == threadID) { @@ -696,7 +698,7 @@ func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discord } } if ref == nil { - return nil + return nil, "" } isHungry := portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry if !isHungry { @@ -709,25 +711,25 @@ func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discord if ref.ChannelID != portal.Key.ChannelID && ref.ChannelID != threadID && crossRoomReplies { targetPortal = portal.bridge.GetExistingPortalByID(database.PortalKey{ChannelID: ref.ChannelID, Receiver: source.DiscordID}) if targetPortal == nil { - return nil + return nil, "" } } replyToMsg := portal.bridge.DB.Message.GetByDiscordID(targetPortal.Key, ref.MessageID) if len(replyToMsg) > 0 { if !crossRoomReplies { - return &event.InReplyTo{EventID: replyToMsg[0].MXID} + return &event.InReplyTo{EventID: replyToMsg[0].MXID}, replyToMsg[0].SenderMXID } return &event.InReplyTo{ EventID: replyToMsg[0].MXID, UnstableRoomID: targetPortal.MXID, - } + }, replyToMsg[0].SenderMXID } else if allowNonExistent { return &event.InReplyTo{ EventID: targetPortal.deterministicEventID(ref.MessageID, ""), UnstableRoomID: targetPortal.MXID, - } + }, "" } - return nil + return nil, "" } const JoinThreadReaction = "join thread" @@ -895,7 +897,10 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess Msg("Dropping non-text edit") return } + converted.Content.Mentions = portal.convertDiscordMentions(msg, "", false) converted.Content.SetEdit(existing[0].MXID) + // Never actually mention new users of edits, only include mentions inside m.new_content + converted.Content.Mentions = &event.Mentions{} if converted.Extra != nil { converted.Extra = map[string]any{ "m.new_content": converted.Extra, @@ -1585,6 +1590,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { } else { dbMsg.SenderID = portal.RelayWebhookID } + dbMsg.SenderMXID = sender.MXID dbMsg.Timestamp, _ = discordgo.SnowflakeTimestamp(msg.ID) dbMsg.ThreadID = threadID dbMsg.Insert() diff --git a/portal_convert.go b/portal_convert.go index b339e53..d8fcb5a 100644 --- a/portal_convert.go +++ b/portal_convert.go @@ -26,6 +26,8 @@ import ( "github.com/bwmarrin/discordgo" "github.com/rs/zerolog" + "golang.org/x/exp/slices" + "maunium.net/go/mautrix/id" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -518,6 +520,31 @@ func isPlainGifMessage(msg *discordgo.Message) bool { return len(msg.Embeds) == 1 && msg.Embeds[0].Video != nil && msg.Embeds[0].URL == msg.Content && msg.Embeds[0].Type == discordgo.EmbedTypeGifv } +func (portal *Portal) convertDiscordMentions(msg *discordgo.Message, replySender id.UserID, syncGhosts bool) *event.Mentions { + var matrixMentions event.Mentions + for _, mention := range msg.Mentions { + puppet := portal.bridge.GetPuppetByID(mention.ID) + if syncGhosts { + puppet.UpdateInfo(nil, mention) + } + user := portal.bridge.GetUserByID(mention.ID) + if user != nil { + matrixMentions.UserIDs = append(matrixMentions.UserIDs, user.MXID) + } else { + matrixMentions.UserIDs = append(matrixMentions.UserIDs, puppet.MXID) + } + } + if replySender != "" { + matrixMentions.UserIDs = append(matrixMentions.UserIDs, replySender) + } + slices.Sort(matrixMentions.UserIDs) + matrixMentions.UserIDs = slices.Compact(matrixMentions.UserIDs) + if msg.MentionEveryone { + matrixMentions.Room = true + } + return &matrixMentions +} + func (portal *Portal) convertDiscordTextMessage(ctx context.Context, intent *appservice.IntentAPI, msg *discordgo.Message) *ConvertedMessage { log := zerolog.Ctx(ctx) if msg.Type == discordgo.MessageTypeCall {