Break a bunch of stuff

This commit is contained in:
Tulir Asokan
2022-05-28 23:03:24 +03:00
parent 575f684a54
commit 91dbc83b5d
26 changed files with 1824 additions and 1002 deletions

View File

@@ -16,7 +16,7 @@ type AttachmentQuery struct {
}
const (
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver FROM attachment"
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id FROM attachment"
)
func (aq *AttachmentQuery) New() *Attachment {
@@ -77,58 +77,50 @@ type Attachment struct {
db *Database
log log.Logger
Channel PortalKey
Channel PortalKey
ThreadID string
MessageID string
ID string
MXID id.EventID
}
DiscordMessageID string
DiscordAttachmentID string
MXID id.EventID
func (a *Attachment) DiscordProtoChannelID() string {
if a.ThreadID != "" {
return a.ThreadID
} else {
return a.Channel.ChannelID
}
}
func (a *Attachment) Scan(row dbutil.Scannable) *Attachment {
err := row.Scan(
&a.DiscordAttachmentID, &a.DiscordMessageID,
&a.Channel.ChannelID, &a.Channel.Receiver,
&a.MXID)
err := row.Scan(&a.ID, &a.MessageID, &a.Channel.ChannelID, &a.Channel.Receiver, &a.ThreadID, &a.MXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
a.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
return a
}
func (a *Attachment) Insert() {
query := "INSERT INTO attachment" +
" (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, " +
" mxid) VALUES ($1, $2, $3, $4, $5);"
_, err := a.db.Exec(
query,
a.Channel.ChannelID, a.Channel.Receiver,
a.DiscordMessageID, a.DiscordAttachmentID,
a.MXID,
)
query := `
INSERT INTO attachment (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, mxid)
VALUES ($1, $2, $3, $4, $5)
`
_, err := a.db.Exec(query, a.ID, a.MessageID, a.Channel.ChannelID, a.Channel.Receiver, strPtr(a.ThreadID), a.MXID)
if err != nil {
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.ID, a.Channel, err)
panic(err)
}
}
func (a *Attachment) Delete() {
query := "DELETE FROM attachment WHERE" +
" dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
_, err := a.db.Exec(
query,
a.Channel.ChannelID, a.Channel.Receiver,
a.DiscordAttachmentID,
)
query := `DELETE FROM attachment WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3`
_, err := a.db.Exec(query, a.Channel.ChannelID, a.Channel.Receiver, a.ID)
if err != nil {
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.ID, a.Channel, err)
panic(err)
}
}

View File

