Update database schema and fix things

This commit is contained in:
Tulir Asokan
2022-05-27 15:58:09 +03:00
parent 8c66a064e1
commit 4c7829c304
27 changed files with 806 additions and 903 deletions

View File

@@ -23,6 +23,7 @@ import (
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@@ -45,6 +46,7 @@ func (br *DiscordBridge) RegisterCommands() {
cmdReconnect, cmdReconnect,
cmdDisconnect, cmdDisconnect,
cmdGuilds, cmdGuilds,
cmdDeleteAllPortals,
) )
} }
@@ -234,6 +236,7 @@ var cmdGuilds = &commands.FullHandler{
func fnGuilds(ce *WrappedCommandEvent) { func fnGuilds(ce *WrappedCommandEvent) {
if len(ce.Args) == 0 { if len(ce.Args) == 0 {
ce.Reply("**Usage**: `$cmdprefix guilds <status/bridge/unbridge> [guild ID] [--entire]`") ce.Reply("**Usage**: `$cmdprefix guilds <status/bridge/unbridge> [guild ID] [--entire]`")
return
} }
subcommand := strings.ToLower(ce.Args[0]) subcommand := strings.ToLower(ce.Args[0])
ce.Args = ce.Args[1:] ce.Args = ce.Args[1:]
@@ -283,3 +286,53 @@ func fnUnbridgeGuild(ce *WrappedCommandEvent) {
ce.Reply("Successfully unbridged guild") ce.Reply("Successfully unbridged guild")
} }
} }
var cmdDeleteAllPortals = &commands.FullHandler{
Func: wrapCommand(fnDeleteAllPortals),
Name: "delete-all-portals",
Help: commands.HelpMeta{
Section: commands.HelpSectionUnclassified,
Description: "Delete all portals.",
},
RequiresAdmin: true,
}
func fnDeleteAllPortals(ce *WrappedCommandEvent) {
portals := ce.Bridge.GetAllPortals()
if len(portals) == 0 {
ce.Reply("Didn't find any portals")
return
}
leave := func(portal *Portal) {
if len(portal.MXID) > 0 {
_, _ = portal.MainIntent().KickUser(portal.MXID, &mautrix.ReqKickUser{
Reason: "Deleting portal",
UserID: ce.User.MXID,
})
}
}
customPuppet := ce.Bridge.GetPuppetByCustomMXID(ce.User.MXID)
if customPuppet != nil && customPuppet.CustomIntent() != nil {
intent := customPuppet.CustomIntent()
leave = func(portal *Portal) {
if len(portal.MXID) > 0 {
_, _ = intent.LeaveRoom(portal.MXID)
_, _ = intent.ForgetRoom(portal.MXID)
}
}
}
ce.Reply("Found %d portals, deleting...", len(portals))
for _, portal := range portals {
portal.Delete()
leave(portal)
}
ce.Reply("Finished deleting portal info. Now cleaning up rooms in background.")
go func() {
for _, portal := range portals {
portal.cleanup(false)
}
ce.Reply("Finished background cleanup of deleted portal rooms.")
}()
}

View File

