From b66556ad994a77804d64f2dff3871f32494077f4 Mon Sep 17 00:00:00 2001 From: Gary Kramlich Date: Thu, 14 Apr 2022 12:03:44 -0500 Subject: [PATCH] Add provisioning api for bridging guilds This isn't exactly the same as the puppet bridge, basically it removes the puppetID from the url and just works against the logged in user. There is a known errata right now where some times all channels aren't left when unbridging a guild. I figured it's more important to get some more testing on this while I try to figure that out in the mean time. Also when you call /guilds/:guildID/joinentire, it joins each channel serially. I tried to make this concurrent but I maxed out the rate limit. We can do this in the future, but again, rather get this into everyone's hands sooner rather than later. I also did the same thing with unbridge and had to revert for the same reasons. Refs #8 --- bridge/discord.go | 16 +++++++++ bridge/portal.go | 9 +++++ bridge/provisioning.go | 75 ++++++++++++++++++++++++++++++++++++++++ bridge/user.go | 78 ++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 bridge/discord.go diff --git a/bridge/discord.go b/bridge/discord.go new file mode 100644 index 0000000..2322b81 --- /dev/null +++ b/bridge/discord.go @@ -0,0 +1,16 @@ +package bridge + +import ( + "github.com/bwmarrin/discordgo" +) + +func channelIsBridgeable(channel *discordgo.Channel) bool { + switch channel.Type { + case discordgo.ChannelTypeGuildText: + fallthrough + case discordgo.ChannelTypeGuildNews: + return true + } + + return false +} diff --git a/bridge/portal.go b/bridge/portal.go index ccfa3cc..860acd8 100644 --- a/bridge/portal.go +++ b/bridge/portal.go @@ -659,6 +659,15 @@ func (p *Portal) handleMatrixLeave(sender *User) { p.cleanupIfEmpty() } +func (p *Portal) leave(sender *User) { + if p.MXID == "" { + return + } + + intent := p.bridge.GetPuppetByID(sender.ID).IntentFor(p) + intent.LeaveRoom(p.MXID) +} + func (p *Portal) delete() { p.Portal.Delete() p.bridge.portalsLock.Lock() diff --git a/bridge/provisioning.go b/bridge/provisioning.go index 02aaf31..46f95f9 100644 --- a/bridge/provisioning.go +++ b/bridge/provisioning.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/gorilla/mux" "github.com/gorilla/websocket" log "maunium.net/go/maulogger/v2" @@ -48,6 +49,12 @@ func newProvisioningAPI(bridge *Bridge) *ProvisioningAPI { r.HandleFunc("/logout", p.logout).Methods(http.MethodPost) r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost) + // Setup the guild endpoints + r.HandleFunc("/guilds", p.guildsList).Methods(http.MethodGet) + r.HandleFunc("/guilds/{guildID}/bridge", p.guildsBridge).Methods(http.MethodPost) + r.HandleFunc("/guilds/{guildID}/unbridge", p.guildsUnbridge).Methods(http.MethodPost) + r.HandleFunc("/guilds/{guildID}/joinentire", p.guildsJoinEntire).Methods(http.MethodPost) + return p } @@ -381,3 +388,71 @@ func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) { }) } } + +func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + user.guildsLock.Lock() + defer user.guildsLock.Unlock() + + data := make([]map[string]interface{}, len(user.guilds)) + idx := 0 + for _, guild := range user.guilds { + data[idx] = map[string]interface{}{ + "name": guild.GuildName, + "id": guild.GuildID, + "bridged": guild.Bridge, + } + + idx++ + } + + jsonResponse(w, http.StatusOK, data) +} + +func (p *ProvisioningAPI) guildsBridge(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + guildID, _ := mux.Vars(r)["guildID"] + + if err := user.bridgeGuild(guildID, false); err != nil { + jsonResponse(w, http.StatusNotFound, Error{ + Error: err.Error(), + ErrCode: "M_NOT_FOUND", + }) + } else { + w.WriteHeader(http.StatusCreated) + } +} + +func (p *ProvisioningAPI) guildsUnbridge(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + guildID, _ := mux.Vars(r)["guildID"] + + if err := user.unbridgeGuild(guildID); err != nil { + jsonResponse(w, http.StatusNotFound, Error{ + Error: err.Error(), + ErrCode: "M_NOT_FOUND", + }) + + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (p *ProvisioningAPI) guildsJoinEntire(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + guildID, _ := mux.Vars(r)["guildID"] + + if err := user.bridgeGuild(guildID, true); err != nil { + jsonResponse(w, http.StatusNotFound, Error{ + Error: err.Error(), + ErrCode: "M_NOT_FOUND", + }) + } else { + w.WriteHeader(http.StatusCreated) + } +} diff --git a/bridge/user.go b/bridge/user.go index 04a54d6..f34102e 100644 --- a/bridge/user.go +++ b/bridge/user.go @@ -503,7 +503,7 @@ func (u *User) guildUpdateHandler(s *discordgo.Session, g *discordgo.GuildUpdate } } -func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) { +func (u *User) createChannel(c *discordgo.Channel) { key := database.NewPortalKey(c.ID, u.User.ID) portal := u.bridge.GetPortalByID(key) @@ -525,7 +525,11 @@ func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCr portal.Update() - portal.createMatrixRoom(u, c.Channel) + portal.createMatrixRoom(u, c) +} + +func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) { + u.createChannel(c.Channel) } func (u *User) channelDeleteHandler(s *discordgo.Session, c *discordgo.ChannelDelete) { @@ -738,3 +742,73 @@ func (u *User) updateDirectChats(chats map[id.UserID][]id.RoomID) { u.log.Warnln("Failed to update m.direct list:", err) } } + +func (u *User) bridgeGuild(guildID string, everything bool) error { + u.guildsLock.Lock() + defer u.guildsLock.Unlock() + + guild, found := u.guilds[guildID] + if !found { + return fmt.Errorf("guildID not found") + } + + // Update the guild + guild.Bridge = true + guild.Upsert() + + // If this is a full bridge, create portals for all the channels + if everything { + channels, err := u.Session.GuildChannels(guildID) + if err != nil { + return err + } + + for _, channel := range channels { + if channelIsBridgeable(channel) { + u.createChannel(channel) + } + } + } + + return nil +} + +func (u *User) unbridgeGuild(guildID string) error { + u.guildsLock.Lock() + defer u.guildsLock.Unlock() + + guild, exists := u.guilds[guildID] + if !exists { + return fmt.Errorf("guildID not found") + } + + if !guild.Bridge { + return fmt.Errorf("guild not bridged") + } + + // First update the guild so we don't have any other go routines recreating + // channels we're about to destroy. + guild.Bridge = false + guild.Upsert() + + // Now run through the channels in the guild and remove any portals we + // have for them. + channels, err := u.Session.GuildChannels(guildID) + if err != nil { + return err + } + + for _, channel := range channels { + if channelIsBridgeable(channel) { + key := database.PortalKey{ + ChannelID: channel.ID, + Receiver: u.ID, + } + + portal := u.bridge.GetPortalByID(key) + portal.leave(u) + } + } + + return nil +}