discordid,connector: remember which guilds were bridged

This commit is contained in:
Skip R
2026-02-05 21:05:04 -08:00
parent 869d8c5412
commit 4bdb0de559
3 changed files with 71 additions and 5 deletions

View File

@@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"maps"
"net/http" "net/http"
"slices" "slices"
"sync" "sync"
@@ -276,8 +277,30 @@ func (d *DiscordClient) syncGuildSpace(_ context.Context, guild *discordgo.Guild
}) })
} }
// bridgedGuildIDs returns a set of guild IDs that should be bridged. Note that
// presence in the returned set does not imply that the rooms for the guild have
// already been created.
func (d *DiscordClient) bridgedGuildIDs() map[string]struct{} {
meta := d.UserLogin.Metadata.(*discordid.UserLoginMetadata)
bridgingGuildIDs := map[string]struct{}{}
// guilds that were bridged via the provisioning api
for guildID, bridged := range meta.BridgedGuildIDs {
if bridged {
bridgingGuildIDs[guildID] = struct{}{}
}
}
// guilds that were declared in the configuration file
for _, guildID := range d.connector.Config.Guilds.BridgingGuildIDs {
bridgingGuildIDs[guildID] = struct{}{}
}
return bridgingGuildIDs
}
func (d *DiscordClient) syncGuilds(ctx context.Context) { func (d *DiscordClient) syncGuilds(ctx context.Context) {
guildIDs := d.connector.Config.Guilds.BridgingGuildIDs guildIDs := slices.Sorted(maps.Keys(d.bridgedGuildIDs()))
for _, guildID := range guildIDs { for _, guildID := range guildIDs {
log := zerolog.Ctx(ctx).With(). log := zerolog.Ctx(ctx).With().
@@ -298,6 +321,7 @@ func (d *DiscordClient) bridgeGuild(ctx context.Context, guildID string) error {
guild, err := d.Session.State.Guild(guildID) guild, err := d.Session.State.Guild(guildID)
if errors.Is(err, discordgo.ErrStateNotFound) || guild == nil { if errors.Is(err, discordgo.ErrStateNotFound) || guild == nil {
log.Err(err).Msg("Couldn't find guild, user isn't a member?") log.Err(err).Msg("Couldn't find guild, user isn't a member?")
// TODO likely left/kicked/banned from guild; nuke the portals
return errors.New("couldn't find guild in state") return errors.New("couldn't find guild in state")
} }

View File

@@ -117,19 +117,33 @@ func (p *ProvisioningAPI) makeHandler(handler func(http.ResponseWriter, *http.Re
} }
func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request, login *bridgev2.UserLogin, client *DiscordClient) { func (p *ProvisioningAPI) guildsList(w http.ResponseWriter, r *http.Request, login *bridgev2.UserLogin, client *DiscordClient) {
ctx := r.Context()
p.log.Info().Str("login_id", discordid.ParseUserLoginID(login.ID)).Msg("guilds list requested via provisioning api") p.log.Info().Str("login_id", discordid.ParseUserLoginID(login.ID)).Msg("guilds list requested via provisioning api")
var resp respGuildsList var resp respGuildsList
resp.Guilds = []guildEntry{} resp.Guilds = []guildEntry{}
for _, guild := range client.Session.State.Guilds { for _, guild := range client.Session.State.Guilds {
portalKey := client.guildPortalKeyFromID(guild.ID)
portal, err := p.connector.Bridge.GetExistingPortalByKey(ctx, portalKey)
if err != nil {
p.log.Err(err).
Str("guild_id", guild.ID).
Msg("Failed to get guild portal for provisioning list")
}
mxid := ""
if portal != nil {
mxid = portal.MXID.String()
}
resp.Guilds = append(resp.Guilds, guildEntry{ resp.Guilds = append(resp.Guilds, guildEntry{
ID: guild.ID, ID: guild.ID,
Name: guild.Name, Name: guild.Name,
AvatarURL: discordgo.EndpointGuildIcon(guild.ID, guild.Icon), AvatarURL: discordgo.EndpointGuildIcon(guild.ID, guild.Icon),
MXID: mxid,
BridgingMode: "everything", BridgingMode: "everything",
Available: !guild.Unavailable,
Available: !guild.Unavailable,
}) })
} }
@@ -148,10 +162,27 @@ func (p *ProvisioningAPI) bridgeGuild(w http.ResponseWriter, r *http.Request, lo
Str("guild_id", guildID). Str("guild_id", guildID).
Msg("requested to bridge guild via provisioning api") Msg("requested to bridge guild via provisioning api")
// TODO detect guild already bridged meta := login.Metadata.(*discordid.UserLoginMetadata)
if meta.BridgedGuildIDs == nil {
meta.BridgedGuildIDs = map[string]bool{}
}
_, alreadyBridged := meta.BridgedGuildIDs[guildID]
meta.BridgedGuildIDs[guildID] = true
if err := login.Save(r.Context()); err != nil {
p.log.Err(err).Msg("Failed to save login after guild bridge request")
mautrix.MUnknown.WithMessage("failed to save login: %v", err).Write(w)
return
}
go client.bridgeGuild(context.TODO(), guildID) go client.bridgeGuild(context.TODO(), guildID)
exhttp.WriteJSONResponse(w, 201, nil) responseStatus := 201
if alreadyBridged {
responseStatus = 200
}
exhttp.WriteJSONResponse(w, responseStatus, nil)
} }
func (p *ProvisioningAPI) unbridgeGuild(w http.ResponseWriter, r *http.Request, login *bridgev2.UserLogin, client *DiscordClient) { func (p *ProvisioningAPI) unbridgeGuild(w http.ResponseWriter, r *http.Request, login *bridgev2.UserLogin, client *DiscordClient) {
@@ -166,6 +197,16 @@ func (p *ProvisioningAPI) unbridgeGuild(w http.ResponseWriter, r *http.Request,
Str("guild_id", guildID). Str("guild_id", guildID).
Msg("requested to unbridge guild via provisioning api") Msg("requested to unbridge guild via provisioning api")
meta := login.Metadata.(*discordid.UserLoginMetadata)
if meta.BridgedGuildIDs != nil {
delete(meta.BridgedGuildIDs, guildID)
}
if err := login.Save(r.Context()); err != nil {
p.log.Err(err).Msg("Failed to save login after guild unbridge request")
mautrix.MUnknown.WithMessage("failed to save login: %v", err).Write(w)
return
}
ctx := context.TODO() ctx := context.TODO()
portalKey := client.guildPortalKeyFromID(guildID) portalKey := client.guildPortalKeyFromID(guildID)

View File

@@ -30,4 +30,5 @@ type PortalMetadata struct {
type UserLoginMetadata struct { type UserLoginMetadata struct {
Token string `json:"token"` Token string `json:"token"`
HeartbeatSession discordgo.HeartbeatSession `json:"heartbeat_session"` HeartbeatSession discordgo.HeartbeatSession `json:"heartbeat_session"`
BridgedGuildIDs map[string]bool `json:"bridged_guild_ids,omitempty"`
} }