diff --git a/bridge/matrix.go b/bridge/matrix.go index 7e4686a..b99f5ad 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -226,7 +226,7 @@ func (mh *matrixHandler) handleMembership(evt *event.Event) { } else if puppet != nil { portal.handleMatrixKick(user, puppet) } - } else if content.Membership == event.MembershipInvite && !isSelf { + } else if content.Membership == event.MembershipInvite { portal.handleMatrixInvite(user, evt) } } diff --git a/bridge/portal.go b/bridge/portal.go index 620fc8e..9328cd4 100644 --- a/bridge/portal.go +++ b/bridge/portal.go @@ -3,6 +3,7 @@ package bridge import ( "fmt" "sync" + "time" "github.com/bwmarrin/discordgo" @@ -297,21 +298,51 @@ func (p *Portal) ensureUserInvited(user *User) bool { return user.ensureInvited(p.MainIntent(), p.MXID, p.IsPrivateChat()) } +func (p *Portal) markMessageHandled(msg *database.Message, discordID string, mxid id.EventID, authorID string, timestamp time.Time) *database.Message { + if msg == nil { + msg := p.bridge.db.Message.New() + msg.Channel = p.Key + msg.DiscordID = discordID + msg.MatrixID = mxid + msg.AuthorID = authorID + msg.Timestamp = timestamp + msg.Insert() + } else { + msg.UpdateMatrixID(mxid) + } + + return msg +} + func (p *Portal) handleDiscordMessage(msg *discordgo.Message) { if p.MXID == "" { p.log.Warnln("handle message called without a valid portal") + + return + } + + existing := p.bridge.db.Message.GetByDiscordID(p.Key, msg.ID) + if existing != nil { + p.log.Debugln("not handling duplicate message", msg.ID) + return } - // TODO: Check if we already got the message content := &event.MessageEventContent{ Body: msg.Content, MsgType: event.MsgText, } - resp, err := p.MainIntent().SendMessageEvent(p.MXID, event.EventMessage, content) - p.log.Warnln("response:", resp) - p.log.Warnln("error:", err) + intent := p.bridge.GetPuppetByID(msg.Author.ID).IntentFor(p) + + resp, err := intent.SendMessageEvent(p.MXID, event.EventMessage, content) + if err != nil { + p.log.Warnfln("failed to send message %q to matrix: %v", msg.ID, err) + return + } + + ts, _ := msg.Timestamp.Parse() + p.markMessageHandled(nil, msg.ID, resp.EventID, msg.Author.ID, ts) } func (p *Portal) syncParticipants(source *User, participants []*discordgo.User) { @@ -344,6 +375,13 @@ func (p *Portal) handleMatrixMessage(sender *User, evt *event.Event) { return } + existing := p.bridge.db.Message.GetByMatrixID(p.Key, evt.ID) + if existing != nil { + p.log.Debugln("not handling duplicate message", evt.ID) + + return + } + content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok { p.log.Debugfln("Failed to handle event %s: unexpected parsed content type %T", evt.ID, evt.Content.Parsed) @@ -351,8 +389,20 @@ func (p *Portal) handleMatrixMessage(sender *User, evt *event.Event) { return } - sender.Session.ChannelMessageSend(p.Key.ChannelID, content.Body) - p.log.Debugln("sent message:", content.Body) + msg, err := sender.Session.ChannelMessageSend(p.Key.ChannelID, content.Body) + if err != nil { + p.log.Errorfln("Failed to send message: %v", err) + + return + } + + dbMsg := p.bridge.db.Message.New() + dbMsg.Channel = p.Key + dbMsg.DiscordID = msg.ID + dbMsg.MatrixID = evt.ID + dbMsg.AuthorID = sender.ID + dbMsg.Timestamp = time.Now() + dbMsg.Insert() } func (p *Portal) handleMatrixLeave(sender *User) { diff --git a/database/database.go b/database/database.go index 2edcc30..2e5c036 100644 --- a/database/database.go +++ b/database/database.go @@ -16,9 +16,10 @@ type Database struct { log log.Logger dialect string - User *UserQuery - Portal *PortalQuery - Puppet *PuppetQuery + User *UserQuery + Portal *PortalQuery + Puppet *PuppetQuery + Message *MessageQuery } func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger) (*Database, error) { @@ -61,5 +62,10 @@ func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger) log: db.log.Sub("Puppet"), } + db.Message = &MessageQuery{ + db: db, + log: db.log.Sub("Message"), + } + return db, nil } diff --git a/database/message.go b/database/message.go new file mode 100644 index 0000000..703120d --- /dev/null +++ b/database/message.go @@ -0,0 +1,66 @@ +package database + +import ( + "database/sql" + "errors" + "time" + + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" +) + +type Message struct { + db *Database + log log.Logger + + Channel PortalKey + + DiscordID string + MatrixID id.EventID + + AuthorID string + Timestamp time.Time +} + +func (m *Message) Scan(row Scannable) *Message { + var ts int64 + + err := row.Scan(&m.Channel.ChannelID, &m.Channel.Receiver, &m.DiscordID, &m.MatrixID, &m.AuthorID, &ts) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + m.log.Errorln("Database scan failed:", err) + } + + return nil + } + + if ts != 0 { + m.Timestamp = time.Unix(ts, 0) + } + + return m +} + +func (m *Message) Insert() { + query := "INSERT INTO message" + + " (channel_id, receiver, discord_message_id, matrix_message_id," + + " author_id, timestamp) VALUES ($1, $2, $3, $4, $5, $6)" + + _, err := m.db.Exec(query, m.Channel.ChannelID, m.Channel.Receiver, + m.DiscordID, m.MatrixID, m.AuthorID, m.Timestamp.Unix()) + + if err != nil { + m.log.Warnfln("Failed to insert %s@%s: %v", m.Channel, m.DiscordID, err) + } +} + +func (m *Message) UpdateMatrixID(mxid id.EventID) { + query := "UPDATE message SET matrix_message_id=$1 WHERE channel_id=$2" + + "AND receiver=$3 AND discord_message_id=$4" + m.MatrixID = mxid + + _, err := m.db.Exec(query, m.MatrixID, m.Channel.ChannelID, m.Channel.Receiver, m.DiscordID) + if err != nil { + m.log.Warnfln("Failed to update %s@%s: %v", m.Channel, m.DiscordID, err) + } +} diff --git a/database/messagequery.go b/database/messagequery.go new file mode 100644 index 0000000..adea68f --- /dev/null +++ b/database/messagequery.go @@ -0,0 +1,64 @@ +package database + +import ( + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" +) + +type MessageQuery struct { + db *Database + log log.Logger +} + +const ( + messageSelect = "SELECT channel_id, receiver, discord_message_id," + + " matrix_message_id, author_id, timestamp FROM message" +) + +func (mq *MessageQuery) New() *Message { + return &Message{ + db: mq.db, + log: mq.log, + } +} + +func (mq *MessageQuery) GetAll(key PortalKey) []*Message { + query := messageSelect + " WHERE channeld_id=$1 AND receiver=$2" + + rows, err := mq.db.Query(query, key.ChannelID, key.Receiver) + if err != nil || rows == nil { + return nil + } + + messages := []*Message{} + for rows.Next() { + messages = append(messages, mq.New().Scan(rows)) + } + + return messages +} + +func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message { + query := messageSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + + " discord_message_id=$3" + + row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID) + if row == nil { + mq.log.Debugfln("failed to find existing message for discord_id %s", discordID) + return nil + } + + return mq.New().Scan(row) +} + +func (mq *MessageQuery) GetByMatrixID(key PortalKey, matrixID id.EventID) *Message { + query := messageSelect + " WHERE channel_id=$1 AND receiver=$2 AND" + + " matrix_message_id=$3" + + row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, matrixID) + if row == nil { + return nil + } + + return mq.New().Scan(row) +} diff --git a/database/migrations/01-initial.sql b/database/migrations/01-initial.sql index 35ff22d..1d4bf12 100644 --- a/database/migrations/01-initial.sql +++ b/database/migrations/01-initial.sql @@ -33,6 +33,20 @@ CREATE TABLE user ( token TEXT ); +CREATE TABLE message ( + channel_id TEXT NOT NULL, + receiver TEXT NOT NULL, + + discord_message_id TEXT NOT NULL, + matrix_message_id TEXT NOT NULL UNIQUE, + + author_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + + PRIMARY KEY(discord_message_id, channel_id, receiver), + FOREIGN KEY(channel_id, receiver) REFERENCES portal(channel_id, receiver) ON DELETE CASCADE +); + CREATE TABLE mx_user_profile ( room_id TEXT, user_id TEXT, diff --git a/database/portal.go b/database/portal.go index ceb6a00..88f0ed9 100644 --- a/database/portal.go +++ b/database/portal.go @@ -20,6 +20,9 @@ type Portal struct { Avatar string AvatarURL id.ContentURI + Type int + DMUser string + FirstEventID id.EventID }