Break a bunch of stuff
This commit is contained in:
@@ -16,7 +16,7 @@ type AttachmentQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver FROM attachment"
|
||||
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id FROM attachment"
|
||||
)
|
||||
|
||||
func (aq *AttachmentQuery) New() *Attachment {
|
||||
@@ -77,58 +77,50 @@ type Attachment struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
Channel PortalKey
|
||||
Channel PortalKey
|
||||
ThreadID string
|
||||
MessageID string
|
||||
ID string
|
||||
MXID id.EventID
|
||||
}
|
||||
|
||||
DiscordMessageID string
|
||||
DiscordAttachmentID string
|
||||
MXID id.EventID
|
||||
func (a *Attachment) DiscordProtoChannelID() string {
|
||||
if a.ThreadID != "" {
|
||||
return a.ThreadID
|
||||
} else {
|
||||
return a.Channel.ChannelID
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Attachment) Scan(row dbutil.Scannable) *Attachment {
|
||||
err := row.Scan(
|
||||
&a.DiscordAttachmentID, &a.DiscordMessageID,
|
||||
&a.Channel.ChannelID, &a.Channel.Receiver,
|
||||
&a.MXID)
|
||||
|
||||
err := row.Scan(&a.ID, &a.MessageID, &a.Channel.ChannelID, &a.Channel.Receiver, &a.ThreadID, &a.MXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
a.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *Attachment) Insert() {
|
||||
query := "INSERT INTO attachment" +
|
||||
" (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, " +
|
||||
" mxid) VALUES ($1, $2, $3, $4, $5);"
|
||||
|
||||
_, err := a.db.Exec(
|
||||
query,
|
||||
a.Channel.ChannelID, a.Channel.Receiver,
|
||||
a.DiscordMessageID, a.DiscordAttachmentID,
|
||||
a.MXID,
|
||||
)
|
||||
|
||||
query := `
|
||||
INSERT INTO attachment (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, mxid)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`
|
||||
_, err := a.db.Exec(query, a.ID, a.MessageID, a.Channel.ChannelID, a.Channel.Receiver, strPtr(a.ThreadID), a.MXID)
|
||||
if err != nil {
|
||||
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
|
||||
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.ID, a.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Attachment) Delete() {
|
||||
query := "DELETE FROM attachment WHERE" +
|
||||
" dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
|
||||
|
||||
_, err := a.db.Exec(
|
||||
query,
|
||||
a.Channel.ChannelID, a.Channel.Receiver,
|
||||
a.DiscordAttachmentID,
|
||||
)
|
||||
|
||||
query := `DELETE FROM attachment WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3`
|
||||
_, err := a.db.Exec(query, a.Channel.ChannelID, a.Channel.Receiver, a.ID)
|
||||
if err != nil {
|
||||
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
|
||||
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.ID, a.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,9 @@ import (
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"go.mau.fi/mautrix-discord/database/upgrades"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
|
||||
"go.mau.fi/mautrix-discord/database/upgrades"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
@@ -18,6 +19,7 @@ type Database struct {
|
||||
Portal *PortalQuery
|
||||
Puppet *PuppetQuery
|
||||
Message *MessageQuery
|
||||
Thread *ThreadQuery
|
||||
Reaction *ReactionQuery
|
||||
Attachment *AttachmentQuery
|
||||
Emoji *EmojiQuery
|
||||
@@ -54,6 +56,10 @@ func New(baseDB *dbutil.Database) *Database {
|
||||
db: db,
|
||||
log: db.Log.Sub("Message"),
|
||||
}
|
||||
db.Thread = &ThreadQuery{
|
||||
db: db,
|
||||
log: db.Log.Sub("Thread"),
|
||||
}
|
||||
db.Reaction = &ReactionQuery{
|
||||
db: db,
|
||||
log: db.Log.Sub("Reaction"),
|
||||
@@ -72,3 +78,10 @@ func New(baseDB *dbutil.Database) *Database {
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func strPtr(val string) *string {
|
||||
if val == "" {
|
||||
return nil
|
||||
}
|
||||
return &val
|
||||
}
|
||||
|
||||
@@ -28,23 +28,16 @@ func (eq *EmojiQuery) New() *Emoji {
|
||||
|
||||
func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji {
|
||||
query := emojiSelect + " WHERE discord_id=$1"
|
||||
|
||||
return eq.get(query, discordID)
|
||||
}
|
||||
|
||||
func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji {
|
||||
query := emojiSelect + " WHERE matrix_url=$1"
|
||||
|
||||
return eq.get(query, matrixURL.String())
|
||||
}
|
||||
|
||||
func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji {
|
||||
row := eq.db.QueryRow(query, args...)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return eq.New().Scan(row)
|
||||
return eq.New().Scan(eq.db.QueryRow(query, args...))
|
||||
}
|
||||
|
||||
type Emoji struct {
|
||||
@@ -59,18 +52,17 @@ type Emoji struct {
|
||||
|
||||
func (e *Emoji) Scan(row dbutil.Scannable) *Emoji {
|
||||
var matrixURL sql.NullString
|
||||
err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL)
|
||||
|
||||
err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
e.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
e.MatrixURL, _ = id.ParseContentURI(matrixURL.String)
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -83,6 +75,7 @@ func (e *Emoji) Insert() {
|
||||
|
||||
if err != nil {
|
||||
e.log.Warnfln("Failed to insert emoji %s: %v", e.DiscordID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,17 +85,15 @@ func (e *Emoji) Delete() {
|
||||
_, err := e.db.Exec(query, e.DiscordID)
|
||||
if err != nil {
|
||||
e.log.Warnfln("Failed to delete emoji %s: %v", e.DiscordID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Emoji) APIName() string {
|
||||
if e.DiscordID != "" && e.DiscordName != "" {
|
||||
return e.DiscordName + ":" + e.DiscordID
|
||||
}
|
||||
|
||||
if e.DiscordName != "" {
|
||||
} else if e.DiscordName != "" {
|
||||
return e.DiscordName
|
||||
}
|
||||
|
||||
return e.DiscordID
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
@@ -16,7 +16,7 @@ type GuildQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild"
|
||||
guildSelect = "SELECT dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels FROM guild"
|
||||
)
|
||||
|
||||
func (gq *GuildQuery) New() *Guild {
|
||||
@@ -26,109 +26,100 @@ func (gq *GuildQuery) New() *Guild {
|
||||
}
|
||||
}
|
||||
|
||||
func (gq *GuildQuery) Get(discordID, guildID string) *Guild {
|
||||
query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2"
|
||||
|
||||
row := gq.db.QueryRow(query, discordID, guildID)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return gq.New().Scan(row)
|
||||
func (gq *GuildQuery) GetByID(dcid string) *Guild {
|
||||
query := guildSelect + " WHERE dcid=$1"
|
||||
return gq.New().Scan(gq.db.QueryRow(query, dcid))
|
||||
}
|
||||
|
||||
func (gq *GuildQuery) GetAll(discordID string) []*Guild {
|
||||
query := guildSelect + " WHERE discord_id=$1"
|
||||
func (gq *GuildQuery) GetByMXID(mxid id.RoomID) *Guild {
|
||||
query := guildSelect + " WHERE mxid=$1"
|
||||
return gq.New().Scan(gq.db.QueryRow(query, mxid))
|
||||
}
|
||||
|
||||
rows, err := gq.db.Query(query, discordID)
|
||||
if err != nil || rows == nil {
|
||||
func (gq *GuildQuery) GetAll() []*Guild {
|
||||
rows, err := gq.db.Query(guildSelect)
|
||||
if err != nil {
|
||||
gq.log.Errorln("Failed to query guilds:", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
guilds := []*Guild{}
|
||||
var guilds []*Guild
|
||||
for rows.Next() {
|
||||
guilds = append(guilds, gq.New().Scan(rows))
|
||||
guild := gq.New().Scan(rows)
|
||||
if guild != nil {
|
||||
guilds = append(guilds, guild)
|
||||
}
|
||||
}
|
||||
|
||||
return guilds
|
||||
}
|
||||
|
||||
func (gq *GuildQuery) Prune(discordID string, guilds []string) {
|
||||
// We need this interface slice because a variadic function can't mix
|
||||
// arguements with a `...` expanded slice.
|
||||
args := []interface{}{discordID}
|
||||
|
||||
nGuilds := len(guilds)
|
||||
if nGuilds <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
gq.log.Debugfln("prunning guilds for %s", discordID)
|
||||
|
||||
// Build the in query
|
||||
inQuery := "$2"
|
||||
for i := 1; i < nGuilds; i++ {
|
||||
inQuery += fmt.Sprintf(", $%d", i+2)
|
||||
}
|
||||
|
||||
// Add the arguements for the build query
|
||||
for _, guildID := range guilds {
|
||||
args = append(args, guildID)
|
||||
}
|
||||
|
||||
// Now remove any guilds that the user has left.
|
||||
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" +
|
||||
inQuery + ")"
|
||||
|
||||
_, err := gq.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err)
|
||||
}
|
||||
}
|
||||
|
||||
type Guild struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
DiscordID string
|
||||
GuildID string
|
||||
GuildName string
|
||||
Bridge bool
|
||||
ID string
|
||||
MXID id.RoomID
|
||||
Name string
|
||||
NameSet bool
|
||||
Avatar string
|
||||
AvatarURL id.ContentURI
|
||||
AvatarSet bool
|
||||
|
||||
AutoBridgeChannels bool
|
||||
}
|
||||
|
||||
func (g *Guild) Scan(row dbutil.Scannable) *Guild {
|
||||
err := row.Scan(&g.DiscordID, &g.GuildID, &g.GuildName, &g.Bridge)
|
||||
var mxid sql.NullString
|
||||
var avatarURL string
|
||||
err := row.Scan(&g.ID, &mxid, &g.Name, &g.NameSet, &g.Avatar, &avatarURL, &g.AvatarSet, &g.AutoBridgeChannels)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
g.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
g.MXID = id.RoomID(mxid.String)
|
||||
g.AvatarURL, _ = id.ParseContentURI(avatarURL)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *Guild) Upsert() {
|
||||
query := "INSERT INTO guild" +
|
||||
" (discord_id, guild_id, guild_name, bridge)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT(discord_id, guild_id)" +
|
||||
" DO UPDATE SET guild_name=excluded.guild_name, bridge=excluded.bridge"
|
||||
|
||||
_, err := g.db.Exec(query, g.DiscordID, g.GuildID, g.GuildName, g.Bridge)
|
||||
|
||||
func (g *Guild) mxidPtr() *id.RoomID {
|
||||
if g.MXID != "" {
|
||||
return &g.MXID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (g *Guild) Insert() {
|
||||
query := `
|
||||
INSERT INTO guild (dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`
|
||||
_, err := g.db.Exec(query, g.ID, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels)
|
||||
if err != nil {
|
||||
g.log.Warnfln("Failed to upsert guild %s for %s: %v", g.GuildID, g.DiscordID, err)
|
||||
g.log.Warnfln("Failed to insert %s: %v", g.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guild) Update() {
|
||||
query := `
|
||||
UPDATE guild SET mxid=$1, name=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, auto_bridge_channels=$7
|
||||
WHERE dcid=$8
|
||||
`
|
||||
_, err := g.db.Exec(query, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels, g.ID)
|
||||
if err != nil {
|
||||
g.log.Warnfln("Failed to update %s: %v", g.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guild) Delete() {
|
||||
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id=$2"
|
||||
|
||||
_, err := g.db.Exec(query, g.DiscordID, g.GuildID)
|
||||
|
||||
_, err := g.db.Exec("DELETE FROM guild WHERE dcid=$1", g.ID)
|
||||
if err != nil {
|
||||
g.log.Warnfln("Failed to delete guild %s for user %s: %v", g.GuildID, g.DiscordID, err)
|
||||
g.log.Warnfln("Failed to delete %s: %v", g.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ type MessageQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid FROM message"
|
||||
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message"
|
||||
)
|
||||
|
||||
func (mq *MessageQuery) New() *Message {
|
||||
@@ -45,14 +45,12 @@ func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
|
||||
|
||||
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
|
||||
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
|
||||
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
|
||||
}
|
||||
|
||||
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)
|
||||
if row == nil {
|
||||
mq.log.Debugfln("failed to find existing message for discord_id %s", discordID)
|
||||
return nil
|
||||
}
|
||||
|
||||
return mq.New().Scan(row)
|
||||
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
|
||||
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC LIMIT 1"
|
||||
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
|
||||
@@ -74,36 +72,54 @@ type Message struct {
|
||||
Channel PortalKey
|
||||
SenderID string
|
||||
Timestamp time.Time
|
||||
ThreadID string
|
||||
|
||||
MXID id.EventID
|
||||
}
|
||||
|
||||
func (m *Message) DiscordProtoChannelID() string {
|
||||
if m.ThreadID != "" {
|
||||
return m.ThreadID
|
||||
} else {
|
||||
return m.Channel.ChannelID
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) Scan(row dbutil.Scannable) *Message {
|
||||
var ts int64
|
||||
var threadID sql.NullString
|
||||
|
||||
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.MXID)
|
||||
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
m.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if ts != 0 {
|
||||
m.Timestamp = time.Unix(ts, 0)
|
||||
m.Timestamp = time.UnixMilli(ts)
|
||||
}
|
||||
m.ThreadID = threadID.String
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Message) Insert() {
|
||||
query := "INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
query := `
|
||||
INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
`
|
||||
|
||||
_, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.Unix(), m.MXID)
|
||||
_, err := m.db.Exec(query,
|
||||
m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
|
||||
m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
|
||||
|
||||
if err != nil {
|
||||
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,5 +130,6 @@ func (m *Message) Delete() {
|
||||
|
||||
if err != nil {
|
||||
m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,11 +12,30 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
portalSelect = "SELECT dcid, receiver, mxid, name, topic, avatar," +
|
||||
" avatar_url, type, other_user_id, first_event_id, encrypted" +
|
||||
portalSelect = "SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, " +
|
||||
" mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, in_space, first_event_id" +
|
||||
" FROM portal"
|
||||
)
|
||||
|
||||
type PortalKey struct {
|
||||
ChannelID string
|
||||
Receiver string
|
||||
}
|
||||
|
||||
func NewPortalKey(channelID, receiver string) PortalKey {
|
||||
return PortalKey{
|
||||
ChannelID: channelID,
|
||||
Receiver: receiver,
|
||||
}
|
||||
}
|
||||
|
||||
func (key PortalKey) String() string {
|
||||
if key.Receiver == "" {
|
||||
return key.ChannelID
|
||||
}
|
||||
return key.ChannelID + "-" + key.Receiver
|
||||
}
|
||||
|
||||
type PortalQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
@@ -34,7 +53,7 @@ func (pq *PortalQuery) GetAll() []*Portal {
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
|
||||
return pq.get(portalSelect+" WHERE dcid=$1 AND receiver=$2", key.ChannelID, key.Receiver)
|
||||
return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
|
||||
@@ -67,12 +86,7 @@ func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
|
||||
row := pq.db.QueryRow(query, args...)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return pq.New().Scan(row)
|
||||
return pq.New().Scan(pq.db.QueryRow(query, args...))
|
||||
}
|
||||
|
||||
type Portal struct {
|
||||
@@ -82,78 +96,87 @@ type Portal struct {
|
||||
Key PortalKey
|
||||
Type discordgo.ChannelType
|
||||
OtherUserID string
|
||||
ParentID string
|
||||
GuildID string
|
||||
|
||||
MXID id.RoomID
|
||||
|
||||
Name string
|
||||
NameSet bool
|
||||
Topic string
|
||||
TopicSet bool
|
||||
Avatar string
|
||||
AvatarURL id.ContentURI
|
||||
AvatarSet bool
|
||||
Encrypted bool
|
||||
InSpace id.RoomID
|
||||
|
||||
FirstEventID id.EventID
|
||||
}
|
||||
|
||||
func (p *Portal) Scan(row dbutil.Scannable) *Portal {
|
||||
var mxid, avatarURL, firstEventID sql.NullString
|
||||
var typ sql.NullInt32
|
||||
var otherUserID, guildID, parentID, mxid, firstEventID sql.NullString
|
||||
var chanType int32
|
||||
var avatarURL string
|
||||
|
||||
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
|
||||
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.OtherUserID, &firstEventID,
|
||||
&p.Encrypted)
|
||||
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID,
|
||||
&mxid, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet,
|
||||
&p.Encrypted, &p.InSpace, &firstEventID)
|
||||
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
p.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
p.MXID = id.RoomID(mxid.String)
|
||||
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
|
||||
p.Type = discordgo.ChannelType(typ.Int32)
|
||||
p.OtherUserID = otherUserID.String
|
||||
p.GuildID = guildID.String
|
||||
p.ParentID = parentID.String
|
||||
p.Type = discordgo.ChannelType(chanType)
|
||||
p.FirstEventID = id.EventID(firstEventID.String)
|
||||
p.AvatarURL, _ = id.ParseContentURI(avatarURL)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Portal) mxidPtr() *id.RoomID {
|
||||
if p.MXID != "" {
|
||||
return &p.MXID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Portal) Insert() {
|
||||
query := "INSERT INTO portal" +
|
||||
" (dcid, receiver, mxid, name, topic, avatar, avatar_url," +
|
||||
" type, other_user_id, first_event_id, encrypted)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
|
||||
|
||||
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
|
||||
p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.OtherUserID,
|
||||
p.FirstEventID.String(), p.Encrypted)
|
||||
query := `
|
||||
INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
|
||||
name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
|
||||
encrypted, in_space, first_event_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
|
||||
`
|
||||
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type,
|
||||
strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
|
||||
p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
|
||||
p.Encrypted, p.InSpace, p.FirstEventID.String())
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Portal) Update() {
|
||||
query := "UPDATE portal SET" +
|
||||
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
|
||||
" other_user_id=$7, first_event_id=$8, encrypted=$9" +
|
||||
" WHERE dcid=$10 AND receiver=$11"
|
||||
|
||||
_, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
|
||||
p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(),
|
||||
p.Encrypted,
|
||||
query := `
|
||||
UPDATE portal SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5,
|
||||
name=$6, name_set=$7, topic=$8, topic_set=$9, avatar=$10, avatar_url=$11, avatar_set=$12,
|
||||
encrypted=$13, in_space=$14, first_event_id=$15
|
||||
WHERE dcid=$16 AND receiver=$17
|
||||
`
|
||||
_, err := p.db.Exec(query,
|
||||
p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
|
||||
p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
|
||||
p.Encrypted, p.InSpace, p.FirstEventID.String(),
|
||||
p.Key.ChannelID, p.Key.Receiver)
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to update %s: %v", p.Key, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,5 +185,6 @@ func (p *Portal) Delete() {
|
||||
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package database
|
||||
|
||||
type PortalKey struct {
|
||||
ChannelID string
|
||||
Receiver string
|
||||
}
|
||||
|
||||
func NewPortalKey(channelID, receiver string) PortalKey {
|
||||
return PortalKey{
|
||||
ChannelID: channelID,
|
||||
Receiver: receiver,
|
||||
}
|
||||
}
|
||||
|
||||
func (key PortalKey) String() string {
|
||||
if key.Receiver == "" {
|
||||
return key.ChannelID
|
||||
}
|
||||
return key.ChannelID + "-" + key.Receiver
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
puppetSelect = "SELECT id, display_name, avatar, avatar_url," +
|
||||
puppetSelect = "SELECT id, name, name_set, avatar, avatar_url, avatar_set," +
|
||||
" custom_mxid, access_token, next_batch" +
|
||||
" FROM puppet "
|
||||
)
|
||||
@@ -36,12 +36,7 @@ func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
|
||||
row := pq.db.QueryRow(query, args...)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return pq.New().Scan(row)
|
||||
return pq.New().Scan(pq.db.QueryRow(query, args...))
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) GetAll() []*Puppet {
|
||||
@@ -59,7 +54,7 @@ func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
puppets := []*Puppet{}
|
||||
var puppets []*Puppet
|
||||
for rows.Next() {
|
||||
puppets = append(puppets, pq.New().Scan(rows))
|
||||
}
|
||||
@@ -71,11 +66,12 @@ type Puppet struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
ID string
|
||||
DisplayName string
|
||||
|
||||
ID string
|
||||
Name string
|
||||
NameSet bool
|
||||
Avatar string
|
||||
AvatarURL id.ContentURI
|
||||
AvatarSet bool
|
||||
|
||||
CustomMXID id.UserID
|
||||
AccessToken string
|
||||
@@ -83,24 +79,22 @@ type Puppet struct {
|
||||
}
|
||||
|
||||
func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
|
||||
var did, displayName, avatar, avatarURL sql.NullString
|
||||
var avatarURL string
|
||||
var customMXID, accessToken, nextBatch sql.NullString
|
||||
|
||||
err := row.Scan(&did, &displayName, &avatar, &avatarURL,
|
||||
err := row.Scan(&p.ID, &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.AvatarSet,
|
||||
&customMXID, &accessToken, &nextBatch)
|
||||
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
p.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
p.ID = did.String
|
||||
p.DisplayName = displayName.String
|
||||
p.Avatar = avatar.String
|
||||
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
|
||||
p.AvatarURL, _ = id.ParseContentURI(avatarURL)
|
||||
p.CustomMXID = id.UserID(customMXID.String)
|
||||
p.AccessToken = accessToken.String
|
||||
p.NextBatch = nextBatch.String
|
||||
@@ -109,31 +103,31 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
|
||||
}
|
||||
|
||||
func (p *Puppet) Insert() {
|
||||
query := "INSERT INTO puppet" +
|
||||
" (id, display_name, avatar, avatar_url," +
|
||||
" custom_mxid, access_token, next_batch)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
|
||||
_, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar,
|
||||
p.AvatarURL.String(), p.CustomMXID, p.AccessToken,
|
||||
p.NextBatch)
|
||||
query := `
|
||||
INSERT INTO puppet (id, name, name_set, avatar, avatar_url, avatar_set, custom_mxid, access_token, next_batch)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`
|
||||
_, err := p.db.Exec(query, p.ID, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
|
||||
strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch))
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Puppet) Update() {
|
||||
query := "UPDATE puppet" +
|
||||
" SET display_name=$1, avatar=$2, avatar_url=$3, " +
|
||||
" custom_mxid=$4, access_token=$5, next_batch=$6" +
|
||||
" WHERE id=$7"
|
||||
|
||||
_, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(),
|
||||
p.CustomMXID, p.AccessToken, p.NextBatch,
|
||||
query := `
|
||||
UPDATE puppet SET name=$1, name_set=$2, avatar=$3, avatar_url=$4, avatar_set=$5,
|
||||
custom_mxid=$6, access_token=$7, next_batch=$8
|
||||
WHERE id=$9
|
||||
`
|
||||
_, err := p.db.Exec(query, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
|
||||
strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch),
|
||||
p.ID)
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to update %s: %v", p.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type ReactionQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, mxid FROM reaction"
|
||||
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, mxid FROM reaction"
|
||||
)
|
||||
|
||||
func (rq *ReactionQuery) New() *Reaction {
|
||||
@@ -75,15 +75,17 @@ type Reaction struct {
|
||||
MessageID string
|
||||
Sender string
|
||||
EmojiName string
|
||||
ThreadID string
|
||||
|
||||
MXID id.EventID
|
||||
}
|
||||
|
||||
func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
|
||||
err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.MXID)
|
||||
err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.ThreadID, &r.MXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
r.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -91,14 +93,23 @@ func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Reaction) DiscordProtoChannelID() string {
|
||||
if r.ThreadID != "" {
|
||||
return r.ThreadID
|
||||
} else {
|
||||
return r.Channel.ChannelID
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reaction) Insert() {
|
||||
query := `
|
||||
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, mxid)
|
||||
VALUES($1, $2, $3, $4, $5, $6)
|
||||
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid)
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7)
|
||||
`
|
||||
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, r.MXID)
|
||||
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID)
|
||||
if err != nil {
|
||||
r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,5 +118,6 @@ func (r *Reaction) Delete() {
|
||||
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName)
|
||||
if err != nil {
|
||||
r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.MessageID, r.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
105
database/thread.go
Normal file
105
database/thread.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type ThreadQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
const (
|
||||
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread"
|
||||
)
|
||||
|
||||
func (tq *ThreadQuery) New() *Thread {
|
||||
return &Thread{
|
||||
db: tq.db,
|
||||
log: tq.log,
|
||||
}
|
||||
}
|
||||
|
||||
func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
|
||||
query := threadSelect + " WHERE dcid=$1"
|
||||
|
||||
row := tq.db.QueryRow(query, discordID)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tq.New().Scan(row)
|
||||
}
|
||||
|
||||
//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
|
||||
// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
|
||||
//
|
||||
// row := tq.db.QueryRow(query, channelID, messageID)
|
||||
// if row == nil {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// return tq.New().Scan(row)
|
||||
//}
|
||||
|
||||
func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
|
||||
query := threadSelect + " WHERE root_msg_mxid=$1"
|
||||
|
||||
row := tq.db.QueryRow(query, mxid)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tq.New().Scan(row)
|
||||
}
|
||||
|
||||
type Thread struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
ID string
|
||||
ParentID string
|
||||
|
||||
RootDiscordID string
|
||||
RootMXID id.EventID
|
||||
}
|
||||
|
||||
func (t *Thread) Scan(row dbutil.Scannable) *Thread {
|
||||
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Thread) Insert() {
|
||||
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
|
||||
|
||||
if err != nil {
|
||||
t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Thread) Delete() {
|
||||
query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
|
||||
|
||||
_, err := t.db.Exec(query, t.ID, t.ParentID)
|
||||
|
||||
if err != nil {
|
||||
t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -1,29 +1,65 @@
|
||||
-- v0 -> v2: Latest revision
|
||||
-- v0 -> v3: Latest revision
|
||||
|
||||
CREATE TABLE guild (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
mxid TEXT UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
name_set BOOLEAN NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
avatar_set BOOLEAN NOT NULL,
|
||||
|
||||
auto_bridge_channels BOOLEAN NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE portal (
|
||||
dcid TEXT,
|
||||
receiver TEXT,
|
||||
other_user_id TEXT,
|
||||
type INTEGER,
|
||||
type INTEGER NOT NULL,
|
||||
|
||||
dc_guild_id TEXT,
|
||||
dc_parent_id TEXT,
|
||||
-- This is not accessed by the bridge, it's only used for the portal parent foreign key.
|
||||
-- Only guild channels have parents, but only DMs have a receiver field.
|
||||
dc_parent_receiver TEXT NOT NULL DEFAULT '',
|
||||
|
||||
mxid TEXT UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
name_set BOOLEAN NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
encrypted BOOLEAN NOT NULL DEFAULT false,
|
||||
avatar_set BOOLEAN NOT NULL,
|
||||
encrypted BOOLEAN NOT NULL,
|
||||
in_space TEXT NOT NULL,
|
||||
|
||||
first_event_id TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (dcid, receiver)
|
||||
PRIMARY KEY (dcid, receiver),
|
||||
CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE,
|
||||
CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE thread (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
parent_chan_id TEXT NOT NULL,
|
||||
root_msg_dcid TEXT NOT NULL,
|
||||
root_msg_mxid TEXT NOT NULL,
|
||||
-- This is also not accessed by the bridge.
|
||||
receiver TEXT NOT NULL DEFAULT '',
|
||||
|
||||
CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE puppet (
|
||||
id TEXT PRIMARY KEY,
|
||||
|
||||
name TEXT,
|
||||
avatar TEXT,
|
||||
avatar_url TEXT,
|
||||
name TEXT NOT NULL,
|
||||
name_set BOOLEAN NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
avatar_set BOOLEAN NOT NULL,
|
||||
|
||||
custom_mxid TEXT,
|
||||
access_token TEXT,
|
||||
@@ -34,9 +70,20 @@ CREATE TABLE "user" (
|
||||
mxid TEXT PRIMARY KEY,
|
||||
dcid TEXT UNIQUE,
|
||||
|
||||
discord_token TEXT,
|
||||
management_room TEXT,
|
||||
space_room TEXT
|
||||
);
|
||||
|
||||
token TEXT
|
||||
CREATE TABLE user_guild (
|
||||
guild_id TEXT,
|
||||
user_mxid TEXT,
|
||||
in_space BOOLEAN NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
|
||||
PRIMARY KEY (guild_id, user_mxid),
|
||||
CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE,
|
||||
CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE message (
|
||||
@@ -45,6 +92,7 @@ CREATE TABLE message (
|
||||
dc_chan_receiver TEXT,
|
||||
dc_sender TEXT NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
dc_thread_id TEXT,
|
||||
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
|
||||
@@ -58,6 +106,7 @@ CREATE TABLE reaction (
|
||||
dc_msg_id TEXT,
|
||||
dc_sender TEXT,
|
||||
dc_emoji_name TEXT,
|
||||
dc_thread_id TEXT,
|
||||
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
|
||||
@@ -70,6 +119,7 @@ CREATE TABLE attachment (
|
||||
dc_msg_id TEXT,
|
||||
dc_chan_id TEXT,
|
||||
dc_chan_receiver TEXT,
|
||||
dc_thread_id TEXT,
|
||||
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
|
||||
@@ -82,11 +132,3 @@ CREATE TABLE emoji (
|
||||
discord_name TEXT,
|
||||
matrix_url TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE guild (
|
||||
discord_id TEXT NOT NULL,
|
||||
guild_id TEXT NOT NULL,
|
||||
guild_name TEXT NOT NULL,
|
||||
bridge BOOLEAN DEFAULT FALSE,
|
||||
PRIMARY KEY(discord_id, guild_id)
|
||||
);
|
||||
|
||||
72
database/upgrades/03-spaces.sql
Normal file
72
database/upgrades/03-spaces.sql
Normal file
@@ -0,0 +1,72 @@
|
||||
-- v3: Store portal parent metadata for spaces
|
||||
DROP TABLE guild;
|
||||
|
||||
CREATE TABLE guild (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
mxid TEXT UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
name_set BOOLEAN NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
avatar_set BOOLEAN NOT NULL,
|
||||
|
||||
auto_bridge_channels BOOLEAN NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE user_guild (
|
||||
guild_id TEXT,
|
||||
user_mxid TEXT,
|
||||
in_space BOOLEAN NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
|
||||
PRIMARY KEY (guild_id, user_mxid),
|
||||
CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE,
|
||||
CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
ALTER TABLE portal ADD COLUMN dc_guild_id TEXT;
|
||||
ALTER TABLE portal ADD COLUMN dc_parent_id TEXT;
|
||||
ALTER TABLE portal ADD COLUMN dc_parent_receiver TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE portal ADD CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE;
|
||||
ALTER TABLE portal ADD CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE;
|
||||
DELETE FROM portal WHERE type IS NULL;
|
||||
-- only: postgres
|
||||
ALTER TABLE portal ALTER COLUMN type SET NOT NULL;
|
||||
|
||||
ALTER TABLE portal ADD COLUMN in_space TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE portal ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false;
|
||||
ALTER TABLE portal ADD COLUMN topic_set BOOLEAN NOT NULL DEFAULT false;
|
||||
ALTER TABLE portal ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false;
|
||||
-- only: postgres for next 5 lines
|
||||
ALTER TABLE portal ALTER COLUMN in_space DROP DEFAULT;
|
||||
ALTER TABLE portal ALTER COLUMN name_set DROP DEFAULT;
|
||||
ALTER TABLE portal ALTER COLUMN topic_set DROP DEFAULT;
|
||||
ALTER TABLE portal ALTER COLUMN avatar_set DROP DEFAULT;
|
||||
ALTER TABLE portal ALTER COLUMN encrypted DROP DEFAULT;
|
||||
|
||||
ALTER TABLE puppet RENAME COLUMN display_name TO name;
|
||||
ALTER TABLE puppet ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false;
|
||||
ALTER TABLE puppet ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false;
|
||||
-- only: postgres for next 2 lines
|
||||
ALTER TABLE puppet ALTER COLUMN name_set DROP DEFAULT;
|
||||
ALTER TABLE puppet ALTER COLUMN avatar_set DROP DEFAULT;
|
||||
|
||||
ALTER TABLE "user" ADD COLUMN space_room TEXT;
|
||||
ALTER TABLE "user" RENAME COLUMN token TO discord_token;
|
||||
|
||||
UPDATE message SET timestamp=timestamp*1000;
|
||||
|
||||
CREATE TABLE thread (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
parent_chan_id TEXT NOT NULL,
|
||||
root_msg_dcid TEXT NOT NULL,
|
||||
root_msg_mxid TEXT NOT NULL,
|
||||
-- This is also not accessed by the bridge.
|
||||
receiver TEXT NOT NULL DEFAULT '',
|
||||
|
||||
CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
ALTER TABLE message ADD COLUMN dc_thread_id TEXT;
|
||||
ALTER TABLE attachment ADD COLUMN dc_thread_id TEXT;
|
||||
ALTER TABLE reaction ADD COLUMN dc_thread_id TEXT;
|
||||
176
database/user.go
176
database/user.go
@@ -2,6 +2,8 @@ package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
@@ -22,38 +24,32 @@ func (uq *UserQuery) New() *User {
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
|
||||
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE mxid=$1`
|
||||
row := uq.db.QueryRow(query, userID)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return uq.New().Scan(row)
|
||||
query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE mxid=$1`
|
||||
return uq.New().Scan(uq.db.QueryRow(query, userID))
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetByID(id string) *User {
|
||||
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE dcid=$1`
|
||||
row := uq.db.QueryRow(query, id)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return uq.New().Scan(row)
|
||||
query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE dcid=$1`
|
||||
return uq.New().Scan(uq.db.QueryRow(query, id))
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetAll() []*User {
|
||||
rows, err := uq.db.Query(`SELECT mxid, dcid, management_room, token FROM "user" WHERE token IS NOT NULL`)
|
||||
func (uq *UserQuery) GetAllWithToken() []*User {
|
||||
query := `
|
||||
SELECT mxid, dcid, discord_token, management_room, space_room
|
||||
FROM "user" WHERE discord_token IS NOT NULL
|
||||
`
|
||||
rows, err := uq.db.Query(query)
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
users := []*User{}
|
||||
var users []*User
|
||||
for rows.Next() {
|
||||
users = append(users, uq.New().Scan(rows))
|
||||
user := uq.New().Scan(rows)
|
||||
if user != nil {
|
||||
users = append(users, user)
|
||||
}
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
@@ -61,80 +57,114 @@ type User struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
MXID id.UserID
|
||||
ID string
|
||||
|
||||
MXID id.UserID
|
||||
DiscordID string
|
||||
DiscordToken string
|
||||
ManagementRoom id.RoomID
|
||||
SpaceRoom id.RoomID
|
||||
}
|
||||
|
||||
Token string
|
||||
type UserGuild struct {
|
||||
GuildID string
|
||||
Timestamp time.Time
|
||||
InSpace bool
|
||||
}
|
||||
|
||||
func (u *User) GetGuilds() []UserGuild {
|
||||
res, err := u.db.Query("SELECT guild_id, timestamp, in_space FROM user_guild WHERE user_mxid=$1", u.MXID)
|
||||
if err != nil {
|
||||
u.log.Errorln("Failed to get guilds:", err)
|
||||
panic(err)
|
||||
return nil
|
||||
}
|
||||
var guilds []UserGuild
|
||||
for res.Next() {
|
||||
var guild UserGuild
|
||||
var ts int64
|
||||
err = res.Scan(&guild.GuildID, &ts, &guild.InSpace)
|
||||
if err != nil {
|
||||
u.log.Errorln("Error scanning user guild:", err)
|
||||
panic(err)
|
||||
} else {
|
||||
guild.Timestamp = time.UnixMilli(ts)
|
||||
guilds = append(guilds, guild)
|
||||
}
|
||||
}
|
||||
return guilds
|
||||
}
|
||||
|
||||
func (u *User) IsInSpace(guildID string) (isIn bool) {
|
||||
query := `SELECT in_space FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
|
||||
err := u.db.QueryRow(query, u.MXID, guildID).Scan(&isIn)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, guildID, err)
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (u *User) MarkInGuild(guild UserGuild) {
|
||||
query := `
|
||||
INSERT INTO user_guild (guild_id, user_mxid, timestamp, in_space)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (guild_id, user_mxid) DO UPDATE
|
||||
SET timestamp=excluded.timestamp, in_space=excluded.in_space
|
||||
`
|
||||
_, err := u.db.Exec(query, guild.GuildID, u.MXID, guild.Timestamp.UnixMilli(), guild.InSpace)
|
||||
if err != nil {
|
||||
u.log.Errorfln("Failed to insert user guild %s/%s: %v", u.MXID, guild.GuildID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) MarkNotInGuild(guildID string) {
|
||||
query := `DELETE FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
|
||||
_, err := u.db.Exec(query, u.MXID, guildID)
|
||||
if err != nil {
|
||||
u.log.Errorfln("Failed to remove user guild %s/%s: %v", u.MXID, guildID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) PruneGuildList(beforeTS time.Time) {
|
||||
_, err := u.db.Exec("DELETE FROM user_guild WHERE user_mxid=$1 AND timestamp<$2", u.MXID, beforeTS.UnixMilli())
|
||||
if err != nil {
|
||||
u.log.Errorln("Failed to prune user guild list:", err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) Scan(row dbutil.Scannable) *User {
|
||||
var token sql.NullString
|
||||
var discordID sql.NullString
|
||||
|
||||
err := row.Scan(&u.MXID, &discordID, &u.ManagementRoom, &token)
|
||||
var discordID, managementRoom, spaceRoom, discordToken sql.NullString
|
||||
err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
u.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if token.Valid {
|
||||
u.Token = token.String
|
||||
}
|
||||
|
||||
if discordID.Valid {
|
||||
u.ID = discordID.String
|
||||
}
|
||||
|
||||
u.DiscordID = discordID.String
|
||||
u.DiscordToken = discordToken.String
|
||||
u.ManagementRoom = id.RoomID(managementRoom.String)
|
||||
u.SpaceRoom = id.RoomID(spaceRoom.String)
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *User) Insert() {
|
||||
query := "INSERT INTO \"user\" (mxid, dcid, management_room, token) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
var token sql.NullString
|
||||
var discordID sql.NullString
|
||||
|
||||
if u.Token != "" {
|
||||
token.String = u.Token
|
||||
token.Valid = true
|
||||
}
|
||||
|
||||
if u.ID != "" {
|
||||
discordID.String = u.ID
|
||||
discordID.Valid = true
|
||||
}
|
||||
|
||||
_, err := u.db.Exec(query, u.MXID, discordID, u.ManagementRoom, token)
|
||||
|
||||
query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room) VALUES ($1, $2, $3, $4, $5)`
|
||||
_, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)))
|
||||
if err != nil {
|
||||
u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) Update() {
|
||||
query := "UPDATE \"user\" SET dcid=$1, management_room=$2, token=$3 WHERE mxid=$4"
|
||||
|
||||
var token sql.NullString
|
||||
var discordID sql.NullString
|
||||
|
||||
if u.Token != "" {
|
||||
token.String = u.Token
|
||||
token.Valid = true
|
||||
}
|
||||
|
||||
if u.ID != "" {
|
||||
discordID.String = u.ID
|
||||
discordID.Valid = true
|
||||
}
|
||||
|
||||
_, err := u.db.Exec(query, discordID, u.ManagementRoom, token, u.MXID)
|
||||
|
||||
query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4 WHERE mxid=$5`
|
||||
_, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), u.MXID)
|
||||
if err != nil {
|
||||
u.log.Warnfln("Failed to update %q: %v", u.MXID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user