Make backfilling code compatible with threads

This doesn't trigger thread backfill yet, but the backfill methods can
handle threads now.
This commit is contained in:
Tulir Asokan
2023-06-18 20:09:21 +03:00
parent 248664f8b0
commit 8ebad277f5
6 changed files with 84 additions and 35 deletions

View File

@@ -17,7 +17,7 @@ import (
"go.mau.fi/mautrix-discord/database" "go.mau.fi/mautrix-discord/database"
) )
func (portal *Portal) forwardBackfillInitial(source *User) { func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) {
defer portal.forwardBackfillLock.Unlock() defer portal.forwardBackfillLock.Unlock()
// This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room. // This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room.
if portal.forwardBackfillLock.TryLock() { if portal.forwardBackfillLock.TryLock() {
@@ -27,21 +27,27 @@ func (portal *Portal) forwardBackfillInitial(source *User) {
limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel
if portal.GuildID == "" { if portal.GuildID == "" {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM
if thread != nil {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread
}
} }
if limit == 0 { if limit == 0 {
return return
} }
log := portal.log.With(). with := portal.log.With().
Str("action", "initial backfill"). Str("action", "initial backfill").
Str("room_id", portal.MXID.String()). Str("room_id", portal.MXID.String()).
Int("limit", limit). Int("limit", limit)
Logger() if thread != nil {
with = with.Str("thread_id", thread.ID)
}
log := with.Logger()
portal.backfillLimited(log, source, limit, "") portal.backfillLimited(log, source, limit, "", thread)
} }
func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) { func (portal *Portal) ForwardBackfillMissed(source *User, serverLastMessageID string, thread *Thread) {
if portal.MXID == "" { if portal.MXID == "" {
return return
} }
@@ -49,50 +55,65 @@ func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channe
limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
if portal.GuildID == "" { if portal.GuildID == "" {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
if thread != nil {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.Thread
}
} }
if limit == 0 { if limit == 0 {
return return
} }
log := portal.log.With(). with := portal.log.With().
Str("action", "missed event backfill"). Str("action", "missed event backfill").
Str("room_id", portal.MXID.String()). Str("room_id", portal.MXID.String()).
Int("limit", limit). Int("limit", limit)
Logger() if thread != nil {
with = with.Str("thread_id", thread.ID)
}
log := with.Logger()
portal.forwardBackfillLock.Lock() portal.forwardBackfillLock.Lock()
defer portal.forwardBackfillLock.Unlock() defer portal.forwardBackfillLock.Unlock()
lastMessage := portal.bridge.DB.Message.GetLast(portal.Key) var lastMessage *database.Message
if lastMessage == nil || meta.LastMessageID == "" { if thread != nil {
lastMessage = portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
} else {
lastMessage = portal.bridge.DB.Message.GetLast(portal.Key)
}
if lastMessage == nil || serverLastMessageID == "" {
log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata") log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
return return
} else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) { } else if !shouldBackfill(lastMessage.DiscordID, serverLastMessageID) {
log.Debug(). log.Debug().
Str("last_bridged_message", lastMessage.DiscordID). Str("last_bridged_message", lastMessage.DiscordID).
Str("last_server_message", meta.LastMessageID). Str("last_server_message", serverLastMessageID).
Msg("Not backfilling, last message in database is newer than last message in metadata") Msg("Not backfilling, last message in database is newer than last message in metadata")
return return
} }
log.Debug(). log.Debug().
Str("last_bridged_message", lastMessage.DiscordID). Str("last_bridged_message", lastMessage.DiscordID).
Str("last_server_message", meta.LastMessageID). Str("last_server_message", serverLastMessageID).
Msg("Backfilling missed messages") Msg("Backfilling missed messages")
if limit < 0 { if limit < 0 {
portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID) portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID, thread)
} else { } else {
portal.backfillLimited(log, source, limit, lastMessage.DiscordID) portal.backfillLimited(log, source, limit, lastMessage.DiscordID, thread)
} }
} }
const messageFetchChunkSize = 50 const messageFetchChunkSize = 50
func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) { func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string, thread *Thread) ([]*discordgo.Message, bool, error) {
var messages []*discordgo.Message var messages []*discordgo.Message
var before string var before string
var foundAll bool var foundAll bool
protoChannelID := portal.Key.ChannelID
if thread != nil {
protoChannelID = thread.ID
}
for { for {
log.Debug().Str("before_id", before).Msg("Fetching messages for backfill") log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "") newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "")
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
@@ -123,8 +144,8 @@ func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User,
return messages, foundAll, nil return messages, foundAll, nil
} }
func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string) { func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string, thread *Thread) {
messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after) messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after, thread)
if err != nil { if err != nil {
log.Err(err).Msg("Error collecting messages to forward backfill") log.Err(err).Msg("Error collecting messages to forward backfill")
return return
@@ -145,13 +166,17 @@ func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit in
log.Debug().Msg("Sent warning about possibly missed messages") log.Debug().Msg("Sent warning about possibly missed messages")
} }
} }
portal.sendBackfillBatch(log, source, messages) portal.sendBackfillBatch(log, source, messages, thread)
} }
func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) { func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string, thread *Thread) {
protoChannelID := portal.Key.ChannelID
if thread != nil {
protoChannelID = thread.ID
}
for { for {
log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill") log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "") messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "")
if err != nil { if err != nil {
log.Err(err).Msg("Error fetching chunk of messages to forward backfill") log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
return return
@@ -159,7 +184,7 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User,
log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill") log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
sort.Sort(MessageSlice(messages)) sort.Sort(MessageSlice(messages))
portal.sendBackfillBatch(log, source, messages) portal.sendBackfillBatch(log, source, messages, thread)
if len(messages) < messageFetchChunkSize { if len(messages) < messageFetchChunkSize {
// Assume that was all the missing messages // Assume that was all the missing messages
@@ -170,20 +195,20 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User,
} }
} }
func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) { func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint") log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
portal.forwardBatchSend(log, source, messages) portal.forwardBatchSend(log, source, messages, thread)
} else { } else {
log.Debug().Msg("Not using hungryserv, sending messages one by one") log.Debug().Msg("Not using hungryserv, sending messages one by one")
for _, msg := range messages { for _, msg := range messages {
portal.handleDiscordMessageCreate(source, msg, nil) portal.handleDiscordMessageCreate(source, msg, thread)
} }
} }
} }
func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) { func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
evts, dbMessages := portal.convertMessageBatch(log, source, messages) evts, dbMessages := portal.convertMessageBatch(log, source, messages, thread)
if len(evts) == 0 { if len(evts) == 0 {
log.Warn().Msg("Didn't get any events to backfill") log.Warn().Msg("Didn't get any events to backfill")
return return
@@ -204,7 +229,19 @@ func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, message
log.Info().Msg("Inserted backfilled batch to database") log.Info().Msg("Inserted backfilled batch to database")
} }
func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) ([]*event.Event, []database.Message) { func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []database.Message) {
var discordThreadID string
var threadRootEvent, lastThreadEvent id.EventID
if thread != nil {
discordThreadID = thread.ID
threadRootEvent = thread.RootMXID
lastThreadEvent = threadRootEvent
lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
if lastInThread != nil {
lastThreadEvent = lastInThread.MXID
}
}
evts := make([]*event.Event, 0, len(messages)) evts := make([]*event.Event, 0, len(messages))
dbMessages := make([]database.Message, 0, len(messages)) dbMessages := make([]database.Message, 0, len(messages))
ctx := context.Background() ctx := context.Background()
@@ -217,7 +254,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
puppet := portal.bridge.GetPuppetByID(msg.Author.ID) puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
puppet.UpdateInfo(source, msg.Author, msg.WebhookID) puppet.UpdateInfo(source, msg.Author, msg.WebhookID)
intent := puppet.IntentFor(portal) intent := puppet.IntentFor(portal)
replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true) replyTo := portal.getReplyTarget(source, discordThreadID, msg.MessageReference, msg.Embeds, true)
mentions := portal.convertDiscordMentions(msg, false) mentions := portal.convertDiscordMentions(msg, false)
ts, _ := discordgo.SnowflakeTimestamp(msg.ID) ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
@@ -228,8 +265,14 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
Logger() Logger()
parts := portal.convertDiscordMessage(log.WithContext(ctx), puppet, intent, msg) parts := portal.convertDiscordMessage(log.WithContext(ctx), puppet, intent, msg)
for i, part := range parts { for i, part := range parts {
if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil {
part.Content.RelatesTo = &event.RelatesTo{}
}
if threadRootEvent != "" {
part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent)
}
if replyTo != nil { if replyTo != nil {
part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo} part.Content.RelatesTo.SetReplyTo(replyTo.EventID)
// Only set reply for first event // Only set reply for first event
replyTo = nil replyTo = nil
} }
@@ -270,6 +313,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
AttachmentID: part.AttachmentID, AttachmentID: part.AttachmentID,
SenderMXID: intent.UserID, SenderMXID: intent.UserID,
}) })
lastThreadEvent = evt.ID
} }
} }
return evts, dbMessages return evts, dbMessages

