diff --git a/commands.go b/commands.go index 42132c8..8c08807 100644 --- a/commands.go +++ b/commands.go @@ -23,6 +23,7 @@ import ( "github.com/skip2/go-qrcode" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -45,6 +46,7 @@ func (br *DiscordBridge) RegisterCommands() { cmdReconnect, cmdDisconnect, cmdGuilds, + cmdDeleteAllPortals, ) } @@ -234,6 +236,7 @@ var cmdGuilds = &commands.FullHandler{ func fnGuilds(ce *WrappedCommandEvent) { if len(ce.Args) == 0 { ce.Reply("**Usage**: `$cmdprefix guilds [guild ID] [--entire]`") + return } subcommand := strings.ToLower(ce.Args[0]) ce.Args = ce.Args[1:] @@ -283,3 +286,53 @@ func fnUnbridgeGuild(ce *WrappedCommandEvent) { ce.Reply("Successfully unbridged guild") } } + +var cmdDeleteAllPortals = &commands.FullHandler{ + Func: wrapCommand(fnDeleteAllPortals), + Name: "delete-all-portals", + Help: commands.HelpMeta{ + Section: commands.HelpSectionUnclassified, + Description: "Delete all portals.", + }, + RequiresAdmin: true, +} + +func fnDeleteAllPortals(ce *WrappedCommandEvent) { + portals := ce.Bridge.GetAllPortals() + if len(portals) == 0 { + ce.Reply("Didn't find any portals") + return + } + + leave := func(portal *Portal) { + if len(portal.MXID) > 0 { + _, _ = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{ + Reason: "Deleting portal", + UserID: ce.User.MXID, + }) + } + } + customPuppet := ce.Bridge.GetPuppetByCustomMXID(ce.User.MXID) + if customPuppet != nil && customPuppet.CustomIntent() != nil { + intent := customPuppet.CustomIntent() + leave = func(portal *Portal) { + if len(portal.MXID) > 0 { + _, _ = intent.LeaveRoom(portal.MXID) + _, _ = intent.ForgetRoom(portal.MXID) + } + } + } + ce.Reply("Found %d portals, deleting...", len(portals)) + for _, portal := range portals { + portal.Delete() + leave(portal) + } + ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.") + + go func() { + for _, portal := range portals { + portal.cleanup(false) + } + ce.Reply("Finished background cleanup of deleted portal rooms.") + }() +} diff --git a/config/bridge.go b/config/bridge.go index 669565c..70dbd6f 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -39,8 +39,7 @@ type BridgeConfig struct { SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"` SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - DefaultBridgeReceipts bool `yaml:"default_bridge_receipts"` - DefaultBridgePresence bool `yaml:"default_bridge_presence"` + FederateRooms bool `yaml:"federate_rooms"` DoublePuppetServerMap map[string]string `yaml:"double_puppet_server_map"` DoublePuppetAllowDiscovery bool `yaml:"double_puppet_allow_discovery"` diff --git a/config/upgrade.go b/config/upgrade.go index 9c2231c..7b6c675 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -31,8 +31,7 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Int, "bridge", "portal_message_buffer") helper.Copy(up.Bool, "bridge", "sync_with_custom_puppets") helper.Copy(up.Bool, "bridge", "sync_direct_chat_list") - helper.Copy(up.Bool, "bridge", "default_bridge_receipts") - helper.Copy(up.Bool, "bridge", "default_bridge_presence") + helper.Copy(up.Bool, "bridge", "federate_rooms") helper.Copy(up.Map, "bridge", "double_puppet_server_map") helper.Copy(up.Bool, "bridge", "double_puppet_allow_discovery") helper.Copy(up.Map, "bridge", "login_shared_secret_map") diff --git a/custompuppet.go b/custompuppet.go index 815cbf2..7796e47 100644 --- a/custompuppet.go +++ b/custompuppet.go @@ -324,9 +324,6 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error puppet.bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet } - puppet.EnablePresence = puppet.bridge.Config.Bridge.DefaultBridgePresence - puppet.EnableReceipts = puppet.bridge.Config.Bridge.DefaultBridgeReceipts - puppet.bridge.AS.StateStore.MarkRegistered(puppet.CustomMXID) puppet.Update() diff --git a/database/attachment.go b/database/attachment.go index 7451219..a4f8538 100644 --- a/database/attachment.go +++ b/database/attachment.go @@ -10,6 +10,69 @@ import ( "maunium.net/go/mautrix/util/dbutil" ) +type AttachmentQuery struct { + db *Database + log log.Logger +} + +const ( + attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver FROM attachment" +) + +func (aq *AttachmentQuery) New() *Attachment { + return &Attachment{ + db: aq.db, + log: aq.log, + } +} + +func (aq *AttachmentQuery) GetAllByDiscordMessageID(key PortalKey, discordMessageID string) []*Attachment { + query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3" + + return aq.getAll(query, key.ChannelID, key.Receiver, discordMessageID) +} + +func (aq *AttachmentQuery) getAll(query string, args ...interface{}) []*Attachment { + rows, err := aq.db.Query(query, args...) + if err != nil { + aq.log.Debugfln("getAll failed: %v", err) + + return nil + } + + if rows == nil { + return nil + } + + var attachments []*Attachment + for rows.Next() { + attachments = append(attachments, aq.New().Scan(rows)) + } + + return attachments +} + +func (aq *AttachmentQuery) GetByDiscordID(key PortalKey, discordMessageID, discordID string) *Attachment { + query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3 AND dcid=$4" + + return aq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID) +} + +func (aq *AttachmentQuery) GetByMatrixID(key PortalKey, matrixEventID id.EventID) *Attachment { + query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3" + + return aq.get(query, key.ChannelID, key.Receiver, matrixEventID) +} + +func (aq *AttachmentQuery) get(query string, args ...interface{}) *Attachment { + row := aq.db.QueryRow(query, args...) + if row == nil { + return nil + } + + return aq.New().Scan(row) +} + type Attachment struct { db *Database log log.Logger @@ -18,14 +81,14 @@ type Attachment struct { DiscordMessageID string DiscordAttachmentID string - MatrixEventID id.EventID + MXID id.EventID } func (a *Attachment) Scan(row dbutil.Scannable) *Attachment { err := row.Scan( + &a.DiscordAttachmentID, &a.DiscordMessageID, &a.Channel.ChannelID, &a.Channel.Receiver, - &a.DiscordMessageID, &a.DiscordAttachmentID, - &a.MatrixEventID) + &a.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { @@ -40,33 +103,32 @@ func (a *Attachment) Scan(row dbutil.Scannable) *Attachment { func (a *Attachment) Insert() { query := "INSERT INTO attachment" + - " (channel_id, receiver, discord_message_id, discord_attachment_id, " + - " matrix_event_id) VALUES ($1, $2, $3, $4, $5);" + " (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, " + + " mxid) VALUES ($1, $2, $3, $4, $5);" _, err := a.db.Exec( query, a.Channel.ChannelID, a.Channel.Receiver, a.DiscordMessageID, a.DiscordAttachmentID, - a.MatrixEventID, + a.MXID, ) if err != nil { - a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.Channel, a.DiscordMessageID, err) + a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err) } } func (a *Attachment) Delete() { query := "DELETE FROM attachment WHERE" + - " channel_id=$1 AND receiver=$2 AND discord_attachment_id=$3 AND" + - " matrix_event_id=$4" + " dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3" _, err := a.db.Exec( query, a.Channel.ChannelID, a.Channel.Receiver, - a.DiscordAttachmentID, a.MatrixEventID, + a.DiscordAttachmentID, ) if err != nil { - a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.Channel, a.DiscordAttachmentID, err) + a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err) } } diff --git a/database/attachmentquery.go b/database/attachmentquery.go deleted file mode 100644 index efd9bb8..0000000 --- a/database/attachmentquery.go +++ /dev/null @@ -1,73 +0,0 @@ -package database - -import ( - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type AttachmentQuery struct { - db *Database - log log.Logger -} - -const ( - attachmentSelect = "SELECT channel_id, receiver, discord_message_id," + - " discord_attachment_id, matrix_event_id FROM attachment" -) - -func (aq *AttachmentQuery) New() *Attachment { - return &Attachment{ - db: aq.db, - log: aq.log, - } -} - -func (aq *AttachmentQuery) GetAllByDiscordMessageID(key PortalKey, discordMessageID string) []*Attachment { - query := attachmentSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + - " discord_message_id=$3" - - return aq.getAll(query, key.ChannelID, key.Receiver, discordMessageID) -} - -func (aq *AttachmentQuery) getAll(query string, args ...interface{}) []*Attachment { - rows, err := aq.db.Query(query, args...) - if err != nil { - aq.log.Debugfln("getAll failed: %v", err) - - return nil - } - - if rows == nil { - return nil - } - - attachments := []*Attachment{} - for rows.Next() { - attachments = append(attachments, aq.New().Scan(rows)) - } - - return attachments -} - -func (aq *AttachmentQuery) GetByDiscordAttachmentID(key PortalKey, discordMessageID, discordID string) *Attachment { - query := attachmentSelect + " WHERE channel_id=$1 AND receiver=$2" + - " AND discord_message_id=$3 AND discord_id=$4" - - return aq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID) -} - -func (aq *AttachmentQuery) GetByMatrixID(key PortalKey, matrixEventID id.EventID) *Attachment { - query := attachmentSelect + " WHERE channel_id=$1 AND receiver=$2" + - " AND matrix_event_id=$3" - - return aq.get(query, key.ChannelID, key.Receiver, matrixEventID) -} - -func (aq *AttachmentQuery) get(query string, args ...interface{}) *Attachment { - row := aq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return aq.New().Scan(row) -} diff --git a/database/emoji.go b/database/emoji.go index 3e3a01b..9e10198 100644 --- a/database/emoji.go +++ b/database/emoji.go @@ -10,6 +10,43 @@ import ( "maunium.net/go/mautrix/util/dbutil" ) +type EmojiQuery struct { + db *Database + log log.Logger +} + +const ( + emojiSelect = "SELECT discord_id, discord_name, matrix_url FROM emoji" +) + +func (eq *EmojiQuery) New() *Emoji { + return &Emoji{ + db: eq.db, + log: eq.log, + } +} + +func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji { + query := emojiSelect + " WHERE discord_id=$1" + + return eq.get(query, discordID) +} + +func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji { + query := emojiSelect + " WHERE matrix_url=$1" + + return eq.get(query, matrixURL.String()) +} + +func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji { + row := eq.db.QueryRow(query, args...) + if row == nil { + return nil + } + + return eq.New().Scan(row) +} + type Emoji struct { db *Database log log.Logger diff --git a/database/emojiquery.go b/database/emojiquery.go deleted file mode 100644 index 98bf6b2..0000000 --- a/database/emojiquery.go +++ /dev/null @@ -1,44 +0,0 @@ -package database - -import ( - log "maunium.net/go/maulogger/v2" - - "maunium.net/go/mautrix/id" -) - -type EmojiQuery struct { - db *Database - log log.Logger -} - -const ( - emojiSelect = "SELECT discord_id, discord_name, matrix_url FROM emoji" -) - -func (eq *EmojiQuery) New() *Emoji { - return &Emoji{ - db: eq.db, - log: eq.log, - } -} - -func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji { - query := emojiSelect + " WHERE discord_id=$1" - - return eq.get(query, discordID) -} - -func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji { - query := emojiSelect + " WHERE matrix_url=$1" - - return eq.get(query, matrixURL.String()) -} - -func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji { - row := eq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return eq.New().Scan(row) -} diff --git a/database/guild.go b/database/guild.go index c913d2f..f69b725 100644 --- a/database/guild.go +++ b/database/guild.go @@ -3,12 +3,89 @@ package database import ( "database/sql" "errors" + "fmt" log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/util/dbutil" ) +type GuildQuery struct { + db *Database + log log.Logger +} + +const ( + guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild" +) + +func (gq *GuildQuery) New() *Guild { + return &Guild{ + db: gq.db, + log: gq.log, + } +} + +func (gq *GuildQuery) Get(discordID, guildID string) *Guild { + query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2" + + row := gq.db.QueryRow(query, discordID, guildID) + if row == nil { + return nil + } + + return gq.New().Scan(row) +} + +func (gq *GuildQuery) GetAll(discordID string) []*Guild { + query := guildSelect + " WHERE discord_id=$1" + + rows, err := gq.db.Query(query, discordID) + if err != nil || rows == nil { + return nil + } + + guilds := []*Guild{} + for rows.Next() { + guilds = append(guilds, gq.New().Scan(rows)) + } + + return guilds +} + +func (gq *GuildQuery) Prune(discordID string, guilds []string) { + // We need this interface slice because a variadic function can't mix + // arguements with a `...` expanded slice. + args := []interface{}{discordID} + + nGuilds := len(guilds) + if nGuilds <= 0 { + return + } + + gq.log.Debugfln("prunning guilds for %s", discordID) + + // Build the in query + inQuery := "$2" + for i := 1; i < nGuilds; i++ { + inQuery += fmt.Sprintf(", $%d", i+2) + } + + // Add the arguements for the build query + for _, guildID := range guilds { + args = append(args, guildID) + } + + // Now remove any guilds that the user has left. + query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" + + inQuery + ")" + + _, err := gq.db.Exec(query, args...) + if err != nil { + gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err) + } +} + type Guild struct { db *Database log log.Logger diff --git a/database/guildquery.go b/database/guildquery.go deleted file mode 100644 index 467a588..0000000 --- a/database/guildquery.go +++ /dev/null @@ -1,83 +0,0 @@ -package database - -import ( - "fmt" - - log "maunium.net/go/maulogger/v2" -) - -type GuildQuery struct { - db *Database - log log.Logger -} - -const ( - guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild" -) - -func (gq *GuildQuery) New() *Guild { - return &Guild{ - db: gq.db, - log: gq.log, - } -} - -func (gq *GuildQuery) Get(discordID, guildID string) *Guild { - query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2" - - row := gq.db.QueryRow(query, discordID, guildID) - if row == nil { - return nil - } - - return gq.New().Scan(row) -} - -func (gq *GuildQuery) GetAll(discordID string) []*Guild { - query := guildSelect + " WHERE discord_id=$1" - - rows, err := gq.db.Query(query, discordID) - if err != nil || rows == nil { - return nil - } - - guilds := []*Guild{} - for rows.Next() { - guilds = append(guilds, gq.New().Scan(rows)) - } - - return guilds -} - -func (gq *GuildQuery) Prune(discordID string, guilds []string) { - // We need this interface slice because a variadic function can't mix - // arguements with a `...` expanded slice. - args := []interface{}{discordID} - - nGuilds := len(guilds) - if nGuilds <= 0 { - return - } - - gq.log.Debugfln("prunning guilds for %s", discordID) - - // Build the in query - inQuery := "$2" - for i := 1; i < nGuilds; i++ { - inQuery += fmt.Sprintf(", $%d", i+2) - } - - // Add the arguements for the build query - for _, guildID := range guilds { - args = append(args, guildID) - } - - // Now remove any guilds that the user has left. - query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" + - inQuery + ")" - - _, err := gq.db.Exec(query, args...) - if err != nil { - gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err) - } -} diff --git a/database/message.go b/database/message.go index be1b482..b95ee5e 100644 --- a/database/message.go +++ b/database/message.go @@ -11,23 +11,77 @@ import ( "maunium.net/go/mautrix/util/dbutil" ) +type MessageQuery struct { + db *Database + log log.Logger +} + +const ( + messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid FROM message" +) + +func (mq *MessageQuery) New() *Message { + return &Message{ + db: mq.db, + log: mq.log, + } +} + +func (mq *MessageQuery) GetAll(key PortalKey) []*Message { + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2" + + rows, err := mq.db.Query(query, key.ChannelID, key.Receiver) + if err != nil || rows == nil { + return nil + } + + var messages []*Message + for rows.Next() { + messages = append(messages, mq.New().Scan(rows)) + } + + return messages +} + +func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message { + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3" + + row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID) + if row == nil { + mq.log.Debugfln("failed to find existing message for discord_id %s", discordID) + return nil + } + + return mq.New().Scan(row) +} + +func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message { + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3" + + row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid) + if row == nil { + return nil + } + + return mq.New().Scan(row) +} + type Message struct { db *Database log log.Logger - Channel PortalKey - DiscordID string - MatrixID id.EventID - - AuthorID string + Channel PortalKey + SenderID string Timestamp time.Time + + MXID id.EventID } func (m *Message) Scan(row dbutil.Scannable) *Message { var ts int64 - err := row.Scan(&m.Channel.ChannelID, &m.Channel.Receiver, &m.DiscordID, &m.MatrixID, &m.AuthorID, &ts) + err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { m.log.Errorln("Database scan failed:", err) @@ -44,38 +98,21 @@ func (m *Message) Scan(row dbutil.Scannable) *Message { } func (m *Message) Insert() { - query := "INSERT INTO message" + - " (channel_id, receiver, discord_message_id, matrix_message_id," + - " author_id, timestamp) VALUES ($1, $2, $3, $4, $5, $6)" + query := "INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid) VALUES ($1, $2, $3, $4, $5, $6)" - _, err := m.db.Exec(query, m.Channel.ChannelID, m.Channel.Receiver, - m.DiscordID, m.MatrixID, m.AuthorID, m.Timestamp.Unix()) + _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.Unix(), m.MXID) if err != nil { - m.log.Warnfln("Failed to insert %s@%s: %v", m.Channel, m.DiscordID, err) + m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err) } } func (m *Message) Delete() { - query := "DELETE FROM message" + - " WHERE channel_id=$1 AND receiver=$2 AND discord_message_id=$3 AND" + - " matrix_message_id=$4" + query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3" - _, err := m.db.Exec(query, m.Channel.ChannelID, m.Channel.Receiver, - m.DiscordID, m.MatrixID) + _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver) if err != nil { - m.log.Warnfln("Failed to delete %s@%s: %v", m.Channel, m.DiscordID, err) - } -} - -func (m *Message) UpdateMatrixID(mxid id.EventID) { - query := "UPDATE message SET matrix_message_id=$1 WHERE channel_id=$2" + - " AND receiver=$3 AND discord_message_id=$4" - m.MatrixID = mxid - - _, err := m.db.Exec(query, m.MatrixID, m.Channel.ChannelID, m.Channel.Receiver, m.DiscordID) - if err != nil { - m.log.Warnfln("Failed to update %s@%s: %v", m.Channel, m.DiscordID, err) + m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err) } } diff --git a/database/messagequery.go b/database/messagequery.go deleted file mode 100644 index adea68f..0000000 --- a/database/messagequery.go +++ /dev/null @@ -1,64 +0,0 @@ -package database - -import ( - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type MessageQuery struct { - db *Database - log log.Logger -} - -const ( - messageSelect = "SELECT channel_id, receiver, discord_message_id," + - " matrix_message_id, author_id, timestamp FROM message" -) - -func (mq *MessageQuery) New() *Message { - return &Message{ - db: mq.db, - log: mq.log, - } -} - -func (mq *MessageQuery) GetAll(key PortalKey) []*Message { - query := messageSelect + " WHERE channeld_id=$1 AND receiver=$2" - - rows, err := mq.db.Query(query, key.ChannelID, key.Receiver) - if err != nil || rows == nil { - return nil - } - - messages := []*Message{} - for rows.Next() { - messages = append(messages, mq.New().Scan(rows)) - } - - return messages -} - -func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message { - query := messageSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + - " discord_message_id=$3" - - row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID) - if row == nil { - mq.log.Debugfln("failed to find existing message for discord_id %s", discordID) - return nil - } - - return mq.New().Scan(row) -} - -func (mq *MessageQuery) GetByMatrixID(key PortalKey, matrixID id.EventID) *Message { - query := messageSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + - " matrix_message_id=$3" - - row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, matrixID) - if row == nil { - return nil - } - - return mq.New().Scan(row) -} diff --git a/database/portal.go b/database/portal.go index c30650f..e04d316 100644 --- a/database/portal.go +++ b/database/portal.go @@ -11,23 +11,85 @@ import ( "maunium.net/go/mautrix/util/dbutil" ) +const ( + portalSelect = "SELECT dcid, receiver, mxid, name, topic, avatar," + + " avatar_url, type, other_user_id, first_event_id, encrypted" + + " FROM portal" +) + +type PortalQuery struct { + db *Database + log log.Logger +} + +func (pq *PortalQuery) New() *Portal { + return &Portal{ + db: pq.db, + log: pq.log, + } +} + +func (pq *PortalQuery) GetAll() []*Portal { + return pq.getAll(portalSelect) +} + +func (pq *PortalQuery) GetByID(key PortalKey) *Portal { + return pq.get(portalSelect+" WHERE dcid=$1 AND receiver=$2", key.ChannelID, key.Receiver) +} + +func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { + return pq.get(portalSelect+" WHERE mxid=$1", mxid) +} + +func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal { + return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM) +} + +func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal { + query := portalSelect + " portal WHERE receiver=$1 AND type=$2;" + + return pq.getAll(query, receiver, discordgo.ChannelTypeDM) +} + +func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal { + rows, err := pq.db.Query(query, args...) + if err != nil || rows == nil { + return nil + } + defer rows.Close() + + var portals []*Portal + for rows.Next() { + portals = append(portals, pq.New().Scan(rows)) + } + + return portals +} + +func (pq *PortalQuery) get(query string, args ...interface{}) *Portal { + row := pq.db.QueryRow(query, args...) + if row == nil { + return nil + } + + return pq.New().Scan(row) +} + type Portal struct { db *Database log log.Logger - Key PortalKey + Key PortalKey + Type discordgo.ChannelType + OtherUserID string + MXID id.RoomID - Name string - Topic string - - Encrypted bool - + Name string + Topic string Avatar string AvatarURL id.ContentURI - - Type discordgo.ChannelType - DMUser string + Encrypted bool FirstEventID id.EventID } @@ -37,7 +99,7 @@ func (p *Portal) Scan(row dbutil.Scannable) *Portal { var typ sql.NullInt32 err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name, - &p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID, + &p.Topic, &p.Avatar, &avatarURL, &typ, &p.OtherUserID, &firstEventID, &p.Encrypted) if err != nil { @@ -66,12 +128,12 @@ func (p *Portal) mxidPtr() *id.RoomID { func (p *Portal) Insert() { query := "INSERT INTO portal" + - " (channel_id, receiver, mxid, name, topic, avatar, avatar_url," + - " type, dmuser, first_event_id, encrypted)" + + " (dcid, receiver, mxid, name, topic, avatar, avatar_url," + + " type, other_user_id, first_event_id, encrypted)" + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(), - p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.DMUser, + p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(), p.Encrypted) if err != nil { @@ -82,11 +144,11 @@ func (p *Portal) Insert() { func (p *Portal) Update() { query := "UPDATE portal SET" + " mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," + - " dmuser=$7, first_event_id=$8, encrypted=$9" + - " WHERE channel_id=$10 AND receiver=$11" + " other_user_id=$7, first_event_id=$8, encrypted=$9" + + " WHERE dcid=$10 AND receiver=$11" _, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar, - p.AvatarURL.String(), p.Type, p.DMUser, p.FirstEventID.String(), + p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(), p.Encrypted, p.Key.ChannelID, p.Key.Receiver) @@ -96,7 +158,7 @@ func (p *Portal) Update() { } func (p *Portal) Delete() { - query := "DELETE FROM portal WHERE channel_id=$1 AND receiver=$2" + query := "DELETE FROM portal WHERE dcid=$1 AND receiver=$2" _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver) if err != nil { p.log.Warnfln("Failed to delete %s: %v", p.Key, err) diff --git a/database/portalkey.go b/database/portalkey.go index b597cf3..54c1a59 100644 --- a/database/portalkey.go +++ b/database/portalkey.go @@ -13,8 +13,8 @@ func NewPortalKey(channelID, receiver string) PortalKey { } func (key PortalKey) String() string { - if key.ChannelID == key.Receiver { - return key.Receiver + if key.Receiver == "" { + return key.ChannelID } return key.ChannelID + "-" + key.Receiver } diff --git a/database/portalquery.go b/database/portalquery.go deleted file mode 100644 index c055ad6..0000000 --- a/database/portalquery.go +++ /dev/null @@ -1,71 +0,0 @@ -package database - -import ( - "github.com/bwmarrin/discordgo" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -const ( - portalSelect = "SELECT channel_id, receiver, mxid, name, topic, avatar," + - " avatar_url, type, dmuser, first_event_id, encrypted" + - " FROM portal" -) - -type PortalQuery struct { - db *Database - log log.Logger -} - -func (pq *PortalQuery) New() *Portal { - return &Portal{ - db: pq.db, - log: pq.log, - } -} - -func (pq *PortalQuery) GetAll() []*Portal { - return pq.getAll(portalSelect) -} - -func (pq *PortalQuery) GetByID(key PortalKey) *Portal { - return pq.get(portalSelect+" WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver) -} - -func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { - return pq.get(portalSelect+" WHERE mxid=$1", mxid) -} - -func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal { - return pq.getAll(portalSelect+" WHERE dmuser=$1 AND type=$2", id, discordgo.ChannelTypeDM) -} - -func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal { - query := portalSelect + " portal WHERE receiver=$1 AND type=$2;" - - return pq.getAll(query, receiver, discordgo.ChannelTypeDM) -} - -func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal { - rows, err := pq.db.Query(query, args...) - if err != nil || rows == nil { - return nil - } - defer rows.Close() - - portals := []*Portal{} - for rows.Next() { - portals = append(portals, pq.New().Scan(rows)) - } - - return portals -} - -func (pq *PortalQuery) get(query string, args ...interface{}) *Portal { - row := pq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return pq.New().Scan(row) -} diff --git a/database/puppet.go b/database/puppet.go index 3cdde8b..c5f1ba9 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -11,11 +11,62 @@ import ( const ( puppetSelect = "SELECT id, display_name, avatar, avatar_url," + - " enable_presence, custom_mxid, access_token, next_batch," + - " enable_receipts" + + " custom_mxid, access_token, next_batch" + " FROM puppet " ) +type PuppetQuery struct { + db *Database + log log.Logger +} + +func (pq *PuppetQuery) New() *Puppet { + return &Puppet{ + db: pq.db, + log: pq.log, + } +} + +func (pq *PuppetQuery) Get(id string) *Puppet { + return pq.get(puppetSelect+" WHERE id=$1", id) +} + +func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet { + return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid) +} + +func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet { + row := pq.db.QueryRow(query, args...) + if row == nil { + return nil + } + + return pq.New().Scan(row) +} + +func (pq *PuppetQuery) GetAll() []*Puppet { + return pq.getAll(puppetSelect) +} + +func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet { + return pq.getAll(puppetSelect + " WHERE custom_mxid<>''") +} + +func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet { + rows, err := pq.db.Query(query, args...) + if err != nil || rows == nil { + return nil + } + defer rows.Close() + + puppets := []*Puppet{} + for rows.Next() { + puppets = append(puppets, pq.New().Scan(rows)) + } + + return puppets +} + type Puppet struct { db *Database log log.Logger @@ -26,23 +77,17 @@ type Puppet struct { Avatar string AvatarURL id.ContentURI - EnablePresence bool - CustomMXID id.UserID AccessToken string - - NextBatch string - - EnableReceipts bool + NextBatch string } func (p *Puppet) Scan(row dbutil.Scannable) *Puppet { var did, displayName, avatar, avatarURL sql.NullString - var enablePresence sql.NullBool var customMXID, accessToken, nextBatch sql.NullString - err := row.Scan(&did, &displayName, &avatar, &avatarURL, &enablePresence, - &customMXID, &accessToken, &nextBatch, &p.EnableReceipts) + err := row.Scan(&did, &displayName, &avatar, &avatarURL, + &customMXID, &accessToken, &nextBatch) if err != nil { if err != sql.ErrNoRows { @@ -56,7 +101,6 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet { p.DisplayName = displayName.String p.Avatar = avatar.String p.AvatarURL, _ = id.ParseContentURI(avatarURL.String) - p.EnablePresence = enablePresence.Bool p.CustomMXID = id.UserID(customMXID.String) p.AccessToken = accessToken.String p.NextBatch = nextBatch.String @@ -66,13 +110,13 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet { func (p *Puppet) Insert() { query := "INSERT INTO puppet" + - " (id, display_name, avatar, avatar_url, enable_presence," + - " custom_mxid, access_token, next_batch, enable_receipts)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " (id, display_name, avatar, avatar_url," + + " custom_mxid, access_token, next_batch)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7)" _, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar, - p.AvatarURL.String(), p.EnablePresence, p.CustomMXID, p.AccessToken, - p.NextBatch, p.EnableReceipts) + p.AvatarURL.String(), p.CustomMXID, p.AccessToken, + p.NextBatch) if err != nil { p.log.Warnfln("Failed to insert %s: %v", p.ID, err) @@ -81,14 +125,13 @@ func (p *Puppet) Insert() { func (p *Puppet) Update() { query := "UPDATE puppet" + - " SET display_name=$1, avatar=$2, avatar_url=$3, enable_presence=$4," + - " custom_mxid=$5, access_token=$6, next_batch=$7," + - " enable_receipts=$8" + - " WHERE id=$9" + " SET display_name=$1, avatar=$2, avatar_url=$3, " + + " custom_mxid=$4, access_token=$5, next_batch=$6" + + " WHERE id=$7" _, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(), - p.EnablePresence, p.CustomMXID, p.AccessToken, p.NextBatch, - p.EnableReceipts, p.ID) + p.CustomMXID, p.AccessToken, p.NextBatch, + p.ID) if err != nil { p.log.Warnfln("Failed to update %s: %v", p.ID, err) diff --git a/database/puppetquery.go b/database/puppetquery.go deleted file mode 100644 index 5ae32f1..0000000 --- a/database/puppetquery.go +++ /dev/null @@ -1,60 +0,0 @@ -package database - -import ( - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type PuppetQuery struct { - db *Database - log log.Logger -} - -func (pq *PuppetQuery) New() *Puppet { - return &Puppet{ - db: pq.db, - log: pq.log, - - EnablePresence: true, - } -} - -func (pq *PuppetQuery) Get(id string) *Puppet { - return pq.get(puppetSelect+" WHERE id=$1", id) -} - -func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet { - return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid) -} - -func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet { - row := pq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return pq.New().Scan(row) -} - -func (pq *PuppetQuery) GetAll() []*Puppet { - return pq.getAll(puppetSelect) -} - -func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet { - return pq.getAll(puppetSelect + " WHERE custom_mxid<>''") -} - -func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet { - rows, err := pq.db.Query(query, args...) - if err != nil || rows == nil { - return nil - } - defer rows.Close() - - puppets := []*Puppet{} - for rows.Next() { - puppets = append(puppets, pq.New().Scan(rows)) - } - - return puppets -} diff --git a/database/reaction.go b/database/reaction.go index 8f51209..41037c1 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -10,97 +10,102 @@ import ( "maunium.net/go/mautrix/util/dbutil" ) +type ReactionQuery struct { + db *Database + log log.Logger +} + +const ( + reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, mxid FROM reaction" +) + +func (rq *ReactionQuery) New() *Reaction { + return &Reaction{ + db: rq.db, + log: rq.log, + } +} + +func (rq *ReactionQuery) GetAllForMessage(key PortalKey, discordMessageID string) []*Reaction { + query := reactionSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3" + + return rq.getAll(query, key.ChannelID, key.Receiver, discordMessageID) +} + +func (rq *ReactionQuery) getAll(query string, args ...interface{}) []*Reaction { + rows, err := rq.db.Query(query, args...) + if err != nil || rows == nil { + return nil + } + + var reactions []*Reaction + for rows.Next() { + reactions = append(reactions, rq.New().Scan(rows)) + } + + return reactions +} + +func (rq *ReactionQuery) GetByDiscordID(key PortalKey, msgID, sender, emojiName string) *Reaction { + query := reactionSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3 AND dc_sender=$4 AND dc_emoji_name=$5" + + return rq.get(query, key.ChannelID, key.Receiver, msgID, sender, emojiName) +} + +func (rq *ReactionQuery) GetByMXID(mxid id.EventID) *Reaction { + query := reactionSelect + " WHERE mxid=$1" + + return rq.get(query, mxid) +} + +func (rq *ReactionQuery) get(query string, args ...interface{}) *Reaction { + row := rq.db.QueryRow(query, args...) + if row == nil { + return nil + } + + return rq.New().Scan(row) +} + type Reaction struct { db *Database log log.Logger - Channel PortalKey + Channel PortalKey + MessageID string + Sender string + EmojiName string - DiscordMessageID string - MatrixEventID id.EventID - - // The discord ID of who create this reaction - AuthorID string - - MatrixName string - MatrixURL string // Used for custom emoji - - DiscordID string // The id or unicode of the emoji for discord + MXID id.EventID } func (r *Reaction) Scan(row dbutil.Scannable) *Reaction { - var discordID sql.NullString - - err := row.Scan( - &r.Channel.ChannelID, &r.Channel.Receiver, - &r.DiscordMessageID, &r.MatrixEventID, - &r.AuthorID, - &r.MatrixName, &r.MatrixURL, - &discordID) - + err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { r.log.Errorln("Database scan failed:", err) } - return nil } - r.DiscordID = discordID.String - return r } func (r *Reaction) Insert() { - query := "INSERT INTO reaction" + - " (channel_id, receiver, discord_message_id, matrix_event_id," + - " author_id, matrix_name, matrix_url, discord_id)" + - " VALUES($1, $2, $3, $4, $5, $6, $7, $8);" - - var discordID sql.NullString - - if r.DiscordID != "" { - discordID = sql.NullString{r.DiscordID, true} - } - - _, err := r.db.Exec( - query, - r.Channel.ChannelID, r.Channel.Receiver, - r.DiscordMessageID, r.MatrixEventID, - r.AuthorID, - r.MatrixName, r.MatrixURL, - discordID, - ) - + query := ` + INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, mxid) + VALUES($1, $2, $3, $4, $5, $6) + ` + _, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, r.MXID) if err != nil { - r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.Channel, r.DiscordMessageID, err) + r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err) } } -func (r *Reaction) Update() { - // TODO: determine if we need this. The only scenario I can think of that - // would require this is if we insert a custom emoji before uploading to - // the homeserver? -} - func (r *Reaction) Delete() { - query := "DELETE FROM reaction WHERE" + - " channel_id=$1 AND receiver=$2 AND discord_message_id=$3 AND" + - " author_id=$4 AND discord_id=$5" - - var discordID sql.NullString - if r.DiscordID != "" { - discordID = sql.NullString{r.DiscordID, true} - } - - _, err := r.db.Exec( - query, - r.Channel.ChannelID, r.Channel.Receiver, - r.DiscordMessageID, r.AuthorID, - discordID, - ) - + query := "DELETE FROM reaction WHERE dc_msg_id=$1 AND dc_sender=$2 AND dc_emoji_name=$3" + _, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName) if err != nil { - r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.Channel, r.DiscordMessageID, err) + r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.MessageID, r.Channel, err) } } diff --git a/database/reactionquery.go b/database/reactionquery.go deleted file mode 100644 index d0bc0af..0000000 --- a/database/reactionquery.go +++ /dev/null @@ -1,75 +0,0 @@ -package database - -import ( - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type ReactionQuery struct { - db *Database - log log.Logger -} - -const ( - reactionSelect = "SELECT channel_id, receiver, discord_message_id," + - " matrix_event_id, author_id, matrix_name, matrix_url, " + - " discord_id FROM reaction" -) - -func (rq *ReactionQuery) New() *Reaction { - return &Reaction{ - db: rq.db, - log: rq.log, - } -} - -func (rq *ReactionQuery) GetAllByDiscordID(key PortalKey, discordMessageID string) []*Reaction { - query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + - " discord_message_id=$3" - - return rq.getAll(query, key.ChannelID, key.Receiver, discordMessageID) -} - -func (rq *ReactionQuery) GetAllByMatrixID(key PortalKey, matrixEventID id.EventID) []*Reaction { - query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + - " matrix_event_id=$3" - - return rq.getAll(query, key.ChannelID, key.Receiver, matrixEventID) -} - -func (rq *ReactionQuery) getAll(query string, args ...interface{}) []*Reaction { - rows, err := rq.db.Query(query) - if err != nil || rows == nil { - return nil - } - - reactions := []*Reaction{} - for rows.Next() { - reactions = append(reactions, rq.New().Scan(rows)) - } - - return reactions -} - -func (rq *ReactionQuery) GetByDiscordID(key PortalKey, discordMessageID, discordID string) *Reaction { - query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" + - " AND discord_message_id=$3 AND discord_id=$4" - - return rq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID) -} - -func (rq *ReactionQuery) GetByMatrixID(key PortalKey, matrixEventID id.EventID) *Reaction { - query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" + - " AND matrix_event_id=$3" - - return rq.get(query, key.ChannelID, key.Receiver, matrixEventID) -} - -func (rq *ReactionQuery) get(query string, args ...interface{}) *Reaction { - row := rq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return rq.New().Scan(row) -} diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql new file mode 100644 index 0000000..4a63f90 --- /dev/null +++ b/database/upgrades/00-latest-revision.sql @@ -0,0 +1,92 @@ +-- v0 -> v2: Latest revision + +CREATE TABLE portal ( + dcid TEXT, + receiver TEXT, + other_user_id TEXT, + type INTEGER, + + mxid TEXT UNIQUE, + name TEXT NOT NULL, + topic TEXT NOT NULL, + avatar TEXT NOT NULL, + avatar_url TEXT NOT NULL, + encrypted BOOLEAN NOT NULL DEFAULT false, + + first_event_id TEXT NOT NULL, + + PRIMARY KEY (dcid, receiver) +); + +CREATE TABLE puppet ( + id TEXT PRIMARY KEY, + + name TEXT, + avatar TEXT, + avatar_url TEXT, + + custom_mxid TEXT, + access_token TEXT, + next_batch TEXT +); + +CREATE TABLE "user" ( + mxid TEXT PRIMARY KEY, + dcid TEXT UNIQUE, + + management_room TEXT, + + token TEXT +); + +CREATE TABLE message ( + dcid TEXT, + dc_chan_id TEXT, + dc_chan_receiver TEXT, + dc_sender TEXT NOT NULL, + timestamp BIGINT NOT NULL, + + mxid TEXT NOT NULL UNIQUE, + + PRIMARY KEY (dcid, dc_chan_id, dc_chan_receiver), + CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE +); + +CREATE TABLE reaction ( + dc_chan_id TEXT, + dc_chan_receiver TEXT, + dc_msg_id TEXT, + dc_sender TEXT, + dc_emoji_name TEXT, + + mxid TEXT NOT NULL UNIQUE, + + PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), + CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE +); + +CREATE TABLE attachment ( + dcid TEXT, + dc_msg_id TEXT, + dc_chan_id TEXT, + dc_chan_receiver TEXT, + + mxid TEXT NOT NULL UNIQUE, + + PRIMARY KEY (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver), + CONSTRAINT attachment_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE +); + +CREATE TABLE emoji ( + discord_id TEXT PRIMARY KEY, + discord_name TEXT, + matrix_url TEXT +); + +CREATE TABLE guild ( + discord_id TEXT NOT NULL, + guild_id TEXT NOT NULL, + guild_name TEXT NOT NULL, + bridge BOOLEAN DEFAULT FALSE, + PRIMARY KEY(discord_id, guild_id) +); diff --git a/database/upgrades/01-initial-revision.sql b/database/upgrades/01-initial-revision.sql deleted file mode 100644 index c8e48d0..0000000 --- a/database/upgrades/01-initial-revision.sql +++ /dev/null @@ -1,105 +0,0 @@ --- v1: Initial revision - -CREATE TABLE portal ( - channel_id TEXT, - receiver TEXT, - mxid TEXT UNIQUE, - - name TEXT NOT NULL, - topic TEXT NOT NULL, - - avatar TEXT NOT NULL, - avatar_url TEXT, - - encrypted BOOLEAN NOT NULL DEFAULT false, - - type INT, - dmuser TEXT, - - first_event_id TEXT, - - PRIMARY KEY (channel_id, receiver) -); - -CREATE TABLE puppet ( - id TEXT PRIMARY KEY, - display_name TEXT, - - avatar TEXT, - avatar_url TEXT, - - enable_presence BOOLEAN NOT NULL DEFAULT true, - enable_receipts BOOLEAN NOT NULL DEFAULT true, - - custom_mxid TEXT, - access_token TEXT, - next_batch TEXT -); - -CREATE TABLE "user" ( - mxid TEXT PRIMARY KEY, - id TEXT UNIQUE, - - management_room TEXT, - - token TEXT -); - -CREATE TABLE message ( - channel_id TEXT NOT NULL, - receiver TEXT NOT NULL, - - discord_message_id TEXT NOT NULL, - matrix_message_id TEXT NOT NULL UNIQUE, - - author_id TEXT NOT NULL, - timestamp BIGINT NOT NULL, - - PRIMARY KEY(discord_message_id, channel_id, receiver), - FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE -); - -CREATE TABLE reaction ( - channel_id TEXT NOT NULL, - receiver TEXT NOT NULL, - - discord_message_id TEXT NOT NULL, - matrix_event_id TEXT NOT NULL UNIQUE, - - author_id TEXT NOT NULL, - - matrix_name TEXT, - matrix_url TEXT, - - discord_id TEXT, - - UNIQUE (discord_id, author_id, discord_message_id, channel_id, receiver), - FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE -); - -CREATE TABLE attachment ( - channel_id TEXT NOT NULL, - receiver TEXT NOT NULL, - - discord_message_id TEXT NOT NULL, - discord_attachment_id TEXT NOT NULL, - - matrix_event_id TEXT NOT NULL UNIQUE, - - PRIMARY KEY(discord_attachment_id, matrix_event_id), - FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE -); - -CREATE TABLE emoji ( - discord_id TEXT PRIMARY KEY, - discord_name TEXT, - matrix_url TEXT -); - -CREATE TABLE guild ( - discord_id TEXT NOT NULL, - guild_id TEXT NOT NULL, - guild_name TEXT NOT NULL, - bridge BOOLEAN DEFAULT FALSE, - PRIMARY KEY(discord_id, guild_id) -); diff --git a/database/upgrades/02-column-renames.sql b/database/upgrades/02-column-renames.sql new file mode 100644 index 0000000..86b0cb0 --- /dev/null +++ b/database/upgrades/02-column-renames.sql @@ -0,0 +1,53 @@ +-- v2: Rename columns in message-related tables + +ALTER TABLE portal RENAME COLUMN dmuser TO other_user_id; +ALTER TABLE portal RENAME COLUMN channel_id TO dcid; + +ALTER TABLE "user" RENAME COLUMN id TO dcid; + +ALTER TABLE puppet DROP COLUMN enable_presence; +ALTER TABLE puppet DROP COLUMN enable_receipts; + +DROP TABLE message; +DROP TABLE reaction; +DROP TABLE attachment; + +CREATE TABLE message ( + dcid TEXT, + dc_chan_id TEXT, + dc_chan_receiver TEXT, + dc_sender TEXT NOT NULL, + timestamp BIGINT NOT NULL, + + mxid TEXT NOT NULL UNIQUE, + + PRIMARY KEY (dcid, dc_chan_id, dc_chan_receiver), + CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE +); + +CREATE TABLE reaction ( + dc_chan_id TEXT, + dc_chan_receiver TEXT, + dc_msg_id TEXT, + dc_sender TEXT, + dc_emoji_name TEXT, + + mxid TEXT NOT NULL UNIQUE, + + PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), + CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE +); + +CREATE TABLE attachment ( + dcid TEXT, + dc_msg_id TEXT, + dc_chan_id TEXT, + dc_chan_receiver TEXT, + + mxid TEXT NOT NULL UNIQUE, + + PRIMARY KEY (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver), + CONSTRAINT attachment_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE +); + +UPDATE portal SET receiver='' WHERE type<>1; diff --git a/database/user.go b/database/user.go index cc6ceee..1d7b419 100644 --- a/database/user.go +++ b/database/user.go @@ -9,6 +9,54 @@ import ( "maunium.net/go/mautrix/util/dbutil" ) +type UserQuery struct { + db *Database + log log.Logger +} + +func (uq *UserQuery) New() *User { + return &User{ + db: uq.db, + log: uq.log, + } +} + +func (uq *UserQuery) GetByMXID(userID id.UserID) *User { + query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE mxid=$1` + row := uq.db.QueryRow(query, userID) + if row == nil { + return nil + } + + return uq.New().Scan(row) +} + +func (uq *UserQuery) GetByID(id string) *User { + query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE dcid=$1` + row := uq.db.QueryRow(query, id) + if row == nil { + return nil + } + + return uq.New().Scan(row) +} + +func (uq *UserQuery) GetAll() []*User { + rows, err := uq.db.Query(`SELECT mxid, dcid, management_room, token FROM "user" WHERE token IS NOT NULL`) + if err != nil || rows == nil { + return nil + } + + defer rows.Close() + + users := []*User{} + for rows.Next() { + users = append(users, uq.New().Scan(rows)) + } + + return users +} + type User struct { db *Database log log.Logger @@ -46,9 +94,7 @@ func (u *User) Scan(row dbutil.Scannable) *User { } func (u *User) Insert() { - query := "INSERT INTO \"user\"" + - " (mxid, id, management_room, token)" + - " VALUES ($1, $2, $3, $4);" + query := "INSERT INTO \"user\" (mxid, dcid, management_room, token) VALUES ($1, $2, $3, $4)" var token sql.NullString var discordID sql.NullString @@ -71,9 +117,7 @@ func (u *User) Insert() { } func (u *User) Update() { - query := "UPDATE \"user\" SET" + - " id=$1, management_room=$2, token=$3" + - " WHERE mxid=$4;" + query := "UPDATE \"user\" SET dcid=$1, management_room=$2, token=$3 WHERE mxid=$4" var token sql.NullString var discordID sql.NullString diff --git a/database/userquery.go b/database/userquery.go deleted file mode 100644 index 25b1a84..0000000 --- a/database/userquery.go +++ /dev/null @@ -1,54 +0,0 @@ -package database - -import ( - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" -) - -type UserQuery struct { - db *Database - log log.Logger -} - -func (uq *UserQuery) New() *User { - return &User{ - db: uq.db, - log: uq.log, - } -} - -func (uq *UserQuery) GetByMXID(userID id.UserID) *User { - query := `SELECT mxid, id, management_room, token FROM "user" WHERE mxid=$1` - row := uq.db.QueryRow(query, userID) - if row == nil { - return nil - } - - return uq.New().Scan(row) -} - -func (uq *UserQuery) GetByID(id string) *User { - query := `SELECT mxid, id, management_room, token FROM "user" WHERE id=$1` - row := uq.db.QueryRow(query, id) - if row == nil { - return nil - } - - return uq.New().Scan(row) -} - -func (uq *UserQuery) GetAll() []*User { - rows, err := uq.db.Query(`SELECT mxid, id, management_room, token FROM "user" WHERE token IS NOT NULL`) - if err != nil || rows == nil { - return nil - } - - defer rows.Close() - - users := []*User{} - for rows.Next() { - users = append(users, uq.New().Scan(rows)) - } - - return users -} diff --git a/example-config.yaml b/example-config.yaml index 0758495..abb8c8c 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -81,11 +81,9 @@ bridge: # Note that updating the m.direct event is not atomic (except with mautrix-asmux) # and is therefore prone to race conditions. sync_direct_chat_list: false - # When double puppeting is enabled, users can use `!wa toggle` to change whether - # presence and read receipts are bridged. These settings set the default values. - # Existing users won't be affected when these are changed. - default_bridge_receipts: true - default_bridge_presence: true + # Whether or not created rooms should have federation enabled. + # If false, created portal rooms will never be federated. + federate_rooms: true # Servers to always allow double puppeting from double_puppet_server_map: example.com: https://example.com diff --git a/portal.go b/portal.go index 9c6d7e1..7ea9b65 100644 --- a/portal.go +++ b/portal.go @@ -174,8 +174,8 @@ func (portal *Portal) IsPrivateChat() bool { } func (portal *Portal) MainIntent() *appservice.IntentAPI { - if portal.IsPrivateChat() && portal.DMUser != "" { - return portal.bridge.GetPuppetByID(portal.DMUser).DefaultIntent() + if portal.IsPrivateChat() && portal.OtherUserID != "" { + return portal.bridge.GetPuppetByID(portal.OtherUserID).DefaultIntent() } return portal.bridge.Bot @@ -184,15 +184,13 @@ func (portal *Portal) MainIntent() *appservice.IntentAPI { func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) error { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() - - // If we have a matrix id the room should exist so we have nothing to do. if portal.MXID != "" { return nil } portal.Type = channel.Type if portal.Type == discordgo.ChannelTypeDM { - portal.DMUser = channel.Recipients[0].ID + portal.OtherUserID = channel.Recipients[0].ID } intent := portal.MainIntent() @@ -219,7 +217,9 @@ func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) e initialState := []*event.Event{} creationContent := make(map[string]interface{}) - creationContent["m.federate"] = false + if !portal.bridge.Config.Bridge.FederateRooms { + creationContent["m.federate"] = false + } var invite []id.UserID @@ -325,19 +325,14 @@ func (portal *Portal) ensureUserInvited(user *User) bool { return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat()) } -func (portal *Portal) markMessageHandled(msg *database.Message, discordID string, mxid id.EventID, authorID string, timestamp time.Time) *database.Message { - if msg == nil { - msg := portal.bridge.DB.Message.New() - msg.Channel = portal.Key - msg.DiscordID = discordID - msg.MatrixID = mxid - msg.AuthorID = authorID - msg.Timestamp = timestamp - msg.Insert() - } else { - msg.UpdateMatrixID(mxid) - } - +func (portal *Portal) markMessageHandled(discordID string, mxid id.EventID, authorID string, timestamp time.Time) *database.Message { + msg := portal.bridge.DB.Message.New() + msg.Channel = portal.Key + msg.DiscordID = discordID + msg.MXID = mxid + msg.SenderID = authorID + msg.Timestamp = timestamp + msg.Insert() return msg } @@ -410,7 +405,7 @@ func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgI dbAttachment.Channel = portal.Key dbAttachment.DiscordMessageID = msgID dbAttachment.DiscordAttachmentID = attachment.ID - dbAttachment.MatrixEventID = resp.EventID + dbAttachment.MXID = resp.EventID dbAttachment.Insert() } @@ -461,14 +456,14 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess MsgType: event.MsgText, } - if msg.MessageReference != nil { - key := database.PortalKey{msg.MessageReference.ChannelID, user.ID} - existing := portal.bridge.DB.Message.GetByDiscordID(key, msg.MessageReference.MessageID) + if msg.MessageReference != nil && msg.MessageReference.ChannelID == portal.Key.ChannelID { + //key := database.PortalKey{msg.MessageReference.ChannelID, user.ID} + replyTo := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID) - if existing != nil && existing.MatrixID != "" { + if replyTo != nil { content.RelatesTo = &event.RelatesTo{ Type: event.RelReply, - EventID: existing.MatrixID, + EventID: existing.MXID, } } } @@ -481,7 +476,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } ts, _ := msg.Timestamp.Parse() - portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts) + portal.markMessageHandled(msg.ID, resp.EventID, msg.Author.ID, ts) } // now run through any attachments the message has @@ -544,9 +539,9 @@ func (portal *Portal) handleDiscordMessagesUpdate(user *User, msg *discordgo.Mes // Finally run through any attachments still in the map and delete them // on the matrix side and our database. for _, attachment := range attachmentMap { - _, err := intent.RedactEvent(portal.MXID, attachment.MatrixEventID) + _, err := intent.RedactEvent(portal.MXID, attachment.MXID) if err != nil { - portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MatrixEventID, err) + portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err) } attachment.Delete() @@ -560,17 +555,17 @@ func (portal *Portal) handleDiscordMessagesUpdate(user *User, msg *discordgo.Mes MsgType: event.MsgText, } - content.SetEdit(existing.MatrixID) + content.SetEdit(existing.MXID) - resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) + _, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) if err != nil { portal.log.Warnfln("failed to send message %q to matrix: %v", msg.ID, err) return } - ts, _ := msg.Timestamp.Parse() - portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts) + //ts, _ := msg.Timestamp.Parse() + //portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts) } func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) { @@ -583,19 +578,12 @@ func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Mess // Find the message that we're working with. This could correctly return // nil if the message was just one or more attachments. existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) - - var intent *appservice.IntentAPI - - if portal.Type == discordgo.ChannelTypeDM { - intent = portal.bridge.GetPuppetByID(portal.DMUser).IntentFor(portal) - } else { - intent = portal.MainIntent() - } + intent := portal.MainIntent() if existing != nil { - _, err := intent.RedactEvent(portal.MXID, existing.MatrixID) + _, err := intent.RedactEvent(portal.MXID, existing.MXID) if err != nil { - portal.log.Warnfln("Failed to remove message %s: %v", existing.MatrixID, err) + portal.log.Warnfln("Failed to remove message %s: %v", existing.MXID, err) } existing.Delete() @@ -604,9 +592,9 @@ func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Mess // Now delete all of the existing attachments. attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID) for _, attachment := range attachments { - _, err := intent.RedactEvent(portal.MXID, attachment.MatrixEventID) + _, err := intent.RedactEvent(portal.MXID, attachment.MXID) if err != nil { - portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MatrixEventID, err) + portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err) } attachment.Delete() @@ -646,7 +634,6 @@ func (portal *Portal) encrypt(content *event.Content, eventType event.Type) (eve return eventType, nil } -const doublePuppetKey = "fi.mau.double_puppet_source" const doublePuppetValue = "mautrix-discord" func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, timestamp int64) (*mautrix.RespSendEvent, error) { @@ -656,7 +643,7 @@ func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType wrappedContent.Raw = map[string]interface{}{} } if intent.IsCustomPuppet { - wrappedContent.Raw[doublePuppetKey] = doublePuppetValue + wrappedContent.Raw[bridge.DoublePuppetKey] = doublePuppetValue } } var err error @@ -668,7 +655,7 @@ func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType if eventType == event.EventEncrypted { // Clear other custom keys if the event was encrypted, but keep the double puppet identifier if intent.IsCustomPuppet { - wrappedContent.Raw = map[string]interface{}{doublePuppetKey: doublePuppetValue} + wrappedContent.Raw = map[string]interface{}{bridge.DoublePuppetKey: doublePuppetValue} } else { wrappedContent.Raw = nil } @@ -700,13 +687,6 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { return } - existing := portal.bridge.DB.Message.GetByMatrixID(portal.Key, evt.ID) - if existing != nil { - portal.log.Debugln("not handling duplicate message", evt.ID) - - return - } - content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok { portal.log.Debugfln("Failed to handle event %s: unexpected parsed content type %T", evt.ID, evt.Content.Parsed) @@ -715,15 +695,15 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { } if content.RelatesTo != nil && content.RelatesTo.Type == event.RelReplace { - existing := portal.bridge.DB.Message.GetByMatrixID(portal.Key, content.RelatesTo.EventID) + edits := portal.bridge.DB.Message.GetByMXID(portal.Key, content.RelatesTo.EventID) - if existing != nil && existing.DiscordID != "" { + if edits != nil { // we don't have anything to save for the update message right now // as we're not tracking edited timestamps. _, err := sender.Session.ChannelMessageEdit(portal.Key.ChannelID, - existing.DiscordID, content.NewContent.Body) + edits.DiscordID, content.NewContent.Body) if err != nil { - portal.log.Errorln("Failed to update message %s: %v", existing.DiscordID, err) + portal.log.Errorln("Failed to update message %s: %v", edits.DiscordID, err) return } @@ -740,18 +720,18 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { sent := false if content.RelatesTo != nil && content.RelatesTo.Type == event.RelReply { - existing := portal.bridge.DB.Message.GetByMatrixID( + replyTo := portal.bridge.DB.Message.GetByMXID( portal.Key, content.RelatesTo.EventID, ) - if existing != nil && existing.DiscordID != "" { + if replyTo != nil { msg, err = sender.Session.ChannelMessageSendReply( portal.Key.ChannelID, content.Body, &discordgo.MessageReference{ ChannelID: portal.Key.ChannelID, - MessageID: existing.DiscordID, + MessageID: replyTo.DiscordID, }, ) if err == nil { @@ -771,13 +751,11 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { } msgSend := &discordgo.MessageSend{ - Files: []*discordgo.File{ - &discordgo.File{ - Name: content.Body, - ContentType: content.Info.MimeType, - Reader: bytes.NewReader(data), - }, - }, + Files: []*discordgo.File{{ + Name: content.Body, + ContentType: content.Info.MimeType, + Reader: bytes.NewReader(data), + }}, } msg, err = sender.Session.ChannelMessageSendComplex(portal.Key.ChannelID, msgSend) @@ -796,8 +774,9 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { dbMsg := portal.bridge.DB.Message.New() dbMsg.Channel = portal.Key dbMsg.DiscordID = msg.ID - dbMsg.MatrixID = evt.ID - dbMsg.AuthorID = sender.ID + dbMsg.MXID = evt.ID + dbMsg.SenderID = sender.ID + // TODO use actual timestamp dbMsg.Timestamp = time.Now() dbMsg.Insert() } @@ -927,7 +906,7 @@ func (portal *Portal) handleMatrixReaction(user *User, evt *event.Event) { var discordID string - msg := portal.bridge.DB.Message.GetByMatrixID(portal.Key, reaction.RelatesTo.EventID) + msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID) // Due to the differences in attachments between Discord and Matrix, if a // user reacts to a media message on discord our lookup above will fail @@ -976,13 +955,11 @@ func (portal *Portal) handleMatrixReaction(user *User, evt *event.Event) { } dbReaction := portal.bridge.DB.Reaction.New() - dbReaction.Channel.ChannelID = portal.Key.ChannelID - dbReaction.Channel.Receiver = portal.Key.Receiver - dbReaction.MatrixEventID = evt.ID - dbReaction.DiscordMessageID = discordID - dbReaction.AuthorID = user.ID - dbReaction.MatrixName = reaction.RelatesTo.Key - dbReaction.DiscordID = emojiID + dbReaction.Channel = portal.Key + dbReaction.MessageID = discordID + dbReaction.Sender = user.ID + dbReaction.EmojiName = emojiID + dbReaction.MXID = evt.ID dbReaction.Insert() } @@ -990,7 +967,7 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess intent := portal.bridge.GetPuppetByID(reaction.UserID).IntentFor(portal) var discordID string - var matrixID string + var matrixReaction string if reaction.Emoji.ID != "" { dbEmoji := portal.bridge.DB.Emoji.GetByDiscordID(reaction.Emoji.ID) @@ -1018,10 +995,10 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess } discordID = dbEmoji.DiscordID - matrixID = dbEmoji.MatrixURL.String() + matrixReaction = dbEmoji.MatrixURL.String() } else { discordID = reaction.Emoji.Name - matrixID = reaction.Emoji.Name + matrixReaction = reaction.Emoji.Name } // Find the message that we're working with. @@ -1033,8 +1010,7 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess } // Lookup an existing reaction - existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message.DiscordID, discordID) - + existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message.DiscordID, reaction.UserID, discordID) if !add { if existing == nil { portal.log.Debugln("Failed to remove reaction for unknown message", reaction.MessageID) @@ -1042,21 +1018,24 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess return } - _, err := intent.RedactEvent(portal.MXID, existing.MatrixEventID) + _, err := intent.RedactEvent(portal.MXID, existing.MXID) if err != nil { portal.log.Warnfln("Failed to remove reaction from %s: %v", portal.MXID, err) } existing.Delete() + return + } else if existing != nil { + portal.log.Debugfln("Ignoring duplicate reaction %s from %s to %s", discordID, reaction.UserID, message.DiscordID) return } content := event.Content{Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ - EventID: message.MatrixID, + EventID: message.MXID, Type: event.RelAnnotation, - Key: matrixID, + Key: matrixReaction, }, }} @@ -1070,13 +1049,10 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess if existing == nil { dbReaction := portal.bridge.DB.Reaction.New() dbReaction.Channel = portal.Key - dbReaction.DiscordMessageID = message.DiscordID - dbReaction.MatrixEventID = resp.EventID - dbReaction.AuthorID = reaction.UserID - - dbReaction.MatrixName = matrixID - dbReaction.DiscordID = discordID - + dbReaction.MessageID = message.DiscordID + dbReaction.Sender = reaction.UserID + dbReaction.EmojiName = discordID + dbReaction.MXID = resp.EventID dbReaction.Insert() } } @@ -1087,7 +1063,7 @@ func (portal *Portal) handleMatrixRedaction(user *User, evt *event.Event) { } // First look if we're redacting a message - message := portal.bridge.DB.Message.GetByMatrixID(portal.Key, evt.Redacts) + message := portal.bridge.DB.Message.GetByMXID(portal.Key, evt.Redacts) if message != nil { if message.DiscordID != "" { err := user.Session.ChannelMessageDelete(portal.Key.ChannelID, message.DiscordID) @@ -1102,21 +1078,19 @@ func (portal *Portal) handleMatrixRedaction(user *User, evt *event.Event) { } // Now check if it's a reaction. - reaction := portal.bridge.DB.Reaction.GetByMatrixID(portal.Key, evt.Redacts) - if reaction != nil { - if reaction.DiscordID != "" { - err := user.Session.MessageReactionRemove(portal.Key.ChannelID, reaction.DiscordMessageID, reaction.DiscordID, reaction.AuthorID) - if err != nil { - portal.log.Debugfln("Failed to delete reaction %s for message %s: %v", reaction.DiscordID, reaction.DiscordMessageID, err) - } else { - reaction.Delete() - } + reaction := portal.bridge.DB.Reaction.GetByMXID(evt.Redacts) + if reaction != nil && reaction.Channel == portal.Key { + err := user.Session.MessageReactionRemove(portal.Key.ChannelID, reaction.MessageID, reaction.EmojiName, reaction.Sender) + if err != nil { + portal.log.Debugfln("Failed to delete reaction %s from %s: %v", reaction.EmojiName, reaction.MessageID, err) + } else { + reaction.Delete() } return } - portal.log.Warnfln("Failed to redact %s@%s: no event found", portal.Key, evt.Redacts) + portal.log.Warnfln("Failed to redact %s: no event found", evt.Redacts) } func (portal *Portal) update(user *User, channel *discordgo.Channel) { @@ -1150,9 +1124,9 @@ func (portal *Portal) update(user *User, channel *discordgo.Channel) { var url string if portal.Type == discordgo.ChannelTypeDM { - dmUser, err := user.Session.User(portal.DMUser) + dmUser, err := user.Session.User(portal.OtherUserID) if err != nil { - portal.log.Warnln("failed to lookup the dmuser", err) + portal.log.Warnln("failed to lookup the other user in DM", err) } else { url = dmUser.AvatarURL("") } diff --git a/user.go b/user.go index b5bf75e..cef4b2d 100644 --- a/user.go +++ b/user.go @@ -527,7 +527,7 @@ func (user *User) createChannel(c *discordgo.Channel) { portal.Type = c.Type if portal.Type == discordgo.ChannelTypeDM { - portal.DMUser = c.Recipients[0].ID + portal.OtherUserID = c.Recipients[0].ID } if c.Icon != "" {