Break a bunch of stuff
This commit is contained in:
@@ -3,7 +3,7 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -29,17 +29,16 @@ func (portal *Portal) downloadDiscordAttachment(url string) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", discordgo.DroidBrowserUserAgent)
|
||||
for key, value := range discordgo.DroidDownloadHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (portal *Portal) downloadMatrixAttachment(eventID id.EventID, content *event.MessageEventContent) ([]byte, error) {
|
||||
|
||||
13
avatar.go
13
avatar.go
@@ -7,16 +7,25 @@ import (
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
func uploadAvatar(intent *appservice.IntentAPI, url string) (id.ContentURI, error) {
|
||||
getResp, err := http.DefaultClient.Get(url)
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return id.ContentURI{}, fmt.Errorf("failed to prepare request: %w", err)
|
||||
}
|
||||
for key, value := range discordgo.DroidImageHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
getResp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return id.ContentURI{}, fmt.Errorf("failed to download avatar: %w", err)
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(getResp.Body)
|
||||
getResp.Body.Close()
|
||||
_ = getResp.Body.Close()
|
||||
if err != nil {
|
||||
return id.ContentURI{}, fmt.Errorf("failed to read avatar data: %w", err)
|
||||
}
|
||||
|
||||
28
commands.go
28
commands.go
@@ -116,7 +116,7 @@ func fnLogin(ce *WrappedCommandEvent) {
|
||||
ce.Reply("Error connecting after login: %v", err)
|
||||
}
|
||||
ce.User.Lock()
|
||||
ce.User.ID = user.UserID
|
||||
ce.User.DiscordID = user.UserID
|
||||
ce.User.Update()
|
||||
ce.User.Unlock()
|
||||
ce.Reply("Successfully logged in as %s#%s", user.Username, user.Discriminator)
|
||||
@@ -251,19 +251,21 @@ func fnGuilds(ce *WrappedCommandEvent) {
|
||||
}
|
||||
|
||||
func fnListGuilds(ce *WrappedCommandEvent) {
|
||||
ce.User.guildsLock.Lock()
|
||||
defer ce.User.guildsLock.Unlock()
|
||||
if len(ce.User.guilds) == 0 {
|
||||
ce.Reply("You haven't joined any guilds")
|
||||
} else {
|
||||
var output strings.Builder
|
||||
for _, guild := range ce.User.guilds {
|
||||
status := "not bridged"
|
||||
if guild.Bridge {
|
||||
status = "bridged"
|
||||
}
|
||||
_, _ = fmt.Fprintf(&output, "* %s (`%s`) - %s\n", guild.GuildName, guild.GuildID, status)
|
||||
var output strings.Builder
|
||||
for _, userGuild := range ce.User.GetGuilds() {
|
||||
guild := ce.Bridge.GetGuildByID(userGuild.GuildID, false)
|
||||
if guild == nil {
|
||||
continue
|
||||
}
|
||||
status := "not bridged"
|
||||
if guild.MXID != "" {
|
||||
status = "bridged"
|
||||
}
|
||||
_, _ = fmt.Fprintf(&output, "* %s (`%s`) - %s\n", guild.Name, guild.ID, status)
|
||||
}
|
||||
if output.Len() == 0 {
|
||||
ce.Reply("No guilds found")
|
||||
} else {
|
||||
ce.Reply("List of guilds:\n\n%s", output.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type AttachmentQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver FROM attachment"
|
||||
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id FROM attachment"
|
||||
)
|
||||
|
||||
func (aq *AttachmentQuery) New() *Attachment {
|
||||
@@ -77,58 +77,50 @@ type Attachment struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
Channel PortalKey
|
||||
Channel PortalKey
|
||||
ThreadID string
|
||||
MessageID string
|
||||
ID string
|
||||
MXID id.EventID
|
||||
}
|
||||
|
||||
DiscordMessageID string
|
||||
DiscordAttachmentID string
|
||||
MXID id.EventID
|
||||
func (a *Attachment) DiscordProtoChannelID() string {
|
||||
if a.ThreadID != "" {
|
||||
return a.ThreadID
|
||||
} else {
|
||||
return a.Channel.ChannelID
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Attachment) Scan(row dbutil.Scannable) *Attachment {
|
||||
err := row.Scan(
|
||||
&a.DiscordAttachmentID, &a.DiscordMessageID,
|
||||
&a.Channel.ChannelID, &a.Channel.Receiver,
|
||||
&a.MXID)
|
||||
|
||||
err := row.Scan(&a.ID, &a.MessageID, &a.Channel.ChannelID, &a.Channel.Receiver, &a.ThreadID, &a.MXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
a.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *Attachment) Insert() {
|
||||
query := "INSERT INTO attachment" +
|
||||
" (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, " +
|
||||
" mxid) VALUES ($1, $2, $3, $4, $5);"
|
||||
|
||||
_, err := a.db.Exec(
|
||||
query,
|
||||
a.Channel.ChannelID, a.Channel.Receiver,
|
||||
a.DiscordMessageID, a.DiscordAttachmentID,
|
||||
a.MXID,
|
||||
)
|
||||
|
||||
query := `
|
||||
INSERT INTO attachment (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, mxid)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`
|
||||
_, err := a.db.Exec(query, a.ID, a.MessageID, a.Channel.ChannelID, a.Channel.Receiver, strPtr(a.ThreadID), a.MXID)
|
||||
if err != nil {
|
||||
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
|
||||
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.ID, a.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Attachment) Delete() {
|
||||
query := "DELETE FROM attachment WHERE" +
|
||||
" dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
|
||||
|
||||
_, err := a.db.Exec(
|
||||
query,
|
||||
a.Channel.ChannelID, a.Channel.Receiver,
|
||||
a.DiscordAttachmentID,
|
||||
)
|
||||
|
||||
query := `DELETE FROM attachment WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3`
|
||||
_, err := a.db.Exec(query, a.Channel.ChannelID, a.Channel.Receiver, a.ID)
|
||||
if err != nil {
|
||||
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
|
||||
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.ID, a.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,9 @@ import (
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"go.mau.fi/mautrix-discord/database/upgrades"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
|
||||
"go.mau.fi/mautrix-discord/database/upgrades"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
@@ -18,6 +19,7 @@ type Database struct {
|
||||
Portal *PortalQuery
|
||||
Puppet *PuppetQuery
|
||||
Message *MessageQuery
|
||||
Thread *ThreadQuery
|
||||
Reaction *ReactionQuery
|
||||
Attachment *AttachmentQuery
|
||||
Emoji *EmojiQuery
|
||||
@@ -54,6 +56,10 @@ func New(baseDB *dbutil.Database) *Database {
|
||||
db: db,
|
||||
log: db.Log.Sub("Message"),
|
||||
}
|
||||
db.Thread = &ThreadQuery{
|
||||
db: db,
|
||||
log: db.Log.Sub("Thread"),
|
||||
}
|
||||
db.Reaction = &ReactionQuery{
|
||||
db: db,
|
||||
log: db.Log.Sub("Reaction"),
|
||||
@@ -72,3 +78,10 @@ func New(baseDB *dbutil.Database) *Database {
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func strPtr(val string) *string {
|
||||
if val == "" {
|
||||
return nil
|
||||
}
|
||||
return &val
|
||||
}
|
||||
|
||||
@@ -28,23 +28,16 @@ func (eq *EmojiQuery) New() *Emoji {
|
||||
|
||||
func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji {
|
||||
query := emojiSelect + " WHERE discord_id=$1"
|
||||
|
||||
return eq.get(query, discordID)
|
||||
}
|
||||
|
||||
func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji {
|
||||
query := emojiSelect + " WHERE matrix_url=$1"
|
||||
|
||||
return eq.get(query, matrixURL.String())
|
||||
}
|
||||
|
||||
func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji {
|
||||
row := eq.db.QueryRow(query, args...)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return eq.New().Scan(row)
|
||||
return eq.New().Scan(eq.db.QueryRow(query, args...))
|
||||
}
|
||||
|
||||
type Emoji struct {
|
||||
@@ -59,18 +52,17 @@ type Emoji struct {
|
||||
|
||||
func (e *Emoji) Scan(row dbutil.Scannable) *Emoji {
|
||||
var matrixURL sql.NullString
|
||||
err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL)
|
||||
|
||||
err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
e.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
e.MatrixURL, _ = id.ParseContentURI(matrixURL.String)
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -83,6 +75,7 @@ func (e *Emoji) Insert() {
|
||||
|
||||
if err != nil {
|
||||
e.log.Warnfln("Failed to insert emoji %s: %v", e.DiscordID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,17 +85,15 @@ func (e *Emoji) Delete() {
|
||||
_, err := e.db.Exec(query, e.DiscordID)
|
||||
if err != nil {
|
||||
e.log.Warnfln("Failed to delete emoji %s: %v", e.DiscordID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Emoji) APIName() string {
|
||||
if e.DiscordID != "" && e.DiscordName != "" {
|
||||
return e.DiscordName + ":" + e.DiscordID
|
||||
}
|
||||
|
||||
if e.DiscordName != "" {
|
||||
} else if e.DiscordName != "" {
|
||||
return e.DiscordName
|
||||
}
|
||||
|
||||
return e.DiscordID
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
@@ -16,7 +16,7 @@ type GuildQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild"
|
||||
guildSelect = "SELECT dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels FROM guild"
|
||||
)
|
||||
|
||||
func (gq *GuildQuery) New() *Guild {
|
||||
@@ -26,109 +26,100 @@ func (gq *GuildQuery) New() *Guild {
|
||||
}
|
||||
}
|
||||
|
||||
func (gq *GuildQuery) Get(discordID, guildID string) *Guild {
|
||||
query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2"
|
||||
|
||||
row := gq.db.QueryRow(query, discordID, guildID)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return gq.New().Scan(row)
|
||||
func (gq *GuildQuery) GetByID(dcid string) *Guild {
|
||||
query := guildSelect + " WHERE dcid=$1"
|
||||
return gq.New().Scan(gq.db.QueryRow(query, dcid))
|
||||
}
|
||||
|
||||
func (gq *GuildQuery) GetAll(discordID string) []*Guild {
|
||||
query := guildSelect + " WHERE discord_id=$1"
|
||||
func (gq *GuildQuery) GetByMXID(mxid id.RoomID) *Guild {
|
||||
query := guildSelect + " WHERE mxid=$1"
|
||||
return gq.New().Scan(gq.db.QueryRow(query, mxid))
|
||||
}
|
||||
|
||||
rows, err := gq.db.Query(query, discordID)
|
||||
if err != nil || rows == nil {
|
||||
func (gq *GuildQuery) GetAll() []*Guild {
|
||||
rows, err := gq.db.Query(guildSelect)
|
||||
if err != nil {
|
||||
gq.log.Errorln("Failed to query guilds:", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
guilds := []*Guild{}
|
||||
var guilds []*Guild
|
||||
for rows.Next() {
|
||||
guilds = append(guilds, gq.New().Scan(rows))
|
||||
guild := gq.New().Scan(rows)
|
||||
if guild != nil {
|
||||
guilds = append(guilds, guild)
|
||||
}
|
||||
}
|
||||
|
||||
return guilds
|
||||
}
|
||||
|
||||
func (gq *GuildQuery) Prune(discordID string, guilds []string) {
|
||||
// We need this interface slice because a variadic function can't mix
|
||||
// arguements with a `...` expanded slice.
|
||||
args := []interface{}{discordID}
|
||||
|
||||
nGuilds := len(guilds)
|
||||
if nGuilds <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
gq.log.Debugfln("prunning guilds for %s", discordID)
|
||||
|
||||
// Build the in query
|
||||
inQuery := "$2"
|
||||
for i := 1; i < nGuilds; i++ {
|
||||
inQuery += fmt.Sprintf(", $%d", i+2)
|
||||
}
|
||||
|
||||
// Add the arguements for the build query
|
||||
for _, guildID := range guilds {
|
||||
args = append(args, guildID)
|
||||
}
|
||||
|
||||
// Now remove any guilds that the user has left.
|
||||
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" +
|
||||
inQuery + ")"
|
||||
|
||||
_, err := gq.db.Exec(query, args...)
|
||||
if err != nil {
|
||||
gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err)
|
||||
}
|
||||
}
|
||||
|
||||
type Guild struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
DiscordID string
|
||||
GuildID string
|
||||
GuildName string
|
||||
Bridge bool
|
||||
ID string
|
||||
MXID id.RoomID
|
||||
Name string
|
||||
NameSet bool
|
||||
Avatar string
|
||||
AvatarURL id.ContentURI
|
||||
AvatarSet bool
|
||||
|
||||
AutoBridgeChannels bool
|
||||
}
|
||||
|
||||
func (g *Guild) Scan(row dbutil.Scannable) *Guild {
|
||||
err := row.Scan(&g.DiscordID, &g.GuildID, &g.GuildName, &g.Bridge)
|
||||
var mxid sql.NullString
|
||||
var avatarURL string
|
||||
err := row.Scan(&g.ID, &mxid, &g.Name, &g.NameSet, &g.Avatar, &avatarURL, &g.AvatarSet, &g.AutoBridgeChannels)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
g.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
g.MXID = id.RoomID(mxid.String)
|
||||
g.AvatarURL, _ = id.ParseContentURI(avatarURL)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *Guild) Upsert() {
|
||||
query := "INSERT INTO guild" +
|
||||
" (discord_id, guild_id, guild_name, bridge)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT(discord_id, guild_id)" +
|
||||
" DO UPDATE SET guild_name=excluded.guild_name, bridge=excluded.bridge"
|
||||
|
||||
_, err := g.db.Exec(query, g.DiscordID, g.GuildID, g.GuildName, g.Bridge)
|
||||
|
||||
func (g *Guild) mxidPtr() *id.RoomID {
|
||||
if g.MXID != "" {
|
||||
return &g.MXID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (g *Guild) Insert() {
|
||||
query := `
|
||||
INSERT INTO guild (dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`
|
||||
_, err := g.db.Exec(query, g.ID, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels)
|
||||
if err != nil {
|
||||
g.log.Warnfln("Failed to upsert guild %s for %s: %v", g.GuildID, g.DiscordID, err)
|
||||
g.log.Warnfln("Failed to insert %s: %v", g.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guild) Update() {
|
||||
query := `
|
||||
UPDATE guild SET mxid=$1, name=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, auto_bridge_channels=$7
|
||||
WHERE dcid=$8
|
||||
`
|
||||
_, err := g.db.Exec(query, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels, g.ID)
|
||||
if err != nil {
|
||||
g.log.Warnfln("Failed to update %s: %v", g.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guild) Delete() {
|
||||
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id=$2"
|
||||
|
||||
_, err := g.db.Exec(query, g.DiscordID, g.GuildID)
|
||||
|
||||
_, err := g.db.Exec("DELETE FROM guild WHERE dcid=$1", g.ID)
|
||||
if err != nil {
|
||||
g.log.Warnfln("Failed to delete guild %s for user %s: %v", g.GuildID, g.DiscordID, err)
|
||||
g.log.Warnfln("Failed to delete %s: %v", g.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ type MessageQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid FROM message"
|
||||
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message"
|
||||
)
|
||||
|
||||
func (mq *MessageQuery) New() *Message {
|
||||
@@ -45,14 +45,12 @@ func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
|
||||
|
||||
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
|
||||
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
|
||||
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
|
||||
}
|
||||
|
||||
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)
|
||||
if row == nil {
|
||||
mq.log.Debugfln("failed to find existing message for discord_id %s", discordID)
|
||||
return nil
|
||||
}
|
||||
|
||||
return mq.New().Scan(row)
|
||||
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
|
||||
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC LIMIT 1"
|
||||
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
|
||||
}
|
||||
|
||||
func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
|
||||
@@ -74,36 +72,54 @@ type Message struct {
|
||||
Channel PortalKey
|
||||
SenderID string
|
||||
Timestamp time.Time
|
||||
ThreadID string
|
||||
|
||||
MXID id.EventID
|
||||
}
|
||||
|
||||
func (m *Message) DiscordProtoChannelID() string {
|
||||
if m.ThreadID != "" {
|
||||
return m.ThreadID
|
||||
} else {
|
||||
return m.Channel.ChannelID
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) Scan(row dbutil.Scannable) *Message {
|
||||
var ts int64
|
||||
var threadID sql.NullString
|
||||
|
||||
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.MXID)
|
||||
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
m.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if ts != 0 {
|
||||
m.Timestamp = time.Unix(ts, 0)
|
||||
m.Timestamp = time.UnixMilli(ts)
|
||||
}
|
||||
m.ThreadID = threadID.String
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Message) Insert() {
|
||||
query := "INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
query := `
|
||||
INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
`
|
||||
|
||||
_, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.Unix(), m.MXID)
|
||||
_, err := m.db.Exec(query,
|
||||
m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
|
||||
m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
|
||||
|
||||
if err != nil {
|
||||
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,5 +130,6 @@ func (m *Message) Delete() {
|
||||
|
||||
if err != nil {
|
||||
m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,11 +12,30 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
portalSelect = "SELECT dcid, receiver, mxid, name, topic, avatar," +
|
||||
" avatar_url, type, other_user_id, first_event_id, encrypted" +
|
||||
portalSelect = "SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, " +
|
||||
" mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, in_space, first_event_id" +
|
||||
" FROM portal"
|
||||
)
|
||||
|
||||
type PortalKey struct {
|
||||
ChannelID string
|
||||
Receiver string
|
||||
}
|
||||
|
||||
func NewPortalKey(channelID, receiver string) PortalKey {
|
||||
return PortalKey{
|
||||
ChannelID: channelID,
|
||||
Receiver: receiver,
|
||||
}
|
||||
}
|
||||
|
||||
func (key PortalKey) String() string {
|
||||
if key.Receiver == "" {
|
||||
return key.ChannelID
|
||||
}
|
||||
return key.ChannelID + "-" + key.Receiver
|
||||
}
|
||||
|
||||
type PortalQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
@@ -34,7 +53,7 @@ func (pq *PortalQuery) GetAll() []*Portal {
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
|
||||
return pq.get(portalSelect+" WHERE dcid=$1 AND receiver=$2", key.ChannelID, key.Receiver)
|
||||
return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
|
||||
@@ -67,12 +86,7 @@ func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
|
||||
row := pq.db.QueryRow(query, args...)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return pq.New().Scan(row)
|
||||
return pq.New().Scan(pq.db.QueryRow(query, args...))
|
||||
}
|
||||
|
||||
type Portal struct {
|
||||
@@ -82,78 +96,87 @@ type Portal struct {
|
||||
Key PortalKey
|
||||
Type discordgo.ChannelType
|
||||
OtherUserID string
|
||||
ParentID string
|
||||
GuildID string
|
||||
|
||||
MXID id.RoomID
|
||||
|
||||
Name string
|
||||
NameSet bool
|
||||
Topic string
|
||||
TopicSet bool
|
||||
Avatar string
|
||||
AvatarURL id.ContentURI
|
||||
AvatarSet bool
|
||||
Encrypted bool
|
||||
InSpace id.RoomID
|
||||
|
||||
FirstEventID id.EventID
|
||||
}
|
||||
|
||||
func (p *Portal) Scan(row dbutil.Scannable) *Portal {
|
||||
var mxid, avatarURL, firstEventID sql.NullString
|
||||
var typ sql.NullInt32
|
||||
var otherUserID, guildID, parentID, mxid, firstEventID sql.NullString
|
||||
var chanType int32
|
||||
var avatarURL string
|
||||
|
||||
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
|
||||
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.OtherUserID, &firstEventID,
|
||||
&p.Encrypted)
|
||||
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID,
|
||||
&mxid, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet,
|
||||
&p.Encrypted, &p.InSpace, &firstEventID)
|
||||
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
p.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
p.MXID = id.RoomID(mxid.String)
|
||||
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
|
||||
p.Type = discordgo.ChannelType(typ.Int32)
|
||||
p.OtherUserID = otherUserID.String
|
||||
p.GuildID = guildID.String
|
||||
p.ParentID = parentID.String
|
||||
p.Type = discordgo.ChannelType(chanType)
|
||||
p.FirstEventID = id.EventID(firstEventID.String)
|
||||
p.AvatarURL, _ = id.ParseContentURI(avatarURL)
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Portal) mxidPtr() *id.RoomID {
|
||||
if p.MXID != "" {
|
||||
return &p.MXID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Portal) Insert() {
|
||||
query := "INSERT INTO portal" +
|
||||
" (dcid, receiver, mxid, name, topic, avatar, avatar_url," +
|
||||
" type, other_user_id, first_event_id, encrypted)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
|
||||
|
||||
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
|
||||
p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.OtherUserID,
|
||||
p.FirstEventID.String(), p.Encrypted)
|
||||
query := `
|
||||
INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
|
||||
name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
|
||||
encrypted, in_space, first_event_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
|
||||
`
|
||||
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type,
|
||||
strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
|
||||
p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
|
||||
p.Encrypted, p.InSpace, p.FirstEventID.String())
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Portal) Update() {
|
||||
query := "UPDATE portal SET" +
|
||||
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
|
||||
" other_user_id=$7, first_event_id=$8, encrypted=$9" +
|
||||
" WHERE dcid=$10 AND receiver=$11"
|
||||
|
||||
_, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
|
||||
p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(),
|
||||
p.Encrypted,
|
||||
query := `
|
||||
UPDATE portal SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5,
|
||||
name=$6, name_set=$7, topic=$8, topic_set=$9, avatar=$10, avatar_url=$11, avatar_set=$12,
|
||||
encrypted=$13, in_space=$14, first_event_id=$15
|
||||
WHERE dcid=$16 AND receiver=$17
|
||||
`
|
||||
_, err := p.db.Exec(query,
|
||||
p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
|
||||
p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
|
||||
p.Encrypted, p.InSpace, p.FirstEventID.String(),
|
||||
p.Key.ChannelID, p.Key.Receiver)
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to update %s: %v", p.Key, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,5 +185,6 @@ func (p *Portal) Delete() {
|
||||
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package database
|
||||
|
||||
type PortalKey struct {
|
||||
ChannelID string
|
||||
Receiver string
|
||||
}
|
||||
|
||||
func NewPortalKey(channelID, receiver string) PortalKey {
|
||||
return PortalKey{
|
||||
ChannelID: channelID,
|
||||
Receiver: receiver,
|
||||
}
|
||||
}
|
||||
|
||||
func (key PortalKey) String() string {
|
||||
if key.Receiver == "" {
|
||||
return key.ChannelID
|
||||
}
|
||||
return key.ChannelID + "-" + key.Receiver
|
||||
}
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
puppetSelect = "SELECT id, display_name, avatar, avatar_url," +
|
||||
puppetSelect = "SELECT id, name, name_set, avatar, avatar_url, avatar_set," +
|
||||
" custom_mxid, access_token, next_batch" +
|
||||
" FROM puppet "
|
||||
)
|
||||
@@ -36,12 +36,7 @@ func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
|
||||
row := pq.db.QueryRow(query, args...)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return pq.New().Scan(row)
|
||||
return pq.New().Scan(pq.db.QueryRow(query, args...))
|
||||
}
|
||||
|
||||
func (pq *PuppetQuery) GetAll() []*Puppet {
|
||||
@@ -59,7 +54,7 @@ func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
puppets := []*Puppet{}
|
||||
var puppets []*Puppet
|
||||
for rows.Next() {
|
||||
puppets = append(puppets, pq.New().Scan(rows))
|
||||
}
|
||||
@@ -71,11 +66,12 @@ type Puppet struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
ID string
|
||||
DisplayName string
|
||||
|
||||
ID string
|
||||
Name string
|
||||
NameSet bool
|
||||
Avatar string
|
||||
AvatarURL id.ContentURI
|
||||
AvatarSet bool
|
||||
|
||||
CustomMXID id.UserID
|
||||
AccessToken string
|
||||
@@ -83,24 +79,22 @@ type Puppet struct {
|
||||
}
|
||||
|
||||
func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
|
||||
var did, displayName, avatar, avatarURL sql.NullString
|
||||
var avatarURL string
|
||||
var customMXID, accessToken, nextBatch sql.NullString
|
||||
|
||||
err := row.Scan(&did, &displayName, &avatar, &avatarURL,
|
||||
err := row.Scan(&p.ID, &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.AvatarSet,
|
||||
&customMXID, &accessToken, &nextBatch)
|
||||
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
p.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
p.ID = did.String
|
||||
p.DisplayName = displayName.String
|
||||
p.Avatar = avatar.String
|
||||
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
|
||||
p.AvatarURL, _ = id.ParseContentURI(avatarURL)
|
||||
p.CustomMXID = id.UserID(customMXID.String)
|
||||
p.AccessToken = accessToken.String
|
||||
p.NextBatch = nextBatch.String
|
||||
@@ -109,31 +103,31 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
|
||||
}
|
||||
|
||||
func (p *Puppet) Insert() {
|
||||
query := "INSERT INTO puppet" +
|
||||
" (id, display_name, avatar, avatar_url," +
|
||||
" custom_mxid, access_token, next_batch)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
|
||||
_, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar,
|
||||
p.AvatarURL.String(), p.CustomMXID, p.AccessToken,
|
||||
p.NextBatch)
|
||||
query := `
|
||||
INSERT INTO puppet (id, name, name_set, avatar, avatar_url, avatar_set, custom_mxid, access_token, next_batch)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`
|
||||
_, err := p.db.Exec(query, p.ID, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
|
||||
strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch))
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Puppet) Update() {
|
||||
query := "UPDATE puppet" +
|
||||
" SET display_name=$1, avatar=$2, avatar_url=$3, " +
|
||||
" custom_mxid=$4, access_token=$5, next_batch=$6" +
|
||||
" WHERE id=$7"
|
||||
|
||||
_, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(),
|
||||
p.CustomMXID, p.AccessToken, p.NextBatch,
|
||||
query := `
|
||||
UPDATE puppet SET name=$1, name_set=$2, avatar=$3, avatar_url=$4, avatar_set=$5,
|
||||
custom_mxid=$6, access_token=$7, next_batch=$8
|
||||
WHERE id=$9
|
||||
`
|
||||
_, err := p.db.Exec(query, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
|
||||
strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch),
|
||||
p.ID)
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to update %s: %v", p.ID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ type ReactionQuery struct {
|
||||
}
|
||||
|
||||
const (
|
||||
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, mxid FROM reaction"
|
||||
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, mxid FROM reaction"
|
||||
)
|
||||
|
||||
func (rq *ReactionQuery) New() *Reaction {
|
||||
@@ -75,15 +75,17 @@ type Reaction struct {
|
||||
MessageID string
|
||||
Sender string
|
||||
EmojiName string
|
||||
ThreadID string
|
||||
|
||||
MXID id.EventID
|
||||
}
|
||||
|
||||
func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
|
||||
err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.MXID)
|
||||
err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.ThreadID, &r.MXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
r.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -91,14 +93,23 @@ func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Reaction) DiscordProtoChannelID() string {
|
||||
if r.ThreadID != "" {
|
||||
return r.ThreadID
|
||||
} else {
|
||||
return r.Channel.ChannelID
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reaction) Insert() {
|
||||
query := `
|
||||
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, mxid)
|
||||
VALUES($1, $2, $3, $4, $5, $6)
|
||||
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid)
|
||||
VALUES($1, $2, $3, $4, $5, $6, $7)
|
||||
`
|
||||
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, r.MXID)
|
||||
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID)
|
||||
if err != nil {
|
||||
r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,5 +118,6 @@ func (r *Reaction) Delete() {
|
||||
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName)
|
||||
if err != nil {
|
||||
r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.MessageID, r.Channel, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
105
database/thread.go
Normal file
105
database/thread.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type ThreadQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
}
|
||||
|
||||
const (
|
||||
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread"
|
||||
)
|
||||
|
||||
func (tq *ThreadQuery) New() *Thread {
|
||||
return &Thread{
|
||||
db: tq.db,
|
||||
log: tq.log,
|
||||
}
|
||||
}
|
||||
|
||||
func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
|
||||
query := threadSelect + " WHERE dcid=$1"
|
||||
|
||||
row := tq.db.QueryRow(query, discordID)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tq.New().Scan(row)
|
||||
}
|
||||
|
||||
//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
|
||||
// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
|
||||
//
|
||||
// row := tq.db.QueryRow(query, channelID, messageID)
|
||||
// if row == nil {
|
||||
// return nil
|
||||
// }
|
||||
//
|
||||
// return tq.New().Scan(row)
|
||||
//}
|
||||
|
||||
func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
|
||||
query := threadSelect + " WHERE root_msg_mxid=$1"
|
||||
|
||||
row := tq.db.QueryRow(query, mxid)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return tq.New().Scan(row)
|
||||
}
|
||||
|
||||
type Thread struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
ID string
|
||||
ParentID string
|
||||
|
||||
RootDiscordID string
|
||||
RootMXID id.EventID
|
||||
}
|
||||
|
||||
func (t *Thread) Scan(row dbutil.Scannable) *Thread {
|
||||
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *Thread) Insert() {
|
||||
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
|
||||
|
||||
if err != nil {
|
||||
t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Thread) Delete() {
|
||||
query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
|
||||
|
||||
_, err := t.db.Exec(query, t.ID, t.ParentID)
|
||||
|
||||
if err != nil {
|
||||
t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
@@ -1,29 +1,65 @@
|
||||
-- v0 -> v2: Latest revision
|
||||
-- v0 -> v3: Latest revision
|
||||
|
||||
CREATE TABLE guild (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
mxid TEXT UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
name_set BOOLEAN NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
avatar_set BOOLEAN NOT NULL,
|
||||
|
||||
auto_bridge_channels BOOLEAN NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE portal (
|
||||
dcid TEXT,
|
||||
receiver TEXT,
|
||||
other_user_id TEXT,
|
||||
type INTEGER,
|
||||
type INTEGER NOT NULL,
|
||||
|
||||
dc_guild_id TEXT,
|
||||
dc_parent_id TEXT,
|
||||
-- This is not accessed by the bridge, it's only used for the portal parent foreign key.
|
||||
-- Only guild channels have parents, but only DMs have a receiver field.
|
||||
dc_parent_receiver TEXT NOT NULL DEFAULT '',
|
||||
|
||||
mxid TEXT UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
name_set BOOLEAN NOT NULL,
|
||||
topic TEXT NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
encrypted BOOLEAN NOT NULL DEFAULT false,
|
||||
avatar_set BOOLEAN NOT NULL,
|
||||
encrypted BOOLEAN NOT NULL,
|
||||
in_space TEXT NOT NULL,
|
||||
|
||||
first_event_id TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY (dcid, receiver)
|
||||
PRIMARY KEY (dcid, receiver),
|
||||
CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE,
|
||||
CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE thread (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
parent_chan_id TEXT NOT NULL,
|
||||
root_msg_dcid TEXT NOT NULL,
|
||||
root_msg_mxid TEXT NOT NULL,
|
||||
-- This is also not accessed by the bridge.
|
||||
receiver TEXT NOT NULL DEFAULT '',
|
||||
|
||||
CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE puppet (
|
||||
id TEXT PRIMARY KEY,
|
||||
|
||||
name TEXT,
|
||||
avatar TEXT,
|
||||
avatar_url TEXT,
|
||||
name TEXT NOT NULL,
|
||||
name_set BOOLEAN NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
avatar_set BOOLEAN NOT NULL,
|
||||
|
||||
custom_mxid TEXT,
|
||||
access_token TEXT,
|
||||
@@ -34,9 +70,20 @@ CREATE TABLE "user" (
|
||||
mxid TEXT PRIMARY KEY,
|
||||
dcid TEXT UNIQUE,
|
||||
|
||||
discord_token TEXT,
|
||||
management_room TEXT,
|
||||
space_room TEXT
|
||||
);
|
||||
|
||||
token TEXT
|
||||
CREATE TABLE user_guild (
|
||||
guild_id TEXT,
|
||||
user_mxid TEXT,
|
||||
in_space BOOLEAN NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
|
||||
PRIMARY KEY (guild_id, user_mxid),
|
||||
CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE,
|
||||
CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE message (
|
||||
@@ -45,6 +92,7 @@ CREATE TABLE message (
|
||||
dc_chan_receiver TEXT,
|
||||
dc_sender TEXT NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
dc_thread_id TEXT,
|
||||
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
|
||||
@@ -58,6 +106,7 @@ CREATE TABLE reaction (
|
||||
dc_msg_id TEXT,
|
||||
dc_sender TEXT,
|
||||
dc_emoji_name TEXT,
|
||||
dc_thread_id TEXT,
|
||||
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
|
||||
@@ -70,6 +119,7 @@ CREATE TABLE attachment (
|
||||
dc_msg_id TEXT,
|
||||
dc_chan_id TEXT,
|
||||
dc_chan_receiver TEXT,
|
||||
dc_thread_id TEXT,
|
||||
|
||||
mxid TEXT NOT NULL UNIQUE,
|
||||
|
||||
@@ -82,11 +132,3 @@ CREATE TABLE emoji (
|
||||
discord_name TEXT,
|
||||
matrix_url TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE guild (
|
||||
discord_id TEXT NOT NULL,
|
||||
guild_id TEXT NOT NULL,
|
||||
guild_name TEXT NOT NULL,
|
||||
bridge BOOLEAN DEFAULT FALSE,
|
||||
PRIMARY KEY(discord_id, guild_id)
|
||||
);
|
||||
|
||||
72
database/upgrades/03-spaces.sql
Normal file
72
database/upgrades/03-spaces.sql
Normal file
@@ -0,0 +1,72 @@
|
||||
-- v3: Store portal parent metadata for spaces
|
||||
DROP TABLE guild;
|
||||
|
||||
CREATE TABLE guild (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
mxid TEXT UNIQUE,
|
||||
name TEXT NOT NULL,
|
||||
name_set BOOLEAN NOT NULL,
|
||||
avatar TEXT NOT NULL,
|
||||
avatar_url TEXT NOT NULL,
|
||||
avatar_set BOOLEAN NOT NULL,
|
||||
|
||||
auto_bridge_channels BOOLEAN NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE user_guild (
|
||||
guild_id TEXT,
|
||||
user_mxid TEXT,
|
||||
in_space BOOLEAN NOT NULL,
|
||||
timestamp BIGINT NOT NULL,
|
||||
|
||||
PRIMARY KEY (guild_id, user_mxid),
|
||||
CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE,
|
||||
CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
ALTER TABLE portal ADD COLUMN dc_guild_id TEXT;
|
||||
ALTER TABLE portal ADD COLUMN dc_parent_id TEXT;
|
||||
ALTER TABLE portal ADD COLUMN dc_parent_receiver TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE portal ADD CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE;
|
||||
ALTER TABLE portal ADD CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE;
|
||||
DELETE FROM portal WHERE type IS NULL;
|
||||
-- only: postgres
|
||||
ALTER TABLE portal ALTER COLUMN type SET NOT NULL;
|
||||
|
||||
ALTER TABLE portal ADD COLUMN in_space TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE portal ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false;
|
||||
ALTER TABLE portal ADD COLUMN topic_set BOOLEAN NOT NULL DEFAULT false;
|
||||
ALTER TABLE portal ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false;
|
||||
-- only: postgres for next 5 lines
|
||||
ALTER TABLE portal ALTER COLUMN in_space DROP DEFAULT;
|
||||
ALTER TABLE portal ALTER COLUMN name_set DROP DEFAULT;
|
||||
ALTER TABLE portal ALTER COLUMN topic_set DROP DEFAULT;
|
||||
ALTER TABLE portal ALTER COLUMN avatar_set DROP DEFAULT;
|
||||
ALTER TABLE portal ALTER COLUMN encrypted DROP DEFAULT;
|
||||
|
||||
ALTER TABLE puppet RENAME COLUMN display_name TO name;
|
||||
ALTER TABLE puppet ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false;
|
||||
ALTER TABLE puppet ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false;
|
||||
-- only: postgres for next 2 lines
|
||||
ALTER TABLE puppet ALTER COLUMN name_set DROP DEFAULT;
|
||||
ALTER TABLE puppet ALTER COLUMN avatar_set DROP DEFAULT;
|
||||
|
||||
ALTER TABLE "user" ADD COLUMN space_room TEXT;
|
||||
ALTER TABLE "user" RENAME COLUMN token TO discord_token;
|
||||
|
||||
UPDATE message SET timestamp=timestamp*1000;
|
||||
|
||||
CREATE TABLE thread (
|
||||
dcid TEXT PRIMARY KEY,
|
||||
parent_chan_id TEXT NOT NULL,
|
||||
root_msg_dcid TEXT NOT NULL,
|
||||
root_msg_mxid TEXT NOT NULL,
|
||||
-- This is also not accessed by the bridge.
|
||||
receiver TEXT NOT NULL DEFAULT '',
|
||||
|
||||
CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
|
||||
);
|
||||
|
||||
ALTER TABLE message ADD COLUMN dc_thread_id TEXT;
|
||||
ALTER TABLE attachment ADD COLUMN dc_thread_id TEXT;
|
||||
ALTER TABLE reaction ADD COLUMN dc_thread_id TEXT;
|
||||
176
database/user.go
176
database/user.go
@@ -2,6 +2,8 @@ package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
@@ -22,38 +24,32 @@ func (uq *UserQuery) New() *User {
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
|
||||
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE mxid=$1`
|
||||
row := uq.db.QueryRow(query, userID)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return uq.New().Scan(row)
|
||||
query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE mxid=$1`
|
||||
return uq.New().Scan(uq.db.QueryRow(query, userID))
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetByID(id string) *User {
|
||||
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE dcid=$1`
|
||||
row := uq.db.QueryRow(query, id)
|
||||
if row == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return uq.New().Scan(row)
|
||||
query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE dcid=$1`
|
||||
return uq.New().Scan(uq.db.QueryRow(query, id))
|
||||
}
|
||||
|
||||
func (uq *UserQuery) GetAll() []*User {
|
||||
rows, err := uq.db.Query(`SELECT mxid, dcid, management_room, token FROM "user" WHERE token IS NOT NULL`)
|
||||
func (uq *UserQuery) GetAllWithToken() []*User {
|
||||
query := `
|
||||
SELECT mxid, dcid, discord_token, management_room, space_room
|
||||
FROM "user" WHERE discord_token IS NOT NULL
|
||||
`
|
||||
rows, err := uq.db.Query(query)
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
users := []*User{}
|
||||
var users []*User
|
||||
for rows.Next() {
|
||||
users = append(users, uq.New().Scan(rows))
|
||||
user := uq.New().Scan(rows)
|
||||
if user != nil {
|
||||
users = append(users, user)
|
||||
}
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
@@ -61,80 +57,114 @@ type User struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
|
||||
MXID id.UserID
|
||||
ID string
|
||||
|
||||
MXID id.UserID
|
||||
DiscordID string
|
||||
DiscordToken string
|
||||
ManagementRoom id.RoomID
|
||||
SpaceRoom id.RoomID
|
||||
}
|
||||
|
||||
Token string
|
||||
type UserGuild struct {
|
||||
GuildID string
|
||||
Timestamp time.Time
|
||||
InSpace bool
|
||||
}
|
||||
|
||||
func (u *User) GetGuilds() []UserGuild {
|
||||
res, err := u.db.Query("SELECT guild_id, timestamp, in_space FROM user_guild WHERE user_mxid=$1", u.MXID)
|
||||
if err != nil {
|
||||
u.log.Errorln("Failed to get guilds:", err)
|
||||
panic(err)
|
||||
return nil
|
||||
}
|
||||
var guilds []UserGuild
|
||||
for res.Next() {
|
||||
var guild UserGuild
|
||||
var ts int64
|
||||
err = res.Scan(&guild.GuildID, &ts, &guild.InSpace)
|
||||
if err != nil {
|
||||
u.log.Errorln("Error scanning user guild:", err)
|
||||
panic(err)
|
||||
} else {
|
||||
guild.Timestamp = time.UnixMilli(ts)
|
||||
guilds = append(guilds, guild)
|
||||
}
|
||||
}
|
||||
return guilds
|
||||
}
|
||||
|
||||
func (u *User) IsInSpace(guildID string) (isIn bool) {
|
||||
query := `SELECT in_space FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
|
||||
err := u.db.QueryRow(query, u.MXID, guildID).Scan(&isIn)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, guildID, err)
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (u *User) MarkInGuild(guild UserGuild) {
|
||||
query := `
|
||||
INSERT INTO user_guild (guild_id, user_mxid, timestamp, in_space)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (guild_id, user_mxid) DO UPDATE
|
||||
SET timestamp=excluded.timestamp, in_space=excluded.in_space
|
||||
`
|
||||
_, err := u.db.Exec(query, guild.GuildID, u.MXID, guild.Timestamp.UnixMilli(), guild.InSpace)
|
||||
if err != nil {
|
||||
u.log.Errorfln("Failed to insert user guild %s/%s: %v", u.MXID, guild.GuildID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) MarkNotInGuild(guildID string) {
|
||||
query := `DELETE FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
|
||||
_, err := u.db.Exec(query, u.MXID, guildID)
|
||||
if err != nil {
|
||||
u.log.Errorfln("Failed to remove user guild %s/%s: %v", u.MXID, guildID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) PruneGuildList(beforeTS time.Time) {
|
||||
_, err := u.db.Exec("DELETE FROM user_guild WHERE user_mxid=$1 AND timestamp<$2", u.MXID, beforeTS.UnixMilli())
|
||||
if err != nil {
|
||||
u.log.Errorln("Failed to prune user guild list:", err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) Scan(row dbutil.Scannable) *User {
|
||||
var token sql.NullString
|
||||
var discordID sql.NullString
|
||||
|
||||
err := row.Scan(&u.MXID, &discordID, &u.ManagementRoom, &token)
|
||||
var discordID, managementRoom, spaceRoom, discordToken sql.NullString
|
||||
err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
u.log.Errorln("Database scan failed:", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if token.Valid {
|
||||
u.Token = token.String
|
||||
}
|
||||
|
||||
if discordID.Valid {
|
||||
u.ID = discordID.String
|
||||
}
|
||||
|
||||
u.DiscordID = discordID.String
|
||||
u.DiscordToken = discordToken.String
|
||||
u.ManagementRoom = id.RoomID(managementRoom.String)
|
||||
u.SpaceRoom = id.RoomID(spaceRoom.String)
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *User) Insert() {
|
||||
query := "INSERT INTO \"user\" (mxid, dcid, management_room, token) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
var token sql.NullString
|
||||
var discordID sql.NullString
|
||||
|
||||
if u.Token != "" {
|
||||
token.String = u.Token
|
||||
token.Valid = true
|
||||
}
|
||||
|
||||
if u.ID != "" {
|
||||
discordID.String = u.ID
|
||||
discordID.Valid = true
|
||||
}
|
||||
|
||||
_, err := u.db.Exec(query, u.MXID, discordID, u.ManagementRoom, token)
|
||||
|
||||
query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room) VALUES ($1, $2, $3, $4, $5)`
|
||||
_, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)))
|
||||
if err != nil {
|
||||
u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *User) Update() {
|
||||
query := "UPDATE \"user\" SET dcid=$1, management_room=$2, token=$3 WHERE mxid=$4"
|
||||
|
||||
var token sql.NullString
|
||||
var discordID sql.NullString
|
||||
|
||||
if u.Token != "" {
|
||||
token.String = u.Token
|
||||
token.Valid = true
|
||||
}
|
||||
|
||||
if u.ID != "" {
|
||||
discordID.String = u.ID
|
||||
discordID.Valid = true
|
||||
}
|
||||
|
||||
_, err := u.db.Exec(query, discordID, u.ManagementRoom, token, u.MXID)
|
||||
|
||||
query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4 WHERE mxid=$5`
|
||||
_, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), u.MXID)
|
||||
if err != nil {
|
||||
u.log.Warnfln("Failed to update %q: %v", u.MXID, err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
6
go.mod
6
go.mod
@@ -3,14 +3,14 @@ module go.mau.fi/mautrix-discord
|
||||
go 1.17
|
||||
|
||||
require (
|
||||
github.com/bwmarrin/discordgo v0.23.2
|
||||
github.com/bwmarrin/discordgo v0.25.0
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/lib/pq v1.10.6
|
||||
github.com/mattn/go-sqlite3 v1.14.13
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
maunium.net/go/maulogger/v2 v2.3.2
|
||||
maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365
|
||||
maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -27,4 +27,4 @@ require (
|
||||
maunium.net/go/mauflag v1.0.0 // indirect
|
||||
)
|
||||
|
||||
replace github.com/bwmarrin/discordgo v0.23.2 => gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685
|
||||
replace github.com/bwmarrin/discordgo => gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7
|
||||
|
||||
14
go.sum
14
go.sum
@@ -3,7 +3,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
|
||||
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
|
||||
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs=
|
||||
@@ -28,11 +28,12 @@ github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc=
|
||||
github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM=
|
||||
github.com/yuin/goldmark v1.4.12 h1:6hffw6vALvEDqJ19dOJvJKOoAOKe4NDaTqvd2sktGN0=
|
||||
github.com/yuin/goldmark v1.4.12/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685 h1:WLswEvEb4wi02p+VjMl8slhG/4+3pdfx04xQfqQjpoc=
|
||||
gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685/go.mod h1:Hwfv4M8yP/MDh47BN+4Z1WItJ1umLKUyplCH5KcQPgE=
|
||||
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
|
||||
gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7 h1:S8hbrkgKGU4aU5kXW4d8CA/9ayi8ymI3QU6yg/aWfUw=
|
||||
gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c=
|
||||
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220513224357-95641704303c h1:nF9mHSvoKBLkQNQhJZNsc66z2UzAMUbLGjC95CF3pU0=
|
||||
golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
|
||||
@@ -44,6 +45,7 @@ golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16C
|
||||
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -56,5 +58,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
|
||||
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
|
||||
maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
|
||||
maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
|
||||
maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365 h1:ghL/KMHiMdcwKmuxvVYZvwe6Yk7BHeWlRfoHZovDA2Q=
|
||||
maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365/go.mod h1:CiKpMhAx5QZFHK03jpWb0iKI3sGU8x6+LfsOjDrcO8I=
|
||||
maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf h1:hNCuEyp9oZPM08n7UsI/gq95ptbhOV1A8heFoH/eiM0=
|
||||
maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf/go.mod h1:CiKpMhAx5QZFHK03jpWb0iKI3sGU8x6+LfsOjDrcO8I=
|
||||
|
||||
287
guildportal.go
Normal file
287
guildportal.go
Normal file
@@ -0,0 +1,287 @@
|
||||
// mautrix-discord - A Matrix-Discord puppeting bridge.
|
||||
// Copyright (C) 2022 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
|
||||
"go.mau.fi/mautrix-discord/database"
|
||||
)
|
||||
|
||||
type Guild struct {
|
||||
*database.Guild
|
||||
|
||||
bridge *DiscordBridge
|
||||
log log.Logger
|
||||
|
||||
roomCreateLock sync.Mutex
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) loadGuild(dbGuild *database.Guild, id string, createIfNotExist bool) *Guild {
|
||||
if dbGuild == nil {
|
||||
if id == "" || !createIfNotExist {
|
||||
return nil
|
||||
}
|
||||
|
||||
dbGuild = br.DB.Guild.New()
|
||||
dbGuild.ID = id
|
||||
dbGuild.Insert()
|
||||
}
|
||||
|
||||
guild := br.NewGuild(dbGuild)
|
||||
|
||||
br.guildsByID[guild.ID] = guild
|
||||
if guild.MXID != "" {
|
||||
br.guildsByMXID[guild.MXID] = guild
|
||||
}
|
||||
|
||||
return guild
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) GetGuildByMXID(mxid id.RoomID) *Guild {
|
||||
br.guildsLock.Lock()
|
||||
defer br.guildsLock.Unlock()
|
||||
|
||||
portal, ok := br.guildsByMXID[mxid]
|
||||
if !ok {
|
||||
return br.loadGuild(br.DB.Guild.GetByMXID(mxid), "", false)
|
||||
}
|
||||
|
||||
return portal
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) GetGuildByID(id string, createIfNotExist bool) *Guild {
|
||||
br.guildsLock.Lock()
|
||||
defer br.guildsLock.Unlock()
|
||||
|
||||
guild, ok := br.guildsByID[id]
|
||||
if !ok {
|
||||
return br.loadGuild(br.DB.Guild.GetByID(id), id, createIfNotExist)
|
||||
}
|
||||
|
||||
return guild
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) GetAllGuilds() []*Guild {
|
||||
return br.dbGuildsToGuilds(br.DB.Guild.GetAll())
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) dbGuildsToGuilds(dbGuilds []*database.Guild) []*Guild {
|
||||
br.guildsLock.Lock()
|
||||
defer br.guildsLock.Unlock()
|
||||
|
||||
output := make([]*Guild, len(dbGuilds))
|
||||
for index, dbGuild := range dbGuilds {
|
||||
if dbGuild == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
guild, ok := br.guildsByID[dbGuild.ID]
|
||||
if !ok {
|
||||
guild = br.loadGuild(dbGuild, "", false)
|
||||
}
|
||||
|
||||
output[index] = guild
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) NewGuild(dbGuild *database.Guild) *Guild {
|
||||
guild := &Guild{
|
||||
Guild: dbGuild,
|
||||
bridge: br,
|
||||
log: br.Log.Sub(fmt.Sprintf("Guild/%s", dbGuild.ID)),
|
||||
}
|
||||
|
||||
return guild
|
||||
}
|
||||
|
||||
func (guild *Guild) getBridgeInfo() (string, event.BridgeEventContent) {
|
||||
bridgeInfo := event.BridgeEventContent{
|
||||
BridgeBot: guild.bridge.Bot.UserID,
|
||||
Creator: guild.bridge.Bot.UserID,
|
||||
Protocol: event.BridgeInfoSection{
|
||||
ID: "discord",
|
||||
DisplayName: "Discord",
|
||||
AvatarURL: guild.bridge.Config.AppService.Bot.ParsedAvatar.CUString(),
|
||||
ExternalURL: "https://discord.com/",
|
||||
},
|
||||
Channel: event.BridgeInfoSection{
|
||||
ID: guild.ID,
|
||||
DisplayName: guild.Name,
|
||||
AvatarURL: guild.AvatarURL.CUString(),
|
||||
},
|
||||
}
|
||||
bridgeInfoStateKey := fmt.Sprintf("fi.mau.discord://discord/%s", guild.ID)
|
||||
return bridgeInfoStateKey, bridgeInfo
|
||||
}
|
||||
|
||||
func (guild *Guild) UpdateBridgeInfo() {
|
||||
if len(guild.MXID) == 0 {
|
||||
guild.log.Debugln("Not updating bridge info: no Matrix room created")
|
||||
return
|
||||
}
|
||||
guild.log.Debugln("Updating bridge info...")
|
||||
stateKey, content := guild.getBridgeInfo()
|
||||
_, err := guild.bridge.Bot.SendStateEvent(guild.MXID, event.StateBridge, stateKey, content)
|
||||
if err != nil {
|
||||
guild.log.Warnln("Failed to update m.bridge:", err)
|
||||
}
|
||||
// TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec
|
||||
_, err = guild.bridge.Bot.SendStateEvent(guild.MXID, event.StateHalfShotBridge, stateKey, content)
|
||||
if err != nil {
|
||||
guild.log.Warnln("Failed to update uk.half-shot.bridge:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (guild *Guild) CreateMatrixRoom(user *User, meta *discordgo.Guild) error {
|
||||
guild.roomCreateLock.Lock()
|
||||
defer guild.roomCreateLock.Unlock()
|
||||
if guild.MXID != "" {
|
||||
return nil
|
||||
}
|
||||
guild.log.Infoln("Creating Matrix room for guild")
|
||||
guild.UpdateInfo(user, meta)
|
||||
|
||||
bridgeInfoStateKey, bridgeInfo := guild.getBridgeInfo()
|
||||
|
||||
initialState := []*event.Event{{
|
||||
Type: event.StateBridge,
|
||||
Content: event.Content{Parsed: bridgeInfo},
|
||||
StateKey: &bridgeInfoStateKey,
|
||||
}, {
|
||||
// TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec
|
||||
Type: event.StateHalfShotBridge,
|
||||
Content: event.Content{Parsed: bridgeInfo},
|
||||
StateKey: &bridgeInfoStateKey,
|
||||
}}
|
||||
|
||||
if !guild.AvatarURL.IsEmpty() {
|
||||
initialState = append(initialState, &event.Event{
|
||||
Type: event.StateRoomAvatar,
|
||||
Content: event.Content{Parsed: &event.RoomAvatarEventContent{
|
||||
URL: guild.AvatarURL,
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
creationContent := map[string]interface{}{
|
||||
"type": event.RoomTypeSpace,
|
||||
}
|
||||
if !guild.bridge.Config.Bridge.FederateRooms {
|
||||
creationContent["m.federate"] = false
|
||||
}
|
||||
|
||||
resp, err := guild.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{
|
||||
Visibility: "private",
|
||||
Name: guild.Name,
|
||||
Preset: "private_chat",
|
||||
InitialState: initialState,
|
||||
CreationContent: creationContent,
|
||||
})
|
||||
if err != nil {
|
||||
guild.log.Warnln("Failed to create room:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
guild.MXID = resp.RoomID
|
||||
guild.NameSet = true
|
||||
guild.AvatarSet = !guild.AvatarURL.IsEmpty()
|
||||
guild.Update()
|
||||
guild.bridge.guildsLock.Lock()
|
||||
guild.bridge.guildsByMXID[guild.MXID] = guild
|
||||
guild.bridge.guildsLock.Unlock()
|
||||
guild.log.Infoln("Matrix room created:", guild.MXID)
|
||||
|
||||
user.ensureInvited(nil, guild.MXID, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (guild *Guild) UpdateInfo(source *User, meta *discordgo.Guild) *discordgo.Guild {
|
||||
if meta.Unavailable {
|
||||
return meta
|
||||
}
|
||||
changed := false
|
||||
// FIXME
|
||||
//name, err := guild.bridge.Config.Bridge.FormatChannelname(meta, user.Session)
|
||||
//if err != nil {
|
||||
// guild.log.Warnfln("failed to format name, proceeding with generic name: %v", err)
|
||||
// guild.Name = meta.Name
|
||||
//} else {
|
||||
//}
|
||||
changed = guild.UpdateName(meta.Name) || changed
|
||||
changed = guild.UpdateAvatar(meta.Icon) || changed
|
||||
if changed {
|
||||
guild.UpdateBridgeInfo()
|
||||
guild.Update()
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func (guild *Guild) UpdateName(name string) bool {
|
||||
if guild.Name == name && guild.NameSet {
|
||||
return false
|
||||
}
|
||||
guild.Name = name
|
||||
guild.NameSet = false
|
||||
if guild.MXID != "" {
|
||||
_, err := guild.bridge.Bot.SetRoomName(guild.MXID, guild.Name)
|
||||
if err != nil {
|
||||
guild.log.Warnln("Failed to update room name: %s", err)
|
||||
} else {
|
||||
guild.NameSet = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (guild *Guild) UpdateAvatar(iconID string) bool {
|
||||
if guild.Avatar == iconID && guild.AvatarSet {
|
||||
return false
|
||||
}
|
||||
guild.AvatarSet = false
|
||||
guild.Avatar = iconID
|
||||
if guild.Avatar != "" {
|
||||
var err error
|
||||
guild.AvatarURL, err = uploadAvatar(guild.bridge.Bot, discordgo.EndpointGuildIcon(guild.ID, iconID))
|
||||
if err != nil {
|
||||
guild.log.Warnln("Failed to reupload avatar %s: %v", iconID, err)
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
guild.AvatarURL = id.ContentURI{}
|
||||
}
|
||||
if guild.MXID != "" {
|
||||
_, err := guild.bridge.Bot.SetRoomAvatar(guild.MXID, guild.AvatarURL)
|
||||
if err != nil {
|
||||
guild.log.Warnln("Failed to update room avatar:", err)
|
||||
} else {
|
||||
guild.AvatarSet = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
17
main.go
17
main.go
@@ -20,13 +20,13 @@ import (
|
||||
_ "embed"
|
||||
"sync"
|
||||
|
||||
"go.mau.fi/mautrix-discord/database"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/bridge/commands"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/configupgrade"
|
||||
|
||||
"go.mau.fi/mautrix-discord/config"
|
||||
"go.mau.fi/mautrix-discord/database"
|
||||
)
|
||||
|
||||
// Information to find out exactly which commit the bridge was built from.
|
||||
@@ -59,6 +59,14 @@ type DiscordBridge struct {
|
||||
portalsByID map[database.PortalKey]*Portal
|
||||
portalsLock sync.Mutex
|
||||
|
||||
threadsByID map[string]*Thread
|
||||
threadsByRootMXID map[id.EventID]*Thread
|
||||
threadsLock sync.Mutex
|
||||
|
||||
guildsByMXID map[id.RoomID]*Guild
|
||||
guildsByID map[string]*Guild
|
||||
guildsLock sync.Mutex
|
||||
|
||||
puppets map[string]*Puppet
|
||||
puppetsByCustomMXID map[id.UserID]*Puppet
|
||||
puppetsLock sync.Mutex
|
||||
@@ -81,6 +89,7 @@ func (br *DiscordBridge) Init() {
|
||||
br.RegisterCommands()
|
||||
|
||||
br.DB = database.New(br.Bridge.DB)
|
||||
discordLog = br.Log.Sub("Discord")
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) Start() {
|
||||
@@ -144,6 +153,12 @@ func main() {
|
||||
portalsByMXID: make(map[id.RoomID]*Portal),
|
||||
portalsByID: make(map[database.PortalKey]*Portal),
|
||||
|
||||
threadsByID: make(map[string]*Thread),
|
||||
threadsByRootMXID: make(map[id.EventID]*Thread),
|
||||
|
||||
guildsByID: make(map[string]*Guild),
|
||||
guildsByMXID: make(map[id.RoomID]*Guild),
|
||||
|
||||
puppets: make(map[string]*Puppet),
|
||||
puppetsByCustomMXID: make(map[id.UserID]*Puppet),
|
||||
}
|
||||
|
||||
@@ -49,7 +49,6 @@ func newProvisioningAPI(br *DiscordBridge) *ProvisioningAPI {
|
||||
r.HandleFunc("/logout", p.logout).Methods(http.MethodPost)
|
||||
r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost)
|
||||
|
||||
// Setup the guild endpoints
|
||||
r.HandleFunc("/guilds", p.guildsList).Methods(http.MethodGet)
|
||||
r.HandleFunc("/guilds/{guildID}/bridge", p.guildsBridge).Methods(http.MethodPost)
|
||||
r.HandleFunc("/guilds/{guildID}/unbridge", p.guildsUnbridge).Methods(http.MethodPost)
|
||||
@@ -182,8 +181,8 @@ func (p *ProvisioningAPI) ping(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
user.Lock()
|
||||
if user.ID != "" {
|
||||
discord["id"] = user.ID
|
||||
if user.DiscordID != "" {
|
||||
discord["id"] = user.DiscordID
|
||||
}
|
||||
|
||||
if user.Session != nil {
|
||||
@@ -338,7 +337,7 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user.ID = discordUser.UserID
|
||||
user.DiscordID = discordUser.UserID
|
||||
user.Update()
|
||||
|
||||
if err := user.Login(discordUser.Token); err != nil {
|
||||
@@ -354,7 +353,7 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
c.WriteJSON(map[string]interface{}{
|
||||
"success": true,
|
||||
"id": user.ID,
|
||||
"id": user.DiscordID,
|
||||
})
|
||||
|
||||
return
|
||||
@@ -392,19 +391,17 @@ func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) {
|
||||
func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) {
|
||||
user := r.Context().Value("user").(*User)
|
||||
|
||||
user.guildsLock.Lock()
|
||||
defer user.guildsLock.Unlock()
|
||||
|
||||
data := make([]map[string]interface{}, len(user.guilds))
|
||||
idx := 0
|
||||
for _, guild := range user.guilds {
|
||||
data[idx] = map[string]interface{}{
|
||||
"name": guild.GuildName,
|
||||
"id": guild.GuildID,
|
||||
"bridged": guild.Bridge,
|
||||
var data []map[string]interface{}
|
||||
for _, userGuild := range user.GetGuilds() {
|
||||
guild := p.bridge.GetGuildByID(userGuild.GuildID, false)
|
||||
if guild == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
idx++
|
||||
data = append(data, map[string]interface{}{
|
||||
"name": guild.Name,
|
||||
"id": guild.ID,
|
||||
"mxid": guild.MXID,
|
||||
})
|
||||
}
|
||||
|
||||
jsonResponse(w, http.StatusOK, data)
|
||||
|
||||
163
puppet.go
163
puppet.go
@@ -7,6 +7,8 @@ import (
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/bridge"
|
||||
"maunium.net/go/mautrix/id"
|
||||
@@ -66,12 +68,12 @@ func (br *DiscordBridge) ParsePuppetMXID(mxid id.UserID) (string, bool) {
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
|
||||
id, ok := br.ParsePuppetMXID(mxid)
|
||||
discordID, ok := br.ParsePuppetMXID(mxid)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return br.GetPuppetByID(id)
|
||||
return br.GetPuppetByID(discordID)
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) GetPuppetByID(id string) *Puppet {
|
||||
@@ -159,7 +161,7 @@ func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI {
|
||||
}
|
||||
|
||||
func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI {
|
||||
if puppet.customIntent == nil {
|
||||
if puppet.customIntent == nil || (portal.Key.Receiver != "" && portal.Key.Receiver != puppet.ID) {
|
||||
return puppet.DefaultIntent()
|
||||
}
|
||||
|
||||
@@ -179,121 +181,86 @@ func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) {
|
||||
}
|
||||
}
|
||||
|
||||
func (puppet *Puppet) updateName(source *User) bool {
|
||||
user, err := source.Session.User(puppet.ID)
|
||||
if err != nil {
|
||||
puppet.log.Warnln("failed to get user from id:", err)
|
||||
func (puppet *Puppet) UpdateName(info *discordgo.User) bool {
|
||||
newName := puppet.bridge.Config.Bridge.FormatDisplayname(info)
|
||||
if puppet.Name == newName && puppet.NameSet {
|
||||
return false
|
||||
}
|
||||
|
||||
newName := puppet.bridge.Config.Bridge.FormatDisplayname(user)
|
||||
|
||||
if puppet.DisplayName != newName {
|
||||
err := puppet.DefaultIntent().SetDisplayName(newName)
|
||||
if err == nil {
|
||||
puppet.DisplayName = newName
|
||||
go puppet.updatePortalName()
|
||||
puppet.Update()
|
||||
} else {
|
||||
puppet.log.Warnln("failed to set display name:", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (puppet *Puppet) updatePortalName() {
|
||||
puppet.updatePortalMeta(func(portal *Portal) {
|
||||
if portal.MXID != "" {
|
||||
_, err := portal.MainIntent().SetRoomName(portal.MXID, puppet.DisplayName)
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to set name:", err)
|
||||
puppet.Name = newName
|
||||
puppet.NameSet = false
|
||||
err := puppet.DefaultIntent().SetDisplayName(newName)
|
||||
if err != nil {
|
||||
puppet.log.Warnln("Failed to update displayname:", err)
|
||||
} else {
|
||||
go puppet.updatePortalMeta(func(portal *Portal) {
|
||||
if portal.UpdateName(puppet.Name) {
|
||||
portal.Update()
|
||||
}
|
||||
}
|
||||
|
||||
portal.Name = puppet.DisplayName
|
||||
portal.Update()
|
||||
})
|
||||
}
|
||||
|
||||
func (puppet *Puppet) updateAvatar(source *User) bool {
|
||||
user, err := source.Session.User(puppet.ID)
|
||||
if err != nil {
|
||||
puppet.log.Warnln("Failed to get user:", err)
|
||||
|
||||
return false
|
||||
})
|
||||
puppet.NameSet = true
|
||||
}
|
||||
|
||||
if puppet.Avatar == user.Avatar {
|
||||
return false
|
||||
}
|
||||
|
||||
if user.Avatar == "" {
|
||||
puppet.log.Warnln("User does not have an avatar")
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
url, err := uploadAvatar(puppet.DefaultIntent(), user.AvatarURL(""))
|
||||
if err != nil {
|
||||
puppet.log.Warnln("Failed to upload user avatar:", err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
puppet.AvatarURL = url
|
||||
|
||||
err = puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL)
|
||||
if err != nil {
|
||||
puppet.log.Warnln("Failed to set avatar:", err)
|
||||
}
|
||||
|
||||
puppet.log.Debugln("Updated avatar", puppet.Avatar, "->", user.Avatar)
|
||||
puppet.Avatar = user.Avatar
|
||||
go puppet.updatePortalAvatar()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (puppet *Puppet) updatePortalAvatar() {
|
||||
puppet.updatePortalMeta(func(portal *Portal) {
|
||||
if portal.MXID != "" {
|
||||
_, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL)
|
||||
if err != nil {
|
||||
portal.log.Warnln("Failed to set avatar:", err)
|
||||
}
|
||||
func (puppet *Puppet) UpdateAvatar(info *discordgo.User) bool {
|
||||
if puppet.Avatar == info.Avatar && puppet.AvatarSet {
|
||||
return false
|
||||
}
|
||||
puppet.Avatar = info.Avatar
|
||||
puppet.AvatarSet = false
|
||||
|
||||
if puppet.Avatar == "" {
|
||||
// TODO should we just use discord's default avatars?
|
||||
puppet.AvatarURL = id.ContentURI{}
|
||||
} else {
|
||||
url, err := uploadAvatar(puppet.DefaultIntent(), info.AvatarURL(""))
|
||||
if err != nil {
|
||||
puppet.log.Warnln("Failed to reupload user avatar:", err)
|
||||
return true
|
||||
}
|
||||
puppet.AvatarURL = url
|
||||
}
|
||||
|
||||
portal.AvatarURL = puppet.AvatarURL
|
||||
portal.Avatar = puppet.Avatar
|
||||
portal.Update()
|
||||
})
|
||||
|
||||
err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL)
|
||||
if err != nil {
|
||||
puppet.log.Warnln("Failed to update avatar:", err)
|
||||
} else {
|
||||
go puppet.updatePortalMeta(func(portal *Portal) {
|
||||
if portal.UpdateAvatarFromPuppet(puppet) {
|
||||
portal.Update()
|
||||
}
|
||||
})
|
||||
puppet.AvatarSet = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (puppet *Puppet) SyncContact(source *User) {
|
||||
func (puppet *Puppet) UpdateInfo(source *User, info *discordgo.User) {
|
||||
puppet.syncLock.Lock()
|
||||
defer puppet.syncLock.Unlock()
|
||||
|
||||
puppet.log.Debugln("syncing contact", puppet.DisplayName)
|
||||
if info == nil || len(info.Username) == 0 || len(info.Discriminator) == 0 {
|
||||
if puppet.Name != "" {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
puppet.log.Debugfln("Fetching info through %s to update", source.DiscordID)
|
||||
info, err = source.Session.User(puppet.ID)
|
||||
if err != nil {
|
||||
puppet.log.Errorfln("Failed to fetch info through %s: %v", source.DiscordID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err := puppet.DefaultIntent().EnsureRegistered()
|
||||
if err != nil {
|
||||
puppet.log.Errorln("Failed to ensure registered:", err)
|
||||
}
|
||||
|
||||
update := false
|
||||
|
||||
update = puppet.updateName(source) || update
|
||||
|
||||
if puppet.Avatar == "" {
|
||||
update = puppet.updateAvatar(source) || update
|
||||
puppet.log.Debugln("update avatar returned", update)
|
||||
}
|
||||
|
||||
if update {
|
||||
changed := false
|
||||
changed = puppet.UpdateName(info) || changed
|
||||
changed = puppet.UpdateAvatar(info) || changed
|
||||
if changed {
|
||||
puppet.Update()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,13 +11,14 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
sync.Mutex
|
||||
|
||||
URL string
|
||||
Origin string
|
||||
URL string
|
||||
|
||||
conn *websocket.Conn
|
||||
|
||||
@@ -43,7 +44,6 @@ func New() (*Client, error) {
|
||||
|
||||
return &Client{
|
||||
URL: "wss://remote-auth-gateway.discord.gg/?v=1",
|
||||
Origin: "https://discord.com",
|
||||
privateKey: privateKey,
|
||||
}, nil
|
||||
}
|
||||
@@ -54,8 +54,9 @@ func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan str
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
header := http.Header{
|
||||
"Origin": []string{c.Origin},
|
||||
header := http.Header{}
|
||||
for key, value := range discordgo.DroidWSHeaders {
|
||||
header.Set(key, value)
|
||||
}
|
||||
|
||||
c.qrChan = qrChan
|
||||
|
||||
53
thread.go
Normal file
53
thread.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"go.mau.fi/mautrix-discord/database"
|
||||
)
|
||||
|
||||
type Thread struct {
|
||||
*database.Thread
|
||||
Parent *Portal
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
|
||||
br.threadsLock.Lock()
|
||||
defer br.threadsLock.Unlock()
|
||||
thread, ok := br.threadsByID[id]
|
||||
if !ok {
|
||||
return br.loadThread(br.DB.Thread.GetByDiscordID(id), id, root)
|
||||
}
|
||||
return thread
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread {
|
||||
br.threadsLock.Lock()
|
||||
defer br.threadsLock.Unlock()
|
||||
thread, ok := br.threadsByRootMXID[mxid]
|
||||
if !ok {
|
||||
return br.loadThread(br.DB.Thread.GetByMatrixRootMsg(mxid), "", nil)
|
||||
}
|
||||
return thread
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread {
|
||||
if dbThread == nil {
|
||||
if root == nil {
|
||||
return nil
|
||||
}
|
||||
dbThread = br.DB.Thread.New()
|
||||
dbThread.ID = id
|
||||
dbThread.RootDiscordID = root.DiscordID
|
||||
dbThread.RootMXID = root.MXID
|
||||
dbThread.ParentID = root.Channel.ChannelID
|
||||
dbThread.Insert()
|
||||
}
|
||||
thread := &Thread{
|
||||
Thread: dbThread,
|
||||
}
|
||||
thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, ""))
|
||||
br.threadsByID[thread.ID] = thread
|
||||
br.threadsByRootMXID[thread.RootMXID] = thread
|
||||
return thread
|
||||
}
|
||||
567
user.go
567
user.go
@@ -4,11 +4,15 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
@@ -35,12 +39,43 @@ type User struct {
|
||||
|
||||
PermissionLevel bridgeconfig.PermissionLevel
|
||||
|
||||
guilds map[string]*database.Guild
|
||||
guildsLock sync.Mutex
|
||||
spaceCreateLock sync.Mutex
|
||||
spaceMembershipChecked bool
|
||||
|
||||
Session *discordgo.Session
|
||||
}
|
||||
|
||||
var discordLog log.Logger
|
||||
|
||||
func init() {
|
||||
discordgo.Logger = func(msgL, caller int, format string, a ...interface{}) {
|
||||
pc, file, line, _ := runtime.Caller(caller + 1)
|
||||
|
||||
files := strings.Split(file, "/")
|
||||
file = files[len(files)-1]
|
||||
|
||||
name := runtime.FuncForPC(pc).Name()
|
||||
fns := strings.Split(name, ".")
|
||||
name = fns[len(fns)-1]
|
||||
|
||||
msg := fmt.Sprintf(format, a...)
|
||||
|
||||
var level log.Level
|
||||
switch msgL {
|
||||
case discordgo.LogError:
|
||||
level = log.LevelError
|
||||
case discordgo.LogWarning:
|
||||
level = log.LevelWarn
|
||||
case discordgo.LogInformational:
|
||||
level = log.LevelInfo
|
||||
case discordgo.LogDebug:
|
||||
level = log.LevelDebug
|
||||
}
|
||||
|
||||
discordLog.Logfln(level, "%s:%d:%s() %s", file, line, name, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) GetPermissionLevel() bridgeconfig.PermissionLevel {
|
||||
return user.PermissionLevel
|
||||
}
|
||||
@@ -66,10 +101,10 @@ func (user *User) GetIDoublePuppet() bridge.DoublePuppet {
|
||||
}
|
||||
|
||||
func (user *User) GetIGhost() bridge.Ghost {
|
||||
if user.ID == "" {
|
||||
if user.DiscordID == "" {
|
||||
return nil
|
||||
}
|
||||
p := user.bridge.GetPuppetByID(user.ID)
|
||||
p := user.bridge.GetPuppetByID(user.DiscordID)
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -78,14 +113,6 @@ func (user *User) GetIGhost() bridge.Ghost {
|
||||
|
||||
var _ bridge.User = (*User)(nil)
|
||||
|
||||
// this assume you are holding the guilds lock!!!
|
||||
func (user *User) loadGuilds() {
|
||||
user.guilds = map[string]*database.Guild{}
|
||||
for _, guild := range user.bridge.DB.Guild.GetAll(user.ID) {
|
||||
user.guilds[guild.GuildID] = guild
|
||||
}
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User {
|
||||
// If we weren't passed in a user we attempt to create one if we were given
|
||||
// a matrix id.
|
||||
@@ -103,8 +130,8 @@ func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User
|
||||
|
||||
// We assume the usersLock was acquired by our caller.
|
||||
br.usersByMXID[user.MXID] = user
|
||||
if user.ID != "" {
|
||||
br.usersByID[user.ID] = user
|
||||
if user.DiscordID != "" {
|
||||
br.usersByID[user.DiscordID] = user
|
||||
}
|
||||
|
||||
if user.ManagementRoom != "" {
|
||||
@@ -114,17 +141,10 @@ func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User
|
||||
br.managementRoomsLock.Unlock()
|
||||
}
|
||||
|
||||
// Load our guilds state from the database and turn it into a map
|
||||
user.guildsLock.Lock()
|
||||
user.loadGuilds()
|
||||
user.guildsLock.Unlock()
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) GetUserByMXID(userID id.UserID) *User {
|
||||
// TODO: check if puppet
|
||||
|
||||
br.usersLock.Lock()
|
||||
defer br.usersLock.Unlock()
|
||||
|
||||
@@ -153,7 +173,6 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User {
|
||||
User: dbUser,
|
||||
bridge: br,
|
||||
log: br.Log.Sub("User").Sub(string(dbUser.MXID)),
|
||||
guilds: map[string]*database.Guild{},
|
||||
}
|
||||
|
||||
user.PermissionLevel = br.Config.Bridge.Permissions.Get(user.MXID)
|
||||
@@ -161,11 +180,11 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User {
|
||||
return user
|
||||
}
|
||||
|
||||
func (br *DiscordBridge) getAllUsers() []*User {
|
||||
func (br *DiscordBridge) getAllUsersWithToken() []*User {
|
||||
br.usersLock.Lock()
|
||||
defer br.usersLock.Unlock()
|
||||
|
||||
dbUsers := br.DB.User.GetAll()
|
||||
dbUsers := br.DB.User.GetAllWithToken()
|
||||
users := make([]*User, len(dbUsers))
|
||||
|
||||
for idx, dbUser := range dbUsers {
|
||||
@@ -182,7 +201,7 @@ func (br *DiscordBridge) getAllUsers() []*User {
|
||||
func (br *DiscordBridge) startUsers() {
|
||||
br.Log.Debugln("Starting users")
|
||||
|
||||
for _, u := range br.getAllUsers() {
|
||||
for _, u := range br.getAllUsersWithToken() {
|
||||
go func(user *User) {
|
||||
err := user.Connect()
|
||||
if err != nil {
|
||||
@@ -209,10 +228,6 @@ func (user *User) SetManagementRoom(roomID id.RoomID) {
|
||||
|
||||
existing, ok := user.bridge.managementRooms[roomID]
|
||||
if ok {
|
||||
// If there's a user already assigned to this management room, clear it
|
||||
// out.
|
||||
// I think this is due a name change or something? I dunno, leaving it
|
||||
// for now.
|
||||
existing.ManagementRoom = ""
|
||||
existing.Update()
|
||||
}
|
||||
@@ -222,6 +237,52 @@ func (user *User) SetManagementRoom(roomID id.RoomID) {
|
||||
user.Update()
|
||||
}
|
||||
|
||||
func (user *User) GetSpaceRoom() id.RoomID {
|
||||
if len(user.SpaceRoom) == 0 {
|
||||
user.spaceCreateLock.Lock()
|
||||
defer user.spaceCreateLock.Unlock()
|
||||
if len(user.SpaceRoom) > 0 {
|
||||
return user.SpaceRoom
|
||||
}
|
||||
|
||||
resp, err := user.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{
|
||||
Visibility: "private",
|
||||
Name: "Discord",
|
||||
Topic: "Your Discord bridged chats",
|
||||
InitialState: []*event.Event{{
|
||||
Type: event.StateRoomAvatar,
|
||||
Content: event.Content{
|
||||
Parsed: &event.RoomAvatarEventContent{
|
||||
URL: user.bridge.Config.AppService.Bot.ParsedAvatar,
|
||||
},
|
||||
},
|
||||
}},
|
||||
CreationContent: map[string]interface{}{
|
||||
"type": event.RoomTypeSpace,
|
||||
},
|
||||
PowerLevelOverride: &event.PowerLevelsEventContent{
|
||||
Users: map[id.UserID]int{
|
||||
user.bridge.Bot.UserID: 9001,
|
||||
user.MXID: 50,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
user.log.Errorln("Failed to auto-create space room:", err)
|
||||
} else {
|
||||
user.SpaceRoom = resp.RoomID
|
||||
user.Update()
|
||||
user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false)
|
||||
}
|
||||
} else if !user.spaceMembershipChecked && !user.bridge.StateStore.IsInRoom(user.SpaceRoom, user.MXID) {
|
||||
user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false)
|
||||
}
|
||||
user.spaceMembershipChecked = true
|
||||
|
||||
return user.SpaceRoom
|
||||
}
|
||||
|
||||
func (user *User) tryAutomaticDoublePuppeting() {
|
||||
user.Lock()
|
||||
defer user.Unlock()
|
||||
@@ -232,7 +293,7 @@ func (user *User) tryAutomaticDoublePuppeting() {
|
||||
|
||||
user.log.Debugln("Checking if double puppeting needs to be enabled")
|
||||
|
||||
puppet := user.bridge.GetPuppetByID(user.ID)
|
||||
puppet := user.bridge.GetPuppetByID(user.DiscordID)
|
||||
if puppet.CustomMXID != "" {
|
||||
user.log.Debugln("User already has double-puppeting enabled")
|
||||
|
||||
@@ -270,7 +331,7 @@ func (user *User) syncChatDoublePuppetDetails(portal *Portal, justCreated bool)
|
||||
}
|
||||
|
||||
func (user *User) Login(token string) error {
|
||||
user.Token = token
|
||||
user.DiscordToken = token
|
||||
user.Update()
|
||||
return user.Connect()
|
||||
}
|
||||
@@ -279,7 +340,7 @@ func (user *User) IsLoggedIn() bool {
|
||||
user.Lock()
|
||||
defer user.Unlock()
|
||||
|
||||
return user.Token != ""
|
||||
return user.DiscordToken != ""
|
||||
}
|
||||
|
||||
func (user *User) Logout() error {
|
||||
@@ -290,7 +351,7 @@ func (user *User) Logout() error {
|
||||
return ErrNotLoggedIn
|
||||
}
|
||||
|
||||
puppet := user.bridge.GetPuppetByID(user.ID)
|
||||
puppet := user.bridge.GetPuppetByID(user.DiscordID)
|
||||
if puppet.CustomMXID != "" {
|
||||
err := puppet.SwitchCustomMXID("", "")
|
||||
if err != nil {
|
||||
@@ -304,7 +365,7 @@ func (user *User) Logout() error {
|
||||
|
||||
user.Session = nil
|
||||
|
||||
user.Token = ""
|
||||
user.DiscordToken = ""
|
||||
user.Update()
|
||||
|
||||
return nil
|
||||
@@ -321,16 +382,20 @@ func (user *User) Connect() error {
|
||||
user.Lock()
|
||||
defer user.Unlock()
|
||||
|
||||
if user.Token == "" {
|
||||
if user.DiscordToken == "" {
|
||||
return ErrNotLoggedIn
|
||||
}
|
||||
|
||||
user.log.Debugln("connecting to discord")
|
||||
user.log.Debugln("Connecting to discord")
|
||||
|
||||
session, err := discordgo.New(user.Token)
|
||||
session, err := discordgo.New(user.DiscordToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO move to config
|
||||
if os.Getenv("DISCORD_DEBUG") == "1" {
|
||||
session.LogLevel = discordgo.LogDebug
|
||||
}
|
||||
|
||||
user.Session = session
|
||||
|
||||
@@ -382,263 +447,178 @@ func (user *User) bridgeMessage(guildID string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
user.guildsLock.Lock()
|
||||
defer user.guildsLock.Unlock()
|
||||
|
||||
if guild, found := user.guilds[guildID]; found {
|
||||
if guild.Bridge {
|
||||
return true
|
||||
}
|
||||
guild := user.bridge.GetGuildByID(guildID, false)
|
||||
if guild.MXID != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
user.log.Debugfln("ignoring message for non-bridged guild %s-%s", user.ID, guildID)
|
||||
|
||||
user.log.Debugfln("Cgnoring message for non-bridged guild %s", guildID)
|
||||
return false
|
||||
}
|
||||
|
||||
func (user *User) readyHandler(s *discordgo.Session, r *discordgo.Ready) {
|
||||
user.log.Debugln("discord connection ready")
|
||||
func (user *User) readyHandler(_ *discordgo.Session, r *discordgo.Ready) {
|
||||
user.log.Debugln("Discord connection ready")
|
||||
|
||||
// Update our user fields
|
||||
user.ID = r.User.ID
|
||||
|
||||
// Update our guild map to match watch discord thinks we're in. This is the
|
||||
// only time we can get the full guild map as discordgo doesn't make it
|
||||
// available to us later. Also, discord might not give us the full guild
|
||||
// information here, so we use this to remove guilds the user left and only
|
||||
// add guilds whose full information we have. The are told about the
|
||||
// "unavailable" guilds later via the GuildCreate handler.
|
||||
user.guildsLock.Lock()
|
||||
defer user.guildsLock.Unlock()
|
||||
|
||||
// build a list of the current guilds we're in so we can prune the old ones
|
||||
current := []string{}
|
||||
|
||||
user.log.Debugln("database guild count", len(user.guilds))
|
||||
user.log.Debugln("discord guild count", len(r.Guilds))
|
||||
|
||||
for _, guild := range r.Guilds {
|
||||
current = append(current, guild.ID)
|
||||
|
||||
// If we already know about this guild, make sure we reset it's bridge
|
||||
// status.
|
||||
if val, found := user.guilds[guild.ID]; found {
|
||||
bridge := val.Bridge
|
||||
user.guilds[guild.ID].Bridge = bridge
|
||||
|
||||
// Update the name if the guild is available
|
||||
if !guild.Unavailable {
|
||||
user.guilds[guild.ID].GuildName = guild.Name
|
||||
}
|
||||
|
||||
val.Upsert()
|
||||
} else {
|
||||
g := user.bridge.DB.Guild.New()
|
||||
g.DiscordID = user.ID
|
||||
g.GuildID = guild.ID
|
||||
user.guilds[guild.ID] = g
|
||||
|
||||
if !guild.Unavailable {
|
||||
g.GuildName = guild.Name
|
||||
}
|
||||
|
||||
g.Upsert()
|
||||
}
|
||||
if user.DiscordID != r.User.ID {
|
||||
user.DiscordID = r.User.ID
|
||||
user.Update()
|
||||
}
|
||||
|
||||
// Sync the guilds to the database.
|
||||
user.bridge.DB.Guild.Prune(user.ID, current)
|
||||
|
||||
// Finally reload from the database since it purged servers we're not in
|
||||
// anymore.
|
||||
user.loadGuilds()
|
||||
|
||||
user.log.Debugln("updated database guild count", len(user.guilds))
|
||||
|
||||
user.Update()
|
||||
updateTS := time.Now()
|
||||
guildsInSpace := make(map[string]bool)
|
||||
for _, guild := range user.GetGuilds() {
|
||||
guildsInSpace[guild.GuildID] = guild.InSpace
|
||||
}
|
||||
for _, guild := range r.Guilds {
|
||||
user.handleGuild(guild, updateTS, guildsInSpace[guild.ID])
|
||||
}
|
||||
user.PruneGuildList(updateTS)
|
||||
const maxCreate = 5
|
||||
for i, ch := range r.PrivateChannels {
|
||||
portal := user.GetPortalByMeta(ch)
|
||||
if i < maxCreate && portal.MXID == "" {
|
||||
err := portal.CreateMatrixRoom(user, ch)
|
||||
if err != nil {
|
||||
user.log.Errorfln("Failed to create portal for private channel %s in initial sync: %v", ch.ID, err)
|
||||
}
|
||||
} else {
|
||||
portal.UpdateInfo(user, ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) {
|
||||
user.log.Debugln("connected to discord")
|
||||
func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSpace bool) {
|
||||
guild := user.bridge.GetGuildByID(meta.ID, true)
|
||||
guild.UpdateInfo(user, meta)
|
||||
if len(meta.Channels) > 0 {
|
||||
for _, ch := range meta.Channels {
|
||||
portal := user.GetPortalByMeta(ch)
|
||||
if guild.AutoBridgeChannels && portal.MXID == "" {
|
||||
err := portal.CreateMatrixRoom(user, ch)
|
||||
if err != nil {
|
||||
user.log.Errorfln("Failed to create portal for guild channel %s/%s in initial sync: %v", guild.ID, ch.ID, err)
|
||||
}
|
||||
} else {
|
||||
portal.UpdateInfo(user, ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(guild.MXID) > 0 && !isInSpace {
|
||||
_, err := user.bridge.Bot.SendStateEvent(user.GetSpaceRoom(), event.StateSpaceChild, guild.MXID.String(), &event.SpaceChildEventContent{
|
||||
Via: []string{user.bridge.AS.HomeserverDomain},
|
||||
})
|
||||
if err != nil {
|
||||
user.log.Errorfln("Failed to add guild space %s to user space: %v", guild.MXID, err)
|
||||
} else {
|
||||
isInSpace = true
|
||||
}
|
||||
}
|
||||
user.MarkInGuild(database.UserGuild{GuildID: meta.ID, Timestamp: timestamp, InSpace: isInSpace})
|
||||
}
|
||||
|
||||
func (user *User) connectedHandler(_ *discordgo.Session, c *discordgo.Connect) {
|
||||
user.log.Debugln("Connected to discord")
|
||||
|
||||
user.tryAutomaticDoublePuppeting()
|
||||
}
|
||||
|
||||
func (user *User) disconnectedHandler(s *discordgo.Session, d *discordgo.Disconnect) {
|
||||
user.log.Debugln("disconnected from discord")
|
||||
func (user *User) disconnectedHandler(_ *discordgo.Session, d *discordgo.Disconnect) {
|
||||
user.log.Debugln("Disconnected from discord")
|
||||
}
|
||||
|
||||
func (user *User) guildCreateHandler(s *discordgo.Session, g *discordgo.GuildCreate) {
|
||||
user.guildsLock.Lock()
|
||||
defer user.guildsLock.Unlock()
|
||||
|
||||
// If we somehow already know about the guild, just update it's name
|
||||
if guild, found := user.guilds[g.ID]; found {
|
||||
guild.GuildName = g.Name
|
||||
guild.Upsert()
|
||||
func (user *User) guildCreateHandler(_ *discordgo.Session, g *discordgo.GuildCreate) {
|
||||
user.handleGuild(g.Guild, time.Now(), false)
|
||||
}
|
||||
|
||||
func (user *User) guildDeleteHandler(_ *discordgo.Session, g *discordgo.GuildDelete) {
|
||||
user.MarkNotInGuild(g.ID)
|
||||
guild := user.bridge.GetGuildByID(g.ID, false)
|
||||
if guild == nil || guild.MXID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// This is a brand new guild so lets get it added.
|
||||
guild := user.bridge.DB.Guild.New()
|
||||
guild.DiscordID = user.ID
|
||||
guild.GuildID = g.ID
|
||||
guild.GuildName = g.Name
|
||||
guild.Upsert()
|
||||
|
||||
user.guilds[g.ID] = guild
|
||||
// TODO clean up?
|
||||
}
|
||||
|
||||
func (user *User) guildDeleteHandler(s *discordgo.Session, g *discordgo.GuildDelete) {
|
||||
user.guildsLock.Lock()
|
||||
defer user.guildsLock.Unlock()
|
||||
func (user *User) guildUpdateHandler(_ *discordgo.Session, g *discordgo.GuildUpdate) {
|
||||
user.handleGuild(g.Guild, time.Now(), user.IsInSpace(g.ID))
|
||||
}
|
||||
|
||||
if guild, found := user.guilds[g.ID]; found {
|
||||
guild.Delete()
|
||||
delete(user.guilds, g.ID)
|
||||
user.log.Debugln("deleted guild", g.Guild.ID)
|
||||
func (user *User) channelCreateHandler(_ *discordgo.Session, c *discordgo.ChannelCreate) {
|
||||
if !user.bridgeMessage(c.GuildID) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) guildUpdateHandler(s *discordgo.Session, g *discordgo.GuildUpdate) {
|
||||
user.guildsLock.Lock()
|
||||
defer user.guildsLock.Unlock()
|
||||
|
||||
// If we somehow already know about the guild, just update it's name
|
||||
if guild, found := user.guilds[g.ID]; found {
|
||||
guild.GuildName = g.Name
|
||||
guild.Upsert()
|
||||
|
||||
user.log.Debugln("updated guild", g.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) createChannel(c *discordgo.Channel) {
|
||||
key := database.NewPortalKey(c.ID, user.User.ID)
|
||||
portal := user.bridge.GetPortalByID(key)
|
||||
|
||||
portal := user.GetPortalByMeta(c.Channel)
|
||||
if portal.MXID != "" {
|
||||
return
|
||||
}
|
||||
|
||||
portal.Name = c.Name
|
||||
portal.Topic = c.Topic
|
||||
portal.Type = c.Type
|
||||
|
||||
if portal.Type == discordgo.ChannelTypeDM {
|
||||
portal.OtherUserID = c.Recipients[0].ID
|
||||
err := portal.CreateMatrixRoom(user, c.Channel)
|
||||
if err != nil {
|
||||
user.log.Errorfln("Error creating Matrix room for %s on channel create event: %v", c.ID, err)
|
||||
}
|
||||
|
||||
if c.Icon != "" {
|
||||
user.log.Debugln("channel icon", c.Icon)
|
||||
}
|
||||
|
||||
portal.Update()
|
||||
|
||||
portal.createMatrixRoom(user, c)
|
||||
}
|
||||
|
||||
func (user *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) {
|
||||
user.createChannel(c.Channel)
|
||||
}
|
||||
|
||||
func (user *User) channelDeleteHandler(s *discordgo.Session, c *discordgo.ChannelDelete) {
|
||||
func (user *User) channelDeleteHandler(_ *discordgo.Session, c *discordgo.ChannelDelete) {
|
||||
user.log.Debugln("channel delete handler")
|
||||
}
|
||||
|
||||
func (user *User) channelPinsUpdateHandler(s *discordgo.Session, c *discordgo.ChannelPinsUpdate) {
|
||||
func (user *User) channelPinsUpdateHandler(_ *discordgo.Session, c *discordgo.ChannelPinsUpdate) {
|
||||
user.log.Debugln("channel pins update")
|
||||
}
|
||||
|
||||
func (user *User) channelUpdateHandler(s *discordgo.Session, c *discordgo.ChannelUpdate) {
|
||||
key := database.NewPortalKey(c.ID, user.User.ID)
|
||||
portal := user.bridge.GetPortalByID(key)
|
||||
|
||||
portal.update(user, c.Channel)
|
||||
func (user *User) channelUpdateHandler(_ *discordgo.Session, c *discordgo.ChannelUpdate) {
|
||||
portal := user.GetPortalByMeta(c.Channel)
|
||||
portal.UpdateInfo(user, c.Channel)
|
||||
}
|
||||
|
||||
func (user *User) messageCreateHandler(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
if !user.bridgeMessage(m.GuildID) {
|
||||
func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildID string) {
|
||||
fmt.Printf("%+v\n", msg)
|
||||
if !user.bridgeMessage(guildID) {
|
||||
return
|
||||
}
|
||||
|
||||
key := database.NewPortalKey(m.ChannelID, user.ID)
|
||||
portal := user.bridge.GetPortalByID(key)
|
||||
|
||||
msg := portalDiscordMessage{
|
||||
msg: m,
|
||||
user: user,
|
||||
portal := user.GetExistingPortalByID(channelID)
|
||||
var thread *Thread
|
||||
if portal == nil {
|
||||
thread = user.bridge.GetThreadByID(channelID, nil)
|
||||
if thread == nil || thread.Parent == nil {
|
||||
user.log.Debugfln("Dropping %s in unknown channel %s/%s", typeName, guildID, channelID)
|
||||
return
|
||||
}
|
||||
portal = thread.Parent
|
||||
}
|
||||
|
||||
portal.discordMessages <- msg
|
||||
portal.discordMessages <- portalDiscordMessage{
|
||||
msg: msg,
|
||||
user: user,
|
||||
thread: thread,
|
||||
}
|
||||
}
|
||||
|
||||
func (user *User) messageDeleteHandler(s *discordgo.Session, m *discordgo.MessageDelete) {
|
||||
if !user.bridgeMessage(m.GuildID) {
|
||||
return
|
||||
}
|
||||
|
||||
key := database.NewPortalKey(m.ChannelID, user.ID)
|
||||
portal := user.bridge.GetPortalByID(key)
|
||||
|
||||
msg := portalDiscordMessage{
|
||||
msg: m,
|
||||
user: user,
|
||||
}
|
||||
|
||||
portal.discordMessages <- msg
|
||||
func (user *User) messageCreateHandler(_ *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
user.pushPortalMessage(m, "message create", m.ChannelID, m.GuildID)
|
||||
}
|
||||
|
||||
func (user *User) messageUpdateHandler(s *discordgo.Session, m *discordgo.MessageUpdate) {
|
||||
if !user.bridgeMessage(m.GuildID) {
|
||||
return
|
||||
}
|
||||
|
||||
key := database.NewPortalKey(m.ChannelID, user.ID)
|
||||
portal := user.bridge.GetPortalByID(key)
|
||||
|
||||
msg := portalDiscordMessage{
|
||||
msg: m,
|
||||
user: user,
|
||||
}
|
||||
|
||||
portal.discordMessages <- msg
|
||||
func (user *User) messageDeleteHandler(_ *discordgo.Session, m *discordgo.MessageDelete) {
|
||||
user.pushPortalMessage(m, "message delete", m.ChannelID, m.GuildID)
|
||||
}
|
||||
|
||||
func (user *User) reactionAddHandler(s *discordgo.Session, m *discordgo.MessageReactionAdd) {
|
||||
if !user.bridgeMessage(m.MessageReaction.GuildID) {
|
||||
return
|
||||
}
|
||||
|
||||
key := database.NewPortalKey(m.ChannelID, user.User.ID)
|
||||
portal := user.bridge.GetPortalByID(key)
|
||||
|
||||
msg := portalDiscordMessage{
|
||||
msg: m,
|
||||
user: user,
|
||||
}
|
||||
|
||||
portal.discordMessages <- msg
|
||||
func (user *User) messageUpdateHandler(_ *discordgo.Session, m *discordgo.MessageUpdate) {
|
||||
user.pushPortalMessage(m, "message update", m.ChannelID, m.GuildID)
|
||||
}
|
||||
|
||||
func (user *User) reactionRemoveHandler(s *discordgo.Session, m *discordgo.MessageReactionRemove) {
|
||||
if !user.bridgeMessage(m.MessageReaction.GuildID) {
|
||||
return
|
||||
}
|
||||
func (user *User) reactionAddHandler(_ *discordgo.Session, m *discordgo.MessageReactionAdd) {
|
||||
user.pushPortalMessage(m, "reaction add", m.ChannelID, m.GuildID)
|
||||
}
|
||||
|
||||
key := database.NewPortalKey(m.ChannelID, user.User.ID)
|
||||
portal := user.bridge.GetPortalByID(key)
|
||||
|
||||
msg := portalDiscordMessage{
|
||||
msg: m,
|
||||
user: user,
|
||||
}
|
||||
|
||||
portal.discordMessages <- msg
|
||||
func (user *User) reactionRemoveHandler(_ *discordgo.Session, m *discordgo.MessageReactionRemove) {
|
||||
user.pushPortalMessage(m, "reaction remove", m.ChannelID, m.GuildID)
|
||||
}
|
||||
|
||||
func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) bool {
|
||||
if intent == nil {
|
||||
intent = user.bridge.Bot
|
||||
}
|
||||
ret := false
|
||||
|
||||
inviteContent := event.Content{
|
||||
@@ -682,7 +662,7 @@ func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID,
|
||||
func (user *User) getDirectChats() map[id.UserID][]id.RoomID {
|
||||
chats := map[id.UserID][]id.RoomID{}
|
||||
|
||||
privateChats := user.bridge.DB.Portal.FindPrivateChatsOf(user.ID)
|
||||
privateChats := user.bridge.DB.Portal.FindPrivateChatsOf(user.DiscordID)
|
||||
for _, portal := range privateChats {
|
||||
if portal.MXID != "" {
|
||||
puppetMXID := user.bridge.FormatPuppetMXID(portal.Key.Receiver)
|
||||
@@ -755,28 +735,21 @@ func (user *User) updateDirectChats(chats map[id.UserID][]id.RoomID) {
|
||||
}
|
||||
|
||||
func (user *User) bridgeGuild(guildID string, everything bool) error {
|
||||
user.guildsLock.Lock()
|
||||
defer user.guildsLock.Unlock()
|
||||
|
||||
guild, found := user.guilds[guildID]
|
||||
if !found {
|
||||
return fmt.Errorf("guildID not found")
|
||||
guild := user.bridge.GetGuildByID(guildID, false)
|
||||
if guild == nil {
|
||||
return errors.New("guild not found")
|
||||
}
|
||||
|
||||
// Update the guild
|
||||
guild.Bridge = true
|
||||
guild.Upsert()
|
||||
|
||||
// If this is a full bridge, create portals for all the channels
|
||||
if everything {
|
||||
channels, err := user.Session.GuildChannels(guildID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
if channelIsBridgeable(channel) {
|
||||
user.createChannel(channel)
|
||||
meta, _ := user.Session.State.Guild(guildID)
|
||||
err := guild.CreateMatrixRoom(user, meta)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, ch := range meta.Channels {
|
||||
portal := user.GetPortalByMeta(ch)
|
||||
if (everything && channelIsBridgeable(ch)) || ch.Type == discordgo.ChannelTypeGuildCategory {
|
||||
err = portal.CreateMatrixRoom(user, ch)
|
||||
if err != nil {
|
||||
user.log.Warnfln("Error creating room for guild channel %s: %v", ch.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -785,41 +758,41 @@ func (user *User) bridgeGuild(guildID string, everything bool) error {
|
||||
}
|
||||
|
||||
func (user *User) unbridgeGuild(guildID string) error {
|
||||
user.guildsLock.Lock()
|
||||
defer user.guildsLock.Unlock()
|
||||
|
||||
guild, exists := user.guilds[guildID]
|
||||
if !exists {
|
||||
return fmt.Errorf("guildID not found")
|
||||
}
|
||||
|
||||
if !guild.Bridge {
|
||||
return fmt.Errorf("guild not bridged")
|
||||
}
|
||||
|
||||
// First update the guild so we don't have any other go routines recreating
|
||||
// channels we're about to destroy.
|
||||
guild.Bridge = false
|
||||
guild.Upsert()
|
||||
|
||||
// Now run through the channels in the guild and remove any portals we
|
||||
// have for them.
|
||||
channels, err := user.Session.GuildChannels(guildID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
if channelIsBridgeable(channel) {
|
||||
key := database.PortalKey{
|
||||
ChannelID: channel.ID,
|
||||
Receiver: user.ID,
|
||||
}
|
||||
|
||||
portal := user.bridge.GetPortalByID(key)
|
||||
portal.leave(user)
|
||||
}
|
||||
}
|
||||
//user.guildsLock.Lock()
|
||||
//defer user.guildsLock.Unlock()
|
||||
//
|
||||
//guild, exists := user.guilds[guildID]
|
||||
//if !exists {
|
||||
// return fmt.Errorf("guildID not found")
|
||||
//}
|
||||
//
|
||||
//if !guild.Bridge {
|
||||
// return fmt.Errorf("guild not bridged")
|
||||
//}
|
||||
//
|
||||
//// First update the guild so we don't have any other go routines recreating
|
||||
//// channels we're about to destroy.
|
||||
//guild.Bridge = false
|
||||
//guild.Upsert()
|
||||
//
|
||||
//// Now run through the channels in the guild and remove any portals we
|
||||
//// have for them.
|
||||
//channels, err := user.Session.GuildChannels(guildID)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//for _, channel := range channels {
|
||||
// if channelIsBridgeable(channel) {
|
||||
// key := database.PortalKey{
|
||||
// ChannelID: channel.ID,
|
||||
// Receiver: user.DiscordID,
|
||||
// }
|
||||
//
|
||||
// portal := user.bridge.GetPortalByID(key)
|
||||
// portal.leave(user)
|
||||
// }
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user