Remove separate attachment table. Fixes #8

This commit is contained in:
Tulir Asokan
2022-06-27 10:53:49 +03:00
parent a55dcc7c9f
commit 21269d91dd
7 changed files with 168 additions and 287 deletions

View File

@@ -1,126 +0,0 @@
package database
import (
"database/sql"
"errors"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type AttachmentQuery struct {
db *Database
log log.Logger
}
const (
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid FROM attachment"
)
func (aq *AttachmentQuery) New() *Attachment {
return &Attachment{
db: aq.db,
log: aq.log,
}
}
func (aq *AttachmentQuery) GetAllByDiscordMessageID(key PortalKey, discordMessageID string) []*Attachment {
query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3"
return aq.getAll(query, key.ChannelID, key.Receiver, discordMessageID)
}
func (aq *AttachmentQuery) getAll(query string, args ...interface{}) []*Attachment {
rows, err := aq.db.Query(query, args...)
if err != nil {
aq.log.Debugfln("getAll failed: %v", err)
return nil
}
if rows == nil {
return nil
}
var attachments []*Attachment
for rows.Next() {
attachments = append(attachments, aq.New().Scan(rows))
}
return attachments
}
func (aq *AttachmentQuery) GetByDiscordID(key PortalKey, discordMessageID, discordID string) *Attachment {
query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_msg_id=$3 AND dcid=$4"
return aq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID)
}
func (aq *AttachmentQuery) GetByMatrixID(key PortalKey, matrixEventID id.EventID) *Attachment {
query := attachmentSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND mxid=$3"
return aq.get(query, key.ChannelID, key.Receiver, matrixEventID)
}
func (aq *AttachmentQuery) get(query string, args ...interface{}) *Attachment {
row := aq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return aq.New().Scan(row)
}
type Attachment struct {
db *Database
log log.Logger
Channel PortalKey
ThreadID string
MessageID string
ID string
MXID id.EventID
}
func (a *Attachment) DiscordProtoChannelID() string {
if a.ThreadID != "" {
return a.ThreadID
} else {
return a.Channel.ChannelID
}
}
func (a *Attachment) Scan(row dbutil.Scannable) *Attachment {
err := row.Scan(&a.ID, &a.MessageID, &a.Channel.ChannelID, &a.Channel.Receiver, &a.ThreadID, &a.MXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
a.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
return a
}
func (a *Attachment) Insert() {
query := `
INSERT INTO attachment (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid)
VALUES ($1, $2, $3, $4, $5, $6)
`
_, err := a.db.Exec(query, a.ID, a.MessageID, a.Channel.ChannelID, a.Channel.Receiver, strPtr(a.ThreadID), a.MXID)
if err != nil {
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.ID, a.Channel, err)
panic(err)
}
}
func (a *Attachment) Delete() {
query := `DELETE FROM attachment WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3`
_, err := a.db.Exec(query, a.Channel.ChannelID, a.Channel.Receiver, a.ID)
if err != nil {
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.ID, a.Channel, err)
panic(err)
}
}

View File

