From 3f3c86754def6d00776d88432cf2fc72a4c77c49 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 22 Apr 2023 02:50:14 +0300 Subject: [PATCH] Bridge friend nicks as DM room name --- database/portal.go | 47 ++++++++------ database/upgrades/00-latest-revision.sql | 3 +- .../upgrades/17-dm-portal-friend-nick.sql | 2 + go.mod | 2 +- go.sum | 4 +- portal.go | 34 +++++++++-- puppet.go | 2 +- user.go | 61 +++++++++++++++++++ 8 files changed, 124 insertions(+), 31 deletions(-) create mode 100644 database/upgrades/17-dm-portal-friend-nick.sql diff --git a/database/portal.go b/database/portal.go index ba6bccd..995101f 100644 --- a/database/portal.go +++ b/database/portal.go @@ -15,7 +15,7 @@ import ( const ( portalSelect = ` SELECT dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid, - plain_name, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, + plain_name, name, name_set, friend_nick, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, in_space, first_event_id, relay_webhook_id, relay_webhook_secret FROM portal ` @@ -68,6 +68,10 @@ func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { return pq.get(portalSelect+" WHERE mxid=$1", mxid) } +func (pq *PortalQuery) FindPrivateChatBetween(id, receiver string) *Portal { + return pq.get(portalSelect+" WHERE other_user_id=$1 AND receiver=$2 AND type=$3", id, receiver, discordgo.ChannelTypeDM) +} + func (pq *PortalQuery) FindPrivateChatsWith(id string) []*Portal { return pq.getAll(portalSelect+" WHERE other_user_id=$1 AND type=$2", id, discordgo.ChannelTypeDM) } @@ -109,16 +113,17 @@ type Portal struct { MXID id.RoomID - PlainName string - Name string - NameSet bool - Topic string - TopicSet bool - Avatar string - AvatarURL id.ContentURI - AvatarSet bool - Encrypted bool - InSpace id.RoomID + PlainName string + Name string + NameSet bool + FriendNick bool + Topic string + TopicSet bool + Avatar string + AvatarURL id.ContentURI + AvatarSet bool + Encrypted bool + InSpace id.RoomID FirstEventID id.EventID @@ -132,7 +137,7 @@ func (p *Portal) Scan(row dbutil.Scannable) *Portal { var avatarURL string err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &chanType, &otherUserID, &guildID, &parentID, - &mxid, &p.PlainName, &p.Name, &p.NameSet, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet, + &mxid, &p.PlainName, &p.Name, &p.NameSet, &p.FriendNick, &p.Topic, &p.TopicSet, &p.Avatar, &avatarURL, &p.AvatarSet, &p.Encrypted, &p.InSpace, &firstEventID, &relayWebhookID, &relayWebhookSecret) if err != nil { @@ -160,13 +165,13 @@ func (p *Portal) Scan(row dbutil.Scannable) *Portal { func (p *Portal) Insert() { query := ` INSERT INTO portal (dcid, receiver, type, other_user_id, dc_guild_id, dc_parent_id, mxid, - plain_name, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, + plain_name, name, name_set, friend_nick, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, in_space, first_event_id, relay_webhook_id, relay_webhook_secret) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) ` _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)), - p.PlainName, p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, + p.PlainName, p.Name, p.NameSet, p.FriendNick, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.Encrypted, p.InSpace, p.FirstEventID.String(), strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret)) if err != nil { @@ -179,14 +184,16 @@ func (p *Portal) Update() { query := ` UPDATE portal SET type=$1, other_user_id=$2, dc_guild_id=$3, dc_parent_id=$4, mxid=$5, - plain_name=$6, name=$7, name_set=$8, topic=$9, topic_set=$10, avatar=$11, avatar_url=$12, avatar_set=$13, - encrypted=$14, in_space=$15, first_event_id=$16, relay_webhook_id=$17, relay_webhook_secret=$18 - WHERE dcid=$19 AND receiver=$20 + plain_name=$6, name=$7, name_set=$8, friend_nick=$9, topic=$10, topic_set=$11, + avatar=$12, avatar_url=$13, avatar_set=$14, encrypted=$15, in_space=$16, first_event_id=$17, + relay_webhook_id=$18, relay_webhook_secret=$19 + WHERE dcid=$20 AND receiver=$21 ` _, err := p.db.Exec(query, p.Type, strPtr(p.OtherUserID), strPtr(p.GuildID), strPtr(p.ParentID), strPtr(string(p.MXID)), - p.PlainName, p.Name, p.NameSet, p.Topic, p.TopicSet, p.Avatar, p.AvatarURL.String(), p.AvatarSet, - p.Encrypted, p.InSpace, p.FirstEventID.String(), strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret), + p.PlainName, p.Name, p.NameSet, p.FriendNick, p.Topic, p.TopicSet, + p.Avatar, p.AvatarURL.String(), p.AvatarSet, p.Encrypted, p.InSpace, p.FirstEventID.String(), + strPtr(p.RelayWebhookID), strPtr(p.RelayWebhookSecret), p.Key.ChannelID, p.Key.Receiver) if err != nil { diff --git a/database/upgrades/00-latest-revision.sql b/database/upgrades/00-latest-revision.sql index df9837e..868020c 100644 --- a/database/upgrades/00-latest-revision.sql +++ b/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v16: Latest revision +-- v0 -> v17: Latest revision CREATE TABLE guild ( dcid TEXT PRIMARY KEY, @@ -29,6 +29,7 @@ CREATE TABLE portal ( plain_name TEXT NOT NULL, name TEXT NOT NULL, name_set BOOLEAN NOT NULL, + friend_nick BOOLEAN NOT NULL, topic TEXT NOT NULL, topic_set BOOLEAN NOT NULL, avatar TEXT NOT NULL, diff --git a/database/upgrades/17-dm-portal-friend-nick.sql b/database/upgrades/17-dm-portal-friend-nick.sql new file mode 100644 index 0000000..2c2b43c --- /dev/null +++ b/database/upgrades/17-dm-portal-friend-nick.sql @@ -0,0 +1,2 @@ +-- v17: Store whether DM portal name is a friend nickname +ALTER TABLE portal ADD COLUMN friend_nick BOOLEAN NOT NULL DEFAULT false; diff --git a/go.mod b/go.mod index d64c6d1..bbe0942 100644 --- a/go.mod +++ b/go.mod @@ -37,4 +37,4 @@ require ( maunium.net/go/mauflag v1.0.0 // indirect ) -replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230421223629-940c512c92de +replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230421234845-798d8c4943a5 diff --git a/go.sum b/go.sum index c943467..bdd5eb4 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= -github.com/beeper/discordgo v0.0.0-20230421223629-940c512c92de h1:jq5xgpkSvFJiiXH8w0SWGjK6jzwU8gZvrNahAq24nyI= -github.com/beeper/discordgo v0.0.0-20230421223629-940c512c92de/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= +github.com/beeper/discordgo v0.0.0-20230421234845-798d8c4943a5 h1:goe2lr+0qYbVtdidyl3UMurIGdFnXXZqmNq0vWZo4f4= +github.com/beeper/discordgo v0.0.0-20230421234845-798d8c4943a5/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/portal.go b/portal.go index 47b36bb..7e1741f 100644 --- a/portal.go +++ b/portal.go @@ -165,6 +165,20 @@ func (user *User) GetPortalByID(id string, chanType discordgo.ChannelType) *Port return user.bridge.GetPortalByID(database.NewPortalKey(id, user.DiscordID), chanType) } +func (user *User) FindPrivateChatWith(userID string) *Portal { + user.bridge.portalsLock.Lock() + defer user.bridge.portalsLock.Unlock() + dbPortal := user.bridge.DB.Portal.FindPrivateChatBetween(userID, user.DiscordID) + if dbPortal == nil { + return nil + } + existing, ok := user.bridge.portalsByID[dbPortal.Key] + if ok { + return existing + } + return user.bridge.loadPortal(dbPortal, nil, discordgo.ChannelTypeDM) +} + func (br *DiscordBridge) GetExistingPortalByID(key database.PortalKey) *Portal { br.portalsLock.Lock() defer br.portalsLock.Unlock() @@ -466,7 +480,7 @@ func (portal *Portal) CreateMatrixRoom(user *User, channel *discordgo.Channel) e InitialState: initialState, CreationContent: creationContent, } - if !portal.shouldSetDMRoomMetadata() { + if !portal.shouldSetDMRoomMetadata() && !portal.FriendNick { req.Name = "" } @@ -1854,11 +1868,13 @@ func (portal *Portal) UpdateName(meta *discordgo.Channel) bool { GuildName: guildName, NSFW: meta.NSFW, Type: meta.Type, - })) || plainNameChanged + }), false) || plainNameChanged } -func (portal *Portal) UpdateNameDirect(name string) bool { - if portal.Name == name && (portal.NameSet || portal.MXID == "" || !portal.shouldSetDMRoomMetadata()) { +func (portal *Portal) UpdateNameDirect(name string, isFriendNick bool) bool { + if portal.FriendNick && !isFriendNick { + return false + } else if portal.Name == name && (portal.NameSet || portal.MXID == "" || (!portal.shouldSetDMRoomMetadata() && !isFriendNick)) { return false } portal.log.Debugfln("Updating name %q -> %q", portal.Name, name) @@ -1869,7 +1885,7 @@ func (portal *Portal) UpdateNameDirect(name string) bool { } func (portal *Portal) updateRoomName() { - if portal.MXID != "" && portal.shouldSetDMRoomMetadata() { + if portal.MXID != "" && (portal.shouldSetDMRoomMetadata() || portal.FriendNick) { _, err := portal.MainIntent().SetRoomName(portal.MXID, portal.Name) if err != nil { portal.log.Warnln("Failed to update room name:", err) @@ -2073,7 +2089,13 @@ func (portal *Portal) UpdateInfo(source *User, meta *discordgo.Channel) *discord if portal.OtherUserID != "" { puppet := portal.bridge.GetPuppetByID(portal.OtherUserID) changed = portal.UpdateAvatarFromPuppet(puppet) || changed - changed = portal.UpdateNameDirect(puppet.Name) || changed + if rel, ok := source.relationships[portal.OtherUserID]; ok && rel.Nickname != "" { + portal.FriendNick = true + changed = portal.UpdateNameDirect(rel.Nickname, true) || changed + } else { + portal.FriendNick = false + changed = portal.UpdateNameDirect(puppet.Name, false) || changed + } } case discordgo.ChannelTypeGroupDM: changed = portal.UpdateGroupDMAvatar(meta.Icon) || changed diff --git a/puppet.go b/puppet.go index 62f9ce5..70fd894 100644 --- a/puppet.go +++ b/puppet.go @@ -205,7 +205,7 @@ func (puppet *Puppet) UpdateName(info *discordgo.User) bool { puppet.log.Warn().Err(err).Msg("Failed to update displayname") } else { go puppet.updatePortalMeta(func(portal *Portal) { - if portal.UpdateNameDirect(puppet.Name) { + if portal.UpdateNameDirect(puppet.Name, false) { portal.Update() portal.UpdateBridgeInfo() } diff --git a/user.go b/user.go index a4d7913..af46533 100644 --- a/user.go +++ b/user.go @@ -62,6 +62,8 @@ type User struct { pendingInteractionsLock sync.Mutex nextDiscordUploadID atomic.Int32 + + relationships map[string]*discordgo.Relationship } func (user *User) GetRemoteID() string { @@ -189,6 +191,8 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User { PermissionLevel: br.Config.Bridge.Permissions.Get(dbUser.MXID), pendingInteractions: make(map[string]*WrappedCommandEvent), + + relationships: make(map[string]*discordgo.Relationship), } user.nextDiscordUploadID.Store(rand.Int31n(100)) user.BridgeState = br.NewBridgeStateQueue(user) @@ -581,6 +585,10 @@ func (user *User) Connect() error { user.Session.AddHandler(user.channelPinsUpdateHandler) user.Session.AddHandler(user.channelUpdateHandler) + user.Session.AddHandler(user.relationshipAddHandler) + user.Session.AddHandler(user.relationshipRemoveHandler) + user.Session.AddHandler(user.relationshipUpdateHandler) + user.Session.AddHandler(user.messageCreateHandler) user.Session.AddHandler(user.messageDeleteHandler) user.Session.AddHandler(user.messageUpdateHandler) @@ -660,6 +668,10 @@ func (user *User) readyHandler(_ *discordgo.Session, r *discordgo.Ready) { user.BridgeState.Send(status.BridgeState{StateEvent: status.StateBackfilling}) user.tryAutomaticDoublePuppeting() + for _, relationship := range r.Relationships { + user.relationships[relationship.ID] = relationship + } + updateTS := time.Now() portalsInSpace := make(map[string]bool) for _, guild := range user.GetPortals() { @@ -738,6 +750,55 @@ func (user *User) addPrivateChannelToSpace(portal *Portal) bool { } } +func (user *User) relationshipAddHandler(_ *discordgo.Session, r *discordgo.RelationshipAdd) { + user.log.Debug().Interface("relationship", r.Relationship).Msg("Relationship added") + user.relationships[r.ID] = r.Relationship + user.handleRelationshipChange(r.ID, r.Nickname) +} + +func (user *User) relationshipUpdateHandler(_ *discordgo.Session, r *discordgo.RelationshipUpdate) { + user.log.Debug().Interface("relationship", r.Relationship).Msg("Relationship update") + user.relationships[r.ID] = r.Relationship + user.handleRelationshipChange(r.ID, r.Nickname) +} + +func (user *User) relationshipRemoveHandler(_ *discordgo.Session, r *discordgo.RelationshipRemove) { + user.log.Debug().Str("other_user_id", r.ID).Msg("Relationship removed") + delete(user.relationships, r.ID) + user.handleRelationshipChange(r.ID, "") +} + +func (user *User) handleRelationshipChange(userID, nickname string) { + puppet := user.bridge.GetPuppetByID(userID) + portal := user.FindPrivateChatWith(userID) + if portal == nil || puppet == nil { + return + } + + updated := portal.FriendNick == (nickname != "") + portal.FriendNick = nickname != "" + if nickname != "" { + updated = portal.UpdateNameDirect(nickname, true) + } else if portal.Name != puppet.Name { + if portal.shouldSetDMRoomMetadata() { + updated = portal.UpdateNameDirect(puppet.Name, false) + } else if portal.NameSet { + _, err := portal.MainIntent().SendStateEvent(portal.MXID, event.StateRoomName, "", map[string]any{}) + if err != nil { + portal.zlog.Warn().Err(err).Msg("Failed to clear room name after friend nickname was removed") + } else { + portal.zlog.Debug().Msg("Cleared room name after friend nickname was removed") + portal.NameSet = false + portal.Update() + updated = true + } + } + } + if !updated { + portal.Update() + } +} + func (user *User) handlePrivateChannel(portal *Portal, meta *discordgo.Channel, timestamp time.Time, create, isInSpace bool) { if create && portal.MXID == "" { err := portal.CreateMatrixRoom(user, meta)