From 3e9f927ed44ddbab46a2c2e243b8cc47a6b52654 Mon Sep 17 00:00:00 2001 From: Gary Kramlich Date: Tue, 8 Feb 2022 05:18:25 -0600 Subject: [PATCH] Add support for deleting messages from discord --- bridge/portal.go | 70 ++++++++++++++++++------------ bridge/user.go | 30 +++++++++++-- database/migrations/01-initial.sql | 3 ++ database/portal.go | 27 ++++++++---- 4 files changed, 90 insertions(+), 40 deletions(-) diff --git a/bridge/portal.go b/bridge/portal.go index da2876d..38d94ba 100644 --- a/bridge/portal.go +++ b/bridge/portal.go @@ -32,8 +32,6 @@ type Portal struct { bridge *Bridge log log.Logger - channel *discordgo.Channel - roomCreateLock sync.Mutex discordMessages chan portalDiscordMessage @@ -159,39 +157,22 @@ func (p *Portal) messageLoop() { } func (p *Portal) IsPrivateChat() bool { - if p.channel != nil { - return p.channel.Type == discordgo.ChannelTypeDM - } - - return false + return p.Type == discordgo.ChannelTypeDM } func (p *Portal) MainIntent() *appservice.IntentAPI { - if p.IsPrivateChat() && p.channel != nil && len(p.channel.Recipients) == 1 { - return p.bridge.GetPuppetByID(p.channel.Recipients[0].ID).DefaultIntent() + if p.IsPrivateChat() && p.DMUser != "" { + return p.bridge.GetPuppetByID(p.DMUser).DefaultIntent() } return p.bridge.bot } -func (p *Portal) getMessagePuppet(user *User, message *discordgo.Message) *Puppet { - p.log.Debugf("getMessagePuppet") - if message.Author.ID == user.ID { - return p.bridge.GetPuppetByID(user.ID) - } - - puppet := p.bridge.GetPuppetByID(message.Author.ID) - puppet.SyncContact(user) - - return puppet -} - -func (p *Portal) getMessageIntent(user *User, message *discordgo.Message) *appservice.IntentAPI { - return p.getMessagePuppet(user, message).IntentFor(p) -} - func (p *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) error { - p.channel = channel + p.Type = channel.Type + if p.Type == discordgo.ChannelTypeDM { + p.DMUser = channel.Recipients[0].ID + } p.roomCreateLock.Lock() defer p.roomCreateLock.Unlock() @@ -288,7 +269,9 @@ func (p *Portal) handleDiscordMessages(msg portalDiscordMessage) { switch msg.msg.(type) { case *discordgo.MessageCreate: - p.handleDiscordMessage(msg.user, msg.msg.(*discordgo.MessageCreate).Message) + p.handleDiscordMessageCreate(msg.user, msg.msg.(*discordgo.MessageCreate).Message) + case *discordgo.MessageDelete: + p.handleDiscordMessageDelete(msg.user, msg.msg.(*discordgo.MessageDelete).Message) case *discordgo.MessageReactionAdd: p.handleDiscordReaction(msg.user, msg.msg.(*discordgo.MessageReactionAdd).MessageReaction, true) case *discordgo.MessageReactionRemove: @@ -318,7 +301,7 @@ func (p *Portal) markMessageHandled(msg *database.Message, discordID string, mxi return msg } -func (p *Portal) handleDiscordMessage(user *User, msg *discordgo.Message) { +func (p *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message) { if user.ID == msg.Author.ID { return } @@ -353,6 +336,37 @@ func (p *Portal) handleDiscordMessage(user *User, msg *discordgo.Message) { p.markMessageHandled(nil, msg.ID, resp.EventID, msg.Author.ID, ts) } +func (p *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Message) { + // The discord delete message object is pretty empty and doesn't include + // the author so we have to use the DMUser from the portal that was added + // at creation time if we're a DM. We'll might have similar issues when we + // add guild message support, but we'll cross that bridge when we get + // there. + + // Find the message that we're working with. + existing := p.bridge.db.Message.GetByDiscordID(p.Key, msg.ID) + if existing == nil { + p.log.Debugfln("failed to find message", msg.ID) + + return + } + + var intent *appservice.IntentAPI + + if p.Type == discordgo.ChannelTypeDM { + intent = p.bridge.GetPuppetByID(p.DMUser).IntentFor(p) + } else { + p.log.Errorfln("no guilds yet...") + } + + _, err := intent.RedactEvent(p.MXID, existing.MatrixID) + if err != nil { + p.log.Warnfln("Failed to remove message %s: %v", existing.MatrixID, err) + } + + existing.Delete() +} + func (p *Portal) syncParticipants(source *User, participants []*discordgo.User) { for _, participant := range participants { puppet := p.bridge.GetPuppetByID(participant.ID) diff --git a/bridge/user.go b/bridge/user.go index 1a8501c..a38b54e 100644 --- a/bridge/user.go +++ b/bridge/user.go @@ -212,7 +212,8 @@ func (u *User) Connect() error { u.User.Session.AddHandler(u.channelPinsUpdateHandler) u.User.Session.AddHandler(u.channelUpdateHandler) - u.User.Session.AddHandler(u.messageHandler) + u.User.Session.AddHandler(u.messageCreateHandler) + u.User.Session.AddHandler(u.messageDeleteHandler) u.User.Session.AddHandler(u.reactionAddHandler) u.User.Session.AddHandler(u.reactionRemoveHandler) @@ -249,6 +250,11 @@ func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCr portal.Name = c.Name portal.Topic = c.Topic + portal.Type = c.Type + + if portal.Type == discordgo.ChannelTypeDM { + portal.DMUser = c.Recipients[0].ID + } if c.Icon != "" { u.log.Debugln("channel icon", c.Icon) @@ -279,14 +285,32 @@ func (u *User) channelUpdateHandler(s *discordgo.Session, c *discordgo.ChannelUp u.log.Debugln("channel update") } -func (u *User) messageHandler(s *discordgo.Session, m *discordgo.MessageCreate) { +func (u *User) messageCreateHandler(s *discordgo.Session, m *discordgo.MessageCreate) { if m.GuildID != "" { u.log.Debugln("ignoring message for guild") return } - key := database.NewPortalKey(m.ChannelID, u.User.ID) + key := database.NewPortalKey(m.ChannelID, u.ID) + portal := u.bridge.GetPortalByID(key) + + msg := portalDiscordMessage{ + msg: m, + user: u, + } + + portal.discordMessages <- msg +} + +func (u *User) messageDeleteHandler(s *discordgo.Session, m *discordgo.MessageDelete) { + if m.GuildID != "" { + u.log.Debugln("ignoring message delete for guild message") + + return + } + + key := database.NewPortalKey(m.ChannelID, u.ID) portal := u.bridge.GetPortalByID(key) msg := portalDiscordMessage{ diff --git a/database/migrations/01-initial.sql b/database/migrations/01-initial.sql index ccb1329..fb521a4 100644 --- a/database/migrations/01-initial.sql +++ b/database/migrations/01-initial.sql @@ -9,6 +9,9 @@ CREATE TABLE portal ( avatar TEXT NOT NULL, avatar_url TEXT, + type INT, + dmuser TEXT, + first_event_id TEXT, PRIMARY KEY (channel_id, receiver) diff --git a/database/portal.go b/database/portal.go index 88f0ed9..b21d298 100644 --- a/database/portal.go +++ b/database/portal.go @@ -3,6 +3,8 @@ package database import ( "database/sql" + "github.com/bwmarrin/discordgo" + log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" ) @@ -20,7 +22,7 @@ type Portal struct { Avatar string AvatarURL id.ContentURI - Type int + Type discordgo.ChannelType DMUser string FirstEventID id.EventID @@ -28,8 +30,11 @@ type Portal struct { func (p *Portal) Scan(row Scannable) *Portal { var mxid, avatarURL, firstEventID sql.NullString + var typ sql.NullInt32 + + err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name, + &p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID) - err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name, &p.Topic, &p.Avatar, &avatarURL, &firstEventID) if err != nil { if err != sql.ErrNoRows { p.log.Errorln("Database scan failed:", err) @@ -40,6 +45,7 @@ func (p *Portal) Scan(row Scannable) *Portal { p.MXID = id.RoomID(mxid.String) p.AvatarURL, _ = id.ParseContentURI(avatarURL.String) + p.Type = discordgo.ChannelType(typ.Int32) p.FirstEventID = id.EventID(firstEventID.String) return p @@ -47,11 +53,13 @@ func (p *Portal) Scan(row Scannable) *Portal { func (p *Portal) Insert() { query := "INSERT INTO portal" + - " (channel_id, receiver, mxid, name, topic, avatar, avatar_url, first_event_id)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + " (channel_id, receiver, mxid, name, topic, avatar, avatar_url," + + " type, dmuser, first_event_id)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.MXID, - p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.FirstEventID.String()) + p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.DMUser, + p.FirstEventID.String()) if err != nil { p.log.Warnfln("Failed to insert %s: %v", p.Key, err) @@ -60,12 +68,13 @@ func (p *Portal) Insert() { func (p *Portal) Update() { query := "UPDATE portal SET" + - " mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, first_event_id=$6" + - " WHERE channel_id=$7 AND receiver=$8" + " mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," + + " dmuser=$7, first_event_id=$8" + + " WHERE channel_id=$9 AND receiver=$10" _, err := p.db.Exec(query, p.MXID, p.Name, p.Topic, p.Avatar, - p.AvatarURL.String(), p.FirstEventID.String(), p.Key.ChannelID, - p.Key.Receiver) + p.AvatarURL.String(), p.Type, p.DMUser, p.FirstEventID.String(), + p.Key.ChannelID, p.Key.Receiver) if err != nil { p.log.Warnfln("Failed to update %s: %v", p.Key, err)