diff --git a/emoji.go b/emoji.go index 1130c7f..2365926 100644 --- a/emoji.go +++ b/emoji.go @@ -10,6 +10,32 @@ import ( "maunium.net/go/mautrix/id" ) +func (portal *Portal) getEmojiMXCByDiscordID(emojiID, name string, animated bool) id.ContentURI { + dbEmoji := portal.bridge.DB.Emoji.GetByDiscordID(emojiID) + + if dbEmoji == nil { + data, mimeType, err := portal.downloadDiscordEmoji(emojiID, animated) + if err != nil { + portal.log.Warnfln("Failed to download emoji %s from discord: %v", emojiID, err) + return id.ContentURI{} + } + + uri, err := portal.uploadMatrixEmoji(portal.MainIntent(), data, mimeType) + if err != nil { + portal.log.Warnfln("Failed to upload discord emoji %s to homeserver: %v", emojiID, err) + return id.ContentURI{} + } + + dbEmoji = portal.bridge.DB.Emoji.New() + dbEmoji.DiscordID = emojiID + dbEmoji.DiscordName = name + dbEmoji.MatrixURL = uri + dbEmoji.Insert() + } + + return dbEmoji.MatrixURL +} + func (portal *Portal) downloadDiscordEmoji(id string, animated bool) ([]byte, string, error) { var url string var mimeType string diff --git a/formatter.go b/formatter.go index 934ac07..f3dcef5 100644 --- a/formatter.go +++ b/formatter.go @@ -1,3 +1,19 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2022 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + package main import ( @@ -13,14 +29,17 @@ import ( "maunium.net/go/mautrix/format/mdext" ) -var mdRenderer = goldmark.New(format.Extensions, format.HTMLOptions, - goldmark.WithExtensions(mdext.EscapeHTML, mdext.SimpleSpoiler, mdext.DiscordUnderline)) +var discordExtensions = goldmark.WithExtensions(mdext.EscapeHTML, mdext.SimpleSpoiler, mdext.DiscordUnderline) var escapeFixer = regexp.MustCompile(`\\(__[^_]|\*\*[^*])`) -func renderDiscordMarkdown(text string) event.MessageEventContent { +func (portal *Portal) renderDiscordMarkdown(text string) event.MessageEventContent { text = escapeFixer.ReplaceAllStringFunc(text, func(s string) string { return s[:2] + `\` + s[2:] }) + mdRenderer := goldmark.New( + format.Extensions, format.HTMLOptions, discordExtensions, + goldmark.WithExtensions(&DiscordTag{portal}), + ) return format.RenderMarkdownCustom(text, mdRenderer) } diff --git a/formatter_tag.go b/formatter_tag.go new file mode 100644 index 0000000..120121c --- /dev/null +++ b/formatter_tag.go @@ -0,0 +1,208 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2022 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/text" + "github.com/yuin/goldmark/util" + "maunium.net/go/mautrix" + + "go.mau.fi/mautrix-discord/database" +) + +type astDiscordTag struct { + ast.BaseInline + id int64 +} + +var _ ast.Node = (*astDiscordTag)(nil) +var astKindDiscordTag = ast.NewNodeKind("DiscordTag") + +func (n *astDiscordTag) Dump(source []byte, level int) { + ast.DumpHelper(n, source, level, nil, nil) +} + +func (n *astDiscordTag) Kind() ast.NodeKind { + return astKindDiscordTag +} + +type astDiscordUserMention struct { + astDiscordTag + hasNick bool +} + +func (n *astDiscordUserMention) String() string { + if n.hasNick { + return fmt.Sprintf("<@!%d>", n.id) + } + return fmt.Sprintf("<@%d>", n.id) +} + +type astDiscordRoleMention struct { + astDiscordTag +} + +func (n *astDiscordRoleMention) String() string { + return fmt.Sprintf("<@&%d>", n.id) +} + +type astDiscordChannelMention struct { + astDiscordTag + + guildID int64 + name string +} + +func (n *astDiscordChannelMention) String() string { + if n.guildID != 0 { + return fmt.Sprintf("<#%d:%d:%s>", n.id, n.guildID, n.name) + } + return fmt.Sprintf("<#%d>", n.id) +} + +type astDiscordCustomEmoji struct { + astDiscordTag + name string + animated bool +} + +func (n *astDiscordCustomEmoji) String() string { + if n.animated { + return fmt.Sprintf("", n.name, n.id) + } + return fmt.Sprintf("<%s%s>", n.name, n.id) +} + +type discordTagParser struct{} + +var discordTagRegex = regexp.MustCompile(`<(a?:\w+:|@[!&]?|#)(\d+)(?::(\d+):(.+?))?>`) +var defaultDiscordTagParser = &discordTagParser{} + +func (s *discordTagParser) Trigger() []byte { + return []byte{'<'} +} + +func (s *discordTagParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { + //before := block.PrecendingCharacter() + line, _ := block.PeekLine() + match := discordTagRegex.FindSubmatch(line) + if match == nil { + return nil + } + //seg := segment.WithStop(segment.Start + len(match[0])) + block.Advance(len(match[0])) + + id, err := strconv.ParseInt(string(match[2]), 10, 64) + if err != nil { + return nil + } + tag := astDiscordTag{id: id} + tagName := string(match[1]) + switch { + case tagName == "@": + return &astDiscordUserMention{astDiscordTag: tag} + case tagName == "@!": + return &astDiscordUserMention{astDiscordTag: tag, hasNick: true} + case tagName == "@&": + return &astDiscordRoleMention{astDiscordTag: tag} + case tagName == "#": + var guildID int64 + var channelName string + if len(match[3]) > 0 && len(match[4]) > 0 { + guildID, _ = strconv.ParseInt(string(match[3]), 10, 64) + channelName = string(match[4]) + } + return &astDiscordChannelMention{astDiscordTag: tag, guildID: guildID, name: channelName} + case strings.HasPrefix(tagName, ":"): + return &astDiscordCustomEmoji{name: tagName, astDiscordTag: tag} + case strings.HasPrefix(tagName, "a:"): + return &astDiscordCustomEmoji{name: tagName[1:], astDiscordTag: tag} + default: + return nil + } +} + +func (s *discordTagParser) CloseBlock(parent ast.Node, pc parser.Context) { + // nothing to do +} + +type discordTagHTMLRenderer struct { + portal *Portal +} + +func (r *discordTagHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { + reg.Register(astKindDiscordTag, r.renderDiscordMention) +} + +func (r *discordTagHTMLRenderer) renderDiscordMention(w util.BufWriter, source []byte, n ast.Node, entering bool) (status ast.WalkStatus, err error) { + status = ast.WalkContinue + if !entering { + return + } + switch node := n.(type) { + case *astDiscordUserMention: + puppet := r.portal.bridge.GetPuppetByID(strconv.FormatInt(node.id, 10)) + _, _ = fmt.Fprintf(w, `%s`, puppet.MXID, puppet.Name) + return + case *astDiscordRoleMention: + // TODO + case *astDiscordChannelMention: + portal := r.portal.bridge.GetExistingPortalByID(database.PortalKey{ + ChannelID: strconv.FormatInt(node.id, 10), + Receiver: "", + }) + if portal != nil { + _, _ = fmt.Fprintf(w, `%s`, portal.MXID, portal.bridge.AS.HomeserverDomain, portal.Name) + return + } + case *astDiscordCustomEmoji: + reactionMXC := r.portal.getEmojiMXCByDiscordID(strconv.FormatInt(node.id, 10), node.name, node.animated) + if !reactionMXC.IsEmpty() { + _, _ = fmt.Fprintf(w, `%[2]s`, reactionMXC.String(), node.name) + return + } + } + stringifiable, ok := n.(mautrix.Stringifiable) + if ok { + _, _ = w.WriteString(stringifiable.String()) + } else { + _, _ = w.Write(source) + } + return +} + +type DiscordTag struct { + Portal *Portal +} + +func (e *DiscordTag) Extend(m goldmark.Markdown) { + m.Parser().AddOptions(parser.WithInlineParsers( + util.Prioritized(defaultDiscordTagParser, 600), + )) + m.Renderer().AddOptions(renderer.WithNodeRenderers( + util.Prioritized(&discordTagHTMLRenderer{e.Portal}, 600), + )) +} diff --git a/portal.go b/portal.go index 000ea3c..1190be6 100644 --- a/portal.go +++ b/portal.go @@ -611,7 +611,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess var parts []database.MessagePart ts, _ := discordgo.SnowflakeTimestamp(msg.ID) if msg.Content != "" { - content := renderDiscordMarkdown(msg.Content) + content := portal.renderDiscordMarkdown(msg.Content) content.RelatesTo = threadRelation.Copy() if msg.MessageReference != nil { @@ -697,24 +697,24 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess attachmentMap[existingPart.AttachmentID] = existingPart } } - for _, attachment := range msg.Attachments { - if _, found := attachmentMap[attachment.ID]; found { - delete(attachmentMap, attachment.ID) + for _, remainingAttachment := range msg.Attachments { + if _, found := attachmentMap[remainingAttachment.ID]; found { + delete(attachmentMap, remainingAttachment.ID) } } - for _, attachment := range attachmentMap { - _, err := intent.RedactEvent(portal.MXID, attachment.MXID) + for _, deletedAttachment := range attachmentMap { + _, err := intent.RedactEvent(portal.MXID, deletedAttachment.MXID) if err != nil { - portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err) + portal.log.Warnfln("Failed to remove attachment %s: %v", deletedAttachment.MXID, err) } - attachment.Delete() + deletedAttachment.Delete() } if msg.Content == "" || existing[0].AttachmentID != "" { portal.log.Debugfln("Dropping non-text edit to %s (message on matrix: %t, text on discord: %t)", msg.ID, existing[0].AttachmentID == "", len(msg.Content) > 0) return } - content := renderDiscordMarkdown(msg.Content) + content := portal.renderDiscordMarkdown(msg.Content) content.SetEdit(existing[0].MXID) var editTS int64 @@ -885,7 +885,6 @@ func (portal *Portal) startThreadFromMatrix(sender *User, threadRoot id.EventID) return "", fmt.Errorf("error starting thread: %v", err) } portal.log.Debugfln("Created Discord thread from %s/%s", threadRoot, ch.ID) - fmt.Printf("Created thread %+v\n", ch) portal.bridge.GetThreadByID(existingMsg.DiscordID, existingMsg) return ch.ID, nil } @@ -1295,32 +1294,12 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess var matrixReaction string if reaction.Emoji.ID != "" { - dbEmoji := portal.bridge.DB.Emoji.GetByDiscordID(reaction.Emoji.ID) - - if dbEmoji == nil { - data, mimeType, err := portal.downloadDiscordEmoji(reaction.Emoji.ID, reaction.Emoji.Animated) - if err != nil { - portal.log.Warnfln("Failed to download emoji %s from discord: %v", reaction.Emoji.ID, err) - - return - } - - uri, err := portal.uploadMatrixEmoji(intent, data, mimeType) - if err != nil { - portal.log.Warnfln("Failed to upload discord emoji %s to homeserver: %v", reaction.Emoji.ID, err) - - return - } - - dbEmoji = portal.bridge.DB.Emoji.New() - dbEmoji.DiscordID = reaction.Emoji.ID - dbEmoji.DiscordName = reaction.Emoji.Name - dbEmoji.MatrixURL = uri - dbEmoji.Insert() + reactionMXC := portal.getEmojiMXCByDiscordID(reaction.Emoji.ID, reaction.Emoji.Name, reaction.Emoji.Animated) + if reactionMXC.IsEmpty() { + return } - - discordID = dbEmoji.DiscordID - matrixReaction = dbEmoji.MatrixURL.String() + matrixReaction = reactionMXC.String() + discordID = reaction.Emoji.ID } else { discordID = reaction.Emoji.Name matrixReaction = variationselector.Add(reaction.Emoji.Name) diff --git a/user.go b/user.go index c15a676..c802dfe 100644 --- a/user.go +++ b/user.go @@ -675,9 +675,6 @@ func (user *User) channelUpdateHandler(_ *discordgo.Session, c *discordgo.Channe } func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildID string) { - if user.Session.LogLevel == discordgo.LogDebug { - fmt.Printf("%+v\n", msg) - } if !user.bridgeMessage(guildID) { return }