Break a bunch of stuff

This commit is contained in:
Tulir Asokan
2022-05-28 23:03:24 +03:00
parent 575f684a54
commit 91dbc83b5d
26 changed files with 1824 additions and 1002 deletions

View File

@@ -3,7 +3,7 @@ package main
import (
"bytes"
"image"
"io/ioutil"
"io"
"net/http"
"strings"
@@ -29,17 +29,16 @@ func (portal *Portal) downloadDiscordAttachment(url string) ([]byte, error) {
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", discordgo.DroidBrowserUserAgent)
for key, value := range discordgo.DroidDownloadHeaders {
req.Header.Set(key, value)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
return io.ReadAll(resp.Body)
}
func (portal *Portal) downloadMatrixAttachment(eventID id.EventID, content *event.MessageEventContent) ([]byte, error) {

View File

@@ -7,16 +7,25 @@ import (
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/id"
"github.com/bwmarrin/discordgo"
)
func uploadAvatar(intent *appservice.IntentAPI, url string) (id.ContentURI, error) {
getResp, err := http.DefaultClient.Get(url)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return id.ContentURI{}, fmt.Errorf("failed to prepare request: %w", err)
}
for key, value := range discordgo.DroidImageHeaders {
req.Header.Set(key, value)
}
getResp, err := http.DefaultClient.Do(req)
if err != nil {
return id.ContentURI{}, fmt.Errorf("failed to download avatar: %w", err)
}
data, err := io.ReadAll(getResp.Body)
getResp.Body.Close()
_ = getResp.Body.Close()
if err != nil {
return id.ContentURI{}, fmt.Errorf("failed to read avatar data: %w", err)
}

View File

@@ -116,7 +116,7 @@ func fnLogin(ce *WrappedCommandEvent) {
ce.Reply("Error connecting after login: %v", err)
}
ce.User.Lock()
ce.User.ID = user.UserID
ce.User.DiscordID = user.UserID
ce.User.Update()
ce.User.Unlock()
ce.Reply("Successfully logged in as %s#%s", user.Username, user.Discriminator)
@@ -251,19 +251,21 @@ func fnGuilds(ce *WrappedCommandEvent) {
}
func fnListGuilds(ce *WrappedCommandEvent) {
ce.User.guildsLock.Lock()
defer ce.User.guildsLock.Unlock()
if len(ce.User.guilds) == 0 {
ce.Reply("You haven't joined any guilds")
} else {
var output strings.Builder
for _, guild := range ce.User.guilds {
for _, userGuild := range ce.User.GetGuilds() {
guild := ce.Bridge.GetGuildByID(userGuild.GuildID, false)
if guild == nil {
continue
}
status := "not bridged"
if guild.Bridge {
if guild.MXID != "" {
status = "bridged"
}
_, _ = fmt.Fprintf(&output, "* %s (`%s`) - %s\n", guild.GuildName, guild.GuildID, status)
_, _ = fmt.Fprintf(&output, "* %s (`%s`) - %s\n", guild.Name, guild.ID, status)
}
if output.Len() == 0 {
ce.Reply("No guilds found")
} else {
ce.Reply("List of guilds:\n\n%s", output.String())
}
}

View File

@@ -16,7 +16,7 @@ type AttachmentQuery struct {
}
const (
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver FROM attachment"
attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id FROM attachment"
)
func (aq *AttachmentQuery) New() *Attachment {
@@ -78,57 +78,49 @@ type Attachment struct {
log log.Logger
Channel PortalKey
DiscordMessageID string
DiscordAttachmentID string
ThreadID string
MessageID string
ID string
MXID id.EventID
}
func (a *Attachment) Scan(row dbutil.Scannable) *Attachment {
err := row.Scan(
&a.DiscordAttachmentID, &a.DiscordMessageID,
&a.Channel.ChannelID, &a.Channel.Receiver,
&a.MXID)
func (a *Attachment) DiscordProtoChannelID() string {
if a.ThreadID != "" {
return a.ThreadID
} else {
return a.Channel.ChannelID
}
}
func (a *Attachment) Scan(row dbutil.Scannable) *Attachment {
err := row.Scan(&a.ID, &a.MessageID, &a.Channel.ChannelID, &a.Channel.Receiver, &a.ThreadID, &a.MXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
a.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
return a
}
func (a *Attachment) Insert() {
query := "INSERT INTO attachment" +
" (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, " +
" mxid) VALUES ($1, $2, $3, $4, $5);"
_, err := a.db.Exec(
query,
a.Channel.ChannelID, a.Channel.Receiver,
a.DiscordMessageID, a.DiscordAttachmentID,
a.MXID,
)
query := `
INSERT INTO attachment (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, mxid)
VALUES ($1, $2, $3, $4, $5)
`
_, err := a.db.Exec(query, a.ID, a.MessageID, a.Channel.ChannelID, a.Channel.Receiver, strPtr(a.ThreadID), a.MXID)
if err != nil {
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.ID, a.Channel, err)
panic(err)
}
}
func (a *Attachment) Delete() {
query := "DELETE FROM attachment WHERE" +
" dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
_, err := a.db.Exec(
query,
a.Channel.ChannelID, a.Channel.Receiver,
a.DiscordAttachmentID,
)
query := `DELETE FROM attachment WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3`
_, err := a.db.Exec(query, a.Channel.ChannelID, a.Channel.Receiver, a.ID)
if err != nil {
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err)
a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.ID, a.Channel, err)
panic(err)
}
}

View File

@@ -7,8 +7,9 @@ import (
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"go.mau.fi/mautrix-discord/database/upgrades"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/mautrix-discord/database/upgrades"
)
type Database struct {
@@ -18,6 +19,7 @@ type Database struct {
Portal *PortalQuery
Puppet *PuppetQuery
Message *MessageQuery
Thread *ThreadQuery
Reaction *ReactionQuery
Attachment *AttachmentQuery
Emoji *EmojiQuery
@@ -54,6 +56,10 @@ func New(baseDB *dbutil.Database) *Database {
db: db,
log: db.Log.Sub("Message"),
}
db.Thread = &ThreadQuery{
db: db,
log: db.Log.Sub("Thread"),
}
db.Reaction = &ReactionQuery{
db: db,
log: db.Log.Sub("Reaction"),
@@ -72,3 +78,10 @@ func New(baseDB *dbutil.Database) *Database {
}
return db
}
func strPtr(val string) *string {
if val == "" {
return nil
}
return &val
}

View File

@@ -28,23 +28,16 @@ func (eq *EmojiQuery) New() *Emoji {
func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji {
query := emojiSelect + " WHERE discord_id=$1"
return eq.get(query, discordID)
}
func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji {
query := emojiSelect + " WHERE matrix_url=$1"
return eq.get(query, matrixURL.String())
}
func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji {
row := eq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return eq.New().Scan(row)
return eq.New().Scan(eq.db.QueryRow(query, args...))
}
type Emoji struct {
@@ -59,18 +52,17 @@ type Emoji struct {
func (e *Emoji) Scan(row dbutil.Scannable) *Emoji {
var matrixURL sql.NullString
err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL)
err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
e.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
e.MatrixURL, _ = id.ParseContentURI(matrixURL.String)
return e
}
@@ -83,6 +75,7 @@ func (e *Emoji) Insert() {
if err != nil {
e.log.Warnfln("Failed to insert emoji %s: %v", e.DiscordID, err)
panic(err)
}
}
@@ -92,17 +85,15 @@ func (e *Emoji) Delete() {
_, err := e.db.Exec(query, e.DiscordID)
if err != nil {
e.log.Warnfln("Failed to delete emoji %s: %v", e.DiscordID, err)
panic(err)
}
}
func (e *Emoji) APIName() string {
if e.DiscordID != "" && e.DiscordName != "" {
return e.DiscordName + ":" + e.DiscordID
}
if e.DiscordName != "" {
} else if e.DiscordName != "" {
return e.DiscordName
}
return e.DiscordID
}

View File

@@ -3,9 +3,9 @@ package database
import (
"database/sql"
"errors"
"fmt"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
@@ -16,7 +16,7 @@ type GuildQuery struct {
}
const (
guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild"
guildSelect = "SELECT dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels FROM guild"
)
func (gq *GuildQuery) New() *Guild {
@@ -26,109 +26,100 @@ func (gq *GuildQuery) New() *Guild {
}
}
func (gq *GuildQuery) Get(discordID, guildID string) *Guild {
query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2"
func (gq *GuildQuery) GetByID(dcid string) *Guild {
query := guildSelect + " WHERE dcid=$1"
return gq.New().Scan(gq.db.QueryRow(query, dcid))
}
row := gq.db.QueryRow(query, discordID, guildID)
if row == nil {
func (gq *GuildQuery) GetByMXID(mxid id.RoomID) *Guild {
query := guildSelect + " WHERE mxid=$1"
return gq.New().Scan(gq.db.QueryRow(query, mxid))
}
func (gq *GuildQuery) GetAll() []*Guild {
rows, err := gq.db.Query(guildSelect)
if err != nil {
gq.log.Errorln("Failed to query guilds:", err)
return nil
}
return gq.New().Scan(row)
}
func (gq *GuildQuery) GetAll(discordID string) []*Guild {
query := guildSelect + " WHERE discord_id=$1"
rows, err := gq.db.Query(query, discordID)
if err != nil || rows == nil {
return nil
}
guilds := []*Guild{}
var guilds []*Guild
for rows.Next() {
guilds = append(guilds, gq.New().Scan(rows))
guild := gq.New().Scan(rows)
if guild != nil {
guilds = append(guilds, guild)
}
}
return guilds
}
func (gq *GuildQuery) Prune(discordID string, guilds []string) {
// We need this interface slice because a variadic function can't mix
// arguements with a `...` expanded slice.
args := []interface{}{discordID}
nGuilds := len(guilds)
if nGuilds <= 0 {
return
}
gq.log.Debugfln("prunning guilds for %s", discordID)
// Build the in query
inQuery := "$2"
for i := 1; i < nGuilds; i++ {
inQuery += fmt.Sprintf(", $%d", i+2)
}
// Add the arguements for the build query
for _, guildID := range guilds {
args = append(args, guildID)
}
// Now remove any guilds that the user has left.
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" +
inQuery + ")"
_, err := gq.db.Exec(query, args...)
if err != nil {
gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err)
}
}
type Guild struct {
db *Database
log log.Logger
DiscordID string
GuildID string
GuildName string
Bridge bool
ID string
MXID id.RoomID
Name string
NameSet bool
Avatar string
AvatarURL id.ContentURI
AvatarSet bool
AutoBridgeChannels bool
}
func (g *Guild) Scan(row dbutil.Scannable) *Guild {
err := row.Scan(&g.DiscordID, &g.GuildID, &g.GuildName, &g.Bridge)
var mxid sql.NullString
var avatarURL string
err := row.Scan(&g.ID, &mxid, &g.Name, &g.NameSet, &g.Avatar, &avatarURL, &g.AvatarSet, &g.AutoBridgeChannels)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
g.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
g.MXID = id.RoomID(mxid.String)
g.AvatarURL, _ = id.ParseContentURI(avatarURL)
return g
}
func (g *Guild) Upsert() {
query := "INSERT INTO guild" +
" (discord_id, guild_id, guild_name, bridge)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT(discord_id, guild_id)" +
" DO UPDATE SET guild_name=excluded.guild_name, bridge=excluded.bridge"
_, err := g.db.Exec(query, g.DiscordID, g.GuildID, g.GuildName, g.Bridge)
func (g *Guild) mxidPtr() *id.RoomID {
if g.MXID != "" {
return &g.MXID
}
return nil
}
func (g *Guild) Insert() {
query := `
INSERT INTO guild (dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`
_, err := g.db.Exec(query, g.ID, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels)
if err != nil {
g.log.Warnfln("Failed to upsert guild %s for %s: %v", g.GuildID, g.DiscordID, err)
g.log.Warnfln("Failed to insert %s: %v", g.ID, err)
panic(err)
}
}
func (g *Guild) Update() {
query := `
UPDATE guild SET mxid=$1, name=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, auto_bridge_channels=$7
WHERE dcid=$8
`
_, err := g.db.Exec(query, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels, g.ID)
if err != nil {
g.log.Warnfln("Failed to update %s: %v", g.ID, err)
panic(err)
}
}
func (g *Guild) Delete() {
query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id=$2"
_, err := g.db.Exec(query, g.DiscordID, g.GuildID)
_, err := g.db.Exec("DELETE FROM guild WHERE dcid=$1", g.ID)
if err != nil {
g.log.Warnfln("Failed to delete guild %s for user %s: %v", g.GuildID, g.DiscordID, err)
g.log.Warnfln("Failed to delete %s: %v", g.ID, err)
panic(err)
}
}

View File

@@ -17,7 +17,7 @@ type MessageQuery struct {
}
const (
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid FROM message"
messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message"
)
func (mq *MessageQuery) New() *Message {
@@ -45,14 +45,12 @@ func (mq *MessageQuery) GetAll(key PortalKey) []*Message {
func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3"
row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)
if row == nil {
mq.log.Debugfln("failed to find existing message for discord_id %s", discordID)
return nil
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID))
}
return mq.New().Scan(row)
func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message {
query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC LIMIT 1"
return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID))
}
func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message {
@@ -74,36 +72,54 @@ type Message struct {
Channel PortalKey
SenderID string
Timestamp time.Time
ThreadID string
MXID id.EventID
}
func (m *Message) DiscordProtoChannelID() string {
if m.ThreadID != "" {
return m.ThreadID
} else {
return m.Channel.ChannelID
}
}
func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts int64
var threadID sql.NullString
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.MXID)
err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
if ts != 0 {
m.Timestamp = time.Unix(ts, 0)
m.Timestamp = time.UnixMilli(ts)
}
m.ThreadID = threadID.String
return m
}
func (m *Message) Insert() {
query := "INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid) VALUES ($1, $2, $3, $4, $5, $6)"
query := `
INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.Unix(), m.MXID)
_, err := m.db.Exec(query,
m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID)
if err != nil {
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)
panic(err)
}
}
@@ -114,5 +130,6 @@ func (m *Message) Delete() {
if err != nil {
m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err)
panic(err)
}
}

View File

@@ -12,11 +12,30 @@ import (
)
const (
portalSelect = "SELECT dcid, receiver, mxid, name, topic, avatar," +
" avatar_url, type, other_user_id, first_event_id, encrypted" +
portalSelect = "SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, " +
" mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, in_space, first_event_id" +
" FROM portal"
)
type PortalKey struct {
ChannelID string
Receiver string
}
func NewPortalKey(channelID, receiver string) PortalKey {
return PortalKey{
ChannelID: channelID,
Receiver: receiver,
}
}
func (key PortalKey) String() string {
if key.Receiver == "" {
return key.ChannelID
}
return key.ChannelID + "-" + key.Receiver
}
type PortalQuery struct {
db *Database
log log.Logger
@@ -34,7 +53,7 @@ func (pq *PortalQuery) GetAll() []*Portal {
}
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
return pq.get(portalSelect+" WHERE dcid=$1 AND receiver=$2", key.ChannelID, key.Receiver)
return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
@@ -67,12 +86,7 @@ func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
return pq.New().Scan(pq.db.QueryRow(query, args...))
}
type Portal struct {
@@ -82,78 +96,87 @@ type Portal struct {
Key PortalKey
Type discordgo.ChannelType
OtherUserID string
ParentID string
GuildID string
MXID id.RoomID
Name string
NameSet bool
Topic string
TopicSet bool
Avatar string
AvatarURL id.ContentURI
AvatarSet bool
Encrypted bool
InSpace id.RoomID
FirstEventID id.EventID
}
func (p *Portal) Scan(row dbutil.Scannable) *Portal {
var mxid, avatarURL, firstEventID sql.NullString
var typ sql.NullInt32
var otherUserID, guildID, parentID, mxid, firstEventID sql.NullString
var chanType int32
var avatarURL string
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.OtherUserID, &firstEventID,
&p.Encrypted)
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID,
&mxid, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet,
&p.Encrypted, &p.InSpace, &firstEventID)
if err != nil {
if err != sql.ErrNoRows {
p.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
p.MXID = id.RoomID(mxid.String)
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
p.Type = discordgo.ChannelType(typ.Int32)
p.OtherUserID = otherUserID.String
p.GuildID = guildID.String
p.ParentID = parentID.String
p.Type = discordgo.ChannelType(chanType)
p.FirstEventID = id.EventID(firstEventID.String)
p.AvatarURL, _ = id.ParseContentURI(avatarURL)
return p
}
func (p *Portal) mxidPtr() *id.RoomID {
if p.MXID != "" {
return &p.MXID
}
return nil
}
func (p *Portal) Insert() {
query := "INSERT INTO portal" +
" (dcid, receiver, mxid, name, topic, avatar, avatar_url," +
" type, other_user_id, first_event_id, encrypted)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.OtherUserID,
p.FirstEventID.String(), p.Encrypted)
query := `
INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid,
name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, in_space, first_event_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
`
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type,
strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
p.Encrypted, p.InSpace, p.FirstEventID.String())
if err != nil {
p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
panic(err)
}
}
func (p *Portal) Update() {
query := "UPDATE portal SET" +
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
" other_user_id=$7, first_event_id=$8, encrypted=$9" +
" WHERE dcid=$10 AND receiver=$11"
_, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(),
p.Encrypted,
query := `
UPDATE portal SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5,
name=$6, name_set=$7, topic=$8, topic_set=$9, avatar=$10, avatar_url=$11, avatar_set=$12,
encrypted=$13, in_space=$14, first_event_id=$15
WHERE dcid=$16 AND receiver=$17
`
_, err := p.db.Exec(query,
p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)),
p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
p.Encrypted, p.InSpace, p.FirstEventID.String(),
p.Key.ChannelID, p.Key.Receiver)
if err != nil {
p.log.Warnfln("Failed to update %s: %v", p.Key, err)
panic(err)
}
}
@@ -162,5 +185,6 @@ func (p *Portal) Delete() {
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver)
if err != nil {
p.log.Warnfln("Failed to delete %s: %v", p.Key, err)
panic(err)
}
}

View File

@@ -1,20 +0,0 @@
package database
type PortalKey struct {
ChannelID string
Receiver string
}
func NewPortalKey(channelID, receiver string) PortalKey {
return PortalKey{
ChannelID: channelID,
Receiver: receiver,
}
}
func (key PortalKey) String() string {
if key.Receiver == "" {
return key.ChannelID
}
return key.ChannelID + "-" + key.Receiver
}

View File

@@ -10,7 +10,7 @@ import (
)
const (
puppetSelect = "SELECT id, display_name, avatar, avatar_url," +
puppetSelect = "SELECT id, name, name_set, avatar, avatar_url, avatar_set," +
" custom_mxid, access_token, next_batch" +
" FROM puppet "
)
@@ -36,12 +36,7 @@ func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
}
func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
return pq.New().Scan(pq.db.QueryRow(query, args...))
}
func (pq *PuppetQuery) GetAll() []*Puppet {
@@ -59,7 +54,7 @@ func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet {
}
defer rows.Close()
puppets := []*Puppet{}
var puppets []*Puppet
for rows.Next() {
puppets = append(puppets, pq.New().Scan(rows))
}
@@ -72,10 +67,11 @@ type Puppet struct {
log log.Logger
ID string
DisplayName string
Name string
NameSet bool
Avatar string
AvatarURL id.ContentURI
AvatarSet bool
CustomMXID id.UserID
AccessToken string
@@ -83,24 +79,22 @@ type Puppet struct {
}
func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
var did, displayName, avatar, avatarURL sql.NullString
var avatarURL string
var customMXID, accessToken, nextBatch sql.NullString
err := row.Scan(&did, &displayName, &avatar, &avatarURL,
err := row.Scan(&p.ID, &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.AvatarSet,
&customMXID, &accessToken, &nextBatch)
if err != nil {
if err != sql.ErrNoRows {
p.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
p.ID = did.String
p.DisplayName = displayName.String
p.Avatar = avatar.String
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
p.AvatarURL, _ = id.ParseContentURI(avatarURL)
p.CustomMXID = id.UserID(customMXID.String)
p.AccessToken = accessToken.String
p.NextBatch = nextBatch.String
@@ -109,31 +103,31 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet {
}
func (p *Puppet) Insert() {
query := "INSERT INTO puppet" +
" (id, display_name, avatar, avatar_url," +
" custom_mxid, access_token, next_batch)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7)"
_, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar,
p.AvatarURL.String(), p.CustomMXID, p.AccessToken,
p.NextBatch)
query := `
INSERT INTO puppet (id, name, name_set, avatar, avatar_url, avatar_set, custom_mxid, access_token, next_batch)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`
_, err := p.db.Exec(query, p.ID, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch))
if err != nil {
p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
panic(err)
}
}
func (p *Puppet) Update() {
query := "UPDATE puppet" +
" SET display_name=$1, avatar=$2, avatar_url=$3, " +
" custom_mxid=$4, access_token=$5, next_batch=$6" +
" WHERE id=$7"
_, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(),
p.CustomMXID, p.AccessToken, p.NextBatch,
query := `
UPDATE puppet SET name=$1, name_set=$2, avatar=$3, avatar_url=$4, avatar_set=$5,
custom_mxid=$6, access_token=$7, next_batch=$8
WHERE id=$9
`
_, err := p.db.Exec(query, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet,
strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch),
p.ID)
if err != nil {
p.log.Warnfln("Failed to update %s: %v", p.ID, err)
panic(err)
}
}

View File

@@ -16,7 +16,7 @@ type ReactionQuery struct {
}
const (
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, mxid FROM reaction"
reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, mxid FROM reaction"
)
func (rq *ReactionQuery) New() *Reaction {
@@ -75,15 +75,17 @@ type Reaction struct {
MessageID string
Sender string
EmojiName string
ThreadID string
MXID id.EventID
}
func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.MXID)
err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.ThreadID, &r.MXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
r.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
@@ -91,14 +93,23 @@ func (r *Reaction) Scan(row dbutil.Scannable) *Reaction {
return r
}
func (r *Reaction) DiscordProtoChannelID() string {
if r.ThreadID != "" {
return r.ThreadID
} else {
return r.Channel.ChannelID
}
}
func (r *Reaction) Insert() {
query := `
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, mxid)
VALUES($1, $2, $3, $4, $5, $6)
INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid)
VALUES($1, $2, $3, $4, $5, $6, $7)
`
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, r.MXID)
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID)
if err != nil {
r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err)
panic(err)
}
}
@@ -107,5 +118,6 @@ func (r *Reaction) Delete() {
_, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName)
if err != nil {
r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.MessageID, r.Channel, err)
panic(err)
}
}

105
database/thread.go Normal file
View File

@@ -0,0 +1,105 @@
package database
import (
"database/sql"
"errors"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
type ThreadQuery struct {
db *Database
log log.Logger
}
const (
threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread"
)
func (tq *ThreadQuery) New() *Thread {
return &Thread{
db: tq.db,
log: tq.log,
}
}
func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread {
query := threadSelect + " WHERE dcid=$1"
row := tq.db.QueryRow(query, discordID)
if row == nil {
return nil
}
return tq.New().Scan(row)
}
//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread {
// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2"
//
// row := tq.db.QueryRow(query, channelID, messageID)
// if row == nil {
// return nil
// }
//
// return tq.New().Scan(row)
//}
func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread {
query := threadSelect + " WHERE root_msg_mxid=$1"
row := tq.db.QueryRow(query, mxid)
if row == nil {
return nil
}
return tq.New().Scan(row)
}
type Thread struct {
db *Database
log log.Logger
ID string
ParentID string
RootDiscordID string
RootMXID id.EventID
}
func (t *Thread) Scan(row dbutil.Scannable) *Thread {
err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
t.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
return t
}
func (t *Thread) Insert() {
query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)"
_, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID)
if err != nil {
t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err)
panic(err)
}
}
func (t *Thread) Delete() {
query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2"
_, err := t.db.Exec(query, t.ID, t.ParentID)
if err != nil {
t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err)
panic(err)
}
}

View File

@@ -1,29 +1,65 @@
-- v0 -> v2: Latest revision
-- v0 -> v3: Latest revision
CREATE TABLE guild (
dcid TEXT PRIMARY KEY,
mxid TEXT UNIQUE,
name TEXT NOT NULL,
name_set BOOLEAN NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
avatar_set BOOLEAN NOT NULL,
auto_bridge_channels BOOLEAN NOT NULL
);
CREATE TABLE portal (
dcid TEXT,
receiver TEXT,
other_user_id TEXT,
type INTEGER,
type INTEGER NOT NULL,
dc_guild_id TEXT,
dc_parent_id TEXT,
-- This is not accessed by the bridge, it's only used for the portal parent foreign key.
-- Only guild channels have parents, but only DMs have a receiver field.
dc_parent_receiver TEXT NOT NULL DEFAULT '',
mxid TEXT UNIQUE,
name TEXT NOT NULL,
name_set BOOLEAN NOT NULL,
topic TEXT NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
encrypted BOOLEAN NOT NULL DEFAULT false,
avatar_set BOOLEAN NOT NULL,
encrypted BOOLEAN NOT NULL,
in_space TEXT NOT NULL,
first_event_id TEXT NOT NULL,
PRIMARY KEY (dcid, receiver)
PRIMARY KEY (dcid, receiver),
CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE,
CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE
);
CREATE TABLE thread (
dcid TEXT PRIMARY KEY,
parent_chan_id TEXT NOT NULL,
root_msg_dcid TEXT NOT NULL,
root_msg_mxid TEXT NOT NULL,
-- This is also not accessed by the bridge.
receiver TEXT NOT NULL DEFAULT '',
CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
);
CREATE TABLE puppet (
id TEXT PRIMARY KEY,
name TEXT,
avatar TEXT,
avatar_url TEXT,
name TEXT NOT NULL,
name_set BOOLEAN NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
avatar_set BOOLEAN NOT NULL,
custom_mxid TEXT,
access_token TEXT,
@@ -34,9 +70,20 @@ CREATE TABLE "user" (
mxid TEXT PRIMARY KEY,
dcid TEXT UNIQUE,
discord_token TEXT,
management_room TEXT,
space_room TEXT
);
token TEXT
CREATE TABLE user_guild (
guild_id TEXT,
user_mxid TEXT,
in_space BOOLEAN NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (guild_id, user_mxid),
CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE,
CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE
);
CREATE TABLE message (
@@ -45,6 +92,7 @@ CREATE TABLE message (
dc_chan_receiver TEXT,
dc_sender TEXT NOT NULL,
timestamp BIGINT NOT NULL,
dc_thread_id TEXT,
mxid TEXT NOT NULL UNIQUE,
@@ -58,6 +106,7 @@ CREATE TABLE reaction (
dc_msg_id TEXT,
dc_sender TEXT,
dc_emoji_name TEXT,
dc_thread_id TEXT,
mxid TEXT NOT NULL UNIQUE,
@@ -70,6 +119,7 @@ CREATE TABLE attachment (
dc_msg_id TEXT,
dc_chan_id TEXT,
dc_chan_receiver TEXT,
dc_thread_id TEXT,
mxid TEXT NOT NULL UNIQUE,
@@ -82,11 +132,3 @@ CREATE TABLE emoji (
discord_name TEXT,
matrix_url TEXT
);
CREATE TABLE guild (
discord_id TEXT NOT NULL,
guild_id TEXT NOT NULL,
guild_name TEXT NOT NULL,
bridge BOOLEAN DEFAULT FALSE,
PRIMARY KEY(discord_id, guild_id)
);

View File

@@ -0,0 +1,72 @@
-- v3: Store portal parent metadata for spaces
DROP TABLE guild;
CREATE TABLE guild (
dcid TEXT PRIMARY KEY,
mxid TEXT UNIQUE,
name TEXT NOT NULL,
name_set BOOLEAN NOT NULL,
avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL,
avatar_set BOOLEAN NOT NULL,
auto_bridge_channels BOOLEAN NOT NULL
);
CREATE TABLE user_guild (
guild_id TEXT,
user_mxid TEXT,
in_space BOOLEAN NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (guild_id, user_mxid),
CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE,
CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE
);
ALTER TABLE portal ADD COLUMN dc_guild_id TEXT;
ALTER TABLE portal ADD COLUMN dc_parent_id TEXT;
ALTER TABLE portal ADD COLUMN dc_parent_receiver TEXT NOT NULL DEFAULT '';
ALTER TABLE portal ADD CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE;
ALTER TABLE portal ADD CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE;
DELETE FROM portal WHERE type IS NULL;
-- only: postgres
ALTER TABLE portal ALTER COLUMN type SET NOT NULL;
ALTER TABLE portal ADD COLUMN in_space TEXT NOT NULL DEFAULT '';
ALTER TABLE portal ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE portal ADD COLUMN topic_set BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE portal ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false;
-- only: postgres for next 5 lines
ALTER TABLE portal ALTER COLUMN in_space DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN name_set DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN topic_set DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN avatar_set DROP DEFAULT;
ALTER TABLE portal ALTER COLUMN encrypted DROP DEFAULT;
ALTER TABLE puppet RENAME COLUMN display_name TO name;
ALTER TABLE puppet ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false;
ALTER TABLE puppet ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false;
-- only: postgres for next 2 lines
ALTER TABLE puppet ALTER COLUMN name_set DROP DEFAULT;
ALTER TABLE puppet ALTER COLUMN avatar_set DROP DEFAULT;
ALTER TABLE "user" ADD COLUMN space_room TEXT;
ALTER TABLE "user" RENAME COLUMN token TO discord_token;
UPDATE message SET timestamp=timestamp*1000;
CREATE TABLE thread (
dcid TEXT PRIMARY KEY,
parent_chan_id TEXT NOT NULL,
root_msg_dcid TEXT NOT NULL,
root_msg_mxid TEXT NOT NULL,
-- This is also not accessed by the bridge.
receiver TEXT NOT NULL DEFAULT '',
CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE
);
ALTER TABLE message ADD COLUMN dc_thread_id TEXT;
ALTER TABLE attachment ADD COLUMN dc_thread_id TEXT;
ALTER TABLE reaction ADD COLUMN dc_thread_id TEXT;

View File

@@ -2,6 +2,8 @@ package database
import (
"database/sql"
"errors"
"time"
log "maunium.net/go/maulogger/v2"
@@ -22,38 +24,32 @@ func (uq *UserQuery) New() *User {
}
func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE mxid=$1`
row := uq.db.QueryRow(query, userID)
if row == nil {
return nil
}
return uq.New().Scan(row)
query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE mxid=$1`
return uq.New().Scan(uq.db.QueryRow(query, userID))
}
func (uq *UserQuery) GetByID(id string) *User {
query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE dcid=$1`
row := uq.db.QueryRow(query, id)
if row == nil {
return nil
query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE dcid=$1`
return uq.New().Scan(uq.db.QueryRow(query, id))
}
return uq.New().Scan(row)
}
func (uq *UserQuery) GetAll() []*User {
rows, err := uq.db.Query(`SELECT mxid, dcid, management_room, token FROM "user" WHERE token IS NOT NULL`)
func (uq *UserQuery) GetAllWithToken() []*User {
query := `
SELECT mxid, dcid, discord_token, management_room, space_room
FROM "user" WHERE discord_token IS NOT NULL
`
rows, err := uq.db.Query(query)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
users := []*User{}
var users []*User
for rows.Next() {
users = append(users, uq.New().Scan(rows))
user := uq.New().Scan(rows)
if user != nil {
users = append(users, user)
}
}
return users
}
@@ -62,79 +58,113 @@ type User struct {
log log.Logger
MXID id.UserID
ID string
DiscordID string
DiscordToken string
ManagementRoom id.RoomID
SpaceRoom id.RoomID
}
Token string
type UserGuild struct {
GuildID string
Timestamp time.Time
InSpace bool
}
func (u *User) GetGuilds() []UserGuild {
res, err := u.db.Query("SELECT guild_id, timestamp, in_space FROM user_guild WHERE user_mxid=$1", u.MXID)
if err != nil {
u.log.Errorln("Failed to get guilds:", err)
panic(err)
return nil
}
var guilds []UserGuild
for res.Next() {
var guild UserGuild
var ts int64
err = res.Scan(&guild.GuildID, &ts, &guild.InSpace)
if err != nil {
u.log.Errorln("Error scanning user guild:", err)
panic(err)
} else {
guild.Timestamp = time.UnixMilli(ts)
guilds = append(guilds, guild)
}
}
return guilds
}
func (u *User) IsInSpace(guildID string) (isIn bool) {
query := `SELECT in_space FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
err := u.db.QueryRow(query, u.MXID, guildID).Scan(&isIn)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, guildID, err)
panic(err)
}
return
}
func (u *User) MarkInGuild(guild UserGuild) {
query := `
INSERT INTO user_guild (guild_id, user_mxid, timestamp, in_space)
VALUES ($1, $2, $3, $4)
ON CONFLICT (guild_id, user_mxid) DO UPDATE
SET timestamp=excluded.timestamp, in_space=excluded.in_space
`
_, err := u.db.Exec(query, guild.GuildID, u.MXID, guild.Timestamp.UnixMilli(), guild.InSpace)
if err != nil {
u.log.Errorfln("Failed to insert user guild %s/%s: %v", u.MXID, guild.GuildID, err)
panic(err)
}
}
func (u *User) MarkNotInGuild(guildID string) {
query := `DELETE FROM user_guild WHERE user_mxid=$1 AND guild_id=$2`
_, err := u.db.Exec(query, u.MXID, guildID)
if err != nil {
u.log.Errorfln("Failed to remove user guild %s/%s: %v", u.MXID, guildID, err)
panic(err)
}
}
func (u *User) PruneGuildList(beforeTS time.Time) {
_, err := u.db.Exec("DELETE FROM user_guild WHERE user_mxid=$1 AND timestamp<$2", u.MXID, beforeTS.UnixMilli())
if err != nil {
u.log.Errorln("Failed to prune user guild list:", err)
panic(err)
}
}
func (u *User) Scan(row dbutil.Scannable) *User {
var token sql.NullString
var discordID sql.NullString
err := row.Scan(&u.MXID, &discordID, &u.ManagementRoom, &token)
var discordID, managementRoom, spaceRoom, discordToken sql.NullString
err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom)
if err != nil {
if err != sql.ErrNoRows {
u.log.Errorln("Database scan failed:", err)
panic(err)
}
return nil
}
if token.Valid {
u.Token = token.String
}
if discordID.Valid {
u.ID = discordID.String
}
u.DiscordID = discordID.String
u.DiscordToken = discordToken.String
u.ManagementRoom = id.RoomID(managementRoom.String)
u.SpaceRoom = id.RoomID(spaceRoom.String)
return u
}
func (u *User) Insert() {
query := "INSERT INTO \"user\" (mxid, dcid, management_room, token) VALUES ($1, $2, $3, $4)"
var token sql.NullString
var discordID sql.NullString
if u.Token != "" {
token.String = u.Token
token.Valid = true
}
if u.ID != "" {
discordID.String = u.ID
discordID.Valid = true
}
_, err := u.db.Exec(query, u.MXID, discordID, u.ManagementRoom, token)
query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room) VALUES ($1, $2, $3, $4, $5)`
_, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)))
if err != nil {
u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
panic(err)
}
}
func (u *User) Update() {
query := "UPDATE \"user\" SET dcid=$1, management_room=$2, token=$3 WHERE mxid=$4"
var token sql.NullString
var discordID sql.NullString
if u.Token != "" {
token.String = u.Token
token.Valid = true
}
if u.ID != "" {
discordID.String = u.ID
discordID.Valid = true
}
_, err := u.db.Exec(query, discordID, u.ManagementRoom, token, u.MXID)
query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4 WHERE mxid=$5`
_, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), u.MXID)
if err != nil {
u.log.Warnfln("Failed to update %q: %v", u.MXID, err)
panic(err)
}
}

6
go.mod
View File

@@ -3,14 +3,14 @@ module go.mau.fi/mautrix-discord
go 1.17
require (
github.com/bwmarrin/discordgo v0.23.2
github.com/bwmarrin/discordgo v0.25.0
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/lib/pq v1.10.6
github.com/mattn/go-sqlite3 v1.14.13
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
maunium.net/go/maulogger/v2 v2.3.2
maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365
maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf
)
require (
@@ -27,4 +27,4 @@ require (
maunium.net/go/mauflag v1.0.0 // indirect
)
replace github.com/bwmarrin/discordgo v0.23.2 => gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685
replace github.com/bwmarrin/discordgo => gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7

14
go.sum
View File

@@ -3,7 +3,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs=
@@ -28,11 +28,12 @@ github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc=
github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM=
github.com/yuin/goldmark v1.4.12 h1:6hffw6vALvEDqJ19dOJvJKOoAOKe4NDaTqvd2sktGN0=
github.com/yuin/goldmark v1.4.12/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685 h1:WLswEvEb4wi02p+VjMl8slhG/4+3pdfx04xQfqQjpoc=
gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685/go.mod h1:Hwfv4M8yP/MDh47BN+4Z1WItJ1umLKUyplCH5KcQPgE=
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7 h1:S8hbrkgKGU4aU5kXW4d8CA/9ayi8ymI3QU6yg/aWfUw=
gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c=
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220513224357-95641704303c h1:nF9mHSvoKBLkQNQhJZNsc66z2UzAMUbLGjC95CF3pU0=
golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=
@@ -44,6 +45,7 @@ golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16C
golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -56,5 +58,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0=
maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365 h1:ghL/KMHiMdcwKmuxvVYZvwe6Yk7BHeWlRfoHZovDA2Q=
maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365/go.mod h1:CiKpMhAx5QZFHK03jpWb0iKI3sGU8x6+LfsOjDrcO8I=
maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf h1:hNCuEyp9oZPM08n7UsI/gq95ptbhOV1A8heFoH/eiM0=
maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf/go.mod h1:CiKpMhAx5QZFHK03jpWb0iKI3sGU8x6+LfsOjDrcO8I=

287
guildportal.go Normal file
View File

@@ -0,0 +1,287 @@
// mautrix-discord - A Matrix-Discord puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package main
import (
"fmt"
"sync"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"github.com/bwmarrin/discordgo"
"go.mau.fi/mautrix-discord/database"
)
type Guild struct {
*database.Guild
bridge *DiscordBridge
log log.Logger
roomCreateLock sync.Mutex
}
func (br *DiscordBridge) loadGuild(dbGuild *database.Guild, id string, createIfNotExist bool) *Guild {
if dbGuild == nil {
if id == "" || !createIfNotExist {
return nil
}
dbGuild = br.DB.Guild.New()
dbGuild.ID = id
dbGuild.Insert()
}
guild := br.NewGuild(dbGuild)
br.guildsByID[guild.ID] = guild
if guild.MXID != "" {
br.guildsByMXID[guild.MXID] = guild
}
return guild
}
func (br *DiscordBridge) GetGuildByMXID(mxid id.RoomID) *Guild {
br.guildsLock.Lock()
defer br.guildsLock.Unlock()
portal, ok := br.guildsByMXID[mxid]
if !ok {
return br.loadGuild(br.DB.Guild.GetByMXID(mxid), "", false)
}
return portal
}
func (br *DiscordBridge) GetGuildByID(id string, createIfNotExist bool) *Guild {
br.guildsLock.Lock()
defer br.guildsLock.Unlock()
guild, ok := br.guildsByID[id]
if !ok {
return br.loadGuild(br.DB.Guild.GetByID(id), id, createIfNotExist)
}
return guild
}
func (br *DiscordBridge) GetAllGuilds() []*Guild {
return br.dbGuildsToGuilds(br.DB.Guild.GetAll())
}
func (br *DiscordBridge) dbGuildsToGuilds(dbGuilds []*database.Guild) []*Guild {
br.guildsLock.Lock()
defer br.guildsLock.Unlock()
output := make([]*Guild, len(dbGuilds))
for index, dbGuild := range dbGuilds {
if dbGuild == nil {
continue
}
guild, ok := br.guildsByID[dbGuild.ID]
if !ok {
guild = br.loadGuild(dbGuild, "", false)
}
output[index] = guild
}
return output
}
func (br *DiscordBridge) NewGuild(dbGuild *database.Guild) *Guild {
guild := &Guild{
Guild: dbGuild,
bridge: br,
log: br.Log.Sub(fmt.Sprintf("Guild/%s", dbGuild.ID)),
}
return guild
}
func (guild *Guild) getBridgeInfo() (string, event.BridgeEventContent) {
bridgeInfo := event.BridgeEventContent{
BridgeBot: guild.bridge.Bot.UserID,
Creator: guild.bridge.Bot.UserID,
Protocol: event.BridgeInfoSection{
ID: "discord",
DisplayName: "Discord",
AvatarURL: guild.bridge.Config.AppService.Bot.ParsedAvatar.CUString(),
ExternalURL: "https://discord.com/",
},
Channel: event.BridgeInfoSection{
ID: guild.ID,
DisplayName: guild.Name,
AvatarURL: guild.AvatarURL.CUString(),
},
}
bridgeInfoStateKey := fmt.Sprintf("fi.mau.discord://discord/%s", guild.ID)
return bridgeInfoStateKey, bridgeInfo
}
func (guild *Guild) UpdateBridgeInfo() {
if len(guild.MXID) == 0 {
guild.log.Debugln("Not updating bridge info: no Matrix room created")
return
}
guild.log.Debugln("Updating bridge info...")
stateKey, content := guild.getBridgeInfo()
_, err := guild.bridge.Bot.SendStateEvent(guild.MXID, event.StateBridge, stateKey, content)
if err != nil {
guild.log.Warnln("Failed to update m.bridge:", err)
}
// TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec
_, err = guild.bridge.Bot.SendStateEvent(guild.MXID, event.StateHalfShotBridge, stateKey, content)
if err != nil {
guild.log.Warnln("Failed to update uk.half-shot.bridge:", err)
}
}
func (guild *Guild) CreateMatrixRoom(user *User, meta *discordgo.Guild) error {
guild.roomCreateLock.Lock()
defer guild.roomCreateLock.Unlock()
if guild.MXID != "" {
return nil
}
guild.log.Infoln("Creating Matrix room for guild")
guild.UpdateInfo(user, meta)
bridgeInfoStateKey, bridgeInfo := guild.getBridgeInfo()
initialState := []*event.Event{{
Type: event.StateBridge,
Content: event.Content{Parsed: bridgeInfo},
StateKey: &bridgeInfoStateKey,
}, {
// TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec
Type: event.StateHalfShotBridge,
Content: event.Content{Parsed: bridgeInfo},
StateKey: &bridgeInfoStateKey,
}}
if !guild.AvatarURL.IsEmpty() {
initialState = append(initialState, &event.Event{
Type: event.StateRoomAvatar,
Content: event.Content{Parsed: &event.RoomAvatarEventContent{
URL: guild.AvatarURL,
}},
})
}
creationContent := map[string]interface{}{
"type": event.RoomTypeSpace,
}
if !guild.bridge.Config.Bridge.FederateRooms {
creationContent["m.federate"] = false
}
resp, err := guild.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{
Visibility: "private",
Name: guild.Name,
Preset: "private_chat",
InitialState: initialState,
CreationContent: creationContent,
})
if err != nil {
guild.log.Warnln("Failed to create room:", err)
return err
}
guild.MXID = resp.RoomID
guild.NameSet = true
guild.AvatarSet = !guild.AvatarURL.IsEmpty()
guild.Update()
guild.bridge.guildsLock.Lock()
guild.bridge.guildsByMXID[guild.MXID] = guild
guild.bridge.guildsLock.Unlock()
guild.log.Infoln("Matrix room created:", guild.MXID)
user.ensureInvited(nil, guild.MXID, false)
return nil
}
func (guild *Guild) UpdateInfo(source *User, meta *discordgo.Guild) *discordgo.Guild {
if meta.Unavailable {
return meta
}
changed := false
// FIXME
//name, err := guild.bridge.Config.Bridge.FormatChannelname(meta, user.Session)
//if err != nil {
// guild.log.Warnfln("failed to format name, proceeding with generic name: %v", err)
// guild.Name = meta.Name
//} else {
//}
changed = guild.UpdateName(meta.Name) || changed
changed = guild.UpdateAvatar(meta.Icon) || changed
if changed {
guild.UpdateBridgeInfo()
guild.Update()
}
return meta
}
func (guild *Guild) UpdateName(name string) bool {
if guild.Name == name && guild.NameSet {
return false
}
guild.Name = name
guild.NameSet = false
if guild.MXID != "" {
_, err := guild.bridge.Bot.SetRoomName(guild.MXID, guild.Name)
if err != nil {
guild.log.Warnln("Failed to update room name: %s", err)
} else {
guild.NameSet = true
}
}
return true
}
func (guild *Guild) UpdateAvatar(iconID string) bool {
if guild.Avatar == iconID && guild.AvatarSet {
return false
}
guild.AvatarSet = false
guild.Avatar = iconID
if guild.Avatar != "" {
var err error
guild.AvatarURL, err = uploadAvatar(guild.bridge.Bot, discordgo.EndpointGuildIcon(guild.ID, iconID))
if err != nil {
guild.log.Warnln("Failed to reupload avatar %s: %v", iconID, err)
return true
}
} else {
guild.AvatarURL = id.ContentURI{}
}
if guild.MXID != "" {
_, err := guild.bridge.Bot.SetRoomAvatar(guild.MXID, guild.AvatarURL)
if err != nil {
guild.log.Warnln("Failed to update room avatar:", err)
} else {
guild.AvatarSet = true
}
}
return true
}

17
main.go
View File

@@ -20,13 +20,13 @@ import (
_ "embed"
"sync"
"go.mau.fi/mautrix-discord/database"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/bridge/commands"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/configupgrade"
"go.mau.fi/mautrix-discord/config"
"go.mau.fi/mautrix-discord/database"
)
// Information to find out exactly which commit the bridge was built from.
@@ -59,6 +59,14 @@ type DiscordBridge struct {
portalsByID map[database.PortalKey]*Portal
portalsLock sync.Mutex
threadsByID map[string]*Thread
threadsByRootMXID map[id.EventID]*Thread
threadsLock sync.Mutex
guildsByMXID map[id.RoomID]*Guild
guildsByID map[string]*Guild
guildsLock sync.Mutex
puppets map[string]*Puppet
puppetsByCustomMXID map[id.UserID]*Puppet
puppetsLock sync.Mutex
@@ -81,6 +89,7 @@ func (br *DiscordBridge) Init() {
br.RegisterCommands()
br.DB = database.New(br.Bridge.DB)
discordLog = br.Log.Sub("Discord")
}
func (br *DiscordBridge) Start() {
@@ -144,6 +153,12 @@ func main() {
portalsByMXID: make(map[id.RoomID]*Portal),
portalsByID: make(map[database.PortalKey]*Portal),
threadsByID: make(map[string]*Thread),
threadsByRootMXID: make(map[id.EventID]*Thread),
guildsByID: make(map[string]*Guild),
guildsByMXID: make(map[id.RoomID]*Guild),
puppets: make(map[string]*Puppet),
puppetsByCustomMXID: make(map[id.UserID]*Puppet),
}

684
portal.go

File diff suppressed because it is too large Load Diff

View File

@@ -49,7 +49,6 @@ func newProvisioningAPI(br *DiscordBridge) *ProvisioningAPI {
r.HandleFunc("/logout", p.logout).Methods(http.MethodPost)
r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost)
// Setup the guild endpoints
r.HandleFunc("/guilds", p.guildsList).Methods(http.MethodGet)
r.HandleFunc("/guilds/{guildID}/bridge", p.guildsBridge).Methods(http.MethodPost)
r.HandleFunc("/guilds/{guildID}/unbridge", p.guildsUnbridge).Methods(http.MethodPost)
@@ -182,8 +181,8 @@ func (p *ProvisioningAPI) ping(w http.ResponseWriter, r *http.Request) {
}
user.Lock()
if user.ID != "" {
discord["id"] = user.ID
if user.DiscordID != "" {
discord["id"] = user.DiscordID
}
if user.Session != nil {
@@ -338,7 +337,7 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
return
}
user.ID = discordUser.UserID
user.DiscordID = discordUser.UserID
user.Update()
if err := user.Login(discordUser.Token); err != nil {
@@ -354,7 +353,7 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
c.WriteJSON(map[string]interface{}{
"success": true,
"id": user.ID,
"id": user.DiscordID,
})
return
@@ -392,19 +391,17 @@ func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) {
func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user.guildsLock.Lock()
defer user.guildsLock.Unlock()
data := make([]map[string]interface{}, len(user.guilds))
idx := 0
for _, guild := range user.guilds {
data[idx] = map[string]interface{}{
"name": guild.GuildName,
"id": guild.GuildID,
"bridged": guild.Bridge,
var data []map[string]interface{}
for _, userGuild := range user.GetGuilds() {
guild := p.bridge.GetGuildByID(userGuild.GuildID, false)
if guild == nil {
continue
}
idx++
data = append(data, map[string]interface{}{
"name": guild.Name,
"id": guild.ID,
"mxid": guild.MXID,
})
}
jsonResponse(w, http.StatusOK, data)

147
puppet.go
View File

@@ -7,6 +7,8 @@ import (
log "maunium.net/go/maulogger/v2"
"github.com/bwmarrin/discordgo"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/bridge"
"maunium.net/go/mautrix/id"
@@ -66,12 +68,12 @@ func (br *DiscordBridge) ParsePuppetMXID(mxid id.UserID) (string, bool) {
}
func (br *DiscordBridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
id, ok := br.ParsePuppetMXID(mxid)
discordID, ok := br.ParsePuppetMXID(mxid)
if !ok {
return nil
}
return br.GetPuppetByID(id)
return br.GetPuppetByID(discordID)
}
func (br *DiscordBridge) GetPuppetByID(id string) *Puppet {
@@ -159,7 +161,7 @@ func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI {
}
func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI {
if puppet.customIntent == nil {
if puppet.customIntent == nil || (portal.Key.Receiver != "" && portal.Key.Receiver != puppet.ID) {
return puppet.DefaultIntent()
}
@@ -179,121 +181,86 @@ func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) {
}
}
func (puppet *Puppet) updateName(source *User) bool {
user, err := source.Session.User(puppet.ID)
if err != nil {
puppet.log.Warnln("failed to get user from id:", err)
func (puppet *Puppet) UpdateName(info *discordgo.User) bool {
newName := puppet.bridge.Config.Bridge.FormatDisplayname(info)
if puppet.Name == newName && puppet.NameSet {
return false
}
newName := puppet.bridge.Config.Bridge.FormatDisplayname(user)
if puppet.DisplayName != newName {
puppet.Name = newName
puppet.NameSet = false
err := puppet.DefaultIntent().SetDisplayName(newName)
if err == nil {
puppet.DisplayName = newName
go puppet.updatePortalName()
puppet.Update()
if err != nil {
puppet.log.Warnln("Failed to update displayname:", err)
} else {
puppet.log.Warnln("failed to set display name:", err)
go puppet.updatePortalMeta(func(portal *Portal) {
if portal.UpdateName(puppet.Name) {
portal.Update()
}
})
puppet.NameSet = true
}
return true
}
func (puppet *Puppet) UpdateAvatar(info *discordgo.User) bool {
if puppet.Avatar == info.Avatar && puppet.AvatarSet {
return false
}
puppet.Avatar = info.Avatar
puppet.AvatarSet = false
func (puppet *Puppet) updatePortalName() {
puppet.updatePortalMeta(func(portal *Portal) {
if portal.MXID != "" {
_, err := portal.MainIntent().SetRoomName(portal.MXID, puppet.DisplayName)
if puppet.Avatar == "" {
// TODO should we just use discord's default avatars?
puppet.AvatarURL = id.ContentURI{}
} else {
url, err := uploadAvatar(puppet.DefaultIntent(), info.AvatarURL(""))
if err != nil {
portal.log.Warnln("Failed to set name:", err)
puppet.log.Warnln("Failed to reupload user avatar:", err)
return true
}
}
portal.Name = puppet.DisplayName
portal.Update()
})
}
func (puppet *Puppet) updateAvatar(source *User) bool {
user, err := source.Session.User(puppet.ID)
if err != nil {
puppet.log.Warnln("Failed to get user:", err)
return false
}
if puppet.Avatar == user.Avatar {
return false
}
if user.Avatar == "" {
puppet.log.Warnln("User does not have an avatar")
return false
}
url, err := uploadAvatar(puppet.DefaultIntent(), user.AvatarURL(""))
if err != nil {
puppet.log.Warnln("Failed to upload user avatar:", err)
return false
}
puppet.AvatarURL = url
err = puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL)
if err != nil {
puppet.log.Warnln("Failed to set avatar:", err)
}
puppet.log.Debugln("Updated avatar", puppet.Avatar, "->", user.Avatar)
puppet.Avatar = user.Avatar
go puppet.updatePortalAvatar()
err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL)
if err != nil {
puppet.log.Warnln("Failed to update avatar:", err)
} else {
go puppet.updatePortalMeta(func(portal *Portal) {
if portal.UpdateAvatarFromPuppet(puppet) {
portal.Update()
}
})
puppet.AvatarSet = true
}
return true
}
func (puppet *Puppet) updatePortalAvatar() {
puppet.updatePortalMeta(func(portal *Portal) {
if portal.MXID != "" {
_, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL)
if err != nil {
portal.log.Warnln("Failed to set avatar:", err)
}
}
portal.AvatarURL = puppet.AvatarURL
portal.Avatar = puppet.Avatar
portal.Update()
})
}
func (puppet *Puppet) SyncContact(source *User) {
func (puppet *Puppet) UpdateInfo(source *User, info *discordgo.User) {
puppet.syncLock.Lock()
defer puppet.syncLock.Unlock()
puppet.log.Debugln("syncing contact", puppet.DisplayName)
if info == nil || len(info.Username) == 0 || len(info.Discriminator) == 0 {
if puppet.Name != "" {
return
}
var err error
puppet.log.Debugfln("Fetching info through %s to update", source.DiscordID)
info, err = source.Session.User(puppet.ID)
if err != nil {
puppet.log.Errorfln("Failed to fetch info through %s: %v", source.DiscordID, err)
return
}
}
err := puppet.DefaultIntent().EnsureRegistered()
if err != nil {
puppet.log.Errorln("Failed to ensure registered:", err)
}
update := false
update = puppet.updateName(source) || update
if puppet.Avatar == "" {
update = puppet.updateAvatar(source) || update
puppet.log.Debugln("update avatar returned", update)
}
if update {
changed := false
changed = puppet.UpdateName(info) || changed
changed = puppet.UpdateAvatar(info) || changed
if changed {
puppet.Update()
}
}

View File

@@ -11,13 +11,14 @@ import (
"sync"
"github.com/gorilla/websocket"
"github.com/bwmarrin/discordgo"
)
type Client struct {
sync.Mutex
URL string
Origin string
conn *websocket.Conn
@@ -43,7 +44,6 @@ func New() (*Client, error) {
return &Client{
URL: "wss://remote-auth-gateway.discord.gg/?v=1",
Origin: "https://discord.com",
privateKey: privateKey,
}, nil
}
@@ -54,8 +54,9 @@ func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan str
c.Lock()
defer c.Unlock()
header := http.Header{
"Origin": []string{c.Origin},
header := http.Header{}
for key, value := range discordgo.DroidWSHeaders {
header.Set(key, value)
}
c.qrChan = qrChan

53
thread.go Normal file
View File

@@ -0,0 +1,53 @@
package main
import (
"maunium.net/go/mautrix/id"
"go.mau.fi/mautrix-discord/database"
)
type Thread struct {
*database.Thread
Parent *Portal
}
func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
br.threadsLock.Lock()
defer br.threadsLock.Unlock()
thread, ok := br.threadsByID[id]
if !ok {
return br.loadThread(br.DB.Thread.GetByDiscordID(id), id, root)
}
return thread
}
func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread {
br.threadsLock.Lock()
defer br.threadsLock.Unlock()
thread, ok := br.threadsByRootMXID[mxid]
if !ok {
return br.loadThread(br.DB.Thread.GetByMatrixRootMsg(mxid), "", nil)
}
return thread
}
func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread {
if dbThread == nil {
if root == nil {
return nil
}
dbThread = br.DB.Thread.New()
dbThread.ID = id
dbThread.RootDiscordID = root.DiscordID
dbThread.RootMXID = root.MXID
dbThread.ParentID = root.Channel.ChannelID
dbThread.Insert()
}
thread := &Thread{
Thread: dbThread,
}
thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, ""))
br.threadsByID[thread.ID] = thread
br.threadsByRootMXID[thread.RootMXID] = thread
return thread
}

555
user.go
View File

@@ -4,11 +4,15 @@ import (
"errors"
"fmt"
"net/http"
"os"
"runtime"
"strings"
"sync"
"time"
log "maunium.net/go/maulogger/v2"
"github.com/bwmarrin/discordgo"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
@@ -35,12 +39,43 @@ type User struct {
PermissionLevel bridgeconfig.PermissionLevel
guilds map[string]*database.Guild
guildsLock sync.Mutex
spaceCreateLock sync.Mutex
spaceMembershipChecked bool
Session *discordgo.Session
}
var discordLog log.Logger
func init() {
discordgo.Logger = func(msgL, caller int, format string, a ...interface{}) {
pc, file, line, _ := runtime.Caller(caller + 1)
files := strings.Split(file, "/")
file = files[len(files)-1]
name := runtime.FuncForPC(pc).Name()
fns := strings.Split(name, ".")
name = fns[len(fns)-1]
msg := fmt.Sprintf(format, a...)
var level log.Level
switch msgL {
case discordgo.LogError:
level = log.LevelError
case discordgo.LogWarning:
level = log.LevelWarn
case discordgo.LogInformational:
level = log.LevelInfo
case discordgo.LogDebug:
level = log.LevelDebug
}
discordLog.Logfln(level, "%s:%d:%s() %s", file, line, name, msg)
}
}
func (user *User) GetPermissionLevel() bridgeconfig.PermissionLevel {
return user.PermissionLevel
}
@@ -66,10 +101,10 @@ func (user *User) GetIDoublePuppet() bridge.DoublePuppet {
}
func (user *User) GetIGhost() bridge.Ghost {
if user.ID == "" {
if user.DiscordID == "" {
return nil
}
p := user.bridge.GetPuppetByID(user.ID)
p := user.bridge.GetPuppetByID(user.DiscordID)
if p == nil {
return nil
}
@@ -78,14 +113,6 @@ func (user *User) GetIGhost() bridge.Ghost {
var _ bridge.User = (*User)(nil)
// this assume you are holding the guilds lock!!!
func (user *User) loadGuilds() {
user.guilds = map[string]*database.Guild{}
for _, guild := range user.bridge.DB.Guild.GetAll(user.ID) {
user.guilds[guild.GuildID] = guild
}
}
func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User {
// If we weren't passed in a user we attempt to create one if we were given
// a matrix id.
@@ -103,8 +130,8 @@ func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User
// We assume the usersLock was acquired by our caller.
br.usersByMXID[user.MXID] = user
if user.ID != "" {
br.usersByID[user.ID] = user
if user.DiscordID != "" {
br.usersByID[user.DiscordID] = user
}
if user.ManagementRoom != "" {
@@ -114,17 +141,10 @@ func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User
br.managementRoomsLock.Unlock()
}
// Load our guilds state from the database and turn it into a map
user.guildsLock.Lock()
user.loadGuilds()
user.guildsLock.Unlock()
return user
}
func (br *DiscordBridge) GetUserByMXID(userID id.UserID) *User {
// TODO: check if puppet
br.usersLock.Lock()
defer br.usersLock.Unlock()
@@ -153,7 +173,6 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User {
User: dbUser,
bridge: br,
log: br.Log.Sub("User").Sub(string(dbUser.MXID)),
guilds: map[string]*database.Guild{},
}
user.PermissionLevel = br.Config.Bridge.Permissions.Get(user.MXID)
@@ -161,11 +180,11 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User {
return user
}
func (br *DiscordBridge) getAllUsers() []*User {
func (br *DiscordBridge) getAllUsersWithToken() []*User {
br.usersLock.Lock()
defer br.usersLock.Unlock()
dbUsers := br.DB.User.GetAll()
dbUsers := br.DB.User.GetAllWithToken()
users := make([]*User, len(dbUsers))
for idx, dbUser := range dbUsers {
@@ -182,7 +201,7 @@ func (br *DiscordBridge) getAllUsers() []*User {
func (br *DiscordBridge) startUsers() {
br.Log.Debugln("Starting users")
for _, u := range br.getAllUsers() {
for _, u := range br.getAllUsersWithToken() {
go func(user *User) {
err := user.Connect()
if err != nil {
@@ -209,10 +228,6 @@ func (user *User) SetManagementRoom(roomID id.RoomID) {
existing, ok := user.bridge.managementRooms[roomID]
if ok {
// If there's a user already assigned to this management room, clear it
// out.
// I think this is due a name change or something? I dunno, leaving it
// for now.
existing.ManagementRoom = ""
existing.Update()
}
@@ -222,6 +237,52 @@ func (user *User) SetManagementRoom(roomID id.RoomID) {
user.Update()
}
func (user *User) GetSpaceRoom() id.RoomID {
if len(user.SpaceRoom) == 0 {
user.spaceCreateLock.Lock()
defer user.spaceCreateLock.Unlock()
if len(user.SpaceRoom) > 0 {
return user.SpaceRoom
}
resp, err := user.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{
Visibility: "private",
Name: "Discord",
Topic: "Your Discord bridged chats",
InitialState: []*event.Event{{
Type: event.StateRoomAvatar,
Content: event.Content{
Parsed: &event.RoomAvatarEventContent{
URL: user.bridge.Config.AppService.Bot.ParsedAvatar,
},
},
}},
CreationContent: map[string]interface{}{
"type": event.RoomTypeSpace,
},
PowerLevelOverride: &event.PowerLevelsEventContent{
Users: map[id.UserID]int{
user.bridge.Bot.UserID: 9001,
user.MXID: 50,
},
},
})
if err != nil {
user.log.Errorln("Failed to auto-create space room:", err)
} else {
user.SpaceRoom = resp.RoomID
user.Update()
user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false)
}
} else if !user.spaceMembershipChecked && !user.bridge.StateStore.IsInRoom(user.SpaceRoom, user.MXID) {
user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false)
}
user.spaceMembershipChecked = true
return user.SpaceRoom
}
func (user *User) tryAutomaticDoublePuppeting() {
user.Lock()
defer user.Unlock()
@@ -232,7 +293,7 @@ func (user *User) tryAutomaticDoublePuppeting() {
user.log.Debugln("Checking if double puppeting needs to be enabled")
puppet := user.bridge.GetPuppetByID(user.ID)
puppet := user.bridge.GetPuppetByID(user.DiscordID)
if puppet.CustomMXID != "" {
user.log.Debugln("User already has double-puppeting enabled")
@@ -270,7 +331,7 @@ func (user *User) syncChatDoublePuppetDetails(portal *Portal, justCreated bool)
}
func (user *User) Login(token string) error {
user.Token = token
user.DiscordToken = token
user.Update()
return user.Connect()
}
@@ -279,7 +340,7 @@ func (user *User) IsLoggedIn() bool {
user.Lock()
defer user.Unlock()
return user.Token != ""
return user.DiscordToken != ""
}
func (user *User) Logout() error {
@@ -290,7 +351,7 @@ func (user *User) Logout() error {
return ErrNotLoggedIn
}
puppet := user.bridge.GetPuppetByID(user.ID)
puppet := user.bridge.GetPuppetByID(user.DiscordID)
if puppet.CustomMXID != "" {
err := puppet.SwitchCustomMXID("", "")
if err != nil {
@@ -304,7 +365,7 @@ func (user *User) Logout() error {
user.Session = nil
user.Token = ""
user.DiscordToken = ""
user.Update()
return nil
@@ -321,16 +382,20 @@ func (user *User) Connect() error {
user.Lock()
defer user.Unlock()
if user.Token == "" {
if user.DiscordToken == "" {
return ErrNotLoggedIn
}
user.log.Debugln("connecting to discord")
user.log.Debugln("Connecting to discord")
session, err := discordgo.New(user.Token)
session, err := discordgo.New(user.DiscordToken)
if err != nil {
return err
}
// TODO move to config
if os.Getenv("DISCORD_DEBUG") == "1" {
session.LogLevel = discordgo.LogDebug
}
user.Session = session
@@ -382,263 +447,178 @@ func (user *User) bridgeMessage(guildID string) bool {
return true
}
user.guildsLock.Lock()
defer user.guildsLock.Unlock()
if guild, found := user.guilds[guildID]; found {
if guild.Bridge {
guild := user.bridge.GetGuildByID(guildID, false)
if guild.MXID != "" {
return true
}
}
user.log.Debugfln("ignoring message for non-bridged guild %s-%s", user.ID, guildID)
user.log.Debugfln("Cgnoring message for non-bridged guild %s", guildID)
return false
}
func (user *User) readyHandler(s *discordgo.Session, r *discordgo.Ready) {
user.log.Debugln("discord connection ready")
func (user *User) readyHandler(_ *discordgo.Session, r *discordgo.Ready) {
user.log.Debugln("Discord connection ready")
// Update our user fields
user.ID = r.User.ID
// Update our guild map to match watch discord thinks we're in. This is the
// only time we can get the full guild map as discordgo doesn't make it
// available to us later. Also, discord might not give us the full guild
// information here, so we use this to remove guilds the user left and only
// add guilds whose full information we have. The are told about the
// "unavailable" guilds later via the GuildCreate handler.
user.guildsLock.Lock()
defer user.guildsLock.Unlock()
// build a list of the current guilds we're in so we can prune the old ones
current := []string{}
user.log.Debugln("database guild count", len(user.guilds))
user.log.Debugln("discord guild count", len(r.Guilds))
for _, guild := range r.Guilds {
current = append(current, guild.ID)
// If we already know about this guild, make sure we reset it's bridge
// status.
if val, found := user.guilds[guild.ID]; found {
bridge := val.Bridge
user.guilds[guild.ID].Bridge = bridge
// Update the name if the guild is available
if !guild.Unavailable {
user.guilds[guild.ID].GuildName = guild.Name
}
val.Upsert()
} else {
g := user.bridge.DB.Guild.New()
g.DiscordID = user.ID
g.GuildID = guild.ID
user.guilds[guild.ID] = g
if !guild.Unavailable {
g.GuildName = guild.Name
}
g.Upsert()
}
}
// Sync the guilds to the database.
user.bridge.DB.Guild.Prune(user.ID, current)
// Finally reload from the database since it purged servers we're not in
// anymore.
user.loadGuilds()
user.log.Debugln("updated database guild count", len(user.guilds))
if user.DiscordID != r.User.ID {
user.DiscordID = r.User.ID
user.Update()
}
func (user *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) {
user.log.Debugln("connected to discord")
updateTS := time.Now()
guildsInSpace := make(map[string]bool)
for _, guild := range user.GetGuilds() {
guildsInSpace[guild.GuildID] = guild.InSpace
}
for _, guild := range r.Guilds {
user.handleGuild(guild, updateTS, guildsInSpace[guild.ID])
}
user.PruneGuildList(updateTS)
const maxCreate = 5
for i, ch := range r.PrivateChannels {
portal := user.GetPortalByMeta(ch)
if i < maxCreate && portal.MXID == "" {
err := portal.CreateMatrixRoom(user, ch)
if err != nil {
user.log.Errorfln("Failed to create portal for private channel %s in initial sync: %v", ch.ID, err)
}
} else {
portal.UpdateInfo(user, ch)
}
}
}
func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSpace bool) {
guild := user.bridge.GetGuildByID(meta.ID, true)
guild.UpdateInfo(user, meta)
if len(meta.Channels) > 0 {
for _, ch := range meta.Channels {
portal := user.GetPortalByMeta(ch)
if guild.AutoBridgeChannels && portal.MXID == "" {
err := portal.CreateMatrixRoom(user, ch)
if err != nil {
user.log.Errorfln("Failed to create portal for guild channel %s/%s in initial sync: %v", guild.ID, ch.ID, err)
}
} else {
portal.UpdateInfo(user, ch)
}
}
}
if len(guild.MXID) > 0 && !isInSpace {
_, err := user.bridge.Bot.SendStateEvent(user.GetSpaceRoom(), event.StateSpaceChild, guild.MXID.String(), &event.SpaceChildEventContent{
Via: []string{user.bridge.AS.HomeserverDomain},
})
if err != nil {
user.log.Errorfln("Failed to add guild space %s to user space: %v", guild.MXID, err)
} else {
isInSpace = true
}
}
user.MarkInGuild(database.UserGuild{GuildID: meta.ID, Timestamp: timestamp, InSpace: isInSpace})
}
func (user *User) connectedHandler(_ *discordgo.Session, c *discordgo.Connect) {
user.log.Debugln("Connected to discord")
user.tryAutomaticDoublePuppeting()
}
func (user *User) disconnectedHandler(s *discordgo.Session, d *discordgo.Disconnect) {
user.log.Debugln("disconnected from discord")
func (user *User) disconnectedHandler(_ *discordgo.Session, d *discordgo.Disconnect) {
user.log.Debugln("Disconnected from discord")
}
func (user *User) guildCreateHandler(s *discordgo.Session, g *discordgo.GuildCreate) {
user.guildsLock.Lock()
defer user.guildsLock.Unlock()
// If we somehow already know about the guild, just update it's name
if guild, found := user.guilds[g.ID]; found {
guild.GuildName = g.Name
guild.Upsert()
func (user *User) guildCreateHandler(_ *discordgo.Session, g *discordgo.GuildCreate) {
user.handleGuild(g.Guild, time.Now(), false)
}
func (user *User) guildDeleteHandler(_ *discordgo.Session, g *discordgo.GuildDelete) {
user.MarkNotInGuild(g.ID)
guild := user.bridge.GetGuildByID(g.ID, false)
if guild == nil || guild.MXID == "" {
return
}
// This is a brand new guild so lets get it added.
guild := user.bridge.DB.Guild.New()
guild.DiscordID = user.ID
guild.GuildID = g.ID
guild.GuildName = g.Name
guild.Upsert()
user.guilds[g.ID] = guild
// TODO clean up?
}
func (user *User) guildDeleteHandler(s *discordgo.Session, g *discordgo.GuildDelete) {
user.guildsLock.Lock()
defer user.guildsLock.Unlock()
if guild, found := user.guilds[g.ID]; found {
guild.Delete()
delete(user.guilds, g.ID)
user.log.Debugln("deleted guild", g.Guild.ID)
}
func (user *User) guildUpdateHandler(_ *discordgo.Session, g *discordgo.GuildUpdate) {
user.handleGuild(g.Guild, time.Now(), user.IsInSpace(g.ID))
}
func (user *User) guildUpdateHandler(s *discordgo.Session, g *discordgo.GuildUpdate) {
user.guildsLock.Lock()
defer user.guildsLock.Unlock()
// If we somehow already know about the guild, just update it's name
if guild, found := user.guilds[g.ID]; found {
guild.GuildName = g.Name
guild.Upsert()
user.log.Debugln("updated guild", g.ID)
func (user *User) channelCreateHandler(_ *discordgo.Session, c *discordgo.ChannelCreate) {
if !user.bridgeMessage(c.GuildID) {
return
}
}
func (user *User) createChannel(c *discordgo.Channel) {
key := database.NewPortalKey(c.ID, user.User.ID)
portal := user.bridge.GetPortalByID(key)
portal := user.GetPortalByMeta(c.Channel)
if portal.MXID != "" {
return
}
portal.Name = c.Name
portal.Topic = c.Topic
portal.Type = c.Type
if portal.Type == discordgo.ChannelTypeDM {
portal.OtherUserID = c.Recipients[0].ID
err := portal.CreateMatrixRoom(user, c.Channel)
if err != nil {
user.log.Errorfln("Error creating Matrix room for %s on channel create event: %v", c.ID, err)
}
}
if c.Icon != "" {
user.log.Debugln("channel icon", c.Icon)
}
portal.Update()
portal.createMatrixRoom(user, c)
}
func (user *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) {
user.createChannel(c.Channel)
}
func (user *User) channelDeleteHandler(s *discordgo.Session, c *discordgo.ChannelDelete) {
func (user *User) channelDeleteHandler(_ *discordgo.Session, c *discordgo.ChannelDelete) {
user.log.Debugln("channel delete handler")
}
func (user *User) channelPinsUpdateHandler(s *discordgo.Session, c *discordgo.ChannelPinsUpdate) {
func (user *User) channelPinsUpdateHandler(_ *discordgo.Session, c *discordgo.ChannelPinsUpdate) {
user.log.Debugln("channel pins update")
}
func (user *User) channelUpdateHandler(s *discordgo.Session, c *discordgo.ChannelUpdate) {
key := database.NewPortalKey(c.ID, user.User.ID)
portal := user.bridge.GetPortalByID(key)
portal.update(user, c.Channel)
func (user *User) channelUpdateHandler(_ *discordgo.Session, c *discordgo.ChannelUpdate) {
portal := user.GetPortalByMeta(c.Channel)
portal.UpdateInfo(user, c.Channel)
}
func (user *User) messageCreateHandler(s *discordgo.Session, m *discordgo.MessageCreate) {
if !user.bridgeMessage(m.GuildID) {
func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildID string) {
fmt.Printf("%+v\n", msg)
if !user.bridgeMessage(guildID) {
return
}
key := database.NewPortalKey(m.ChannelID, user.ID)
portal := user.bridge.GetPortalByID(key)
msg := portalDiscordMessage{
msg: m,
user: user,
}
portal.discordMessages <- msg
}
func (user *User) messageDeleteHandler(s *discordgo.Session, m *discordgo.MessageDelete) {
if !user.bridgeMessage(m.GuildID) {
portal := user.GetExistingPortalByID(channelID)
var thread *Thread
if portal == nil {
thread = user.bridge.GetThreadByID(channelID, nil)
if thread == nil || thread.Parent == nil {
user.log.Debugfln("Dropping %s in unknown channel %s/%s", typeName, guildID, channelID)
return
}
portal = thread.Parent
}
key := database.NewPortalKey(m.ChannelID, user.ID)
portal := user.bridge.GetPortalByID(key)
msg := portalDiscordMessage{
msg: m,
portal.discordMessages <- portalDiscordMessage{
msg: msg,
user: user,
thread: thread,
}
}
portal.discordMessages <- msg
func (user *User) messageCreateHandler(_ *discordgo.Session, m *discordgo.MessageCreate) {
user.pushPortalMessage(m, "message create", m.ChannelID, m.GuildID)
}
func (user *User) messageUpdateHandler(s *discordgo.Session, m *discordgo.MessageUpdate) {
if !user.bridgeMessage(m.GuildID) {
return
func (user *User) messageDeleteHandler(_ *discordgo.Session, m *discordgo.MessageDelete) {
user.pushPortalMessage(m, "message delete", m.ChannelID, m.GuildID)
}
key := database.NewPortalKey(m.ChannelID, user.ID)
portal := user.bridge.GetPortalByID(key)
msg := portalDiscordMessage{
msg: m,
user: user,
func (user *User) messageUpdateHandler(_ *discordgo.Session, m *discordgo.MessageUpdate) {
user.pushPortalMessage(m, "message update", m.ChannelID, m.GuildID)
}
portal.discordMessages <- msg
func (user *User) reactionAddHandler(_ *discordgo.Session, m *discordgo.MessageReactionAdd) {
user.pushPortalMessage(m, "reaction add", m.ChannelID, m.GuildID)
}
func (user *User) reactionAddHandler(s *discordgo.Session, m *discordgo.MessageReactionAdd) {
if !user.bridgeMessage(m.MessageReaction.GuildID) {
return
}
key := database.NewPortalKey(m.ChannelID, user.User.ID)
portal := user.bridge.GetPortalByID(key)
msg := portalDiscordMessage{
msg: m,
user: user,
}
portal.discordMessages <- msg
}
func (user *User) reactionRemoveHandler(s *discordgo.Session, m *discordgo.MessageReactionRemove) {
if !user.bridgeMessage(m.MessageReaction.GuildID) {
return
}
key := database.NewPortalKey(m.ChannelID, user.User.ID)
portal := user.bridge.GetPortalByID(key)
msg := portalDiscordMessage{
msg: m,
user: user,
}
portal.discordMessages <- msg
func (user *User) reactionRemoveHandler(_ *discordgo.Session, m *discordgo.MessageReactionRemove) {
user.pushPortalMessage(m, "reaction remove", m.ChannelID, m.GuildID)
}
func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) bool {
if intent == nil {
intent = user.bridge.Bot
}
ret := false
inviteContent := event.Content{
@@ -682,7 +662,7 @@ func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID,
func (user *User) getDirectChats() map[id.UserID][]id.RoomID {
chats := map[id.UserID][]id.RoomID{}
privateChats := user.bridge.DB.Portal.FindPrivateChatsOf(user.ID)
privateChats := user.bridge.DB.Portal.FindPrivateChatsOf(user.DiscordID)
for _, portal := range privateChats {
if portal.MXID != "" {
puppetMXID := user.bridge.FormatPuppetMXID(portal.Key.Receiver)
@@ -755,28 +735,21 @@ func (user *User) updateDirectChats(chats map[id.UserID][]id.RoomID) {
}
func (user *User) bridgeGuild(guildID string, everything bool) error {
user.guildsLock.Lock()
defer user.guildsLock.Unlock()
guild, found := user.guilds[guildID]
if !found {
return fmt.Errorf("guildID not found")
guild := user.bridge.GetGuildByID(guildID, false)
if guild == nil {
return errors.New("guild not found")
}
// Update the guild
guild.Bridge = true
guild.Upsert()
// If this is a full bridge, create portals for all the channels
if everything {
channels, err := user.Session.GuildChannels(guildID)
meta, _ := user.Session.State.Guild(guildID)
err := guild.CreateMatrixRoom(user, meta)
if err != nil {
return err
}
for _, channel := range channels {
if channelIsBridgeable(channel) {
user.createChannel(channel)
for _, ch := range meta.Channels {
portal := user.GetPortalByMeta(ch)
if (everything && channelIsBridgeable(ch)) || ch.Type == discordgo.ChannelTypeGuildCategory {
err = portal.CreateMatrixRoom(user, ch)
if err != nil {
user.log.Warnfln("Error creating room for guild channel %s: %v", ch.ID, err)
}
}
}
@@ -785,41 +758,41 @@ func (user *User) bridgeGuild(guildID string, everything bool) error {
}
func (user *User) unbridgeGuild(guildID string) error {
user.guildsLock.Lock()
defer user.guildsLock.Unlock()
guild, exists := user.guilds[guildID]
if !exists {
return fmt.Errorf("guildID not found")
}
if !guild.Bridge {
return fmt.Errorf("guild not bridged")
}
// First update the guild so we don't have any other go routines recreating
// channels we're about to destroy.
guild.Bridge = false
guild.Upsert()
// Now run through the channels in the guild and remove any portals we
// have for them.
channels, err := user.Session.GuildChannels(guildID)
if err != nil {
return err
}
for _, channel := range channels {
if channelIsBridgeable(channel) {
key := database.PortalKey{
ChannelID: channel.ID,
Receiver: user.ID,
}
portal := user.bridge.GetPortalByID(key)
portal.leave(user)
}
}
//user.guildsLock.Lock()
//defer user.guildsLock.Unlock()
//
//guild, exists := user.guilds[guildID]
//if !exists {
// return fmt.Errorf("guildID not found")
//}
//
//if !guild.Bridge {
// return fmt.Errorf("guild not bridged")
//}
//
//// First update the guild so we don't have any other go routines recreating
//// channels we're about to destroy.
//guild.Bridge = false
//guild.Upsert()
//
//// Now run through the channels in the guild and remove any portals we
//// have for them.
//channels, err := user.Session.GuildChannels(guildID)
//if err != nil {
// return err
//}
//
//for _, channel := range channels {
// if channelIsBridgeable(channel) {
// key := database.PortalKey{
// ChannelID: channel.ID,
// Receiver: user.DiscordID,
// }
//
// portal := user.bridge.GetPortalByID(key)
// portal.leave(user)
// }
//}
return nil
}