View File

@@ -207,6 +207,7 @@ func (mp *MediaPatterns) Avatar(userID, avatarID, ext string) id.ContentURI {
type BackfillLimitPart struct { type BackfillLimitPart struct {
DM int `yaml:"dm"` DM int `yaml:"dm"`
Channel int `yaml:"channel"` Channel int `yaml:"channel"`
Thread int `yaml:"thread"`
} }
func (bc *BridgeConfig) GetResendBridgeInfo() bool { func (bc *BridgeConfig) GetResendBridgeInfo() bool {

View File

@@ -79,8 +79,10 @@ func DoUpgrade(helper *up.Helper) {
helper.Copy(up.Bool, "bridge", "backfill", "enabled") helper.Copy(up.Bool, "bridge", "backfill", "enabled")
helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "dm") helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "dm")
helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "channel") helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "channel")
helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "thread")
helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "dm") helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "dm")
helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "channel") helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "channel")
helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "thread")
helper.Copy(up.Int, "bridge", "backfill", "max_guild_members") helper.Copy(up.Int, "bridge", "backfill", "max_guild_members")
helper.Copy(up.Bool, "bridge", "encryption", "allow") helper.Copy(up.Bool, "bridge", "encryption", "allow")
helper.Copy(up.Bool, "bridge", "encryption", "default") helper.Copy(up.Bool, "bridge", "encryption", "default")

