From b77eea45869c85124403cf398581c7f5f5fa3d39 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 18 Jun 2023 20:49:27 +0300 Subject: [PATCH] Create threads for backfilled messages --- backfill.go | 24 +++++++++++++++++++++--- portal.go | 16 +++++++++++++--- portal_convert.go | 6 ++++++ 3 files changed, 40 insertions(+), 6 deletions(-) diff --git a/backfill.go b/backfill.go index 75e5efc..20e254d 100644 --- a/backfill.go +++ b/backfill.go @@ -208,7 +208,7 @@ func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messag } func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) { - evts, dbMessages := portal.convertMessageBatch(log, source, messages, thread) + evts, metas, dbMessages := portal.convertMessageBatch(log, source, messages, thread) if len(evts) == 0 { log.Warn().Msg("Didn't get any events to backfill") return @@ -224,12 +224,24 @@ func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, message } for i, evtID := range resp.EventIDs { dbMessages[i].MXID = evtID + if metas[i] != nil && metas[i].Flags == discordgo.MessageFlagsHasThread { + thread = portal.bridge.GetThreadByID(metas[i].ID, &dbMessages[i]) + log.Debug(). + Str("message_id", metas[i].ID). + Str("event_id", evtID.String()). + Msg("Marked backfilled message as thread root") + if thread.CreationNoticeMXID == "" { + // TODO proper context + ctx := log.WithContext(context.Background()) + portal.sendThreadCreationNotice(ctx, thread) + } + } } portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages) log.Info().Msg("Inserted backfilled batch to database") } -func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []database.Message) { +func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []*discordgo.Message, []database.Message) { var discordThreadID string var threadRootEvent, lastThreadEvent id.EventID if thread != nil { @@ -244,6 +256,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess evts := make([]*event.Event, 0, len(messages)) dbMessages := make([]database.Message, 0, len(messages)) + metas := make([]*discordgo.Message, 0, len(messages)) ctx := context.Background() for _, msg := range messages { for _, mention := range msg.Mentions { @@ -313,10 +326,15 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess AttachmentID: part.AttachmentID, SenderMXID: intent.UserID, }) + if i == 0 { + metas = append(metas, msg) + } else { + metas = append(metas, nil) + } lastThreadEvent = evt.ID } } - return evts, dbMessages + return evts, metas, dbMessages } func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID { diff --git a/portal.go b/portal.go index ac6ea7b..f3dd341 100644 --- a/portal.go +++ b/portal.go @@ -586,7 +586,7 @@ func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool { return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache) } -func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, senderMXID id.UserID, parts []database.MessagePart) { +func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, senderMXID id.UserID, parts []database.MessagePart) *database.Message { msg := portal.bridge.DB.Message.New() msg.Channel = portal.Key msg.DiscordID = discordID @@ -595,6 +595,9 @@ func (portal *Portal) markMessageHandled(discordID string, authorID string, time msg.ThreadID = threadID msg.SenderMXID = senderMXID msg.MassInsertParts(parts) + msg.MXID = parts[0].MXID + msg.AttachmentID = parts[0].AttachmentID + return msg } func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) { @@ -678,7 +681,14 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } else if len(dbParts) == 0 { log.Warn().Msg("All parts of message failed to send to Matrix") } else { - portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts) + firstDBMessage := portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts) + if msg.Flags == discordgo.MessageFlagsHasThread { + thread = portal.bridge.GetThreadByID(msg.ID, firstDBMessage) + log.Debug().Msg("Marked message as thread root") + if thread.CreationNoticeMXID == "" { + portal.sendThreadCreationNotice(ctx, thread) + } + } } } @@ -1463,7 +1473,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { } return } else if threadRoot := content.GetRelatesTo().GetThreadParent(); threadRoot != "" { - existingThread := portal.bridge.DB.Thread.GetByMatrixRootMsg(threadRoot) + existingThread := portal.bridge.GetThreadByRootMXID(threadRoot) if existingThread != nil { threadID = existingThread.ID } else { diff --git a/portal_convert.go b/portal_convert.go index 3c72842..10538e5 100644 --- a/portal_convert.go +++ b/portal_convert.go @@ -308,6 +308,12 @@ func (portal *Portal) convertDiscordMessage(ctx context.Context, puppet *Puppet, parts = append(parts, part) } } + if len(parts) == 0 && msg.Thread != nil { + parts = append(parts, &ConvertedMessage{Type: event.EventMessage, Content: &event.MessageEventContent{ + MsgType: event.MsgText, + Body: fmt.Sprintf("Created a thread: %s", msg.Thread.Name), + }}) + } for _, part := range parts { puppet.addWebhookMeta(part, msg) puppet.addMemberMeta(part, msg)