Improve missed message backfilling

This commit is contained in:
Tulir Asokan
2023-04-16 15:06:02 +03:00
parent d465bd2d67
commit 4194b4dfd9
7 changed files with 389 additions and 74 deletions

259
backfill.go Normal file
View File

@@ -0,0 +1,259 @@
package main
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"sort"
"github.com/bwmarrin/discordgo"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridge/bridgeconfig"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/mautrix-discord/database"
)
func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) {
limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
if portal.GuildID == "" {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
}
if limit == 0 {
return
}
log := portal.zlog.With().
Str("action", "missed event backfill").
Int("limit", limit).
Logger()
portal.forwardBackfillLock.Lock()
defer portal.forwardBackfillLock.Unlock()
lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
if lastMessage == nil || meta.LastMessageID == "" {
log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
return
} else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) {
log.Debug().
Str("last_bridged_message", lastMessage.DiscordID).
Str("last_server_message", meta.LastMessageID).
Msg("Not backfilling, last message in database is newer than last message in metadata")
return
}
log.Debug().
Str("last_bridged_message", lastMessage.DiscordID).
Str("last_server_message", meta.LastMessageID).
Msg("Backfilling missed messages")
if limit < 0 {
portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID)
} else {
portal.backfillLimitedMissed(log, source, limit, lastMessage.DiscordID)
}
}
const messageFetchChunkSize = 50
func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) {
var messages []*discordgo.Message
var before string
var foundAll bool
for {
log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "")
if err != nil {
return nil, false, err
}
for i, msg := range newMessages {
if compareMessageIDs(msg.ID, until) <= 0 {
log.Debug().
Str("message_id", msg.ID).
Str("until_id", until).
Msg("Found message that was already bridged")
newMessages = newMessages[:i]
foundAll = true
break
}
}
messages = append(messages, newMessages...)
log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection")
if len(newMessages) <= messageFetchChunkSize || len(messages) >= limit {
break
}
before = newMessages[len(newMessages)-1].ID
}
if len(messages) > limit {
messages = messages[:limit]
}
return messages, foundAll, nil
}
func (portal *Portal) backfillLimitedMissed(log zerolog.Logger, source *User, limit int, after string) {
messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after)
if err != nil {
log.Err(err).Msg("Error collecting messages to forward backfill")
return
}
log.Info().
Int("count", len(messages)).
Bool("found_all", foundAll).
Msg("Collected messages to backfill")
sort.Sort(MessageSlice(messages))
if !foundAll {
_, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
MsgType: event.MsgNotice,
Body: "Some messages may have been missed here while the bridge was offline.",
}, nil, 0)
if err != nil {
log.Warn().Err(err).Msg("Failed to send missed message warning")
} else {
log.Debug().Msg("Sent warning about possibly missed messages")
}
}
portal.sendBackfillBatch(log, source, messages)
}
func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) {
for {
log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "")
if err != nil {
log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
return
}
log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
sort.Sort(MessageSlice(messages))
portal.sendBackfillBatch(log, source, messages)
if len(messages) < messageFetchChunkSize {
// Assume that was all the missing messages
log.Debug().Msg("Chunk had less than 50 messages, stopping backfill")
return
}
after = messages[len(messages)-1].ID
}
}
func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) {
if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
portal.forwardBatchSend(log, source, messages)
} else {
log.Debug().Msg("Not using hungryserv, sending messages one by one")
for _, msg := range messages {
portal.handleDiscordMessageCreate(source, msg, nil)
}
}
}
func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) {
evts := make([]*event.Event, 0, len(messages))
dbMessages := make([]database.Message, 0, len(messages))
for _, msg := range messages {
for _, mention := range msg.Mentions {
puppet := portal.bridge.GetPuppetByID(mention.ID)
puppet.UpdateInfo(nil, mention)
}
puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
puppet.UpdateInfo(source, msg.Author)
intent := puppet.IntentFor(portal)
replyTo := portal.getReplyTarget(source, msg.MessageReference, true)
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
parts := portal.convertDiscordMessage(intent, msg)
for i, part := range parts {
if replyTo != nil {
part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo}
// Only set reply for first event
replyTo = nil
}
partName := part.AttachmentID
// Always use blank part name for first part so that replies and other things
// can reference it without knowing about attachments.
if i == 0 {
partName = ""
}
evts = append(evts, &event.Event{
ID: portal.deterministicEventID(msg.ID, partName),
Type: part.Type,
Sender: intent.UserID,
Timestamp: ts.UnixMilli(),
Content: event.Content{
Parsed: part.Content,
Raw: part.Extra,
},
})
dbMessages = append(dbMessages, database.Message{
Channel: portal.Key,
DiscordID: msg.ID,
SenderID: msg.Author.ID,
Timestamp: ts,
AttachmentID: part.AttachmentID,
})
}
}
log.Info().Int("parts", len(evts)).Msg("Converted messages to backfill")
resp, err := portal.MainIntent().BatchSend(portal.MXID, &mautrix.ReqBatchSend{
BeeperNewMessages: true,
Events: evts,
})
if err != nil {
log.Err(err).Msg("Error sending backfill batch")
return
}
for i, evtID := range resp.EventIDs {
dbMessages[i].MXID = evtID
}
portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
log.Info().Msg("Inserted backfilled batch to database")
}
func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {
data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName)
sum := sha256.Sum256([]byte(data))
return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:])))
}
// compareMessageIDs compares two Discord message IDs.
//
// If the first ID is lower, -1 is returned.
// If the second ID is lower, 1 is returned.
// If the IDs are equal, 0 is returned.
func compareMessageIDs(id1, id2 string) int {
if id1 == id2 {
return 0
}
if len(id1) < len(id2) {
return -1
} else if len(id2) < len(id1) {
return 1
}
if id1 < id2 {
return -1
}
return 1
}
func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool {
return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1
}
type MessageSlice []*discordgo.Message
var _ sort.Interface = (MessageSlice)(nil)
func (a MessageSlice) Len() int {
return len(a)
}
func (a MessageSlice) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a MessageSlice) Less(i, j int) bool {
return compareMessageIDs(a[i].ID, a[j].ID) == -1
}