View File

@@ -232,6 +232,7 @@ bridge:
initial: initial:
dm: 0 dm: 0
channel: 0 channel: 0
thread: 0
# Missed message backfill (on startup). # Missed message backfill (on startup).
# 0 means backfill is disabled, -1 means fetch all messages since last bridged message. # 0 means backfill is disabled, -1 means fetch all messages since last bridged message.
# When using unlimited backfill (-1), messages are backfilled as they are fetched. # When using unlimited backfill (-1), messages are backfilled as they are fetched.
@@ -239,6 +240,7 @@ bridge:
missed: missed:
dm: 0 dm: 0
channel: 0 channel: 0
thread: 0
# Maximum members in a guild to enable backfilling. Set to -1 to disable limit. # Maximum members in a guild to enable backfilling. Set to -1 to disable limit.
# This can be used as a rough heuristic to disable backfilling in channels that are too active. # This can be used as a rough heuristic to disable backfilling in channels that are too active.
# Currently only applies to missed message backfill. # Currently only applies to missed message backfill.

View File

@@ -541,7 +541,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, channel *discordgo.Channel) e
portal.Update() portal.Update()
} }
go portal.forwardBackfillInitial(user) go portal.forwardBackfillInitial(user, nil)
backfillStarted = true backfillStarted = true
return nil return nil

View File

@@ -864,7 +864,7 @@ func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel,
} }
} else { } else {
portal.UpdateInfo(user, meta) portal.UpdateInfo(user, meta)
portal.ForwardBackfillMissed(user, meta) portal.ForwardBackfillMissed(user, meta.LastMessageID, nil)
} }
user.MarkInPortal(database.UserPortal{ user.MarkInPortal(database.UserPortal{
DiscordID: portal.Key.ChannelID, DiscordID: portal.Key.ChannelID,
@@ -966,7 +966,7 @@ func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSp
} else { } else {
portal.UpdateInfo(user, ch) portal.UpdateInfo(user, ch)
if user.bridge.Config.Bridge.Backfill.MaxGuildMembers < 0 || meta.MemberCount < user.bridge.Config.Bridge.Backfill.MaxGuildMembers { if user.bridge.Config.Bridge.Backfill.MaxGuildMembers < 0 || meta.MemberCount < user.bridge.Config.Bridge.Backfill.MaxGuildMembers {
portal.ForwardBackfillMissed(user, ch) portal.ForwardBackfillMissed(user, ch.LastMessageID, nil)
} }
} }
} }