diff --git a/bridge/portal.go b/bridge/portal.go index fdf51d9..ec5062c 100644 --- a/bridge/portal.go +++ b/bridge/portal.go @@ -291,6 +291,8 @@ func (p *Portal) handleDiscordMessages(msg portalDiscordMessage) { p.handleDiscordMessage(msg.user, msg.msg.(*discordgo.MessageCreate).Message) case *discordgo.MessageReactionAdd: p.handleDiscordReaction(msg.user, msg.msg.(*discordgo.MessageReactionAdd).MessageReaction, true) + case *discordgo.MessageReactionRemove: + p.handleDiscordReaction(msg.user, msg.msg.(*discordgo.MessageReactionRemove).MessageReaction, false) default: p.log.Warnln("unknown message type") } @@ -549,12 +551,18 @@ func (p *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageRe return } - if reaction.Emoji.ID != "" { + // Emoji.ID is only set if it's a custom emote, otherwise Emoji.Name is + // used. + customEmote := (reaction.Emoji.ID != "") + + // This is temporary until we add support for custom emoji. + if customEmote { p.log.Debugln("ignoring non-unicode reaction") return } + // Find the message that we're working with. message := p.bridge.db.Message.GetByDiscordID(p.Key, reaction.MessageID) if message == nil { p.log.Debugfln("failed to add reaction to message %s: message not found", reaction.MessageID) @@ -562,8 +570,34 @@ func (p *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageRe return } + // Lookup an existing reaction + var existing *database.Reaction + + if customEmote { + existing = p.bridge.db.Reaction.GetByDiscordID(p.Key, message.DiscordID, reaction.Emoji.ID) + } else { + existing = p.bridge.db.Reaction.GetByDiscordName(p.Key, message.DiscordID, reaction.Emoji.Name) + } + + if !add && existing == nil { + p.log.Debugln("Failed to remove emote for unknown message", reaction.MessageID) + + return + } + intent := p.bridge.GetPuppetByID(reaction.UserID).IntentFor(p) + if !add { + _, err := intent.RedactEvent(p.MXID, existing.MatrixEventID) + if err != nil { + p.log.Warnfln("Failed to remove reaction from %s: %v", p.MXID, err) + } + + existing.Delete() + + return + } + content := event.Content{Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ EventID: message.MatrixID, @@ -572,10 +606,29 @@ func (p *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageRe }, }} - _, err := intent.Client.SendMessageEvent(p.MXID, event.EventReaction, &content) - if err != nil { - p.log.Errorfln("failed to send reaction from %s: %v", reaction.MessageID, err) + if add { + resp, err := intent.Client.SendMessageEvent(p.MXID, event.EventReaction, &content) + if err != nil { + p.log.Errorfln("failed to send reaction from %s: %v", reaction.MessageID, err) - return + return + } + + if existing == nil { + dbReaction := p.bridge.db.Reaction.New() + dbReaction.Channel = p.Key + dbReaction.DiscordMessageID = message.DiscordID + dbReaction.MatrixEventID = resp.EventID + dbReaction.AuthorID = reaction.UserID + + if customEmote { + // TODO: + } else { + dbReaction.MatrixName = reaction.Emoji.Name + dbReaction.DiscordName = reaction.Emoji.Name + } + + dbReaction.Insert() + } } } diff --git a/bridge/user.go b/bridge/user.go index 8ca1c69..1a8501c 100644 --- a/bridge/user.go +++ b/bridge/user.go @@ -213,7 +213,8 @@ func (u *User) Connect() error { u.User.Session.AddHandler(u.channelUpdateHandler) u.User.Session.AddHandler(u.messageHandler) - u.User.Session.AddHandler(u.reactionHandler) + u.User.Session.AddHandler(u.reactionAddHandler) + u.User.Session.AddHandler(u.reactionRemoveHandler) // u.User.Session.Identify.Capabilities = 125 // // Setup our properties @@ -296,7 +297,25 @@ func (u *User) messageHandler(s *discordgo.Session, m *discordgo.MessageCreate) portal.discordMessages <- msg } -func (u *User) reactionHandler(s *discordgo.Session, m *discordgo.MessageReactionAdd) { +func (u *User) reactionAddHandler(s *discordgo.Session, m *discordgo.MessageReactionAdd) { + if m.GuildID != "" { + u.log.Debugln("ignoring reaction for guild message") + + return + } + + key := database.NewPortalKey(m.ChannelID, u.User.ID) + portal := u.bridge.GetPortalByID(key) + + msg := portalDiscordMessage{ + msg: m, + user: u, + } + + portal.discordMessages <- msg +} + +func (u *User) reactionRemoveHandler(s *discordgo.Session, m *discordgo.MessageReactionRemove) { if m.GuildID != "" { u.log.Debugln("ignoring reaction for guild message") diff --git a/database/database.go b/database/database.go index 2e5c036..37d6aac 100644 --- a/database/database.go +++ b/database/database.go @@ -16,10 +16,11 @@ type Database struct { log log.Logger dialect string - User *UserQuery - Portal *PortalQuery - Puppet *PuppetQuery - Message *MessageQuery + User *UserQuery + Portal *PortalQuery + Puppet *PuppetQuery + Message *MessageQuery + Reaction *ReactionQuery } func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger) (*Database, error) { @@ -67,5 +68,10 @@ func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger) log: db.log.Sub("Message"), } + db.Reaction = &ReactionQuery{ + db: db, + log: db.log.Sub("Reaction"), + } + return db, nil } diff --git a/database/migrations/01-initial.sql b/database/migrations/01-initial.sql index 1d4bf12..bbebc7b 100644 --- a/database/migrations/01-initial.sql +++ b/database/migrations/01-initial.sql @@ -47,6 +47,26 @@ CREATE TABLE message ( FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE ); +CREATE TABLE reaction ( + channel_id TEXT NOT NULL, + receiver TEXT NOT NULL, + + discord_message_id TEXT NOT NULL, + matrix_event_id TEXT NOT NULL UNIQUE, + + author_id TEXT NOT NULL, + + matrix_name TEXT, + matrix_url TEXT, + + discord_name TEXT, + discord_id TEXT, + + CHECK ((discord_name IS NULL AND discord_id IS NOT NULL) OR (discord_name IS NOT NULL AND discord_id IS NULL)), + UNIQUE (discord_name, discord_id, author_id, discord_message_id, channel_id, receiver), + FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE +); + CREATE TABLE mx_user_profile ( room_id TEXT, user_id TEXT, diff --git a/database/reaction.go b/database/reaction.go new file mode 100644 index 0000000..8e634f3 --- /dev/null +++ b/database/reaction.go @@ -0,0 +1,115 @@ +package database + +import ( + "database/sql" + "errors" + + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" +) + +type Reaction struct { + db *Database + log log.Logger + + Channel PortalKey + + DiscordMessageID string + MatrixEventID id.EventID + + // The discord ID of who create this reaction + AuthorID string + + MatrixName string + MatrixURL string // Used for custom emoji + + DiscordName string // Used for unicode emoji + DiscordID string // Used for custom emoji +} + +func (r *Reaction) Scan(row Scannable) *Reaction { + var discordName, discordID sql.NullString + + err := row.Scan( + &r.Channel.ChannelID, &r.Channel.Receiver, + &r.DiscordMessageID, &r.MatrixEventID, + &r.AuthorID, + &r.MatrixName, &r.MatrixURL, + &discordName, &discordID) + + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + r.log.Errorln("Database scan failed:", err) + } + + return nil + } + + r.DiscordName = discordName.String + r.DiscordID = discordID.String + + return r +} + +func (r *Reaction) Insert() { + query := "INSERT INTO reaction" + + " (channel_id, receiver, discord_message_id, matrix_event_id," + + " author_id, matrix_name, matrix_url, discord_name, discord_id)" + + " VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9);" + + var discordName, discordID sql.NullString + + if r.DiscordName != "" { + discordName = sql.NullString{r.DiscordName, true} + } + + if r.DiscordID != "" { + discordID = sql.NullString{r.DiscordID, true} + } + + _, err := r.db.Exec( + query, + r.Channel.ChannelID, r.Channel.Receiver, + r.DiscordMessageID, r.MatrixEventID, + r.AuthorID, + r.MatrixName, r.MatrixURL, + discordName, discordID, + ) + + if err != nil { + r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.Channel, r.DiscordMessageID, err) + } +} + +func (r *Reaction) Update() { + // TODO: determine if we need this. The only scenario I can think of that + // would require this is if we insert a custom emoji before uploading to + // the homeserver? +} + +func (r *Reaction) Delete() { + query := "DELETE FROM reaction WHERE" + + " channel_id=$1 AND receiver=$2 AND discord_message_id=$3 AND" + + " author_id=$4 AND discord_name=$5 AND discord_id=$6" + + var discordName, discordID sql.NullString + + if r.DiscordName != "" { + discordName = sql.NullString{r.DiscordName, true} + } + + if r.DiscordID != "" { + discordID = sql.NullString{r.DiscordID, true} + } + + _, err := r.db.Exec( + query, + r.Channel.ChannelID, r.Channel.Receiver, + r.DiscordMessageID, r.AuthorID, + discordName, discordID, + ) + + if err != nil { + r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.Channel, r.DiscordMessageID, err) + } +} diff --git a/database/reactionquery.go b/database/reactionquery.go new file mode 100644 index 0000000..6a5eaa9 --- /dev/null +++ b/database/reactionquery.go @@ -0,0 +1,82 @@ +package database + +import ( + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" +) + +type ReactionQuery struct { + db *Database + log log.Logger +} + +const ( + reactionSelect = "SELECT channel_id, receiver, discord_message_id," + + " matrix_event_id, author_id, matrix_name, matrix_url, " + + " discord_name, discord_id FROM reaction" +) + +func (rq *ReactionQuery) New() *Reaction { + return &Reaction{ + db: rq.db, + log: rq.log, + } +} + +func (rq *ReactionQuery) GetAllByDiscordID(key PortalKey, discordMessageID string) []*Reaction { + query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + + " discord_message_id=$3" + + return rq.getAll(query, key.ChannelID, key.Receiver, discordMessageID) +} + +func (rq *ReactionQuery) GetAllByMatrixID(key PortalKey, matrixEventID id.EventID) []*Reaction { + query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + + " matrix_event_id=$3" + + return rq.getAll(query, key.ChannelID, key.Receiver, matrixEventID) +} + +func (rq *ReactionQuery) getAll(query string, args ...interface{}) []*Reaction { + rows, err := rq.db.Query(query) + if err != nil || rows == nil { + return nil + } + + reactions := []*Reaction{} + for rows.Next() { + reactions = append(reactions, rq.New().Scan(rows)) + } + + return reactions +} + +func (rq *ReactionQuery) GetByDiscordName(key PortalKey, discordMessageID, discordName string) *Reaction { + query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" + + " AND discord_message_id=$3 AND discord_name=$4" + + return rq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordName) +} + +func (rq *ReactionQuery) GetByDiscordID(key PortalKey, discordMessageID, discordID string) *Reaction { + query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" + + " AND discord_message_id=$3 AND discord_id=$4" + + return rq.get(query, key.ChannelID, key.Receiver, discordMessageID, discordID) +} + +func (rq *ReactionQuery) GetByMatrixName(key PortalKey, matrixEventID id.EventID, matrixName string) *Reaction { + query := reactionSelect + " WHERE channel_id=$1 AND receiver=$2" + + " AND matrix_event_id=$3 AND matrix_name=$4" + + return rq.get(query, key.ChannelID, key.Receiver, matrixEventID, matrixName) +} + +func (rq *ReactionQuery) get(query string, args ...interface{}) *Reaction { + row := rq.db.QueryRow(query, args...) + if row == nil { + return nil + } + + return rq.New().Scan(row) +}