Add option for autojoining threads when opened

This commit is contained in:
Tulir Asokan
2022-10-28 23:35:31 +03:00
parent 6be531685f
commit f268ddd132
13 changed files with 219 additions and 70 deletions

View File

@@ -60,9 +60,9 @@ func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Mes
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
}
func (mq *MessageQuery) GetClosestBefore(key PortalKey, ts time.Time) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND timestamp<=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, ts.UnixMilli()))
func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli()))
}
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
@@ -115,9 +115,8 @@ func (m *Message) DiscordProtoChannelID() string {
func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts int64
var threadID sql.NullString
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.ThreadID, &m.MXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err)
@@ -130,7 +129,6 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
if ts != 0 {
m.Timestamp = time.UnixMilli(ts)
}
m.ThreadID = threadID.String
return m
}
@@ -181,7 +179,7 @@ func (m *Message) MassInsert(msgs []MessagePart) {
func (m *Message) Insert() {
_, err := m.db.Exec(messageInsertQuery,
m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
m.Timestamp.UnixMilli(), m.ThreadID, m.MXID)
if err != nil {
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)

View File

@@ -16,7 +16,7 @@ type ThreadQuery struct {
}
const (
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread"
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid FROM thread"
)
func (tq *ThreadQuery) New() *Thread {
@@ -37,17 +37,6 @@ func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
return tq.New().Scan(row)
}
//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
//
// row := tq.db.QueryRow(query, channelID, messageID)
// if row == nil {
// return nil
// }
//
// return tq.New().Scan(row)
//}
func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
query := threadSelect + " WHERE root_msg_mxid=$1"
@@ -59,6 +48,17 @@ func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
return tq.New().Scan(row)
}
func (tq *ThreadQuery) GetByMatrixRootOrCreationNoticeMsg(mxid id.EventID) *Thread {
query := threadSelect + " WHERE root_msg_mxid=$1 OR creation_notice_mxid=$1"
row := tq.db.QueryRow(query, mxid)
if row == nil {
return nil
}
return tq.New().Scan(row)
}
type Thread struct {
db *Database
log log.Logger
@@ -68,10 +68,12 @@ type Thread struct {
RootDiscordID string
RootMXID id.EventID
CreationNoticeMXID id.EventID
}
func (t *Thread) Scan(row dbutil.Scannable) *Thread {
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID)
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID, &t.CreationNoticeMXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
t.log.Errorln("Database scan failed:", err)
@@ -83,21 +85,26 @@ func (t *Thread) Scan(row dbutil.Scannable) *Thread {
}
func (t *Thread) Insert() {
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)"
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid) VALUES ($1, $2, $3, $4, $5)"
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID, t.CreationNoticeMXID)
if err != nil {
t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
panic(err)
}
}
func (t *Thread) Update() {
query := "UPDATE thread SET creation_notice_mxid=$2 WHERE dcid=$1"
_, err := t.db.Exec(query, t.ID, t.CreationNoticeMXID)
if err != nil {
t.log.Warnfln("Failed to update %s@%s: %v", t.ID, t.ParentID, err)
panic(err)
}
}
func (t *Thread) Delete() {
query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
_, err := t.db.Exec(query, t.ID, t.ParentID)
if err != nil {
t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
panic(err)

View File

@@ -1,4 +1,4 @@
-- v0 -> v8: Latest revision
-- v0 -> v9: Latest revision
CREATE TABLE guild (
dcid TEXT PRIMARY KEY,
@@ -49,6 +49,7 @@ CREATE TABLE thread (
parent_chan_id TEXT NOT NULL,
root_msg_dcid TEXT NOT NULL,
root_msg_mxid TEXT NOT NULL,
creation_notice_mxid TEXT NOT NULL,
-- This is also not accessed by the bridge.
receiver TEXT NOT NULL DEFAULT '',
@@ -98,9 +99,9 @@ CREATE TABLE message (
dc_edit_index INTEGER,
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_sender TEXT NOT NULL,
dc_sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
dc_thread_id TEXT,
dc_thread_id TEXT NOT NULL,
mxid TEXT NOT NULL UNIQUE,
@@ -114,9 +115,9 @@ CREATE TABLE reaction (
dc_msg_id TEXT,
dc_sender TEXT,
dc_emoji_name TEXT,
dc_thread_id TEXT,
dc_thread_id TEXT NOT NULL,
dc_first_attachment_id TEXT NOT NULL,
dc_first_attachment_id TEXT NOT NULL,
_dc_first_edit_index INTEGER NOT NULL DEFAULT 0,
mxid TEXT NOT NULL UNIQUE,

View File

@@ -0,0 +1,9 @@
-- v9: Store more info for proper thread support
ALTER TABLE thread ADD COLUMN creation_notice_mxid TEXT NOT NULL DEFAULT '';
UPDATE message SET dc_thread_id='' WHERE dc_thread_id IS NULL;
UPDATE reaction SET dc_thread_id='' WHERE dc_thread_id IS NULL;
-- only: postgres for next 3 lines
ALTER TABLE thread ALTER COLUMN creation_notice_mxid DROP DEFAULT;
ALTER TABLE message ALTER COLUMN dc_thread_id SET NOT NULL;
ALTER TABLE reaction ALTER COLUMN dc_thread_id SET NOT NULL;

View File

@@ -10,8 +10,9 @@ import (
)
const (
UserPortalTypeDM = "dm"
UserPortalTypeGuild = "guild"
UserPortalTypeDM = "dm"
UserPortalTypeGuild = "guild"
UserPortalTypeThread = "thread"
)
type UserPortal struct {
@@ -62,6 +63,16 @@ func (u *User) IsInSpace(discordID string) (isIn bool) {
return
}
func (u *User) IsInPortal(discordID string) (isIn bool) {
query := `SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_mxid=$1 AND discord_id=$2)`
err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
panic(err)
}
return
}
func (u *User) MarkInPortal(portal UserPortal) {
query := `
INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space)
@@ -88,8 +99,8 @@ func (u *User) MarkNotInPortal(discordID string) {
func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal {
query := `
DELETE FROM user_portal
WHERE user_mxid=$1 AND timestamp<$2
RETURNING discord_id, type, timestamp, in_space
WHERE user_mxid=$1 AND timestamp<$2 AND type IN ('dm', 'guild')
RETURNING discord_id, type, timestamp, in_space
`
rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli())
if err != nil {