@@ -7,8 +7,9 @@ import (
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"go.mau.fi/mautrix-discord/database/upgrades"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/mautrix-discord/database/upgrades"
)
type Database struct {
@@ -18,6 +19,7 @@ type Database struct {
Portal *PortalQuery
Puppet *PuppetQuery
Message *MessageQuery
Thread *ThreadQuery
Reaction *ReactionQuery
Attachment *AttachmentQuery
Emoji *EmojiQuery
@@ -54,6 +56,10 @@ func New(baseDB *dbutil.Database) *Database {
db: db,
log: db.Log.Sub("Message"),
}
db.Thread = &ThreadQuery{
db: db,
log: db.Log.Sub("Thread"),
}
db.Reaction = &ReactionQuery{
db: db,
log: db.Log.Sub("Reaction"),
@@ -72,3 +78,10 @@ func New(baseDB *dbutil.Database) *Database {
}
return db
}
func strPtr(val string) *string {
if val == "" {
return nil
}
return &val
}

View File

@@ -28,23 +28,16 @@ func (eq *EmojiQuery) New() *Emoji {
func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji {
query := emojiSelect + " WHERE discord_id=$1"
return eq.get(query, discordID)
}
func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji {
query := emojiSelect + " WHERE matrix_url=$1"
return eq.get(query, matrixURL.String())
}
func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji {
row := eq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return eq.New().Scan(row)
return eq.New().Scan(eq.db.QueryRow(query, args...))
}
type Emoji struct {
@@ -59,18 +52,17 @@ type Emoji struct {
func (e *Emoji) Scan(row dbutil.Scannable) *Emoji {
var matrixURL sql.NullString
err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL)
err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
e.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
e.MatrixURL, _ = id.ParseContentURI(matrixURL.String)
return e
}
@@ -83,6 +75,7 @@ func (e *Emoji) Insert() {
if err != nil {
e.log.Warnfln("Failed to insert emoji %s: %v", e.DiscordID, err)
panic(err)
}
}
@@ -92,17 +85,15 @@ func (e *Emoji) Delete() {
_, err := e.db.Exec(query, e.DiscordID)
if err != nil {
e.log.Warnfln("Failed to delete emoji %s: %v", e.DiscordID, err)
panic(err)
}
}
func (e *Emoji) APIName() string {
if e.DiscordID != "" && e.DiscordName != "" {
return e.DiscordName + ":" + e.DiscordID
}
if e.DiscordName != "" {
} else if e.DiscordName != "" {
return e.DiscordName
}
return e.DiscordID
}

View File

@@ -3,9 +3,9 @@ package database
import (
"database/sql"
"errors"
"fmt"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
@@ -16,7 +16,7 @@ type GuildQuery struct {
}
const (
guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild"
guildSelect = "SELECT dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels FROM guild"
)
func (gq *GuildQuery) New() *Guild {
@@ -26,109 +26,100 @@ func (gq *GuildQuery) New() *Guild {
}
}
func (gq *GuildQuery) Get(discordID, guildID string) *Guild {
query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2"
row := gq.db.QueryRow(query, discordID, guildID)
if row == nil {
return nil
}
return gq.New().Scan(row)
func (gq *GuildQuery) GetByID(dcid string) *Guild {
query := guildSelect + " WHERE dcid=$1"
return gq.New().Scan(gq.db.QueryRow(query, dcid))
}
func (gq *GuildQuery) GetAll(discordID string) []*Guild {
query := guildSelect + " WHERE discord_id=$1"
func (gq *GuildQuery) GetByMXID(mxid id.RoomID) *Guild {
query := guildSelect + " WHERE mxid=$1"
return gq.New().Scan(gq.db.QueryRow(query, mxid))
}
rows, err := gq.db.Query(query, discordID)
if err != nil || rows == nil {
func (gq *GuildQuery) GetAll() []*Guild {
rows, err := gq.db.Query(guildSelect)
if err != nil {
gq.log.Errorln("Failed to query guilds:", err)
return nil
}
guilds := []*Guild{}
var guilds []*Guild
for rows.Next() {
guilds = append(guilds, gq.New().Scan(rows))
guild := gq.New().Scan(rows)
if guild != nil {
guilds = append(guilds, guild)
}
}
return guilds
}
func (gq *GuildQuery) Prune(discordID string, guilds []string) {
// We need this interface slice because a variadic function can't mix
// arguements with a `...` expanded slice.
args := []interface{}{discordID}
nGuilds := len(guilds)
if nGuilds <= 0 {
return
}
gq.log.Debugfln("prunning guilds for %s", discordID)
// Build the in query
inQuery := "$2"
for i := 1; i < nGuilds; i++ {
inQuery += fmt.Sprintf(", $%d", i+2)
}
// Add the arguements for the build query
for _, guildID := range guilds {
args = append(args, guildID)
}
// Now remove any guilds that the user has left.
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" +
inQuery + ")"
_, err := gq.db.Exec(query, args...)
if err != nil {
gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err)
}
}
type Guild struct {
db *Database
log log.Logger
DiscordID string
GuildID string
GuildName string
Bridge bool
ID string
MXID id.RoomID
Name string
NameSet bool
Avatar string
AvatarURL id.ContentURI
AvatarSet bool
AutoBridgeChannels bool
}
func (g *Guild) Scan(row dbutil.Scannable) *Guild {
err := row.Scan(&g.DiscordID, &g.GuildID, &g.GuildName, &g.Bridge)
var mxid sql.NullString
var avatarURL string
err := row.Scan(&g.ID, &mxid, &g.Name, &g.NameSet, &g.Avatar, &avatarURL, &g.AvatarSet, &g.AutoBridgeChannels)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
g.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
g.MXID = id.RoomID(mxid.String)
g.AvatarURL, _ = id.ParseContentURI(avatarURL)
return g
}
func (g *Guild) Upsert() {
query := "INSERT INTO guild" +
" (discord_id, guild_id, guild_name, bridge)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT(discord_id, guild_id)" +
" DO UPDATE SET guild_name=excluded.guild_name, bridge=excluded.bridge"
_, err := g.db.Exec(query, g.DiscordID, g.GuildID, g.GuildName, g.Bridge)
func (g *Guild) mxidPtr() *id.RoomID {
if g.MXID != "" {
return &g.MXID
}
return nil
}
func (g *Guild) Insert() {
query := `
INSERT INTO guild (dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`
_, err := g.db.Exec(query, g.ID, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels)
if err != nil {
g.log.Warnfln("Failed to upsert guild %s for %s: %v", g.GuildID, g.DiscordID, err)
g.log.Warnfln("Failed to insert %s: %v", g.ID, err)
panic(err)
}
}
func (g *Guild) Update() {
query := `
UPDATE guild SET mxid=$1, name=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, auto_bridge_channels=$7
WHERE dcid=$8
`
_, err := g.db.Exec(query, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels, g.ID)
if err != nil {
g.log.Warnfln("Failed to update %s: %v", g.ID, err)
panic(err)
}
}
func (g *Guild) Delete() {
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id=$2"
_, err := g.db.Exec(query, g.DiscordID, g.GuildID)
_, err := g.db.Exec("DELETE FROM guild WHERE dcid=$1", g.ID)
if err != nil {
g.log.Warnfln("Failed to delete guild %s for user %s: %v", g.GuildID, g.DiscordID, err)
g.log.Warnfln("Failed to delete %s: %v", g.ID, err)
panic(err)
}
}

View File

@@ -17,7 +17,7 @@ type MessageQuery struct {
}
const (
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid FROM message"
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message"
)
func (mq *MessageQuery) New() *Message {
@@ -45,14 +45,12 @@ func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
}
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)
if row == nil {
mq.log.Debugfln("failed to find existing message for discord_id %s", discordID)
return nil
}
return mq.New().Scan(row)
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
}
func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
@@ -74,36 +72,54 @@ type Message struct {
Channel PortalKey
SenderID string
Timestamp time.Time
ThreadID string
MXID id.EventID
}
func (m *Message) DiscordProtoChannelID() string {
if m.ThreadID != "" {
return m.ThreadID
} else {
return m.Channel.ChannelID
}
}
func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts int64
var threadID sql.NullString
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.MXID)
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
if ts != 0 {
m.Timestamp = time.Unix(ts, 0)
m.Timestamp = time.UnixMilli(ts)
}
m.ThreadID = threadID.String
return m
}
func (m *Message) Insert() {
query := "INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid) VALUES ($1, $2, $3, $4, $5, $6)"
query := `
INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.Unix(), m.MXID)
_, err := m.db.Exec(query,
m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
if err != nil {
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
panic(err)
}
}
@@ -114,5 +130,6 @@ func (m *Message) Delete() {
if err != nil {
m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err)
panic(err)
}
}

View File

@@ -12,11 +12,30 @@ import (
)
const (
portalSelect = "SELECT dcid, receiver, mxid, name, topic, avatar," +
" avatar_url, type, other_user_id, first_event_id, encrypted" +
portalSelect = "SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, " +
" mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, in_space, first_event_id" +
" FROM portal"
)
type PortalKey struct {
ChannelID string
Receiver string
}
func NewPortalKey(channelID, receiver string) PortalKey {
return PortalKey{
ChannelID: channelID,
Receiver: receiver,
}
}
func (key PortalKey) String() string {
if key.Receiver == "" {
return key.ChannelID
}
return key.ChannelID + "-" + key.Receiver
}
type PortalQuery struct {
db *Database
log log.Logger
@@ -34,7 +53,7 @@ func (pq *PortalQuery) GetAll() []*Portal {
}
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
return pq.get(portalSelect+" WHERE dcid=$1 AND receiver=$2", key.ChannelID, key.Receiver)
return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
@@ -67,12 +86,7 @@ func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
return pq.New().Scan(pq.db.QueryRow(query, args...))
}
type Portal struct {
@@ -82,78 +96,87 @@ type Portal struct {
Key PortalKey
Type discordgo.ChannelType
OtherUserID string
ParentID string
GuildID string
MXID id.RoomID
Name string
NameSet bool
Topic string
TopicSet bool
Avatar string
AvatarURL id.ContentURI
AvatarSet bool
Encrypted bool
InSpace id.RoomID
FirstEventID id.EventID
}
func (p *Portal) Scan(row dbutil.Scannable) *Portal {
var mxid, avatarURL, firstEventID sql.NullString
var typ sql.NullInt32
var otherUserID, guildID, parentID, mxid, firstEventID sql.NullString
var chanType int32
var avatarURL string
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.OtherUserID, &firstEventID,
&p.Encrypted)
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID,
&mxid, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet,
&p.Encrypted, &p.InSpace, &firstEventID)
if err != nil {
if err != sql.ErrNoRows {
p.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
p.MXID = id.RoomID(mxid.String)
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
p.Type = discordgo.ChannelType(typ.Int32)
p.OtherUserID = otherUserID.String
p.GuildID = guildID.String
p.ParentID = parentID.String
p.Type = discordgo.ChannelType(chanType)
p.FirstEventID = id.EventID(firstEventID.String)
p.AvatarURL, _ = id.ParseContentURI(avatarURL)
return p
}
func (p *Portal) mxidPtr() *id.RoomID {
if p.MXID != "" {
return &p.MXID
}
return nil
}
func (p *Portal) Insert() {
query := "INSERT INTO portal" +
" (dcid, receiver, mxid, name, topic, avatar, avatar_url," +
" type, other_user_id, first_event_id, encrypted)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.OtherUserID,
p.FirstEventID.String(), p.Encrypted)
query := `
INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, in_space, first_event_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
`
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type,
strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
p.Encrypted, p.InSpace, p.FirstEventID.String())
if err != nil {
p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
panic(err)
}
}
func (p *Portal) Update() {
query := "UPDATE portal SET" +
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
" other_user_id=$7, first_event_id=$8, encrypted=$9" +
" WHERE dcid=$10 AND receiver=$11"
_, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(),
p.Encrypted,
query := `
UPDATE portal SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5,
name=$6, name_set=$7, topic=$8, topic_set=$9, avatar=$10, avatar_url=$11, avatar_set=$12,
encrypted=$13, in_space=$14, first_event_id=$15
WHERE dcid=$16 AND receiver=$17
`
_, err := p.db.Exec(query,
p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
p.Encrypted, p.InSpace, p.FirstEventID.String(),
p.Key.ChannelID, p.Key.Receiver)
if err != nil {
p.log.Warnfln("Failed to update %s: %v", p.Key, err)
panic(err)
}
}
@@ -162,5 +185,6 @@ func (p *Portal) Delete() {
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
if err != nil {
p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
panic(err)
}
}

View File

@@ -1,20 +0,0 @@
package database
type PortalKey struct {
ChannelID string
Receiver string
}
func NewPortalKey(channelID, receiver string) PortalKey {
return PortalKey{
ChannelID: channelID,
Receiver: receiver,
}
}
func (key PortalKey) String() string {
if key.Receiver == "" {
return key.ChannelID
}
return key.ChannelID + "-" + key.Receiver
}

View File

@@ -10,7 +10,7 @@ import (
)
const (
puppetSelect = "SELECT id, display_name, avatar, avatar_url," +
puppetSelect = "SELECT id, name, name_set, avatar, avatar_url, avatar_set," +
" custom_mxid, access_token, next_batch" +
" FROM puppet "
)
@@ -36,12 +36,7 @@ func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
}
func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
return pq.New().Scan(pq.db.QueryRow(query, args...))
}
func (pq *PuppetQuery) GetAll() []*Puppet {
@@ -59,7 +54,7 @@ func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
}
defer rows.Close()
puppets := []*Puppet{}
var puppets []*Puppet
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
@@ -71,11 +66,12 @@ type Puppet struct {
db *Database
log log.Logger
ID string
DisplayName string
ID string
Name string
NameSet bool
Avatar string
AvatarURL id.ContentURI
AvatarSet bool
CustomMXID id.UserID
AccessToken string
@@ -83,24 +79,22 @@ type Puppet struct {
}
func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
var did, displayName, avatar, avatarURL sql.NullString
var avatarURL string
var customMXID, accessToken, nextBatch sql.NullString
err := row.Scan(&did, &displayName, &avatar, &avatarURL,
err := row.Scan(&p.ID, &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.AvatarSet,
&customMXID, &accessToken, &nextBatch)
if err != nil {
if err != sql.ErrNoRows {
p.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
p.ID = did.String
p.DisplayName = displayName.String
p.Avatar = avatar.String
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
p.AvatarURL, _ = id.ParseContentURI(avatarURL)
p.CustomMXID = id.UserID(customMXID.String)
p.AccessToken = accessToken.String
p.NextBatch = nextBatch.String
@@ -109,31 +103,31 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
}
func (p *Puppet) Insert() {
query := "INSERT INTO puppet" +
" (id, display_name, avatar, avatar_url," +
" custom_mxid, access_token, next_batch)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7)"
_, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar,
p.AvatarURL.String(), p.CustomMXID, p.AccessToken,
p.NextBatch)
query := `
INSERT INTO puppet (id, name, name_set, avatar, avatar_url, avatar_set, custom_mxid, access_token, next_batch)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
_, err := p.db.Exec(query, p.ID, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch))
if err != nil {
p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
panic(err)
}
}
func (p *Puppet) Update() {
query := "UPDATE puppet" +
" SET display_name=$1, avatar=$2, avatar_url=$3, " +
" custom_mxid=$4, access_token=$5, next_batch=$6" +
" WHERE id=$7"
_, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(),
p.CustomMXID, p.AccessToken, p.NextBatch,
query := `
UPDATE puppet SET name=$1, name_set=$2, avatar=$3, avatar_url=$4, avatar_set=$5,
custom_mxid=$6, access_token=$7, next_batch=$8
WHERE id=$9
`
_, err := p.db.Exec(query, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch),
p.ID)
if err != nil {
p.log.Warnfln("Failed to update %s: %v", p.ID, err)
panic(err)
}
}

View File

@@ -16,7 +16,7 @@ type ReactionQuery struct {
}
const (
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, mxid FROM reaction"
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, mxid FROM reaction"
)
func (rq *ReactionQuery) New() *Reaction {
@@ -75,15 +75,17 @@ type Reaction struct {
MessageID string
Sender string
EmojiName string
ThreadID string
MXID id.EventID
}
func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.MXID)
err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.ThreadID, &r.MXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
r.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
@@ -91,14 +93,23 @@ func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
return r
}
func (r *Reaction) DiscordProtoChannelID() string {
if r.ThreadID != "" {
return r.ThreadID
} else {
return r.Channel.ChannelID
}
}
func (r *Reaction) Insert() {
query := `
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, mxid)
VALUES($1, $2, $3, $4, $5, $6)
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid)
VALUES($1, $2, $3, $4, $5, $6, $7)
`
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, r.MXID)
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID)
if err != nil {
r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err)
panic(err)
}
}
@@ -107,5 +118,6 @@ func (r *Reaction) Delete() {
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName)
if err != nil {
r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.MessageID, r.Channel, err)
panic(err)
}
}

