From 4194b4dfd9bf55e89bc7e22efeee185495752910 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Apr 2023 15:06:02 +0300 Subject: [PATCH] Improve missed message backfilling --- backfill.go | 259 ++++++++++++++++++++++++++++++++++++++++++++ config/bridge.go | 13 +++ config/upgrade.go | 5 + database/message.go | 32 +++++- example-config.yaml | 14 +++ portal.go | 136 +++++++++++------------ user.go | 4 +- 7 files changed, 389 insertions(+), 74 deletions(-) create mode 100644 backfill.go diff --git a/backfill.go b/backfill.go new file mode 100644 index 0000000..93396ae --- /dev/null +++ b/backfill.go @@ -0,0 +1,259 @@ +package main + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "sort" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridge/bridgeconfig" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-discord/database" +) + +func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) { + limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel + if portal.GuildID == "" { + limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM + } + if limit == 0 { + return + } + log := portal.zlog.With(). + Str("action", "missed event backfill"). + Int("limit", limit). + Logger() + + portal.forwardBackfillLock.Lock() + defer portal.forwardBackfillLock.Unlock() + + lastMessage := portal.bridge.DB.Message.GetLast(portal.Key) + if lastMessage == nil || meta.LastMessageID == "" { + log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata") + return + } else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) { + log.Debug(). + Str("last_bridged_message", lastMessage.DiscordID). + Str("last_server_message", meta.LastMessageID). + Msg("Not backfilling, last message in database is newer than last message in metadata") + return + } + log.Debug(). + Str("last_bridged_message", lastMessage.DiscordID). + Str("last_server_message", meta.LastMessageID). + Msg("Backfilling missed messages") + if limit < 0 { + portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID) + } else { + portal.backfillLimitedMissed(log, source, limit, lastMessage.DiscordID) + } +} + +const messageFetchChunkSize = 50 + +func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) { + var messages []*discordgo.Message + var before string + var foundAll bool + for { + log.Debug().Str("before_id", before).Msg("Fetching messages for backfill") + newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "") + if err != nil { + return nil, false, err + } + for i, msg := range newMessages { + if compareMessageIDs(msg.ID, until) <= 0 { + log.Debug(). + Str("message_id", msg.ID). + Str("until_id", until). + Msg("Found message that was already bridged") + newMessages = newMessages[:i] + foundAll = true + break + } + } + messages = append(messages, newMessages...) + log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection") + if len(newMessages) <= messageFetchChunkSize || len(messages) >= limit { + break + } + before = newMessages[len(newMessages)-1].ID + } + if len(messages) > limit { + messages = messages[:limit] + } + return messages, foundAll, nil +} + +func (portal *Portal) backfillLimitedMissed(log zerolog.Logger, source *User, limit int, after string) { + messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after) + if err != nil { + log.Err(err).Msg("Error collecting messages to forward backfill") + return + } + log.Info(). + Int("count", len(messages)). + Bool("found_all", foundAll). + Msg("Collected messages to backfill") + sort.Sort(MessageSlice(messages)) + if !foundAll { + _, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: "Some messages may have been missed here while the bridge was offline.", + }, nil, 0) + if err != nil { + log.Warn().Err(err).Msg("Failed to send missed message warning") + } else { + log.Debug().Msg("Sent warning about possibly missed messages") + } + } + portal.sendBackfillBatch(log, source, messages) +} + +func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) { + for { + log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill") + messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "") + if err != nil { + log.Err(err).Msg("Error fetching chunk of messages to forward backfill") + return + } + log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill") + sort.Sort(MessageSlice(messages)) + + portal.sendBackfillBatch(log, source, messages) + + if len(messages) < messageFetchChunkSize { + // Assume that was all the missing messages + log.Debug().Msg("Chunk had less than 50 messages, stopping backfill") + return + } + after = messages[len(messages)-1].ID + } +} + +func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) { + if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { + log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint") + portal.forwardBatchSend(log, source, messages) + } else { + log.Debug().Msg("Not using hungryserv, sending messages one by one") + for _, msg := range messages { + portal.handleDiscordMessageCreate(source, msg, nil) + } + } +} + +func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) { + evts := make([]*event.Event, 0, len(messages)) + dbMessages := make([]database.Message, 0, len(messages)) + for _, msg := range messages { + for _, mention := range msg.Mentions { + puppet := portal.bridge.GetPuppetByID(mention.ID) + puppet.UpdateInfo(nil, mention) + } + + puppet := portal.bridge.GetPuppetByID(msg.Author.ID) + puppet.UpdateInfo(source, msg.Author) + intent := puppet.IntentFor(portal) + replyTo := portal.getReplyTarget(source, msg.MessageReference, true) + + ts, _ := discordgo.SnowflakeTimestamp(msg.ID) + parts := portal.convertDiscordMessage(intent, msg) + for i, part := range parts { + if replyTo != nil { + part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo} + // Only set reply for first event + replyTo = nil + } + partName := part.AttachmentID + // Always use blank part name for first part so that replies and other things + // can reference it without knowing about attachments. + if i == 0 { + partName = "" + } + evts = append(evts, &event.Event{ + ID: portal.deterministicEventID(msg.ID, partName), + Type: part.Type, + Sender: intent.UserID, + Timestamp: ts.UnixMilli(), + Content: event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, + }) + dbMessages = append(dbMessages, database.Message{ + Channel: portal.Key, + DiscordID: msg.ID, + SenderID: msg.Author.ID, + Timestamp: ts, + AttachmentID: part.AttachmentID, + }) + } + } + log.Info().Int("parts", len(evts)).Msg("Converted messages to backfill") + resp, err := portal.MainIntent().BatchSend(portal.MXID, &mautrix.ReqBatchSend{ + BeeperNewMessages: true, + Events: evts, + }) + if err != nil { + log.Err(err).Msg("Error sending backfill batch") + return + } + for i, evtID := range resp.EventIDs { + dbMessages[i].MXID = evtID + } + portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages) + log.Info().Msg("Inserted backfilled batch to database") +} + +func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID { + data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName) + sum := sha256.Sum256([]byte(data)) + return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:]))) +} + +// compareMessageIDs compares two Discord message IDs. +// +// If the first ID is lower, -1 is returned. +// If the second ID is lower, 1 is returned. +// If the IDs are equal, 0 is returned. +func compareMessageIDs(id1, id2 string) int { + if id1 == id2 { + return 0 + } + if len(id1) < len(id2) { + return -1 + } else if len(id2) < len(id1) { + return 1 + } + if id1 < id2 { + return -1 + } + return 1 +} + +func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool { + return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1 +} + +type MessageSlice []*discordgo.Message + +var _ sort.Interface = (MessageSlice)(nil) + +func (a MessageSlice) Len() int { + return len(a) +} + +func (a MessageSlice) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} + +func (a MessageSlice) Less(i, j int) bool { + return compareMessageIDs(a[i].ID, a[j].ID) == -1 +} diff --git a/config/bridge.go b/config/bridge.go index 6fe0acb..77b4d15 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -66,6 +66,14 @@ type BridgeConfig struct { CommandPrefix string `yaml:"command_prefix"` ManagementRoomText bridgeconfig.ManagementRoomTexts `yaml:"management_room_text"` + Backfill struct { + Enabled bool `yaml:"enabled"` + Limits struct { + Initial BackfillLimitPart `yaml:"initial"` + Missed BackfillLimitPart `yaml:"missed"` + } `yaml:"limits"` + } `yaml:"backfill"` + Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"` Provisioning struct { @@ -81,6 +89,11 @@ type BridgeConfig struct { guildNameTemplate *template.Template `yaml:"-"` } +type BackfillLimitPart struct { + DM int `yaml:"dm"` + Channel int `yaml:"channel"` +} + func (bc *BridgeConfig) GetResendBridgeInfo() bool { return bc.ResendBridgeInfo } diff --git a/config/upgrade.go b/config/upgrade.go index d40d675..7d7e82f 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -67,6 +67,11 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Str, "bridge", "management_room_text", "welcome_connected") helper.Copy(up.Str, "bridge", "management_room_text", "welcome_unconnected") helper.Copy(up.Str|up.Null, "bridge", "management_room_text", "additional_help") + helper.Copy(up.Bool, "bridge", "backfill", "enabled") + helper.Copy(up.Int, "bridge", "backfill", "limits", "initial", "dm") + helper.Copy(up.Int, "bridge", "backfill", "limits", "initial", "channel") + helper.Copy(up.Int, "bridge", "backfill", "limits", "missed", "dm") + helper.Copy(up.Int, "bridge", "backfill", "limits", "missed", "channel") helper.Copy(up.Bool, "bridge", "encryption", "allow") helper.Copy(up.Bool, "bridge", "encryption", "default") helper.Copy(up.Bool, "bridge", "encryption", "require") diff --git a/database/message.go b/database/message.go index 335fbb2..99b3ad0 100644 --- a/database/message.go +++ b/database/message.go @@ -95,6 +95,36 @@ func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message { return mq.New().Scan(row) } +func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) { + if len(msgs) == 0 { + return + } + valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)" + if mq.db.Dialect == dbutil.SQLite { + valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?") + } + params := make([]interface{}, 2+len(msgs)*7) + placeholders := make([]string, len(msgs)) + params[0] = key.ChannelID + params[1] = key.Receiver + for i, msg := range msgs { + baseIndex := 2 + i*7 + params[baseIndex] = msg.DiscordID + params[baseIndex+1] = msg.AttachmentID + params[baseIndex+2] = msg.EditIndex + params[baseIndex+3] = msg.SenderID + params[baseIndex+4] = msg.Timestamp + params[baseIndex+5] = msg.ThreadID + params[baseIndex+6] = msg.MXID + placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7) + } + _, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...) + if err != nil { + mq.log.Warnfln("Failed to insert %d messages: %v", len(msgs), err) + panic(err) + } +} + type Message struct { db *Database log log.Logger @@ -152,7 +182,7 @@ type MessagePart struct { MXID id.EventID } -func (m *Message) MassInsert(msgs []MessagePart) { +func (m *Message) MassInsertParts(msgs []MessagePart) { if len(msgs) == 0 { return } diff --git a/example-config.yaml b/example-config.yaml index f1c3575..472ecfd 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -186,6 +186,20 @@ bridge: # Optional extra text sent when joining a management room. additional_help: "" + backfill: + # Should backfill be enabled at all? + enabled: false + # Limits for backfilling. Set to 0 to disable that type of backfill, or -1 for unlimited. + limits: + # Initial backfill (when creating portal). + initial: + dm: 50 + channel: 0 + # Missed message backfill (on startup). + missed: + dm: 50 + channel: 0 + # End-to-bridge encryption support options. # # See https://docs.mau.fi/bridges/general/end-to-bridge-encryption.html for more info. diff --git a/portal.go b/portal.go index e304796..20ced51 100644 --- a/portal.go +++ b/portal.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "reflect" - "sort" "strconv" "strings" "sync" @@ -13,8 +12,10 @@ import ( "github.com/bwmarrin/discordgo" "github.com/gabriel-vasile/mimetype" + "github.com/rs/zerolog" + "maunium.net/go/maulogger/v2/maulogadapt" - log "maunium.net/go/maulogger/v2" + "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -52,7 +53,9 @@ type Portal struct { Guild *Guild bridge *DiscordBridge - log log.Logger + // Deprecated + log maulogger.Logger + zlog zerolog.Logger roomCreateLock sync.Mutex encryptLock sync.Mutex @@ -65,6 +68,8 @@ type Portal struct { commands map[string]*discordgo.ApplicationCommand commandsLock sync.RWMutex + forwardBackfillLock sync.Mutex + currentlyTyping []id.UserID currentlyTypingLock sync.Mutex } @@ -233,7 +238,10 @@ func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal { portal := &Portal{ Portal: dbPortal, bridge: br, - log: br.Log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)), + zlog: br.ZLog.With(). + Str("channel_id", dbPortal.Key.ChannelID). + Str("channel_receiver", dbPortal.Key.Receiver). + Logger(), discordMessages: make(chan portalDiscordMessage, br.Config.Bridge.PortalMessageBuffer), matrixMessages: make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer), @@ -242,6 +250,7 @@ func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal { commands: make(map[string]*discordgo.ApplicationCommand), } + portal.log = maulogadapt.ZeroAsMau(&portal.zlog) go portal.messageLoop() @@ -252,10 +261,13 @@ func (portal *Portal) messageLoop() { for { select { case msg := <-portal.matrixMessages: + portal.forwardBackfillLock.Lock() portal.handleMatrixMessages(msg) case msg := <-portal.discordMessages: + portal.forwardBackfillLock.Lock() portal.handleDiscordMessages(msg) } + portal.forwardBackfillLock.Unlock() } } @@ -549,7 +561,7 @@ func (portal *Portal) markMessageHandled(discordID string, editIndex int, author msg.SenderID = authorID msg.Timestamp = timestamp msg.ThreadID = threadID - msg.MassInsert(parts) + msg.MassInsertParts(parts) } func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) { @@ -578,7 +590,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess intent := puppet.IntentFor(portal) var discordThreadID string - var threadRootEvent, lastThreadEvent, replyToEvent id.EventID + var threadRootEvent, lastThreadEvent id.EventID if thread != nil { discordThreadID = thread.ID threadRootEvent = thread.RootMXID @@ -588,30 +600,25 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess lastThreadEvent = lastInThread.MXID } } - - if msg.MessageReference != nil { - // This could be used to find cross-channel replies, but Matrix doesn't support those currently. - //key := database.PortalKey{msg.MessageReference.ChannelID, user.ID} - replyToMsg := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID) - if len(replyToMsg) > 0 { - replyToEvent = replyToMsg[0].MXID - } - } + replyTo := portal.getReplyTarget(user, msg.MessageReference, false) ts, _ := discordgo.SnowflakeTimestamp(msg.ID) parts := portal.convertDiscordMessage(intent, msg) dbParts := make([]database.MessagePart, 0, len(parts)) for i, part := range parts { - if (replyToEvent != "" || threadRootEvent != "") && part.Content.RelatesTo == nil { + if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil { part.Content.RelatesTo = &event.RelatesTo{} } if threadRootEvent != "" { part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent) } - if replyToEvent != "" { - part.Content.RelatesTo.SetReplyTo(replyToEvent) + if replyTo != nil { + part.Content.RelatesTo.SetReplyTo(replyTo.EventID) + if replyTo.UnstableRoomID != "" { + part.Content.RelatesTo.InReplyTo.UnstableRoomID = replyTo.UnstableRoomID + } // Only set reply for first event - replyToEvent = "" + replyTo = nil } resp, err := portal.sendMatrixMessage(intent, part.Type, part.Content, part.Extra, ts.UnixMilli()) if err != nil { @@ -630,6 +637,42 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } } +func (portal *Portal) getReplyTarget(source *User, ref *discordgo.MessageReference, allowNonExistent bool) *event.InReplyTo { + if ref == nil { + return nil + } + isHungry := portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry + if !isHungry { + allowNonExistent = false + } + // TODO add config option for cross-room replies + crossRoomReplies := isHungry + + targetPortal := portal + if ref.ChannelID != portal.Key.ChannelID && crossRoomReplies { + targetPortal = portal.bridge.GetExistingPortalByID(database.PortalKey{ChannelID: ref.ChannelID, Receiver: source.DiscordID}) + if targetPortal == nil { + return nil + } + } + replyToMsg := portal.bridge.DB.Message.GetByDiscordID(targetPortal.Key, ref.MessageID) + if len(replyToMsg) > 0 { + if !crossRoomReplies { + return &event.InReplyTo{EventID: replyToMsg[0].MXID} + } + return &event.InReplyTo{ + EventID: replyToMsg[0].MXID, + UnstableRoomID: targetPortal.MXID, + } + } else if allowNonExistent { + return &event.InReplyTo{ + EventID: targetPortal.deterministicEventID(ref.MessageID, ""), + UnstableRoomID: targetPortal.MXID, + } + } + return nil +} + const JoinThreadReaction = "join thread" func (portal *Portal) sendThreadCreationNotice(thread *Thread) { @@ -1066,9 +1109,9 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin evtDescription += fmt.Sprintf(" of %s", evt.Redacts) } if err != nil { - level := log.LevelError + level := maulogger.LevelError if part == "Ignoring" { - level = log.LevelDebug + level = maulogger.LevelDebug } portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err) reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err) @@ -1376,7 +1419,7 @@ func (portal *Portal) cleanup(puppetsOnly bool) { portal.bridge.cleanupRoom(intent, portal.MXID, puppetsOnly, portal.log) } -func (br *DiscordBridge) cleanupRoom(intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool, log log.Logger) { +func (br *DiscordBridge) cleanupRoom(intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool, log maulogger.Logger) { members, err := intent.JoinedMembers(mxid) if err != nil { log.Errorln("Failed to get portal members for cleanup:", err) @@ -1975,52 +2018,3 @@ func (portal *Portal) UpdateInfo(source *User, meta *discordgo.Channel) *discord } return meta } - -func (portal *Portal) ForwardBackfill(source *User, meta *discordgo.Channel) error { - portal.log.Debugln("Checking for missing messages to fill") - lastMessage := portal.bridge.DB.Message.GetLast(portal.Key) - if lastMessage == nil { - return nil - } - - metaLastMessageID, err := strconv.ParseInt(meta.LastMessageID, 10, 0) - if err != nil { - portal.log.Errorfln("Last message ID %s isn't integer", meta.LastMessageID) - return err - } - dbLastMessageID, err := strconv.ParseInt(lastMessage.DiscordID, 10, 0) - if err != nil { - portal.log.Errorfln("Last message ID %s isn't integer", lastMessage.DiscordID) - return err - } - if metaLastMessageID <= dbLastMessageID { - return nil - } - - // Get up to 50 messages at a time until everything is fetched - for { - messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, 50, "", lastMessage.DiscordID, "") - if err != nil { - portal.log.Debugln("Error getting messages to forward backfill", err) - return err - } - // Discord seems to return messages in reverse order, - // but no specific order is guaranteed by their API docs? - sort.Slice(messages, func(i, j int) bool { - return messages[i].Timestamp.Before(messages[j].Timestamp) - }) - - for _, msg := range messages { - portal.handleDiscordMessageCreate(source, msg, nil) - } - - if len(messages) < 100 { - // Assume that was all the missing messages - return nil - } - lastMessage = portal.bridge.DB.Message.GetLast(portal.Key) - if lastMessage == nil { - return nil - } - } -} diff --git a/user.go b/user.go index 09fd639..f5fa83e 100644 --- a/user.go +++ b/user.go @@ -728,7 +728,7 @@ func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel, } } else { portal.UpdateInfo(user, meta) - portal.ForwardBackfill(user, meta) + portal.ForwardBackfillMissed(user, meta) } user.MarkInPortal(database.UserPortal{ DiscordID: portal.Key.ChannelID, @@ -843,7 +843,7 @@ func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSp } } else { portal.UpdateInfo(user, ch) - portal.ForwardBackfill(user, ch) + portal.ForwardBackfillMissed(user, ch) } } }