diff --git a/bridge/bridge.go b/bridge/bridge.go index c1b09a9..31c5463 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -48,6 +48,8 @@ type Bridge struct { puppetsLock sync.Mutex StateStore *database.SQLStateStore + + crypto Crypto } func New(cfg *config.Config) (*Bridge, error) { @@ -104,6 +106,8 @@ func New(cfg *config.Config) (*Bridge, error) { StateStore: stateStore, } + bridge.crypto = NewCryptoHelper(bridge) + if cfg.Appservice.Provisioning.Enabled() { bridge.provisioning = newProvisioningAPI(bridge) } @@ -151,6 +155,13 @@ func (b *Bridge) Start() error { return err } + if b.crypto != nil { + if err := b.crypto.Init(); err != nil { + b.log.Fatalln("Error initializing end-to-bridge encryption:", err) + return err + } + } + b.log.Debugln("Starting application service HTTP server") go b.as.Start() @@ -159,6 +170,10 @@ func (b *Bridge) Start() error { go b.updateBotProfile() + if b.crypto != nil { + go b.crypto.Start() + } + go b.startUsers() // Finally tell the appservice we're ready @@ -168,5 +183,21 @@ func (b *Bridge) Start() error { } func (b *Bridge) Stop() { + if b.crypto != nil { + b.crypto.Stop() + } + + b.as.Stop() + b.eventProcessor.Stop() + + for _, user := range b.usersByMXID { + if user.Session == nil { + continue + } + + b.log.Debugln("Disconnecting", user.MXID) + user.Session.Close() + } + b.log.Infoln("Bridge stopped") } diff --git a/bridge/commands.go b/bridge/commands.go index 2799e7f..83363b6 100644 --- a/bridge/commands.go +++ b/bridge/commands.go @@ -3,6 +3,7 @@ package bridge import ( "context" "fmt" + "time" "github.com/alecthomas/kong" @@ -33,11 +34,17 @@ func (g *globals) reply(msg string) { content.MsgType = event.MsgNotice intent := g.bot - if g.portal != nil && g.portal.IsPrivateChat() { + if g.portal == nil { + g.handler.log.Errorfln("we don't have a portal for this command") + + return + } + + if g.portal.IsPrivateChat() { intent = g.portal.MainIntent() } - _, err := intent.SendMessageEvent(g.roomID, event.EventMessage, content) + _, err := g.portal.sendMatrixMessage(intent, event.EventMessage, &content, nil, time.Now().UTC().UnixMilli()) if err != nil { g.handler.log.Warnfln("Failed to reply to command from %q: %v", g.user.MXID, err) } diff --git a/bridge/crypto.go b/bridge/crypto.go new file mode 100644 index 0000000..7c12eb3 --- /dev/null +++ b/bridge/crypto.go @@ -0,0 +1,339 @@ +package bridge + +import ( + "fmt" + "runtime/debug" + "time" + + "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "gitlab.com/beeper/discord/database" +) + +var NoSessionFound = crypto.NoSessionFound + +var levelTrace = maulogger.Level{ + Name: "TRACE", + Severity: -10, + Color: -1, +} + +type Crypto interface { + HandleMemberEvent(*event.Event) + Decrypt(*event.Event) (*event.Event, error) + Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error) + WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool + RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) + ResetSession(id.RoomID) + Init() error + Start() + Stop() +} + +type CryptoHelper struct { + bridge *Bridge + client *mautrix.Client + mach *crypto.OlmMachine + store *database.SQLCryptoStore + log maulogger.Logger + baseLog maulogger.Logger +} + +func NewCryptoHelper(bridge *Bridge) Crypto { + if !bridge.Config.Bridge.Encryption.Allow { + bridge.log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config") + return nil + } + + baseLog := bridge.log.Sub("Crypto") + return &CryptoHelper{ + bridge: bridge, + log: baseLog.Sub("Helper"), + baseLog: baseLog, + } +} + +func (helper *CryptoHelper) Init() error { + helper.log.Debugln("Initializing end-to-bridge encryption...") + + helper.store = database.NewSQLCryptoStore(helper.bridge.db, helper.bridge.as.BotMXID(), + fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.as.HomeserverDomain)) + + var err error + helper.client, err = helper.loginBot() + if err != nil { + return err + } + + helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID) + + logger := &cryptoLogger{helper.baseLog} + stateStore := &cryptoStateStore{helper.bridge} + helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore) + helper.mach.AllowKeyShare = helper.allowKeyShare + + helper.client.Syncer = &cryptoSyncer{helper.mach} + helper.client.Store = &cryptoClientStore{helper.store} + + return helper.mach.Load() +} + +func (helper *CryptoHelper) allowKeyShare(device *crypto.DeviceIdentity, info event.RequestedKeyInfo) *crypto.KeyShareRejection { + cfg := helper.bridge.Config.Bridge.Encryption.KeySharing + if !cfg.Allow { + return &crypto.KeyShareRejectNoResponse + } else if device.Trust == crypto.TrustStateBlacklisted { + return &crypto.KeyShareRejectBlacklisted + } else if device.Trust == crypto.TrustStateVerified || !cfg.RequireVerification { + portal := helper.bridge.GetPortalByMXID(info.RoomID) + if portal == nil { + helper.log.Debugfln("Rejecting key request for %s from %s/%s: room is not a portal", info.SessionID, device.UserID, device.DeviceID) + + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"} + } + user := helper.bridge.GetUserByMXID(device.UserID) + // FIXME reimplement IsInPortal + if !user.Admin /*&& !user.IsInPortal(portal.Key)*/ { + helper.log.Debugfln("Rejecting key request for %s from %s/%s: user is not in portal", info.SessionID, device.UserID, device.DeviceID) + + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"} + } + helper.log.Debugfln("Accepting key request for %s from %s/%s", info.SessionID, device.UserID, device.DeviceID) + + return nil + } + + return &crypto.KeyShareRejectUnverified +} + +func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) { + deviceID := helper.store.FindDeviceID() + if len(deviceID) > 0 { + helper.log.Debugln("Found existing device ID for bot in database:", deviceID) + } + + client, err := mautrix.NewClient(helper.bridge.as.HomeserverURL, "", "") + if err != nil { + return nil, fmt.Errorf("failed to initialize client: %w", err) + } + + client.Logger = helper.baseLog.Sub("Bot") + client.Client = helper.bridge.as.HTTPClient + client.DefaultHTTPRetries = helper.bridge.as.DefaultHTTPRetries + flows, err := client.GetLoginFlows() + if err != nil { + return nil, fmt.Errorf("failed to get supported login flows: %w", err) + } + + flow := flows.FirstFlowOfType(mautrix.AuthTypeAppservice, mautrix.AuthTypeHalfyAppservice) + if flow == nil { + return nil, fmt.Errorf("homeserver does not support appservice login") + } + + // We set the API token to the AS token here to authenticate the appservice login + // It'll get overridden after the login + client.AccessToken = helper.bridge.as.Registration.AppToken + resp, err := client.Login(&mautrix.ReqLogin{ + Type: flow.Type, + Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.as.BotMXID())}, + DeviceID: deviceID, + InitialDeviceDisplayName: "Discord Bridge", + StoreCredentials: true, + }) + if err != nil { + return nil, fmt.Errorf("failed to log in as bridge bot: %w", err) + } + + helper.store.DeviceID = resp.DeviceID + + return client, nil +} + +func (helper *CryptoHelper) Start() { + helper.log.Debugln("Starting syncer for receiving to-device messages") + + err := helper.client.Sync() + if err != nil { + helper.log.Errorln("Fatal error syncing:", err) + } else { + helper.log.Infoln("Bridge bot to-device syncer stopped without error") + } +} + +func (helper *CryptoHelper) Stop() { + helper.log.Debugln("CryptoHelper.Stop() called, stopping bridge bot sync") + helper.client.StopSync() +} + +func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) { + return helper.mach.DecryptMegolmEvent(evt) +} + +func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) { + encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, &content) + + if err != nil { + if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession { + return nil, err + } + + helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID) + users, err := helper.store.GetRoomMembers(roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room member list: %w", err) + } + + err = helper.mach.ShareGroupSession(roomID, users) + if err != nil { + return nil, fmt.Errorf("failed to share group session: %w", err) + } + + encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, &content) + if err != nil { + return nil, fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) + } + } + + return encrypted, nil +} + +func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { + return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout) +} + +func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { + err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) + if err != nil { + helper.log.Warnfln("Failed to send key request to %s/%s for %s in %s: %v", userID, deviceID, sessionID, roomID, err) + } else { + helper.log.Debugfln("Sent key request to %s/%s for %s in %s", userID, deviceID, sessionID, roomID) + } +} + +func (helper *CryptoHelper) ResetSession(roomID id.RoomID) { + err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID) + if err != nil { + helper.log.Debugfln("Error manually removing outbound group session in %s: %v", roomID, err) + } +} + +func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) { + helper.mach.HandleMemberEvent(evt) +} + +type cryptoSyncer struct { + *crypto.OlmMachine +} + +func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error { + done := make(chan struct{}) + go func() { + defer func() { + if err := recover(); err != nil { + syncer.Log.Error("Processing sync response (%s) panicked: %v\n%s", since, err, debug.Stack()) + } + done <- struct{}{} + }() + syncer.Log.Trace("Starting sync response handling (%s)", since) + syncer.ProcessSyncResponse(resp, since) + syncer.Log.Trace("Successfully handled sync response (%s)", since) + }() + + select { + case <-done: + case <-time.After(30 * time.Second): + syncer.Log.Warn("Handling sync response (%s) is taking unusually long", since) + } + + return nil +} + +func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { + syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err) + + return 10 * time.Second, nil +} + +func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { + everything := []event.Type{{Type: "*"}} + + return &mautrix.Filter{ + Presence: mautrix.FilterPart{NotTypes: everything}, + AccountData: mautrix.FilterPart{NotTypes: everything}, + Room: mautrix.RoomFilter{ + IncludeLeave: false, + Ephemeral: mautrix.FilterPart{NotTypes: everything}, + AccountData: mautrix.FilterPart{NotTypes: everything}, + State: mautrix.FilterPart{NotTypes: everything}, + Timeline: mautrix.FilterPart{NotTypes: everything}, + }, + } +} + +type cryptoLogger struct { + int maulogger.Logger +} + +func (c *cryptoLogger) Error(message string, args ...interface{}) { + c.int.Errorfln(message, args...) +} + +func (c *cryptoLogger) Warn(message string, args ...interface{}) { + c.int.Warnfln(message, args...) +} + +func (c *cryptoLogger) Debug(message string, args ...interface{}) { + c.int.Debugfln(message, args...) +} + +func (c *cryptoLogger) Trace(message string, args ...interface{}) { + c.int.Logfln(levelTrace, message, args...) +} + +type cryptoClientStore struct { + int *database.SQLCryptoStore +} + +func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {} +func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" } +func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {} +func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil } + +func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) { + c.int.PutNextBatch(nextBatchToken) +} + +func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string { + return c.int.GetNextBatch() +} + +var _ mautrix.Storer = (*cryptoClientStore)(nil) + +type cryptoStateStore struct { + bridge *Bridge +} + +var _ crypto.StateStore = (*cryptoStateStore)(nil) + +func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool { + portal := c.bridge.GetPortalByMXID(id) + if portal != nil { + return portal.Encrypted + } + + return false +} + +func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID { + return c.bridge.StateStore.FindSharedRooms(id) +} + +func (c *cryptoStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent { + // TODO implement + return nil +} diff --git a/bridge/matrix.go b/bridge/matrix.go index 4e88f18..cb9127e 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -1,7 +1,10 @@ package bridge import ( + "errors" + "fmt" "strings" + "time" "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" @@ -29,9 +32,11 @@ func (b *Bridge) setupEvents() { } b.eventProcessor.On(event.EventMessage, b.matrixHandler.handleMessage) + b.eventProcessor.On(event.EventEncrypted, b.matrixHandler.handleEncrypted) b.eventProcessor.On(event.EventReaction, b.matrixHandler.handleReaction) b.eventProcessor.On(event.EventRedaction, b.matrixHandler.handleRedaction) b.eventProcessor.On(event.StateMember, b.matrixHandler.handleMembership) + b.eventProcessor.On(event.StateEncryption, b.matrixHandler.handleEncryption) } func (mh *matrixHandler) join(evt *event.Event, intent *appservice.IntentAPI) *mautrix.RespJoinedMembers { @@ -185,6 +190,10 @@ func (mh *matrixHandler) handleMembership(evt *event.Event) { return } + if mh.bridge.crypto != nil { + mh.bridge.crypto.HandleMemberEvent(evt) + } + // Grab the content of the event. content := evt.Content.AsMember() @@ -255,3 +264,113 @@ func (mh *matrixHandler) handleRedaction(evt *event.Event) { portal.handleMatrixRedaction(evt) } } + +func (mh *matrixHandler) handleEncryption(evt *event.Event) { + if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 { + return + } + + portal := mh.bridge.GetPortalByMXID(evt.RoomID) + if portal != nil && !portal.Encrypted { + mh.log.Debugfln("%s enabled encryption in %s", evt.Sender, evt.RoomID) + portal.Encrypted = true + portal.Update() + } +} + +const sessionWaitTimeout = 5 * time.Second + +func (mh *matrixHandler) handleEncrypted(evt *event.Event) { + if mh.ignoreEvent(evt) || mh.bridge.crypto == nil { + return + } + + decrypted, err := mh.bridge.crypto.Decrypt(evt) + decryptionRetryCount := 0 + if errors.Is(err, NoSessionFound) { + content := evt.Content.AsEncrypted() + mh.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d seconds...", content.SessionID, evt.ID, int(sessionWaitTimeout.Seconds())) + mh.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, false, decryptionRetryCount) + decryptionRetryCount++ + + if mh.bridge.crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, sessionWaitTimeout) { + mh.log.Debugfln("Got session %s after waiting, trying to decrypt %s again", content.SessionID, evt.ID) + decrypted, err = mh.bridge.crypto.Decrypt(evt) + } else { + mh.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, fmt.Errorf("didn't receive encryption keys"), false, decryptionRetryCount) + + go mh.waitLongerForSession(evt) + + return + } + } + + if err != nil { + mh.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, true, decryptionRetryCount) + + mh.log.Warnfln("Failed to decrypt %s: %v", evt.ID, err) + _, _ = mh.bridge.bot.SendNotice(evt.RoomID, fmt.Sprintf( + "\u26a0 Your message was not bridged: %v", err)) + + return + } + + mh.as.SendMessageSendCheckpoint(decrypted, appservice.StepDecrypted, decryptionRetryCount) + mh.bridge.eventProcessor.Dispatch(decrypted) +} + +func (mh *matrixHandler) waitLongerForSession(evt *event.Event) { + const extendedTimeout = sessionWaitTimeout * 3 + + content := evt.Content.AsEncrypted() + mh.log.Debugfln("Couldn't find session %s trying to decrypt %s, waiting %d more seconds...", + content.SessionID, evt.ID, int(extendedTimeout.Seconds())) + + go mh.bridge.crypto.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + + resp, err := mh.bridge.bot.SendNotice(evt.RoomID, fmt.Sprintf( + "\u26a0 Your message was not bridged: the bridge hasn't received the decryption keys. "+ + "The bridge will retry for %d seconds. If this error keeps happening, try restarting your client.", + int(extendedTimeout.Seconds()))) + if err != nil { + mh.log.Errorfln("Failed to send decryption error to %s: %v", evt.RoomID, err) + } + + update := event.MessageEventContent{MsgType: event.MsgNotice} + + if mh.bridge.crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedTimeout) { + mh.log.Debugfln("Got session %s after waiting more, trying to decrypt %s again", content.SessionID, evt.ID) + + decrypted, err := mh.bridge.crypto.Decrypt(evt) + if err == nil { + mh.as.SendMessageSendCheckpoint(decrypted, appservice.StepDecrypted, 2) + mh.bridge.eventProcessor.Dispatch(decrypted) + _, _ = mh.bridge.bot.RedactEvent(evt.RoomID, resp.EventID) + + return + } + + mh.log.Warnfln("Failed to decrypt %s: %v", evt.ID, err) + mh.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, err, true, 2) + update.Body = fmt.Sprintf("\u26a0 Your message was not bridged: %v", err) + } else { + mh.log.Debugfln("Didn't get %s, giving up on %s", content.SessionID, evt.ID) + mh.as.SendErrorMessageSendCheckpoint(evt, appservice.StepDecrypted, fmt.Errorf("didn't receive encryption keys"), true, 2) + update.Body = "\u26a0 Your message was not bridged: the bridge hasn't received the decryption keys. " + + "If this error keeps happening, try restarting your client." + } + + newContent := update + update.NewContent = &newContent + if resp != nil { + update.RelatesTo = &event.RelatesTo{ + Type: event.RelReplace, + EventID: resp.EventID, + } + } + + _, err = mh.bridge.bot.SendMessageEvent(evt.RoomID, event.EventMessage, &update) + if err != nil { + mh.log.Debugfln("Failed to update decryption error notice %s: %v", resp.EventID, err) + } +} diff --git a/bridge/portal.go b/bridge/portal.go index 05562d3..2eed2ef 100644 --- a/bridge/portal.go +++ b/bridge/portal.go @@ -35,6 +35,7 @@ type Portal struct { log log.Logger roomCreateLock sync.Mutex + encryptLock sync.Mutex discordMessages chan portalDiscordMessage matrixMessages chan portalMatrixMessage @@ -144,7 +145,7 @@ func (p *Portal) handleMatrixInvite(sender *User, evt *event.Event) { p.log.Infoln("no puppet for %v", sender) // Open a conversation on the discord side? } - p.log.Infoln("puppet:", puppet) + p.log.Infoln("matrixInvite: puppet:", puppet) } func (p *Portal) messageLoop() { @@ -212,14 +213,25 @@ func (p *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) error var invite []id.UserID - if p.IsPrivateChat() { - invite = append(invite, p.bridge.bot.UserID) + if p.bridge.Config.Bridge.Encryption.Default { + initialState = append(initialState, &event.Event{ + Type: event.StateEncryption, + Content: event.Content{ + Parsed: event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1}, + }, + }) + p.Encrypted = true + + if p.IsPrivateChat() { + invite = append(invite, p.bridge.bot.UserID) + } } resp, err := intent.CreateRoom(&mautrix.ReqCreateRoom{ Visibility: "private", Name: p.Name, Topic: p.Topic, + Invite: invite, Preset: "private_chat", IsDirect: p.IsPrivateChat(), InitialState: initialState, @@ -325,7 +337,7 @@ func (p *Portal) sendMediaFailedMessage(intent *appservice.IntentAPI, bridgeErr MsgType: event.MsgNotice, } - _, err := intent.SendMessageEvent(p.MXID, event.EventMessage, content) + _, err := p.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) if err != nil { p.log.Warnfln("failed to send error message to matrix: %v", err) } @@ -379,7 +391,7 @@ func (p *Portal) handleDiscordAttachment(intent *appservice.IntentAPI, msgID str return } - resp, err := intent.SendMessageEvent(p.MXID, event.EventMessage, content) + resp, err := p.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) if err != nil { p.log.Warnfln("failed to send media message to matrix: %v", err) } @@ -426,7 +438,7 @@ func (p *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Message) } } - resp, err := intent.SendMessageEvent(p.MXID, event.EventMessage, content) + resp, err := p.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) if err != nil { p.log.Warnfln("failed to send message %q to matrix: %v", msg.ID, err) @@ -498,7 +510,7 @@ func (p *Portal) handleDiscordMessagesUpdate(user *User, msg *discordgo.Message) content.SetEdit(existing.MatrixID) - resp, err := intent.SendMessageEvent(p.MXID, event.EventMessage, content) + resp, err := p.sendMatrixMessage(intent, event.EventMessage, content, nil, time.Now().UTC().UnixMilli()) if err != nil { p.log.Warnfln("failed to send message %q to matrix: %v", msg.ID, err) @@ -567,6 +579,57 @@ func (p *Portal) syncParticipants(source *User, participants []*discordgo.User) } } +func (portal *Portal) encrypt(content *event.Content, eventType event.Type) (event.Type, error) { + if portal.Encrypted && portal.bridge.crypto != nil { + // TODO maybe the locking should be inside mautrix-go? + portal.encryptLock.Lock() + encrypted, err := portal.bridge.crypto.Encrypt(portal.MXID, eventType, *content) + portal.encryptLock.Unlock() + if err != nil { + return eventType, fmt.Errorf("failed to encrypt event: %w", err) + } + eventType = event.EventEncrypted + content.Parsed = encrypted + } + return eventType, nil +} + +const doublePuppetKey = "fi.mau.double_puppet_source" +const doublePuppetValue = "mautrix-discord" + +func (portal *Portal) sendMatrixMessage(intent *appservice.IntentAPI, eventType event.Type, content *event.MessageEventContent, extraContent map[string]interface{}, timestamp int64) (*mautrix.RespSendEvent, error) { + wrappedContent := event.Content{Parsed: content, Raw: extraContent} + if timestamp != 0 && intent.IsCustomPuppet { + if wrappedContent.Raw == nil { + wrappedContent.Raw = map[string]interface{}{} + } + if intent.IsCustomPuppet { + wrappedContent.Raw[doublePuppetKey] = doublePuppetValue + } + } + var err error + eventType, err = portal.encrypt(&wrappedContent, eventType) + if err != nil { + return nil, err + } + + if eventType == event.EventEncrypted { + // Clear other custom keys if the event was encrypted, but keep the double puppet identifier + if intent.IsCustomPuppet { + wrappedContent.Raw = map[string]interface{}{doublePuppetKey: doublePuppetValue} + } else { + wrappedContent.Raw = nil + } + } + + _, _ = intent.UserTyping(portal.MXID, false, 0) + if timestamp == 0 { + return intent.SendMessageEvent(portal.MXID, eventType, &wrappedContent) + } else { + return intent.SendMassagedMessageEvent(portal.MXID, eventType, &wrappedContent, timestamp) + } +} + func (p *Portal) handleMatrixMessages(msg portalMatrixMessage) { switch msg.evt.Type { case event.EventMessage: diff --git a/bridge/user.go b/bridge/user.go index 1dcb262..bf35510 100644 --- a/bridge/user.go +++ b/bridge/user.go @@ -32,6 +32,9 @@ type User struct { bridge *Bridge log log.Logger + // TODO finish implementing + Admin bool + guilds map[string]*database.Guild guildsLock sync.Mutex diff --git a/config/bridge.go b/config/bridge.go index 9d0510c..ce863db 100644 --- a/config/bridge.go +++ b/config/bridge.go @@ -30,6 +30,8 @@ type bridge struct { DoublePuppetAllowDiscovery bool `yaml:"double_puppet_allow_discovery"` LoginSharedSecretMap map[string]string `yaml:"login_shared_secret_map"` + Encryption encryption `yaml:"encryption"` + usernameTemplate *template.Template `yaml:"-"` displaynameTemplate *template.Template `yaml:"-"` channelnameTemplate *template.Template `yaml:"-"` diff --git a/config/encryption.go b/config/encryption.go new file mode 100644 index 0000000..1d57c39 --- /dev/null +++ b/config/encryption.go @@ -0,0 +1,29 @@ +package config + +type encryption struct { + Allow bool `yaml:"allow"` + Default bool `yaml:"default"` + + KeySharing struct { + Allow bool `yaml:"allow"` + RequireCrossSigning bool `yaml:"require_cross_signing"` + RequireVerification bool `yaml:"require_verification"` + } `yaml:"key_sharing"` +} + +func (e *encryption) validate() error { + return nil +} + +func (e *encryption) UnmarshalYAML(unmarshal func(interface{}) error) error { + type rawEncryption encryption + + raw := rawEncryption{} + if err := unmarshal(&raw); err != nil { + return err + } + + *e = encryption(raw) + + return e.validate() +} diff --git a/database/cryptostore.go b/database/cryptostore.go new file mode 100644 index 0000000..171c824 --- /dev/null +++ b/database/cryptostore.go @@ -0,0 +1,97 @@ +package database + +import ( + "database/sql" + + log "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/id" +) + +type SQLCryptoStore struct { + *crypto.SQLCryptoStore + UserID id.UserID + GhostIDFormat string +} + +var _ crypto.Store = (*SQLCryptoStore)(nil) + +func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore { + return &SQLCryptoStore{ + SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "", + []byte("maunium.net/go/mautrix-whatsapp"), + &cryptoLogger{db.log.Sub("CryptoStore")}), + UserID: userID, + GhostIDFormat: ghostIDFormat, + } +} + +func (store *SQLCryptoStore) FindDeviceID() id.DeviceID { + var deviceID id.DeviceID + + query := `SELECT device_id FROM crypto_account WHERE account_id=$1` + err := store.DB.QueryRow(query, store.AccountID).Scan(&deviceID) + if err != nil && err != sql.ErrNoRows { + store.Log.Warn("Failed to scan device ID: %v", err) + } + + return deviceID +} + +func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) ([]id.UserID, error) { + query := ` + SELECT user_id FROM mx_user_profile + WHERE room_id=$1 + AND (membership='join' OR membership='invite') + AND user_id<>$2 + AND user_id NOT LIKE $3 + ` + + members := []id.UserID{} + + rows, err := store.DB.Query(query, roomID, store.UserID, store.GhostIDFormat) + if err != nil { + return members, err + } + + for rows.Next() { + var userID id.UserID + err := rows.Scan(&userID) + if err != nil { + store.Log.Warn("Failed to scan member in %s: %v", roomID, err) + return members, err + } + + members = append(members, userID) + } + + return members, nil +} + +// TODO merge this with the one in the parent package +type cryptoLogger struct { + int log.Logger +} + +var levelTrace = log.Level{ + Name: "TRACE", + Severity: -10, + Color: -1, +} + +func (c *cryptoLogger) Error(message string, args ...interface{}) { + c.int.Errorfln(message, args...) +} + +func (c *cryptoLogger) Warn(message string, args ...interface{}) { + c.int.Warnfln(message, args...) +} + +func (c *cryptoLogger) Debug(message string, args ...interface{}) { + c.int.Debugfln(message, args...) +} + +func (c *cryptoLogger) Trace(message string, args ...interface{}) { + c.int.Logfln(levelTrace, message, args...) +} diff --git a/database/migrations/08-add-crypto-store-to-database.sql b/database/migrations/08-add-crypto-store-to-database.sql new file mode 100644 index 0000000..c615976 --- /dev/null +++ b/database/migrations/08-add-crypto-store-to-database.sql @@ -0,0 +1,3 @@ +-- This migration is implemented in migrations.go as it comes from +-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 0 +-- which is described as "Add crypto store to database". diff --git a/database/migrations/09-add-account_id-to-crypto-store.sql b/database/migrations/09-add-account_id-to-crypto-store.sql new file mode 100644 index 0000000..03dc1cc --- /dev/null +++ b/database/migrations/09-add-account_id-to-crypto-store.sql @@ -0,0 +1,3 @@ +-- This migration is implemented in migrations.go as it comes from +-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 1 +-- which is described as "Add account_id to crypto store". diff --git a/database/migrations/10-add-megolm-withheld-data-to-crypto-store.sql b/database/migrations/10-add-megolm-withheld-data-to-crypto-store.sql new file mode 100644 index 0000000..38813b8 --- /dev/null +++ b/database/migrations/10-add-megolm-withheld-data-to-crypto-store.sql @@ -0,0 +1,3 @@ +-- This migration is implemented in migrations.go as it comes from +-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 2 +-- which is described as "Add megolm withheld data to crypto store". diff --git a/database/migrations/11-add-cross-signing-keys-to-crypto-store.sql b/database/migrations/11-add-cross-signing-keys-to-crypto-store.sql new file mode 100644 index 0000000..39f5041 --- /dev/null +++ b/database/migrations/11-add-cross-signing-keys-to-crypto-store.sql @@ -0,0 +1,3 @@ +-- This migration is implemented in migrations.go as it comes from +-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 3 +-- which is described as "Add cross-signing keys to crypto store". diff --git a/database/migrations/12-replace-varchar-with-text-in-the-crypto-database.sql b/database/migrations/12-replace-varchar-with-text-in-the-crypto-database.sql new file mode 100644 index 0000000..adb841e --- /dev/null +++ b/database/migrations/12-replace-varchar-with-text-in-the-crypto-database.sql @@ -0,0 +1,4 @@ +-- This migration is implemented in migrations.go as it comes from +-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 4 +-- which is described as "Replace VARCHAR(255) with TEXT in the crypto +-- database". diff --git a/database/migrations/13-split-last_used-into-last_encrypted-and-last_decrypted-in-crypto-store.sql b/database/migrations/13-split-last_used-into-last_encrypted-and-last_decrypted-in-crypto-store.sql new file mode 100644 index 0000000..28906bd --- /dev/null +++ b/database/migrations/13-split-last_used-into-last_encrypted-and-last_decrypted-in-crypto-store.sql @@ -0,0 +1,4 @@ +-- This migration is implemented in migrations.go as it comes from +-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 5 +-- which is described as "Split last_used into last_encrypted and +-- last_decrypted in crypto store". diff --git a/database/migrations/14-add-encrypted-column-to-portal-table.sql b/database/migrations/14-add-encrypted-column-to-portal-table.sql new file mode 100644 index 0000000..d032fee --- /dev/null +++ b/database/migrations/14-add-encrypted-column-to-portal-table.sql @@ -0,0 +1 @@ +ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false; diff --git a/database/migrations/migrations.go b/database/migrations/migrations.go index 06fcce4..904b9f4 100644 --- a/database/migrations/migrations.go +++ b/database/migrations/migrations.go @@ -3,37 +3,18 @@ package migrations import ( "database/sql" "embed" - "sort" "github.com/lopezator/migrator" log "maunium.net/go/maulogger/v2" + "maunium.net/go/mautrix/crypto/sql_store_upgrade" ) //go:embed *.sql var embeddedMigrations embed.FS -var ( - commonMigrations = []string{ - "01-initial.sql", - "02-attachments.sql", - "03-emoji.sql", - "04-custom-puppet.sql", - "05-additional-puppet-fields.sql", - "07-guilds.sql", - } - - sqliteMigrations = []string{ - "06-remove-unique-user-constraint.sqlite.sql", - } - - postgresMigrations = []string{ - "06-remove-unique-user-constraint.postgres.sql", - } -) - -func migrationFromFile(filename string) *migrator.Migration { +func migrationFromFile(description, filename string) *migrator.Migration { return &migrator.Migration{ - Name: filename, + Name: description, Func: func(tx *sql.Tx) error { data, err := embeddedMigrations.ReadFile(filename) if err != nil { @@ -49,31 +30,83 @@ func migrationFromFile(filename string) *migrator.Migration { } } +func migrationFromFileWithDialect(dialect, description, sqliteFile, postgresFile string) *migrator.Migration { + switch dialect { + case "sqlite3": + return migrationFromFile(description, sqliteFile) + case "postgres": + return migrationFromFile(description, postgresFile) + default: + return nil + } +} + func Run(db *sql.DB, baseLog log.Logger, dialect string) error { subLogger := baseLog.Sub("Migrations") logger := migrator.LoggerFunc(func(msg string, args ...interface{}) { subLogger.Infof(msg, args...) }) - migrationNames := commonMigrations - switch dialect { - case "sqlite3": - migrationNames = append(migrationNames, sqliteMigrations...) - case "postgres": - migrationNames = append(migrationNames, postgresMigrations...) - } - - sort.Strings(migrationNames) - - migrations := make([]interface{}, len(migrationNames)) - for idx, name := range migrationNames { - migrations[idx] = migrationFromFile(name) - } - m, err := migrator.New( migrator.TableName("version"), migrator.WithLogger(logger), - migrator.Migrations(migrations...), + migrator.Migrations( + migrationFromFile("Initial Schema", "01-initial.sql"), + migrationFromFile("Attachments", "02-attachments.sql"), + migrationFromFile("Emoji", "03-emoji.sql"), + migrationFromFile("Custom Puppets", "04-custom-puppet.sql"), + migrationFromFile( + "Additional puppet fields", + "05-additional-puppet-fields.sql", + ), + migrationFromFileWithDialect( + dialect, + "Remove unique user constraint", + "06-remove-unique-user-constraint.sqlite.sql", + "06-remove-unique-user-constraint.postgres.sql", + ), + migrationFromFile("Guild Bridging", "07-guilds.sql"), + &migrator.Migration{ + Name: "Add crypto store to database", + Func: func(tx *sql.Tx) error { + return sql_store_upgrade.Upgrades[0](tx, dialect) + }, + }, + &migrator.Migration{ + Name: "Add account_id to crypto store", + Func: func(tx *sql.Tx) error { + return sql_store_upgrade.Upgrades[1](tx, dialect) + }, + }, + &migrator.Migration{ + Name: "Add megolm withheld data to crypto store", + Func: func(tx *sql.Tx) error { + return sql_store_upgrade.Upgrades[2](tx, dialect) + }, + }, + &migrator.Migration{ + Name: "Add cross-signing keys to crypto store", + Func: func(tx *sql.Tx) error { + return sql_store_upgrade.Upgrades[3](tx, dialect) + }, + }, + &migrator.Migration{ + Name: "Replace VARCHAR(255) with TEXT in the crypto database", + Func: func(tx *sql.Tx) error { + return sql_store_upgrade.Upgrades[4](tx, dialect) + }, + }, + &migrator.Migration{ + Name: "Split last_used into last_encrypted and last_decrypted in crypto store", + Func: func(tx *sql.Tx) error { + return sql_store_upgrade.Upgrades[5](tx, dialect) + }, + }, + migrationFromFile( + "Add encryption column to portal table", + "14-add-encrypted-column-to-portal-table.sql", + ), + ), ) if err != nil { return err diff --git a/database/portal.go b/database/portal.go index 663651b..96e4c45 100644 --- a/database/portal.go +++ b/database/portal.go @@ -19,6 +19,8 @@ type Portal struct { Name string Topic string + Encrypted bool + Avatar string AvatarURL id.ContentURI @@ -33,7 +35,8 @@ func (p *Portal) Scan(row Scannable) *Portal { var typ sql.NullInt32 err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name, - &p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID) + &p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID, + &p.Encrypted) if err != nil { if err != sql.ErrNoRows { @@ -62,12 +65,12 @@ func (p *Portal) mxidPtr() *id.RoomID { func (p *Portal) Insert() { query := "INSERT INTO portal" + " (channel_id, receiver, mxid, name, topic, avatar, avatar_url," + - " type, dmuser, first_event_id)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" + " type, dmuser, first_event_id, encrypted)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" _, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(), p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.DMUser, - p.FirstEventID.String()) + p.FirstEventID.String(), p.Encrypted) if err != nil { p.log.Warnfln("Failed to insert %s: %v", p.Key, err) @@ -77,11 +80,12 @@ func (p *Portal) Insert() { func (p *Portal) Update() { query := "UPDATE portal SET" + " mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," + - " dmuser=$7, first_event_id=$8" + - " WHERE channel_id=$9 AND receiver=$10" + " dmuser=$7, first_event_id=$8, encrypted=$9" + + " WHERE channel_id=$10 AND receiver=$11" _, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.DMUser, p.FirstEventID.String(), + p.Encrypted, p.Key.ChannelID, p.Key.Receiver) if err != nil { diff --git a/database/portalquery.go b/database/portalquery.go index 8b4353a..d9dd1ba 100644 --- a/database/portalquery.go +++ b/database/portalquery.go @@ -6,6 +6,12 @@ import ( "maunium.net/go/mautrix/id" ) +const ( + portalSelect = "SELECT channel_id, receiver, mxid, name, topic, avatar," + + " avatar_url, type, dmuser, first_event_id, encrypted" + + " FROM portal" +) + type PortalQuery struct { db *Database log log.Logger @@ -19,23 +25,23 @@ func (pq *PortalQuery) New() *Portal { } func (pq *PortalQuery) GetAll() []*Portal { - return pq.getAll("SELECT * FROM portal") + return pq.getAll(portalSelect) } func (pq *PortalQuery) GetByID(key PortalKey) *Portal { - return pq.get("SELECT * FROM portal WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver) + return pq.get(portalSelect+" WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver) } func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { - return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid) + return pq.get(portalSelect+" WHERE mxid=$1", mxid) } func (pq *PortalQuery) GetAllByID(id string) []*Portal { - return pq.getAll("SELECT * FROM portal WHERE receiver=$1", id) + return pq.getAll(portalSelect+" WHERE receiver=$1", id) } func (pq *PortalQuery) FindPrivateChats(receiver string) []*Portal { - query := "SELECT * FROM portal WHERE receiver=$1 AND type=$2;" + query := portalSelect + " portal WHERE receiver=$1 AND type=$2;" return pq.getAll(query, receiver, discordgo.ChannelTypeDM) } diff --git a/database/sqlstatestore.go b/database/sqlstatestore.go index c5d800b..e316773 100644 --- a/database/sqlstatestore.go +++ b/database/sqlstatestore.go @@ -272,3 +272,31 @@ func (s *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventT return s.GetPowerLevel(roomID, userID) >= s.GetPowerLevelRequirement(roomID, eventType) } + +func (store *SQLStateStore) FindSharedRooms(userID id.UserID) []id.RoomID { + query := "SELECT room_id FROM mx_user_profile" + + "LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id" + + "WHERE user_id=$1 AND portal.encrypted=true" + + rooms := []id.RoomID{} + + rows, err := store.db.Query(query, userID) + if err != nil { + store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err) + + return rooms + } + + for rows.Next() { + var roomID id.RoomID + + err = rows.Scan(&roomID) + if err != nil { + store.log.Warnfln("Failed to scan room ID: %v", err) + } else { + rooms = append(rooms, roomID) + } + } + + return rooms +} diff --git a/go.mod b/go.mod index bba6b6a..536f49d 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,10 @@ require ( github.com/gorilla/mux v1.8.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect + github.com/tidwall/gjson v1.14.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/sjson v1.2.4 // indirect golang.org/x/crypto v0.0.0-20220331220935-ae2d96664a29 // indirect golang.org/x/net v0.0.0-20220403103023-749bd193bc2b // indirect golang.org/x/sys v0.0.0-20220406163625-3f8b81556e12 // indirect diff --git a/go.sum b/go.sum index 2066e59..7ff3c7f 100644 --- a/go.sum +++ b/go.sum @@ -35,9 +35,13 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.0 h1:6aeJ0bzojgWLa82gDQHcx3S0Lr/O51I9bJ5nv6JFx5w= github.com/tidwall/gjson v1.14.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc= github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM= gitlab.com/beeper/discordgo v0.23.3-0.20220219094025-13ff4cc63da7 h1:8ieR27GadHnShqhsvPrDzL1/ZOntavGGt4TXqafncYE= gitlab.com/beeper/discordgo v0.23.3-0.20220219094025-13ff4cc63da7/go.mod h1:Hwfv4M8yP/MDh47BN+4Z1WItJ1umLKUyplCH5KcQPgE=