diff --git a/commands.go b/commands.go index f68abf3..93f402b 100644 --- a/commands.go +++ b/commands.go @@ -54,6 +54,8 @@ func (br *DiscordBridge) RegisterCommands() { cmdGuilds, cmdRejoinSpace, cmdDeleteAllPortals, + cmdExec, + cmdCommands, ) } diff --git a/commands_botinteraction.go b/commands_botinteraction.go new file mode 100644 index 0000000..1e37fea --- /dev/null +++ b/commands_botinteraction.go @@ -0,0 +1,316 @@ +// mautrix-discord - A Matrix-Discord puppeting bridge. +// Copyright (C) 2023 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package main + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/google/shlex" + + "maunium.net/go/mautrix/bridge/commands" +) + +var cmdCommands = &commands.FullHandler{ + Func: wrapCommand(fnCommands), + Name: "commands", + Aliases: []string{"cmds", "cs"}, + Help: commands.HelpMeta{ + Section: commands.HelpSectionUnclassified, + Description: "View parameters of bot interaction commands on Discord", + Args: "search <_query_> OR help <_command_>", + }, + RequiresPortal: true, + RequiresLogin: true, +} + +var cmdExec = &commands.FullHandler{ + Func: wrapCommand(fnExec), + Name: "exec", + Aliases: []string{"command", "cmd", "c", "exec", "e"}, + Help: commands.HelpMeta{ + Section: commands.HelpSectionUnclassified, + Description: "Run bot interaction commands on Discord", + Args: "<_command_> [_arg=value ..._]", + }, + RequiresLogin: true, + RequiresPortal: true, +} + +func (portal *Portal) getCommand(user *User, command string) (*discordgo.ApplicationCommand, error) { + portal.commandsLock.Lock() + defer portal.commandsLock.Unlock() + cmd, ok := portal.commands[command] + if !ok { + results, err := user.Session.ApplicationCommandsSearch(portal.Key.ChannelID, command) + if err != nil { + return nil, err + } + for _, result := range results { + if result.Name == command { + portal.commands[result.Name] = result + cmd = result + break + } + } + if cmd == nil { + return nil, nil + } + } + return cmd, nil +} + +func getCommandOptionTypeName(optType discordgo.ApplicationCommandOptionType) string { + switch optType { + case discordgo.ApplicationCommandOptionSubCommand: + return "subcommand" + case discordgo.ApplicationCommandOptionSubCommandGroup: + return "subcommand group (unsupported)" + case discordgo.ApplicationCommandOptionString: + return "string" + case discordgo.ApplicationCommandOptionInteger: + return "integer" + case discordgo.ApplicationCommandOptionBoolean: + return "boolean" + case discordgo.ApplicationCommandOptionUser: + return "user (unsupported)" + case discordgo.ApplicationCommandOptionChannel: + return "channel (unsupported)" + case discordgo.ApplicationCommandOptionRole: + return "role (unsupported)" + case discordgo.ApplicationCommandOptionMentionable: + return "mentionable (unsupported)" + case discordgo.ApplicationCommandOptionNumber: + return "number" + case discordgo.ApplicationCommandOptionAttachment: + return "attachment (unsupported)" + default: + return fmt.Sprintf("unknown type %d", optType) + } +} + +func parseCommandOptionValue(optType discordgo.ApplicationCommandOptionType, value string) (any, error) { + switch optType { + case discordgo.ApplicationCommandOptionSubCommandGroup: + return nil, fmt.Errorf("subcommand groups aren't supported") + case discordgo.ApplicationCommandOptionString: + return value, nil + case discordgo.ApplicationCommandOptionInteger: + return strconv.ParseInt(value, 10, 64) + case discordgo.ApplicationCommandOptionBoolean: + return strconv.ParseBool(value) + case discordgo.ApplicationCommandOptionUser: + return nil, fmt.Errorf("user options aren't supported") + case discordgo.ApplicationCommandOptionChannel: + return nil, fmt.Errorf("channel options aren't supported") + case discordgo.ApplicationCommandOptionRole: + return nil, fmt.Errorf("role options aren't supported") + case discordgo.ApplicationCommandOptionMentionable: + return nil, fmt.Errorf("mentionable options aren't supported") + case discordgo.ApplicationCommandOptionNumber: + return strconv.ParseFloat(value, 64) + case discordgo.ApplicationCommandOptionAttachment: + return nil, fmt.Errorf("attachment options aren't supported") + default: + return nil, fmt.Errorf("unknown option type %d", optType) + } +} + +func indent(text, with string) string { + split := strings.Split(text, "\n") + for i, part := range split { + split[i] = with + part + } + return strings.Join(split, "\n") +} + +func formatOption(opt *discordgo.ApplicationCommandOption) string { + argText := fmt.Sprintf("* `%s`: %s", opt.Name, getCommandOptionTypeName(opt.Type)) + if strings.ToLower(opt.Description) != opt.Name { + argText += fmt.Sprintf(" - %s", opt.Description) + } + if opt.Required { + argText += " (required)" + } + if len(opt.Options) > 0 { + subopts := make([]string, len(opt.Options)) + for i, subopt := range opt.Options { + subopts[i] = indent(formatOption(subopt), " ") + } + argText += "\n" + strings.Join(subopts, "\n") + } + return argText +} + +func formatCommand(cmd *discordgo.ApplicationCommand) string { + baseText := fmt.Sprintf("$cmdprefix exec %s", cmd.Name) + if len(cmd.Options) > 0 { + args := make([]string, len(cmd.Options)) + argPlaceholder := "[arg=value ...]" + for i, opt := range cmd.Options { + args[i] = formatOption(opt) + if opt.Required { + argPlaceholder = "" + } + } + baseText = fmt.Sprintf("`%s %s` - %s\n%s", baseText, argPlaceholder, cmd.Description, strings.Join(args, "\n")) + } else { + baseText = fmt.Sprintf("`%s` - %s", baseText, cmd.Description) + } + return baseText +} + +func parseCommandOptions(opts []*discordgo.ApplicationCommandOption, subcommands []string, namedArgs map[string]string) (res []*discordgo.ApplicationCommandOptionInput, err error) { + subcommandDone := false + for _, opt := range opts { + optRes := &discordgo.ApplicationCommandOptionInput{ + Type: opt.Type, + Name: opt.Name, + } + if opt.Type == discordgo.ApplicationCommandOptionSubCommand { + if !subcommandDone && len(subcommands) > 0 && subcommands[0] == opt.Name { + subcommandDone = true + optRes.Options, err = parseCommandOptions(opt.Options, subcommands[1:], namedArgs) + if err != nil { + err = fmt.Errorf("error parsing subcommand %s: %v", opt.Name, err) + break + } + subcommands = subcommands[1:] + } else { + continue + } + } else if argVal, ok := namedArgs[opt.Name]; ok { + optRes.Value, err = parseCommandOptionValue(opt.Type, argVal) + if err != nil { + err = fmt.Errorf("error parsing parameter %s: %v", opt.Name, err) + break + } + } else if opt.Required { + switch opt.Type { + case discordgo.ApplicationCommandOptionSubCommandGroup, discordgo.ApplicationCommandOptionUser, + discordgo.ApplicationCommandOptionChannel, discordgo.ApplicationCommandOptionRole, + discordgo.ApplicationCommandOptionMentionable, discordgo.ApplicationCommandOptionAttachment: + err = fmt.Errorf("missing required parameter %s (which is not supported by the bridge)", opt.Name) + default: + err = fmt.Errorf("missing required parameter %s", opt.Name) + } + break + } else { + continue + } + res = append(res, optRes) + } + if len(subcommands) > 0 { + err = fmt.Errorf("unparsed subcommands left over (did you forget quoting for parameters with spaces?)") + } + return +} + +func executeCommand(cmd *discordgo.ApplicationCommand, args []string) (res []*discordgo.ApplicationCommandOptionInput, err error) { + namedArgs := map[string]string{} + n := 0 + for _, arg := range args { + name, value, isNamed := strings.Cut(arg, "=") + if isNamed { + namedArgs[name] = value + } else { + args[n] = arg + n++ + } + } + return parseCommandOptions(cmd.Options, args[:n], namedArgs) +} + +func fnCommands(ce *WrappedCommandEvent) { + if len(ce.Args) < 2 { + ce.Reply("**Usage**: `$cmdprefix commands search <_query_>` OR `$cmdprefix commands help <_command_>`") + return + } + subcmd := strings.ToLower(ce.Args[0]) + if subcmd == "search" { + results, err := ce.User.Session.ApplicationCommandsSearch(ce.Portal.Key.ChannelID, ce.Args[1]) + if err != nil { + ce.Reply("Error searching for commands: %v", err) + return + } + formatted := make([]string, len(results)) + ce.Portal.commandsLock.Lock() + for i, result := range results { + ce.Portal.commands[result.Name] = result + formatted[i] = indent(formatCommand(result), " ") + formatted[i] = "*" + formatted[i][1:] + } + ce.Portal.commandsLock.Unlock() + ce.Reply("Found results:\n" + strings.Join(formatted, "\n")) + } else if subcmd == "help" { + command := strings.ToLower(ce.Args[1]) + cmd, err := ce.Portal.getCommand(ce.User, command) + if err != nil { + ce.Reply("Error searching for commands: %v", err) + } else if cmd == nil { + ce.Reply("Command %q not found", command) + } else { + ce.Reply(formatCommand(cmd)) + } + } +} + +func fnExec(ce *WrappedCommandEvent) { + if len(ce.Args) == 0 { + ce.Reply("**Usage**: `$cmdprefix exec [arg=value ...]`") + return + } + args, err := shlex.Split(ce.RawArgs) + if err != nil { + ce.Reply("Error parsing args with shlex: %v", err) + return + } + command := strings.ToLower(args[0]) + cmd, err := ce.Portal.getCommand(ce.User, command) + if err != nil { + ce.Reply("Error searching for commands: %v", err) + } else if cmd == nil { + ce.Reply("Command %q not found", command) + } else if options, err := executeCommand(cmd, args[1:]); err != nil { + ce.Reply("Error parsing arguments: %v\n\n**Usage:** "+formatCommand(cmd), err) + } else { + nonce := generateNonce() + ce.User.pendingInteractionsLock.Lock() + ce.User.pendingInteractions[nonce] = ce + ce.User.pendingInteractionsLock.Unlock() + err = ce.User.Session.SendInteractions(ce.Portal.GuildID, ce.Portal.Key.ChannelID, cmd, options, nonce) + if err != nil { + ce.Reply("Error sending interaction: %v", err) + ce.User.pendingInteractionsLock.Lock() + delete(ce.User.pendingInteractions, nonce) + ce.User.pendingInteractionsLock.Unlock() + } else { + go func() { + time.Sleep(10 * time.Second) + ce.User.pendingInteractionsLock.Lock() + if _, stillWaiting := ce.User.pendingInteractions[nonce]; stillWaiting { + delete(ce.User.pendingInteractions, nonce) + ce.Reply("Timed out waiting for interaction success") + } + ce.User.pendingInteractionsLock.Unlock() + }() + } + } +} diff --git a/go.mod b/go.mod index 697f542..2355631 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/bwmarrin/discordgo v0.26.1 github.com/gabriel-vasile/mimetype v1.4.1 + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.7 @@ -13,7 +14,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/yuin/goldmark v1.5.3 maunium.net/go/maulogger/v2 v2.3.2 - maunium.net/go/mautrix v0.13.1-0.20230128124647-7d98a9f8e3a6 + maunium.net/go/mautrix v0.13.1-0.20230129104640-4a2a7653e437 ) require ( @@ -33,4 +34,4 @@ require ( maunium.net/go/mauflag v1.0.0 // indirect ) -replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230128134018-766d08cb045e +replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230129113304-139f45f429a0 diff --git a/go.sum b/go.sum index d6230a3..58b3e67 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= -github.com/beeper/discordgo v0.0.0-20230128134018-766d08cb045e h1:R0Db6p3gANvV2Hk8lbSSlPDNG3zzeOM8nyZHmLl3tkI= -github.com/beeper/discordgo v0.0.0-20230128134018-766d08cb045e/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= +github.com/beeper/discordgo v0.0.0-20230129113304-139f45f429a0 h1:RrF9ffkMyEsUtZqWR/m/KXSrYbpyT7bkuL+KY8pexSE= +github.com/beeper/discordgo v0.0.0-20230129113304-139f45f429a0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -8,6 +8,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/gabriel-vasile/mimetype v1.4.1 h1:TRWk7se+TOjCYgRth7+1/OYLNiRNIotknkFtf/dnN7Q= github.com/gabriel-vasile/mimetype v1.4.1/go.mod h1:05Vi0w3Y9c/lNvJOdmIwvrrAhX3rYhfQQCaf9VJcv7M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= +github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= @@ -75,5 +77,5 @@ maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0= maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= -maunium.net/go/mautrix v0.13.1-0.20230128124647-7d98a9f8e3a6 h1:c2HKxT3wYxWS213BXaWDY3UkHGfXGmhzOv4h1OKZm20= -maunium.net/go/mautrix v0.13.1-0.20230128124647-7d98a9f8e3a6/go.mod h1:gYMQPsZ9lQpyKlVp+DGwOuc9LIcE/c8GZW2CvKHISgM= +maunium.net/go/mautrix v0.13.1-0.20230129104640-4a2a7653e437 h1:BMfeE1yJNs97rIXCRzIY284g7dXa7E6OaM0HCWpddwU= +maunium.net/go/mautrix v0.13.1-0.20230129104640-4a2a7653e437/go.mod h1:gYMQPsZ9lQpyKlVp+DGwOuc9LIcE/c8GZW2CvKHISgM= diff --git a/portal.go b/portal.go index 0333df8..9fdf18e 100644 --- a/portal.go +++ b/portal.go @@ -60,6 +60,9 @@ type Portal struct { recentMessages *util.RingBuffer[string, *discordgo.Message] + commands map[string]*discordgo.ApplicationCommand + commandsLock sync.RWMutex + currentlyTyping []id.UserID currentlyTypingLock sync.Mutex } @@ -232,6 +235,8 @@ func (br *DiscordBridge) NewPortal(dbPortal *database.Portal) *Portal { matrixMessages: make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer), recentMessages: util.NewRingBuffer[string, *discordgo.Message](recentMessageBufferSize), + + commands: make(map[string]*discordgo.ApplicationCommand), } go portal.messageLoop() diff --git a/user.go b/user.go index b4c1ff4..1e4dba2 100644 --- a/user.go +++ b/user.go @@ -59,6 +59,9 @@ type User struct { markedOpened map[string]time.Time markedOpenedLock sync.Mutex + pendingInteractions map[string]*WrappedCommandEvent + pendingInteractionsLock sync.Mutex + nextDiscordUploadID atomic.Int32 } @@ -197,6 +200,8 @@ func (br *DiscordBridge) NewUser(dbUser *database.User) *User { markedOpened: make(map[string]time.Time), PermissionLevel: br.Config.Bridge.Permissions.Get(dbUser.MXID), + + pendingInteractions: make(map[string]*WrappedCommandEvent), } user.nextDiscordUploadID.Store(rand.Int31n(100)) user.BridgeState = br.NewBridgeStateQueue(user, user.log) @@ -540,6 +545,8 @@ func (user *User) Connect() error { user.Session.AddHandler(user.messageAckHandler) user.Session.AddHandler(user.typingStartHandler) + user.Session.AddHandler(user.interactionSuccessHandler) + user.Session.Identify.Presence.Status = "online" return user.Session.Open() @@ -963,6 +970,19 @@ func (user *User) typingStartHandler(_ *discordgo.Session, t *discordgo.TypingSt } } +func (user *User) interactionSuccessHandler(_ *discordgo.Session, s *discordgo.InteractionSuccess) { + user.pendingInteractionsLock.Lock() + defer user.pendingInteractionsLock.Unlock() + ce, ok := user.pendingInteractions[s.Nonce] + if !ok { + user.log.Debugfln("Got interaction success for unknown interaction %s/%s", s.Nonce, s.ID) + } else { + user.log.Infofln("Got interaction success for pending interaction %s/%s", s.Nonce, s.ID) + ce.React("✅") + delete(user.pendingInteractions, s.Nonce) + } +} + func (user *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) bool { if intent == nil { intent = user.bridge.Bot