diff --git a/backfill.go b/backfill.go index 030d8ef..75e5efc 100644 --- a/backfill.go +++ b/backfill.go @@ -17,7 +17,7 @@ import ( "go.mau.fi/mautrix-discord/database" ) -func (portal *Portal) forwardBackfillInitial(source *User) { +func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) { defer portal.forwardBackfillLock.Unlock() // This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room. if portal.forwardBackfillLock.TryLock() { @@ -27,21 +27,27 @@ func (portal *Portal) forwardBackfillInitial(source *User) { limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel if portal.GuildID == "" { limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM + if thread != nil { + limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread + } } if limit == 0 { return } - log := portal.log.With(). + with := portal.log.With(). Str("action", "initial backfill"). Str("room_id", portal.MXID.String()). - Int("limit", limit). - Logger() + Int("limit", limit) + if thread != nil { + with = with.Str("thread_id", thread.ID) + } + log := with.Logger() - portal.backfillLimited(log, source, limit, "") + portal.backfillLimited(log, source, limit, "", thread) } -func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) { +func (portal *Portal) ForwardBackfillMissed(source *User, serverLastMessageID string, thread *Thread) { if portal.MXID == "" { return } @@ -49,50 +55,65 @@ func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channe limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel if portal.GuildID == "" { limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM + if thread != nil { + limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.Thread + } } if limit == 0 { return } - log := portal.log.With(). + with := portal.log.With(). Str("action", "missed event backfill"). Str("room_id", portal.MXID.String()). - Int("limit", limit). - Logger() + Int("limit", limit) + if thread != nil { + with = with.Str("thread_id", thread.ID) + } + log := with.Logger() portal.forwardBackfillLock.Lock() defer portal.forwardBackfillLock.Unlock() - lastMessage := portal.bridge.DB.Message.GetLast(portal.Key) - if lastMessage == nil || meta.LastMessageID == "" { + var lastMessage *database.Message + if thread != nil { + lastMessage = portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID) + } else { + lastMessage = portal.bridge.DB.Message.GetLast(portal.Key) + } + if lastMessage == nil || serverLastMessageID == "" { log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata") return - } else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) { + } else if !shouldBackfill(lastMessage.DiscordID, serverLastMessageID) { log.Debug(). Str("last_bridged_message", lastMessage.DiscordID). - Str("last_server_message", meta.LastMessageID). + Str("last_server_message", serverLastMessageID). Msg("Not backfilling, last message in database is newer than last message in metadata") return } log.Debug(). Str("last_bridged_message", lastMessage.DiscordID). - Str("last_server_message", meta.LastMessageID). + Str("last_server_message", serverLastMessageID). Msg("Backfilling missed messages") if limit < 0 { - portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID) + portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID, thread) } else { - portal.backfillLimited(log, source, limit, lastMessage.DiscordID) + portal.backfillLimited(log, source, limit, lastMessage.DiscordID, thread) } } const messageFetchChunkSize = 50 -func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) { +func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string, thread *Thread) ([]*discordgo.Message, bool, error) { var messages []*discordgo.Message var before string var foundAll bool + protoChannelID := portal.Key.ChannelID + if thread != nil { + protoChannelID = thread.ID + } for { log.Debug().Str("before_id", before).Msg("Fetching messages for backfill") - newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "") + newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "") if err != nil { return nil, false, err } @@ -123,8 +144,8 @@ func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, return messages, foundAll, nil } -func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string) { - messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after) +func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string, thread *Thread) { + messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after, thread) if err != nil { log.Err(err).Msg("Error collecting messages to forward backfill") return @@ -145,13 +166,17 @@ func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit in log.Debug().Msg("Sent warning about possibly missed messages") } } - portal.sendBackfillBatch(log, source, messages) + portal.sendBackfillBatch(log, source, messages, thread) } -func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) { +func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string, thread *Thread) { + protoChannelID := portal.Key.ChannelID + if thread != nil { + protoChannelID = thread.ID + } for { log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill") - messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "") + messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "") if err != nil { log.Err(err).Msg("Error fetching chunk of messages to forward backfill") return @@ -159,7 +184,7 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill") sort.Sort(MessageSlice(messages)) - portal.sendBackfillBatch(log, source, messages) + portal.sendBackfillBatch(log, source, messages, thread) if len(messages) < messageFetchChunkSize { // Assume that was all the missing messages @@ -170,20 +195,20 @@ func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, } } -func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) { +func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) { if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint") - portal.forwardBatchSend(log, source, messages) + portal.forwardBatchSend(log, source, messages, thread) } else { log.Debug().Msg("Not using hungryserv, sending messages one by one") for _, msg := range messages { - portal.handleDiscordMessageCreate(source, msg, nil) + portal.handleDiscordMessageCreate(source, msg, thread) } } } -func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) { - evts, dbMessages := portal.convertMessageBatch(log, source, messages) +func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) { + evts, dbMessages := portal.convertMessageBatch(log, source, messages, thread) if len(evts) == 0 { log.Warn().Msg("Didn't get any events to backfill") return @@ -204,7 +229,19 @@ func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, message log.Info().Msg("Inserted backfilled batch to database") } -func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) ([]*event.Event, []database.Message) { +func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []database.Message) { + var discordThreadID string + var threadRootEvent, lastThreadEvent id.EventID + if thread != nil { + discordThreadID = thread.ID + threadRootEvent = thread.RootMXID + lastThreadEvent = threadRootEvent + lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID) + if lastInThread != nil { + lastThreadEvent = lastInThread.MXID + } + } + evts := make([]*event.Event, 0, len(messages)) dbMessages := make([]database.Message, 0, len(messages)) ctx := context.Background() @@ -217,7 +254,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess puppet := portal.bridge.GetPuppetByID(msg.Author.ID) puppet.UpdateInfo(source, msg.Author, msg.WebhookID) intent := puppet.IntentFor(portal) - replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true) + replyTo := portal.getReplyTarget(source, discordThreadID, msg.MessageReference, msg.Embeds, true) mentions := portal.convertDiscordMentions(msg, false) ts, _ := discordgo.SnowflakeTimestamp(msg.ID) @@ -228,8 +265,14 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess Logger() parts := portal.convertDiscordMessage(log.WithContext(ctx), puppet, intent, msg) for i, part := range parts { + if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil { + part.Content.RelatesTo = &event.RelatesTo{} + } + if threadRootEvent != "" { + part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent) + } if replyTo != nil { - part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo} + part.Content.RelatesTo.SetReplyTo(replyTo.EventID) // Only set reply for first event replyTo = nil } @@ -270,6 +313,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess AttachmentID: part.AttachmentID, SenderMXID: intent.UserID, }) + lastThreadEvent = evt.ID } } return evts, dbMessages diff --git a/config/bridge.go b/config/bridge.go index 9118317..30f69e2 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -207,6 +207,7 @@ func (mp *MediaPatterns) Avatar(userID, avatarID, ext string) id.ContentURI { type BackfillLimitPart struct { DM int `yaml:"dm"` Channel int `yaml:"channel"` + Thread int `yaml:"thread"` } func (bc *BridgeConfig) GetResendBridgeInfo() bool { diff --git a/config/upgrade.go b/config/upgrade.go index 6dc713d..6367280 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -79,8 +79,10 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "bridge", "backfill", "enabled") helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "dm") helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "channel") + helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "initial", "thread") helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "dm") helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "channel") + helper.Copy(up.Int, "bridge", "backfill", "forward_limits", "missed", "thread") helper.Copy(up.Int, "bridge", "backfill", "max_guild_members") helper.Copy(up.Bool, "bridge", "encryption", "allow") helper.Copy(up.Bool, "bridge", "encryption", "default") diff --git a/example-config.yaml b/example-config.yaml index d2e800d..8c807e9 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -232,6 +232,7 @@ bridge: initial: dm: 0 channel: 0 + thread: 0 # Missed message backfill (on startup). # 0 means backfill is disabled, -1 means fetch all messages since last bridged message. # When using unlimited backfill (-1), messages are backfilled as they are fetched. @@ -239,6 +240,7 @@ bridge: missed: dm: 0 channel: 0 + thread: 0 # Maximum members in a guild to enable backfilling. Set to -1 to disable limit. # This can be used as a rough heuristic to disable backfilling in channels that are too active. # Currently only applies to missed message backfill. diff --git a/portal.go b/portal.go index 4bd1ab4..ac6ea7b 100644 --- a/portal.go +++ b/portal.go @@ -541,7 +541,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, channel *discordgo.Channel) e portal.Update() } - go portal.forwardBackfillInitial(user) + go portal.forwardBackfillInitial(user, nil) backfillStarted = true return nil diff --git a/user.go b/user.go index 1e5aaf2..dd75978 100644 --- a/user.go +++ b/user.go @@ -864,7 +864,7 @@ func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel, } } else { portal.UpdateInfo(user, meta) - portal.ForwardBackfillMissed(user, meta) + portal.ForwardBackfillMissed(user, meta.LastMessageID, nil) } user.MarkInPortal(database.UserPortal{ DiscordID: portal.Key.ChannelID, @@ -966,7 +966,7 @@ func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSp } else { portal.UpdateInfo(user, ch) if user.bridge.Config.Bridge.Backfill.MaxGuildMembers < 0 || meta.MemberCount < user.bridge.Config.Bridge.Backfill.MaxGuildMembers { - portal.ForwardBackfillMissed(user, ch) + portal.ForwardBackfillMissed(user, ch.LastMessageID, nil) } } }