From d797541cdc899409293c3c10a3c1b770eb736acc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 28 May 2022 23:41:34 +0300 Subject: [PATCH] Add DM rooms to space --- commands.go | 4 +- database/upgrades/00-latest-revision.sql | 20 +-- database/upgrades/03-spaces.sql | 17 +-- database/user.go | 92 ++------------ database/userportal.go | 102 +++++++++++++++ provisioning.go | 4 +- user.go | 152 ++++++++++++++++------- 7 files changed, 247 insertions(+), 144 deletions(-) create mode 100644 database/userportal.go diff --git a/commands.go b/commands.go index b8f65c0..e1a9f36 100644 --- a/commands.go +++ b/commands.go @@ -252,8 +252,8 @@ func fnGuilds(ce *WrappedCommandEvent) { func fnListGuilds(ce *WrappedCommandEvent) { var output strings.Builder - for _, userGuild := range ce.User.GetGuilds() { - guild := ce.Bridge.GetGuildByID(userGuild.GuildID, false) + for _, userGuild := range ce.User.GetPortals() { + guild := ce.Bridge.GetGuildByID(userGuild.DiscordID, false) if guild == nil { continue } diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index 0a1c872..f272adc 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -28,6 +28,7 @@ CREATE TABLE portal ( name TEXT NOT NULL, name_set BOOLEAN NOT NULL, topic TEXT NOT NULL, + topic_set BOOLEAN NOT NULL, avatar TEXT NOT NULL, avatar_url TEXT NOT NULL, avatar_set BOOLEAN NOT NULL, @@ -72,18 +73,19 @@ CREATE TABLE "user" ( discord_token TEXT, management_room TEXT, - space_room TEXT + space_room TEXT, + dm_space_room TEXT ); -CREATE TABLE user_guild ( - guild_id TEXT, - user_mxid TEXT, - in_space BOOLEAN NOT NULL, - timestamp BIGINT NOT NULL, +CREATE TABLE user_portal ( + discord_id TEXT, + user_mxid TEXT, + type TEXT NOT NULL, + in_space BOOLEAN NOT NULL, + timestamp BIGINT NOT NULL, - PRIMARY KEY (guild_id, user_mxid), - CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE, - CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE + PRIMARY KEY (discord_id, user_mxid), + CONSTRAINT up_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ); CREATE TABLE message ( diff --git a/database/upgrades/03-spaces.sql b/database/upgrades/03-spaces.sql index 53ae9bc..79bc3c5 100644 --- a/database/upgrades/03-spaces.sql +++ b/database/upgrades/03-spaces.sql @@ -13,15 +13,15 @@ CREATE TABLE guild ( auto_bridge_channels BOOLEAN NOT NULL ); -CREATE TABLE user_guild ( - guild_id TEXT, - user_mxid TEXT, - in_space BOOLEAN NOT NULL, - timestamp BIGINT NOT NULL, +CREATE TABLE user_portal ( + discord_id TEXT, + user_mxid TEXT, + type TEXT NOT NULL, + in_space BOOLEAN NOT NULL, + timestamp BIGINT NOT NULL, - PRIMARY KEY (guild_id, user_mxid), - CONSTRAINT ug_guild_fkey FOREIGN KEY (guild_id) REFERENCES guild (dcid) ON DELETE CASCADE, - CONSTRAINT ug_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE + PRIMARY KEY (discord_id, user_mxid), + CONSTRAINT up_user_fkey FOREIGN KEY (user_mxid) REFERENCES "user" (mxid) ON DELETE CASCADE ); ALTER TABLE portal ADD COLUMN dc_guild_id TEXT; @@ -52,6 +52,7 @@ ALTER TABLE puppet ALTER COLUMN name_set DROP DEFAULT; ALTER TABLE puppet ALTER COLUMN avatar_set DROP DEFAULT; ALTER TABLE "user" ADD COLUMN space_room TEXT; +ALTER TABLE "user" ADD COLUMN dm_space_room TEXT; ALTER TABLE "user" RENAME COLUMN token TO discord_token; UPDATE message SET timestamp=timestamp*1000; diff --git a/database/user.go b/database/user.go index 0150f16..6c05208 100644 --- a/database/user.go +++ b/database/user.go @@ -2,8 +2,6 @@ package database import ( "database/sql" - "errors" - "time" log "maunium.net/go/maulogger/v2" @@ -24,18 +22,18 @@ func (uq *UserQuery) New() *User { } func (uq *UserQuery) GetByMXID(userID id.UserID) *User { - query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE mxid=$1` + query := `SELECT mxid, dcid, discord_token, management_room, space_room, dm_space_room FROM "user" WHERE mxid=$1` return uq.New().Scan(uq.db.QueryRow(query, userID)) } func (uq *UserQuery) GetByID(id string) *User { - query := `SELECT mxid, dcid, discord_token, management_room, space_room FROM "user" WHERE dcid=$1` + query := `SELECT mxid, dcid, discord_token, management_room, space_room, dm_space_room FROM "user" WHERE dcid=$1` return uq.New().Scan(uq.db.QueryRow(query, id)) } func (uq *UserQuery) GetAllWithToken() []*User { query := ` - SELECT mxid, dcid, discord_token, management_room, space_room + SELECT mxid, dcid, discord_token, management_room, space_room, dm_space_room FROM "user" WHERE discord_token IS NOT NULL ` rows, err := uq.db.Query(query) @@ -62,81 +60,12 @@ type User struct { DiscordToken string ManagementRoom id.RoomID SpaceRoom id.RoomID -} - -type UserGuild struct { - GuildID string - Timestamp time.Time - InSpace bool -} - -func (u *User) GetGuilds() []UserGuild { - res, err := u.db.Query("SELECT guild_id, timestamp, in_space FROM user_guild WHERE user_mxid=$1", u.MXID) - if err != nil { - u.log.Errorln("Failed to get guilds:", err) - panic(err) - return nil - } - var guilds []UserGuild - for res.Next() { - var guild UserGuild - var ts int64 - err = res.Scan(&guild.GuildID, &ts, &guild.InSpace) - if err != nil { - u.log.Errorln("Error scanning user guild:", err) - panic(err) - } else { - guild.Timestamp = time.UnixMilli(ts) - guilds = append(guilds, guild) - } - } - return guilds -} - -func (u *User) IsInSpace(guildID string) (isIn bool) { - query := `SELECT in_space FROM user_guild WHERE user_mxid=$1 AND guild_id=$2` - err := u.db.QueryRow(query, u.MXID, guildID).Scan(&isIn) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, guildID, err) - panic(err) - } - return -} - -func (u *User) MarkInGuild(guild UserGuild) { - query := ` - INSERT INTO user_guild (guild_id, user_mxid, timestamp, in_space) - VALUES ($1, $2, $3, $4) - ON CONFLICT (guild_id, user_mxid) DO UPDATE - SET timestamp=excluded.timestamp, in_space=excluded.in_space - ` - _, err := u.db.Exec(query, guild.GuildID, u.MXID, guild.Timestamp.UnixMilli(), guild.InSpace) - if err != nil { - u.log.Errorfln("Failed to insert user guild %s/%s: %v", u.MXID, guild.GuildID, err) - panic(err) - } -} - -func (u *User) MarkNotInGuild(guildID string) { - query := `DELETE FROM user_guild WHERE user_mxid=$1 AND guild_id=$2` - _, err := u.db.Exec(query, u.MXID, guildID) - if err != nil { - u.log.Errorfln("Failed to remove user guild %s/%s: %v", u.MXID, guildID, err) - panic(err) - } -} - -func (u *User) PruneGuildList(beforeTS time.Time) { - _, err := u.db.Exec("DELETE FROM user_guild WHERE user_mxid=$1 AND timestamp<$2", u.MXID, beforeTS.UnixMilli()) - if err != nil { - u.log.Errorln("Failed to prune user guild list:", err) - panic(err) - } + DMSpaceRoom id.RoomID } func (u *User) Scan(row dbutil.Scannable) *User { - var discordID, managementRoom, spaceRoom, discordToken sql.NullString - err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom) + var discordID, managementRoom, spaceRoom, dmSpaceRoom, discordToken sql.NullString + err := row.Scan(&u.MXID, &discordID, &discordToken, &managementRoom, &spaceRoom, &dmSpaceRoom) if err != nil { if err != sql.ErrNoRows { u.log.Errorln("Database scan failed:", err) @@ -148,12 +77,13 @@ func (u *User) Scan(row dbutil.Scannable) *User { u.DiscordToken = discordToken.String u.ManagementRoom = id.RoomID(managementRoom.String) u.SpaceRoom = id.RoomID(spaceRoom.String) + u.DMSpaceRoom = id.RoomID(dmSpaceRoom.String) return u } func (u *User) Insert() { - query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room) VALUES ($1, $2, $3, $4, $5)` - _, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom))) + query := `INSERT INTO "user" (mxid, dcid, discord_token, management_room, space_room, dm_space_room) VALUES ($1, $2, $3, $4, $5, $6)` + _, err := u.db.Exec(query, u.MXID, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), strPtr(string(u.DMSpaceRoom))) if err != nil { u.log.Warnfln("Failed to insert %s: %v", u.MXID, err) panic(err) @@ -161,8 +91,8 @@ func (u *User) Insert() { } func (u *User) Update() { - query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4 WHERE mxid=$5` - _, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), u.MXID) + query := `UPDATE "user" SET dcid=$1, discord_token=$2, management_room=$3, space_room=$4, dm_space_room=$5 WHERE mxid=$6` + _, err := u.db.Exec(query, strPtr(u.DiscordID), strPtr(u.DiscordToken), strPtr(string(u.ManagementRoom)), strPtr(string(u.SpaceRoom)), strPtr(string(u.DMSpaceRoom)), u.MXID) if err != nil { u.log.Warnfln("Failed to update %q: %v", u.MXID, err) panic(err) diff --git a/database/userportal.go b/database/userportal.go new file mode 100644 index 0000000..0eaa0a3 --- /dev/null +++ b/database/userportal.go @@ -0,0 +1,102 @@ +package database + +import ( + "database/sql" + "errors" + "time" + + log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/util/dbutil" +) + +const ( + UserPortalTypeDM = "dm" + UserPortalTypeGuild = "guild" +) + +type UserPortal struct { + DiscordID string + Type string + Timestamp time.Time + InSpace bool +} + +func (up UserPortal) Scan(l log.Logger, row dbutil.Scannable) *UserPortal { + var ts int64 + err := row.Scan(&up.DiscordID, &up.Type, &ts, &up.InSpace) + if err != nil { + l.Errorln("Error scanning user portal:", err) + panic(err) + return nil + } + up.Timestamp = time.UnixMilli(ts) + return &up +} + +func (u *User) scanUserPortals(rows *sql.Rows) []UserPortal { + var ups []UserPortal + for rows.Next() { + up := UserPortal{}.Scan(u.log, rows) + if up != nil { + ups = append(ups, *up) + } + } + return ups +} + +func (u *User) GetPortals() []UserPortal { + rows, err := u.db.Query("SELECT discord_id, type, timestamp, in_space FROM user_portal WHERE user_mxid=$1", u.MXID) + if err != nil { + u.log.Errorln("Failed to get portals:", err) + panic(err) + return nil + } + return u.scanUserPortals(rows) +} + +func (u *User) IsInSpace(discordID string) (isIn bool) { + query := `SELECT in_space FROM user_portal WHERE user_mxid=$1 AND discord_id=$2` + err := u.db.QueryRow(query, u.MXID, discordID).Scan(&isIn) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + u.log.Warnfln("Failed to scan in_space for %s/%s: %v", u.MXID, discordID, err) + panic(err) + } + return +} + +func (u *User) MarkInPortal(portal UserPortal) { + query := ` + INSERT INTO user_portal (discord_id, type, user_mxid, timestamp, in_space) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (discord_id, user_mxid) DO UPDATE + SET timestamp=excluded.timestamp, in_space=excluded.in_space + ` + _, err := u.db.Exec(query, portal.DiscordID, portal.Type, u.MXID, portal.Timestamp.UnixMilli(), portal.InSpace) + if err != nil { + u.log.Errorfln("Failed to insert user portal %s/%s: %v", u.MXID, portal.DiscordID, err) + panic(err) + } +} + +func (u *User) MarkNotInPortal(discordID string) { + query := `DELETE FROM user_portal WHERE user_mxid=$1 AND discord_id=$2` + _, err := u.db.Exec(query, u.MXID, discordID) + if err != nil { + u.log.Errorfln("Failed to remove user portal %s/%s: %v", u.MXID, discordID, err) + panic(err) + } +} + +func (u *User) PrunePortalList(beforeTS time.Time) []UserPortal { + query := ` + DELETE FROM user_portal + WHERE user_mxid=$1 AND timestamp<$2 + RETURNING discord_id, type, timestamp, in_space + ` + rows, err := u.db.Query(query, u.MXID, beforeTS.UnixMilli()) + if err != nil { + u.log.Errorln("Failed to prune user guild list:", err) + panic(err) + } + return u.scanUserPortals(rows) +} diff --git a/provisioning.go b/provisioning.go index 7a01825..6cc7a14 100644 --- a/provisioning.go +++ b/provisioning.go @@ -392,8 +392,8 @@ func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) { user := r.Context().Value("user").(*User) var data []map[string]interface{} - for _, userGuild := range user.GetGuilds() { - guild := p.bridge.GetGuildByID(userGuild.GuildID, false) + for _, userGuild := range user.GetPortals() { + guild := p.bridge.GetGuildByID(userGuild.DiscordID, false) if guild == nil { continue } diff --git a/user.go b/user.go index 40699f5..f956fd3 100644 --- a/user.go +++ b/user.go @@ -39,8 +39,9 @@ type User struct { PermissionLevel bridgeconfig.PermissionLevel - spaceCreateLock sync.Mutex - spaceMembershipChecked bool + spaceCreateLock sync.Mutex + spaceMembershipChecked bool + dmSpaceMembershipChecked bool Session *discordgo.Session } @@ -237,26 +238,42 @@ func (user *User) SetManagementRoom(roomID id.RoomID) { user.Update() } -func (user *User) GetSpaceRoom() id.RoomID { - if len(user.SpaceRoom) == 0 { +func (user *User) getSpaceRoom(ptr *id.RoomID, checked *bool, name, topic string, parent id.RoomID) id.RoomID { + if len(*ptr) == 0 { user.spaceCreateLock.Lock() defer user.spaceCreateLock.Unlock() - if len(user.SpaceRoom) > 0 { - return user.SpaceRoom + if len(*ptr) > 0 { + return *ptr + } + + initialState := []*event.Event{{ + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{ + URL: user.bridge.Config.AppService.Bot.ParsedAvatar, + }, + }, + }} + + if parent != "" { + parentIDStr := parent.String() + initialState = append(initialState, &event.Event{ + Type: event.StateSpaceParent, + StateKey: &parentIDStr, + Content: event.Content{ + Parsed: &event.SpaceParentEventContent{ + Canonical: true, + Via: []string{user.bridge.AS.HomeserverDomain}, + }, + }, + }) } resp, err := user.bridge.Bot.CreateRoom(&mautrix.ReqCreateRoom{ - Visibility: "private", - Name: "Discord", - Topic: "Your Discord bridged chats", - InitialState: []*event.Event{{ - Type: event.StateRoomAvatar, - Content: event.Content{ - Parsed: &event.RoomAvatarEventContent{ - URL: user.bridge.Config.AppService.Bot.ParsedAvatar, - }, - }, - }}, + Visibility: "private", + Name: name, + Topic: topic, + InitialState: initialState, CreationContent: map[string]interface{}{ "type": event.RoomTypeSpace, }, @@ -271,16 +288,34 @@ func (user *User) GetSpaceRoom() id.RoomID { if err != nil { user.log.Errorln("Failed to auto-create space room:", err) } else { - user.SpaceRoom = resp.RoomID + *ptr = resp.RoomID user.Update() - user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false) - } - } else if !user.spaceMembershipChecked && !user.bridge.StateStore.IsInRoom(user.SpaceRoom, user.MXID) { - user.ensureInvited(user.bridge.Bot, user.SpaceRoom, false) - } - user.spaceMembershipChecked = true + user.ensureInvited(user.bridge.Bot, *ptr, false) - return user.SpaceRoom + if parent != "" { + _, err = user.bridge.Bot.SendStateEvent(parent, event.StateSpaceChild, resp.RoomID.String(), &event.SpaceChildEventContent{ + Via: []string{user.bridge.AS.HomeserverDomain}, + Order: " 0000", + }) + if err != nil { + user.log.Errorfln("Failed to add space room %s to parent space %s: %v", resp.RoomID, parent, err) + } + } + } + } else if !*checked && !user.bridge.StateStore.IsInRoom(*ptr, user.MXID) { + user.ensureInvited(user.bridge.Bot, *ptr, false) + } + *checked = true + + return *ptr +} + +func (user *User) GetSpaceRoom() id.RoomID { + return user.getSpaceRoom(&user.SpaceRoom, &user.spaceMembershipChecked, "Discord", "Your Discord bridged chats", "") +} + +func (user *User) GetDMSpaceRoom() id.RoomID { + return user.getSpaceRoom(&user.DMSpaceRoom, &user.dmSpaceMembershipChecked, "Direct Messages", "Your Discord direct messages", user.GetSpaceRoom()) } func (user *User) tryAutomaticDoublePuppeting() { @@ -466,26 +501,46 @@ func (user *User) readyHandler(_ *discordgo.Session, r *discordgo.Ready) { } updateTS := time.Now() - guildsInSpace := make(map[string]bool) - for _, guild := range user.GetGuilds() { - guildsInSpace[guild.GuildID] = guild.InSpace + portalsInSpace := make(map[string]bool) + for _, guild := range user.GetPortals() { + portalsInSpace[guild.DiscordID] = guild.InSpace } for _, guild := range r.Guilds { - user.handleGuild(guild, updateTS, guildsInSpace[guild.ID]) + user.handleGuild(guild, updateTS, portalsInSpace[guild.ID]) } - user.PruneGuildList(updateTS) + user.PrunePortalList(updateTS) const maxCreate = 5 for i, ch := range r.PrivateChannels { portal := user.GetPortalByMeta(ch) - if i < maxCreate && portal.MXID == "" { - err := portal.CreateMatrixRoom(user, ch) - if err != nil { - user.log.Errorfln("Failed to create portal for private channel %s in initial sync: %v", ch.ID, err) - } + user.handlePrivateChannel(portal, ch, updateTS, i < maxCreate, portalsInSpace[portal.Key.String()]) + } +} + +func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel, timestamp time.Time, create, isInSpace bool) { + if create && portal.MXID == "" { + err := portal.CreateMatrixRoom(user, meta) + if err != nil { + user.log.Errorfln("Failed to create portal for private channel %s in initial sync: %v", meta.ID, err) + } + } else { + portal.UpdateInfo(user, meta) + } + if len(portal.MXID) > 0 && !isInSpace { + _, err := user.bridge.Bot.SendStateEvent(user.GetDMSpaceRoom(), event.StateSpaceChild, portal.MXID.String(), &event.SpaceChildEventContent{ + Via: []string{user.bridge.AS.HomeserverDomain}, + }) + if err != nil { + user.log.Errorfln("Failed to add DM room %s to user DM space: %v", portal.MXID, err) } else { - portal.UpdateInfo(user, ch) + isInSpace = true } } + user.MarkInPortal(database.UserPortal{ + DiscordID: meta.ID, + Type: database.UserPortalTypeDM, + Timestamp: timestamp, + InSpace: isInSpace, + }) } func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSpace bool) { @@ -514,7 +569,12 @@ func (user *User) handleGuild(meta *discordgo.Guild, timestamp time.Time, isInSp isInSpace = true } } - user.MarkInGuild(database.UserGuild{GuildID: meta.ID, Timestamp: timestamp, InSpace: isInSpace}) + user.MarkInPortal(database.UserPortal{ + DiscordID: meta.ID, + Type: database.UserPortalTypeGuild, + Timestamp: timestamp, + InSpace: isInSpace, + }) } func (user *User) connectedHandler(_ *discordgo.Session, c *discordgo.Connect) { @@ -532,7 +592,7 @@ func (user *User) guildCreateHandler(_ *discordgo.Session, g *discordgo.GuildCre } func (user *User) guildDeleteHandler(_ *discordgo.Session, g *discordgo.GuildDelete) { - user.MarkNotInGuild(g.ID) + user.MarkNotInPortal(g.ID) guild := user.bridge.GetGuildByID(g.ID, false) if guild == nil || guild.MXID == "" { return @@ -552,9 +612,13 @@ func (user *User) channelCreateHandler(_ *discordgo.Session, c *discordgo.Channe if portal.MXID != "" { return } - err := portal.CreateMatrixRoom(user, c.Channel) - if err != nil { - user.log.Errorfln("Error creating Matrix room for %s on channel create event: %v", c.ID, err) + if c.GuildID == "" { + user.handlePrivateChannel(portal, c.Channel, time.Now(), true, user.IsInSpace(portal.Key.String())) + } else { + err := portal.CreateMatrixRoom(user, c.Channel) + if err != nil { + user.log.Errorfln("Error creating Matrix room for %s on channel create event: %v", c.ID, err) + } } } @@ -568,7 +632,11 @@ func (user *User) channelPinsUpdateHandler(_ *discordgo.Session, c *discordgo.Ch func (user *User) channelUpdateHandler(_ *discordgo.Session, c *discordgo.ChannelUpdate) { portal := user.GetPortalByMeta(c.Channel) - portal.UpdateInfo(user, c.Channel) + if c.GuildID == "" { + user.handlePrivateChannel(portal, c.Channel, time.Now(), true, user.IsInSpace(portal.Key.String())) + } else { + portal.UpdateInfo(user, c.Channel) + } } func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildID string) {