diff --git a/config/bridge.go b/config/bridge.go index 6e1fa5a..b7e97cf 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -41,6 +41,7 @@ type BridgeConfig struct { MessageStatusEvents bool `yaml:"message_status_events"` MessageErrorNotices bool `yaml:"message_error_notices"` RestrictedRooms bool `yaml:"restricted_rooms"` + AutojoinThreadOnOpen bool `yaml:"autojoin_thread_on_open"` SyncDirectChatList bool `yaml:"sync_direct_chat_list"` ResendBridgeInfo bool `yaml:"resend_bridge_info"` DeletePortalOnChannelDelete bool `yaml:"delete_portal_on_channel_delete"` diff --git a/config/upgrade.go b/config/upgrade.go index 80a2389..d620bcd 100644 --- a/config/upgrade.go +++ b/config/upgrade.go @@ -36,6 +36,7 @@ func DoUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "bridge", "message_status_events") helper.Copy(up.Bool, "bridge", "message_error_notices") helper.Copy(up.Bool, "bridge", "restricted_rooms") + helper.Copy(up.Bool, "bridge", "autojoin_thread_on_open") helper.Copy(up.Bool, "bridge", "sync_direct_chat_list") helper.Copy(up.Bool, "bridge", "resend_bridge_info") helper.Copy(up.Bool, "bridge", "delete_portal_on_channel_delete") diff --git a/database/message.go b/database/message.go index 7b6d651..0b170b3 100644 --- a/database/message.go +++ b/database/message.go @@ -60,9 +60,9 @@ func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Mes return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) } -func (mq *MessageQuery) GetClosestBefore(key PortalKey, ts time.Time) *Message { - query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND timestamp<=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" - return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, ts.UnixMilli())) +func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message { + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1" + return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli())) } func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message { @@ -115,9 +115,8 @@ func (m *Message) DiscordProtoChannelID() string { func (m *Message) Scan(row dbutil.Scannable) *Message { var ts int64 - var threadID sql.NullString - err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID) + err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.ThreadID, &m.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { m.log.Errorln("Database scan failed:", err) @@ -130,7 +129,6 @@ func (m *Message) Scan(row dbutil.Scannable) *Message { if ts != 0 { m.Timestamp = time.UnixMilli(ts) } - m.ThreadID = threadID.String return m } @@ -181,7 +179,7 @@ func (m *Message) MassInsert(msgs []MessagePart) { func (m *Message) Insert() { _, err := m.db.Exec(messageInsertQuery, m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, - m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID) + m.Timestamp.UnixMilli(), m.ThreadID, m.MXID) if err != nil { m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err) diff --git a/database/thread.go b/database/thread.go index 09bf1de..351ed0b 100644 --- a/database/thread.go +++ b/database/thread.go @@ -16,7 +16,7 @@ type ThreadQuery struct { } const ( - threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread" + threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid FROM thread" ) func (tq *ThreadQuery) New() *Thread { @@ -37,17 +37,6 @@ func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread { return tq.New().Scan(row) } -//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread { -// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2" -// -// row := tq.db.QueryRow(query, channelID, messageID) -// if row == nil { -// return nil -// } -// -// return tq.New().Scan(row) -//} - func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread { query := threadSelect + " WHERE root_msg_mxid=$1" @@ -59,6 +48,17 @@ func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread { return tq.New().Scan(row) } +func (tq *ThreadQuery) GetByMatrixRootOrCreationNoticeMsg(mxid id.EventID) *Thread { + query := threadSelect + " WHERE root_msg_mxid=$1 OR creation_notice_mxid=$1" + + row := tq.db.QueryRow(query, mxid) + if row == nil { + return nil + } + + return tq.New().Scan(row) +} + type Thread struct { db *Database log log.Logger @@ -68,10 +68,12 @@ type Thread struct { RootDiscordID string RootMXID id.EventID + + CreationNoticeMXID id.EventID } func (t *Thread) Scan(row dbutil.Scannable) *Thread { - err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID) + err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID, &t.CreationNoticeMXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { t.log.Errorln("Database scan failed:", err) @@ -83,21 +85,26 @@ func (t *Thread) Scan(row dbutil.Scannable) *Thread { } func (t *Thread) Insert() { - query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)" - - _, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID) - + query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid) VALUES ($1, $2, $3, $4, $5)" + _, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID, t.CreationNoticeMXID) if err != nil { t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err) panic(err) } } +func (t *Thread) Update() { + query := "UPDATE thread SET creation_notice_mxid=$2 WHERE dcid=$1" + _, err := t.db.Exec(query, t.ID, t.CreationNoticeMXID) + if err != nil { + t.log.Warnfln("Failed to update %s@%s: %v", t.ID, t.ParentID, err) + panic(err) + } +} + func (t *Thread) Delete() { query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2" - _, err := t.db.Exec(query, t.ID, t.ParentID) - if err != nil { t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err) panic(err) diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index 0ad34ea..e5d13f7 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v8: Latest revision +-- v0 -> v9: Latest revision CREATE TABLE guild ( dcid TEXT PRIMARY KEY, @@ -49,6 +49,7 @@ CREATE TABLE thread ( parent_chan_id TEXT NOT NULL, root_msg_dcid TEXT NOT NULL, root_msg_mxid TEXT NOT NULL, + creation_notice_mxid TEXT NOT NULL, -- This is also not accessed by the bridge. receiver TEXT NOT NULL DEFAULT '', @@ -98,9 +99,9 @@ CREATE TABLE message ( dc_edit_index INTEGER, dc_chan_id TEXT, dc_chan_receiver TEXT, - dc_sender TEXT NOT NULL, + dc_sender TEXT NOT NULL, timestamp BIGINT NOT NULL, - dc_thread_id TEXT, + dc_thread_id TEXT NOT NULL, mxid TEXT NOT NULL UNIQUE, @@ -114,9 +115,9 @@ CREATE TABLE reaction ( dc_msg_id TEXT, dc_sender TEXT, dc_emoji_name TEXT, - dc_thread_id TEXT, + dc_thread_id TEXT NOT NULL, - dc_first_attachment_id TEXT NOT NULL, + dc_first_attachment_id TEXT NOT NULL, _dc_first_edit_index INTEGER NOT NULL DEFAULT 0, mxid TEXT NOT NULL UNIQUE, diff --git a/database/upgrades/09-more-thread-data.sql b/database/upgrades/09-more-thread-data.sql new file mode 100644 index 0000000..461a1d4 --- /dev/null +++ b/database/upgrades/09-more-thread-data.sql @@ -0,0 +1,9 @@ +-- v9: Store more info for proper thread support +ALTER TABLE thread ADD COLUMN creation_notice_mxid TEXT NOT NULL DEFAULT ''; +UPDATE message SET dc_thread_id='' WHERE dc_thread_id IS NULL; +UPDATE reaction SET dc_thread_id='' WHERE dc_thread_id IS NULL; + +-- only: postgres for next 3 lines +ALTER TABLE thread ALTER COLUMN creation_notice_mxid DROP DEFAULT; +ALTER TABLE message ALTER COLUMN dc_thread_id SET NOT NULL; +ALTER TABLE reaction ALTER COLUMN dc_thread_id SET NOT NULL; diff --git a/database/userportal.go b/database/userportal.go index 85cba45..aa3f79f 100644 --- a/database/userportal.go +++ b/database/userportal.go @@ -10,8 +10,9 @@ import ( ) const ( - UserPortalTypeDM = "dm" - UserPortalTypeGuild = "guild" + UserPortalTypeDM = "dm" + UserPortalTypeGuild = "guild" + UserPortalTypeThread = "thread" ) type UserPortal struct { @@ -62,6 +63,16 @@ func (u *User) IsInSpace(discordID string) (isIn bool) { return } +func (u *User) IsInPortal(discordID string) (isIn bool) { + query := `SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_mxid=$1 AND discord_id=$2)` + err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err) + panic(err) + } + return +} + func (u *User) MarkInPortal(portal UserPortal) { query := ` INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space) @@ -88,8 +99,8 @@ func (u *User) MarkNotInPortal(discordID string) { func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal { query := ` DELETE FROM user_portal - WHERE user_mxid=$1 AND timestamp<$2 - RETURNING discord_id, type, timestamp, in_space + WHERE user_mxid=$1 AND timestamp<$2 AND type IN ('dm', 'guild') + RETURNING discord_id, type, timestamp, in_space ` rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli()) if err != nil { diff --git a/example-config.yaml b/example-config.yaml index 83f113d..d5aa486 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -106,6 +106,9 @@ bridge: # Should the bridge use space-restricted join rules instead of invite-only for guild rooms? # This can avoid unnecessary invite events in guild rooms when members are synced in. restricted_rooms: true + # Should the bridge automatically join the user to threads on Discord when the thread is opened on Matrix? + # This only works with clients that support thread read receipts (MSC3771 added in Matrix v1.4). + autojoin_thread_on_open: true # Should the bridge update the m.direct account data event when double puppeting is enabled. # Note that updating the m.direct event is not atomic (except with mautrix-asmux) # and is therefore prone to race conditions. diff --git a/go.mod b/go.mod index 71fd112..9bbd387 100644 --- a/go.mod +++ b/go.mod @@ -6,13 +6,13 @@ require ( github.com/bwmarrin/discordgo v0.26.1 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 - github.com/lib/pq v1.10.6 + github.com/lib/pq v1.10.7 github.com/mattn/go-sqlite3 v1.14.15 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.8.0 - github.com/yuin/goldmark v1.4.13 + github.com/yuin/goldmark v1.5.2 maunium.net/go/maulogger/v2 v2.3.2 - maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e + maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8 ) require ( @@ -20,13 +20,13 @@ require ( github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rs/zerolog v1.27.0 // indirect + github.com/rs/zerolog v1.28.0 // indirect github.com/tidwall/gjson v1.14.3 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect - golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8 // indirect - golang.org/x/net v0.0.0-20220812174116-3211cb980234 // indirect + golang.org/x/crypto v0.0.0-20221012134737-56aed061732a // indirect + golang.org/x/net v0.0.0-20221014081412-f15817d10f9b // indirect golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect maunium.net/go/mauflag v1.0.0 // indirect diff --git a/go.sum b/go.sum index c93ee95..24a0948 100644 --- a/go.sum +++ b/go.sum @@ -11,8 +11,8 @@ github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7 github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= -github.com/lib/pq v1.10.6/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= @@ -22,9 +22,9 @@ github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.27.0 h1:1T7qCieN22GVc8S4Q2yuexzBb1EqjbgjSH9RohbMjKs= -github.com/rs/zerolog v1.27.0/go.mod h1:7frBqO0oezxmnO7GF86FY++uy8I0Tk/If5ni1G9Qc0U= +github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY= +github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -41,14 +41,14 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/goldmark v1.5.2 h1:ALmeCk/px5FSm1MAcFBAsVKZjDuMVj8Tm7FFIlMJnqU= +github.com/yuin/goldmark v1.5.2/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8 h1:GIAS/yBem/gq2MUqgNIzUHW7cJMmx3TGZOrnyYaNQ6c= -golang.org/x/crypto v0.0.0-20220817201139-bc19a97f63c8/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.0.0-20221012134737-56aed061732a h1:NmSIgad6KjE6VvHciPZuNRTKxGhlPfD6OA87W/PLkqg= +golang.org/x/crypto v0.0.0-20221012134737-56aed061732a/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220812174116-3211cb980234 h1:RDqmgfe7SvlMWoqC3xwQ2blLO3fcWcxMa3eBLRdRW7E= -golang.org/x/net v0.0.0-20220812174116-3211cb980234/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU= +golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -66,5 +66,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0= maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= -maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e h1:NkB/p2VTBiF+kis+JhOYKKyAgE0KH7kN9744WETU9aA= -maunium.net/go/mautrix v0.12.2-0.20220919211529-34e3589a5d5e/go.mod h1:/jxQFIipObSsjZPH6o3xyUi8uoULz3Hfr/8p9loqpYE= +maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8 h1:OsKa24dXPnBgT5lr5ResRbTO9f+bZpZe/K5ioO2+1a8= +maunium.net/go/mautrix v0.12.2-0.20221016082058-f7d28cc18df8/go.mod h1:bCw45Qx/m9qsz7eazmbe7Rzq5ZbTPzwRE1UgX2S9DXs= diff --git a/main.go b/main.go index c707e3d..d95ca0c 100644 --- a/main.go +++ b/main.go @@ -59,9 +59,10 @@ type DiscordBridge struct { portalsByID map[database.PortalKey]*Portal portalsLock sync.Mutex - threadsByID map[string]*Thread - threadsByRootMXID map[id.EventID]*Thread - threadsLock sync.Mutex + threadsByID map[string]*Thread + threadsByRootMXID map[id.EventID]*Thread + threadsByCreationNoticeMXID map[id.EventID]*Thread + threadsLock sync.Mutex guildsByMXID map[id.RoomID]*Guild guildsByID map[string]*Guild @@ -153,8 +154,9 @@ func main() { portalsByMXID: make(map[id.RoomID]*Portal), portalsByID: make(map[database.PortalKey]*Portal), - threadsByID: make(map[string]*Thread), - threadsByRootMXID: make(map[id.EventID]*Thread), + threadsByID: make(map[string]*Thread), + threadsByRootMXID: make(map[id.EventID]*Thread), + threadsByCreationNoticeMXID: make(map[id.EventID]*Thread), guildsByID: make(map[string]*Guild), guildsByMXID: make(map[id.RoomID]*Guild), diff --git a/portal.go b/portal.go index dc77c3c..73ebca3 100644 --- a/portal.go +++ b/portal.go @@ -59,6 +59,14 @@ type Portal struct { currentlyTypingLock sync.Mutex } +var _ bridge.Portal = (*Portal)(nil) +var _ bridge.ReadReceiptHandlingPortal = (*Portal)(nil) +var _ bridge.MembershipHandlingPortal = (*Portal)(nil) +var _ bridge.TypingPortal = (*Portal)(nil) + +//var _ bridge.MetaHandlingPortal = (*Portal)(nil) +//var _ bridge.DisappearingPortal = (*Portal)(nil) + func (portal *Portal) IsEncrypted() bool { return portal.Encrypted } @@ -74,8 +82,6 @@ func (portal *Portal) ReceiveMatrixEvent(user bridge.User, evt *event.Event) { } } -var _ bridge.Portal = (*Portal)(nil) - var ( portalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType} ) @@ -714,6 +720,48 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } } +const JoinThreadReaction = "join thread" + +func (portal *Portal) sendThreadCreationNotice(thread *Thread) { + thread.creationNoticeLock.Lock() + defer thread.creationNoticeLock.Unlock() + if thread.CreationNoticeMXID != "" { + return + } + creationNotice := "Thread created. React to this message with \"join thread\" to join the thread on Discord." + if portal.bridge.Config.Bridge.AutojoinThreadOnOpen { + creationNotice = "Thread created. Opening this thread will auto-join you to it on Discord." + } + resp, err := portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{ + Body: creationNotice, + MsgType: event.MsgNotice, + RelatesTo: (&event.RelatesTo{}).SetThread(thread.RootMXID, thread.RootMXID), + }, nil, time.Now().UnixMilli()) + if err != nil { + portal.log.Errorfln("Failed to send thread creation notice: %v", err) + return + } + portal.bridge.threadsLock.Lock() + thread.CreationNoticeMXID = resp.EventID + portal.bridge.threadsByCreationNoticeMXID[resp.EventID] = thread + portal.bridge.threadsLock.Unlock() + thread.Update() + portal.log.Debugfln("Sent notice %s about thread for %s being created", thread.CreationNoticeMXID, thread.ID) + + resp, err = portal.MainIntent().SendMessageEvent(portal.MXID, event.EventReaction, &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: thread.CreationNoticeMXID, + Key: JoinThreadReaction, + }, + }) + if err != nil { + portal.log.Errorfln("Failed to send prefilled reaction to thread creation notice: %v", err) + } else { + portal.log.Debugfln("Sent prefilled reaction %s to thread creation notice %s", resp.EventID, thread.CreationNoticeMXID) + } +} + func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) { if portal.MXID == "" { portal.log.Warnln("handle message called without a valid portal") @@ -728,13 +776,11 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess } if msg.Flags == discordgo.MessageFlagsHasThread { - portal.bridge.GetThreadByID(msg.ID, existing[0]) + thread := portal.bridge.GetThreadByID(msg.ID, existing[0]) portal.log.Debugfln("Marked %s as a thread root", msg.ID) - // TODO make autojoining configurable - //err := user.Session.ThreadJoinWithLocation(msg.ID, discordgo.ThreadJoinLocationContextMenu) - //if err != nil { - // user.log.Warnfln("Error autojoining thread %s@%s: %v", msg.ChannelID, portal.Key.ChannelID, err) - //} + if thread.CreationNoticeMXID == "" { + portal.sendThreadCreationNotice(thread) + } } // There's a few scenarios where the author is nil but I haven't figured @@ -1322,6 +1368,16 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { return } + if reaction.RelatesTo.Key == JoinThreadReaction { + thread := portal.bridge.GetThreadByRootOrCreationNoticeMXID(reaction.RelatesTo.EventID) + if thread == nil { + go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring thread join") + return + } + thread.Join(sender) + return + } + msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID) if msg == nil { go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring") @@ -1479,14 +1535,31 @@ func (portal *Portal) handleMatrixRedaction(sender *User, evt *event.Event) { go portal.sendMessageMetrics(evt, errTargetNotFound, "Ignoring") } -func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.EventID, receiptTimestamp time.Time) { +func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.EventID, receipt event.ReadReceipt) { sender := brUser.(*User) if sender.Session == nil { return } + var thread *Thread + discordThreadID := "" + if receipt.ThreadID != "" && receipt.ThreadID != event.ReadReceiptThreadMain { + thread = portal.bridge.GetThreadByRootMXID(receipt.ThreadID) + if thread != nil { + discordThreadID = thread.ID + } + } + if thread != nil { + if portal.bridge.Config.Bridge.AutojoinThreadOnOpen { + thread.Join(sender) + } + if eventID == thread.CreationNoticeMXID { + portal.log.Debugfln("Dropping Matrix read receipt from %s for thread creation notice %s of %s", sender.MXID, thread.CreationNoticeMXID, thread.ID) + return + } + } msg := portal.bridge.DB.Message.GetByMXID(portal.Key, eventID) if msg == nil { - msg = portal.bridge.DB.Message.GetClosestBefore(portal.Key, receiptTimestamp) + msg = portal.bridge.DB.Message.GetClosestBefore(portal.Key, discordThreadID, receipt.Timestamp) if msg == nil { portal.log.Debugfln("Dropping Matrix read receipt from %s for %s: no messages found", sender.MXID, eventID) return @@ -1494,13 +1567,17 @@ func (portal *Portal) HandleMatrixReadReceipt(brUser bridge.User, eventID id.Eve portal.log.Debugfln("Matrix read receipt target %s from %s not found, using closest message %s", eventID, sender.MXID, msg.MXID) } } + if receipt.ThreadID != "" && msg.ThreadID != discordThreadID { + portal.log.Debugfln("Dropping Matrix read receipt from %s for %s in unexpected thread (receipt: %s, message: %s)", receipt.ThreadID, msg.ThreadID) + return + } resp, err := sender.Session.ChannelMessageAckNoToken(msg.DiscordProtoChannelID(), msg.DiscordID) if err != nil { portal.log.Warnfln("Failed to handle read receipt for %s/%s from %s: %v", msg.MXID, msg.DiscordID, sender.MXID, err) } else if resp.Token != nil { portal.log.Debugfln("Marked %s/%s as read by %s (and got unexpected non-nil token %s)", msg.MXID, msg.DiscordID, sender.MXID, *resp.Token) } else { - portal.log.Debugfln("Marked %s/%s as read by %s", msg.MXID, msg.DiscordID, sender.MXID) + portal.log.Debugfln("Marked %s/%s in %s as read by %s", msg.MXID, msg.DiscordID, msg.DiscordProtoChannelID(), sender.MXID) } } diff --git a/thread.go b/thread.go index 8d7468f..5b0437b 100644 --- a/thread.go +++ b/thread.go @@ -1,6 +1,10 @@ package main import ( + "sync" + "time" + + "github.com/bwmarrin/discordgo" "maunium.net/go/mautrix/id" "go.mau.fi/mautrix-discord/database" @@ -9,6 +13,8 @@ import ( type Thread struct { *database.Thread Parent *Portal + + creationNoticeLock sync.Mutex } func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread { @@ -31,6 +37,19 @@ func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread { return thread } +func (br *DiscordBridge) GetThreadByRootOrCreationNoticeMXID(mxid id.EventID) *Thread { + br.threadsLock.Lock() + defer br.threadsLock.Unlock() + thread, ok := br.threadsByRootMXID[mxid] + if !ok { + thread, ok = br.threadsByCreationNoticeMXID[mxid] + if !ok { + return br.loadThread(br.DB.Thread.GetByMatrixRootOrCreationNoticeMsg(mxid), "", nil) + } + } + return thread +} + func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread { if dbThread == nil { if root == nil { @@ -49,5 +68,25 @@ func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root * thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, "")) br.threadsByID[thread.ID] = thread br.threadsByRootMXID[thread.RootMXID] = thread + if thread.CreationNoticeMXID != "" { + br.threadsByCreationNoticeMXID[thread.CreationNoticeMXID] = thread + } return thread } + +func (thread *Thread) Join(user *User) { + if user.IsInPortal(thread.ID) { + return + } + user.log.Debugfln("Joining thread %s@%s", thread.ID, thread.ParentID) + err := user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu) + if err != nil { + user.log.Errorfln("Error joining thread %s@%s: %v", thread.ID, thread.ParentID, err) + } else { + user.MarkInPortal(database.UserPortal{ + DiscordID: thread.ID, + Type: database.UserPortalTypeThread, + Timestamp: time.Now(), + }) + } +}