Add option for autojoining threads when opened
This commit is contained in:
@@ -60,9 +60,9 @@ func (mq *MessageQuery) GetLastByDiscordID(key PortalKey, discordID string) *Mes
|
||||
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetClosestBefore(key PortalKey, ts time.Time) *Message {
|
||||
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND timestamp<=$3 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
|
||||
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, ts.UnixMilli()))
|
||||
func (mq *MessageQuery) GetClosestBefore(key PortalKey, threadID string, ts time.Time) *Message {
|
||||
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 AND timestamp<=$4 ORDER BY timestamp DESC, dc_attachment_id DESC LIMIT 1"
|
||||
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID, ts.UnixMilli()))
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
|
||||
@@ -115,9 +115,8 @@ func (m *Message) DiscordProtoChannelID() string {
|
||||
|
||||
func (m *Message) Scan(row dbutil.Scannable) *Message {
|
||||
var ts int64
|
||||
var threadID sql.NullString
|
||||
|
||||
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
|
||||
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.EditIndex, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.ThreadID, &m.MXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
m.log.Errorln("Database scan failed:", err)
|
||||
@@ -130,7 +129,6 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
|
||||
if ts != 0 {
|
||||
m.Timestamp = time.UnixMilli(ts)
|
||||
}
|
||||
m.ThreadID = threadID.String
|
||||
|
||||
return m
|
||||
}
|
||||
@@ -181,7 +179,7 @@ func (m *Message) MassInsert(msgs []MessagePart) {
|
||||
func (m *Message) Insert() {
|
||||
_, err := m.db.Exec(messageInsertQuery,
|
||||
m.DiscordID, m.AttachmentID, m.EditIndex, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
|
||||
m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
|
||||
m.Timestamp.UnixMilli(), m.ThreadID, m.MXID)
|
||||
|
||||
if err != nil {
|
||||
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
|
||||
|
||||
@@ -16,7 +16,7 @@ type ThreadQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread"
|
||||
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid FROM thread"
|
||||
)
|
||||
|
||||
func (tq *ThreadQuery) New() *Thread {
|
||||
@@ -37,17 +37,6 @@ func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
|
||||
return tq.New().Scan(row)
|
||||
}
|
||||
|
||||
//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
|
||||
// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
|
||||
//
|
||||
// row := tq.db.QueryRow(query, channelID, messageID)
|
||||
// if row == nil {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// return tq.New().Scan(row)
|
||||
//}
|
||||
|
||||
func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
|
||||
query := threadSelect + " WHERE root_msg_mxid=$1"
|
||||
|
||||
@@ -59,6 +48,17 @@ func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
|
||||
return tq.New().Scan(row)
|
||||
}
|
||||
|
||||
func (tq *ThreadQuery) GetByMatrixRootOrCreationNoticeMsg(mxid id.EventID) *Thread {
|
||||
query := threadSelect + " WHERE root_msg_mxid=$1 OR creation_notice_mxid=$1"
|
||||
|
||||
row := tq.db.QueryRow(query, mxid)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tq.New().Scan(row)
|
||||
}
|
||||
|
||||
type Thread struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
@@ -68,10 +68,12 @@ type Thread struct {
|
||||
|
||||
RootDiscordID string
|
||||
RootMXID id.EventID
|
||||
|
||||
CreationNoticeMXID id.EventID
|
||||
}
|
||||
|
||||
func (t *Thread) Scan(row dbutil.Scannable) *Thread {
|
||||
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID)
|
||||
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID, &t.CreationNoticeMXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.log.Errorln("Database scan failed:", err)
|
||||
@@ -83,21 +85,26 @@ func (t *Thread) Scan(row dbutil.Scannable) *Thread {
|
||||
}
|
||||
|
||||
func (t *Thread) Insert() {
|
||||
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
|
||||
|
||||
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid, creation_notice_mxid) VALUES ($1, $2, $3, $4, $5)"
|
||||
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID, t.CreationNoticeMXID)
|
||||
if err != nil {
|
||||
t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Thread) Update() {
|
||||
query := "UPDATE thread SET creation_notice_mxid=$2 WHERE dcid=$1"
|
||||
_, err := t.db.Exec(query, t.ID, t.CreationNoticeMXID)
|
||||
if err != nil {
|
||||
t.log.Warnfln("Failed to update %s@%s: %v", t.ID, t.ParentID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Thread) Delete() {
|
||||
query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
|
||||
|
||||
_, err := t.db.Exec(query, t.ID, t.ParentID)
|
||||
|
||||
if err != nil {
|
||||
t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
|
||||
panic(err)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
-- v0 -> v8: Latest revision
|
||||
-- v0 -> v9: Latest revision
|
||||
|
||||
CREATE TABLE guild (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
@@ -49,6 +49,7 @@ CREATE TABLE thread (
|
||||
parent_chan_id TEXT NOT NULL,
|
||||
root_msg_dcid TEXT NOT NULL,
|
||||
root_msg_mxid TEXT NOT NULL,
|
||||
creation_notice_mxid TEXT NOT NULL,
|
||||
-- This is also not accessed by the bridge.
|
||||
receiver TEXT NOT NULL DEFAULT '',
|
||||
|
||||
@@ -98,9 +99,9 @@ CREATE TABLE message (
|
||||
dc_edit_index INTEGER,
|
||||
dc_chan_id TEXT,
|
||||
dc_chan_receiver TEXT,
|
||||
dc_sender TEXT NOT NULL,
|
||||
dc_sender TEXT NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
dc_thread_id TEXT,
|
||||
dc_thread_id TEXT NOT NULL,
|
||||
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
|
||||
@@ -114,9 +115,9 @@ CREATE TABLE reaction (
|
||||
dc_msg_id TEXT,
|
||||
dc_sender TEXT,
|
||||
dc_emoji_name TEXT,
|
||||
dc_thread_id TEXT,
|
||||
dc_thread_id TEXT NOT NULL,
|
||||
|
||||
dc_first_attachment_id TEXT NOT NULL,
|
||||
dc_first_attachment_id TEXT NOT NULL,
|
||||
_dc_first_edit_index INTEGER NOT NULL DEFAULT 0,
|
||||
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
|
||||
9
database/upgrades/09-more-thread-data.sql
Normal file
9
database/upgrades/09-more-thread-data.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
-- v9: Store more info for proper thread support
|
||||
ALTER TABLE thread ADD COLUMN creation_notice_mxid TEXT NOT NULL DEFAULT '';
|
||||
UPDATE message SET dc_thread_id='' WHERE dc_thread_id IS NULL;
|
||||
UPDATE reaction SET dc_thread_id='' WHERE dc_thread_id IS NULL;
|
||||
|
||||
-- only: postgres for next 3 lines
|
||||
ALTER TABLE thread ALTER COLUMN creation_notice_mxid DROP DEFAULT;
|
||||
ALTER TABLE message ALTER COLUMN dc_thread_id SET NOT NULL;
|
||||
ALTER TABLE reaction ALTER COLUMN dc_thread_id SET NOT NULL;
|
||||
@@ -10,8 +10,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
UserPortalTypeDM = "dm"
|
||||
UserPortalTypeGuild = "guild"
|
||||
UserPortalTypeDM = "dm"
|
||||
UserPortalTypeGuild = "guild"
|
||||
UserPortalTypeThread = "thread"
|
||||
)
|
||||
|
||||
type UserPortal struct {
|
||||
@@ -62,6 +63,16 @@ func (u *User) IsInSpace(discordID string) (isIn bool) {
|
||||
return
|
||||
}
|
||||
|
||||
func (u *User) IsInPortal(discordID string) (isIn bool) {
|
||||
query := `SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_mxid=$1 AND discord_id=$2)`
|
||||
err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err)
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (u *User) MarkInPortal(portal UserPortal) {
|
||||
query := `
|
||||
INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space)
|
||||
@@ -88,8 +99,8 @@ func (u *User) MarkNotInPortal(discordID string) {
|
||||
func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal {
|
||||
query := `
|
||||
DELETE FROM user_portal
|
||||
WHERE user_mxid=$1 AND timestamp<$2
|
||||
RETURNING discord_id, type, timestamp, in_space
|
||||
WHERE user_mxid=$1 AND timestamp<$2 AND type IN ('dm', 'guild')
|
||||
RETURNING discord_id, type, timestamp, in_space
|
||||
`
|
||||
rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli())
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user