View File

@@ -66,6 +66,14 @@ type BridgeConfig struct {
CommandPrefix string `yaml:"command_prefix"`
ManagementRoomText bridgeconfig.ManagementRoomTexts `yaml:"management_room_text"`
Backfill struct {
Enabled bool `yaml:"enabled"`
Limits struct {
Initial BackfillLimitPart `yaml:"initial"`
Missed BackfillLimitPart `yaml:"missed"`
} `yaml:"limits"`
} `yaml:"backfill"`
Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"`
Provisioning struct {
@@ -81,6 +89,11 @@ type BridgeConfig struct {
guildNameTemplate *template.Template `yaml:"-"`
}
type BackfillLimitPart struct {
DM int `yaml:"dm"`
Channel int `yaml:"channel"`
}
func (bc *BridgeConfig) GetResendBridgeInfo() bool {
return bc.ResendBridgeInfo
}

View File

@@ -67,6 +67,11 @@ func DoUpgrade(helper *up.Helper) {
helper.Copy(up.Str, "bridge", "management_room_text", "welcome_connected")
helper.Copy(up.Str, "bridge", "management_room_text", "welcome_unconnected")
helper.Copy(up.Str|up.Null, "bridge", "management_room_text", "additional_help")
helper.Copy(up.Bool, "bridge", "backfill", "enabled")
helper.Copy(up.Int, "bridge", "backfill", "limits", "initial", "dm")
helper.Copy(up.Int, "bridge", "backfill", "limits", "initial", "channel")
helper.Copy(up.Int, "bridge", "backfill", "limits", "missed", "dm")
helper.Copy(up.Int, "bridge", "backfill", "limits", "missed", "channel")
helper.Copy(up.Bool, "bridge", "encryption", "allow")
helper.Copy(up.Bool, "bridge", "encryption", "default")
helper.Copy(up.Bool, "bridge", "encryption", "require")

View File

@@ -95,6 +95,36 @@ func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
return mq.New().Scan(row)
}
func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
if len(msgs) == 0 {
return
}
valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)"
if mq.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
params := make([]interface{}, 2+len(msgs)*7)
placeholders := make([]string, len(msgs))
params[0] = key.ChannelID
params[1] = key.Receiver
for i, msg := range msgs {
baseIndex := 2 + i*7
params[baseIndex] = msg.DiscordID
params[baseIndex+1] = msg.AttachmentID
params[baseIndex+2] = msg.EditIndex
params[baseIndex+3] = msg.SenderID
params[baseIndex+4] = msg.Timestamp
params[baseIndex+5] = msg.ThreadID
params[baseIndex+6] = msg.MXID
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7)
}
_, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
if err != nil {
mq.log.Warnfln("Failed to insert %d messages: %v", len(msgs), err)
panic(err)
}
}
type Message struct {
db *Database
log log.Logger
@@ -152,7 +182,7 @@ type MessagePart struct {
MXID id.EventID
}
func (m *Message) MassInsert(msgs []MessagePart) {
func (m *Message) MassInsertParts(msgs []MessagePart) {
if len(msgs) == 0 {
return
}

View File

@@ -186,6 +186,20 @@ bridge:
# Optional extra text sent when joining a management room.
additional_help: ""
backfill:
# Should backfill be enabled at all?
enabled: false
# Limits for backfilling. Set to 0 to disable that type of backfill, or -1 for unlimited.
limits:
# Initial backfill (when creating portal).
initial:
dm: 50
channel: 0
# Missed message backfill (on startup).
missed:
dm: 50
channel: 0
# End-to-bridge encryption support options.
#
# See https://docs.mau.fi/bridges/general/end-to-bridge-encryption.html for more info.

136
portal.go
View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"sync"
@@ -13,8 +12,10 @@ import (
"github.com/bwmarrin/discordgo"
"github.com/gabriel-vasile/mimetype"
"github.com/rs/zerolog"
"maunium.net/go/maulogger/v2/maulogadapt"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
@@ -52,7 +53,9 @@ type Portal struct {
Guild *Guild
bridge *DiscordBridge
log log.Logger
// Deprecated
log maulogger.Logger
zlog zerolog.Logger
roomCreateLock sync.Mutex
encryptLock sync.Mutex
@@ -65,6 +68,8 @@ type Portal struct {
commands map[string]*discordgo.ApplicationCommand
commandsLock sync.RWMutex
forwardBackfillLock sync.Mutex
currentlyTyping []id.UserID
currentlyTypingLock sync.Mutex
}
@@ -233,7 +238,10 @@ func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal {
portal := &Portal{
Portal: dbPortal,
bridge: br,
log: br.Log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)),
zlog: br.ZLog.With().
Str("channel_id", dbPortal.Key.ChannelID).
Str("channel_receiver", dbPortal.Key.Receiver).
Logger(),
discordMessages: make(chan portalDiscordMessage, br.Config.Bridge.PortalMessageBuffer),
matrixMessages: make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer),
@@ -242,6 +250,7 @@ func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal {
commands: make(map[string]*discordgo.ApplicationCommand),
}
portal.log = maulogadapt.ZeroAsMau(&portal.zlog)
go portal.messageLoop()
@@ -252,10 +261,13 @@ func (portal *Portal) messageLoop() {
for {
select {
case msg := <-portal.matrixMessages:
portal.forwardBackfillLock.Lock()
portal.handleMatrixMessages(msg)
case msg := <-portal.discordMessages:
portal.forwardBackfillLock.Lock()
portal.handleDiscordMessages(msg)
}
portal.forwardBackfillLock.Unlock()
}
}
@@ -549,7 +561,7 @@ func (portal *Portal) markMessageHandled(discordID string, editIndex int, author
msg.SenderID = authorID
msg.Timestamp = timestamp
msg.ThreadID = threadID
msg.MassInsert(parts)
msg.MassInsertParts(parts)
}
func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) {
@@ -578,7 +590,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
intent := puppet.IntentFor(portal)
var discordThreadID string
var threadRootEvent, lastThreadEvent, replyToEvent id.EventID
var threadRootEvent, lastThreadEvent id.EventID
if thread != nil {
discordThreadID = thread.ID
threadRootEvent = thread.RootMXID
@@ -588,30 +600,25 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
lastThreadEvent = lastInThread.MXID
}
}
if msg.MessageReference != nil {
// This could be used to find cross-channel replies, but Matrix doesn't support those currently.
//key := database.PortalKey{msg.MessageReference.ChannelID, user.ID}
replyToMsg := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID)
if len(replyToMsg) > 0 {
replyToEvent = replyToMsg[0].MXID
}
}
replyTo := portal.getReplyTarget(user, msg.MessageReference, false)
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
parts := portal.convertDiscordMessage(intent, msg)
dbParts := make([]database.MessagePart, 0, len(parts))
for i, part := range parts {
if (replyToEvent != "" || threadRootEvent != "") && part.Content.RelatesTo == nil {
if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil {
part.Content.RelatesTo = &event.RelatesTo{}
}
if threadRootEvent != "" {
part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent)
}
if replyToEvent != "" {
part.Content.RelatesTo.SetReplyTo(replyToEvent)
if replyTo != nil {
part.Content.RelatesTo.SetReplyTo(replyTo.EventID)
if replyTo.UnstableRoomID != "" {
part.Content.RelatesTo.InReplyTo.UnstableRoomID = replyTo.UnstableRoomID
}
// Only set reply for first event
replyToEvent = ""
replyTo = nil
}
resp, err := portal.sendMatrixMessage(intent, part.Type, part.Content, part.Extra, ts.UnixMilli())
if err != nil {
@@ -630,6 +637,42 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
}
}
func (portal *Portal) getReplyTarget(source *User, ref *discordgo.MessageReference, allowNonExistent bool) *event.InReplyTo {
if ref == nil {
return nil
}
isHungry := portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry
if !isHungry {
allowNonExistent = false
}
// TODO add config option for cross-room replies
crossRoomReplies := isHungry
targetPortal := portal
if ref.ChannelID != portal.Key.ChannelID && crossRoomReplies {
targetPortal = portal.bridge.GetExistingPortalByID(database.PortalKey{ChannelID: ref.ChannelID, Receiver: source.DiscordID})
if targetPortal == nil {
return nil
}
}
replyToMsg := portal.bridge.DB.Message.GetByDiscordID(targetPortal.Key, ref.MessageID)
if len(replyToMsg) > 0 {
if !crossRoomReplies {
return &event.InReplyTo{EventID: replyToMsg[0].MXID}
}
return &event.InReplyTo{
EventID: replyToMsg[0].MXID,
UnstableRoomID: targetPortal.MXID,
}
} else if allowNonExistent {
return &event.InReplyTo{
EventID: targetPortal.deterministicEventID(ref.MessageID, ""),
UnstableRoomID: targetPortal.MXID,
}
}
return nil
}
const JoinThreadReaction = "join thread"
func (portal *Portal) sendThreadCreationNotice(thread *Thread) {
@@ -1066,9 +1109,9 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin
evtDescription += fmt.Sprintf(" of %s", evt.Redacts)
}
if err != nil {
level := log.LevelError
level := maulogger.LevelError
if part == "Ignoring" {
level = log.LevelDebug
level = maulogger.LevelDebug
}
portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err)
reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err)
@@ -1376,7 +1419,7 @@ func (portal *Portal) cleanup(puppetsOnly bool) {
portal.bridge.cleanupRoom(intent, portal.MXID, puppetsOnly, portal.log)
}
func (br *DiscordBridge) cleanupRoom(intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool, log log.Logger) {
func (br *DiscordBridge) cleanupRoom(intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool, log maulogger.Logger) {
members, err := intent.JoinedMembers(mxid)
if err != nil {
log.Errorln("Failed to get portal members for cleanup:", err)
@@ -1975,52 +2018,3 @@ func (portal *Portal) UpdateInfo(source *User, meta *discordgo.Channel) *discord
}
return meta
}
func (portal *Portal) ForwardBackfill(source *User, meta *discordgo.Channel) error {
portal.log.Debugln("Checking for missing messages to fill")
lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
if lastMessage == nil {
return nil
}
metaLastMessageID, err := strconv.ParseInt(meta.LastMessageID, 10, 0)
if err != nil {
portal.log.Errorfln("Last message ID %s isn't integer", meta.LastMessageID)
return err
}
dbLastMessageID, err := strconv.ParseInt(lastMessage.DiscordID, 10, 0)
if err != nil {
portal.log.Errorfln("Last message ID %s isn't integer", lastMessage.DiscordID)
return err
}
if metaLastMessageID <= dbLastMessageID {
return nil
}
// Get up to 50 messages at a time until everything is fetched
for {
messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, 50, "", lastMessage.DiscordID, "")
if err != nil {
portal.log.Debugln("Error getting messages to forward backfill", err)
return err
}
// Discord seems to return messages in reverse order,
// but no specific order is guaranteed by their API docs?
sort.Slice(messages, func(i, j int) bool {
return messages[i].Timestamp.Before(messages[j].Timestamp)
})
for _, msg := range messages {
portal.handleDiscordMessageCreate(source, msg, nil)
}
if len(messages) < 100 {
// Assume that was all the missing messages
return nil
}
lastMessage = portal.bridge.DB.Message.GetLast(portal.Key)
if lastMessage == nil {
return nil
}
}
}

View File

@@ -728,7 +728,7 @@ func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel,
}
} else {
portal.UpdateInfo(user, meta)
portal.ForwardBackfill(user, meta)
portal.ForwardBackfillMissed(user, meta)
}
user.MarkInPortal(database.UserPortal{
DiscordID: portal.Key.ChannelID,
@@ -843,7 +843,7 @@ func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSp
}
} else {
portal.UpdateInfo(user, ch)
portal.ForwardBackfill(user, ch)
portal.ForwardBackfillMissed(user, ch)
}
}
}