@@ -14,15 +14,14 @@ import (
type Database struct { type Database struct {
*dbutil.Database *dbutil.Database
User *UserQuery User *UserQuery
Portal *PortalQuery Portal *PortalQuery
Puppet *PuppetQuery Puppet *PuppetQuery
Message *MessageQuery Message *MessageQuery
Thread *ThreadQuery Thread *ThreadQuery
Reaction *ReactionQuery Reaction *ReactionQuery
Attachment *AttachmentQuery Emoji *EmojiQuery
Emoji *EmojiQuery Guild *GuildQuery
Guild *GuildQuery
} }
func New(baseDB *dbutil.Database) *Database { func New(baseDB *dbutil.Database) *Database {
@@ -52,10 +51,6 @@ func New(baseDB *dbutil.Database) *Database {
db: db, db: db,
log: db.Log.Sub("Reaction"), log: db.Log.Sub("Reaction"),
} }
db.Attachment = &AttachmentQuery{
db: db,
log: db.Log.Sub("Attachment"),
}
db.Emoji = &EmojiQuery{ db.Emoji = &EmojiQuery{
db: db, db: db,
log: db.Log.Sub("Emoji"), log: db.Log.Sub("Emoji"),

View File

@@ -3,6 +3,8 @@ package database
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"strings"
"time" "time"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
@@ -17,7 +19,7 @@ type MessageQuery struct {
} }
const ( const (
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message" messageSelect = "SELECT dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message"
) )
func (mq *MessageQuery) New() *Message { func (mq *MessageQuery) New() *Message {
@@ -27,11 +29,12 @@ func (mq *MessageQuery) New() *Message {
} }
} }
func (mq *MessageQuery) GetAll(key PortalKey) []*Message { func (mq *MessageQuery) scanAll(rows *sql.Rows, err error) []*Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2" if err != nil {
mq.log.Warnfln("Failed to query many messages: %v", err)
rows, err := mq.db.Query(query, key.ChannelID, key.Receiver) panic(err)
if err != nil || rows == nil { return nil
} else if rows == nil {
return nil return nil
} }
@@ -43,13 +46,18 @@ func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
return messages return messages
} }
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message { func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) []*Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3" query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC"
return mq.scanAll(mq.db.Query(query, key.ChannelID, key.Receiver, discordID))
}
func (mq *MessageQuery) GetFirstByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3 AND dc_edit_index=0 ORDER BY dc_attachment_id ASC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
} }
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message { func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC LIMIT 1" query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND dc_edit_index=0 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID)) return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
} }
@@ -68,11 +76,13 @@ type Message struct {
db *Database db *Database
log log.Logger log log.Logger
DiscordID string DiscordID string
Channel PortalKey AttachmentID string
SenderID string EditIndex int
Timestamp time.Time Channel PortalKey
ThreadID string SenderID string
Timestamp time.Time
ThreadID string
MXID id.EventID MXID id.EventID
} }
@@ -89,7 +99,7 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts int64 var ts int64
var threadID sql.NullString var threadID sql.NullString
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID) err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err) m.log.Errorln("Database scan failed:", err)
@@ -107,14 +117,49 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
return m return m
} }
func (m *Message) Insert() { const messageInsertQuery = `
query := ` INSERT INTO message (
INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid) dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid
VALUES ($1, $2, $3, $4, $5, $6, $7) )
` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
_, err := m.db.Exec(query, var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9)", "%s", 1)
m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
type MessagePart struct {
AttachmentID string
MXID id.EventID
}
func (m *Message) MassInsert(msgs []MessagePart) {
valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d)"
if m.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
params := make([]interface{}, 7+len(msgs)*2)
placeholders := make([]string, len(msgs))
params[0] = m.DiscordID
params[1] = m.EditIndex
params[2] = m.Channel.ChannelID
params[3] = m.Channel.Receiver
params[4] = m.SenderID
params[5] = m.Timestamp.UnixMilli()
params[6] = m.ThreadID
for i, msg := range msgs {
params[7+i*2] = msg.AttachmentID
params[7+i*2+1] = msg.MXID
placeholders[i] = fmt.Sprintf(valueStringFormat, 7+i*2+1, 7+i*2+2)
}
_, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
if err != nil {
m.log.Warnfln("Failed to insert %d parts of %s@%s: %v", len(msgs), m.DiscordID, m.Channel, err)
panic(err)
}
}
func (m *Message) Insert() {
_, err := m.db.Exec(messageInsertQuery,
m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID) m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
if err != nil { if err != nil {

View File

@@ -78,6 +78,8 @@ type Reaction struct {
ThreadID string ThreadID string
MXID id.EventID MXID id.EventID
FirstAttachmentID string
} }
func (r *Reaction) Scan(row dbutil.Scannable) *Reaction { func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
@@ -105,10 +107,10 @@ func (r *Reaction) DiscordProtoChannelID() string {
func (r *Reaction) Insert() { func (r *Reaction) Insert() {
query := ` query := `
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid) INSERT INTO reaction (dc_msg_id, dc_first_attachment_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid)
VALUES($1, $2, $3, $4, $5, $6, $7) VALUES($1, $2, $3, $4, $5, $6, $7, $8)
` `
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID) _, err := r.db.Exec(query, r.MessageID, r.FirstAttachmentID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID)
if err != nil { if err != nil {
r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err) r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err)
panic(err) panic(err)

View File

@@ -1,4 +1,4 @@
-- v0 -> v3: Latest revision -- v0 -> v4: Latest revision
CREATE TABLE guild ( CREATE TABLE guild (
dcid TEXT PRIMARY KEY, dcid TEXT PRIMARY KEY,
@@ -90,6 +90,8 @@ CREATE TABLE user_portal (
CREATE TABLE message ( CREATE TABLE message (
dcid TEXT, dcid TEXT,
dc_attachment_id TEXT,
dc_edit_index INTEGER,
dc_chan_id TEXT, dc_chan_id TEXT,
dc_chan_receiver TEXT, dc_chan_receiver TEXT,
dc_sender TEXT NOT NULL, dc_sender TEXT NOT NULL,
@@ -98,7 +100,7 @@ CREATE TABLE message (
mxid TEXT NOT NULL UNIQUE, mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dcid, dc_chan_id, dc_chan_receiver), PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver),
CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE
); );
@@ -110,10 +112,13 @@ CREATE TABLE reaction (
dc_emoji_name TEXT, dc_emoji_name TEXT,
dc_thread_id TEXT, dc_thread_id TEXT,
dc_first_attachment_id TEXT NOT NULL,
_dc_first_edit_index INTEGER NOT NULL DEFAULT 0,
mxid TEXT NOT NULL UNIQUE, mxid TEXT NOT NULL UNIQUE,
PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name), PRIMARY KEY (dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name),
CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE CONSTRAINT reaction_message_fkey FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver) REFERENCES message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver) ON DELETE CASCADE
); );
CREATE TABLE attachment ( CREATE TABLE attachment (

View File

@@ -0,0 +1,20 @@
-- v4: Fix storing attachments
ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey;
ALTER TABLE attachment DROP CONSTRAINT attachment_message_fkey;
ALTER TABLE message DROP CONSTRAINT message_pkey;
ALTER TABLE message ADD COLUMN dc_attachment_id TEXT NOT NULL DEFAULT '';
ALTER TABLE message ADD COLUMN dc_edit_index INTEGER NOT NULL DEFAULT 0;
ALTER TABLE message ALTER COLUMN dc_attachment_id DROP DEFAULT;
ALTER TABLE message ALTER COLUMN dc_edit_index DROP DEFAULT;
ALTER TABLE message ADD PRIMARY KEY (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver);
INSERT INTO message (dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid)
SELECT message.dcid, attachment.dcid, 0, attachment.dc_chan_id, attachment.dc_chan_receiver, message.dc_sender, message.timestamp, attachment.dc_thread_id, attachment.mxid
FROM attachment LEFT JOIN message ON attachment.dc_msg_id = message.dcid;
DROP TABLE attachment;
ALTER TABLE reaction ADD COLUMN dc_first_attachment_id TEXT NOT NULL DEFAULT '';
ALTER TABLE reaction ALTER COLUMN dc_first_attachment_id DROP DEFAULT;
ALTER TABLE reaction ADD COLUMN _dc_first_edit_index INTEGER DEFAULT 0;
ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey
FOREIGN KEY (dc_msg_id, dc_first_attachment_id, _dc_first_edit_index, dc_chan_id, dc_chan_receiver)
REFERENCES message(dcid, dc_attachment_id, dc_edit_index, dc_chan_id, dc_chan_receiver);

180
portal.go
View File

@@ -459,16 +459,15 @@ func (portal *Portal) ensureUserInvited(user *User) bool {
return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat()) return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat())
} }
func (portal *Portal) markMessageHandled(discordID string, mxid id.EventID, authorID string, timestamp time.Time, threadID string) *database.Message { func (portal *Portal) markMessageHandled(discordID string, editIndex int, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) {
msg := portal.bridge.DB.Message.New() msg := portal.bridge.DB.Message.New()
msg.Channel = portal.Key msg.Channel = portal.Key
msg.DiscordID = discordID msg.DiscordID = discordID
msg.MXID = mxid msg.EditIndex = editIndex
msg.SenderID = authorID msg.SenderID = authorID
msg.Timestamp = timestamp msg.Timestamp = timestamp
msg.ThreadID = threadID msg.ThreadID = threadID
msg.Insert() msg.MassInsert(parts)
return msg
} }
func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) { func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr error) {
@@ -483,7 +482,7 @@ func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridg
} }
} }
func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgID string, attachment *discordgo.MessageAttachment, ts time.Time, threadRelation *event.RelatesTo, threadID string) { func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgID string, attachment *discordgo.MessageAttachment, ts time.Time, threadRelation *event.RelatesTo, threadID string) *database.MessagePart {
// var captionContent *event.MessageEventContent // var captionContent *event.MessageEventContent
// if attachment.Description != "" { // if attachment.Description != "" {
@@ -521,34 +520,27 @@ func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgI
data, err := portal.downloadDiscordAttachment(attachment.URL) data, err := portal.downloadDiscordAttachment(attachment.URL)
if err != nil { if err != nil {
portal.sendMediaFailedMessage(intent, err) portal.sendMediaFailedMessage(intent, err)
return nil
return
} }
err = portal.uploadMatrixAttachment(intent, data, content) err = portal.uploadMatrixAttachment(intent, data, content)
if err != nil { if err != nil {
portal.sendMediaFailedMessage(intent, err) portal.sendMediaFailedMessage(intent, err)
return nil
return
} }
resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, ts.UnixMilli()) resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, ts.UnixMilli())
if err != nil { if err != nil {
portal.log.Warnfln("failed to send media message to matrix: %v", err) portal.log.Warnfln("failed to send media message to matrix: %v", err)
} }
dbAttachment := portal.bridge.DB.Attachment.New()
dbAttachment.Channel = portal.Key
dbAttachment.MessageID = msgID
dbAttachment.ID = attachment.ID
dbAttachment.MXID = resp.EventID
dbAttachment.ThreadID = threadID
dbAttachment.Insert()
// Update the fallback reply event for the next attachment // Update the fallback reply event for the next attachment
if threadRelation != nil { if threadRelation != nil {
threadRelation.InReplyTo.EventID = resp.EventID threadRelation.InReplyTo.EventID = resp.EventID
} }
go portal.sendDeliveryReceipt(resp.EventID) return &database.MessagePart{
AttachmentID: attachment.ID,
MXID: resp.EventID,
}
} }
func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) { func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) {
@@ -604,6 +596,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
threadRelation = (&event.RelatesTo{}).SetThread(thread.RootMXID, lastEventID) threadRelation = (&event.RelatesTo{}).SetThread(thread.RootMXID, lastEventID)
} }
var parts []database.MessagePart
ts, _ := discordgo.SnowflakeTimestamp(msg.ID) ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
if msg.Content != "" { if msg.Content != "" {
content := renderDiscordMarkdown(msg.Content) content := renderDiscordMarkdown(msg.Content)
@@ -612,11 +605,11 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
if msg.MessageReference != nil { if msg.MessageReference != nil {
//key := database.PortalKey{msg.MessageReference.ChannelID, user.ID} //key := database.PortalKey{msg.MessageReference.ChannelID, user.ID}
replyTo := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID) replyTo := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID)
if replyTo != nil { if len(replyTo) > 0 {
if content.RelatesTo == nil { if content.RelatesTo == nil {
content.RelatesTo = &event.RelatesTo{} content.RelatesTo = &event.RelatesTo{}
} }
content.RelatesTo.SetReplyTo(replyTo.MXID) content.RelatesTo.SetReplyTo(replyTo[0].MXID)
} }
} }
@@ -626,7 +619,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
return return
} }
portal.markMessageHandled(msg.ID, resp.EventID, msg.Author.ID, ts, threadID) parts = append(parts, database.MessagePart{MXID: resp.EventID})
// Update the fallback reply event for attachments // Update the fallback reply event for attachments
if threadRelation != nil { if threadRelation != nil {
threadRelation.InReplyTo.EventID = resp.EventID threadRelation.InReplyTo.EventID = resp.EventID
@@ -635,8 +628,12 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
} }
for _, attachment := range msg.Attachments { for _, attachment := range msg.Attachments {
portal.handleDiscordAttachment(intent, msg.ID, attachment, ts, threadRelation, threadID) part := portal.handleDiscordAttachment(intent, msg.ID, attachment, ts, threadRelation, threadID)
if part != nil {
parts = append(parts, *part)
}
} }
portal.markMessageHandled(msg.ID, 0, msg.Author.ID, ts, threadID, parts)
} }
func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) { func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) {
@@ -653,7 +650,7 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
} }
if msg.Flags == discordgo.MessageFlagsHasThread { if msg.Flags == discordgo.MessageFlagsHasThread {
portal.bridge.GetThreadByID(msg.ID, existing) portal.bridge.GetThreadByID(msg.ID, existing[0])
portal.log.Debugfln("Marked %s as a thread root", msg.ID) portal.log.Debugfln("Marked %s as a thread root", msg.ID)
// TODO make autojoining configurable // TODO make autojoining configurable
//err := user.Session.ThreadJoinWithLocation(msg.ID, discordgo.ThreadJoinLocationContextMenu) //err := user.Session.ThreadJoinWithLocation(msg.ID, discordgo.ThreadJoinLocationContextMenu)
@@ -682,46 +679,12 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
intent := portal.bridge.GetPuppetByID(msg.Author.ID).IntentFor(portal) intent := portal.bridge.GetPuppetByID(msg.Author.ID).IntentFor(portal)
if existing == nil { if msg.Content == "" || existing[0].AttachmentID != "" {
// Due to the differences in Discord and Matrix attachment handling, portal.log.Debugfln("Dropping non-text edit to %s", msg.ID)
// existing will return nil if the original message was empty as we
// don't store/save those messages so we can determine when we're
// working against an attachment and do the attachment lookup instead.
// Find all the existing attachments and drop them in a map so we can
// figure out which, if any have been deleted and clean them up on the
// matrix side.
attachmentMap := map[string]*database.Attachment{}
attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID)
for _, attachment := range attachments {
attachmentMap[attachment.ID] = attachment
}
// Now run through the list of attachments on this message and remove
// them from the map.
for _, attachment := range msg.Attachments {
if _, found := attachmentMap[attachment.ID]; found {
delete(attachmentMap, attachment.ID)
}
}
// Finally run through any attachments still in the map and delete them
// on the matrix side and our database.
for _, attachment := range attachmentMap {
_, err := intent.RedactEvent(portal.MXID, attachment.MXID)
if err != nil {
portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err)
}
attachment.Delete()
}
return return
} }
content := renderDiscordMarkdown(msg.Content) content := renderDiscordMarkdown(msg.Content)
content.SetEdit(existing.MXID) content.SetEdit(existing[0].MXID)
var editTS int64 var editTS int64
if msg.EditedTimestamp != nil { if msg.EditedTimestamp != nil {
@@ -742,33 +705,20 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
} }
func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) { func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) {
// The discord delete message object is pretty empty and doesn't include
// the author so we have to use the DMUser from the portal that was added
// at creation time if we're a DM. We'll might have similar issues when we
// add guild message support, but we'll cross that bridge when we get
// there.
// Find the message that we're working with. This could correctly return
// nil if the message was just one or more attachments.
existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID)
intent := portal.MainIntent() intent := portal.MainIntent()
var lastResp id.EventID
if existing != nil { for _, dbMsg := range existing {
attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID) resp, err := intent.RedactEvent(portal.MXID, dbMsg.MXID)
for _, attachment := range attachments {
_, err := intent.RedactEvent(portal.MXID, attachment.MXID)
if err != nil {
portal.log.Warnfln("Failed to redact attachment %s: %v", attachment.MXID, err)
}
attachment.Delete()
}
resp, err := intent.RedactEvent(portal.MXID, existing.MXID)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to redact message %s: %v", existing.MXID, err) portal.log.Warnfln("Failed to redact message %s: %v", dbMsg.MXID, err)
} else if resp != nil && resp.EventID != "" {
lastResp = resp.EventID
} }
existing.Delete() dbMsg.Delete()
portal.sendDeliveryReceipt(resp.EventID) }
if lastResp != "" {
portal.sendDeliveryReceipt(lastResp)
} }
} }
@@ -1017,6 +967,9 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
dbMsg := portal.bridge.DB.Message.New() dbMsg := portal.bridge.DB.Message.New()
dbMsg.Channel = portal.Key dbMsg.Channel = portal.Key
dbMsg.DiscordID = msg.ID dbMsg.DiscordID = msg.ID
if len(msg.Attachments) > 0 {
dbMsg.AttachmentID = msg.Attachments[0].ID
}
dbMsg.MXID = evt.ID dbMsg.MXID = evt.ID
dbMsg.SenderID = sender.DiscordID dbMsg.SenderID = sender.DiscordID
dbMsg.Timestamp, _ = discordgo.SnowflakeTimestamp(msg.ID) dbMsg.Timestamp, _ = discordgo.SnowflakeTimestamp(msg.ID)
@@ -1161,34 +1114,15 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) {
return return
} }
var discordID, threadID string
channelID := portal.Key.ChannelID
msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID) msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID)
// Due to the differences in attachments between Discord and Matrix, if a
// user reacts to a media message on discord our lookup above will fail
// because the relation of matrix media messages to attachments in handled
// in the attachments table instead of messages so we need to check that
// before continuing.
//
// This also leads to interesting problems when a Discord message comes in
// with multiple attachments. A user can react to each one individually on
// Matrix, which will cause us to send it twice. Discord tends to ignore
// this, but if the user removes one of them, discord removes it and now
// they're out of sync. Perhaps we should add a counter to the reactions
// table to keep them in sync and to avoid sending duplicates to Discord.
if msg == nil { if msg == nil {
attachment := portal.bridge.DB.Attachment.GetByMatrixID(portal.Key, reaction.RelatesTo.EventID) portal.bridge.SendMessageErrorCheckpoint(evt, bridge.MsgStepRemote, errors.New("unknown reaction target"), true, 0)
if attachment == nil { }
portal.bridge.SendMessageErrorCheckpoint(evt, bridge.MsgStepRemote, errors.New("unknown reaction target"), true, 0)
return firstMsg := msg
} if msg.AttachmentID != "" {
discordID = attachment.MessageID firstMsg = portal.bridge.DB.Message.GetFirstByDiscordID(portal.Key, msg.DiscordID)
threadID = attachment.ThreadID // TODO should the emoji be rerouted to the first message if it's different?
} else {
discordID = msg.DiscordID
threadID = msg.ThreadID
} }
// Figure out if this is a custom emoji or not. // Figure out if this is a custom emoji or not.
@@ -1207,22 +1141,28 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) {
emojiID = variationselector.Remove(emojiID) emojiID = variationselector.Remove(emojiID)
} }
if threadID != "" { existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, msg.DiscordID, sender.DiscordID, emojiID)
channelID = threadID if existing != nil {
portal.log.Debugfln("Dropping duplicate Matrix reaction %s (already sent as %s)", evt.ID, existing.MXID)
portal.bridge.SendMessageSuccessCheckpoint(evt, bridge.MsgStepRemote, 0)
portal.sendDeliveryReceipt(evt.ID)
return
} }
err := sender.Session.MessageReactionAdd(channelID, discordID, emojiID)
err := sender.Session.MessageReactionAdd(msg.DiscordProtoChannelID(), msg.DiscordID, emojiID)
if err != nil { if err != nil {
portal.log.Debugf("Failed to send reaction to %s: %v", discordID, err) portal.log.Debugf("Failed to send reaction to %s: %v", msg.DiscordID, err)
portal.bridge.SendMessageErrorCheckpoint(evt, bridge.MsgStepRemote, err, true, 0) portal.bridge.SendMessageErrorCheckpoint(evt, bridge.MsgStepRemote, err, true, 0)
return return
} }
dbReaction := portal.bridge.DB.Reaction.New() dbReaction := portal.bridge.DB.Reaction.New()
dbReaction.Channel = portal.Key dbReaction.Channel = portal.Key
dbReaction.MessageID = discordID dbReaction.MessageID = msg.DiscordID
dbReaction.FirstAttachmentID = firstMsg.AttachmentID
dbReaction.Sender = sender.DiscordID dbReaction.Sender = sender.DiscordID
dbReaction.EmojiName = emojiID dbReaction.EmojiName = emojiID
dbReaction.ThreadID = threadID dbReaction.ThreadID = msg.ThreadID
dbReaction.MXID = evt.ID dbReaction.MXID = evt.ID
dbReaction.Insert() dbReaction.Insert()
portal.log.Debugfln("Handled Matrix reaction %s", evt.ID) portal.log.Debugfln("Handled Matrix reaction %s", evt.ID)
@@ -1272,12 +1212,11 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
message := portal.bridge.DB.Message.GetByDiscordID(portal.Key, reaction.MessageID) message := portal.bridge.DB.Message.GetByDiscordID(portal.Key, reaction.MessageID)
if message == nil { if message == nil {
portal.log.Debugfln("failed to add reaction to message %s: message not found", reaction.MessageID) portal.log.Debugfln("failed to add reaction to message %s: message not found", reaction.MessageID)
return return
} }
// Lookup an existing reaction // Lookup an existing reaction
existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message.DiscordID, reaction.UserID, discordID) existing := portal.bridge.DB.Reaction.GetByDiscordID(portal.Key, message[0].DiscordID, reaction.UserID, discordID)
if !add { if !add {
if existing == nil { if existing == nil {
portal.log.Debugln("Failed to remove reaction for unknown message", reaction.MessageID) portal.log.Debugln("Failed to remove reaction for unknown message", reaction.MessageID)
@@ -1293,13 +1232,13 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
go portal.sendDeliveryReceipt(resp.EventID) go portal.sendDeliveryReceipt(resp.EventID)
return return
} else if existing != nil { } else if existing != nil {
portal.log.Debugfln("Ignoring duplicate reaction %s from %s to %s", discordID, reaction.UserID, message.DiscordID) portal.log.Debugfln("Ignoring duplicate reaction %s from %s to %s", discordID, reaction.UserID, message[0].DiscordID)
return return
} }
content := event.Content{Parsed: &event.ReactionEventContent{ content := event.Content{Parsed: &event.ReactionEventContent{
RelatesTo: event.RelatesTo{ RelatesTo: event.RelatesTo{
EventID: message.MXID, EventID: message[0].MXID,
Type: event.RelAnnotation, Type: event.RelAnnotation,
Key: matrixReaction, Key: matrixReaction,
}, },
@@ -1320,7 +1259,8 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
if existing == nil { if existing == nil {
dbReaction := portal.bridge.DB.Reaction.New() dbReaction := portal.bridge.DB.Reaction.New()
dbReaction.Channel = portal.Key dbReaction.Channel = portal.Key
dbReaction.MessageID = message.DiscordID dbReaction.MessageID = message[0].DiscordID
dbReaction.FirstAttachmentID = message[0].AttachmentID
dbReaction.Sender = reaction.UserID dbReaction.Sender = reaction.UserID
dbReaction.EmojiName = discordID dbReaction.EmojiName = discordID
dbReaction.MXID = resp.EventID dbReaction.MXID = resp.EventID