diff --git a/pkg/connector/attachments.go b/pkg/connector/attachments.go index ae12af1..fe2ccc3 100644 --- a/pkg/connector/attachments.go +++ b/pkg/connector/attachments.go @@ -27,7 +27,6 @@ import ( "strconv" "github.com/bwmarrin/discordgo" - "github.com/rs/zerolog" "go.mau.fi/mautrix-discord/pkg/attachment" "maunium.net/go/mautrix/bridgev2" ) diff --git a/pkg/connector/backfill.go b/pkg/connector/backfill.go index 99dee51..e7606e8 100644 --- a/pkg/connector/backfill.go +++ b/pkg/connector/backfill.go @@ -23,7 +23,6 @@ import ( "github.com/bwmarrin/discordgo" "github.com/rs/zerolog" - "go.mau.fi/mautrix-discord/pkg/msgconv" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" ) @@ -65,10 +64,6 @@ func (dc *DiscordClient) FetchMessages(ctx context.Context, fetchParams bridgev2 } converted := make([]*bridgev2.BackfillMessage, 0, len(msgs)) - mc := msgconv.MessageConverter{ - Bridge: dc.connector.Bridge, - ReuploadMedia: dc.connector.ReuploadMedia, - } for _, msg := range msgs { streamOrder, _ := strconv.ParseInt(msg.ID, 10, 64) ts, _ := discordgo.SnowflakeTimestamp(msg.ID) @@ -90,7 +85,7 @@ func (dc *DiscordClient) FetchMessages(ctx context.Context, fetchParams bridgev2 converted = append(converted, &bridgev2.BackfillMessage{ ID: networkid.MessageID(msg.ID), - ConvertedMessage: mc.ToMatrix(ctx, fetchParams.Portal, intent, dc.UserLogin, msg), + ConvertedMessage: dc.connector.MsgConv.ToMatrix(ctx, fetchParams.Portal, intent, dc.UserLogin, msg), Sender: sender, Timestamp: ts, StreamOrder: streamOrder, diff --git a/pkg/connector/client.go b/pkg/connector/client.go index f43b115..5fd5f08 100644 --- a/pkg/connector/client.go +++ b/pkg/connector/client.go @@ -28,6 +28,7 @@ import ( "github.com/bwmarrin/discordgo" "github.com/rs/zerolog" "go.mau.fi/mautrix-discord/pkg/discordid" + "go.mau.fi/mautrix-discord/pkg/msgconv" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -40,6 +41,7 @@ type DiscordClient struct { UserLogin *bridgev2.UserLogin Session *discordgo.Session hasBegunSyncing bool + MsgConv msgconv.MessageConverter } func (d *DiscordConnector) LoadUserLogin(ctx context.Context, login *bridgev2.UserLogin) error { @@ -59,14 +61,16 @@ func (d *DiscordConnector) LoadUserLogin(ctx context.Context, login *bridgev2.Us return err } - // FIXME(skip): Implement. - session.EventHandler = func(evt any) {} - - login.Client = &DiscordClient{ + cl := DiscordClient{ connector: d, UserLogin: login, Session: session, + MsgConv: msgconv.MessageConverter{ + Bridge: d.Bridge, + ReuploadMedia: d.ReuploadMedia, + }, } + login.Client = &cl return nil } @@ -97,10 +101,18 @@ func (d *DiscordClient) Connect(ctx context.Context) { }) } +func (cl *DiscordClient) handleDiscordEventSync(event any) { + go cl.handleDiscordEvent(event) +} + func (cl *DiscordClient) connect(ctx context.Context) error { log := zerolog.Ctx(ctx) log.Info().Msg("Opening session") + cl.Session.EventHandler = func(event any) { + go cl.handleDiscordEvent(event) + } + err := cl.Session.Open() for attempts := 0; errors.Is(err, discordgo.ErrImmediateDisconnect) && attempts < 2; attempts += 1 { log.Err(err).Int("attempts", attempts).Msg("Immediately disconnected while trying to open session, trying again in 5 seconds") diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index 56c156b..cf3a1fd 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -19,17 +19,23 @@ package connector import ( "context" + "go.mau.fi/mautrix-discord/pkg/msgconv" "maunium.net/go/mautrix/bridgev2" ) type DiscordConnector struct { - Bridge *bridgev2.Bridge + Bridge *bridgev2.Bridge + MsgConv *msgconv.MessageConverter } var _ bridgev2.NetworkConnector = (*DiscordConnector)(nil) func (d *DiscordConnector) Init(bridge *bridgev2.Bridge) { d.Bridge = bridge + d.MsgConv = &msgconv.MessageConverter{ + Bridge: bridge, + ReuploadMedia: d.ReuploadMedia, + } } func (d *DiscordConnector) Start(ctx context.Context) error { diff --git a/pkg/connector/handlediscord.go b/pkg/connector/handlediscord.go new file mode 100644 index 0000000..7ea84f0 --- /dev/null +++ b/pkg/connector/handlediscord.go @@ -0,0 +1,118 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2024 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package connector + +import ( + "context" + "runtime/debug" + + "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type DiscordEventMeta struct { + Type bridgev2.RemoteEventType + PortalKey networkid.PortalKey + LogContext func(c zerolog.Context) zerolog.Context +} + +func (em *DiscordEventMeta) AddLogContext(c zerolog.Context) zerolog.Context { + if em.LogContext == nil { + return c + } + c = em.LogContext(c) + return c +} + +func (em *DiscordEventMeta) GetType() bridgev2.RemoteEventType { + return em.Type +} + +func (em *DiscordEventMeta) GetPortalKey() networkid.PortalKey { + return em.PortalKey +} + +type DiscordMessage struct { + *DiscordEventMeta + Data *discordgo.Message + Client *DiscordClient +} + +var ( + _ bridgev2.RemoteMessage = (*DiscordMessage)(nil) + // _ bridgev2.RemoteEdit = (*DiscordMessage)(nil) + // _ bridgev2.RemoteMessageRemove = (*DiscordMessage)(nil) +) + +func (m *DiscordMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return m.Client.connector.MsgConv.ToMatrix(ctx, portal, intent, m.Client.UserLogin, m.Data), nil +} + +func (m *DiscordMessage) GetID() networkid.MessageID { + return networkid.MessageID(m.Data.ID) +} + +func (m *DiscordMessage) GetSender() bridgev2.EventSender { + return m.Client.makeEventSender(m.Data.Author) +} + +func (d *DiscordClient) wrapDiscordMessage(evt *discordgo.MessageCreate) DiscordMessage { + return DiscordMessage{ + DiscordEventMeta: &DiscordEventMeta{ + Type: bridgev2.RemoteEventMessage, + PortalKey: networkid.PortalKey{ + ID: networkid.PortalID(evt.ChannelID), + Receiver: d.UserLogin.ID, + }, + }, + Data: evt.Message, + Client: d, + } +} + +func (d *DiscordClient) handleDiscordEvent(rawEvt any) { + defer func() { + err := recover() + if err != nil { + d.UserLogin.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in Discord event handler") + } + }() + + log := d.UserLogin.Log.With().Str("action", "handle discord event"). + Type("event_type", rawEvt). + Logger() + + switch evt := rawEvt.(type) { + case *discordgo.MessageCreate: + wrappedEvt := d.wrapDiscordMessage(evt) + d.UserLogin.Bridge.QueueRemoteEvent(d.UserLogin, &wrappedEvt) + case *discordgo.PresenceUpdate: + return + case *discordgo.Event: + // For presently unknown reasons sometimes discordgo won't unmarshal + // events into their proper corresponding structs. + if evt.Type == "PRESENCE_UPDATE" || evt.Type == "PASSIVE_UPDATE_V2" || evt.Type == "CONVERSATION_SUMMARY_UPDATE" { + return + } + log.Debug().Str("event_type", evt.Type).Msg("Ignoring unknown Discord event") + } +} diff --git a/pkg/connector/login.go b/pkg/connector/login.go index efea9d1..b5579da 100644 --- a/pkg/connector/login.go +++ b/pkg/connector/login.go @@ -94,9 +94,6 @@ func (dl *DiscordLogin) SubmitUserInput(ctx context.Context, input map[string]st return nil, fmt.Errorf("couldn't create discord session: %w", err) } - // FIXME(skip): Implement. - session.EventHandler = func(evt any) {} - // Set up logging. session.LogLevel = discordgo.LogInformational session.Logger = func(msgL, caller int, format string, a ...any) {