Add option for autojoining threads when opened

This commit is contained in:
Tulir Asokan
2022-10-28 23:35:31 +03:00
parent 6be531685f
commit f268ddd132
13 changed files with 219 additions and 70 deletions

View File

@@ -41,6 +41,7 @@ type BridgeConfig struct {
MessageStatusEvents bool `yaml:"message_status_events"` MessageStatusEvents bool `yaml:"message_status_events"`
MessageErrorNotices bool `yaml:"message_error_notices"` MessageErrorNotices bool `yaml:"message_error_notices"`
RestrictedRooms bool `yaml:"restricted_rooms"` RestrictedRooms bool `yaml:"restricted_rooms"`
AutojoinThreadOnOpen bool `yaml:"autojoin_thread_on_open"`
SyncDirectChatList bool `yaml:"sync_direct_chat_list"` SyncDirectChatList bool `yaml:"sync_direct_chat_list"`
ResendBridgeInfo bool `yaml:"resend_bridge_info"` ResendBridgeInfo bool `yaml:"resend_bridge_info"`
DeletePortalOnChannelDelete bool `yaml:"delete_portal_on_channel_delete"` DeletePortalOnChannelDelete bool `yaml:"delete_portal_on_channel_delete"`

View File

@@ -36,6 +36,7 @@ func DoUpgrade(helper *up.Helper) {
helper.Copy(up.Bool, "bridge", "message_status_events") helper.Copy(up.Bool, "bridge", "message_status_events")
helper.Copy(up.Bool, "bridge", "message_error_notices") helper.Copy(up.Bool, "bridge", "message_error_notices")
helper.Copy(up.Bool, "bridge", "restricted_rooms") helper.Copy(up.Bool, "bridge", "restricted_rooms")
helper.Copy(up.Bool, "bridge", "autojoin_thread_on_open")
helper.Copy(up.Bool, "bridge", "sync_direct_chat_list") helper.Copy(up.Bool, "bridge", "sync_direct_chat_list")
helper.Copy(up.Bool, "bridge", "resend_bridge_info") helper.Copy(up.Bool, "bridge", "resend_bridge_info")
helper.Copy(up.Bool, "bridge", "delete_portal_on_channel_delete") helper.Copy(up.Bool, "bridge", "delete_portal_on_channel_delete")

View File

@@ -60,9 +60,9 @@ func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Mes
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
} }
func (mq *MessageQuery) GetClosestBefore(key PortalKey, ts time.Time) *Message { func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND timestamp<=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, ts.UnixMilli())) return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli()))
} }
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message { func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
@@ -115,9 +115,8 @@ func (m *Message) DiscordProtoChannelID() string {
func (m *Message) Scan(row dbutil.Scannable) *Message { func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts int64 var ts int64
var threadID sql.NullString
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID) err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.ThreadID, &m.MXID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err) m.log.Errorln("Database scan failed:", err)
@@ -130,7 +129,6 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
if ts != 0 { if ts != 0 {
m.Timestamp = time.UnixMilli(ts) m.Timestamp = time.UnixMilli(ts)
} }
m.ThreadID = threadID.String
return m return m
} }
@@ -181,7 +179,7 @@ func (m *Message) MassInsert(msgs []MessagePart) {
func (m *Message) Insert() { func (m *Message) Insert() {
_, err := m.db.Exec(messageInsertQuery, _, err := m.db.Exec(messageInsertQuery,
m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID) m.Timestamp.UnixMilli(), m.ThreadID, m.MXID)
if err != nil { if err != nil {
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err) m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)

View File

@@ -16,7 +16,7 @@ type ThreadQuery struct {
} }
const ( const (
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread" threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid FROM thread"
) )
func (tq *ThreadQuery) New() *Thread { func (tq *ThreadQuery) New() *Thread {
@@ -37,17 +37,6 @@ func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
return tq.New().Scan(row) return tq.New().Scan(row)
} }
//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
//
// row := tq.db.QueryRow(query, channelID, messageID)
// if row == nil {
// return nil
// }
//
// return tq.New().Scan(row)
//}
func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread { func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
query := threadSelect + " WHERE root_msg_mxid=$1" query := threadSelect + " WHERE root_msg_mxid=$1"
@@ -59,6 +48,17 @@ func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
return tq.New().Scan(row) return tq.New().Scan(row)
} }
func (tq *ThreadQuery) GetByMatrixRootOrCreationNoticeMsg(mxid id.EventID) *Thread {
query := threadSelect + " WHERE root_msg_mxid=$1 OR creation_notice_mxid=$1"
row := tq.db.QueryRow(query, mxid)
if row == nil {
return nil
}
return tq.New().Scan(row)
}
type Thread struct { type Thread struct {
db *Database db *Database
log log.Logger log log.Logger
@@ -68,10 +68,12 @@ type Thread struct {
RootDiscordID string RootDiscordID string
RootMXID id.EventID RootMXID id.EventID
CreationNoticeMXID id.EventID
} }
func (t *Thread) Scan(row dbutil.Scannable) *Thread { func (t *Thread) Scan(row dbutil.Scannable) *Thread {
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID) err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID, &t.CreationNoticeMXID)
if err != nil { if err != nil {
if !errors.Is(err, sql.ErrNoRows) { if !errors.Is(err, sql.ErrNoRows) {
t.log.Errorln("Database scan failed:", err) t.log.Errorln("Database scan failed:", err)
@@ -83,21 +85,26 @@ func (t *Thread) Scan(row dbutil.Scannable) *Thread {
} }
func (t *Thread) Insert() { func (t *Thread) Insert() {
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)" query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid) VALUES ($1, $2, $3, $4, $5)"
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID, t.CreationNoticeMXID)
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
if err != nil { if err != nil {
t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err) t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
panic(err) panic(err)
} }
} }
func (t *Thread) Update() {
query := "UPDATE thread SET creation_notice_mxid=$2 WHERE dcid=$1"
_, err := t.db.Exec(query, t.ID, t.CreationNoticeMXID)
if err != nil {
t.log.Warnfln("Failed to update %s@%s: %v", t.ID, t.ParentID, err)
panic(err)
}
}
func (t *Thread) Delete() { func (t *Thread) Delete() {
query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2" query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
_, err := t.db.Exec(query, t.ID, t.ParentID) _, err := t.db.Exec(query, t.ID, t.ParentID)
if err != nil { if err != nil {
t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err) t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
panic(err) panic(err)

View File

@@ -1,4 +1,4 @@
-- v0 -> v8: Latest revision -- v0 -> v9: Latest revision
CREATE TABLE guild ( CREATE TABLE guild (
dcid TEXT PRIMARY KEY, dcid TEXT PRIMARY KEY,
@@ -49,6 +49,7 @@ CREATE TABLE thread (
parent_chan_id TEXT NOT NULL, parent_chan_id TEXT NOT NULL,
root_msg_dcid TEXT NOT NULL, root_msg_dcid TEXT NOT NULL,
root_msg_mxid TEXT NOT NULL, root_msg_mxid TEXT NOT NULL,
creation_notice_mxid TEXT NOT NULL,
-- This is also not accessed by the bridge. -- This is also not accessed by the bridge.
receiver TEXT NOT NULL DEFAULT '', receiver TEXT NOT NULL DEFAULT '',
@@ -98,9 +99,9 @@ CREATE TABLE message (
dc_edit_index INTEGER, dc_edit_index INTEGER,
dc_chan_id TEXT, dc_chan_id TEXT,
dc_chan_receiver TEXT, dc_chan_receiver TEXT,
dc_sender TEXT NOT NULL, dc_sender TEXT NOT NULL,
timestamp BIGINT NOT NULL, timestamp BIGINT NOT NULL,
dc_thread_id TEXT, dc_thread_id TEXT NOT NULL,
mxid TEXT NOT NULL UNIQUE, mxid TEXT NOT NULL UNIQUE,
@@ -114,9 +115,9 @@ CREATE TABLE reaction (
dc_msg_id TEXT, dc_msg_id TEXT,
dc_sender TEXT, dc_sender TEXT,
dc_emoji_name TEXT, dc_emoji_name TEXT,
dc_thread_id TEXT, dc_thread_id TEXT NOT NULL,
dc_first_attachment_id TEXT NOT NULL, dc_first_attachment_id TEXT NOT NULL,
_dc_first_edit_index INTEGER NOT NULL DEFAULT 0, _dc_first_edit_index INTEGER NOT NULL DEFAULT 0,
mxid TEXT NOT NULL UNIQUE, mxid TEXT NOT NULL UNIQUE,

View File

@@ -0,0 +1,9 @@
-- v9: Store more info for proper thread support
ALTER TABLE thread ADD COLUMN creation_notice_mxid TEXT NOT NULL DEFAULT '';
UPDATE message SET dc_thread_id='' WHERE dc_thread_id IS NULL;
UPDATE reaction SET dc_thread_id='' WHERE dc_thread_id IS NULL;
-- only: postgres for next 3 lines
ALTER TABLE thread ALTER COLUMN creation_notice_mxid DROP DEFAULT;
ALTER TABLE message ALTER COLUMN dc_thread_id SET NOT NULL;
ALTER TABLE reaction ALTER COLUMN dc_thread_id SET NOT NULL;

View File

@@ -10,8 +10,9 @@ import (
) )
const ( const (
UserPortalTypeDM = "dm" UserPortalTypeDM = "dm"
UserPortalTypeGuild = "guild" UserPortalTypeGuild = "guild"
UserPortalTypeThread = "thread"
) )
type UserPortal struct { type UserPortal struct {
@@ -62,6 +63,16 @@ func (u *User) IsInSpace(discordID string) (isIn bool) {
return return
} }
func (u *User) IsInPortal(discordID string) (isIn bool) {
query := `SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_mxid=$1 AND discord_id=$2)`
err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
panic(err)
}
return
}
func (u *User) MarkInPortal(portal UserPortal) { func (u *User) MarkInPortal(portal UserPortal) {
query := ` query := `
INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space) INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space)
@@ -88,8 +99,8 @@ func (u *User) MarkNotInPortal(discordID string) {
func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal { func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal {
query := ` query := `
DELETE FROM user_portal DELETE FROM user_portal
WHERE user_mxid=$1 AND timestamp<$2 WHERE user_mxid=$1 AND timestamp<$2 AND type IN ('dm', 'guild')
RETURNING discord_id, type, timestamp, in_space RETURNING discord_id, type, timestamp, in_space
` `
rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli()) rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli())
if err != nil { if err != nil {

View File

@@ -106,6 +106,9 @@ bridge:
# Should the bridge use space-restricted join rules instead of invite-only for guild rooms? # Should the bridge use space-restricted join rules instead of invite-only for guild rooms?
# This can avoid unnecessary invite events in guild rooms when members are synced in. # This can avoid unnecessary invite events in guild rooms when members are synced in.
restricted_rooms: true restricted_rooms: true
# Should the bridge automatically join the user to threads on Discord when the thread is opened on Matrix?
# This only works with clients that support thread read receipts (MSC3771 added in Matrix v1.4).
autojoin_thread_on_open: true
# Should the bridge update the m.direct account data event when double puppeting is enabled. # Should the bridge update the m.direct account data event when double puppeting is enabled.
# Note that updating the m.direct event is not atomic (except with mautrix-asmux) # Note that updating the m.direct event is not atomic (except with mautrix-asmux)
# and is therefore prone to race conditions. # and is therefore prone to race conditions.

12
go.mod
View File

@@ -6,13 +6,13 @@ require (
github.com/bwmarrin/discordgo v0.26.1 github.com/bwmarrin/discordgo v0.26.1
github.com/gorilla/mux v1.8.0 github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/lib/pq v1.10.6 github.com/lib/pq v1.10.7
github.com/mattn/go-sqlite3 v1.14.15 github.com/mattn/go-sqlite3 v1.14.15
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
github.com/yuin/goldmark v1.4.13 github.com/yuin/goldmark v1.5.2
maunium.net/go/maulogger/v2 v2.3.2 maunium.net/go/maulogger/v2 v2.3.2
maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8
) )
require ( require (
@@ -20,13 +20,13 @@ require (
github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-isatty v0.0.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rs/zerolog v1.27.0 // indirect github.com/rs/zerolog v1.28.0 // indirect
github.com/tidwall/gjson v1.14.3 // indirect github.com/tidwall/gjson v1.14.3 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect github.com/tidwall/sjson v1.2.5 // indirect
golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8 // indirect golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect
golang.org/x/net v0.0.0-20220812174116-3211cb980234 // indirect golang.org/x/net v0.0.0-20221014081412-f15817d10f9b // indirect
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
maunium.net/go/mauflag v1.0.0 // indirect maunium.net/go/mauflag v1.0.0 // indirect

26
go.sum
View File

@@ -11,8 +11,8 @@ github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw=
github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
@@ -22,9 +22,9 @@ github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.27.0 h1:1T7qCieN22GVc8S4Q2yuexzBb1EqjbgjSH9RohbMjKs= github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
github.com/rs/zerolog v1.27.0/go.mod h1:7frBqO0oezxmnO7GF86FY++uy8I0Tk/If5ni1G9Qc0U= github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -41,14 +41,14 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.5.2 h1:ALmeCk/px5FSm1MAcFBAsVKZjDuMVj8Tm7FFIlMJnqU=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.5.2/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8 h1:GIAS/yBem/gq2MUqgNIzUHW7cJMmx3TGZOrnyYaNQ6c= golang.org/x/crypto v0.0.0-20221012134737-56aed061732a h1:NmSIgad6KjE6VvHciPZuNRTKxGhlPfD6OA87W/PLkqg=
golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20221012134737-56aed061732a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220812174116-3211cb980234 h1:RDqmgfe7SvlMWoqC3xwQ2blLO3fcWcxMa3eBLRdRW7E= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU=
golang.org/x/net v0.0.0-20220812174116-3211cb980234/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -66,5 +66,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0= maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e h1:NkB/p2VTBiF+kis+JhOYKKyAgE0KH7kN9744WETU9aA= maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8 h1:OsKa24dXPnBgT5lr5ResRbTO9f+bZpZe/K5ioO2+1a8=
maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e/go.mod h1:/jxQFIipObSsjZPH6o3xyUi8uoULz3Hfr/8p9loqpYE= maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8/go.mod h1:bCw45Qx/m9qsz7eazmbe7Rzq5ZbTPzwRE1UgX2S9DXs=

12
main.go
View File

@@ -59,9 +59,10 @@ type DiscordBridge struct {
portalsByID map[database.PortalKey]*Portal portalsByID map[database.PortalKey]*Portal
portalsLock sync.Mutex portalsLock sync.Mutex
threadsByID map[string]*Thread threadsByID map[string]*Thread
threadsByRootMXID map[id.EventID]*Thread threadsByRootMXID map[id.EventID]*Thread
threadsLock sync.Mutex threadsByCreationNoticeMXID map[id.EventID]*Thread
threadsLock sync.Mutex
guildsByMXID map[id.RoomID]*Guild guildsByMXID map[id.RoomID]*Guild
guildsByID map[string]*Guild guildsByID map[string]*Guild
@@ -153,8 +154,9 @@ func main() {
portalsByMXID: make(map[id.RoomID]*Portal), portalsByMXID: make(map[id.RoomID]*Portal),
portalsByID: make(map[database.PortalKey]*Portal), portalsByID: make(map[database.PortalKey]*Portal),
threadsByID: make(map[string]*Thread), threadsByID: make(map[string]*Thread),
threadsByRootMXID: make(map[id.EventID]*Thread), threadsByRootMXID: make(map[id.EventID]*Thread),
threadsByCreationNoticeMXID: make(map[id.EventID]*Thread),
guildsByID: make(map[string]*Guild), guildsByID: make(map[string]*Guild),
guildsByMXID: make(map[id.RoomID]*Guild), guildsByMXID: make(map[id.RoomID]*Guild),

View File

@@ -59,6 +59,14 @@ type Portal struct {
currentlyTypingLock sync.Mutex currentlyTypingLock sync.Mutex
} }
var _ bridge.Portal = (*Portal)(nil)
var _ bridge.ReadReceiptHandlingPortal = (*Portal)(nil)
var _ bridge.MembershipHandlingPortal = (*Portal)(nil)
var _ bridge.TypingPortal = (*Portal)(nil)
//var _ bridge.MetaHandlingPortal = (*Portal)(nil)
//var _ bridge.DisappearingPortal = (*Portal)(nil)
func (portal *Portal) IsEncrypted() bool { func (portal *Portal) IsEncrypted() bool {
return portal.Encrypted return portal.Encrypted
} }
@@ -74,8 +82,6 @@ func (portal *Portal) ReceiveMatrixEvent(user bridge.User, evt *event.Event) {
} }
} }
var _ bridge.Portal = (*Portal)(nil)
var ( var (
portalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType} portalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType}
) )
@@ -714,6 +720,48 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
} }
} }
const JoinThreadReaction = "join thread"
func (portal *Portal) sendThreadCreationNotice(thread *Thread) {
thread.creationNoticeLock.Lock()
defer thread.creationNoticeLock.Unlock()
if thread.CreationNoticeMXID != "" {
return
}
creationNotice := "Thread created. React to this message with \"join thread\" to join the thread on Discord."
if portal.bridge.Config.Bridge.AutojoinThreadOnOpen {
creationNotice = "Thread created. Opening this thread will auto-join you to it on Discord."
}
resp, err := portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
Body: creationNotice,
MsgType: event.MsgNotice,
RelatesTo: (&event.RelatesTo{}).SetThread(thread.RootMXID, thread.RootMXID),
}, nil, time.Now().UnixMilli())
if err != nil {
portal.log.Errorfln("Failed to send thread creation notice: %v", err)
return
}
portal.bridge.threadsLock.Lock()
thread.CreationNoticeMXID = resp.EventID
portal.bridge.threadsByCreationNoticeMXID[resp.EventID] = thread
portal.bridge.threadsLock.Unlock()
thread.Update()
portal.log.Debugfln("Sent notice %s about thread for %s being created", thread.CreationNoticeMXID, thread.ID)
resp, err = portal.MainIntent().SendMessageEvent(portal.MXID, event.EventReaction, &event.ReactionEventContent{
RelatesTo: event.RelatesTo{
Type: event.RelAnnotation,
EventID: thread.CreationNoticeMXID,
Key: JoinThreadReaction,
},
})
if err != nil {
portal.log.Errorfln("Failed to send prefilled reaction to thread creation notice: %v", err)
} else {
portal.log.Debugfln("Sent prefilled reaction %s to thread creation notice %s", resp.EventID, thread.CreationNoticeMXID)
}
}
func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) { func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) {
if portal.MXID == "" { if portal.MXID == "" {
portal.log.Warnln("handle message called without a valid portal") portal.log.Warnln("handle message called without a valid portal")
@@ -728,13 +776,11 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
} }
if msg.Flags == discordgo.MessageFlagsHasThread { if msg.Flags == discordgo.MessageFlagsHasThread {
portal.bridge.GetThreadByID(msg.ID, existing[0]) thread := portal.bridge.GetThreadByID(msg.ID, existing[0])
portal.log.Debugfln("Marked %s as a thread root", msg.ID) portal.log.Debugfln("Marked %s as a thread root", msg.ID)
// TODO make autojoining configurable if thread.CreationNoticeMXID == "" {
//err := user.Session.ThreadJoinWithLocation(msg.ID, discordgo.ThreadJoinLocationContextMenu) portal.sendThreadCreationNotice(thread)
//if err != nil { }
// user.log.Warnfln("Error autojoining thread %s@%s: %v", msg.ChannelID, portal.Key.ChannelID, err)
//}
} }
// There's a few scenarios where the author is nil but I haven't figured // There's a few scenarios where the author is nil but I haven't figured
@@ -1322,6 +1368,16 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) {
return return
} }
if reaction.RelatesTo.Key == JoinThreadReaction {
thread := portal.bridge.GetThreadByRootOrCreationNoticeMXID(reaction.RelatesTo.EventID)
if thread == nil {
go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring thread join")
return
}
thread.Join(sender)
return
}
msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID) msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID)
if msg == nil { if msg == nil {
go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring") go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring")
@@ -1479,14 +1535,31 @@ func (portal *Portal) handleMatrixRedaction(sender *User, evt *event.Event) {
go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring") go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring")
} }
func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.EventID, receiptTimestamp time.Time) { func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.EventID, receipt event.ReadReceipt) {
sender := brUser.(*User) sender := brUser.(*User)
if sender.Session == nil { if sender.Session == nil {
return return
} }
var thread *Thread
discordThreadID := ""
if receipt.ThreadID != "" && receipt.ThreadID != event.ReadReceiptThreadMain {
thread = portal.bridge.GetThreadByRootMXID(receipt.ThreadID)
if thread != nil {
discordThreadID = thread.ID
}
}
if thread != nil {
if portal.bridge.Config.Bridge.AutojoinThreadOnOpen {
thread.Join(sender)
}
if eventID == thread.CreationNoticeMXID {
portal.log.Debugfln("Dropping Matrix read receipt from %s for thread creation notice %s of %s", sender.MXID, thread.CreationNoticeMXID, thread.ID)
return
}
}
msg := portal.bridge.DB.Message.GetByMXID(portal.Key, eventID) msg := portal.bridge.DB.Message.GetByMXID(portal.Key, eventID)
if msg == nil { if msg == nil {
msg = portal.bridge.DB.Message.GetClosestBefore(portal.Key, receiptTimestamp) msg = portal.bridge.DB.Message.GetClosestBefore(portal.Key, discordThreadID, receipt.Timestamp)
if msg == nil { if msg == nil {
portal.log.Debugfln("Dropping Matrix read receipt from %s for %s: no messages found", sender.MXID, eventID) portal.log.Debugfln("Dropping Matrix read receipt from %s for %s: no messages found", sender.MXID, eventID)
return return
@@ -1494,13 +1567,17 @@ func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.Eve
portal.log.Debugfln("Matrix read receipt target %s from %s not found, using closest message %s", eventID, sender.MXID, msg.MXID) portal.log.Debugfln("Matrix read receipt target %s from %s not found, using closest message %s", eventID, sender.MXID, msg.MXID)
} }
} }
if receipt.ThreadID != "" && msg.ThreadID != discordThreadID {
portal.log.Debugfln("Dropping Matrix read receipt from %s for %s in unexpected thread (receipt: %s, message: %s)", receipt.ThreadID, msg.ThreadID)
return
}
resp, err := sender.Session.ChannelMessageAckNoToken(msg.DiscordProtoChannelID(), msg.DiscordID) resp, err := sender.Session.ChannelMessageAckNoToken(msg.DiscordProtoChannelID(), msg.DiscordID)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to handle read receipt for %s/%s from %s: %v", msg.MXID, msg.DiscordID, sender.MXID, err) portal.log.Warnfln("Failed to handle read receipt for %s/%s from %s: %v", msg.MXID, msg.DiscordID, sender.MXID, err)
} else if resp.Token != nil { } else if resp.Token != nil {
portal.log.Debugfln("Marked %s/%s as read by %s (and got unexpected non-nil token %s)", msg.MXID, msg.DiscordID, sender.MXID, *resp.Token) portal.log.Debugfln("Marked %s/%s as read by %s (and got unexpected non-nil token %s)", msg.MXID, msg.DiscordID, sender.MXID, *resp.Token)
} else { } else {
portal.log.Debugfln("Marked %s/%s as read by %s", msg.MXID, msg.DiscordID, sender.MXID) portal.log.Debugfln("Marked %s/%s in %s as read by %s", msg.MXID, msg.DiscordID, msg.DiscordProtoChannelID(), sender.MXID)
} }
} }

