diff --git a/attachments.go b/attachments.go index b0de607..e82f73b 100644 --- a/attachments.go +++ b/attachments.go @@ -3,7 +3,7 @@ package main import ( "bytes" "image" - "io/ioutil" + "io" "net/http" "strings" @@ -29,17 +29,16 @@ func (portal *Portal) downloadDiscordAttachment(url string) ([]byte, error) { if err != nil { return nil, err } - - req.Header.Set("User-Agent", discordgo.DroidBrowserUserAgent) + for key, value := range discordgo.DroidDownloadHeaders { + req.Header.Set(key, value) + } resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } - defer resp.Body.Close() - - return ioutil.ReadAll(resp.Body) + return io.ReadAll(resp.Body) } func (portal *Portal) downloadMatrixAttachment(eventID id.EventID, content *event.MessageEventContent) ([]byte, error) { diff --git a/avatar.go b/avatar.go index 272c038..320ac72 100644 --- a/avatar.go +++ b/avatar.go @@ -7,16 +7,25 @@ import ( "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/id" + + "github.com/bwmarrin/discordgo" ) func uploadAvatar(intent *appservice.IntentAPI, url string) (id.ContentURI, error) { - getResp, err := http.DefaultClient.Get(url) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return id.ContentURI{}, fmt.Errorf("failed to prepare request: %w", err) + } + for key, value := range discordgo.DroidImageHeaders { + req.Header.Set(key, value) + } + getResp, err := http.DefaultClient.Do(req) if err != nil { return id.ContentURI{}, fmt.Errorf("failed to download avatar: %w", err) } data, err := io.ReadAll(getResp.Body) - getResp.Body.Close() + _ = getResp.Body.Close() if err != nil { return id.ContentURI{}, fmt.Errorf("failed to read avatar data: %w", err) } diff --git a/commands.go b/commands.go index 8c08807..b8f65c0 100644 --- a/commands.go +++ b/commands.go @@ -116,7 +116,7 @@ func fnLogin(ce *WrappedCommandEvent) { ce.Reply("Error connecting after login: %v", err) } ce.User.Lock() - ce.User.ID = user.UserID + ce.User.DiscordID = user.UserID ce.User.Update() ce.User.Unlock() ce.Reply("Successfully logged in as %s#%s", user.Username, user.Discriminator) @@ -251,19 +251,21 @@ func fnGuilds(ce *WrappedCommandEvent) { } func fnListGuilds(ce *WrappedCommandEvent) { - ce.User.guildsLock.Lock() - defer ce.User.guildsLock.Unlock() - if len(ce.User.guilds) == 0 { - ce.Reply("You haven't joined any guilds") - } else { - var output strings.Builder - for _, guild := range ce.User.guilds { - status := "not bridged" - if guild.Bridge { - status = "bridged" - } - _, _ = fmt.Fprintf(&output, "* %s (`%s`) - %s\n", guild.GuildName, guild.GuildID, status) + var output strings.Builder + for _, userGuild := range ce.User.GetGuilds() { + guild := ce.Bridge.GetGuildByID(userGuild.GuildID, false) + if guild == nil { + continue } + status := "not bridged" + if guild.MXID != "" { + status = "bridged" + } + _, _ = fmt.Fprintf(&output, "* %s (`%s`) - %s\n", guild.Name, guild.ID, status) + } + if output.Len() == 0 { + ce.Reply("No guilds found") + } else { ce.Reply("List of guilds:\n\n%s", output.String()) } } diff --git a/database/attachment.go b/database/attachment.go index a4f8538..a30a1b9 100644 --- a/database/attachment.go +++ b/database/attachment.go @@ -16,7 +16,7 @@ type AttachmentQuery struct { } const ( - attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver FROM attachment" + attachmentSelect = "SELECT dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, dc_thread_id FROM attachment" ) func (aq *AttachmentQuery) New() *Attachment { @@ -77,58 +77,50 @@ type Attachment struct { db *Database log log.Logger - Channel PortalKey + Channel PortalKey + ThreadID string + MessageID string + ID string + MXID id.EventID +} - DiscordMessageID string - DiscordAttachmentID string - MXID id.EventID +func (a *Attachment) DiscordProtoChannelID() string { + if a.ThreadID != "" { + return a.ThreadID + } else { + return a.Channel.ChannelID + } } func (a *Attachment) Scan(row dbutil.Scannable) *Attachment { - err := row.Scan( - &a.DiscordAttachmentID, &a.DiscordMessageID, - &a.Channel.ChannelID, &a.Channel.Receiver, - &a.MXID) - + err := row.Scan(&a.ID, &a.MessageID, &a.Channel.ChannelID, &a.Channel.Receiver, &a.ThreadID, &a.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { a.log.Errorln("Database scan failed:", err) + panic(err) } - return nil } - return a } func (a *Attachment) Insert() { - query := "INSERT INTO attachment" + - " (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, " + - " mxid) VALUES ($1, $2, $3, $4, $5);" - - _, err := a.db.Exec( - query, - a.Channel.ChannelID, a.Channel.Receiver, - a.DiscordMessageID, a.DiscordAttachmentID, - a.MXID, - ) - + query := ` + INSERT INTO attachment (dcid, dc_msg_id, dc_chan_id, dc_chan_receiver, mxid) + VALUES ($1, $2, $3, $4, $5) + ` + _, err := a.db.Exec(query, a.ID, a.MessageID, a.Channel.ChannelID, a.Channel.Receiver, strPtr(a.ThreadID), a.MXID) if err != nil { - a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err) + a.log.Warnfln("Failed to insert attachment for %s@%s: %v", a.ID, a.Channel, err) + panic(err) } } func (a *Attachment) Delete() { - query := "DELETE FROM attachment WHERE" + - " dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3" - - _, err := a.db.Exec( - query, - a.Channel.ChannelID, a.Channel.Receiver, - a.DiscordAttachmentID, - ) - + query := `DELETE FROM attachment WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3` + _, err := a.db.Exec(query, a.Channel.ChannelID, a.Channel.Receiver, a.ID) if err != nil { - a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.DiscordAttachmentID, a.Channel, err) + a.log.Warnfln("Failed to delete attachment for %s@%s: %v", a.ID, a.Channel, err) + panic(err) } } diff --git a/database/database.go b/database/database.go index d53c6f3..a3bee8a 100644 --- a/database/database.go +++ b/database/database.go @@ -7,8 +7,9 @@ import ( _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - "go.mau.fi/mautrix-discord/database/upgrades" "maunium.net/go/mautrix/util/dbutil" + + "go.mau.fi/mautrix-discord/database/upgrades" ) type Database struct { @@ -18,6 +19,7 @@ type Database struct { Portal *PortalQuery Puppet *PuppetQuery Message *MessageQuery + Thread *ThreadQuery Reaction *ReactionQuery Attachment *AttachmentQuery Emoji *EmojiQuery @@ -54,6 +56,10 @@ func New(baseDB *dbutil.Database) *Database { db: db, log: db.Log.Sub("Message"), } + db.Thread = &ThreadQuery{ + db: db, + log: db.Log.Sub("Thread"), + } db.Reaction = &ReactionQuery{ db: db, log: db.Log.Sub("Reaction"), @@ -72,3 +78,10 @@ func New(baseDB *dbutil.Database) *Database { } return db } + +func strPtr(val string) *string { + if val == "" { + return nil + } + return &val +} diff --git a/database/emoji.go b/database/emoji.go index 9e10198..285b09f 100644 --- a/database/emoji.go +++ b/database/emoji.go @@ -28,23 +28,16 @@ func (eq *EmojiQuery) New() *Emoji { func (eq *EmojiQuery) GetByDiscordID(discordID string) *Emoji { query := emojiSelect + " WHERE discord_id=$1" - return eq.get(query, discordID) } func (eq *EmojiQuery) GetByMatrixURL(matrixURL id.ContentURI) *Emoji { query := emojiSelect + " WHERE matrix_url=$1" - return eq.get(query, matrixURL.String()) } func (eq *EmojiQuery) get(query string, args ...interface{}) *Emoji { - row := eq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return eq.New().Scan(row) + return eq.New().Scan(eq.db.QueryRow(query, args...)) } type Emoji struct { @@ -59,18 +52,17 @@ type Emoji struct { func (e *Emoji) Scan(row dbutil.Scannable) *Emoji { var matrixURL sql.NullString - err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL) + err := row.Scan(&e.DiscordID, &e.DiscordName, &matrixURL) if err != nil { if !errors.Is(err, sql.ErrNoRows) { e.log.Errorln("Database scan failed:", err) + panic(err) } - return nil } e.MatrixURL, _ = id.ParseContentURI(matrixURL.String) - return e } @@ -83,6 +75,7 @@ func (e *Emoji) Insert() { if err != nil { e.log.Warnfln("Failed to insert emoji %s: %v", e.DiscordID, err) + panic(err) } } @@ -92,17 +85,15 @@ func (e *Emoji) Delete() { _, err := e.db.Exec(query, e.DiscordID) if err != nil { e.log.Warnfln("Failed to delete emoji %s: %v", e.DiscordID, err) + panic(err) } } func (e *Emoji) APIName() string { if e.DiscordID != "" && e.DiscordName != "" { return e.DiscordName + ":" + e.DiscordID - } - - if e.DiscordName != "" { + } else if e.DiscordName != "" { return e.DiscordName } - return e.DiscordID } diff --git a/database/guild.go b/database/guild.go index f69b725..2a09d69 100644 --- a/database/guild.go +++ b/database/guild.go @@ -3,9 +3,9 @@ package database import ( "database/sql" "errors" - "fmt" log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/util/dbutil" ) @@ -16,7 +16,7 @@ type GuildQuery struct { } const ( - guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild" + guildSelect = "SELECT dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels FROM guild" ) func (gq *GuildQuery) New() *Guild { @@ -26,109 +26,100 @@ func (gq *GuildQuery) New() *Guild { } } -func (gq *GuildQuery) Get(discordID, guildID string) *Guild { - query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2" - - row := gq.db.QueryRow(query, discordID, guildID) - if row == nil { - return nil - } - - return gq.New().Scan(row) +func (gq *GuildQuery) GetByID(dcid string) *Guild { + query := guildSelect + " WHERE dcid=$1" + return gq.New().Scan(gq.db.QueryRow(query, dcid)) } -func (gq *GuildQuery) GetAll(discordID string) []*Guild { - query := guildSelect + " WHERE discord_id=$1" +func (gq *GuildQuery) GetByMXID(mxid id.RoomID) *Guild { + query := guildSelect + " WHERE mxid=$1" + return gq.New().Scan(gq.db.QueryRow(query, mxid)) +} - rows, err := gq.db.Query(query, discordID) - if err != nil || rows == nil { +func (gq *GuildQuery) GetAll() []*Guild { + rows, err := gq.db.Query(guildSelect) + if err != nil { + gq.log.Errorln("Failed to query guilds:", err) return nil } - guilds := []*Guild{} + var guilds []*Guild for rows.Next() { - guilds = append(guilds, gq.New().Scan(rows)) + guild := gq.New().Scan(rows) + if guild != nil { + guilds = append(guilds, guild) + } } return guilds } -func (gq *GuildQuery) Prune(discordID string, guilds []string) { - // We need this interface slice because a variadic function can't mix - // arguements with a `...` expanded slice. - args := []interface{}{discordID} - - nGuilds := len(guilds) - if nGuilds <= 0 { - return - } - - gq.log.Debugfln("prunning guilds for %s", discordID) - - // Build the in query - inQuery := "$2" - for i := 1; i < nGuilds; i++ { - inQuery += fmt.Sprintf(", $%d", i+2) - } - - // Add the arguements for the build query - for _, guildID := range guilds { - args = append(args, guildID) - } - - // Now remove any guilds that the user has left. - query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" + - inQuery + ")" - - _, err := gq.db.Exec(query, args...) - if err != nil { - gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err) - } -} - type Guild struct { db *Database log log.Logger - DiscordID string - GuildID string - GuildName string - Bridge bool + ID string + MXID id.RoomID + Name string + NameSet bool + Avatar string + AvatarURL id.ContentURI + AvatarSet bool + + AutoBridgeChannels bool } func (g *Guild) Scan(row dbutil.Scannable) *Guild { - err := row.Scan(&g.DiscordID, &g.GuildID, &g.GuildName, &g.Bridge) + var mxid sql.NullString + var avatarURL string + err := row.Scan(&g.ID, &mxid, &g.Name, &g.NameSet, &g.Avatar, &avatarURL, &g.AvatarSet, &g.AutoBridgeChannels) if err != nil { if !errors.Is(err, sql.ErrNoRows) { g.log.Errorln("Database scan failed:", err) + panic(err) } return nil } - + g.MXID = id.RoomID(mxid.String) + g.AvatarURL, _ = id.ParseContentURI(avatarURL) return g } -func (g *Guild) Upsert() { - query := "INSERT INTO guild" + - " (discord_id, guild_id, guild_name, bridge)" + - " VALUES ($1, $2, $3, $4)" + - " ON CONFLICT(discord_id, guild_id)" + - " DO UPDATE SET guild_name=excluded.guild_name, bridge=excluded.bridge" - - _, err := g.db.Exec(query, g.DiscordID, g.GuildID, g.GuildName, g.Bridge) - +func (g *Guild) mxidPtr() *id.RoomID { + if g.MXID != "" { + return &g.MXID + } + return nil +} +func (g *Guild) Insert() { + query := ` + INSERT INTO guild (dcid, mxid, name, name_set, avatar, avatar_url, avatar_set, auto_bridge_channels) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ` + _, err := g.db.Exec(query, g.ID, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels) if err != nil { - g.log.Warnfln("Failed to upsert guild %s for %s: %v", g.GuildID, g.DiscordID, err) + g.log.Warnfln("Failed to insert %s: %v", g.ID, err) + panic(err) + } +} + +func (g *Guild) Update() { + query := ` + UPDATE guild SET mxid=$1, name=$2, name_set=$3, avatar=$4, avatar_url=$5, avatar_set=$6, auto_bridge_channels=$7 + WHERE dcid=$8 + ` + _, err := g.db.Exec(query, g.mxidPtr(), g.Name, g.NameSet, g.Avatar, g.AvatarURL.String(), g.AvatarSet, g.AutoBridgeChannels, g.ID) + if err != nil { + g.log.Warnfln("Failed to update %s: %v", g.ID, err) + panic(err) } } func (g *Guild) Delete() { - query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id=$2" - - _, err := g.db.Exec(query, g.DiscordID, g.GuildID) - + _, err := g.db.Exec("DELETE FROM guild WHERE dcid=$1", g.ID) if err != nil { - g.log.Warnfln("Failed to delete guild %s for user %s: %v", g.GuildID, g.DiscordID, err) + g.log.Warnfln("Failed to delete %s: %v", g.ID, err) + panic(err) } } diff --git a/database/message.go b/database/message.go index b95ee5e..04fe525 100644 --- a/database/message.go +++ b/database/message.go @@ -17,7 +17,7 @@ type MessageQuery struct { } const ( - messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid FROM message" + messageSelect = "SELECT dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid FROM message" ) func (mq *MessageQuery) New() *Message { @@ -45,14 +45,12 @@ func (mq *MessageQuery) GetAll(key PortalKey) []*Message { func (mq *MessageQuery) GetByDiscordID(key PortalKey, discordID string) *Message { query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dcid=$3" + return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID)) +} - row := mq.db.QueryRow(query, key.ChannelID, key.Receiver, discordID) - if row == nil { - mq.log.Debugfln("failed to find existing message for discord_id %s", discordID) - return nil - } - - return mq.New().Scan(row) +func (mq *MessageQuery) GetLastInThread(key PortalKey, threadID string) *Message { + query := messageSelect + " WHERE dc_chan_id=$1 AND dc_chan_receiver=$2 AND dc_thread_id=$3 ORDER BY timestamp DESC LIMIT 1" + return mq.New().Scan(mq.db.QueryRow(query, key.ChannelID, key.Receiver, threadID)) } func (mq *MessageQuery) GetByMXID(key PortalKey, mxid id.EventID) *Message { @@ -74,36 +72,54 @@ type Message struct { Channel PortalKey SenderID string Timestamp time.Time + ThreadID string MXID id.EventID } +func (m *Message) DiscordProtoChannelID() string { + if m.ThreadID != "" { + return m.ThreadID + } else { + return m.Channel.ChannelID + } +} + func (m *Message) Scan(row dbutil.Scannable) *Message { var ts int64 + var threadID sql.NullString - err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &m.MXID) + err := row.Scan(&m.DiscordID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &threadID, &m.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { m.log.Errorln("Database scan failed:", err) + panic(err) } return nil } if ts != 0 { - m.Timestamp = time.Unix(ts, 0) + m.Timestamp = time.UnixMilli(ts) } + m.ThreadID = threadID.String return m } func (m *Message) Insert() { - query := "INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, mxid) VALUES ($1, $2, $3, $4, $5, $6)" + query := ` + INSERT INTO message (dcid, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_thread_id, mxid) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ` - _, err := m.db.Exec(query, m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, m.Timestamp.Unix(), m.MXID) + _, err := m.db.Exec(query, + m.DiscordID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID, + m.Timestamp.UnixMilli(), strPtr(m.ThreadID), m.MXID) if err != nil { m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err) + panic(err) } } @@ -114,5 +130,6 @@ func (m *Message) Delete() { if err != nil { m.log.Warnfln("Failed to delete %s@%s: %v", m.DiscordID, m.Channel, err) + panic(err) } } diff --git a/database/portal.go b/database/portal.go index e04d316..b5d5dce 100644 --- a/database/portal.go +++ b/database/portal.go @@ -12,11 +12,30 @@ import ( ) const ( - portalSelect = "SELECT dcid, receiver, mxid, name, topic, avatar," + - " avatar_url, type, other_user_id, first_event_id, encrypted" + + portalSelect = "SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, " + + " mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, in_space, first_event_id" + " FROM portal" ) +type PortalKey struct { + ChannelID string + Receiver string +} + +func NewPortalKey(channelID, receiver string) PortalKey { + return PortalKey{ + ChannelID: channelID, + Receiver: receiver, + } +} + +func (key PortalKey) String() string { + if key.Receiver == "" { + return key.ChannelID + } + return key.ChannelID + "-" + key.Receiver +} + type PortalQuery struct { db *Database log log.Logger @@ -34,7 +53,7 @@ func (pq *PortalQuery) GetAll() []*Portal { } func (pq *PortalQuery) GetByID(key PortalKey) *Portal { - return pq.get(portalSelect+" WHERE dcid=$1 AND receiver=$2", key.ChannelID, key.Receiver) + return pq.get(portalSelect+" WHERE dcid=$1 AND (receiver=$2 OR receiver='')", key.ChannelID, key.Receiver) } func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { @@ -67,12 +86,7 @@ func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal { } func (pq *PortalQuery) get(query string, args ...interface{}) *Portal { - row := pq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return pq.New().Scan(row) + return pq.New().Scan(pq.db.QueryRow(query, args...)) } type Portal struct { @@ -82,78 +96,87 @@ type Portal struct { Key PortalKey Type discordgo.ChannelType OtherUserID string + ParentID string + GuildID string MXID id.RoomID Name string + NameSet bool Topic string + TopicSet bool Avatar string AvatarURL id.ContentURI + AvatarSet bool Encrypted bool + InSpace id.RoomID FirstEventID id.EventID } func (p *Portal) Scan(row dbutil.Scannable) *Portal { - var mxid, avatarURL, firstEventID sql.NullString - var typ sql.NullInt32 + var otherUserID, guildID, parentID, mxid, firstEventID sql.NullString + var chanType int32 + var avatarURL string - err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name, - &p.Topic, &p.Avatar, &avatarURL, &typ, &p.OtherUserID, &firstEventID, - &p.Encrypted) + err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID, + &mxid, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet, + &p.Encrypted, &p.InSpace, &firstEventID) if err != nil { if err != sql.ErrNoRows { p.log.Errorln("Database scan failed:", err) + panic(err) } return nil } p.MXID = id.RoomID(mxid.String) - p.AvatarURL, _ = id.ParseContentURI(avatarURL.String) - p.Type = discordgo.ChannelType(typ.Int32) + p.OtherUserID = otherUserID.String + p.GuildID = guildID.String + p.ParentID = parentID.String + p.Type = discordgo.ChannelType(chanType) p.FirstEventID = id.EventID(firstEventID.String) + p.AvatarURL, _ = id.ParseContentURI(avatarURL) return p } -func (p *Portal) mxidPtr() *id.RoomID { - if p.MXID != "" { - return &p.MXID - } - - return nil -} - func (p *Portal) Insert() { - query := "INSERT INTO portal" + - " (dcid, receiver, mxid, name, topic, avatar, avatar_url," + - " type, other_user_id, first_event_id, encrypted)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" - - _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(), - p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.OtherUserID, - p.FirstEventID.String(), p.Encrypted) + query := ` + INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid, + name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, + encrypted, in_space, first_event_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) + ` + _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type, + strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)), + p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, + p.Encrypted, p.InSpace, p.FirstEventID.String()) if err != nil { p.log.Warnfln("Failed to insert %s: %v", p.Key, err) + panic(err) } } func (p *Portal) Update() { - query := "UPDATE portal SET" + - " mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," + - " other_user_id=$7, first_event_id=$8, encrypted=$9" + - " WHERE dcid=$10 AND receiver=$11" - - _, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar, - p.AvatarURL.String(), p.Type, p.OtherUserID, p.FirstEventID.String(), - p.Encrypted, + query := ` + UPDATE portal SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5, + name=$6, name_set=$7, topic=$8, topic_set=$9, avatar=$10, avatar_url=$11, avatar_set=$12, + encrypted=$13, in_space=$14, first_event_id=$15 + WHERE dcid=$16 AND receiver=$17 + ` + _, err := p.db.Exec(query, + p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)), + p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, + p.Encrypted, p.InSpace, p.FirstEventID.String(), p.Key.ChannelID, p.Key.Receiver) if err != nil { p.log.Warnfln("Failed to update %s: %v", p.Key, err) + panic(err) } } @@ -162,5 +185,6 @@ func (p *Portal) Delete() { _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver) if err != nil { p.log.Warnfln("Failed to delete %s: %v", p.Key, err) + panic(err) } } diff --git a/database/portalkey.go b/database/portalkey.go deleted file mode 100644 index 54c1a59..0000000 --- a/database/portalkey.go +++ /dev/null @@ -1,20 +0,0 @@ -package database - -type PortalKey struct { - ChannelID string - Receiver string -} - -func NewPortalKey(channelID, receiver string) PortalKey { - return PortalKey{ - ChannelID: channelID, - Receiver: receiver, - } -} - -func (key PortalKey) String() string { - if key.Receiver == "" { - return key.ChannelID - } - return key.ChannelID + "-" + key.Receiver -} diff --git a/database/puppet.go b/database/puppet.go index c5f1ba9..1200a53 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -10,7 +10,7 @@ import ( ) const ( - puppetSelect = "SELECT id, display_name, avatar, avatar_url," + + puppetSelect = "SELECT id, name, name_set, avatar, avatar_url, avatar_set," + " custom_mxid, access_token, next_batch" + " FROM puppet " ) @@ -36,12 +36,7 @@ func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet { } func (pq *PuppetQuery) get(query string, args ...interface{}) *Puppet { - row := pq.db.QueryRow(query, args...) - if row == nil { - return nil - } - - return pq.New().Scan(row) + return pq.New().Scan(pq.db.QueryRow(query, args...)) } func (pq *PuppetQuery) GetAll() []*Puppet { @@ -59,7 +54,7 @@ func (pq *PuppetQuery) getAll(query string, args ...interface{}) []*Puppet { } defer rows.Close() - puppets := []*Puppet{} + var puppets []*Puppet for rows.Next() { puppets = append(puppets, pq.New().Scan(rows)) } @@ -71,11 +66,12 @@ type Puppet struct { db *Database log log.Logger - ID string - DisplayName string - + ID string + Name string + NameSet bool Avatar string AvatarURL id.ContentURI + AvatarSet bool CustomMXID id.UserID AccessToken string @@ -83,24 +79,22 @@ type Puppet struct { } func (p *Puppet) Scan(row dbutil.Scannable) *Puppet { - var did, displayName, avatar, avatarURL sql.NullString + var avatarURL string var customMXID, accessToken, nextBatch sql.NullString - err := row.Scan(&did, &displayName, &avatar, &avatarURL, + err := row.Scan(&p.ID, &p.Name, &p.NameSet, &p.Avatar, &avatarURL, &p.AvatarSet, &customMXID, &accessToken, &nextBatch) if err != nil { if err != sql.ErrNoRows { p.log.Errorln("Database scan failed:", err) + panic(err) } return nil } - p.ID = did.String - p.DisplayName = displayName.String - p.Avatar = avatar.String - p.AvatarURL, _ = id.ParseContentURI(avatarURL.String) + p.AvatarURL, _ = id.ParseContentURI(avatarURL) p.CustomMXID = id.UserID(customMXID.String) p.AccessToken = accessToken.String p.NextBatch = nextBatch.String @@ -109,31 +103,31 @@ func (p *Puppet) Scan(row dbutil.Scannable) *Puppet { } func (p *Puppet) Insert() { - query := "INSERT INTO puppet" + - " (id, display_name, avatar, avatar_url," + - " custom_mxid, access_token, next_batch)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7)" - - _, err := p.db.Exec(query, p.ID, p.DisplayName, p.Avatar, - p.AvatarURL.String(), p.CustomMXID, p.AccessToken, - p.NextBatch) + query := ` + INSERT INTO puppet (id, name, name_set, avatar, avatar_url, avatar_set, custom_mxid, access_token, next_batch) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ` + _, err := p.db.Exec(query, p.ID, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, + strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch)) if err != nil { p.log.Warnfln("Failed to insert %s: %v", p.ID, err) + panic(err) } } func (p *Puppet) Update() { - query := "UPDATE puppet" + - " SET display_name=$1, avatar=$2, avatar_url=$3, " + - " custom_mxid=$4, access_token=$5, next_batch=$6" + - " WHERE id=$7" - - _, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(), - p.CustomMXID, p.AccessToken, p.NextBatch, + query := ` + UPDATE puppet SET name=$1, name_set=$2, avatar=$3, avatar_url=$4, avatar_set=$5, + custom_mxid=$6, access_token=$7, next_batch=$8 + WHERE id=$9 + ` + _, err := p.db.Exec(query, p.Name, p.NameSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, + strPtr(string(p.CustomMXID)), strPtr(p.AccessToken), strPtr(p.NextBatch), p.ID) if err != nil { p.log.Warnfln("Failed to update %s: %v", p.ID, err) + panic(err) } } diff --git a/database/reaction.go b/database/reaction.go index 41037c1..e681918 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -16,7 +16,7 @@ type ReactionQuery struct { } const ( - reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, mxid FROM reaction" + reactionSelect = "SELECT dc_chan_id, dc_chan_receiver, dc_msg_id, dc_sender, dc_emoji_name, dc_thread_id, mxid FROM reaction" ) func (rq *ReactionQuery) New() *Reaction { @@ -75,15 +75,17 @@ type Reaction struct { MessageID string Sender string EmojiName string + ThreadID string MXID id.EventID } func (r *Reaction) Scan(row dbutil.Scannable) *Reaction { - err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.MXID) + err := row.Scan(&r.Channel.ChannelID, &r.Channel.Receiver, &r.MessageID, &r.Sender, &r.EmojiName, &r.ThreadID, &r.MXID) if err != nil { if !errors.Is(err, sql.ErrNoRows) { r.log.Errorln("Database scan failed:", err) + panic(err) } return nil } @@ -91,14 +93,23 @@ func (r *Reaction) Scan(row dbutil.Scannable) *Reaction { return r } +func (r *Reaction) DiscordProtoChannelID() string { + if r.ThreadID != "" { + return r.ThreadID + } else { + return r.Channel.ChannelID + } +} + func (r *Reaction) Insert() { query := ` - INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, mxid) - VALUES($1, $2, $3, $4, $5, $6) + INSERT INTO reaction (dc_msg_id, dc_sender, dc_emoji_name, dc_chan_id, dc_chan_receiver, dc_thread_id, mxid) + VALUES($1, $2, $3, $4, $5, $6, $7) ` - _, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, r.MXID) + _, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName, r.Channel.ChannelID, r.Channel.Receiver, strPtr(r.ThreadID), r.MXID) if err != nil { r.log.Warnfln("Failed to insert reaction for %s@%s: %v", r.MessageID, r.Channel, err) + panic(err) } } @@ -107,5 +118,6 @@ func (r *Reaction) Delete() { _, err := r.db.Exec(query, r.MessageID, r.Sender, r.EmojiName) if err != nil { r.log.Warnfln("Failed to delete reaction for %s@%s: %v", r.MessageID, r.Channel, err) + panic(err) } } diff --git a/database/thread.go b/database/thread.go new file mode 100644 index 0000000..09bf1de --- /dev/null +++ b/database/thread.go @@ -0,0 +1,105 @@ +package database + +import ( + "database/sql" + "errors" + + log "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/util/dbutil" +) + +type ThreadQuery struct { + db *Database + log log.Logger +} + +const ( + threadSelect = "SELECT dcid, parent_chan_id, root_msg_dcid, root_msg_mxid FROM thread" +) + +func (tq *ThreadQuery) New() *Thread { + return &Thread{ + db: tq.db, + log: tq.log, + } +} + +func (tq *ThreadQuery) GetByDiscordID(discordID string) *Thread { + query := threadSelect + " WHERE dcid=$1" + + row := tq.db.QueryRow(query, discordID) + if row == nil { + return nil + } + + return tq.New().Scan(row) +} + +//func (tq *ThreadQuery) GetByDiscordRootMsg(channelID, messageID string) *Thread { +// query := messageSelect + " WHERE parent_chan_id=$1 AND root_msg_dcid=$2" +// +// row := tq.db.QueryRow(query, channelID, messageID) +// if row == nil { +// return nil +// } +// +// return tq.New().Scan(row) +//} + +func (tq *ThreadQuery) GetByMatrixRootMsg(mxid id.EventID) *Thread { + query := threadSelect + " WHERE root_msg_mxid=$1" + + row := tq.db.QueryRow(query, mxid) + if row == nil { + return nil + } + + return tq.New().Scan(row) +} + +type Thread struct { + db *Database + log log.Logger + + ID string + ParentID string + + RootDiscordID string + RootMXID id.EventID +} + +func (t *Thread) Scan(row dbutil.Scannable) *Thread { + err := row.Scan(&t.ID, &t.ParentID, &t.RootDiscordID, &t.RootMXID) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + t.log.Errorln("Database scan failed:", err) + panic(err) + } + return nil + } + return t +} + +func (t *Thread) Insert() { + query := "INSERT INTO thread (dcid, parent_chan_id, root_msg_dcid, root_msg_mxid) VALUES ($1, $2, $3, $4)" + + _, err := t.db.Exec(query, t.ID, t.ParentID, t.RootDiscordID, t.RootMXID) + + if err != nil { + t.log.Warnfln("Failed to insert %s@%s: %v", t.ID, t.ParentID, err) + panic(err) + } +} + +func (t *Thread) Delete() { + query := "DELETE FROM thread WHERE dcid=$1 AND parent_chan_id=$2" + + _, err := t.db.Exec(query, t.ID, t.ParentID) + + if err != nil { + t.log.Warnfln("Failed to delete %s@%s: %v", t.ID, t.ParentID, err) + panic(err) + } +} diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index 4a63f90..0a1c872 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,29 +1,65 @@ --- v0 -> v2: Latest revision +-- v0 -> v3: Latest revision + +CREATE TABLE guild ( + dcid TEXT PRIMARY KEY, + mxid TEXT UNIQUE, + name TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar TEXT NOT NULL, + avatar_url TEXT NOT NULL, + avatar_set BOOLEAN NOT NULL, + + auto_bridge_channels BOOLEAN NOT NULL +); CREATE TABLE portal ( dcid TEXT, receiver TEXT, other_user_id TEXT, - type INTEGER, + type INTEGER NOT NULL, + + dc_guild_id TEXT, + dc_parent_id TEXT, + -- This is not accessed by the bridge, it's only used for the portal parent foreign key. + -- Only guild channels have parents, but only DMs have a receiver field. + dc_parent_receiver TEXT NOT NULL DEFAULT '', mxid TEXT UNIQUE, name TEXT NOT NULL, + name_set BOOLEAN NOT NULL, topic TEXT NOT NULL, avatar TEXT NOT NULL, avatar_url TEXT NOT NULL, - encrypted BOOLEAN NOT NULL DEFAULT false, + avatar_set BOOLEAN NOT NULL, + encrypted BOOLEAN NOT NULL, + in_space TEXT NOT NULL, first_event_id TEXT NOT NULL, - PRIMARY KEY (dcid, receiver) + PRIMARY KEY (dcid, receiver), + CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE, + CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE +); + +CREATE TABLE thread ( + dcid TEXT PRIMARY KEY, + parent_chan_id TEXT NOT NULL, + root_msg_dcid TEXT NOT NULL, + root_msg_mxid TEXT NOT NULL, + -- This is also not accessed by the bridge. + receiver TEXT NOT NULL DEFAULT '', + + CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE ); CREATE TABLE puppet ( id TEXT PRIMARY KEY, - name TEXT, - avatar TEXT, - avatar_url TEXT, + name TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar TEXT NOT NULL, + avatar_url TEXT NOT NULL, + avatar_set BOOLEAN NOT NULL, custom_mxid TEXT, access_token TEXT, @@ -34,9 +70,20 @@ CREATE TABLE "user" ( mxid TEXT PRIMARY KEY, dcid TEXT UNIQUE, + discord_token TEXT, management_room TEXT, + space_room TEXT +); - token TEXT +CREATE TABLE user_guild ( + guild_id TEXT, + user_mxid TEXT, + in_space BOOLEAN NOT NULL, + timestamp BIGINT NOT NULL, + + PRIMARY KEY (guild_id, user_mxid), + CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE, + CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ); CREATE TABLE message ( @@ -45,6 +92,7 @@ CREATE TABLE message ( dc_chan_receiver TEXT, dc_sender TEXT NOT NULL, timestamp BIGINT NOT NULL, + dc_thread_id TEXT, mxid TEXT NOT NULL UNIQUE, @@ -58,6 +106,7 @@ CREATE TABLE reaction ( dc_msg_id TEXT, dc_sender TEXT, dc_emoji_name TEXT, + dc_thread_id TEXT, mxid TEXT NOT NULL UNIQUE, @@ -70,6 +119,7 @@ CREATE TABLE attachment ( dc_msg_id TEXT, dc_chan_id TEXT, dc_chan_receiver TEXT, + dc_thread_id TEXT, mxid TEXT NOT NULL UNIQUE, @@ -82,11 +132,3 @@ CREATE TABLE emoji ( discord_name TEXT, matrix_url TEXT ); - -CREATE TABLE guild ( - discord_id TEXT NOT NULL, - guild_id TEXT NOT NULL, - guild_name TEXT NOT NULL, - bridge BOOLEAN DEFAULT FALSE, - PRIMARY KEY(discord_id, guild_id) -); diff --git a/database/upgrades/03-spaces.sql b/database/upgrades/03-spaces.sql new file mode 100644 index 0000000..53ae9bc --- /dev/null +++ b/database/upgrades/03-spaces.sql @@ -0,0 +1,72 @@ +-- v3: Store portal parent metadata for spaces +DROP TABLE guild; + +CREATE TABLE guild ( + dcid TEXT PRIMARY KEY, + mxid TEXT UNIQUE, + name TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar TEXT NOT NULL, + avatar_url TEXT NOT NULL, + avatar_set BOOLEAN NOT NULL, + + auto_bridge_channels BOOLEAN NOT NULL +); + +CREATE TABLE user_guild ( + guild_id TEXT, + user_mxid TEXT, + in_space BOOLEAN NOT NULL, + timestamp BIGINT NOT NULL, + + PRIMARY KEY (guild_id, user_mxid), + CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE, + CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE +); + +ALTER TABLE portal ADD COLUMN dc_guild_id TEXT; +ALTER TABLE portal ADD COLUMN dc_parent_id TEXT; +ALTER TABLE portal ADD COLUMN dc_parent_receiver TEXT NOT NULL DEFAULT ''; +ALTER TABLE portal ADD CONSTRAINT portal_parent_fkey FOREIGN KEY (dc_parent_id, dc_parent_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE; +ALTER TABLE portal ADD CONSTRAINT portal_guild_fkey FOREIGN KEY (dc_guild_id) REFERENCES guild(dcid) ON DELETE CASCADE; +DELETE FROM portal WHERE type IS NULL; +-- only: postgres +ALTER TABLE portal ALTER COLUMN type SET NOT NULL; + +ALTER TABLE portal ADD COLUMN in_space TEXT NOT NULL DEFAULT ''; +ALTER TABLE portal ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false; +ALTER TABLE portal ADD COLUMN topic_set BOOLEAN NOT NULL DEFAULT false; +ALTER TABLE portal ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false; +-- only: postgres for next 5 lines +ALTER TABLE portal ALTER COLUMN in_space DROP DEFAULT; +ALTER TABLE portal ALTER COLUMN name_set DROP DEFAULT; +ALTER TABLE portal ALTER COLUMN topic_set DROP DEFAULT; +ALTER TABLE portal ALTER COLUMN avatar_set DROP DEFAULT; +ALTER TABLE portal ALTER COLUMN encrypted DROP DEFAULT; + +ALTER TABLE puppet RENAME COLUMN display_name TO name; +ALTER TABLE puppet ADD COLUMN name_set BOOLEAN NOT NULL DEFAULT false; +ALTER TABLE puppet ADD COLUMN avatar_set BOOLEAN NOT NULL DEFAULT false; +-- only: postgres for next 2 lines +ALTER TABLE puppet ALTER COLUMN name_set DROP DEFAULT; +ALTER TABLE puppet ALTER COLUMN avatar_set DROP DEFAULT; + +ALTER TABLE "user" ADD COLUMN space_room TEXT; +ALTER TABLE "user" RENAME COLUMN token TO discord_token; + +UPDATE message SET timestamp=timestamp*1000; + +CREATE TABLE thread ( + dcid TEXT PRIMARY KEY, + parent_chan_id TEXT NOT NULL, + root_msg_dcid TEXT NOT NULL, + root_msg_mxid TEXT NOT NULL, + -- This is also not accessed by the bridge. + receiver TEXT NOT NULL DEFAULT '', + + CONSTRAINT thread_parent_fkey FOREIGN KEY (parent_chan_id, receiver) REFERENCES portal(dcid, receiver) ON DELETE CASCADE ON UPDATE CASCADE +); + +ALTER TABLE message ADD COLUMN dc_thread_id TEXT; +ALTER TABLE attachment ADD COLUMN dc_thread_id TEXT; +ALTER TABLE reaction ADD COLUMN dc_thread_id TEXT; diff --git a/database/user.go b/database/user.go index 1d7b419..0150f16 100644 --- a/database/user.go +++ b/database/user.go @@ -2,6 +2,8 @@ package database import ( "database/sql" + "errors" + "time" log "maunium.net/go/maulogger/v2" @@ -22,38 +24,32 @@ func (uq *UserQuery) New() *User { } func (uq *UserQuery) GetByMXID(userID id.UserID) *User { - query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE mxid=$1` - row := uq.db.QueryRow(query, userID) - if row == nil { - return nil - } - - return uq.New().Scan(row) + query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE mxid=$1` + return uq.New().Scan(uq.db.QueryRow(query, userID)) } func (uq *UserQuery) GetByID(id string) *User { - query := `SELECT mxid, dcid, management_room, token FROM "user" WHERE dcid=$1` - row := uq.db.QueryRow(query, id) - if row == nil { - return nil - } - - return uq.New().Scan(row) + query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE dcid=$1` + return uq.New().Scan(uq.db.QueryRow(query, id)) } -func (uq *UserQuery) GetAll() []*User { - rows, err := uq.db.Query(`SELECT mxid, dcid, management_room, token FROM "user" WHERE token IS NOT NULL`) +func (uq *UserQuery) GetAllWithToken() []*User { + query := ` + SELECT mxid, dcid, discord_token, management_room, space_room + FROM "user" WHERE discord_token IS NOT NULL + ` + rows, err := uq.db.Query(query) if err != nil || rows == nil { return nil } - defer rows.Close() - - users := []*User{} + var users []*User for rows.Next() { - users = append(users, uq.New().Scan(rows)) + user := uq.New().Scan(rows) + if user != nil { + users = append(users, user) + } } - return users } @@ -61,80 +57,114 @@ type User struct { db *Database log log.Logger - MXID id.UserID - ID string - + MXID id.UserID + DiscordID string + DiscordToken string ManagementRoom id.RoomID + SpaceRoom id.RoomID +} - Token string +type UserGuild struct { + GuildID string + Timestamp time.Time + InSpace bool +} + +func (u *User) GetGuilds() []UserGuild { + res, err := u.db.Query("SELECT guild_id, timestamp, in_space FROM user_guild WHERE user_mxid=$1", u.MXID) + if err != nil { + u.log.Errorln("Failed to get guilds:", err) + panic(err) + return nil + } + var guilds []UserGuild + for res.Next() { + var guild UserGuild + var ts int64 + err = res.Scan(&guild.GuildID, &ts, &guild.InSpace) + if err != nil { + u.log.Errorln("Error scanning user guild:", err) + panic(err) + } else { + guild.Timestamp = time.UnixMilli(ts) + guilds = append(guilds, guild) + } + } + return guilds +} + +func (u *User) IsInSpace(guildID string) (isIn bool) { + query := `SELECT in_space FROM user_guild WHERE user_mxid=$1 AND guild_id=$2` + err := u.db.QueryRow(query, u.MXID, guildID).Scan(&isIn) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, guildID, err) + panic(err) + } + return +} + +func (u *User) MarkInGuild(guild UserGuild) { + query := ` + INSERT INTO user_guild (guild_id, user_mxid, timestamp, in_space) + VALUES ($1, $2, $3, $4) + ON CONFLICT (guild_id, user_mxid) DO UPDATE + SET timestamp=excluded.timestamp, in_space=excluded.in_space + ` + _, err := u.db.Exec(query, guild.GuildID, u.MXID, guild.Timestamp.UnixMilli(), guild.InSpace) + if err != nil { + u.log.Errorfln("Failed to insert user guild %s/%s: %v", u.MXID, guild.GuildID, err) + panic(err) + } +} + +func (u *User) MarkNotInGuild(guildID string) { + query := `DELETE FROM user_guild WHERE user_mxid=$1 AND guild_id=$2` + _, err := u.db.Exec(query, u.MXID, guildID) + if err != nil { + u.log.Errorfln("Failed to remove user guild %s/%s: %v", u.MXID, guildID, err) + panic(err) + } +} + +func (u *User) PruneGuildList(beforeTS time.Time) { + _, err := u.db.Exec("DELETE FROM user_guild WHERE user_mxid=$1 AND timestamp<$2", u.MXID, beforeTS.UnixMilli()) + if err != nil { + u.log.Errorln("Failed to prune user guild list:", err) + panic(err) + } } func (u *User) Scan(row dbutil.Scannable) *User { - var token sql.NullString - var discordID sql.NullString - - err := row.Scan(&u.MXID, &discordID, &u.ManagementRoom, &token) + var discordID, managementRoom, spaceRoom, discordToken sql.NullString + err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom) if err != nil { if err != sql.ErrNoRows { u.log.Errorln("Database scan failed:", err) + panic(err) } - return nil } - - if token.Valid { - u.Token = token.String - } - - if discordID.Valid { - u.ID = discordID.String - } - + u.DiscordID = discordID.String + u.DiscordToken = discordToken.String + u.ManagementRoom = id.RoomID(managementRoom.String) + u.SpaceRoom = id.RoomID(spaceRoom.String) return u } func (u *User) Insert() { - query := "INSERT INTO \"user\" (mxid, dcid, management_room, token) VALUES ($1, $2, $3, $4)" - - var token sql.NullString - var discordID sql.NullString - - if u.Token != "" { - token.String = u.Token - token.Valid = true - } - - if u.ID != "" { - discordID.String = u.ID - discordID.Valid = true - } - - _, err := u.db.Exec(query, u.MXID, discordID, u.ManagementRoom, token) - + query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room) VALUES ($1, $2, $3, $4, $5)` + _, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom))) if err != nil { u.log.Warnfln("Failed to insert %s: %v", u.MXID, err) + panic(err) } } func (u *User) Update() { - query := "UPDATE \"user\" SET dcid=$1, management_room=$2, token=$3 WHERE mxid=$4" - - var token sql.NullString - var discordID sql.NullString - - if u.Token != "" { - token.String = u.Token - token.Valid = true - } - - if u.ID != "" { - discordID.String = u.ID - discordID.Valid = true - } - - _, err := u.db.Exec(query, discordID, u.ManagementRoom, token, u.MXID) - + query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4 WHERE mxid=$5` + _, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), u.MXID) if err != nil { u.log.Warnfln("Failed to update %q: %v", u.MXID, err) + panic(err) } } diff --git a/go.mod b/go.mod index c6ae083..9b1bdf5 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,14 @@ module go.mau.fi/mautrix-discord go 1.17 require ( - github.com/bwmarrin/discordgo v0.23.2 + github.com/bwmarrin/discordgo v0.25.0 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.6 github.com/mattn/go-sqlite3 v1.14.13 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e maunium.net/go/maulogger/v2 v2.3.2 - maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365 + maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf ) require ( @@ -27,4 +27,4 @@ require ( maunium.net/go/mauflag v1.0.0 // indirect ) -replace github.com/bwmarrin/discordgo v0.23.2 => gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685 +replace github.com/bwmarrin/discordgo => gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7 diff --git a/go.sum b/go.sum index dbdc21e..09d1e53 100644 --- a/go.sum +++ b/go.sum @@ -3,7 +3,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.6 h1:jbk+ZieJ0D7EVGJYpL9QTz7/YW6UHbmdnZWYyK5cdBs= @@ -28,11 +28,12 @@ github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc= github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM= github.com/yuin/goldmark v1.4.12 h1:6hffw6vALvEDqJ19dOJvJKOoAOKe4NDaTqvd2sktGN0= github.com/yuin/goldmark v1.4.12/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685 h1:WLswEvEb4wi02p+VjMl8slhG/4+3pdfx04xQfqQjpoc= -gitlab.com/beeper/discordgo v0.23.3-0.20220527120019-b3fefa61d685/go.mod h1:Hwfv4M8yP/MDh47BN+4Z1WItJ1umLKUyplCH5KcQPgE= -golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7 h1:S8hbrkgKGU4aU5kXW4d8CA/9ayi8ymI3QU6yg/aWfUw= +gitlab.com/beeper/discordgo v0.23.3-0.20220528185832-6fcb85e150f7/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= +golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c= golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220513224357-95641704303c h1:nF9mHSvoKBLkQNQhJZNsc66z2UzAMUbLGjC95CF3pU0= golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= @@ -44,6 +45,7 @@ golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 h1:xHms4gcpe1YE7A3yIllJXP16C golang.org/x/sys v0.0.0-20220422013727-9388b58f7150/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -56,5 +58,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0= maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= -maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365 h1:ghL/KMHiMdcwKmuxvVYZvwe6Yk7BHeWlRfoHZovDA2Q= -maunium.net/go/mautrix v0.11.1-0.20220527114846-a9044b80e365/go.mod h1:CiKpMhAx5QZFHK03jpWb0iKI3sGU8x6+LfsOjDrcO8I= +maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf h1:hNCuEyp9oZPM08n7UsI/gq95ptbhOV1A8heFoH/eiM0= +maunium.net/go/mautrix v0.11.1-0.20220528195646-45d989575cbf/go.mod h1:CiKpMhAx5QZFHK03jpWb0iKI3sGU8x6+LfsOjDrcO8I= diff --git a/guildportal.go b/guildportal.go new file mode 100644 index 0000000..dbe030e --- /dev/null +++ b/guildportal.go @@ -0,0 +1,287 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2022 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "fmt" + "sync" + + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "github.com/bwmarrin/discordgo" + + "go.mau.fi/mautrix-discord/database" +) + +type Guild struct { + *database.Guild + + bridge *DiscordBridge + log log.Logger + + roomCreateLock sync.Mutex +} + +func (br *DiscordBridge) loadGuild(dbGuild *database.Guild, id string, createIfNotExist bool) *Guild { + if dbGuild == nil { + if id == "" || !createIfNotExist { + return nil + } + + dbGuild = br.DB.Guild.New() + dbGuild.ID = id + dbGuild.Insert() + } + + guild := br.NewGuild(dbGuild) + + br.guildsByID[guild.ID] = guild + if guild.MXID != "" { + br.guildsByMXID[guild.MXID] = guild + } + + return guild +} + +func (br *DiscordBridge) GetGuildByMXID(mxid id.RoomID) *Guild { + br.guildsLock.Lock() + defer br.guildsLock.Unlock() + + portal, ok := br.guildsByMXID[mxid] + if !ok { + return br.loadGuild(br.DB.Guild.GetByMXID(mxid), "", false) + } + + return portal +} + +func (br *DiscordBridge) GetGuildByID(id string, createIfNotExist bool) *Guild { + br.guildsLock.Lock() + defer br.guildsLock.Unlock() + + guild, ok := br.guildsByID[id] + if !ok { + return br.loadGuild(br.DB.Guild.GetByID(id), id, createIfNotExist) + } + + return guild +} + +func (br *DiscordBridge) GetAllGuilds() []*Guild { + return br.dbGuildsToGuilds(br.DB.Guild.GetAll()) +} + +func (br *DiscordBridge) dbGuildsToGuilds(dbGuilds []*database.Guild) []*Guild { + br.guildsLock.Lock() + defer br.guildsLock.Unlock() + + output := make([]*Guild, len(dbGuilds)) + for index, dbGuild := range dbGuilds { + if dbGuild == nil { + continue + } + + guild, ok := br.guildsByID[dbGuild.ID] + if !ok { + guild = br.loadGuild(dbGuild, "", false) + } + + output[index] = guild + } + + return output +} + +func (br *DiscordBridge) NewGuild(dbGuild *database.Guild) *Guild { + guild := &Guild{ + Guild: dbGuild, + bridge: br, + log: br.Log.Sub(fmt.Sprintf("Guild/%s", dbGuild.ID)), + } + + return guild +} + +func (guild *Guild) getBridgeInfo() (string, event.BridgeEventContent) { + bridgeInfo := event.BridgeEventContent{ + BridgeBot: guild.bridge.Bot.UserID, + Creator: guild.bridge.Bot.UserID, + Protocol: event.BridgeInfoSection{ + ID: "discord", + DisplayName: "Discord", + AvatarURL: guild.bridge.Config.AppService.Bot.ParsedAvatar.CUString(), + ExternalURL: "https://discord.com/", + }, + Channel: event.BridgeInfoSection{ + ID: guild.ID, + DisplayName: guild.Name, + AvatarURL: guild.AvatarURL.CUString(), + }, + } + bridgeInfoStateKey := fmt.Sprintf("fi.mau.discord://discord/%s", guild.ID) + return bridgeInfoStateKey, bridgeInfo +} + +func (guild *Guild) UpdateBridgeInfo() { + if len(guild.MXID) == 0 { + guild.log.Debugln("Not updating bridge info: no Matrix room created") + return + } + guild.log.Debugln("Updating bridge info...") + stateKey, content := guild.getBridgeInfo() + _, err := guild.bridge.Bot.SendStateEvent(guild.MXID, event.StateBridge, stateKey, content) + if err != nil { + guild.log.Warnln("Failed to update m.bridge:", err) + } + // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec + _, err = guild.bridge.Bot.SendStateEvent(guild.MXID, event.StateHalfShotBridge, stateKey, content) + if err != nil { + guild.log.Warnln("Failed to update uk.half-shot.bridge:", err) + } +} + +func (guild *Guild) CreateMatrixRoom(user *User, meta *discordgo.Guild) error { + guild.roomCreateLock.Lock() + defer guild.roomCreateLock.Unlock() + if guild.MXID != "" { + return nil + } + guild.log.Infoln("Creating Matrix room for guild") + guild.UpdateInfo(user, meta) + + bridgeInfoStateKey, bridgeInfo := guild.getBridgeInfo() + + initialState := []*event.Event{{ + Type: event.StateBridge, + Content: event.Content{Parsed: bridgeInfo}, + StateKey: &bridgeInfoStateKey, + }, { + // TODO remove this once https://github.com/matrix-org/matrix-doc/pull/2346 is in spec + Type: event.StateHalfShotBridge, + Content: event.Content{Parsed: bridgeInfo}, + StateKey: &bridgeInfoStateKey, + }} + + if !guild.AvatarURL.IsEmpty() { + initialState = append(initialState, &event.Event{ + Type: event.StateRoomAvatar, + Content: event.Content{Parsed: &event.RoomAvatarEventContent{ + URL: guild.AvatarURL, + }}, + }) + } + + creationContent := map[string]interface{}{ + "type": event.RoomTypeSpace, + } + if !guild.bridge.Config.Bridge.FederateRooms { + creationContent["m.federate"] = false + } + + resp, err := guild.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ + Visibility: "private", + Name: guild.Name, + Preset: "private_chat", + InitialState: initialState, + CreationContent: creationContent, + }) + if err != nil { + guild.log.Warnln("Failed to create room:", err) + return err + } + + guild.MXID = resp.RoomID + guild.NameSet = true + guild.AvatarSet = !guild.AvatarURL.IsEmpty() + guild.Update() + guild.bridge.guildsLock.Lock() + guild.bridge.guildsByMXID[guild.MXID] = guild + guild.bridge.guildsLock.Unlock() + guild.log.Infoln("Matrix room created:", guild.MXID) + + user.ensureInvited(nil, guild.MXID, false) + + return nil +} + +func (guild *Guild) UpdateInfo(source *User, meta *discordgo.Guild) *discordgo.Guild { + if meta.Unavailable { + return meta + } + changed := false + // FIXME + //name, err := guild.bridge.Config.Bridge.FormatChannelname(meta, user.Session) + //if err != nil { + // guild.log.Warnfln("failed to format name, proceeding with generic name: %v", err) + // guild.Name = meta.Name + //} else { + //} + changed = guild.UpdateName(meta.Name) || changed + changed = guild.UpdateAvatar(meta.Icon) || changed + if changed { + guild.UpdateBridgeInfo() + guild.Update() + } + return meta +} + +func (guild *Guild) UpdateName(name string) bool { + if guild.Name == name && guild.NameSet { + return false + } + guild.Name = name + guild.NameSet = false + if guild.MXID != "" { + _, err := guild.bridge.Bot.SetRoomName(guild.MXID, guild.Name) + if err != nil { + guild.log.Warnln("Failed to update room name: %s", err) + } else { + guild.NameSet = true + } + } + return true +} + +func (guild *Guild) UpdateAvatar(iconID string) bool { + if guild.Avatar == iconID && guild.AvatarSet { + return false + } + guild.AvatarSet = false + guild.Avatar = iconID + if guild.Avatar != "" { + var err error + guild.AvatarURL, err = uploadAvatar(guild.bridge.Bot, discordgo.EndpointGuildIcon(guild.ID, iconID)) + if err != nil { + guild.log.Warnln("Failed to reupload avatar %s: %v", iconID, err) + return true + } + } else { + guild.AvatarURL = id.ContentURI{} + } + if guild.MXID != "" { + _, err := guild.bridge.Bot.SetRoomAvatar(guild.MXID, guild.AvatarURL) + if err != nil { + guild.log.Warnln("Failed to update room avatar:", err) + } else { + guild.AvatarSet = true + } + } + return true +} diff --git a/main.go b/main.go index f14cb7e..d527a12 100644 --- a/main.go +++ b/main.go @@ -20,13 +20,13 @@ import ( _ "embed" "sync" - "go.mau.fi/mautrix-discord/database" "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/util/configupgrade" "go.mau.fi/mautrix-discord/config" + "go.mau.fi/mautrix-discord/database" ) // Information to find out exactly which commit the bridge was built from. @@ -59,6 +59,14 @@ type DiscordBridge struct { portalsByID map[database.PortalKey]*Portal portalsLock sync.Mutex + threadsByID map[string]*Thread + threadsByRootMXID map[id.EventID]*Thread + threadsLock sync.Mutex + + guildsByMXID map[id.RoomID]*Guild + guildsByID map[string]*Guild + guildsLock sync.Mutex + puppets map[string]*Puppet puppetsByCustomMXID map[id.UserID]*Puppet puppetsLock sync.Mutex @@ -81,6 +89,7 @@ func (br *DiscordBridge) Init() { br.RegisterCommands() br.DB = database.New(br.Bridge.DB) + discordLog = br.Log.Sub("Discord") } func (br *DiscordBridge) Start() { @@ -144,6 +153,12 @@ func main() { portalsByMXID: make(map[id.RoomID]*Portal), portalsByID: make(map[database.PortalKey]*Portal), + threadsByID: make(map[string]*Thread), + threadsByRootMXID: make(map[id.EventID]*Thread), + + guildsByID: make(map[string]*Guild), + guildsByMXID: make(map[id.RoomID]*Guild), + puppets: make(map[string]*Puppet), puppetsByCustomMXID: make(map[id.UserID]*Puppet), } diff --git a/portal.go b/portal.go index 9dda87f..5b68bf8 100644 --- a/portal.go +++ b/portal.go @@ -3,13 +3,15 @@ package main import ( "bytes" "fmt" + "strconv" "strings" "sync" "time" - "github.com/bwmarrin/discordgo" "maunium.net/go/mautrix/util/variationselector" + "github.com/bwmarrin/discordgo" + log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" @@ -25,6 +27,8 @@ import ( type portalDiscordMessage struct { msg interface{} user *User + + thread *Thread } type portalMatrixMessage struct { @@ -35,6 +39,9 @@ type portalMatrixMessage struct { type Portal struct { *database.Portal + Parent *Portal + Guild *Guild + bridge *DiscordBridge log log.Logger @@ -66,28 +73,37 @@ var ( portalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType} ) -func (br *DiscordBridge) loadPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal { - // If we weren't given a portal we'll attempt to create it if a key was - // provided. +func (br *DiscordBridge) loadPortal(dbPortal *database.Portal, key *database.PortalKey, chanType discordgo.ChannelType) *Portal { if dbPortal == nil { - if key == nil { + if key == nil || chanType < 0 { return nil } dbPortal = br.DB.Portal.New() dbPortal.Key = *key + dbPortal.Type = chanType dbPortal.Insert() } portal := br.NewPortal(dbPortal) - // No need to lock, it is assumed that our callers have already acquired - // the lock. br.portalsByID[portal.Key] = portal if portal.MXID != "" { br.portalsByMXID[portal.MXID] = portal } + if portal.GuildID != "" { + portal.Guild = portal.bridge.GetGuildByID(portal.GuildID, true) + } + if portal.ParentID != "" { + parentKey := database.NewPortalKey(portal.ParentID, "") + var ok bool + portal.Parent, ok = br.portalsByID[parentKey] + if !ok { + portal.Parent = br.loadPortal(br.DB.Portal.GetByID(parentKey), nil, -1) + } + } + return portal } @@ -97,19 +113,48 @@ func (br *DiscordBridge) GetPortalByMXID(mxid id.RoomID) *Portal { portal, ok := br.portalsByMXID[mxid] if !ok { - return br.loadPortal(br.DB.Portal.GetByMXID(mxid), nil) + return br.loadPortal(br.DB.Portal.GetByMXID(mxid), nil, -1) } return portal } -func (br *DiscordBridge) GetPortalByID(key database.PortalKey) *Portal { +func (user *User) GetPortalByMeta(meta *discordgo.Channel) *Portal { + return user.GetPortalByID(meta.ID, meta.Type) +} + +func (user *User) GetExistingPortalByID(id string) *Portal { + return user.bridge.GetExistingPortalByID(database.NewPortalKey(id, user.DiscordID)) +} + +func (user *User) GetPortalByID(id string, chanType discordgo.ChannelType) *Portal { + return user.bridge.GetPortalByID(database.NewPortalKey(id, user.DiscordID), chanType) +} + +func (br *DiscordBridge) GetExistingPortalByID(key database.PortalKey) *Portal { br.portalsLock.Lock() defer br.portalsLock.Unlock() + portal, ok := br.portalsByID[key] + if !ok { + portal, ok = br.portalsByID[database.NewPortalKey(key.ChannelID, "")] + if !ok { + return br.loadPortal(br.DB.Portal.GetByID(key), nil, -1) + } + } + + return portal +} + +func (br *DiscordBridge) GetPortalByID(key database.PortalKey, chanType discordgo.ChannelType) *Portal { + br.portalsLock.Lock() + defer br.portalsLock.Unlock() + if chanType != discordgo.ChannelTypeDM { + key.Receiver = "" + } portal, ok := br.portalsByID[key] if !ok { - return br.loadPortal(br.DB.Portal.GetByID(key), &key) + return br.loadPortal(br.DB.Portal.GetByID(key), &key, chanType) } return portal @@ -135,7 +180,7 @@ func (br *DiscordBridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Por portal, ok := br.portalsByID[dbPortal.Key] if !ok { - portal = br.loadPortal(dbPortal, nil) + portal = br.loadPortal(dbPortal, nil, -1) } output[index] = portal @@ -192,14 +237,25 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { AvatarURL: portal.bridge.Config.AppService.Bot.ParsedAvatar.CUString(), ExternalURL: "https://discord.com/", }, - // TODO use guild as network Channel: event.BridgeInfoSection{ ID: portal.Key.ChannelID, DisplayName: portal.Name, - AvatarURL: portal.AvatarURL.CUString(), }, } - bridgeInfoStateKey := fmt.Sprintf("fi.mau.discord://discord/%s", portal.Key.ChannelID) + var bridgeInfoStateKey string + if portal.GuildID == "" { + bridgeInfoStateKey = fmt.Sprintf("fi.mau.discord://discord/dm/%s", portal.Key.ChannelID) + } else { + bridgeInfo.Network = &event.BridgeInfoSection{ + ID: portal.GuildID, + } + if portal.Guild != nil { + bridgeInfo.Network.DisplayName = portal.Guild.Name + bridgeInfo.Network.AvatarURL = portal.Guild.AvatarURL.CUString() + // TODO is it possible to find the URL? + } + bridgeInfoStateKey = fmt.Sprintf("fi.mau.discord://discord/%s/%s", portal.GuildID, portal.Key.ChannelID) + } return bridgeInfoStateKey, bridgeInfo } @@ -221,40 +277,22 @@ func (portal *Portal) UpdateBridgeInfo() { } } -func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) error { +func (portal *Portal) CreateMatrixRoom(user *User, channel *discordgo.Channel) error { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if portal.MXID != "" { return nil } + portal.log.Infoln("Creating Matrix room for channel") - portal.Type = channel.Type - if portal.Type == discordgo.ChannelTypeDM { - portal.OtherUserID = channel.Recipients[0].ID - } + channel = portal.UpdateInfo(user, channel) intent := portal.MainIntent() if err := intent.EnsureRegistered(); err != nil { return err } - name, err := portal.bridge.Config.Bridge.FormatChannelname(channel, user.Session) - if err != nil { - portal.log.Warnfln("failed to format name, proceeding with generic name: %v", err) - portal.Name = channel.Name - } else { - portal.Name = name - } - - portal.Topic = channel.Topic - - // TODO: get avatars figured out - // portal.Avatar = puppet.Avatar - // portal.AvatarURL = puppet.AvatarURL - - portal.log.Infoln("Creating Matrix room for channel:", portal.Portal.Key.ChannelID) bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() - initialState := []*event.Event{{ Type: event.StateBridge, Content: event.Content{Parsed: bridgeInfo}, @@ -266,10 +304,48 @@ func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) e StateKey: &bridgeInfoStateKey, }} + if !portal.AvatarURL.IsEmpty() { + initialState = append(initialState, &event.Event{ + Type: event.StateRoomAvatar, + Content: event.Content{Parsed: &event.RoomAvatarEventContent{ + URL: portal.AvatarURL, + }}, + }) + } + creationContent := make(map[string]interface{}) + if portal.Type == discordgo.ChannelTypeGuildCategory { + creationContent["type"] = event.RoomTypeSpace + } if !portal.bridge.Config.Bridge.FederateRooms { creationContent["m.federate"] = false } + spaceID := portal.ExpectedSpaceID() + if spaceID != "" { + spaceIDStr := spaceID.String() + initialState = append(initialState, &event.Event{ + Type: event.StateSpaceParent, + StateKey: &spaceIDStr, + Content: event.Content{Parsed: &event.SpaceParentEventContent{ + Via: []string{portal.bridge.AS.HomeserverDomain}, + Canonical: true, + }}, + }) + } + if portal.Guild != nil && portal.Guild.MXID != "" { + // TODO don't do this for private channels in guilds + initialState = append(initialState, &event.Event{ + Type: event.StateJoinRules, + Content: event.Content{Parsed: &event.JoinRulesEventContent{ + JoinRule: event.JoinRuleRestricted, + Allow: []event.JoinRuleAllow{{ + RoomID: spaceID, + Type: event.JoinRuleAllowRoomMembership, + }}, + }}, + }) + } + // TODO set restricted join rule based on guild var invite []id.UserID @@ -302,13 +378,17 @@ func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) e return err } + portal.NameSet = true + portal.TopicSet = true + portal.AvatarSet = !portal.AvatarURL.IsEmpty() portal.MXID = resp.RoomID - portal.Update() portal.bridge.portalsLock.Lock() portal.bridge.portalsByMXID[portal.MXID] = portal portal.bridge.portalsLock.Unlock() + portal.Update() portal.log.Infoln("Matrix room created:", portal.MXID) + portal.updateSpace() portal.ensureUserInvited(user) user.syncChatDoublePuppetDetails(portal, true) @@ -334,39 +414,30 @@ func (portal *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) e func (portal *Portal) handleDiscordMessages(msg portalDiscordMessage) { if portal.MXID == "" { - discordMsg, ok := msg.msg.(*discordgo.MessageCreate) + _, ok := msg.msg.(*discordgo.MessageCreate) if !ok { portal.log.Warnln("Can't create Matrix room from non new message event") return } portal.log.Debugln("Creating Matrix room from incoming message") - - channel, err := msg.user.Session.Channel(discordMsg.ChannelID) - if err != nil { - portal.log.Errorln("Failed to find channel for message:", err) - - return - } - - if err := portal.createMatrixRoom(msg.user, channel); err != nil { + if err := portal.CreateMatrixRoom(msg.user, nil); err != nil { portal.log.Errorln("Failed to create portal room:", err) - return } } - switch msg.msg.(type) { + switch convertedMsg := msg.msg.(type) { case *discordgo.MessageCreate: - portal.handleDiscordMessageCreate(msg.user, msg.msg.(*discordgo.MessageCreate).Message) + portal.handleDiscordMessageCreate(msg.user, convertedMsg.Message, msg.thread) case *discordgo.MessageUpdate: - portal.handleDiscordMessagesUpdate(msg.user, msg.msg.(*discordgo.MessageUpdate).Message) + portal.handleDiscordMessageUpdate(msg.user, convertedMsg.Message) case *discordgo.MessageDelete: - portal.handleDiscordMessageDelete(msg.user, msg.msg.(*discordgo.MessageDelete).Message) + portal.handleDiscordMessageDelete(msg.user, convertedMsg.Message) case *discordgo.MessageReactionAdd: - portal.handleDiscordReaction(msg.user, msg.msg.(*discordgo.MessageReactionAdd).MessageReaction, true) + portal.handleDiscordReaction(msg.user, convertedMsg.MessageReaction, true, msg.thread) case *discordgo.MessageReactionRemove: - portal.handleDiscordReaction(msg.user, msg.msg.(*discordgo.MessageReactionRemove).MessageReaction, false) + portal.handleDiscordReaction(msg.user, convertedMsg.MessageReaction, false, msg.thread) default: portal.log.Warnln("unknown message type") } @@ -376,13 +447,14 @@ func (portal *Portal) ensureUserInvited(user *User) bool { return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat()) } -func (portal *Portal) markMessageHandled(discordID string, mxid id.EventID, authorID string, timestamp time.Time) *database.Message { +func (portal *Portal) markMessageHandled(discordID string, mxid id.EventID, authorID string, timestamp time.Time, threadID string) *database.Message { msg := portal.bridge.DB.Message.New() msg.Channel = portal.Key msg.DiscordID = discordID msg.MXID = mxid msg.SenderID = authorID msg.Timestamp = timestamp + msg.ThreadID = threadID msg.Insert() return msg } @@ -399,7 +471,7 @@ func (portal *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridg } } -func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgID string, attachment *discordgo.MessageAttachment) { +func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgID string, attachment *discordgo.MessageAttachment, ts time.Time, threadRelation *event.RelatesTo, threadID string) { // var captionContent *event.MessageEventContent // if attachment.Description != "" { @@ -420,6 +492,7 @@ func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgI // This gets overwritten later after the file is uploaded to the homeserver Size: attachment.Size, }, + RelatesTo: threadRelation, } switch strings.ToLower(strings.Split(attachment.ContentType, "/")[0]) { @@ -447,20 +520,25 @@ func (portal *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgI return } - resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) + resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, ts.UnixMilli()) if err != nil { portal.log.Warnfln("failed to send media message to matrix: %v", err) } dbAttachment := portal.bridge.DB.Attachment.New() dbAttachment.Channel = portal.Key - dbAttachment.DiscordMessageID = msgID - dbAttachment.DiscordAttachmentID = attachment.ID + dbAttachment.MessageID = msgID + dbAttachment.ID = attachment.ID dbAttachment.MXID = resp.EventID + dbAttachment.ThreadID = threadID dbAttachment.Insert() + // Update the fallback reply event for the next attachment + if threadRelation != nil { + threadRelation.InReplyTo.EventID = resp.EventID + } } -func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message) { +func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message, thread *Thread) { if portal.MXID == "" { portal.log.Warnln("handle message called without a valid portal") @@ -469,22 +547,22 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess // Handle room name changes if msg.Type == discordgo.MessageTypeChannelNameChange { - channel, err := user.Session.Channel(msg.ChannelID) - if err != nil { - portal.log.Errorf("Failed to find the channel for portal %s", portal.Key) - return - } - - name, err := portal.bridge.Config.Bridge.FormatChannelname(channel, user.Session) - if err != nil { - portal.log.Errorf("Failed to format name for portal %s", portal.Key) - return - } - - portal.Name = name - portal.Update() - - portal.MainIntent().SetRoomName(portal.MXID, name) + //channel, err := user.Session.Channel(msg.ChannelID) + //if err != nil { + // portal.log.Errorf("Failed to find the channel for portal %s", portal.Key) + // return + //} + // + //name, err := portal.bridge.Config.Bridge.FormatChannelname(channel, user.Session) + //if err != nil { + // portal.log.Errorf("Failed to format name for portal %s", portal.Key) + // return + //} + // + //portal.Name = name + //portal.Update() + // + //portal.MainIntent().SetRoomName(portal.MXID, name) return } @@ -492,57 +570,87 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess // Handle normal message existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) if existing != nil { - portal.log.Debugln("not handling duplicate message", msg.ID) - + portal.log.Debugln("Dropping duplicate message", msg.ID) return } + portal.log.Debugfln("Starting handling of %s by %s", msg.ID, msg.Author.ID) puppet := portal.bridge.GetPuppetByID(msg.Author.ID) - puppet.SyncContact(user) + puppet.UpdateInfo(user, msg.Author) intent := puppet.IntentFor(portal) + var threadRelation *event.RelatesTo + var threadID string + if thread != nil { + threadID = thread.ID + lastEventID := thread.RootMXID + lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID) + if lastInThread != nil { + lastEventID = lastInThread.MXID + } + threadRelation = (&event.RelatesTo{}).SetThread(thread.RootMXID, lastEventID) + } + + ts, _ := discordgo.SnowflakeTimestamp(msg.ID) if msg.Content != "" { content := &event.MessageEventContent{ - Body: msg.Content, - MsgType: event.MsgText, + Body: msg.Content, + MsgType: event.MsgText, + RelatesTo: threadRelation.Copy(), } - if msg.MessageReference != nil && msg.MessageReference.ChannelID == portal.Key.ChannelID { + if msg.MessageReference != nil { //key := database.PortalKey{msg.MessageReference.ChannelID, user.ID} replyTo := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.MessageReference.MessageID) - if replyTo != nil { - content.RelatesTo = &event.RelatesTo{ - Type: event.RelReply, - EventID: existing.MXID, + if content.RelatesTo == nil { + content.RelatesTo = &event.RelatesTo{} } + content.RelatesTo.SetReplyTo(replyTo.MXID) } } - resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) + resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, ts.UnixMilli()) if err != nil { portal.log.Warnfln("failed to send message %q to matrix: %v", msg.ID, err) - return } - ts, _ := msg.Timestamp.Parse() - portal.markMessageHandled(msg.ID, resp.EventID, msg.Author.ID, ts) + portal.markMessageHandled(msg.ID, resp.EventID, msg.Author.ID, ts, threadID) + // Update the fallback reply event for attachments + if threadRelation != nil { + threadRelation.InReplyTo.EventID = resp.EventID + } } - // now run through any attachments the message has for _, attachment := range msg.Attachments { - portal.handleDiscordAttachment(intent, msg.ID, attachment) + portal.handleDiscordAttachment(intent, msg.ID, attachment, ts, threadRelation, threadID) } } -func (portal *Portal) handleDiscordMessagesUpdate(user *User, msg *discordgo.Message) { +func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Message) { if portal.MXID == "" { portal.log.Warnln("handle message called without a valid portal") return } + existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) + if existing == nil { + portal.log.Warnfln("Dropping update of unknown message %s", msg.ID) + return + } + + if msg.Flags == discordgo.MessageFlagsHasThread { + portal.bridge.GetThreadByID(msg.ID, existing) + portal.log.Debugfln("Marked %s as a thread root", msg.ID) + // TODO make autojoining optional + err := user.Session.ThreadJoinWithLocation(msg.ID, discordgo.ThreadJoinLocationContextMenu) + if err != nil { + user.log.Warnfln("Error autojoining thread %s@%s: %v", msg.ChannelID, portal.Key.ChannelID, err) + } + } + // There's a few scenarios where the author is nil but I haven't figured // them all out yet. if msg.Author == nil { @@ -558,11 +666,11 @@ func (portal *Portal) handleDiscordMessagesUpdate(user *User, msg *discordgo.Mes } portal.log.Errorfln("author is nil: %#v", msg) + return } intent := portal.bridge.GetPuppetByID(msg.Author.ID).IntentFor(portal) - existing := portal.bridge.DB.Message.GetByDiscordID(portal.Key, msg.ID) if existing == nil { // Due to the differences in Discord and Matrix attachment handling, // existing will return nil if the original message was empty as we @@ -576,7 +684,7 @@ func (portal *Portal) handleDiscordMessagesUpdate(user *User, msg *discordgo.Mes attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID) for _, attachment := range attachments { - attachmentMap[attachment.DiscordAttachmentID] = attachment + attachmentMap[attachment.ID] = attachment } // Now run through the list of attachments on this message and remove @@ -632,30 +740,27 @@ func (portal *Portal) handleDiscordMessageDelete(user *User, msg *discordgo.Mess intent := portal.MainIntent() if existing != nil { + attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID) + for _, attachment := range attachments { + _, err := intent.RedactEvent(portal.MXID, attachment.MXID) + if err != nil { + portal.log.Warnfln("Failed to redact attachment %s: %v", attachment.MXID, err) + } + attachment.Delete() + } + _, err := intent.RedactEvent(portal.MXID, existing.MXID) if err != nil { - portal.log.Warnfln("Failed to remove message %s: %v", existing.MXID, err) + portal.log.Warnfln("Failed to redact message %s: %v", existing.MXID, err) } - existing.Delete() } - - // Now delete all of the existing attachments. - attachments := portal.bridge.DB.Attachment.GetAllByDiscordMessageID(portal.Key, msg.ID) - for _, attachment := range attachments { - _, err := intent.RedactEvent(portal.MXID, attachment.MXID) - if err != nil { - portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err) - } - - attachment.Delete() - } } func (portal *Portal) syncParticipants(source *User, participants []*discordgo.User) { for _, participant := range participants { puppet := portal.bridge.GetPuppetByID(participant.ID) - puppet.SyncContact(source) + puppet.UpdateInfo(source, participant) user := portal.bridge.GetUserByID(participant.ID) if user != nil { @@ -733,66 +838,63 @@ func (portal *Portal) handleMatrixMessages(msg portalMatrixMessage) { } } +const discordEpoch = 1420070400000 + +func generateNonce() string { + snowflake := (time.Now().UnixMilli() - discordEpoch) << 22 + // Nonce snowflakes don't have internal IDs or increments + return strconv.FormatInt(snowflake, 10) +} + func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { - if portal.IsPrivateChat() && sender.ID != portal.Key.Receiver { + if portal.IsPrivateChat() && sender.DiscordID != portal.Key.Receiver { return } content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok { portal.log.Debugfln("Failed to handle event %s: unexpected parsed content type %T", evt.ID, evt.Content.Parsed) - return } - if content.RelatesTo != nil && content.RelatesTo.Type == event.RelReplace { - edits := portal.bridge.DB.Message.GetByMXID(portal.Key, content.RelatesTo.EventID) + channelID := portal.Key.ChannelID + var threadID string + if editMXID := content.GetRelatesTo().GetReplaceID(); editMXID != "" && content.NewContent != nil { + edits := portal.bridge.DB.Message.GetByMXID(portal.Key, editMXID) if edits != nil { // we don't have anything to save for the update message right now // as we're not tracking edited timestamps. - _, err := sender.Session.ChannelMessageEdit(portal.Key.ChannelID, - edits.DiscordID, content.NewContent.Body) + _, err := sender.Session.ChannelMessageEdit(edits.DiscordProtoChannelID(), edits.DiscordID, content.NewContent.Body) if err != nil { portal.log.Errorln("Failed to update message %s: %v", edits.DiscordID, err) - - return } } - return + } else if threadRoot := content.GetRelatesTo().GetThreadParent(); threadRoot != "" { + existingThread := portal.bridge.DB.Thread.GetByMatrixRootMsg(threadRoot) + if existingThread != nil { + channelID = existingThread.ID + threadID = existingThread.ID + } else { + // TODO create new thread + } } - var msg *discordgo.Message - var err error + var sendReq discordgo.MessageSend switch content.MsgType { case event.MsgText, event.MsgEmote, event.MsgNotice: - sent := false - - if content.RelatesTo != nil && content.RelatesTo.Type == event.RelReply { - replyTo := portal.bridge.DB.Message.GetByMXID( - portal.Key, - content.RelatesTo.EventID, - ) - - if replyTo != nil { - msg, err = sender.Session.ChannelMessageSendReply( - portal.Key.ChannelID, - content.Body, - &discordgo.MessageReference{ - ChannelID: portal.Key.ChannelID, - MessageID: replyTo.DiscordID, - }, - ) - if err == nil { - sent = true + if replyToMXID := content.GetReplyTo(); replyToMXID != "" { + replyTo := portal.bridge.DB.Message.GetByMXID(portal.Key, replyToMXID) + if replyTo != nil && replyTo.ThreadID == threadID { + sendReq.Reference = &discordgo.MessageReference{ + ChannelID: channelID, + MessageID: replyTo.DiscordID, } } } - if !sent { - msg, err = sender.Session.ChannelMessageSend(portal.Key.ChannelID, content.Body) - } + sendReq.Content = content.Body case event.MsgAudio, event.MsgFile, event.MsgImage, event.MsgVideo: data, err := portal.downloadMatrixAttachment(evt.ID, content) if err != nil { @@ -801,23 +903,19 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { return } - msgSend := &discordgo.MessageSend{ - Files: []*discordgo.File{{ - Name: content.Body, - ContentType: content.Info.MimeType, - Reader: bytes.NewReader(data), - }}, - } - - msg, err = sender.Session.ChannelMessageSendComplex(portal.Key.ChannelID, msgSend) + sendReq.Files = []*discordgo.File{{ + Name: content.Body, + ContentType: content.Info.MimeType, + Reader: bytes.NewReader(data), + }} default: - portal.log.Warnln("unknown message type:", content.MsgType) + portal.log.Warnln("Unknown message type", content.MsgType) return } - + sendReq.Nonce = generateNonce() + msg, err := sender.Session.ChannelMessageSendComplex(channelID, &sendReq) if err != nil { portal.log.Errorfln("Failed to send message: %v", err) - return } @@ -826,16 +924,16 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { dbMsg.Channel = portal.Key dbMsg.DiscordID = msg.ID dbMsg.MXID = evt.ID - dbMsg.SenderID = sender.ID - // TODO use actual timestamp - dbMsg.Timestamp = time.Now() + dbMsg.SenderID = sender.DiscordID + dbMsg.Timestamp, _ = discordgo.SnowflakeTimestamp(msg.ID) + dbMsg.ThreadID = threadID dbMsg.Insert() } } func (portal *Portal) HandleMatrixLeave(brSender bridge.User) { portal.log.Debugln("User left private chat portal, cleaning up and deleting...") - portal.delete() + portal.Delete() portal.cleanup(false) // TODO: figure out how to close a dm from the API. @@ -848,11 +946,11 @@ func (portal *Portal) leave(sender *User) { return } - intent := portal.bridge.GetPuppetByID(sender.ID).IntentFor(portal) + intent := portal.bridge.GetPuppetByID(sender.DiscordID).IntentFor(portal) intent.LeaveRoom(portal.MXID) } -func (portal *Portal) delete() { +func (portal *Portal) Delete() { portal.Portal.Delete() portal.bridge.portalsLock.Lock() delete(portal.bridge.portalsByID, portal.Key) @@ -865,22 +963,25 @@ func (portal *Portal) delete() { } func (portal *Portal) cleanupIfEmpty() { + if portal.MXID == "" { + return + } + users, err := portal.getMatrixUsers() if err != nil { portal.log.Errorfln("Failed to get Matrix user list to determine if portal needs to be cleaned up: %v", err) - return } if len(users) == 0 { portal.log.Infoln("Room seems to be empty, cleaning up...") - portal.delete() + portal.Delete() portal.cleanup(false) } } func (portal *Portal) cleanup(puppetsOnly bool) { - if portal.MXID != "" { + if portal.MXID == "" { return } @@ -889,7 +990,6 @@ func (portal *Portal) cleanup(puppetsOnly bool) { if err != nil { portal.log.Warnln("Failed to leave private chat portal with main intent:", err) } - return } @@ -897,7 +997,6 @@ func (portal *Portal) cleanup(puppetsOnly bool) { members, err := intent.JoinedMembers(portal.MXID) if err != nil { portal.log.Errorln("Failed to get portal members for cleanup:", err) - return } @@ -943,19 +1042,19 @@ func (portal *Portal) getMatrixUsers() ([]id.UserID, error) { return users, nil } -func (portal *Portal) handleMatrixReaction(user *User, evt *event.Event) { - if user.ID != portal.Key.Receiver { +func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { + if portal.IsPrivateChat() && sender.DiscordID != portal.Key.Receiver { return } reaction := evt.Content.AsReaction() if reaction.RelatesTo.Type != event.RelAnnotation { portal.log.Errorfln("Ignoring reaction %s due to unknown m.relates_to data", evt.ID) - return } - var discordID string + var discordID, threadID string + channelID := portal.Key.ChannelID msg := portal.bridge.DB.Message.GetByMXID(portal.Key, reaction.RelatesTo.EventID) @@ -973,15 +1072,14 @@ func (portal *Portal) handleMatrixReaction(user *User, evt *event.Event) { // table to keep them in sync and to avoid sending duplicates to Discord. if msg == nil { attachment := portal.bridge.DB.Attachment.GetByMatrixID(portal.Key, reaction.RelatesTo.EventID) - discordID = attachment.DiscordMessageID - } else { - if msg.DiscordID == "" { - portal.log.Debugf("Message %s has not yet been sent to discord", reaction.RelatesTo.EventID) - + if attachment == nil { return } - + discordID = attachment.MessageID + threadID = attachment.ThreadID + } else { discordID = msg.DiscordID + threadID = msg.ThreadID } // Figure out if this is a custom emoji or not. @@ -990,8 +1088,7 @@ func (portal *Portal) handleMatrixReaction(user *User, evt *event.Event) { uri, _ := id.ParseContentURI(emojiID) emoji := portal.bridge.DB.Emoji.GetByMatrixURL(uri) if emoji == nil { - portal.log.Errorfln("failed to find emoji for %s", emojiID) - + portal.log.Errorfln("Couldn't find emoji corresponding to %s", emojiID) return } @@ -1000,23 +1097,26 @@ func (portal *Portal) handleMatrixReaction(user *User, evt *event.Event) { emojiID = variationselector.Remove(emojiID) } - err := user.Session.MessageReactionAdd(portal.Key.ChannelID, discordID, emojiID) + if threadID != "" { + channelID = threadID + } + err := sender.Session.MessageReactionAdd(channelID, discordID, emojiID) if err != nil { - portal.log.Debugf("Failed to send reaction %s id:%s: %v", portal.Key, discordID, err) - + portal.log.Debugf("Failed to send reaction to %s: %v", discordID, err) return } dbReaction := portal.bridge.DB.Reaction.New() dbReaction.Channel = portal.Key dbReaction.MessageID = discordID - dbReaction.Sender = user.ID + dbReaction.Sender = sender.DiscordID dbReaction.EmojiName = emojiID + dbReaction.ThreadID = threadID dbReaction.MXID = evt.ID dbReaction.Insert() } -func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageReaction, add bool) { +func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.MessageReaction, add bool, thread *Thread) { intent := portal.bridge.GetPuppetByID(reaction.UserID).IntentFor(portal) var discordID string @@ -1067,7 +1167,6 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess if !add { if existing == nil { portal.log.Debugln("Failed to remove reaction for unknown message", reaction.MessageID) - return } @@ -1106,34 +1205,34 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess dbReaction.Sender = reaction.UserID dbReaction.EmojiName = discordID dbReaction.MXID = resp.EventID + if thread != nil { + dbReaction.ThreadID = thread.ID + } dbReaction.Insert() } } func (portal *Portal) handleMatrixRedaction(user *User, evt *event.Event) { - if user.ID != portal.Key.Receiver { + if user.DiscordID != portal.Key.Receiver { return } // First look if we're redacting a message message := portal.bridge.DB.Message.GetByMXID(portal.Key, evt.Redacts) if message != nil { - if message.DiscordID != "" { - err := user.Session.ChannelMessageDelete(portal.Key.ChannelID, message.DiscordID) - if err != nil { - portal.log.Debugfln("Failed to delete discord message %s: %v", message.DiscordID, err) - } else { - message.Delete() - } + err := user.Session.ChannelMessageDelete(message.DiscordProtoChannelID(), message.DiscordID) + if err != nil { + portal.log.Debugfln("Failed to delete discord message %s: %v", message.DiscordID, err) + } else { + message.Delete() } - return } // Now check if it's a reaction. reaction := portal.bridge.DB.Reaction.GetByMXID(evt.Redacts) if reaction != nil && reaction.Channel == portal.Key { - err := user.Session.MessageReactionRemove(portal.Key.ChannelID, reaction.MessageID, reaction.EmojiName, reaction.Sender) + err := user.Session.MessageReactionRemove(reaction.DiscordProtoChannelID(), reaction.MessageID, reaction.EmojiName, reaction.Sender) if err != nil { portal.log.Debugfln("Failed to delete reaction %s from %s: %v", reaction.EmojiName, reaction.MessageID, err) } else { @@ -1146,60 +1245,215 @@ func (portal *Portal) handleMatrixRedaction(user *User, evt *event.Event) { portal.log.Warnfln("Failed to redact %s: no event found", evt.Redacts) } -func (portal *Portal) update(user *User, channel *discordgo.Channel) { - name, err := portal.bridge.Config.Bridge.FormatChannelname(channel, user.Session) - if err != nil { - portal.log.Warnln("Failed to format channel name, using existing:", err) - } else { - portal.Name = name +func (portal *Portal) UpdateName(name string) bool { + if portal.Name == name && portal.NameSet { + return false + } else if !portal.Encrypted && portal.IsPrivateChat() { + // TODO custom config option for always setting private chat portal meta? + return false } - - intent := portal.MainIntent() - - if portal.Name != name { - _, err = intent.SetRoomName(portal.MXID, portal.Name) + portal.Name = name + portal.NameSet = false + if portal.MXID != "" { + _, err := portal.MainIntent().SetRoomName(portal.MXID, portal.Name) if err != nil { portal.log.Warnln("Failed to update room name:", err) + } else { + portal.NameSet = true } } + return true +} - if portal.Topic != channel.Topic { - portal.Topic = channel.Topic - _, err = intent.SetRoomTopic(portal.MXID, portal.Topic) +func (portal *Portal) UpdateAvatarFromPuppet(puppet *Puppet) bool { + if portal.Avatar == puppet.Avatar && portal.AvatarSet { + return false + } + portal.Avatar = puppet.Avatar + portal.AvatarURL = puppet.AvatarURL + portal.AvatarSet = false + portal.updateRoomAvatar() + return true +} + +func (portal *Portal) UpdateGroupDMAvatar(iconID string) bool { + if portal.Avatar == iconID && portal.AvatarSet { + return false + } + portal.Avatar = iconID + portal.AvatarSet = false + if portal.Avatar != "" { + uri, err := uploadAvatar(portal.MainIntent(), discordgo.EndpointGroupIcon(portal.Key.ChannelID, portal.Avatar)) + if err != nil { + portal.log.Warnln("Failed to reupload avatar:", err) + return true + } else { + portal.AvatarURL = uri + } + } else { + portal.AvatarURL = id.ContentURI{} + } + portal.updateRoomAvatar() + return true +} + +func (portal *Portal) updateRoomAvatar() { + if portal.MXID == "" { + return + } + _, err := portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL) + if err != nil { + portal.log.Warnln("Failed to update room avatar:", err) + } else { + portal.AvatarSet = true + } +} + +func (portal *Portal) UpdateTopic(topic string) bool { + if portal.Topic == topic && portal.TopicSet { + return false + } + portal.Topic = topic + portal.TopicSet = false + if portal.MXID != "" { + _, err := portal.MainIntent().SetRoomTopic(portal.MXID, portal.Topic) if err != nil { portal.log.Warnln("Failed to update room topic:", err) } } + return true +} - if portal.Avatar != channel.Icon { - portal.Avatar = channel.Icon - - var url string - - if portal.Type == discordgo.ChannelTypeDM { - dmUser, err := user.Session.User(portal.OtherUserID) - if err != nil { - portal.log.Warnln("failed to lookup the other user in DM", err) - } else { - url = dmUser.AvatarURL("") - } - } else { - url = discordgo.EndpointGroupIcon(channel.ID, channel.Icon) - } - - portal.AvatarURL = id.ContentURI{} - if url != "" { - uri, err := uploadAvatar(intent, url) - if err != nil { - portal.log.Warnf("failed to upload avatar", err) - } else { - portal.AvatarURL = uri - } - } - - intent.SetRoomAvatar(portal.MXID, portal.AvatarURL) +func (portal *Portal) removeFromSpace() { + if portal.InSpace == "" { + return } - portal.Update() - portal.log.Debugln("portal updated") + _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateSpaceParent, portal.InSpace.String(), struct{}{}) + if err != nil { + portal.log.Warnfln("Failed to unset canonical space %s: %v", portal.InSpace, err) + } + _, err = portal.bridge.Bot.SendStateEvent(portal.InSpace, event.StateSpaceChild, portal.MXID.String(), struct{}{}) + if err != nil { + portal.log.Warnfln("Failed to add room to space %s: %v", portal.InSpace, err) + } + portal.InSpace = "" +} + +func (portal *Portal) addToSpace(mxid id.RoomID) bool { + if portal.InSpace == mxid { + return false + } + portal.removeFromSpace() + + _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateSpaceParent, mxid.String(), &event.SpaceParentEventContent{ + Via: []string{portal.bridge.AS.HomeserverDomain}, + Canonical: true, + }) + if err != nil { + portal.log.Warnfln("Failed to set canonical space %s: %v", mxid, err) + } + + _, err = portal.bridge.Bot.SendStateEvent(mxid, event.StateSpaceChild, portal.MXID.String(), &event.SpaceChildEventContent{ + Via: []string{portal.bridge.AS.HomeserverDomain}, + // TODO order + }) + if err != nil { + portal.log.Warnfln("Failed to add room to space %s: %v", mxid, err) + } else { + portal.InSpace = mxid + } + return true +} + +func (portal *Portal) UpdateParent(parentID string) bool { + if portal.ParentID == parentID { + return false + } + portal.ParentID = parentID + if portal.ParentID != "" { + portal.Parent = portal.bridge.GetExistingPortalByID(database.NewPortalKey(parentID, "")) + } else { + portal.Parent = nil + } + return true +} + +func (portal *Portal) ExpectedSpaceID() id.RoomID { + if portal.Parent != nil { + return portal.Parent.MXID + } else if portal.Guild != nil { + return portal.Guild.MXID + } + return "" +} + +func (portal *Portal) updateSpace() bool { + if portal.MXID == "" { + return false + } + if portal.Parent != nil { + return portal.addToSpace(portal.Parent.MXID) + } else if portal.Guild != nil { + return portal.addToSpace(portal.Guild.MXID) + } + return false +} + +func (portal *Portal) UpdateInfo(source *User, meta *discordgo.Channel) *discordgo.Channel { + changed := false + + if portal.Type != meta.Type { + portal.log.Warnfln("Portal type changed from %d to %d", portal.Type, meta.Type) + portal.Type = meta.Type + changed = true + } + if portal.OtherUserID == "" && portal.IsPrivateChat() { + if len(meta.Recipients) == 0 { + var err error + meta, err = source.Session.Channel(meta.ID) + if err != nil { + portal.log.Errorfln("Failed to get DM channel info:", err) + } + } + portal.OtherUserID = meta.Recipients[0].ID + portal.log.Infoln("Found other user ID:", portal.OtherUserID) + changed = true + } + if meta.GuildID != "" && portal.GuildID == "" { + portal.GuildID = meta.GuildID + portal.Guild = portal.bridge.GetGuildByID(portal.GuildID, true) + changed = true + } + + // FIXME + //name, err := portal.bridge.Config.Bridge.FormatChannelname(meta, source.Session) + //if err != nil { + // portal.log.Errorln("Failed to format channel name:", err) + // return + //} + + switch portal.Type { + case discordgo.ChannelTypeDM: + if portal.OtherUserID != "" { + puppet := portal.bridge.GetPuppetByID(portal.OtherUserID) + changed = portal.UpdateAvatarFromPuppet(puppet) || changed + changed = portal.UpdateName(puppet.Name) || changed + } + case discordgo.ChannelTypeGroupDM: + changed = portal.UpdateGroupDMAvatar(meta.Icon) || changed + fallthrough + default: + changed = portal.UpdateName(meta.Name) || changed + } + changed = portal.UpdateTopic(meta.Topic) || changed + changed = portal.UpdateParent(meta.ParentID) || changed + if portal.MXID != "" && portal.ExpectedSpaceID() != portal.InSpace { + changed = portal.updateSpace() || changed + } + if changed { + portal.UpdateBridgeInfo() + portal.Update() + } + return meta } diff --git a/provisioning.go b/provisioning.go index 9c3fc4e..7a01825 100644 --- a/provisioning.go +++ b/provisioning.go @@ -49,7 +49,6 @@ func newProvisioningAPI(br *DiscordBridge) *ProvisioningAPI { r.HandleFunc("/logout", p.logout).Methods(http.MethodPost) r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost) - // Setup the guild endpoints r.HandleFunc("/guilds", p.guildsList).Methods(http.MethodGet) r.HandleFunc("/guilds/{guildID}/bridge", p.guildsBridge).Methods(http.MethodPost) r.HandleFunc("/guilds/{guildID}/unbridge", p.guildsUnbridge).Methods(http.MethodPost) @@ -182,8 +181,8 @@ func (p *ProvisioningAPI) ping(w http.ResponseWriter, r *http.Request) { } user.Lock() - if user.ID != "" { - discord["id"] = user.ID + if user.DiscordID != "" { + discord["id"] = user.DiscordID } if user.Session != nil { @@ -338,7 +337,7 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) { return } - user.ID = discordUser.UserID + user.DiscordID = discordUser.UserID user.Update() if err := user.Login(discordUser.Token); err != nil { @@ -354,7 +353,7 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) { c.WriteJSON(map[string]interface{}{ "success": true, - "id": user.ID, + "id": user.DiscordID, }) return @@ -392,19 +391,17 @@ func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) { func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) { user := r.Context().Value("user").(*User) - user.guildsLock.Lock() - defer user.guildsLock.Unlock() - - data := make([]map[string]interface{}, len(user.guilds)) - idx := 0 - for _, guild := range user.guilds { - data[idx] = map[string]interface{}{ - "name": guild.GuildName, - "id": guild.GuildID, - "bridged": guild.Bridge, + var data []map[string]interface{} + for _, userGuild := range user.GetGuilds() { + guild := p.bridge.GetGuildByID(userGuild.GuildID, false) + if guild == nil { + continue } - - idx++ + data = append(data, map[string]interface{}{ + "name": guild.Name, + "id": guild.ID, + "mxid": guild.MXID, + }) } jsonResponse(w, http.StatusOK, data) diff --git a/puppet.go b/puppet.go index a803b7c..843c150 100644 --- a/puppet.go +++ b/puppet.go @@ -7,6 +7,8 @@ import ( log "maunium.net/go/maulogger/v2" + "github.com/bwmarrin/discordgo" + "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/id" @@ -66,12 +68,12 @@ func (br *DiscordBridge) ParsePuppetMXID(mxid id.UserID) (string, bool) { } func (br *DiscordBridge) GetPuppetByMXID(mxid id.UserID) *Puppet { - id, ok := br.ParsePuppetMXID(mxid) + discordID, ok := br.ParsePuppetMXID(mxid) if !ok { return nil } - return br.GetPuppetByID(id) + return br.GetPuppetByID(discordID) } func (br *DiscordBridge) GetPuppetByID(id string) *Puppet { @@ -159,7 +161,7 @@ func (puppet *Puppet) DefaultIntent() *appservice.IntentAPI { } func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI { - if puppet.customIntent == nil { + if puppet.customIntent == nil || (portal.Key.Receiver != "" && portal.Key.Receiver != puppet.ID) { return puppet.DefaultIntent() } @@ -179,121 +181,86 @@ func (puppet *Puppet) updatePortalMeta(meta func(portal *Portal)) { } } -func (puppet *Puppet) updateName(source *User) bool { - user, err := source.Session.User(puppet.ID) - if err != nil { - puppet.log.Warnln("failed to get user from id:", err) +func (puppet *Puppet) UpdateName(info *discordgo.User) bool { + newName := puppet.bridge.Config.Bridge.FormatDisplayname(info) + if puppet.Name == newName && puppet.NameSet { return false } - - newName := puppet.bridge.Config.Bridge.FormatDisplayname(user) - - if puppet.DisplayName != newName { - err := puppet.DefaultIntent().SetDisplayName(newName) - if err == nil { - puppet.DisplayName = newName - go puppet.updatePortalName() - puppet.Update() - } else { - puppet.log.Warnln("failed to set display name:", err) - } - - return true - } - - return false -} - -func (puppet *Puppet) updatePortalName() { - puppet.updatePortalMeta(func(portal *Portal) { - if portal.MXID != "" { - _, err := portal.MainIntent().SetRoomName(portal.MXID, puppet.DisplayName) - if err != nil { - portal.log.Warnln("Failed to set name:", err) + puppet.Name = newName + puppet.NameSet = false + err := puppet.DefaultIntent().SetDisplayName(newName) + if err != nil { + puppet.log.Warnln("Failed to update displayname:", err) + } else { + go puppet.updatePortalMeta(func(portal *Portal) { + if portal.UpdateName(puppet.Name) { + portal.Update() } - } - - portal.Name = puppet.DisplayName - portal.Update() - }) -} - -func (puppet *Puppet) updateAvatar(source *User) bool { - user, err := source.Session.User(puppet.ID) - if err != nil { - puppet.log.Warnln("Failed to get user:", err) - - return false + }) + puppet.NameSet = true } - - if puppet.Avatar == user.Avatar { - return false - } - - if user.Avatar == "" { - puppet.log.Warnln("User does not have an avatar") - - return false - } - - url, err := uploadAvatar(puppet.DefaultIntent(), user.AvatarURL("")) - if err != nil { - puppet.log.Warnln("Failed to upload user avatar:", err) - - return false - } - - puppet.AvatarURL = url - - err = puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL) - if err != nil { - puppet.log.Warnln("Failed to set avatar:", err) - } - - puppet.log.Debugln("Updated avatar", puppet.Avatar, "->", user.Avatar) - puppet.Avatar = user.Avatar - go puppet.updatePortalAvatar() - return true } -func (puppet *Puppet) updatePortalAvatar() { - puppet.updatePortalMeta(func(portal *Portal) { - if portal.MXID != "" { - _, err := portal.MainIntent().SetRoomAvatar(portal.MXID, puppet.AvatarURL) - if err != nil { - portal.log.Warnln("Failed to set avatar:", err) - } +func (puppet *Puppet) UpdateAvatar(info *discordgo.User) bool { + if puppet.Avatar == info.Avatar && puppet.AvatarSet { + return false + } + puppet.Avatar = info.Avatar + puppet.AvatarSet = false + + if puppet.Avatar == "" { + // TODO should we just use discord's default avatars? + puppet.AvatarURL = id.ContentURI{} + } else { + url, err := uploadAvatar(puppet.DefaultIntent(), info.AvatarURL("")) + if err != nil { + puppet.log.Warnln("Failed to reupload user avatar:", err) + return true } + puppet.AvatarURL = url + } - portal.AvatarURL = puppet.AvatarURL - portal.Avatar = puppet.Avatar - portal.Update() - }) - + err := puppet.DefaultIntent().SetAvatarURL(puppet.AvatarURL) + if err != nil { + puppet.log.Warnln("Failed to update avatar:", err) + } else { + go puppet.updatePortalMeta(func(portal *Portal) { + if portal.UpdateAvatarFromPuppet(puppet) { + portal.Update() + } + }) + puppet.AvatarSet = true + } + return true } -func (puppet *Puppet) SyncContact(source *User) { +func (puppet *Puppet) UpdateInfo(source *User, info *discordgo.User) { puppet.syncLock.Lock() defer puppet.syncLock.Unlock() - puppet.log.Debugln("syncing contact", puppet.DisplayName) + if info == nil || len(info.Username) == 0 || len(info.Discriminator) == 0 { + if puppet.Name != "" { + return + } + var err error + puppet.log.Debugfln("Fetching info through %s to update", source.DiscordID) + info, err = source.Session.User(puppet.ID) + if err != nil { + puppet.log.Errorfln("Failed to fetch info through %s: %v", source.DiscordID, err) + return + } + } err := puppet.DefaultIntent().EnsureRegistered() if err != nil { puppet.log.Errorln("Failed to ensure registered:", err) } - update := false - - update = puppet.updateName(source) || update - - if puppet.Avatar == "" { - update = puppet.updateAvatar(source) || update - puppet.log.Debugln("update avatar returned", update) - } - - if update { + changed := false + changed = puppet.UpdateName(info) || changed + changed = puppet.UpdateAvatar(info) || changed + if changed { puppet.Update() } } diff --git a/remoteauth/client.go b/remoteauth/client.go index fdea7df..b039319 100644 --- a/remoteauth/client.go +++ b/remoteauth/client.go @@ -11,13 +11,14 @@ import ( "sync" "github.com/gorilla/websocket" + + "github.com/bwmarrin/discordgo" ) type Client struct { sync.Mutex - URL string - Origin string + URL string conn *websocket.Conn @@ -43,7 +44,6 @@ func New() (*Client, error) { return &Client{ URL: "wss://remote-auth-gateway.discord.gg/?v=1", - Origin: "https://discord.com", privateKey: privateKey, }, nil } @@ -54,8 +54,9 @@ func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan str c.Lock() defer c.Unlock() - header := http.Header{ - "Origin": []string{c.Origin}, + header := http.Header{} + for key, value := range discordgo.DroidWSHeaders { + header.Set(key, value) } c.qrChan = qrChan diff --git a/thread.go b/thread.go new file mode 100644 index 0000000..8d7468f --- /dev/null +++ b/thread.go @@ -0,0 +1,53 @@ +package main + +import ( + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-discord/database" +) + +type Thread struct { + *database.Thread + Parent *Portal +} + +func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread { + br.threadsLock.Lock() + defer br.threadsLock.Unlock() + thread, ok := br.threadsByID[id] + if !ok { + return br.loadThread(br.DB.Thread.GetByDiscordID(id), id, root) + } + return thread +} + +func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread { + br.threadsLock.Lock() + defer br.threadsLock.Unlock() + thread, ok := br.threadsByRootMXID[mxid] + if !ok { + return br.loadThread(br.DB.Thread.GetByMatrixRootMsg(mxid), "", nil) + } + return thread +} + +func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread { + if dbThread == nil { + if root == nil { + return nil + } + dbThread = br.DB.Thread.New() + dbThread.ID = id + dbThread.RootDiscordID = root.DiscordID + dbThread.RootMXID = root.MXID + dbThread.ParentID = root.Channel.ChannelID + dbThread.Insert() + } + thread := &Thread{ + Thread: dbThread, + } + thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, "")) + br.threadsByID[thread.ID] = thread + br.threadsByRootMXID[thread.RootMXID] = thread + return thread +} diff --git a/user.go b/user.go index cef4b2d..40699f5 100644 --- a/user.go +++ b/user.go @@ -4,11 +4,15 @@ import ( "errors" "fmt" "net/http" + "os" + "runtime" "strings" "sync" + "time" + + log "maunium.net/go/maulogger/v2" "github.com/bwmarrin/discordgo" - log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -35,12 +39,43 @@ type User struct { PermissionLevel bridgeconfig.PermissionLevel - guilds map[string]*database.Guild - guildsLock sync.Mutex + spaceCreateLock sync.Mutex + spaceMembershipChecked bool Session *discordgo.Session } +var discordLog log.Logger + +func init() { + discordgo.Logger = func(msgL, caller int, format string, a ...interface{}) { + pc, file, line, _ := runtime.Caller(caller + 1) + + files := strings.Split(file, "/") + file = files[len(files)-1] + + name := runtime.FuncForPC(pc).Name() + fns := strings.Split(name, ".") + name = fns[len(fns)-1] + + msg := fmt.Sprintf(format, a...) + + var level log.Level + switch msgL { + case discordgo.LogError: + level = log.LevelError + case discordgo.LogWarning: + level = log.LevelWarn + case discordgo.LogInformational: + level = log.LevelInfo + case discordgo.LogDebug: + level = log.LevelDebug + } + + discordLog.Logfln(level, "%s:%d:%s() %s", file, line, name, msg) + } +} + func (user *User) GetPermissionLevel() bridgeconfig.PermissionLevel { return user.PermissionLevel } @@ -66,10 +101,10 @@ func (user *User) GetIDoublePuppet() bridge.DoublePuppet { } func (user *User) GetIGhost() bridge.Ghost { - if user.ID == "" { + if user.DiscordID == "" { return nil } - p := user.bridge.GetPuppetByID(user.ID) + p := user.bridge.GetPuppetByID(user.DiscordID) if p == nil { return nil } @@ -78,14 +113,6 @@ func (user *User) GetIGhost() bridge.Ghost { var _ bridge.User = (*User)(nil) -// this assume you are holding the guilds lock!!! -func (user *User) loadGuilds() { - user.guilds = map[string]*database.Guild{} - for _, guild := range user.bridge.DB.Guild.GetAll(user.ID) { - user.guilds[guild.GuildID] = guild - } -} - func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { // If we weren't passed in a user we attempt to create one if we were given // a matrix id. @@ -103,8 +130,8 @@ func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User // We assume the usersLock was acquired by our caller. br.usersByMXID[user.MXID] = user - if user.ID != "" { - br.usersByID[user.ID] = user + if user.DiscordID != "" { + br.usersByID[user.DiscordID] = user } if user.ManagementRoom != "" { @@ -114,17 +141,10 @@ func (br *DiscordBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User br.managementRoomsLock.Unlock() } - // Load our guilds state from the database and turn it into a map - user.guildsLock.Lock() - user.loadGuilds() - user.guildsLock.Unlock() - return user } func (br *DiscordBridge) GetUserByMXID(userID id.UserID) *User { - // TODO: check if puppet - br.usersLock.Lock() defer br.usersLock.Unlock() @@ -153,7 +173,6 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User { User: dbUser, bridge: br, log: br.Log.Sub("User").Sub(string(dbUser.MXID)), - guilds: map[string]*database.Guild{}, } user.PermissionLevel = br.Config.Bridge.Permissions.Get(user.MXID) @@ -161,11 +180,11 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User { return user } -func (br *DiscordBridge) getAllUsers() []*User { +func (br *DiscordBridge) getAllUsersWithToken() []*User { br.usersLock.Lock() defer br.usersLock.Unlock() - dbUsers := br.DB.User.GetAll() + dbUsers := br.DB.User.GetAllWithToken() users := make([]*User, len(dbUsers)) for idx, dbUser := range dbUsers { @@ -182,7 +201,7 @@ func (br *DiscordBridge) getAllUsers() []*User { func (br *DiscordBridge) startUsers() { br.Log.Debugln("Starting users") - for _, u := range br.getAllUsers() { + for _, u := range br.getAllUsersWithToken() { go func(user *User) { err := user.Connect() if err != nil { @@ -209,10 +228,6 @@ func (user *User) SetManagementRoom(roomID id.RoomID) { existing, ok := user.bridge.managementRooms[roomID] if ok { - // If there's a user already assigned to this management room, clear it - // out. - // I think this is due a name change or something? I dunno, leaving it - // for now. existing.ManagementRoom = "" existing.Update() } @@ -222,6 +237,52 @@ func (user *User) SetManagementRoom(roomID id.RoomID) { user.Update() } +func (user *User) GetSpaceRoom() id.RoomID { + if len(user.SpaceRoom) == 0 { + user.spaceCreateLock.Lock() + defer user.spaceCreateLock.Unlock() + if len(user.SpaceRoom) > 0 { + return user.SpaceRoom + } + + resp, err := user.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ + Visibility: "private", + Name: "Discord", + Topic: "Your Discord bridged chats", + InitialState: []*event.Event{{ + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{ + URL: user.bridge.Config.AppService.Bot.ParsedAvatar, + }, + }, + }}, + CreationContent: map[string]interface{}{ + "type": event.RoomTypeSpace, + }, + PowerLevelOverride: &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + user.bridge.Bot.UserID: 9001, + user.MXID: 50, + }, + }, + }) + + if err != nil { + user.log.Errorln("Failed to auto-create space room:", err) + } else { + user.SpaceRoom = resp.RoomID + user.Update() + user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false) + } + } else if !user.spaceMembershipChecked && !user.bridge.StateStore.IsInRoom(user.SpaceRoom, user.MXID) { + user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false) + } + user.spaceMembershipChecked = true + + return user.SpaceRoom +} + func (user *User) tryAutomaticDoublePuppeting() { user.Lock() defer user.Unlock() @@ -232,7 +293,7 @@ func (user *User) tryAutomaticDoublePuppeting() { user.log.Debugln("Checking if double puppeting needs to be enabled") - puppet := user.bridge.GetPuppetByID(user.ID) + puppet := user.bridge.GetPuppetByID(user.DiscordID) if puppet.CustomMXID != "" { user.log.Debugln("User already has double-puppeting enabled") @@ -270,7 +331,7 @@ func (user *User) syncChatDoublePuppetDetails(portal *Portal, justCreated bool) } func (user *User) Login(token string) error { - user.Token = token + user.DiscordToken = token user.Update() return user.Connect() } @@ -279,7 +340,7 @@ func (user *User) IsLoggedIn() bool { user.Lock() defer user.Unlock() - return user.Token != "" + return user.DiscordToken != "" } func (user *User) Logout() error { @@ -290,7 +351,7 @@ func (user *User) Logout() error { return ErrNotLoggedIn } - puppet := user.bridge.GetPuppetByID(user.ID) + puppet := user.bridge.GetPuppetByID(user.DiscordID) if puppet.CustomMXID != "" { err := puppet.SwitchCustomMXID("", "") if err != nil { @@ -304,7 +365,7 @@ func (user *User) Logout() error { user.Session = nil - user.Token = "" + user.DiscordToken = "" user.Update() return nil @@ -321,16 +382,20 @@ func (user *User) Connect() error { user.Lock() defer user.Unlock() - if user.Token == "" { + if user.DiscordToken == "" { return ErrNotLoggedIn } - user.log.Debugln("connecting to discord") + user.log.Debugln("Connecting to discord") - session, err := discordgo.New(user.Token) + session, err := discordgo.New(user.DiscordToken) if err != nil { return err } + // TODO move to config + if os.Getenv("DISCORD_DEBUG") == "1" { + session.LogLevel = discordgo.LogDebug + } user.Session = session @@ -382,263 +447,178 @@ func (user *User) bridgeMessage(guildID string) bool { return true } - user.guildsLock.Lock() - defer user.guildsLock.Unlock() - - if guild, found := user.guilds[guildID]; found { - if guild.Bridge { - return true - } + guild := user.bridge.GetGuildByID(guildID, false) + if guild.MXID != "" { + return true } - user.log.Debugfln("ignoring message for non-bridged guild %s-%s", user.ID, guildID) - + user.log.Debugfln("Cgnoring message for non-bridged guild %s", guildID) return false } -func (user *User) readyHandler(s *discordgo.Session, r *discordgo.Ready) { - user.log.Debugln("discord connection ready") +func (user *User) readyHandler(_ *discordgo.Session, r *discordgo.Ready) { + user.log.Debugln("Discord connection ready") // Update our user fields - user.ID = r.User.ID - - // Update our guild map to match watch discord thinks we're in. This is the - // only time we can get the full guild map as discordgo doesn't make it - // available to us later. Also, discord might not give us the full guild - // information here, so we use this to remove guilds the user left and only - // add guilds whose full information we have. The are told about the - // "unavailable" guilds later via the GuildCreate handler. - user.guildsLock.Lock() - defer user.guildsLock.Unlock() - - // build a list of the current guilds we're in so we can prune the old ones - current := []string{} - - user.log.Debugln("database guild count", len(user.guilds)) - user.log.Debugln("discord guild count", len(r.Guilds)) - - for _, guild := range r.Guilds { - current = append(current, guild.ID) - - // If we already know about this guild, make sure we reset it's bridge - // status. - if val, found := user.guilds[guild.ID]; found { - bridge := val.Bridge - user.guilds[guild.ID].Bridge = bridge - - // Update the name if the guild is available - if !guild.Unavailable { - user.guilds[guild.ID].GuildName = guild.Name - } - - val.Upsert() - } else { - g := user.bridge.DB.Guild.New() - g.DiscordID = user.ID - g.GuildID = guild.ID - user.guilds[guild.ID] = g - - if !guild.Unavailable { - g.GuildName = guild.Name - } - - g.Upsert() - } + if user.DiscordID != r.User.ID { + user.DiscordID = r.User.ID + user.Update() } - // Sync the guilds to the database. - user.bridge.DB.Guild.Prune(user.ID, current) - - // Finally reload from the database since it purged servers we're not in - // anymore. - user.loadGuilds() - - user.log.Debugln("updated database guild count", len(user.guilds)) - - user.Update() + updateTS := time.Now() + guildsInSpace := make(map[string]bool) + for _, guild := range user.GetGuilds() { + guildsInSpace[guild.GuildID] = guild.InSpace + } + for _, guild := range r.Guilds { + user.handleGuild(guild, updateTS, guildsInSpace[guild.ID]) + } + user.PruneGuildList(updateTS) + const maxCreate = 5 + for i, ch := range r.PrivateChannels { + portal := user.GetPortalByMeta(ch) + if i < maxCreate && portal.MXID == "" { + err := portal.CreateMatrixRoom(user, ch) + if err != nil { + user.log.Errorfln("Failed to create portal for private channel %s in initial sync: %v", ch.ID, err) + } + } else { + portal.UpdateInfo(user, ch) + } + } } -func (user *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) { - user.log.Debugln("connected to discord") +func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSpace bool) { + guild := user.bridge.GetGuildByID(meta.ID, true) + guild.UpdateInfo(user, meta) + if len(meta.Channels) > 0 { + for _, ch := range meta.Channels { + portal := user.GetPortalByMeta(ch) + if guild.AutoBridgeChannels && portal.MXID == "" { + err := portal.CreateMatrixRoom(user, ch) + if err != nil { + user.log.Errorfln("Failed to create portal for guild channel %s/%s in initial sync: %v", guild.ID, ch.ID, err) + } + } else { + portal.UpdateInfo(user, ch) + } + } + } + if len(guild.MXID) > 0 && !isInSpace { + _, err := user.bridge.Bot.SendStateEvent(user.GetSpaceRoom(), event.StateSpaceChild, guild.MXID.String(), &event.SpaceChildEventContent{ + Via: []string{user.bridge.AS.HomeserverDomain}, + }) + if err != nil { + user.log.Errorfln("Failed to add guild space %s to user space: %v", guild.MXID, err) + } else { + isInSpace = true + } + } + user.MarkInGuild(database.UserGuild{GuildID: meta.ID, Timestamp: timestamp, InSpace: isInSpace}) +} + +func (user *User) connectedHandler(_ *discordgo.Session, c *discordgo.Connect) { + user.log.Debugln("Connected to discord") user.tryAutomaticDoublePuppeting() } -func (user *User) disconnectedHandler(s *discordgo.Session, d *discordgo.Disconnect) { - user.log.Debugln("disconnected from discord") +func (user *User) disconnectedHandler(_ *discordgo.Session, d *discordgo.Disconnect) { + user.log.Debugln("Disconnected from discord") } -func (user *User) guildCreateHandler(s *discordgo.Session, g *discordgo.GuildCreate) { - user.guildsLock.Lock() - defer user.guildsLock.Unlock() - - // If we somehow already know about the guild, just update it's name - if guild, found := user.guilds[g.ID]; found { - guild.GuildName = g.Name - guild.Upsert() +func (user *User) guildCreateHandler(_ *discordgo.Session, g *discordgo.GuildCreate) { + user.handleGuild(g.Guild, time.Now(), false) +} +func (user *User) guildDeleteHandler(_ *discordgo.Session, g *discordgo.GuildDelete) { + user.MarkNotInGuild(g.ID) + guild := user.bridge.GetGuildByID(g.ID, false) + if guild == nil || guild.MXID == "" { return } - - // This is a brand new guild so lets get it added. - guild := user.bridge.DB.Guild.New() - guild.DiscordID = user.ID - guild.GuildID = g.ID - guild.GuildName = g.Name - guild.Upsert() - - user.guilds[g.ID] = guild + // TODO clean up? } -func (user *User) guildDeleteHandler(s *discordgo.Session, g *discordgo.GuildDelete) { - user.guildsLock.Lock() - defer user.guildsLock.Unlock() +func (user *User) guildUpdateHandler(_ *discordgo.Session, g *discordgo.GuildUpdate) { + user.handleGuild(g.Guild, time.Now(), user.IsInSpace(g.ID)) +} - if guild, found := user.guilds[g.ID]; found { - guild.Delete() - delete(user.guilds, g.ID) - user.log.Debugln("deleted guild", g.Guild.ID) +func (user *User) channelCreateHandler(_ *discordgo.Session, c *discordgo.ChannelCreate) { + if !user.bridgeMessage(c.GuildID) { + return } -} - -func (user *User) guildUpdateHandler(s *discordgo.Session, g *discordgo.GuildUpdate) { - user.guildsLock.Lock() - defer user.guildsLock.Unlock() - - // If we somehow already know about the guild, just update it's name - if guild, found := user.guilds[g.ID]; found { - guild.GuildName = g.Name - guild.Upsert() - - user.log.Debugln("updated guild", g.ID) - } -} - -func (user *User) createChannel(c *discordgo.Channel) { - key := database.NewPortalKey(c.ID, user.User.ID) - portal := user.bridge.GetPortalByID(key) - + portal := user.GetPortalByMeta(c.Channel) if portal.MXID != "" { return } - - portal.Name = c.Name - portal.Topic = c.Topic - portal.Type = c.Type - - if portal.Type == discordgo.ChannelTypeDM { - portal.OtherUserID = c.Recipients[0].ID + err := portal.CreateMatrixRoom(user, c.Channel) + if err != nil { + user.log.Errorfln("Error creating Matrix room for %s on channel create event: %v", c.ID, err) } - - if c.Icon != "" { - user.log.Debugln("channel icon", c.Icon) - } - - portal.Update() - - portal.createMatrixRoom(user, c) } -func (user *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) { - user.createChannel(c.Channel) -} - -func (user *User) channelDeleteHandler(s *discordgo.Session, c *discordgo.ChannelDelete) { +func (user *User) channelDeleteHandler(_ *discordgo.Session, c *discordgo.ChannelDelete) { user.log.Debugln("channel delete handler") } -func (user *User) channelPinsUpdateHandler(s *discordgo.Session, c *discordgo.ChannelPinsUpdate) { +func (user *User) channelPinsUpdateHandler(_ *discordgo.Session, c *discordgo.ChannelPinsUpdate) { user.log.Debugln("channel pins update") } -func (user *User) channelUpdateHandler(s *discordgo.Session, c *discordgo.ChannelUpdate) { - key := database.NewPortalKey(c.ID, user.User.ID) - portal := user.bridge.GetPortalByID(key) - - portal.update(user, c.Channel) +func (user *User) channelUpdateHandler(_ *discordgo.Session, c *discordgo.ChannelUpdate) { + portal := user.GetPortalByMeta(c.Channel) + portal.UpdateInfo(user, c.Channel) } -func (user *User) messageCreateHandler(s *discordgo.Session, m *discordgo.MessageCreate) { - if !user.bridgeMessage(m.GuildID) { +func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildID string) { + fmt.Printf("%+v\n", msg) + if !user.bridgeMessage(guildID) { return } - key := database.NewPortalKey(m.ChannelID, user.ID) - portal := user.bridge.GetPortalByID(key) - - msg := portalDiscordMessage{ - msg: m, - user: user, + portal := user.GetExistingPortalByID(channelID) + var thread *Thread + if portal == nil { + thread = user.bridge.GetThreadByID(channelID, nil) + if thread == nil || thread.Parent == nil { + user.log.Debugfln("Dropping %s in unknown channel %s/%s", typeName, guildID, channelID) + return + } + portal = thread.Parent } - portal.discordMessages <- msg + portal.discordMessages <- portalDiscordMessage{ + msg: msg, + user: user, + thread: thread, + } } -func (user *User) messageDeleteHandler(s *discordgo.Session, m *discordgo.MessageDelete) { - if !user.bridgeMessage(m.GuildID) { - return - } - - key := database.NewPortalKey(m.ChannelID, user.ID) - portal := user.bridge.GetPortalByID(key) - - msg := portalDiscordMessage{ - msg: m, - user: user, - } - - portal.discordMessages <- msg +func (user *User) messageCreateHandler(_ *discordgo.Session, m *discordgo.MessageCreate) { + user.pushPortalMessage(m, "message create", m.ChannelID, m.GuildID) } -func (user *User) messageUpdateHandler(s *discordgo.Session, m *discordgo.MessageUpdate) { - if !user.bridgeMessage(m.GuildID) { - return - } - - key := database.NewPortalKey(m.ChannelID, user.ID) - portal := user.bridge.GetPortalByID(key) - - msg := portalDiscordMessage{ - msg: m, - user: user, - } - - portal.discordMessages <- msg +func (user *User) messageDeleteHandler(_ *discordgo.Session, m *discordgo.MessageDelete) { + user.pushPortalMessage(m, "message delete", m.ChannelID, m.GuildID) } -func (user *User) reactionAddHandler(s *discordgo.Session, m *discordgo.MessageReactionAdd) { - if !user.bridgeMessage(m.MessageReaction.GuildID) { - return - } - - key := database.NewPortalKey(m.ChannelID, user.User.ID) - portal := user.bridge.GetPortalByID(key) - - msg := portalDiscordMessage{ - msg: m, - user: user, - } - - portal.discordMessages <- msg +func (user *User) messageUpdateHandler(_ *discordgo.Session, m *discordgo.MessageUpdate) { + user.pushPortalMessage(m, "message update", m.ChannelID, m.GuildID) } -func (user *User) reactionRemoveHandler(s *discordgo.Session, m *discordgo.MessageReactionRemove) { - if !user.bridgeMessage(m.MessageReaction.GuildID) { - return - } +func (user *User) reactionAddHandler(_ *discordgo.Session, m *discordgo.MessageReactionAdd) { + user.pushPortalMessage(m, "reaction add", m.ChannelID, m.GuildID) +} - key := database.NewPortalKey(m.ChannelID, user.User.ID) - portal := user.bridge.GetPortalByID(key) - - msg := portalDiscordMessage{ - msg: m, - user: user, - } - - portal.discordMessages <- msg +func (user *User) reactionRemoveHandler(_ *discordgo.Session, m *discordgo.MessageReactionRemove) { + user.pushPortalMessage(m, "reaction remove", m.ChannelID, m.GuildID) } func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) bool { + if intent == nil { + intent = user.bridge.Bot + } ret := false inviteContent := event.Content{ @@ -682,7 +662,7 @@ func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, func (user *User) getDirectChats() map[id.UserID][]id.RoomID { chats := map[id.UserID][]id.RoomID{} - privateChats := user.bridge.DB.Portal.FindPrivateChatsOf(user.ID) + privateChats := user.bridge.DB.Portal.FindPrivateChatsOf(user.DiscordID) for _, portal := range privateChats { if portal.MXID != "" { puppetMXID := user.bridge.FormatPuppetMXID(portal.Key.Receiver) @@ -755,28 +735,21 @@ func (user *User) updateDirectChats(chats map[id.UserID][]id.RoomID) { } func (user *User) bridgeGuild(guildID string, everything bool) error { - user.guildsLock.Lock() - defer user.guildsLock.Unlock() - - guild, found := user.guilds[guildID] - if !found { - return fmt.Errorf("guildID not found") + guild := user.bridge.GetGuildByID(guildID, false) + if guild == nil { + return errors.New("guild not found") } - - // Update the guild - guild.Bridge = true - guild.Upsert() - - // If this is a full bridge, create portals for all the channels - if everything { - channels, err := user.Session.GuildChannels(guildID) - if err != nil { - return err - } - - for _, channel := range channels { - if channelIsBridgeable(channel) { - user.createChannel(channel) + meta, _ := user.Session.State.Guild(guildID) + err := guild.CreateMatrixRoom(user, meta) + if err != nil { + return err + } + for _, ch := range meta.Channels { + portal := user.GetPortalByMeta(ch) + if (everything && channelIsBridgeable(ch)) || ch.Type == discordgo.ChannelTypeGuildCategory { + err = portal.CreateMatrixRoom(user, ch) + if err != nil { + user.log.Warnfln("Error creating room for guild channel %s: %v", ch.ID, err) } } } @@ -785,41 +758,41 @@ func (user *User) bridgeGuild(guildID string, everything bool) error { } func (user *User) unbridgeGuild(guildID string) error { - user.guildsLock.Lock() - defer user.guildsLock.Unlock() - - guild, exists := user.guilds[guildID] - if !exists { - return fmt.Errorf("guildID not found") - } - - if !guild.Bridge { - return fmt.Errorf("guild not bridged") - } - - // First update the guild so we don't have any other go routines recreating - // channels we're about to destroy. - guild.Bridge = false - guild.Upsert() - - // Now run through the channels in the guild and remove any portals we - // have for them. - channels, err := user.Session.GuildChannels(guildID) - if err != nil { - return err - } - - for _, channel := range channels { - if channelIsBridgeable(channel) { - key := database.PortalKey{ - ChannelID: channel.ID, - Receiver: user.ID, - } - - portal := user.bridge.GetPortalByID(key) - portal.leave(user) - } - } + //user.guildsLock.Lock() + //defer user.guildsLock.Unlock() + // + //guild, exists := user.guilds[guildID] + //if !exists { + // return fmt.Errorf("guildID not found") + //} + // + //if !guild.Bridge { + // return fmt.Errorf("guild not bridged") + //} + // + //// First update the guild so we don't have any other go routines recreating + //// channels we're about to destroy. + //guild.Bridge = false + //guild.Upsert() + // + //// Now run through the channels in the guild and remove any portals we + //// have for them. + //channels, err := user.Session.GuildChannels(guildID) + //if err != nil { + // return err + //} + // + //for _, channel := range channels { + // if channelIsBridgeable(channel) { + // key := database.PortalKey{ + // ChannelID: channel.ID, + // Receiver: user.DiscordID, + // } + // + // portal := user.bridge.GetPortalByID(key) + // portal.leave(user) + // } + //} return nil }