Make backfilling code compatible with threads

This doesn't trigger thread backfill yet, but the backfill methods can
handle threads now.
This commit is contained in:
Tulir Asokan
2023-06-18 20:09:21 +03:00
parent 248664f8b0
commit 8ebad277f5
6 changed files with 84 additions and 35 deletions

View File

@@ -17,7 +17,7 @@ import (
"go.mau.fi/mautrix-discord/database"
)
func (portal *Portal) forwardBackfillInitial(source *User) {
func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) {
defer portal.forwardBackfillLock.Unlock()
// This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room.
if portal.forwardBackfillLock.TryLock() {
@@ -27,21 +27,27 @@ func (portal *Portal) forwardBackfillInitial(source *User) {
limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel
if portal.GuildID == "" {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM
if thread != nil {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread
}
}
if limit == 0 {
return
}
log := portal.log.With().
with := portal.log.With().
Str("action", "initial backfill").
Str("room_id", portal.MXID.String()).
Int("limit", limit).
Logger()
Int("limit", limit)
if thread != nil {
with = with.Str("thread_id", thread.ID)
}
log := with.Logger()
portal.backfillLimited(log, source, limit, "")
portal.backfillLimited(log, source, limit, "", thread)
}
func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) {
func (portal *Portal) ForwardBackfillMissed(source *User, serverLastMessageID string, thread *Thread) {
if portal.MXID == "" {
return
}
@@ -49,50 +55,65 @@ func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channe
limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
if portal.GuildID == "" {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
if thread != nil {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.Thread
}
}
if limit == 0 {
return
}
log := portal.log.With().
with := portal.log.With().
Str("action", "missed event backfill").
Str("room_id", portal.MXID.String()).
Int("limit", limit).
Logger()
Int("limit", limit)
if thread != nil {
with = with.Str("thread_id", thread.ID)
}
log := with.Logger()
portal.forwardBackfillLock.Lock()
defer portal.forwardBackfillLock.Unlock()
lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
if lastMessage == nil || meta.LastMessageID == "" {
var lastMessage *database.Message
if thread != nil {
lastMessage = portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
} else {
lastMessage = portal.bridge.DB.Message.GetLast(portal.Key)
}
if lastMessage == nil || serverLastMessageID == "" {
log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
return
} else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) {
} else if !shouldBackfill(lastMessage.DiscordID, serverLastMessageID) {
log.Debug().
Str("last_bridged_message", lastMessage.DiscordID).
Str("last_server_message", meta.LastMessageID).
Str("last_server_message", serverLastMessageID).
Msg("Not backfilling, last message in database is newer than last message in metadata")
return
}
log.Debug().
Str("last_bridged_message", lastMessage.DiscordID).
Str("last_server_message", meta.LastMessageID).
Str("last_server_message", serverLastMessageID).
Msg("Backfilling missed messages")
if limit < 0 {
portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID)
portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID, thread)
} else {
portal.backfillLimited(log, source, limit, lastMessage.DiscordID)
portal.backfillLimited(log, source, limit, lastMessage.DiscordID, thread)
}
}
const messageFetchChunkSize = 50
func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) {
func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string, thread *Thread) ([]*discordgo.Message, bool, error) {
var messages []*discordgo.Message
var before string
var foundAll bool
protoChannelID := portal.Key.ChannelID
if thread != nil {
protoChannelID = thread.ID
}
for {
log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "")
newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "")
if err != nil {
return nil, false, err
}
@@ -123,8 +144,8 @@ func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User,
return messages, foundAll, nil
}
func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string) {
messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after)
func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string, thread *Thread) {
messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after, thread)
if err != nil {
log.Err(err).Msg("Error collecting messages to forward backfill")
return
@@ -145,13 +166,17 @@ func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit in
log.Debug().Msg("Sent warning about possibly missed messages")
}
}
portal.sendBackfillBatch(log, source, messages)
portal.sendBackfillBatch(log, source, messages, thread)
}
func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) {
func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string, thread *Thread) {
protoChannelID := portal.Key.ChannelID
if thread != nil {
protoChannelID = thread.ID
}
for {
log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "")
messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "")
if err != nil {
log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
return
@@ -159,7 +184,7 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User,
log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
sort.Sort(MessageSlice(messages))
portal.sendBackfillBatch(log, source, messages)
portal.sendBackfillBatch(log, source, messages, thread)
if len(messages) < messageFetchChunkSize {
// Assume that was all the missing messages
@@ -170,20 +195,20 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User,
}
}
func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) {
func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
portal.forwardBatchSend(log, source, messages)
portal.forwardBatchSend(log, source, messages, thread)
} else {
log.Debug().Msg("Not using hungryserv, sending messages one by one")
for _, msg := range messages {
portal.handleDiscordMessageCreate(source, msg, nil)
portal.handleDiscordMessageCreate(source, msg, thread)
}
}
}
func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) {
evts, dbMessages := portal.convertMessageBatch(log, source, messages)
func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
evts, dbMessages := portal.convertMessageBatch(log, source, messages, thread)
if len(evts) == 0 {
log.Warn().Msg("Didn't get any events to backfill")
return
@@ -204,7 +229,19 @@ func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, message
log.Info().Msg("Inserted backfilled batch to database")
}
func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) ([]*event.Event, []database.Message) {
func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []database.Message) {
var discordThreadID string
var threadRootEvent, lastThreadEvent id.EventID
if thread != nil {
discordThreadID = thread.ID
threadRootEvent = thread.RootMXID
lastThreadEvent = threadRootEvent
lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
if lastInThread != nil {
lastThreadEvent = lastInThread.MXID
}
}
evts := make([]*event.Event, 0, len(messages))
dbMessages := make([]database.Message, 0, len(messages))
ctx := context.Background()
@@ -217,7 +254,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
puppet.UpdateInfo(source, msg.Author, msg.WebhookID)
intent := puppet.IntentFor(portal)
replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true)
replyTo := portal.getReplyTarget(source, discordThreadID, msg.MessageReference, msg.Embeds, true)
mentions := portal.convertDiscordMentions(msg, false)
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
@@ -228,8 +265,14 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
Logger()
parts := portal.convertDiscordMessage(log.WithContext(ctx), puppet, intent, msg)
for i, part := range parts {
if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil {
part.Content.RelatesTo = &event.RelatesTo{}
}
if threadRootEvent != "" {
part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent)
}
if replyTo != nil {
part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo}
part.Content.RelatesTo.SetReplyTo(replyTo.EventID)
// Only set reply for first event
replyTo = nil
}
@@ -270,6 +313,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
AttachmentID: part.AttachmentID,
SenderMXID: intent.UserID,
})
lastThreadEvent = evt.ID
}
}
return evts, dbMessages