diff --git a/bridge/discord.go b/bridge/discord.go new file mode 100644 index 0000000..2322b81 --- /dev/null +++ b/bridge/discord.go @@ -0,0 +1,16 @@ +package bridge + +import ( + "github.com/bwmarrin/discordgo" +) + +func channelIsBridgeable(channel *discordgo.Channel) bool { + switch channel.Type { + case discordgo.ChannelTypeGuildText: + fallthrough + case discordgo.ChannelTypeGuildNews: + return true + } + + return false +} diff --git a/bridge/portal.go b/bridge/portal.go index ccfa3cc..860acd8 100644 --- a/bridge/portal.go +++ b/bridge/portal.go @@ -659,6 +659,15 @@ func (p *Portal) handleMatrixLeave(sender *User) { p.cleanupIfEmpty() } +func (p *Portal) leave(sender *User) { + if p.MXID == "" { + return + } + + intent := p.bridge.GetPuppetByID(sender.ID).IntentFor(p) + intent.LeaveRoom(p.MXID) +} + func (p *Portal) delete() { p.Portal.Delete() p.bridge.portalsLock.Lock() diff --git a/bridge/provisioning.go b/bridge/provisioning.go index 02aaf31..46f95f9 100644 --- a/bridge/provisioning.go +++ b/bridge/provisioning.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/gorilla/mux" "github.com/gorilla/websocket" log "maunium.net/go/maulogger/v2" @@ -48,6 +49,12 @@ func newProvisioningAPI(bridge *Bridge) *ProvisioningAPI { r.HandleFunc("/logout", p.logout).Methods(http.MethodPost) r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost) + // Setup the guild endpoints + r.HandleFunc("/guilds", p.guildsList).Methods(http.MethodGet) + r.HandleFunc("/guilds/{guildID}/bridge", p.guildsBridge).Methods(http.MethodPost) + r.HandleFunc("/guilds/{guildID}/unbridge", p.guildsUnbridge).Methods(http.MethodPost) + r.HandleFunc("/guilds/{guildID}/joinentire", p.guildsJoinEntire).Methods(http.MethodPost) + return p } @@ -381,3 +388,71 @@ func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) { }) } } + +func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + user.guildsLock.Lock() + defer user.guildsLock.Unlock() + + data := make([]map[string]interface{}, len(user.guilds)) + idx := 0 + for _, guild := range user.guilds { + data[idx] = map[string]interface{}{ + "name": guild.GuildName, + "id": guild.GuildID, + "bridged": guild.Bridge, + } + + idx++ + } + + jsonResponse(w, http.StatusOK, data) +} + +func (p *ProvisioningAPI) guildsBridge(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + guildID, _ := mux.Vars(r)["guildID"] + + if err := user.bridgeGuild(guildID, false); err != nil { + jsonResponse(w, http.StatusNotFound, Error{ + Error: err.Error(), + ErrCode: "M_NOT_FOUND", + }) + } else { + w.WriteHeader(http.StatusCreated) + } +} + +func (p *ProvisioningAPI) guildsUnbridge(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + guildID, _ := mux.Vars(r)["guildID"] + + if err := user.unbridgeGuild(guildID); err != nil { + jsonResponse(w, http.StatusNotFound, Error{ + Error: err.Error(), + ErrCode: "M_NOT_FOUND", + }) + + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (p *ProvisioningAPI) guildsJoinEntire(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + guildID, _ := mux.Vars(r)["guildID"] + + if err := user.bridgeGuild(guildID, true); err != nil { + jsonResponse(w, http.StatusNotFound, Error{ + Error: err.Error(), + ErrCode: "M_NOT_FOUND", + }) + } else { + w.WriteHeader(http.StatusCreated) + } +} diff --git a/bridge/user.go b/bridge/user.go index 04a54d6..f34102e 100644 --- a/bridge/user.go +++ b/bridge/user.go @@ -503,7 +503,7 @@ func (u *User) guildUpdateHandler(s *discordgo.Session, g *discordgo.GuildUpdate } } -func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) { +func (u *User) createChannel(c *discordgo.Channel) { key := database.NewPortalKey(c.ID, u.User.ID) portal := u.bridge.GetPortalByID(key) @@ -525,7 +525,11 @@ func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCr portal.Update() - portal.createMatrixRoom(u, c.Channel) + portal.createMatrixRoom(u, c) +} + +func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) { + u.createChannel(c.Channel) } func (u *User) channelDeleteHandler(s *discordgo.Session, c *discordgo.ChannelDelete) { @@ -738,3 +742,73 @@ func (u *User) updateDirectChats(chats map[id.UserID][]id.RoomID) { u.log.Warnln("Failed to update m.direct list:", err) } } + +func (u *User) bridgeGuild(guildID string, everything bool) error { + u.guildsLock.Lock() + defer u.guildsLock.Unlock() + + guild, found := u.guilds[guildID] + if !found { + return fmt.Errorf("guildID not found") + } + + // Update the guild + guild.Bridge = true + guild.Upsert() + + // If this is a full bridge, create portals for all the channels + if everything { + channels, err := u.Session.GuildChannels(guildID) + if err != nil { + return err + } + + for _, channel := range channels { + if channelIsBridgeable(channel) { + u.createChannel(channel) + } + } + } + + return nil +} + +func (u *User) unbridgeGuild(guildID string) error { + u.guildsLock.Lock() + defer u.guildsLock.Unlock() + + guild, exists := u.guilds[guildID] + if !exists { + return fmt.Errorf("guildID not found") + } + + if !guild.Bridge { + return fmt.Errorf("guild not bridged") + } + + // First update the guild so we don't have any other go routines recreating + // channels we're about to destroy. + guild.Bridge = false + guild.Upsert() + + // Now run through the channels in the guild and remove any portals we + // have for them. + channels, err := u.Session.GuildChannels(guildID) + if err != nil { + return err + } + + for _, channel := range channels { + if channelIsBridgeable(channel) { + key := database.PortalKey{ + ChannelID: channel.ID, + Receiver: u.ID, + } + + portal := u.bridge.GetPortalByID(key) + portal.leave(u) + } + } + + return nil +}