From 928a49b42d1d659d2da8ddde7f1cd0b836225625 Mon Sep 17 00:00:00 2001 From: Gary Kramlich Date: Tue, 29 Mar 2022 07:10:16 -0500 Subject: [PATCH] Keep track of guild states. We don't currently bridge anything, but this just keeps everything in sync with the discord server. --- bridge/commands.go | 22 +++++ bridge/user.go | 129 ++++++++++++++++++++++++++++-- database/database.go | 6 ++ database/guild.go | 55 +++++++++++++ database/guildquery.go | 83 +++++++++++++++++++ database/migrations/07-guilds.sql | 7 ++ database/migrations/migrations.go | 1 + 7 files changed, 295 insertions(+), 8 deletions(-) create mode 100644 database/guild.go create mode 100644 database/guildquery.go create mode 100644 database/migrations/07-guilds.sql diff --git a/bridge/commands.go b/bridge/commands.go index 8605d1b..679ad86 100644 --- a/bridge/commands.go +++ b/bridge/commands.go @@ -53,6 +53,8 @@ type commands struct { Reconnect reconnectCmd `kong:"cmd,help='Reconnect to Discord'"` Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"` + Guilds guildsCmd `kong:"cmd,help='Guild bridging management.'"` + LoginMatrix loginMatrixCmd `kong:"cmd,help='Replace the puppet for your Discord account with your real Matrix account.'"` LogoutMatrix logoutMatrixCmd `kong:"cmd,help='Switch the puppet for your Discord account back to the default one.'"` PingMatrix pingMatrixCmd `kong:"cmd,help='check if your double puppet is working properly'"` @@ -286,3 +288,23 @@ func (m *pingMatrixCmd) Run(g *globals) error { return nil } + +/////////////////////////////////////////////////////////////////////////////// +// Guilds Commands +/////////////////////////////////////////////////////////////////////////////// +type guildsCmd struct { + Status guildStatusCmd `kong:"cmd,help='Show the bridge status for the guilds you are in'"` +} + +type guildStatusCmd struct{} + +func (c *guildStatusCmd) Run(g *globals) error { + g.user.guildsLock.Lock() + defer g.user.guildsLock.Unlock() + + for _, guild := range g.user.guilds { + fmt.Fprintf(g.context.Stdout, "%s %s %t\n", guild.GuildName, guild.GuildID, guild.Bridge) + } + + return nil +} diff --git a/bridge/user.go b/bridge/user.go index 21806d7..d259690 100644 --- a/bridge/user.go +++ b/bridge/user.go @@ -32,9 +32,20 @@ type User struct { bridge *Bridge log log.Logger + guilds map[string]*database.Guild + guildsLock sync.Mutex + Session *discordgo.Session } +// this assume you are holding the guilds lock!!! +func (u *User) loadGuilds() { + u.guilds = map[string]*database.Guild{} + for _, guild := range u.bridge.db.Guild.GetAll(u.ID) { + u.guilds[guild.GuildID] = guild + } +} + func (b *Bridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { // If we weren't passed in a user we attempt to create one if we were given // a matrix id. @@ -63,6 +74,11 @@ func (b *Bridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { b.managementRoomsLock.Unlock() } + // Load our guilds state from the database and turn it into a map + user.guildsLock.Lock() + user.loadGuilds() + user.guildsLock.Unlock() + return user } @@ -97,6 +113,7 @@ func (b *Bridge) NewUser(dbUser *database.User) *User { User: dbUser, bridge: b, log: b.log.Sub("User").Sub(string(dbUser.MXID)), + guilds: map[string]*database.Guild{}, } return user @@ -315,18 +332,15 @@ func (u *User) Connect() error { u.Session = session - // get our user info - user, err := u.Session.User("@me") - if err != nil { - return err - } - - u.User.ID = user.ID - // Add our event handlers + u.Session.AddHandler(u.readyHandler) u.Session.AddHandler(u.connectedHandler) u.Session.AddHandler(u.disconnectedHandler) + u.Session.AddHandler(u.guildCreateHandler) + u.Session.AddHandler(u.guildDeleteHandler) + u.Session.AddHandler(u.guildUpdateHandler) + u.Session.AddHandler(u.channelCreateHandler) u.Session.AddHandler(u.channelDeleteHandler) u.Session.AddHandler(u.channelPinsUpdateHandler) @@ -360,6 +374,59 @@ func (u *User) Disconnect() error { return nil } +func (u *User) readyHandler(s *discordgo.Session, r *discordgo.Ready) { + u.log.Debugln("discord connection ready") + + // Update our user fields + u.ID = r.User.ID + + // Update our guild map to match watch discord thinks we're in. This is the + // only time we can get the full guild map as discordgo doesn't make it + // available to us later. Also, discord might not give us the full guild + // information here, so we use this to remove guilds the user left and only + // add guilds whose full information we have. The are told about the + // "unavailable" guilds later via the GuildCreate handler. + u.guildsLock.Lock() + defer u.guildsLock.Unlock() + + // build a list of the current guilds we're in so we can prune the old ones + current := []string{} + + for _, guild := range r.Guilds { + current = append(current, guild.ID) + + // If we already know about this guild, make sure we reset it's bridge + // status. + if val, found := u.guilds[guild.ID]; found { + bridge := val.Bridge + u.guilds[guild.ID].Bridge = bridge + + // Update the name if the guild is available + if !guild.Unavailable { + u.guilds[guild.ID].GuildName = guild.Name + } + } else { + g := u.bridge.db.Guild.New() + g.DiscordID = u.ID + g.GuildID = guild.ID + u.guilds[guild.ID] = g + + if !guild.Unavailable { + g.GuildName = guild.Name + } + } + } + + // Sync the guilds to the database. + u.bridge.db.Guild.Prune(u.ID, current) + + // Finally reload from the database since it purged servers we're not in + // anymore. + u.loadGuilds() + + u.Update() +} + func (u *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) { u.log.Debugln("connected to discord") @@ -370,6 +437,52 @@ func (u *User) disconnectedHandler(s *discordgo.Session, d *discordgo.Disconnect u.log.Debugln("disconnected from discord") } +func (u *User) guildCreateHandler(s *discordgo.Session, g *discordgo.GuildCreate) { + u.guildsLock.Lock() + defer u.guildsLock.Unlock() + + // If we somehow already know about the guild, just update it's name + if guild, found := u.guilds[g.ID]; found { + guild.GuildName = g.Name + guild.Upsert() + + return + } + + // This is a brand new guild so lets get it added. + guild := u.bridge.db.Guild.New() + guild.DiscordID = u.ID + guild.GuildID = g.ID + guild.GuildName = g.Name + guild.Upsert() + + u.guilds[g.ID] = guild +} + +func (u *User) guildDeleteHandler(s *discordgo.Session, g *discordgo.GuildDelete) { + u.guildsLock.Lock() + defer u.guildsLock.Unlock() + + if guild, found := u.guilds[g.ID]; found { + guild.Delete() + delete(u.guilds, g.ID) + u.log.Debugln("deleted guild", g.Guild.ID) + } +} + +func (u *User) guildUpdateHandler(s *discordgo.Session, g *discordgo.GuildUpdate) { + u.guildsLock.Lock() + defer u.guildsLock.Unlock() + + // If we somehow already know about the guild, just update it's name + if guild, found := u.guilds[g.ID]; found { + guild.GuildName = g.Name + guild.Upsert() + + u.log.Debugln("updated guild", g.ID) + } +} + func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) { key := database.NewPortalKey(c.ID, u.User.ID) portal := u.bridge.GetPortalByID(key) diff --git a/database/database.go b/database/database.go index 8fa7664..8c31441 100644 --- a/database/database.go +++ b/database/database.go @@ -23,6 +23,7 @@ type Database struct { Reaction *ReactionQuery Attachment *AttachmentQuery Emoji *EmojiQuery + Guild *GuildQuery } func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger) (*Database, error) { @@ -85,5 +86,10 @@ func New(dbType, uri string, maxOpenConns, maxIdleConns int, baseLog log.Logger) log: db.log.Sub("Emoji"), } + db.Guild = &GuildQuery{ + db: db, + log: db.log.Sub("Guild"), + } + return db, nil } diff --git a/database/guild.go b/database/guild.go new file mode 100644 index 0000000..3232a26 --- /dev/null +++ b/database/guild.go @@ -0,0 +1,55 @@ +package database + +import ( + "database/sql" + "errors" + + log "maunium.net/go/maulogger/v2" +) + +type Guild struct { + db *Database + log log.Logger + + DiscordID string + GuildID string + GuildName string + Bridge bool +} + +func (g *Guild) Scan(row Scannable) *Guild { + err := row.Scan(&g.DiscordID, &g.GuildID, &g.GuildName, &g.Bridge) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + g.log.Errorln("Database scan failed:", err) + } + + return nil + } + + return g +} + +func (g *Guild) Upsert() { + query := "INSERT INTO guild" + + " (discord_id, guild_id, guild_name, bridge)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT(discord_id, guild_id)" + + " DO UPDATE SET guild_name=excluded.guild_name, bridge=excluded.bridge" + + _, err := g.db.Exec(query, g.DiscordID, g.GuildID, g.GuildName, g.Bridge) + + if err != nil { + g.log.Warnfln("Failed to upsert guild %s for %s: %v", g.GuildID, g.DiscordID, err) + } +} + +func (g *Guild) Delete() { + query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id=$2" + + _, err := g.db.Exec(query, g.DiscordID, g.GuildID) + + if err != nil { + g.log.Warnfln("Failed to delete guild %s for user %s: %v", g.GuildID, g.DiscordID, err) + } +} diff --git a/database/guildquery.go b/database/guildquery.go new file mode 100644 index 0000000..a8faa66 --- /dev/null +++ b/database/guildquery.go @@ -0,0 +1,83 @@ +package database + +import ( + "fmt" + + log "maunium.net/go/maulogger/v2" +) + +type GuildQuery struct { + db *Database + log log.Logger +} + +const ( + guildSelect = "SELECT discord_id, guild_id, guild_name, bridge FROM guild" +) + +func (gq *GuildQuery) New() *Guild { + return &Guild{ + db: gq.db, + log: gq.log, + } +} + +func (gq *GuildQuery) Get(discordID, guildID string) *Guild { + query := guildSelect + " WHERE discord_id=$1 AND guild_id=$2" + + row := gq.db.QueryRow(query, discordID, guildID) + if row == nil { + return nil + } + + return gq.New().Scan(row) +} + +func (gq *GuildQuery) GetAll(discordID string) []*Guild { + query := guildSelect + " WHERE discord_id=$1" + + rows, err := gq.db.Query(query, discordID) + if err != nil || rows == nil { + return nil + } + + guilds := []*Guild{} + for rows.Next() { + guilds = append(guilds, gq.New().Scan(rows)) + } + + return guilds +} + +func (gq *GuildQuery) Prune(discordID string, guilds []string) { + // We need this interface slice because a variadic function can't mix + // arguements with a `...` expanded slice. + args := []interface{}{discordID} + + nGuilds := len(guilds) + if nGuilds <= 0 { + return + } + + gq.log.Debugfln("prunning guilds not in %#v for %s", guilds, discordID) + + // Build the in query + inQuery := "$2" + for i := 1; i < nGuilds; i++ { + inQuery += fmt.Sprintf(", $%d", i+2) + } + + // Add the arguements for the build query + for _, guildID := range guilds { + args = append(args, guildID) + } + + // Now remove any guilds that the user has left. + query := "DELETE FROM guild WHERE discord_id=$1 AND guild_id NOT IN (" + + inQuery + ")" + + _, err := gq.db.Exec(query, args...) + if err != nil { + gq.log.Warnfln("Failed to remove old guilds for user %s: %v", discordID, err) + } +} diff --git a/database/migrations/07-guilds.sql b/database/migrations/07-guilds.sql new file mode 100644 index 0000000..25a868e --- /dev/null +++ b/database/migrations/07-guilds.sql @@ -0,0 +1,7 @@ +CREATE TABLE guild ( + discord_id TEXT NOT NULL, + guild_id TEXT NOT NULL, + guild_name TEXT NOT NULL, + bridge INTEGER(1) DEFAULT FALSE, + PRIMARY KEY(discord_id, guild_id) +); diff --git a/database/migrations/migrations.go b/database/migrations/migrations.go index 17fdbd1..06fcce4 100644 --- a/database/migrations/migrations.go +++ b/database/migrations/migrations.go @@ -19,6 +19,7 @@ var ( "03-emoji.sql", "04-custom-puppet.sql", "05-additional-puppet-fields.sql", + "07-guilds.sql", } sqliteMigrations = []string{