105
database/thread.go Normal file
View File

@@ -0,0 +1,105 @@
package database
import (
"database/sql"
"errors"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type ThreadQuery struct {
db *Database
log log.Logger
}
const (
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread"
)
func (tq *ThreadQuery) New() *Thread {
return &Thread{
db: tq.db,
log: tq.log,
}
}
func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
query := threadSelect + " WHERE dcid=$1"
row := tq.db.QueryRow(query, discordID)
if row == nil {
return nil
}
return tq.New().Scan(row)
}
//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
//
// row := tq.db.QueryRow(query, channelID, messageID)
// if row == nil {
// return nil
// }
//
// return tq.New().Scan(row)
//}
func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
query := threadSelect + " WHERE root_msg_mxid=$1"
row := tq.db.QueryRow(query, mxid)
if row == nil {
return nil
}
return tq.New().Scan(row)
}
type Thread struct {
db *Database
log log.Logger
ID string
ParentID string
RootDiscordID string
RootMXID id.EventID
}
func (t *Thread) Scan(row dbutil.Scannable) *Thread {
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
t.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
return t
}
func (t *Thread) Insert() {
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)"
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
if err != nil {
t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
panic(err)
}
}
func (t *Thread) Delete() {
query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
_, err := t.db.Exec(query, t.ID, t.ParentID)
if err != nil {
t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
panic(err)
}
}

