Improve missed message backfilling
This commit is contained in:
259
backfill.go
Normal file
259
backfill.go
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/bwmarrin/discordgo"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"maunium.net/go/mautrix"
|
||||||
|
"maunium.net/go/mautrix/bridge/bridgeconfig"
|
||||||
|
"maunium.net/go/mautrix/event"
|
||||||
|
"maunium.net/go/mautrix/id"
|
||||||
|
|
||||||
|
"go.mau.fi/mautrix-discord/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (portal *Portal) ForwardBackfillMissed(source *User, meta *discordgo.Channel) {
|
||||||
|
limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
|
||||||
|
if portal.GuildID == "" {
|
||||||
|
limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
|
||||||
|
}
|
||||||
|
if limit == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log := portal.zlog.With().
|
||||||
|
Str("action", "missed event backfill").
|
||||||
|
Int("limit", limit).
|
||||||
|
Logger()
|
||||||
|
|
||||||
|
portal.forwardBackfillLock.Lock()
|
||||||
|
defer portal.forwardBackfillLock.Unlock()
|
||||||
|
|
||||||
|
lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
|
||||||
|
if lastMessage == nil || meta.LastMessageID == "" {
|
||||||
|
log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
|
||||||
|
return
|
||||||
|
} else if !shouldBackfill(lastMessage.DiscordID, meta.LastMessageID) {
|
||||||
|
log.Debug().
|
||||||
|
Str("last_bridged_message", lastMessage.DiscordID).
|
||||||
|
Str("last_server_message", meta.LastMessageID).
|
||||||
|
Msg("Not backfilling, last message in database is newer than last message in metadata")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug().
|
||||||
|
Str("last_bridged_message", lastMessage.DiscordID).
|
||||||
|
Str("last_server_message", meta.LastMessageID).
|
||||||
|
Msg("Backfilling missed messages")
|
||||||
|
if limit < 0 {
|
||||||
|
portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID)
|
||||||
|
} else {
|
||||||
|
portal.backfillLimitedMissed(log, source, limit, lastMessage.DiscordID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const messageFetchChunkSize = 50
|
||||||
|
|
||||||
|
func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string) ([]*discordgo.Message, bool, error) {
|
||||||
|
var messages []*discordgo.Message
|
||||||
|
var before string
|
||||||
|
var foundAll bool
|
||||||
|
for {
|
||||||
|
log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
|
||||||
|
newMessages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, before, "", "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
for i, msg := range newMessages {
|
||||||
|
if compareMessageIDs(msg.ID, until) <= 0 {
|
||||||
|
log.Debug().
|
||||||
|
Str("message_id", msg.ID).
|
||||||
|
Str("until_id", until).
|
||||||
|
Msg("Found message that was already bridged")
|
||||||
|
newMessages = newMessages[:i]
|
||||||
|
foundAll = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, newMessages...)
|
||||||
|
log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection")
|
||||||
|
if len(newMessages) <= messageFetchChunkSize || len(messages) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
before = newMessages[len(newMessages)-1].ID
|
||||||
|
}
|
||||||
|
if len(messages) > limit {
|
||||||
|
messages = messages[:limit]
|
||||||
|
}
|
||||||
|
return messages, foundAll, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) backfillLimitedMissed(log zerolog.Logger, source *User, limit int, after string) {
|
||||||
|
messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after)
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).Msg("Error collecting messages to forward backfill")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Info().
|
||||||
|
Int("count", len(messages)).
|
||||||
|
Bool("found_all", foundAll).
|
||||||
|
Msg("Collected messages to backfill")
|
||||||
|
sort.Sort(MessageSlice(messages))
|
||||||
|
if !foundAll {
|
||||||
|
_, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
|
||||||
|
MsgType: event.MsgNotice,
|
||||||
|
Body: "Some messages may have been missed here while the bridge was offline.",
|
||||||
|
}, nil, 0)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("Failed to send missed message warning")
|
||||||
|
} else {
|
||||||
|
log.Debug().Msg("Sent warning about possibly missed messages")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
portal.sendBackfillBatch(log, source, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string) {
|
||||||
|
for {
|
||||||
|
log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
|
||||||
|
messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, messageFetchChunkSize, "", after, "")
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
|
||||||
|
sort.Sort(MessageSlice(messages))
|
||||||
|
|
||||||
|
portal.sendBackfillBatch(log, source, messages)
|
||||||
|
|
||||||
|
if len(messages) < messageFetchChunkSize {
|
||||||
|
// Assume that was all the missing messages
|
||||||
|
log.Debug().Msg("Chunk had less than 50 messages, stopping backfill")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
after = messages[len(messages)-1].ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message) {
|
||||||
|
if portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry {
|
||||||
|
log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
|
||||||
|
portal.forwardBatchSend(log, source, messages)
|
||||||
|
} else {
|
||||||
|
log.Debug().Msg("Not using hungryserv, sending messages one by one")
|
||||||
|
for _, msg := range messages {
|
||||||
|
portal.handleDiscordMessageCreate(source, msg, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message) {
|
||||||
|
evts := make([]*event.Event, 0, len(messages))
|
||||||
|
dbMessages := make([]database.Message, 0, len(messages))
|
||||||
|
for _, msg := range messages {
|
||||||
|
for _, mention := range msg.Mentions {
|
||||||
|
puppet := portal.bridge.GetPuppetByID(mention.ID)
|
||||||
|
puppet.UpdateInfo(nil, mention)
|
||||||
|
}
|
||||||
|
|
||||||
|
puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
|
||||||
|
puppet.UpdateInfo(source, msg.Author)
|
||||||
|
intent := puppet.IntentFor(portal)
|
||||||
|
replyTo := portal.getReplyTarget(source, msg.MessageReference, true)
|
||||||
|
|
||||||
|
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
|
||||||
|
parts := portal.convertDiscordMessage(intent, msg)
|
||||||
|
for i, part := range parts {
|
||||||
|
if replyTo != nil {
|
||||||
|
part.Content.RelatesTo = &event.RelatesTo{InReplyTo: replyTo}
|
||||||
|
// Only set reply for first event
|
||||||
|
replyTo = nil
|
||||||
|
}
|
||||||
|
partName := part.AttachmentID
|
||||||
|
// Always use blank part name for first part so that replies and other things
|
||||||
|
// can reference it without knowing about attachments.
|
||||||
|
if i == 0 {
|
||||||
|
partName = ""
|
||||||
|
}
|
||||||
|
evts = append(evts, &event.Event{
|
||||||
|
ID: portal.deterministicEventID(msg.ID, partName),
|
||||||
|
Type: part.Type,
|
||||||
|
Sender: intent.UserID,
|
||||||
|
Timestamp: ts.UnixMilli(),
|
||||||
|
Content: event.Content{
|
||||||
|
Parsed: part.Content,
|
||||||
|
Raw: part.Extra,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
dbMessages = append(dbMessages, database.Message{
|
||||||
|
Channel: portal.Key,
|
||||||
|
DiscordID: msg.ID,
|
||||||
|
SenderID: msg.Author.ID,
|
||||||
|
Timestamp: ts,
|
||||||
|
AttachmentID: part.AttachmentID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Info().Int("parts", len(evts)).Msg("Converted messages to backfill")
|
||||||
|
resp, err := portal.MainIntent().BatchSend(portal.MXID, &mautrix.ReqBatchSend{
|
||||||
|
BeeperNewMessages: true,
|
||||||
|
Events: evts,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Err(err).Msg("Error sending backfill batch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, evtID := range resp.EventIDs {
|
||||||
|
dbMessages[i].MXID = evtID
|
||||||
|
}
|
||||||
|
portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
|
||||||
|
log.Info().Msg("Inserted backfilled batch to database")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {
|
||||||
|
data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName)
|
||||||
|
sum := sha256.Sum256([]byte(data))
|
||||||
|
return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:])))
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareMessageIDs compares two Discord message IDs.
|
||||||
|
//
|
||||||
|
// If the first ID is lower, -1 is returned.
|
||||||
|
// If the second ID is lower, 1 is returned.
|
||||||
|
// If the IDs are equal, 0 is returned.
|
||||||
|
func compareMessageIDs(id1, id2 string) int {
|
||||||
|
if id1 == id2 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if len(id1) < len(id2) {
|
||||||
|
return -1
|
||||||
|
} else if len(id2) < len(id1) {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
if id1 < id2 {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool {
|
||||||
|
return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageSlice []*discordgo.Message
|
||||||
|
|
||||||
|
var _ sort.Interface = (MessageSlice)(nil)
|
||||||
|
|
||||||
|
func (a MessageSlice) Len() int {
|
||||||
|
return len(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a MessageSlice) Swap(i, j int) {
|
||||||
|
a[i], a[j] = a[j], a[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a MessageSlice) Less(i, j int) bool {
|
||||||
|
return compareMessageIDs(a[i].ID, a[j].ID) == -1
|
||||||
|
}
|
||||||
@@ -66,6 +66,14 @@ type BridgeConfig struct {
|
|||||||
CommandPrefix string `yaml:"command_prefix"`
|
CommandPrefix string `yaml:"command_prefix"`
|
||||||
ManagementRoomText bridgeconfig.ManagementRoomTexts `yaml:"management_room_text"`
|
ManagementRoomText bridgeconfig.ManagementRoomTexts `yaml:"management_room_text"`
|
||||||
|
|
||||||
|
Backfill struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
Limits struct {
|
||||||
|
Initial BackfillLimitPart `yaml:"initial"`
|
||||||
|
Missed BackfillLimitPart `yaml:"missed"`
|
||||||
|
} `yaml:"limits"`
|
||||||
|
} `yaml:"backfill"`
|
||||||
|
|
||||||
Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"`
|
Encryption bridgeconfig.EncryptionConfig `yaml:"encryption"`
|
||||||
|
|
||||||
Provisioning struct {
|
Provisioning struct {
|
||||||
@@ -81,6 +89,11 @@ type BridgeConfig struct {
|
|||||||
guildNameTemplate *template.Template `yaml:"-"`
|
guildNameTemplate *template.Template `yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BackfillLimitPart struct {
|
||||||
|
DM int `yaml:"dm"`
|
||||||
|
Channel int `yaml:"channel"`
|
||||||
|
}
|
||||||
|
|
||||||
func (bc *BridgeConfig) GetResendBridgeInfo() bool {
|
func (bc *BridgeConfig) GetResendBridgeInfo() bool {
|
||||||
return bc.ResendBridgeInfo
|
return bc.ResendBridgeInfo
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,6 +67,11 @@ func DoUpgrade(helper *up.Helper) {
|
|||||||
helper.Copy(up.Str, "bridge", "management_room_text", "welcome_connected")
|
helper.Copy(up.Str, "bridge", "management_room_text", "welcome_connected")
|
||||||
helper.Copy(up.Str, "bridge", "management_room_text", "welcome_unconnected")
|
helper.Copy(up.Str, "bridge", "management_room_text", "welcome_unconnected")
|
||||||
helper.Copy(up.Str|up.Null, "bridge", "management_room_text", "additional_help")
|
helper.Copy(up.Str|up.Null, "bridge", "management_room_text", "additional_help")
|
||||||
|
helper.Copy(up.Bool, "bridge", "backfill", "enabled")
|
||||||
|
helper.Copy(up.Int, "bridge", "backfill", "limits", "initial", "dm")
|
||||||
|
helper.Copy(up.Int, "bridge", "backfill", "limits", "initial", "channel")
|
||||||
|
helper.Copy(up.Int, "bridge", "backfill", "limits", "missed", "dm")
|
||||||
|
helper.Copy(up.Int, "bridge", "backfill", "limits", "missed", "channel")
|
||||||
helper.Copy(up.Bool, "bridge", "encryption", "allow")
|
helper.Copy(up.Bool, "bridge", "encryption", "allow")
|
||||||
helper.Copy(up.Bool, "bridge", "encryption", "default")
|
helper.Copy(up.Bool, "bridge", "encryption", "default")
|
||||||
helper.Copy(up.Bool, "bridge", "encryption", "require")
|
helper.Copy(up.Bool, "bridge", "encryption", "require")
|
||||||
|
|||||||
@@ -95,6 +95,36 @@ func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
|
|||||||
return mq.New().Scan(row)
|
return mq.New().Scan(row)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)"
|
||||||
|
if mq.db.Dialect == dbutil.SQLite {
|
||||||
|
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
|
||||||
|
}
|
||||||
|
params := make([]interface{}, 2+len(msgs)*7)
|
||||||
|
placeholders := make([]string, len(msgs))
|
||||||
|
params[0] = key.ChannelID
|
||||||
|
params[1] = key.Receiver
|
||||||
|
for i, msg := range msgs {
|
||||||
|
baseIndex := 2 + i*7
|
||||||
|
params[baseIndex] = msg.DiscordID
|
||||||
|
params[baseIndex+1] = msg.AttachmentID
|
||||||
|
params[baseIndex+2] = msg.EditIndex
|
||||||
|
params[baseIndex+3] = msg.SenderID
|
||||||
|
params[baseIndex+4] = msg.Timestamp
|
||||||
|
params[baseIndex+5] = msg.ThreadID
|
||||||
|
params[baseIndex+6] = msg.MXID
|
||||||
|
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7)
|
||||||
|
}
|
||||||
|
_, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
|
||||||
|
if err != nil {
|
||||||
|
mq.log.Warnfln("Failed to insert %d messages: %v", len(msgs), err)
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
db *Database
|
db *Database
|
||||||
log log.Logger
|
log log.Logger
|
||||||
@@ -152,7 +182,7 @@ type MessagePart struct {
|
|||||||
MXID id.EventID
|
MXID id.EventID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) MassInsert(msgs []MessagePart) {
|
func (m *Message) MassInsertParts(msgs []MessagePart) {
|
||||||
if len(msgs) == 0 {
|
if len(msgs) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -186,6 +186,20 @@ bridge:
|
|||||||
# Optional extra text sent when joining a management room.
|
# Optional extra text sent when joining a management room.
|
||||||
additional_help: ""
|
additional_help: ""
|
||||||
|
|
||||||
|
backfill:
|
||||||
|
# Should backfill be enabled at all?
|
||||||
|
enabled: false
|
||||||
|
# Limits for backfilling. Set to 0 to disable that type of backfill, or -1 for unlimited.
|
||||||
|
limits:
|
||||||
|
# Initial backfill (when creating portal).
|
||||||
|
initial:
|
||||||
|
dm: 50
|
||||||
|
channel: 0
|
||||||
|
# Missed message backfill (on startup).
|
||||||
|
missed:
|
||||||
|
dm: 50
|
||||||
|
channel: 0
|
||||||
|
|
||||||
# End-to-bridge encryption support options.
|
# End-to-bridge encryption support options.
|
||||||
#
|
#
|
||||||
# See https://docs.mau.fi/bridges/general/end-to-bridge-encryption.html for more info.
|
# See https://docs.mau.fi/bridges/general/end-to-bridge-encryption.html for more info.
|
||||||
|
|||||||
136
portal.go
136
portal.go
@@ -5,7 +5,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -13,8 +12,10 @@ import (
|
|||||||
|
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
"github.com/gabriel-vasile/mimetype"
|
"github.com/gabriel-vasile/mimetype"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
|
"maunium.net/go/maulogger/v2/maulogadapt"
|
||||||
|
|
||||||
log "maunium.net/go/maulogger/v2"
|
"maunium.net/go/maulogger/v2"
|
||||||
|
|
||||||
"maunium.net/go/mautrix"
|
"maunium.net/go/mautrix"
|
||||||
"maunium.net/go/mautrix/appservice"
|
"maunium.net/go/mautrix/appservice"
|
||||||
@@ -52,7 +53,9 @@ type Portal struct {
|
|||||||
Guild *Guild
|
Guild *Guild
|
||||||
|
|
||||||
bridge *DiscordBridge
|
bridge *DiscordBridge
|
||||||
log log.Logger
|
// Deprecated
|
||||||
|
log maulogger.Logger
|
||||||
|
zlog zerolog.Logger
|
||||||
|
|
||||||
roomCreateLock sync.Mutex
|
roomCreateLock sync.Mutex
|
||||||
encryptLock sync.Mutex
|
encryptLock sync.Mutex
|
||||||
@@ -65,6 +68,8 @@ type Portal struct {
|
|||||||
commands map[string]*discordgo.ApplicationCommand
|
commands map[string]*discordgo.ApplicationCommand
|
||||||
commandsLock sync.RWMutex
|
commandsLock sync.RWMutex
|
||||||
|
|
||||||
|
forwardBackfillLock sync.Mutex
|
||||||
|
|
||||||
currentlyTyping []id.UserID
|
currentlyTyping []id.UserID
|
||||||
currentlyTypingLock sync.Mutex
|
currentlyTypingLock sync.Mutex
|
||||||
}
|
}
|
||||||
@@ -233,7 +238,10 @@ func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal {
|
|||||||
portal := &Portal{
|
portal := &Portal{
|
||||||
Portal: dbPortal,
|
Portal: dbPortal,
|
||||||
bridge: br,
|
bridge: br,
|
||||||
log: br.Log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)),
|
zlog: br.ZLog.With().
|
||||||
|
Str("channel_id", dbPortal.Key.ChannelID).
|
||||||
|
Str("channel_receiver", dbPortal.Key.Receiver).
|
||||||
|
Logger(),
|
||||||
|
|
||||||
discordMessages: make(chan portalDiscordMessage, br.Config.Bridge.PortalMessageBuffer),
|
discordMessages: make(chan portalDiscordMessage, br.Config.Bridge.PortalMessageBuffer),
|
||||||
matrixMessages: make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer),
|
matrixMessages: make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer),
|
||||||
@@ -242,6 +250,7 @@ func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal {
|
|||||||
|
|
||||||
commands: make(map[string]*discordgo.ApplicationCommand),
|
commands: make(map[string]*discordgo.ApplicationCommand),
|
||||||
}
|
}
|
||||||
|
portal.log = maulogadapt.ZeroAsMau(&portal.zlog)
|
||||||
|
|
||||||
go portal.messageLoop()
|
go portal.messageLoop()
|
||||||
|
|
||||||
@@ -252,10 +261,13 @@ func (portal *Portal) messageLoop() {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-portal.matrixMessages:
|
case msg := <-portal.matrixMessages:
|
||||||
|
portal.forwardBackfillLock.Lock()
|
||||||
portal.handleMatrixMessages(msg)
|
portal.handleMatrixMessages(msg)
|
||||||
case msg := <-portal.discordMessages:
|
case msg := <-portal.discordMessages:
|
||||||
|
portal.forwardBackfillLock.Lock()
|
||||||
portal.handleDiscordMessages(msg)
|
portal.handleDiscordMessages(msg)
|
||||||
}
|
}
|
||||||
|
portal.forwardBackfillLock.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -549,7 +561,7 @@ func (portal *Portal) markMessageHandled(discordID string, editIndex int, author
|
|||||||
msg.SenderID = authorID
|
msg.SenderID = authorID
|
||||||
msg.Timestamp = timestamp
|
msg.Timestamp = timestamp
|
||||||
msg.ThreadID = threadID
|
msg.ThreadID = threadID
|
||||||
msg.MassInsert(parts)
|
msg.MassInsertParts(parts)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) {
|
func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) {
|
||||||
@@ -578,7 +590,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
|
|||||||
intent := puppet.IntentFor(portal)
|
intent := puppet.IntentFor(portal)
|
||||||
|
|
||||||
var discordThreadID string
|
var discordThreadID string
|
||||||
var threadRootEvent, lastThreadEvent, replyToEvent id.EventID
|
var threadRootEvent, lastThreadEvent id.EventID
|
||||||
if thread != nil {
|
if thread != nil {
|
||||||
discordThreadID = thread.ID
|
discordThreadID = thread.ID
|
||||||
threadRootEvent = thread.RootMXID
|
threadRootEvent = thread.RootMXID
|
||||||
@@ -588,30 +600,25 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
|
|||||||
lastThreadEvent = lastInThread.MXID
|
lastThreadEvent = lastInThread.MXID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
replyTo := portal.getReplyTarget(user, msg.MessageReference, false)
|
||||||
if msg.MessageReference != nil {
|
|
||||||
// This could be used to find cross-channel replies, but Matrix doesn't support those currently.
|
|
||||||
//key := database.PortalKey{msg.MessageReference.ChannelID, user.ID}
|
|
||||||
replyToMsg := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID)
|
|
||||||
if len(replyToMsg) > 0 {
|
|
||||||
replyToEvent = replyToMsg[0].MXID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
|
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
|
||||||
parts := portal.convertDiscordMessage(intent, msg)
|
parts := portal.convertDiscordMessage(intent, msg)
|
||||||
dbParts := make([]database.MessagePart, 0, len(parts))
|
dbParts := make([]database.MessagePart, 0, len(parts))
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
if (replyToEvent != "" || threadRootEvent != "") && part.Content.RelatesTo == nil {
|
if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil {
|
||||||
part.Content.RelatesTo = &event.RelatesTo{}
|
part.Content.RelatesTo = &event.RelatesTo{}
|
||||||
}
|
}
|
||||||
if threadRootEvent != "" {
|
if threadRootEvent != "" {
|
||||||
part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent)
|
part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent)
|
||||||
}
|
}
|
||||||
if replyToEvent != "" {
|
if replyTo != nil {
|
||||||
part.Content.RelatesTo.SetReplyTo(replyToEvent)
|
part.Content.RelatesTo.SetReplyTo(replyTo.EventID)
|
||||||
|
if replyTo.UnstableRoomID != "" {
|
||||||
|
part.Content.RelatesTo.InReplyTo.UnstableRoomID = replyTo.UnstableRoomID
|
||||||
|
}
|
||||||
// Only set reply for first event
|
// Only set reply for first event
|
||||||
replyToEvent = ""
|
replyTo = nil
|
||||||
}
|
}
|
||||||
resp, err := portal.sendMatrixMessage(intent, part.Type, part.Content, part.Extra, ts.UnixMilli())
|
resp, err := portal.sendMatrixMessage(intent, part.Type, part.Content, part.Extra, ts.UnixMilli())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -630,6 +637,42 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (portal *Portal) getReplyTarget(source *User, ref *discordgo.MessageReference, allowNonExistent bool) *event.InReplyTo {
|
||||||
|
if ref == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
isHungry := portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry
|
||||||
|
if !isHungry {
|
||||||
|
allowNonExistent = false
|
||||||
|
}
|
||||||
|
// TODO add config option for cross-room replies
|
||||||
|
crossRoomReplies := isHungry
|
||||||
|
|
||||||
|
targetPortal := portal
|
||||||
|
if ref.ChannelID != portal.Key.ChannelID && crossRoomReplies {
|
||||||
|
targetPortal = portal.bridge.GetExistingPortalByID(database.PortalKey{ChannelID: ref.ChannelID, Receiver: source.DiscordID})
|
||||||
|
if targetPortal == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
replyToMsg := portal.bridge.DB.Message.GetByDiscordID(targetPortal.Key, ref.MessageID)
|
||||||
|
if len(replyToMsg) > 0 {
|
||||||
|
if !crossRoomReplies {
|
||||||
|
return &event.InReplyTo{EventID: replyToMsg[0].MXID}
|
||||||
|
}
|
||||||
|
return &event.InReplyTo{
|
||||||
|
EventID: replyToMsg[0].MXID,
|
||||||
|
UnstableRoomID: targetPortal.MXID,
|
||||||
|
}
|
||||||
|
} else if allowNonExistent {
|
||||||
|
return &event.InReplyTo{
|
||||||
|
EventID: targetPortal.deterministicEventID(ref.MessageID, ""),
|
||||||
|
UnstableRoomID: targetPortal.MXID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
const JoinThreadReaction = "join thread"
|
const JoinThreadReaction = "join thread"
|
||||||
|
|
||||||
func (portal *Portal) sendThreadCreationNotice(thread *Thread) {
|
func (portal *Portal) sendThreadCreationNotice(thread *Thread) {
|
||||||
@@ -1066,9 +1109,9 @@ func (portal *Portal) sendMessageMetrics(evt *event.Event, err error, part strin
|
|||||||
evtDescription += fmt.Sprintf(" of %s", evt.Redacts)
|
evtDescription += fmt.Sprintf(" of %s", evt.Redacts)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
level := log.LevelError
|
level := maulogger.LevelError
|
||||||
if part == "Ignoring" {
|
if part == "Ignoring" {
|
||||||
level = log.LevelDebug
|
level = maulogger.LevelDebug
|
||||||
}
|
}
|
||||||
portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err)
|
portal.log.Logfln(level, "%s %s %s from %s: %v", part, msgType, evtDescription, evt.Sender, err)
|
||||||
reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err)
|
reason, statusCode, isCertain, sendNotice, _ := errorToStatusReason(err)
|
||||||
@@ -1376,7 +1419,7 @@ func (portal *Portal) cleanup(puppetsOnly bool) {
|
|||||||
portal.bridge.cleanupRoom(intent, portal.MXID, puppetsOnly, portal.log)
|
portal.bridge.cleanupRoom(intent, portal.MXID, puppetsOnly, portal.log)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (br *DiscordBridge) cleanupRoom(intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool, log log.Logger) {
|
func (br *DiscordBridge) cleanupRoom(intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool, log maulogger.Logger) {
|
||||||
members, err := intent.JoinedMembers(mxid)
|
members, err := intent.JoinedMembers(mxid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorln("Failed to get portal members for cleanup:", err)
|
log.Errorln("Failed to get portal members for cleanup:", err)
|
||||||
@@ -1975,52 +2018,3 @@ func (portal *Portal) UpdateInfo(source *User, meta *discordgo.Channel) *discord
|
|||||||
}
|
}
|
||||||
return meta
|
return meta
|
||||||
}
|
}
|
||||||
|
|
||||||
func (portal *Portal) ForwardBackfill(source *User, meta *discordgo.Channel) error {
|
|
||||||
portal.log.Debugln("Checking for missing messages to fill")
|
|
||||||
lastMessage := portal.bridge.DB.Message.GetLast(portal.Key)
|
|
||||||
if lastMessage == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
metaLastMessageID, err := strconv.ParseInt(meta.LastMessageID, 10, 0)
|
|
||||||
if err != nil {
|
|
||||||
portal.log.Errorfln("Last message ID %s isn't integer", meta.LastMessageID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dbLastMessageID, err := strconv.ParseInt(lastMessage.DiscordID, 10, 0)
|
|
||||||
if err != nil {
|
|
||||||
portal.log.Errorfln("Last message ID %s isn't integer", lastMessage.DiscordID)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if metaLastMessageID <= dbLastMessageID {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get up to 50 messages at a time until everything is fetched
|
|
||||||
for {
|
|
||||||
messages, err := source.Session.ChannelMessages(portal.Key.ChannelID, 50, "", lastMessage.DiscordID, "")
|
|
||||||
if err != nil {
|
|
||||||
portal.log.Debugln("Error getting messages to forward backfill", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Discord seems to return messages in reverse order,
|
|
||||||
// but no specific order is guaranteed by their API docs?
|
|
||||||
sort.Slice(messages, func(i, j int) bool {
|
|
||||||
return messages[i].Timestamp.Before(messages[j].Timestamp)
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, msg := range messages {
|
|
||||||
portal.handleDiscordMessageCreate(source, msg, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(messages) < 100 {
|
|
||||||
// Assume that was all the missing messages
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
lastMessage = portal.bridge.DB.Message.GetLast(portal.Key)
|
|
||||||
if lastMessage == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
4
user.go
4
user.go
@@ -728,7 +728,7 @@ func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel,
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
portal.UpdateInfo(user, meta)
|
portal.UpdateInfo(user, meta)
|
||||||
portal.ForwardBackfill(user, meta)
|
portal.ForwardBackfillMissed(user, meta)
|
||||||
}
|
}
|
||||||
user.MarkInPortal(database.UserPortal{
|
user.MarkInPortal(database.UserPortal{
|
||||||
DiscordID: portal.Key.ChannelID,
|
DiscordID: portal.Key.ChannelID,
|
||||||
@@ -843,7 +843,7 @@ func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSp
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
portal.UpdateInfo(user, ch)
|
portal.UpdateInfo(user, ch)
|
||||||
portal.ForwardBackfill(user, ch)
|
portal.ForwardBackfillMissed(user, ch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user