@@ -39,8 +39,7 @@ type BridgeConfig struct {
SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"` SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"`
SyncDirectChatList bool `yaml:"sync_direct_chat_list"` SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
DefaultBridgeReceipts bool `yaml:"default_bridge_receipts"` FederateRooms bool `yaml:"federate_rooms"`
DefaultBridgePresence bool `yaml:"default_bridge_presence"`
DoublePuppetServerMap map[string]string `yaml:"double_puppet_server_map"` DoublePuppetServerMap map[string]string `yaml:"double_puppet_server_map"`
DoublePuppetAllowDiscovery bool `yaml:"double_puppet_allow_discovery"` DoublePuppetAllowDiscovery bool `yaml:"double_puppet_allow_discovery"`

View File

@@ -31,8 +31,7 @@ func DoUpgrade(helper *up.Helper) {
helper.Copy(up.Int, "bridge", "portal_message_buffer") helper.Copy(up.Int, "bridge", "portal_message_buffer")
helper.Copy(up.Bool, "bridge", "sync_with_custom_puppets") helper.Copy(up.Bool, "bridge", "sync_with_custom_puppets")
helper.Copy(up.Bool, "bridge", "sync_direct_chat_list") helper.Copy(up.Bool, "bridge", "sync_direct_chat_list")
helper.Copy(up.Bool, "bridge", "default_bridge_receipts") helper.Copy(up.Bool, "bridge", "federate_rooms")
helper.Copy(up.Bool, "bridge", "default_bridge_presence")
helper.Copy(up.Map, "bridge", "double_puppet_server_map") helper.Copy(up.Map, "bridge", "double_puppet_server_map")
helper.Copy(up.Bool, "bridge", "double_puppet_allow_discovery") helper.Copy(up.Bool, "bridge", "double_puppet_allow_discovery")
helper.Copy(up.Map, "bridge", "login_shared_secret_map") helper.Copy(up.Map, "bridge", "login_shared_secret_map")

View File

@@ -324,9 +324,6 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error
puppet.bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet puppet.bridge.puppetsByCustomMXID[puppet.CustomMXID] = puppet
} }
puppet.EnablePresence = puppet.bridge.Config.Bridge.DefaultBridgePresence
puppet.EnableReceipts = puppet.bridge.Config.Bridge.DefaultBridgeReceipts
puppet.bridge.AS.StateStore.MarkRegistered(puppet.CustomMXID) puppet.bridge.AS.StateStore.MarkRegistered(puppet.CustomMXID)
puppet.Update() puppet.Update()

View File

@@ -10,6 +10,69 @@ import (
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
) )
type AttachmentQuery struct {
db *Database
log log.Logger
}
const (
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver FROM attachment"
)
func (aq *AttachmentQuery) New() *Attachment {
return &Attachment{
db: aq.db,
log: aq.log,
}
}
func (aq *AttachmentQuery) GetAllByDiscordMessageID(key PortalKey, discordMessageID string) []*Attachment {
query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3"
return aq.getAll(query, key.ChannelID, key.Receiver, discordMessageID)
}
func (aq *AttachmentQuery) getAll(query string, args ...interface{}) []*Attachment {
rows, err := aq.db.Query(query, args...)
if err != nil {
aq.log.Debugfln("getAll failed: %v", err)
return nil
}
if rows == nil {
return nil
}
var attachments []*Attachment
for rows.Next() {
attachments = append(attachments, aq.New().Scan(rows))
}
return attachments
}
func (aq *AttachmentQuery) GetByDiscordID(key PortalKey, discordMessageID, discordID string) *Attachment {
query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3 AND dcid=$4"
return aq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID)
}
func (aq *AttachmentQuery) GetByMatrixID(key PortalKey, matrixEventID id.EventID) *Attachment {
query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
return aq.get(query, key.ChannelID, key.Receiver, matrixEventID)
}
func (aq *AttachmentQuery) get(query string, args ...interface{}) *Attachment {
row := aq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return aq.New().Scan(row)
}
type Attachment struct { type Attachment struct {
db *Database db *Database
log log.Logger log log.Logger
@@ -18,14 +81,14 @@ type Attachment struct {
DiscordMessageID string DiscordMessageID string
DiscordAttachmentID string DiscordAttachmentID string
MatrixEventID id.EventID MXID id.EventID
} }
func (a *Attachment) Scan(row dbutil.Scannable) *Attachment { func (a *Attachment) Scan(row dbutil.Scannable) *Attachment {
err := row.Scan( err := row.Scan(
&a.DiscordAttachmentID, &a.DiscordMessageID,
&a.Channel.ChannelID, &a.Channel.Receiver, &a.Channel.ChannelID, &a.Channel.Receiver,
&a.DiscordMessageID, &a.DiscordAttachmentID, &a.MXID)
&a.MatrixEventID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
@@ -40,33 +103,32 @@ func (a *Attachment) Scan(row dbutil.Scannable) *Attachment {
func (a *Attachment) Insert() { func (a *Attachment) Insert() {
query := "INSERT INTO attachment" + query := "INSERT INTO attachment" +
" (channel_id, receiver, discord_message_id, discord_attachment_id, " + " (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, " +
" matrix_event_id) VALUES ($1, $2, $3, $4, $5);" " mxid) VALUES ($1, $2, $3, $4, $5);"
_, err := a.db.Exec( _, err := a.db.Exec(
query, query,
a.Channel.ChannelID, a.Channel.Receiver, a.Channel.ChannelID, a.Channel.Receiver,
a.DiscordMessageID, a.DiscordAttachmentID, a.DiscordMessageID, a.DiscordAttachmentID,
a.MatrixEventID, a.MXID,
) )
if err != nil { if err != nil {
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.Channel, a.DiscordMessageID, err) a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
} }
} }
func (a *Attachment) Delete() { func (a *Attachment) Delete() {
query := "DELETE FROM attachment WHERE" + query := "DELETE FROM attachment WHERE" +
" channel_id=$1 AND receiver=$2 AND discord_attachment_id=$3 AND" + " dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
" matrix_event_id=$4"
_, err := a.db.Exec( _, err := a.db.Exec(
query, query,
a.Channel.ChannelID, a.Channel.Receiver, a.Channel.ChannelID, a.Channel.Receiver,
a.DiscordAttachmentID, a.MatrixEventID, a.DiscordAttachmentID,
) )
if err != nil { if err != nil {
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.Channel, a.DiscordAttachmentID, err) a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
} }
} }

View File

@@ -1,73 +0,0 @@
package database
import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type AttachmentQuery struct {
db *Database
log log.Logger
}
const (
attachmentSelect = "SELECT channel_id, receiver, discord_message_id," +
" discord_attachment_id, matrix_event_id FROM attachment"
)
func (aq *AttachmentQuery) New() *Attachment {
return &Attachment{
db: aq.db,
log: aq.log,
}
}
func (aq *AttachmentQuery) GetAllByDiscordMessageID(key PortalKey, discordMessageID string) []*Attachment {
query := attachmentSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
" discord_message_id=$3"
return aq.getAll(query, key.ChannelID, key.Receiver, discordMessageID)
}
func (aq *AttachmentQuery) getAll(query string, args ...interface{}) []*Attachment {
rows, err := aq.db.Query(query, args...)
if err != nil {
aq.log.Debugfln("getAll failed: %v", err)
return nil
}
if rows == nil {
return nil
}
attachments := []*Attachment{}
for rows.Next() {
attachments = append(attachments, aq.New().Scan(rows))
}
return attachments
}
func (aq *AttachmentQuery) GetByDiscordAttachmentID(key PortalKey, discordMessageID, discordID string) *Attachment {
query := attachmentSelect + " WHERE channel_id=$1 AND receiver=$2" +
" AND discord_message_id=$3 AND discord_id=$4"
return aq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID)
}
func (aq *AttachmentQuery) GetByMatrixID(key PortalKey, matrixEventID id.EventID) *Attachment {
query := attachmentSelect + " WHERE channel_id=$1 AND receiver=$2" +
" AND matrix_event_id=$3"
return aq.get(query, key.ChannelID, key.Receiver, matrixEventID)
}
func (aq *AttachmentQuery) get(query string, args ...interface{}) *Attachment {
row := aq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return aq.New().Scan(row)
}

View File

@@ -10,6 +10,43 @@ import (
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
) )
type EmojiQuery struct {
db *Database
log log.Logger
}
const (
emojiSelect = "SELECT discord_id, discord_name, matrix_url FROM emoji"
)
func (eq *EmojiQuery) New() *Emoji {
return &Emoji{
db: eq.db,
log: eq.log,
}
}
func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji {
query := emojiSelect + " WHERE discord_id=$1"
return eq.get(query, discordID)
}
func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji {
query := emojiSelect + " WHERE matrix_url=$1"
return eq.get(query, matrixURL.String())
}
func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji {
row := eq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return eq.New().Scan(row)
}
type Emoji struct { type Emoji struct {
db *Database db *Database
log log.Logger log log.Logger

View File

@@ -1,44 +0,0 @@
package database
import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type EmojiQuery struct {
db *Database
log log.Logger
}
const (
emojiSelect = "SELECT discord_id, discord_name, matrix_url FROM emoji"
)
func (eq *EmojiQuery) New() *Emoji {
return &Emoji{
db: eq.db,
log: eq.log,
}
}
func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji {
query := emojiSelect + " WHERE discord_id=$1"
return eq.get(query, discordID)
}
func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji {
query := emojiSelect + " WHERE matrix_url=$1"
return eq.get(query, matrixURL.String())
}
func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji {
row := eq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return eq.New().Scan(row)
}

View File

@@ -3,12 +3,89 @@ package database
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
) )
type GuildQuery struct {
db *Database
log log.Logger
}
const (
guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild"
)
func (gq *GuildQuery) New() *Guild {
return &Guild{
db: gq.db,
log: gq.log,
}
}
func (gq *GuildQuery) Get(discordID, guildID string) *Guild {
query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2"
row := gq.db.QueryRow(query, discordID, guildID)
if row == nil {
return nil
}
return gq.New().Scan(row)
}
func (gq *GuildQuery) GetAll(discordID string) []*Guild {
query := guildSelect + " WHERE discord_id=$1"
rows, err := gq.db.Query(query, discordID)
if err != nil || rows == nil {
return nil
}
guilds := []*Guild{}
for rows.Next() {
guilds = append(guilds, gq.New().Scan(rows))
}
return guilds
}
func (gq *GuildQuery) Prune(discordID string, guilds []string) {
// We need this interface slice because a variadic function can't mix
// arguements with a `...` expanded slice.
args := []interface{}{discordID}
nGuilds := len(guilds)
if nGuilds <= 0 {
return
}
gq.log.Debugfln("prunning guilds for %s", discordID)
// Build the in query
inQuery := "$2"
for i := 1; i < nGuilds; i++ {
inQuery += fmt.Sprintf(", $%d", i+2)
}
// Add the arguements for the build query
for _, guildID := range guilds {
args = append(args, guildID)
}
// Now remove any guilds that the user has left.
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" +
inQuery + ")"
_, err := gq.db.Exec(query, args...)
if err != nil {
gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err)
}
}
type Guild struct { type Guild struct {
db *Database db *Database
log log.Logger log log.Logger

View File

@@ -1,83 +0,0 @@
package database
import (
"fmt"
log "maunium.net/go/maulogger/v2"
)
type GuildQuery struct {
db *Database
log log.Logger
}
const (
guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild"
)
func (gq *GuildQuery) New() *Guild {
return &Guild{
db: gq.db,
log: gq.log,
}
}
func (gq *GuildQuery) Get(discordID, guildID string) *Guild {
query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2"
row := gq.db.QueryRow(query, discordID, guildID)
if row == nil {
return nil
}
return gq.New().Scan(row)
}
func (gq *GuildQuery) GetAll(discordID string) []*Guild {
query := guildSelect + " WHERE discord_id=$1"
rows, err := gq.db.Query(query, discordID)
if err != nil || rows == nil {
return nil
}
guilds := []*Guild{}
for rows.Next() {
guilds = append(guilds, gq.New().Scan(rows))
}
return guilds
}
func (gq *GuildQuery) Prune(discordID string, guilds []string) {
// We need this interface slice because a variadic function can't mix
// arguements with a `...` expanded slice.
args := []interface{}{discordID}
nGuilds := len(guilds)
if nGuilds <= 0 {
return
}
gq.log.Debugfln("prunning guilds for %s", discordID)
// Build the in query
inQuery := "$2"
for i := 1; i < nGuilds; i++ {
inQuery += fmt.Sprintf(", $%d", i+2)
}
// Add the arguements for the build query
for _, guildID := range guilds {
args = append(args, guildID)
}
// Now remove any guilds that the user has left.
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" +
inQuery + ")"
_, err := gq.db.Exec(query, args...)
if err != nil {
gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err)
}
}

View File

@@ -11,23 +11,77 @@ import (
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
) )
type MessageQuery struct {
db *Database
log log.Logger
}
const (
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid FROM message"
)
func (mq *MessageQuery) New() *Message {
return &Message{
db: mq.db,
log: mq.log,
}
}
func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2"
rows, err := mq.db.Query(query, key.ChannelID, key.Receiver)
if err != nil || rows == nil {
return nil
}
var messages []*Message
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return messages
}
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)
if row == nil {
mq.log.Debugfln("failed to find existing message for discord_id %s", discordID)
return nil
}
return mq.New().Scan(row)
}
func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, mxid)
if row == nil {
return nil
}
return mq.New().Scan(row)
}
type Message struct { type Message struct {
db *Database db *Database
log log.Logger log log.Logger
Channel PortalKey
DiscordID string DiscordID string
MatrixID id.EventID Channel PortalKey
SenderID string
AuthorID string
Timestamp time.Time Timestamp time.Time
MXID id.EventID
} }
func (m *Message) Scan(row dbutil.Scannable) *Message { func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts int64 var ts int64
err := row.Scan(&m.Channel.ChannelID, &m.Channel.Receiver, &m.DiscordID, &m.MatrixID, &m.AuthorID, &ts) err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.MXID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err) m.log.Errorln("Database scan failed:", err)
@@ -44,38 +98,21 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
} }
func (m *Message) Insert() { func (m *Message) Insert() {
query := "INSERT INTO message" + query := "INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid) VALUES ($1, $2, $3, $4, $5, $6)"
" (channel_id, receiver, discord_message_id, matrix_message_id," +
" author_id, timestamp) VALUES ($1, $2, $3, $4, $5, $6)"
_, err := m.db.Exec(query, m.Channel.ChannelID, m.Channel.Receiver, _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.Unix(), m.MXID)
m.DiscordID, m.MatrixID, m.AuthorID, m.Timestamp.Unix())
if err != nil { if err != nil {
m.log.Warnfln("Failed to insert %s@%s: %v", m.Channel, m.DiscordID, err) m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
} }
} }
func (m *Message) Delete() { func (m *Message) Delete() {
query := "DELETE FROM message" + query := "DELETE FROM message WHERE dcid=$1 AND dc_chan_id=$2 AND dc_chan_receiver=$3"
" WHERE channel_id=$1 AND receiver=$2 AND discord_message_id=$3 AND" +
" matrix_message_id=$4"
_, err := m.db.Exec(query, m.Channel.ChannelID, m.Channel.Receiver, _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver)
m.DiscordID, m.MatrixID)
if err != nil { if err != nil {
m.log.Warnfln("Failed to delete %s@%s: %v", m.Channel, m.DiscordID, err) m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err)
}
}
func (m *Message) UpdateMatrixID(mxid id.EventID) {
query := "UPDATE message SET matrix_message_id=$1 WHERE channel_id=$2" +
" AND receiver=$3 AND discord_message_id=$4"
m.MatrixID = mxid
_, err := m.db.Exec(query, m.MatrixID, m.Channel.ChannelID, m.Channel.Receiver, m.DiscordID)
if err != nil {
m.log.Warnfln("Failed to update %s@%s: %v", m.Channel, m.DiscordID, err)
} }
} }

View File

@@ -1,64 +0,0 @@
package database
import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type MessageQuery struct {
db *Database
log log.Logger
}
const (
messageSelect = "SELECT channel_id, receiver, discord_message_id," +
" matrix_message_id, author_id, timestamp FROM message"
)
func (mq *MessageQuery) New() *Message {
return &Message{
db: mq.db,
log: mq.log,
}
}
func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
query := messageSelect + " WHERE channeld_id=$1 AND receiver=$2"
rows, err := mq.db.Query(query, key.ChannelID, key.Receiver)
if err != nil || rows == nil {
return nil
}
messages := []*Message{}
for rows.Next() {
messages = append(messages, mq.New().Scan(rows))
}
return messages
}
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
" discord_message_id=$3"
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)
if row == nil {
mq.log.Debugfln("failed to find existing message for discord_id %s", discordID)
return nil
}
return mq.New().Scan(row)
}
func (mq *MessageQuery) GetByMatrixID(key PortalKey, matrixID id.EventID) *Message {
query := messageSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
" matrix_message_id=$3"
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, matrixID)
if row == nil {
return nil
}
return mq.New().Scan(row)
}

View File

@@ -11,23 +11,85 @@ import (
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
) )
const (
portalSelect = "SELECT dcid, receiver, mxid, name, topic, avatar," +
" avatar_url, type, other_user_id, first_event_id, encrypted" +
" FROM portal"
)
type PortalQuery struct {
db *Database
log log.Logger
}
func (pq *PortalQuery) New() *Portal {
return &Portal{
db: pq.db,
log: pq.log,
}
}
func (pq *PortalQuery) GetAll() []*Portal {
return pq.getAll(portalSelect)
}
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
return pq.get(portalSelect+" WHERE dcid=$1 AND receiver=$2", key.ChannelID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
return pq.get(portalSelect+" WHERE mxid=$1", mxid)
}
func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal {
return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM)
}
func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal {
query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
}
func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
var portals []*Portal
for rows.Next() {
portals = append(portals, pq.New().Scan(rows))
}
return portals
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
}
type Portal struct { type Portal struct {
db *Database db *Database
log log.Logger log log.Logger
Key PortalKey Key PortalKey
Type discordgo.ChannelType
OtherUserID string
MXID id.RoomID MXID id.RoomID
Name string Name string
Topic string Topic string
Encrypted bool
Avatar string Avatar string
AvatarURL id.ContentURI AvatarURL id.ContentURI
Encrypted bool
Type discordgo.ChannelType
DMUser string
FirstEventID id.EventID FirstEventID id.EventID
} }
@@ -37,7 +99,7 @@ func (p *Portal) Scan(row dbutil.Scannable) *Portal {
var typ sql.NullInt32 var typ sql.NullInt32
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name, err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID, &p.Topic, &p.Avatar, &avatarURL, &typ, &p.OtherUserID, &firstEventID,
&p.Encrypted) &p.Encrypted)
if err != nil { if err != nil {
@@ -66,12 +128,12 @@ func (p *Portal) mxidPtr() *id.RoomID {
func (p *Portal) Insert() { func (p *Portal) Insert() {
query := "INSERT INTO portal" + query := "INSERT INTO portal" +
" (channel_id, receiver, mxid, name, topic, avatar, avatar_url," + " (dcid, receiver, mxid, name, topic, avatar, avatar_url," +
" type, dmuser, first_event_id, encrypted)" + " type, other_user_id, first_event_id, encrypted)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(), _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.DMUser, p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.OtherUserID,
p.FirstEventID.String(), p.Encrypted) p.FirstEventID.String(), p.Encrypted)
if err != nil { if err != nil {
@@ -82,11 +144,11 @@ func (p *Portal) Insert() {
func (p *Portal) Update() { func (p *Portal) Update() {
query := "UPDATE portal SET" + query := "UPDATE portal SET" +
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," + " mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
" dmuser=$7, first_event_id=$8, encrypted=$9" + " other_user_id=$7, first_event_id=$8, encrypted=$9" +
" WHERE channel_id=$10 AND receiver=$11" " WHERE dcid=$10 AND receiver=$11"
_, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar, _, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
p.AvatarURL.String(), p.Type, p.DMUser, p.FirstEventID.String(), p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(),
p.Encrypted, p.Encrypted,
p.Key.ChannelID, p.Key.Receiver) p.Key.ChannelID, p.Key.Receiver)
@@ -96,7 +158,7 @@ func (p *Portal) Update() {
} }
func (p *Portal) Delete() { func (p *Portal) Delete() {
query := "DELETE FROM portal WHERE channel_id=$1 AND receiver=$2" query := "DELETE FROM portal WHERE dcid=$1 AND receiver=$2"
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver) _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
if err != nil { if err != nil {
p.log.Warnfln("Failed to delete %s: %v", p.Key, err) p.log.Warnfln("Failed to delete %s: %v", p.Key, err)

View File

@@ -13,8 +13,8 @@ func NewPortalKey(channelID, receiver string) PortalKey {
} }
func (key PortalKey) String() string { func (key PortalKey) String() string {
if key.ChannelID == key.Receiver { if key.Receiver == "" {
return key.Receiver return key.ChannelID
} }
return key.ChannelID + "-" + key.Receiver return key.ChannelID + "-" + key.Receiver
} }

View File

@@ -1,71 +0,0 @@
package database
import (
"github.com/bwmarrin/discordgo"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
const (
portalSelect = "SELECT channel_id, receiver, mxid, name, topic, avatar," +
" avatar_url, type, dmuser, first_event_id, encrypted" +
" FROM portal"
)
type PortalQuery struct {
db *Database
log log.Logger
}
func (pq *PortalQuery) New() *Portal {
return &Portal{
db: pq.db,
log: pq.log,
}
}
func (pq *PortalQuery) GetAll() []*Portal {
return pq.getAll(portalSelect)
}
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
return pq.get(portalSelect+" WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
return pq.get(portalSelect+" WHERE mxid=$1", mxid)
}
func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal {
return pq.getAll(portalSelect+" WHERE dmuser=$1 AND type=$2", id, discordgo.ChannelTypeDM)
}
func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal {
query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
}
func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
portals := []*Portal{}
for rows.Next() {
portals = append(portals, pq.New().Scan(rows))
}
return portals
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
}

View File

@@ -11,11 +11,62 @@ import (
const ( const (
puppetSelect = "SELECT id, display_name, avatar, avatar_url," + puppetSelect = "SELECT id, display_name, avatar, avatar_url," +
" enable_presence, custom_mxid, access_token, next_batch," + " custom_mxid, access_token, next_batch" +
" enable_receipts" +
" FROM puppet " " FROM puppet "
) )
type PuppetQuery struct {
db *Database
log log.Logger
}
func (pq *PuppetQuery) New() *Puppet {
return &Puppet{
db: pq.db,
log: pq.log,
}
}
func (pq *PuppetQuery) Get(id string) *Puppet {
return pq.get(puppetSelect+" WHERE id=$1", id)
}
func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid)
}
func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
}
func (pq *PuppetQuery) GetAll() []*Puppet {
return pq.getAll(puppetSelect)
}
func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet {
return pq.getAll(puppetSelect + " WHERE custom_mxid<>''")
}
func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
puppets := []*Puppet{}
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
return puppets
}
type Puppet struct { type Puppet struct {
db *Database db *Database
log log.Logger log log.Logger
@@ -26,23 +77,17 @@ type Puppet struct {
Avatar string Avatar string
AvatarURL id.ContentURI AvatarURL id.ContentURI
EnablePresence bool
CustomMXID id.UserID CustomMXID id.UserID
AccessToken string AccessToken string
NextBatch string
NextBatch string
EnableReceipts bool
} }
func (p *Puppet) Scan(row dbutil.Scannable) *Puppet { func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
var did, displayName, avatar, avatarURL sql.NullString var did, displayName, avatar, avatarURL sql.NullString
var enablePresence sql.NullBool
var customMXID, accessToken, nextBatch sql.NullString var customMXID, accessToken, nextBatch sql.NullString
err := row.Scan(&did, &displayName, &avatar, &avatarURL, &enablePresence, err := row.Scan(&did, &displayName, &avatar, &avatarURL,
&customMXID, &accessToken, &nextBatch, &p.EnableReceipts) &customMXID, &accessToken, &nextBatch)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
@@ -56,7 +101,6 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
p.DisplayName = displayName.String p.DisplayName = displayName.String
p.Avatar = avatar.String p.Avatar = avatar.String
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String) p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
p.EnablePresence = enablePresence.Bool
p.CustomMXID = id.UserID(customMXID.String) p.CustomMXID = id.UserID(customMXID.String)
p.AccessToken = accessToken.String p.AccessToken = accessToken.String
p.NextBatch = nextBatch.String p.NextBatch = nextBatch.String
@@ -66,13 +110,13 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
func (p *Puppet) Insert() { func (p *Puppet) Insert() {
query := "INSERT INTO puppet" + query := "INSERT INTO puppet" +
" (id, display_name, avatar, avatar_url, enable_presence," + " (id, display_name, avatar, avatar_url," +
" custom_mxid, access_token, next_batch, enable_receipts)" + " custom_mxid, access_token, next_batch)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" " VALUES ($1, $2, $3, $4, $5, $6, $7)"
_, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar, _, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar,
p.AvatarURL.String(), p.EnablePresence, p.CustomMXID, p.AccessToken, p.AvatarURL.String(), p.CustomMXID, p.AccessToken,
p.NextBatch, p.EnableReceipts) p.NextBatch)
if err != nil { if err != nil {
p.log.Warnfln("Failed to insert %s: %v", p.ID, err) p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
@@ -81,14 +125,13 @@ func (p *Puppet) Insert() {
func (p *Puppet) Update() { func (p *Puppet) Update() {
query := "UPDATE puppet" + query := "UPDATE puppet" +
" SET display_name=$1, avatar=$2, avatar_url=$3, enable_presence=$4," + " SET display_name=$1, avatar=$2, avatar_url=$3, " +
" custom_mxid=$5, access_token=$6, next_batch=$7," + " custom_mxid=$4, access_token=$5, next_batch=$6" +
" enable_receipts=$8" + " WHERE id=$7"
" WHERE id=$9"
_, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(), _, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(),
p.EnablePresence, p.CustomMXID, p.AccessToken, p.NextBatch, p.CustomMXID, p.AccessToken, p.NextBatch,
p.EnableReceipts, p.ID) p.ID)
if err != nil { if err != nil {
p.log.Warnfln("Failed to update %s: %v", p.ID, err) p.log.Warnfln("Failed to update %s: %v", p.ID, err)

View File

@@ -1,60 +0,0 @@
package database
import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type PuppetQuery struct {
db *Database
log log.Logger
}
func (pq *PuppetQuery) New() *Puppet {
return &Puppet{
db: pq.db,
log: pq.log,
EnablePresence: true,
}
}
func (pq *PuppetQuery) Get(id string) *Puppet {
return pq.get(puppetSelect+" WHERE id=$1", id)
}
func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
return pq.get(puppetSelect+" WHERE custom_mxid=$1", mxid)
}
func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
}
func (pq *PuppetQuery) GetAll() []*Puppet {
return pq.getAll(puppetSelect)
}
func (pq *PuppetQuery) GetAllWithCustomMXID() []*Puppet {
return pq.getAll(puppetSelect + " WHERE custom_mxid<>''")
}
func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
puppets := []*Puppet{}
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
return puppets
}

View File

@@ -10,97 +10,102 @@ import (
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
) )
type ReactionQuery struct {
db *Database
log log.Logger
}
const (
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, mxid FROM reaction"
)
func (rq *ReactionQuery) New() *Reaction {
return &Reaction{
db: rq.db,
log: rq.log,
}
}
func (rq *ReactionQuery) GetAllForMessage(key PortalKey, discordMessageID string) []*Reaction {
query := reactionSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3"
return rq.getAll(query, key.ChannelID, key.Receiver, discordMessageID)
}
func (rq *ReactionQuery) getAll(query string, args ...interface{}) []*Reaction {
rows, err := rq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
var reactions []*Reaction
for rows.Next() {
reactions = append(reactions, rq.New().Scan(rows))
}
return reactions
}
func (rq *ReactionQuery) GetByDiscordID(key PortalKey, msgID, sender, emojiName string) *Reaction {
query := reactionSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3 AND dc_sender=$4 AND dc_emoji_name=$5"
return rq.get(query, key.ChannelID, key.Receiver, msgID, sender, emojiName)
}
func (rq *ReactionQuery) GetByMXID(mxid id.EventID) *Reaction {
query := reactionSelect + " WHERE mxid=$1"
return rq.get(query, mxid)
}
func (rq *ReactionQuery) get(query string, args ...interface{}) *Reaction {
row := rq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return rq.New().Scan(row)
}
type Reaction struct { type Reaction struct {
db *Database db *Database
log log.Logger log log.Logger
Channel PortalKey Channel PortalKey
MessageID string
Sender string
EmojiName string
DiscordMessageID string MXID id.EventID
MatrixEventID id.EventID
// The discord ID of who create this reaction
AuthorID string
MatrixName string
MatrixURL string // Used for custom emoji
DiscordID string // The id or unicode of the emoji for discord
} }
func (r *Reaction) Scan(row dbutil.Scannable) *Reaction { func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
var discordID sql.NullString err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.MXID)
err := row.Scan(
&r.Channel.ChannelID, &r.Channel.Receiver,
&r.DiscordMessageID, &r.MatrixEventID,
&r.AuthorID,
&r.MatrixName, &r.MatrixURL,
&discordID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
r.log.Errorln("Database scan failed:", err) r.log.Errorln("Database scan failed:", err)
} }
return nil return nil
} }
r.DiscordID = discordID.String
return r return r
} }
func (r *Reaction) Insert() { func (r *Reaction) Insert() {
query := "INSERT INTO reaction" + query := `
" (channel_id, receiver, discord_message_id, matrix_event_id," + INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, mxid)
" author_id, matrix_name, matrix_url, discord_id)" + VALUES($1, $2, $3, $4, $5, $6)
" VALUES($1, $2, $3, $4, $5, $6, $7, $8);" `
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, r.MXID)
var discordID sql.NullString
if r.DiscordID != "" {
discordID = sql.NullString{r.DiscordID, true}
}
_, err := r.db.Exec(
query,
r.Channel.ChannelID, r.Channel.Receiver,
r.DiscordMessageID, r.MatrixEventID,
r.AuthorID,
r.MatrixName, r.MatrixURL,
discordID,
)
if err != nil { if err != nil {
r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.Channel, r.DiscordMessageID, err) r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err)
} }
} }
func (r *Reaction) Update() {
// TODO: determine if we need this. The only scenario I can think of that
// would require this is if we insert a custom emoji before uploading to
// the homeserver?
}
func (r *Reaction) Delete() { func (r *Reaction) Delete() {
query := "DELETE FROM reaction WHERE" + query := "DELETE FROM reaction WHERE dc_msg_id=$1 AND dc_sender=$2 AND dc_emoji_name=$3"
" channel_id=$1 AND receiver=$2 AND discord_message_id=$3 AND" + _, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName)
" author_id=$4 AND discord_id=$5"
var discordID sql.NullString
if r.DiscordID != "" {
discordID = sql.NullString{r.DiscordID, true}
}
_, err := r.db.Exec(
query,
r.Channel.ChannelID, r.Channel.Receiver,
r.DiscordMessageID, r.AuthorID,
discordID,
)
if err != nil { if err != nil {
r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.Channel, r.DiscordMessageID, err) r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.MessageID, r.Channel, err)
} }
} }

View File

@@ -1,75 +0,0 @@
package database
import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type ReactionQuery struct {
db *Database
log log.Logger
}
const (
reactionSelect = "SELECT channel_id, receiver, discord_message_id," +
" matrix_event_id, author_id, matrix_name, matrix_url, " +
" discord_id FROM reaction"
)
func (rq *ReactionQuery) New() *Reaction {
return &Reaction{
db: rq.db,
log: rq.log,
}
}
func (rq *ReactionQuery) GetAllByDiscordID(key PortalKey, discordMessageID string) []*Reaction {
query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
" discord_message_id=$3"
return rq.getAll(query, key.ChannelID, key.Receiver, discordMessageID)
}
func (rq *ReactionQuery) GetAllByMatrixID(key PortalKey, matrixEventID id.EventID) []*Reaction {
query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2 AND" +
" matrix_event_id=$3"
return rq.getAll(query, key.ChannelID, key.Receiver, matrixEventID)
}
func (rq *ReactionQuery) getAll(query string, args ...interface{}) []*Reaction {
rows, err := rq.db.Query(query)
if err != nil || rows == nil {
return nil
}
reactions := []*Reaction{}
for rows.Next() {
reactions = append(reactions, rq.New().Scan(rows))
}
return reactions
}
func (rq *ReactionQuery) GetByDiscordID(key PortalKey, discordMessageID, discordID string) *Reaction {
query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" +
" AND discord_message_id=$3 AND discord_id=$4"
return rq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID)
}
func (rq *ReactionQuery) GetByMatrixID(key PortalKey, matrixEventID id.EventID) *Reaction {
query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" +
" AND matrix_event_id=$3"
return rq.get(query, key.ChannelID, key.Receiver, matrixEventID)
}
func (rq *ReactionQuery) get(query string, args ...interface{}) *Reaction {
row := rq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return rq.New().Scan(row)
}

View File

@@ -0,0 +1,92 @@
-- v0 -> v2: Latest revision
CREATE TABLE portal (
dcid TEXT,
receiver TEXT,
other_user_id TEXT,
type INTEGER,
mxid TEXT UNIQUE,
name TEXT NOT NULL,
topic TEXT NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
encrypted BOOLEAN NOT NULL DEFAULT false,
first_event_id TEXT NOT NULL,
PRIMARY KEY (dcid, receiver)
);
CREATE TABLE puppet (
id TEXT PRIMARY KEY,
name TEXT,
avatar TEXT,
avatar_url TEXT,
custom_mxid TEXT,
access_token TEXT,
next_batch TEXT
);
CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
dcid TEXT UNIQUE,
management_room TEXT,
token TEXT
);
CREATE TABLE message (
dcid TEXT,
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dcid, dc_chan_id, dc_chan_receiver),
CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE
);
CREATE TABLE reaction (
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_msg_id TEXT,
dc_sender TEXT,
dc_emoji_name TEXT,
mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name),
CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
);
CREATE TABLE attachment (
dcid TEXT,
dc_msg_id TEXT,
dc_chan_id TEXT,
dc_chan_receiver TEXT,
mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver),
CONSTRAINT attachment_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
);
CREATE TABLE emoji (
discord_id TEXT PRIMARY KEY,
discord_name TEXT,
matrix_url TEXT
);
CREATE TABLE guild (
discord_id TEXT NOT NULL,
guild_id TEXT NOT NULL,
guild_name TEXT NOT NULL,
bridge BOOLEAN DEFAULT FALSE,
PRIMARY KEY(discord_id, guild_id)
);

View File

@@ -1,105 +0,0 @@
-- v1: Initial revision
CREATE TABLE portal (
channel_id TEXT,
receiver TEXT,
mxid TEXT UNIQUE,
name TEXT NOT NULL,
topic TEXT NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT,
encrypted BOOLEAN NOT NULL DEFAULT false,
type INT,
dmuser TEXT,
first_event_id TEXT,
PRIMARY KEY (channel_id, receiver)
);
CREATE TABLE puppet (
id TEXT PRIMARY KEY,
display_name TEXT,
avatar TEXT,
avatar_url TEXT,
enable_presence BOOLEAN NOT NULL DEFAULT true,
enable_receipts BOOLEAN NOT NULL DEFAULT true,
custom_mxid TEXT,
access_token TEXT,
next_batch TEXT
);
CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
id TEXT UNIQUE,
management_room TEXT,
token TEXT
);
CREATE TABLE message (
channel_id TEXT NOT NULL,
receiver TEXT NOT NULL,
discord_message_id TEXT NOT NULL,
matrix_message_id TEXT NOT NULL UNIQUE,
author_id TEXT NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY(discord_message_id, channel_id, receiver),
FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE
);
CREATE TABLE reaction (
channel_id TEXT NOT NULL,
receiver TEXT NOT NULL,
discord_message_id TEXT NOT NULL,
matrix_event_id TEXT NOT NULL UNIQUE,
author_id TEXT NOT NULL,
matrix_name TEXT,
matrix_url TEXT,
discord_id TEXT,
UNIQUE (discord_id, author_id, discord_message_id, channel_id, receiver),
FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE
);
CREATE TABLE attachment (
channel_id TEXT NOT NULL,
receiver TEXT NOT NULL,
discord_message_id TEXT NOT NULL,
discord_attachment_id TEXT NOT NULL,
matrix_event_id TEXT NOT NULL UNIQUE,
PRIMARY KEY(discord_attachment_id, matrix_event_id),
FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE
);
CREATE TABLE emoji (
discord_id TEXT PRIMARY KEY,
discord_name TEXT,
matrix_url TEXT
);
CREATE TABLE guild (
discord_id TEXT NOT NULL,
guild_id TEXT NOT NULL,
guild_name TEXT NOT NULL,
bridge BOOLEAN DEFAULT FALSE,
PRIMARY KEY(discord_id, guild_id)
);

View File

@@ -0,0 +1,53 @@
-- v2: Rename columns in message-related tables
ALTER TABLE portal RENAME COLUMN dmuser TO other_user_id;
ALTER TABLE portal RENAME COLUMN channel_id TO dcid;
ALTER TABLE "user" RENAME COLUMN id TO dcid;
ALTER TABLE puppet DROP COLUMN enable_presence;
ALTER TABLE puppet DROP COLUMN enable_receipts;
DROP TABLE message;
DROP TABLE reaction;
DROP TABLE attachment;
CREATE TABLE message (
dcid TEXT,
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dcid, dc_chan_id, dc_chan_receiver),
CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE
);
CREATE TABLE reaction (
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_msg_id TEXT,
dc_sender TEXT,
dc_emoji_name TEXT,
mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name),
CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
);
CREATE TABLE attachment (
dcid TEXT,
dc_msg_id TEXT,
dc_chan_id TEXT,
dc_chan_receiver TEXT,
mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver),
CONSTRAINT attachment_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
);
UPDATE portal SET receiver='' WHERE type<>1;

View File

@@ -9,6 +9,54 @@ import (
"maunium.net/go/mautrix/util/dbutil" "maunium.net/go/mautrix/util/dbutil"
) )
type UserQuery struct {
db *Database
log log.Logger
}
func (uq *UserQuery) New() *User {
return &User{
db: uq.db,
log: uq.log,
}
}
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE mxid=$1`
row := uq.db.QueryRow(query, userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
}
func (uq *UserQuery) GetByID(id string) *User {
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE dcid=$1`
row := uq.db.QueryRow(query, id)
if row == nil {
return nil
}
return uq.New().Scan(row)
}
func (uq *UserQuery) GetAll() []*User {
rows, err := uq.db.Query(`SELECT mxid, dcid, management_room, token FROM "user" WHERE token IS NOT NULL`)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
users := []*User{}
for rows.Next() {
users = append(users, uq.New().Scan(rows))
}
return users
}
type User struct { type User struct {
db *Database db *Database
log log.Logger log log.Logger
@@ -46,9 +94,7 @@ func (u *User) Scan(row dbutil.Scannable) *User {
} }
func (u *User) Insert() { func (u *User) Insert() {
query := "INSERT INTO \"user\"" + query := "INSERT INTO \"user\" (mxid, dcid, management_room, token) VALUES ($1, $2, $3, $4)"
" (mxid, id, management_room, token)" +
" VALUES ($1, $2, $3, $4);"
var token sql.NullString var token sql.NullString
var discordID sql.NullString var discordID sql.NullString
@@ -71,9 +117,7 @@ func (u *User) Insert() {
} }
func (u *User) Update() { func (u *User) Update() {
query := "UPDATE \"user\" SET" + query := "UPDATE \"user\" SET dcid=$1, management_room=$2, token=$3 WHERE mxid=$4"
" id=$1, management_room=$2, token=$3" +
" WHERE mxid=$4;"
var token sql.NullString var token sql.NullString
var discordID sql.NullString var discordID sql.NullString

View File

@@ -1,54 +0,0 @@
package database
import (
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
)
type UserQuery struct {
db *Database
log log.Logger
}
func (uq *UserQuery) New() *User {
return &User{
db: uq.db,
log: uq.log,
}
}
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
query := `SELECT mxid, id, management_room, token FROM "user" WHERE mxid=$1`
row := uq.db.QueryRow(query, userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
}
func (uq *UserQuery) GetByID(id string) *User {
query := `SELECT mxid, id, management_room, token FROM "user" WHERE id=$1`
row := uq.db.QueryRow(query, id)
if row == nil {
return nil
}
return uq.New().Scan(row)
}
func (uq *UserQuery) GetAll() []*User {
rows, err := uq.db.Query(`SELECT mxid, id, management_room, token FROM "user" WHERE token IS NOT NULL`)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
users := []*User{}
for rows.Next() {
users = append(users, uq.New().Scan(rows))
}
return users
}

View File

@@ -81,11 +81,9 @@ bridge:
# Note that updating the m.direct event is not atomic (except with mautrix-asmux) # Note that updating the m.direct event is not atomic (except with mautrix-asmux)
# and is therefore prone to race conditions. # and is therefore prone to race conditions.
sync_direct_chat_list: false sync_direct_chat_list: false
# When double puppeting is enabled, users can use `!wa toggle` to change whether # Whether or not created rooms should have federation enabled.
# presence and read receipts are bridged. These settings set the default values. # If false, created portal rooms will never be federated.
# Existing users won't be affected when these are changed. federate_rooms: true
default_bridge_receipts: true
default_bridge_presence: true
# Servers to always allow double puppeting from # Servers to always allow double puppeting from
double_puppet_server_map: double_puppet_server_map:
example.com: https://example.com example.com: https://example.com

186
portal.go
View File

@@ -174,8 +174,8 @@ func (portal *Portal) IsPrivateChat() bool {
} }
func (portal *Portal) MainIntent() *appservice.IntentAPI { func (portal *Portal) MainIntent() *appservice.IntentAPI {
if portal.IsPrivateChat() && portal.DMUser != "" { if portal.IsPrivateChat() && portal.OtherUserID != "" {
return portal.bridge.GetPuppetByID(portal.DMUser).DefaultIntent() return portal.bridge.GetPuppetByID(portal.OtherUserID).DefaultIntent()
} }
return portal.bridge.Bot return portal.bridge.Bot
@@ -184,15 +184,13 @@ func (portal *Portal) MainIntent() *appservice.IntentAPI {
func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) error { func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) error {
portal.roomCreateLock.Lock() portal.roomCreateLock.Lock()
defer portal.roomCreateLock.Unlock() defer portal.roomCreateLock.Unlock()
// If we have a matrix id the room should exist so we have nothing to do.
if portal.MXID != "" { if portal.MXID != "" {
return nil return nil
} }
portal.Type = channel.Type portal.Type = channel.Type
if portal.Type == discordgo.ChannelTypeDM { if portal.Type == discordgo.ChannelTypeDM {
portal.DMUser = channel.Recipients[0].ID portal.OtherUserID = channel.Recipients[0].ID
} }
intent := portal.MainIntent() intent := portal.MainIntent()
@@ -219,7 +217,9 @@ func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) e
initialState := []*event.Event{} initialState := []*event.Event{}
creationContent := make(map[string]interface{}) creationContent := make(map[string]interface{})
creationContent["m.federate"] = false if !portal.bridge.Config.Bridge.FederateRooms {
creationContent["m.federate"] = false
}
var invite []id.UserID var invite []id.UserID
@@ -325,19 +325,14 @@ func (portal *Portal) ensureUserInvited(user *User) bool {
return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat()) return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat())
} }
func (portal *Portal) markMessageHandled(msg *database.Message, discordID string, mxid id.EventID, authorID string, timestamp time.Time) *database.Message { func (portal *Portal) markMessageHandled(discordID string, mxid id.EventID, authorID string, timestamp time.Time) *database.Message {
if msg == nil { msg := portal.bridge.DB.Message.New()
msg := portal.bridge.DB.Message.New() msg.Channel = portal.Key
msg.Channel = portal.Key msg.DiscordID = discordID
msg.DiscordID = discordID msg.MXID = mxid
msg.MatrixID = mxid msg.SenderID = authorID
msg.AuthorID = authorID msg.Timestamp = timestamp
msg.Timestamp = timestamp msg.Insert()
msg.Insert()
} else {
msg.UpdateMatrixID(mxid)
}
return msg return msg
} }
@@ -410,7 +405,7 @@ func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgI
dbAttachment.Channel = portal.Key dbAttachment.Channel = portal.Key
dbAttachment.DiscordMessageID = msgID dbAttachment.DiscordMessageID = msgID
dbAttachment.DiscordAttachmentID = attachment.ID dbAttachment.DiscordAttachmentID = attachment.ID
dbAttachment.MatrixEventID = resp.EventID dbAttachment.MXID = resp.EventID
dbAttachment.Insert() dbAttachment.Insert()
} }
@@ -461,14 +456,14 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
MsgType: event.MsgText, MsgType: event.MsgText,
} }
if msg.MessageReference != nil { if msg.MessageReference != nil && msg.MessageReference.ChannelID == portal.Key.ChannelID {
key := database.PortalKey{msg.MessageReference.ChannelID, user.ID} //key := database.PortalKey{msg.MessageReference.ChannelID, user.ID}
existing := portal.bridge.DB.Message.GetByDiscordID(key, msg.MessageReference.MessageID) replyTo := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID)
if existing != nil && existing.MatrixID != "" { if replyTo != nil {
content.RelatesTo = &event.RelatesTo{ content.RelatesTo = &event.RelatesTo{
Type: event.RelReply, Type: event.RelReply,
EventID: existing.MatrixID, EventID: existing.MXID,
} }
} }
} }
@@ -481,7 +476,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
} }
ts, _ := msg.Timestamp.Parse() ts, _ := msg.Timestamp.Parse()
portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts) portal.markMessageHandled(msg.ID, resp.EventID, msg.Author.ID, ts)
} }
// now run through any attachments the message has // now run through any attachments the message has
@@ -544,9 +539,9 @@ func (portal *Portal) handleDiscordMessagesUpdate(user *User, msg *discordgo.Mes
// Finally run through any attachments still in the map and delete them // Finally run through any attachments still in the map and delete them
// on the matrix side and our database. // on the matrix side and our database.
for _, attachment := range attachmentMap { for _, attachment := range attachmentMap {
_, err := intent.RedactEvent(portal.MXID, attachment.MatrixEventID) _, err := intent.RedactEvent(portal.MXID, attachment.MXID)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MatrixEventID, err) portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err)
} }
attachment.Delete() attachment.Delete()
@@ -560,17 +555,17 @@ func (portal *Portal) handleDiscordMessagesUpdate(user *User, msg *discordgo.Mes
MsgType: event.MsgText, MsgType: event.MsgText,
} }
content.SetEdit(existing.MatrixID) content.SetEdit(existing.MXID)
resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) _, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli())
if err != nil { if err != nil {
portal.log.Warnfln("failed to send message %q to matrix: %v", msg.ID, err) portal.log.Warnfln("failed to send message %q to matrix: %v", msg.ID, err)
return return
} }
ts, _ := msg.Timestamp.Parse() //ts, _ := msg.Timestamp.Parse()
portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts) //portal.markMessageHandled(existing, msg.ID, resp.EventID, msg.Author.ID, ts)
} }
func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) { func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) {
@@ -583,19 +578,12 @@ func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Mess
// Find the message that we're working with. This could correctly return // Find the message that we're working with. This could correctly return
// nil if the message was just one or more attachments. // nil if the message was just one or more attachments.
existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID)
intent := portal.MainIntent()
var intent *appservice.IntentAPI
if portal.Type == discordgo.ChannelTypeDM {
intent = portal.bridge.GetPuppetByID(portal.DMUser).IntentFor(portal)
} else {
intent = portal.MainIntent()
}
if existing != nil { if existing != nil {
_, err := intent.RedactEvent(portal.MXID, existing.MatrixID) _, err := intent.RedactEvent(portal.MXID, existing.MXID)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to remove message %s: %v", existing.MatrixID, err) portal.log.Warnfln("Failed to remove message %s: %v", existing.MXID, err)
} }
existing.Delete() existing.Delete()
@@ -604,9 +592,9 @@ func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Mess
// Now delete all of the existing attachments. // Now delete all of the existing attachments.
attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID) attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID)
for _, attachment := range attachments { for _, attachment := range attachments {
_, err := intent.RedactEvent(portal.MXID, attachment.MatrixEventID) _, err := intent.RedactEvent(portal.MXID, attachment.MXID)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MatrixEventID, err) portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err)
} }
attachment.Delete() attachment.Delete()
@@ -646,7 +634,6 @@ func (portal *Portal) encrypt(content *event.Content, eventType event.Type) (eve
return eventType, nil return eventType, nil
} }
const doublePuppetKey = "fi.mau.double_puppet_source"
const doublePuppetValue = "mautrix-discord" const doublePuppetValue = "mautrix-discord"
func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, timestamp int64) (*mautrix.RespSendEvent, error) { func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, timestamp int64) (*mautrix.RespSendEvent, error) {
@@ -656,7 +643,7 @@ func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType
wrappedContent.Raw = map[string]interface{}{} wrappedContent.Raw = map[string]interface{}{}
} }
if intent.IsCustomPuppet { if intent.IsCustomPuppet {
wrappedContent.Raw[doublePuppetKey] = doublePuppetValue wrappedContent.Raw[bridge.DoublePuppetKey] = doublePuppetValue
} }
} }
var err error var err error
@@ -668,7 +655,7 @@ func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType
if eventType == event.EventEncrypted { if eventType == event.EventEncrypted {
// Clear other custom keys if the event was encrypted, but keep the double puppet identifier // Clear other custom keys if the event was encrypted, but keep the double puppet identifier
if intent.IsCustomPuppet { if intent.IsCustomPuppet {
wrappedContent.Raw = map[string]interface{}{doublePuppetKey: doublePuppetValue} wrappedContent.Raw = map[string]interface{}{bridge.DoublePuppetKey: doublePuppetValue}
} else { } else {
wrappedContent.Raw = nil wrappedContent.Raw = nil
} }
@@ -700,13 +687,6 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
return return
} }
existing := portal.bridge.DB.Message.GetByMatrixID(portal.Key, evt.ID)
if existing != nil {
portal.log.Debugln("not handling duplicate message", evt.ID)
return
}
content, ok := evt.Content.Parsed.(*event.MessageEventContent) content, ok := evt.Content.Parsed.(*event.MessageEventContent)
if !ok { if !ok {
portal.log.Debugfln("Failed to handle event %s: unexpected parsed content type %T", evt.ID, evt.Content.Parsed) portal.log.Debugfln("Failed to handle event %s: unexpected parsed content type %T", evt.ID, evt.Content.Parsed)
@@ -715,15 +695,15 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
} }
if content.RelatesTo != nil && content.RelatesTo.Type == event.RelReplace { if content.RelatesTo != nil && content.RelatesTo.Type == event.RelReplace {
existing := portal.bridge.DB.Message.GetByMatrixID(portal.Key, content.RelatesTo.EventID) edits := portal.bridge.DB.Message.GetByMXID(portal.Key, content.RelatesTo.EventID)
if existing != nil && existing.DiscordID != "" { if edits != nil {
// we don't have anything to save for the update message right now // we don't have anything to save for the update message right now
// as we're not tracking edited timestamps. // as we're not tracking edited timestamps.
_, err := sender.Session.ChannelMessageEdit(portal.Key.ChannelID, _, err := sender.Session.ChannelMessageEdit(portal.Key.ChannelID,
existing.DiscordID, content.NewContent.Body) edits.DiscordID, content.NewContent.Body)
if err != nil { if err != nil {
portal.log.Errorln("Failed to update message %s: %v", existing.DiscordID, err) portal.log.Errorln("Failed to update message %s: %v", edits.DiscordID, err)
return return
} }
@@ -740,18 +720,18 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
sent := false sent := false
if content.RelatesTo != nil && content.RelatesTo.Type == event.RelReply { if content.RelatesTo != nil && content.RelatesTo.Type == event.RelReply {
existing := portal.bridge.DB.Message.GetByMatrixID( replyTo := portal.bridge.DB.Message.GetByMXID(
portal.Key, portal.Key,
content.RelatesTo.EventID, content.RelatesTo.EventID,
) )
if existing != nil && existing.DiscordID != "" { if replyTo != nil {
msg, err = sender.Session.ChannelMessageSendReply( msg, err = sender.Session.ChannelMessageSendReply(
portal.Key.ChannelID, portal.Key.ChannelID,
content.Body, content.Body,
&discordgo.MessageReference{ &discordgo.MessageReference{
ChannelID: portal.Key.ChannelID, ChannelID: portal.Key.ChannelID,
MessageID: existing.DiscordID, MessageID: replyTo.DiscordID,
}, },
) )
if err == nil { if err == nil {
@@ -771,13 +751,11 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
} }
msgSend := &discordgo.MessageSend{ msgSend := &discordgo.MessageSend{
Files: []*discordgo.File{ Files: []*discordgo.File{{
&discordgo.File{ Name: content.Body,
Name: content.Body, ContentType: content.Info.MimeType,
ContentType: content.Info.MimeType, Reader: bytes.NewReader(data),
Reader: bytes.NewReader(data), }},
},
},
} }
msg, err = sender.Session.ChannelMessageSendComplex(portal.Key.ChannelID, msgSend) msg, err = sender.Session.ChannelMessageSendComplex(portal.Key.ChannelID, msgSend)
@@ -796,8 +774,9 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
dbMsg := portal.bridge.DB.Message.New() dbMsg := portal.bridge.DB.Message.New()
dbMsg.Channel = portal.Key dbMsg.Channel = portal.Key
dbMsg.DiscordID = msg.ID dbMsg.DiscordID = msg.ID
dbMsg.MatrixID = evt.ID dbMsg.MXID = evt.ID
dbMsg.AuthorID = sender.ID dbMsg.SenderID = sender.ID
// TODO use actual timestamp
dbMsg.Timestamp = time.Now() dbMsg.Timestamp = time.Now()
dbMsg.Insert() dbMsg.Insert()
} }
@@ -927,7 +906,7 @@ func (portal *Portal) handleMatrixReaction(user *User, evt *event.Event) {
var discordID string var discordID string
msg := portal.bridge.DB.Message.GetByMatrixID(portal.Key, reaction.RelatesTo.EventID) msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID)
// Due to the differences in attachments between Discord and Matrix, if a // Due to the differences in attachments between Discord and Matrix, if a
// user reacts to a media message on discord our lookup above will fail // user reacts to a media message on discord our lookup above will fail
@@ -976,13 +955,11 @@ func (portal *Portal) handleMatrixReaction(user *User, evt *event.Event) {
} }
dbReaction := portal.bridge.DB.Reaction.New() dbReaction := portal.bridge.DB.Reaction.New()
dbReaction.Channel.ChannelID = portal.Key.ChannelID dbReaction.Channel = portal.Key
dbReaction.Channel.Receiver = portal.Key.Receiver dbReaction.MessageID = discordID
dbReaction.MatrixEventID = evt.ID dbReaction.Sender = user.ID
dbReaction.DiscordMessageID = discordID dbReaction.EmojiName = emojiID
dbReaction.AuthorID = user.ID dbReaction.MXID = evt.ID
dbReaction.MatrixName = reaction.RelatesTo.Key
dbReaction.DiscordID = emojiID
dbReaction.Insert() dbReaction.Insert()
} }
@@ -990,7 +967,7 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
intent := portal.bridge.GetPuppetByID(reaction.UserID).IntentFor(portal) intent := portal.bridge.GetPuppetByID(reaction.UserID).IntentFor(portal)
var discordID string var discordID string
var matrixID string var matrixReaction string
if reaction.Emoji.ID != "" { if reaction.Emoji.ID != "" {
dbEmoji := portal.bridge.DB.Emoji.GetByDiscordID(reaction.Emoji.ID) dbEmoji := portal.bridge.DB.Emoji.GetByDiscordID(reaction.Emoji.ID)
@@ -1018,10 +995,10 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
} }
discordID = dbEmoji.DiscordID discordID = dbEmoji.DiscordID
matrixID = dbEmoji.MatrixURL.String() matrixReaction = dbEmoji.MatrixURL.String()
} else { } else {
discordID = reaction.Emoji.Name discordID = reaction.Emoji.Name
matrixID = reaction.Emoji.Name matrixReaction = reaction.Emoji.Name
} }
// Find the message that we're working with. // Find the message that we're working with.
@@ -1033,8 +1010,7 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
} }
// Lookup an existing reaction // Lookup an existing reaction
existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message.DiscordID, discordID) existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message.DiscordID, reaction.UserID, discordID)
if !add { if !add {
if existing == nil { if existing == nil {
portal.log.Debugln("Failed to remove reaction for unknown message", reaction.MessageID) portal.log.Debugln("Failed to remove reaction for unknown message", reaction.MessageID)
@@ -1042,21 +1018,24 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
return return
} }
_, err := intent.RedactEvent(portal.MXID, existing.MatrixEventID) _, err := intent.RedactEvent(portal.MXID, existing.MXID)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to remove reaction from %s: %v", portal.MXID, err) portal.log.Warnfln("Failed to remove reaction from %s: %v", portal.MXID, err)
} }
existing.Delete() existing.Delete()
return
} else if existing != nil {
portal.log.Debugfln("Ignoring duplicate reaction %s from %s to %s", discordID, reaction.UserID, message.DiscordID)
return return
} }
content := event.Content{Parsed: &event.ReactionEventContent{ content := event.Content{Parsed: &event.ReactionEventContent{
RelatesTo: event.RelatesTo{ RelatesTo: event.RelatesTo{
EventID: message.MatrixID, EventID: message.MXID,
Type: event.RelAnnotation, Type: event.RelAnnotation,
Key: matrixID, Key: matrixReaction,
}, },
}} }}
@@ -1070,13 +1049,10 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
if existing == nil { if existing == nil {
dbReaction := portal.bridge.DB.Reaction.New() dbReaction := portal.bridge.DB.Reaction.New()
dbReaction.Channel = portal.Key dbReaction.Channel = portal.Key
dbReaction.DiscordMessageID = message.DiscordID dbReaction.MessageID = message.DiscordID
dbReaction.MatrixEventID = resp.EventID dbReaction.Sender = reaction.UserID
dbReaction.AuthorID = reaction.UserID dbReaction.EmojiName = discordID
dbReaction.MXID = resp.EventID
dbReaction.MatrixName = matrixID
dbReaction.DiscordID = discordID
dbReaction.Insert() dbReaction.Insert()
} }
} }
@@ -1087,7 +1063,7 @@ func (portal *Portal) handleMatrixRedaction(user *User, evt *event.Event) {
} }
// First look if we're redacting a message // First look if we're redacting a message
message := portal.bridge.DB.Message.GetByMatrixID(portal.Key, evt.Redacts) message := portal.bridge.DB.Message.GetByMXID(portal.Key, evt.Redacts)
if message != nil { if message != nil {
if message.DiscordID != "" { if message.DiscordID != "" {
err := user.Session.ChannelMessageDelete(portal.Key.ChannelID, message.DiscordID) err := user.Session.ChannelMessageDelete(portal.Key.ChannelID, message.DiscordID)
@@ -1102,21 +1078,19 @@ func (portal *Portal) handleMatrixRedaction(user *User, evt *event.Event) {
} }
// Now check if it's a reaction. // Now check if it's a reaction.
reaction := portal.bridge.DB.Reaction.GetByMatrixID(portal.Key, evt.Redacts) reaction := portal.bridge.DB.Reaction.GetByMXID(evt.Redacts)
if reaction != nil { if reaction != nil && reaction.Channel == portal.Key {
if reaction.DiscordID != "" { err := user.Session.MessageReactionRemove(portal.Key.ChannelID, reaction.MessageID, reaction.EmojiName, reaction.Sender)
err := user.Session.MessageReactionRemove(portal.Key.ChannelID, reaction.DiscordMessageID, reaction.DiscordID, reaction.AuthorID) if err != nil {
if err != nil { portal.log.Debugfln("Failed to delete reaction %s from %s: %v", reaction.EmojiName, reaction.MessageID, err)
portal.log.Debugfln("Failed to delete reaction %s for message %s: %v", reaction.DiscordID, reaction.DiscordMessageID, err) } else {
} else { reaction.Delete()
reaction.Delete()
}
} }
return return
} }
portal.log.Warnfln("Failed to redact %s@%s: no event found", portal.Key, evt.Redacts) portal.log.Warnfln("Failed to redact %s: no event found", evt.Redacts)
} }
func (portal *Portal) update(user *User, channel *discordgo.Channel) { func (portal *Portal) update(user *User, channel *discordgo.Channel) {
@@ -1150,9 +1124,9 @@ func (portal *Portal) update(user *User, channel *discordgo.Channel) {
var url string var url string
if portal.Type == discordgo.ChannelTypeDM { if portal.Type == discordgo.ChannelTypeDM {
dmUser, err := user.Session.User(portal.DMUser) dmUser, err := user.Session.User(portal.OtherUserID)
if err != nil { if err != nil {
portal.log.Warnln("failed to lookup the dmuser", err) portal.log.Warnln("failed to lookup the other user in DM", err)
} else { } else {
url = dmUser.AvatarURL("") url = dmUser.AvatarURL("")
} }

View File

@@ -527,7 +527,7 @@ func (user *User) createChannel(c *discordgo.Channel) {
portal.Type = c.Type portal.Type = c.Type
if portal.Type == discordgo.ChannelTypeDM { if portal.Type == discordgo.ChannelTypeDM {
portal.DMUser = c.Recipients[0].ID portal.OtherUserID = c.Recipients[0].ID
} }
if c.Icon != "" { if c.Icon != "" {