View File

@@ -1,29 +1,65 @@
-- v0 -> v2: Latest revision
-- v0 -> v3: Latest revision
CREATE TABLE guild (
dcid TEXT PRIMARY KEY,
mxid TEXT UNIQUE,
name TEXT NOT NULL,
name_set BOOLEAN NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
avatar_set BOOLEAN NOT NULL,
auto_bridge_channels BOOLEAN NOT NULL
);
CREATE TABLE portal (
dcid TEXT,
receiver TEXT,
other_user_id TEXT,
type INTEGER,
type INTEGER NOT NULL,
dc_guild_id TEXT,
dc_parent_id TEXT,
-- This is not accessed by the bridge, it's only used for the portal parent foreign key.
-- Only guild channels have parents, but only DMs have a receiver field.
dc_parent_receiver TEXT NOT NULL DEFAULT '',
mxid TEXT UNIQUE,
name TEXT NOT NULL,
name_set BOOLEAN NOT NULL,
topic TEXT NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
encrypted BOOLEAN NOT NULL DEFAULT false,
avatar_set BOOLEAN NOT NULL,
encrypted BOOLEAN NOT NULL,
in_space TEXT NOT NULL,
first_event_id TEXT NOT NULL,
PRIMARY KEY (dcid, receiver)
PRIMARY KEY (dcid, receiver),
CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE,
CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE
);
CREATE TABLE thread (
dcid TEXT PRIMARY KEY,
parent_chan_id TEXT NOT NULL,
root_msg_dcid TEXT NOT NULL,
root_msg_mxid TEXT NOT NULL,
-- This is also not accessed by the bridge.
receiver TEXT NOT NULL DEFAULT '',
CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE TABLE puppet (
id TEXT PRIMARY KEY,
name TEXT,
avatar TEXT,
avatar_url TEXT,
name TEXT NOT NULL,
name_set BOOLEAN NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
avatar_set BOOLEAN NOT NULL,
custom_mxid TEXT,
access_token TEXT,
@@ -34,9 +70,20 @@ CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
dcid TEXT UNIQUE,
discord_token TEXT,
management_room TEXT,
space_room TEXT
);
token TEXT
CREATE TABLE user_guild (
guild_id TEXT,
user_mxid TEXT,
in_space BOOLEAN NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (guild_id, user_mxid),
CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE,
CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE
);
CREATE TABLE message (
@@ -45,6 +92,7 @@ CREATE TABLE message (
dc_chan_receiver TEXT,
dc_sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
dc_thread_id TEXT,
mxid TEXT NOT NULL UNIQUE,
@@ -58,6 +106,7 @@ CREATE TABLE reaction (
dc_msg_id TEXT,
dc_sender TEXT,
dc_emoji_name TEXT,
dc_thread_id TEXT,
mxid TEXT NOT NULL UNIQUE,
@@ -70,6 +119,7 @@ CREATE TABLE attachment (
dc_msg_id TEXT,
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_thread_id TEXT,
mxid TEXT NOT NULL UNIQUE,
@@ -82,11 +132,3 @@ CREATE TABLE emoji (
discord_name TEXT,
matrix_url TEXT
);
CREATE TABLE guild (
discord_id TEXT NOT NULL,
guild_id TEXT NOT NULL,
guild_name TEXT NOT NULL,
bridge BOOLEAN DEFAULT FALSE,
PRIMARY KEY(discord_id, guild_id)
);

View File

@@ -0,0 +1,72 @@
-- v3: Store portal parent metadata for spaces
DROP TABLE guild;
CREATE TABLE guild (
dcid TEXT PRIMARY KEY,
mxid TEXT UNIQUE,
name TEXT NOT NULL,
name_set BOOLEAN NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
avatar_set BOOLEAN NOT NULL,
auto_bridge_channels BOOLEAN NOT NULL
);
CREATE TABLE user_guild (
guild_id TEXT,
user_mxid TEXT,
in_space BOOLEAN NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (guild_id, user_mxid),
CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE,
CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE
);
ALTER TABLE portal ADD COLUMN dc_guild_id TEXT;
ALTER TABLE portal ADD COLUMN dc_parent_id TEXT;
ALTER TABLE portal ADD COLUMN dc_parent_receiver TEXT NOT NULL DEFAULT '';
ALTER TABLE portal ADD CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE;
ALTER TABLE portal ADD CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE;
DELETE FROM portal WHERE type IS NULL;
-- only: postgres
ALTER TABLE portal ALTER COLUMN type SET NOT NULL;
ALTER TABLE portal ADD COLUMN in_space TEXT NOT NULL DEFAULT '';
ALTER TABLE portal ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE portal ADD COLUMN topic_set BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE portal ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false;
-- only: postgres for next 5 lines
ALTER TABLE portal ALTER COLUMN in_space DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN name_set DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN topic_set DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN avatar_set DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN encrypted DROP DEFAULT;
ALTER TABLE puppet RENAME COLUMN display_name TO name;
ALTER TABLE puppet ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE puppet ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false;
-- only: postgres for next 2 lines
ALTER TABLE puppet ALTER COLUMN name_set DROP DEFAULT;
ALTER TABLE puppet ALTER COLUMN avatar_set DROP DEFAULT;
ALTER TABLE "user" ADD COLUMN space_room TEXT;
ALTER TABLE "user" RENAME COLUMN token TO discord_token;
UPDATE message SET timestamp=timestamp*1000;
CREATE TABLE thread (
dcid TEXT PRIMARY KEY,
parent_chan_id TEXT NOT NULL,
root_msg_dcid TEXT NOT NULL,
root_msg_mxid TEXT NOT NULL,
-- This is also not accessed by the bridge.
receiver TEXT NOT NULL DEFAULT '',
CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
);
ALTER TABLE message ADD COLUMN dc_thread_id TEXT;
ALTER TABLE attachment ADD COLUMN dc_thread_id TEXT;
ALTER TABLE reaction ADD COLUMN dc_thread_id TEXT;

View File

@@ -2,6 +2,8 @@ package database
import (
"database/sql"
"errors"
"time"
log "maunium.net/go/maulogger/v2"
@@ -22,38 +24,32 @@ func (uq *UserQuery) New() *User {
}
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE mxid=$1`
row := uq.db.QueryRow(query, userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE mxid=$1`
return uq.New().Scan(uq.db.QueryRow(query, userID))
}
func (uq *UserQuery) GetByID(id string) *User {
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE dcid=$1`
row := uq.db.QueryRow(query, id)
if row == nil {
return nil
}
return uq.New().Scan(row)
query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE dcid=$1`
return uq.New().Scan(uq.db.QueryRow(query, id))
}
func (uq *UserQuery) GetAll() []*User {
rows, err := uq.db.Query(`SELECT mxid, dcid, management_room, token FROM "user" WHERE token IS NOT NULL`)
func (uq *UserQuery) GetAllWithToken() []*User {
query := `
SELECT mxid, dcid, discord_token, management_room, space_room
FROM "user" WHERE discord_token IS NOT NULL
`
rows, err := uq.db.Query(query)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
users := []*User{}
var users []*User
for rows.Next() {
users = append(users, uq.New().Scan(rows))
user := uq.New().Scan(rows)
if user != nil {
users = append(users, user)
}
}
return users
}
@@ -61,80 +57,114 @@ type User struct {
db *Database
log log.Logger
MXID id.UserID
ID string
MXID id.UserID
DiscordID string
DiscordToken string
ManagementRoom id.RoomID
SpaceRoom id.RoomID
}
Token string
type UserGuild struct {
GuildID string
Timestamp time.Time
InSpace bool
}
func (u *User) GetGuilds() []UserGuild {
res, err := u.db.Query("SELECT guild_id, timestamp, in_space FROM user_guild WHERE user_mxid=$1", u.MXID)
if err != nil {
u.log.Errorln("Failed to get guilds:", err)
panic(err)
return nil
}
var guilds []UserGuild
for res.Next() {
var guild UserGuild
var ts int64
err = res.Scan(&guild.GuildID, &ts, &guild.InSpace)
if err != nil {
u.log.Errorln("Error scanning user guild:", err)
panic(err)
} else {
guild.Timestamp = time.UnixMilli(ts)
guilds = append(guilds, guild)
}
}
return guilds
}
func (u *User) IsInSpace(guildID string) (isIn bool) {
query := `SELECT in_space FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
err := u.db.QueryRow(query, u.MXID, guildID).Scan(&isIn)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, guildID, err)
panic(err)
}
return
}
func (u *User) MarkInGuild(guild UserGuild) {
query := `
INSERT INTO user_guild (guild_id, user_mxid, timestamp, in_space)
VALUES ($1, $2, $3, $4)
ON CONFLICT (guild_id, user_mxid) DO UPDATE
SET timestamp=excluded.timestamp, in_space=excluded.in_space
`
_, err := u.db.Exec(query, guild.GuildID, u.MXID, guild.Timestamp.UnixMilli(), guild.InSpace)
if err != nil {
u.log.Errorfln("Failed to insert user guild %s/%s: %v", u.MXID, guild.GuildID, err)
panic(err)
}
}
func (u *User) MarkNotInGuild(guildID string) {
query := `DELETE FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
_, err := u.db.Exec(query, u.MXID, guildID)
if err != nil {
u.log.Errorfln("Failed to remove user guild %s/%s: %v", u.MXID, guildID, err)
panic(err)
}
}
func (u *User) PruneGuildList(beforeTS time.Time) {
_, err := u.db.Exec("DELETE FROM user_guild WHERE user_mxid=$1 AND timestamp<$2", u.MXID, beforeTS.UnixMilli())
if err != nil {
u.log.Errorln("Failed to prune user guild list:", err)
panic(err)
}
}
func (u *User) Scan(row dbutil.Scannable) *User {
var token sql.NullString
var discordID sql.NullString
err := row.Scan(&u.MXID, &discordID, &u.ManagementRoom, &token)
var discordID, managementRoom, spaceRoom, discordToken sql.NullString
err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom)
if err != nil {
if err != sql.ErrNoRows {
u.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
if token.Valid {
u.Token = token.String
}
if discordID.Valid {
u.ID = discordID.String
}
u.DiscordID = discordID.String
u.DiscordToken = discordToken.String
u.ManagementRoom = id.RoomID(managementRoom.String)
u.SpaceRoom = id.RoomID(spaceRoom.String)
return u
}
func (u *User) Insert() {
query := "INSERT INTO \"user\" (mxid, dcid, management_room, token) VALUES ($1, $2, $3, $4)"
var token sql.NullString
var discordID sql.NullString
if u.Token != "" {
token.String = u.Token
token.Valid = true
}
if u.ID != "" {
discordID.String = u.ID
discordID.Valid = true
}
_, err := u.db.Exec(query, u.MXID, discordID, u.ManagementRoom, token)
query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room) VALUES ($1, $2, $3, $4, $5)`
_, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)))
if err != nil {
u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
panic(err)
}
}
func (u *User) Update() {
query := "UPDATE \"user\" SET dcid=$1, management_room=$2, token=$3 WHERE mxid=$4"
var token sql.NullString
var discordID sql.NullString
if u.Token != "" {
token.String = u.Token
token.Valid = true
}
if u.ID != "" {
discordID.String = u.ID
discordID.Valid = true
}
_, err := u.db.Exec(query, discordID, u.ManagementRoom, token, u.MXID)
query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4 WHERE mxid=$5`
_, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), u.MXID)
if err != nil {
u.log.Warnfln("Failed to update %q: %v", u.MXID, err)
panic(err)
}
}