View File

@@ -1,6 +1,10 @@
package main package main
import ( import (
"sync"
"time"
"github.com/bwmarrin/discordgo"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"go.mau.fi/mautrix-discord/database" "go.mau.fi/mautrix-discord/database"
@@ -9,6 +13,8 @@ import (
type Thread struct { type Thread struct {
*database.Thread *database.Thread
Parent *Portal Parent *Portal
creationNoticeLock sync.Mutex
} }
func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread { func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
@@ -31,6 +37,19 @@ func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread {
return thread return thread
} }
func (br *DiscordBridge) GetThreadByRootOrCreationNoticeMXID(mxid id.EventID) *Thread {
br.threadsLock.Lock()
defer br.threadsLock.Unlock()
thread, ok := br.threadsByRootMXID[mxid]
if !ok {
thread, ok = br.threadsByCreationNoticeMXID[mxid]
if !ok {
return br.loadThread(br.DB.Thread.GetByMatrixRootOrCreationNoticeMsg(mxid), "", nil)
}
}
return thread
}
func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread { func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread {
if dbThread == nil { if dbThread == nil {
if root == nil { if root == nil {
@@ -49,5 +68,25 @@ func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *
thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, "")) thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, ""))
br.threadsByID[thread.ID] = thread br.threadsByID[thread.ID] = thread
br.threadsByRootMXID[thread.RootMXID] = thread br.threadsByRootMXID[thread.RootMXID] = thread
if thread.CreationNoticeMXID != "" {
br.threadsByCreationNoticeMXID[thread.CreationNoticeMXID] = thread
}
return thread return thread
} }
func (thread *Thread) Join(user *User) {
if user.IsInPortal(thread.ID) {
return
}
user.log.Debugfln("Joining thread %s@%s", thread.ID, thread.ParentID)
err := user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu)
if err != nil {
user.log.Errorfln("Error joining thread %s@%s: %v", thread.ID, thread.ParentID, err)
} else {
user.MarkInPortal(database.UserPortal{
DiscordID: thread.ID,
Type: database.UserPortalTypeThread,
Timestamp: time.Now(),
})
}
}