From d883befee5377834fecb5c23b64a8bd5ef2cb68f Mon Sep 17 00:00:00 2001 From: Gary Kramlich Date: Mon, 21 Feb 2022 06:49:44 -0600 Subject: [PATCH] Basic provisioning api Includes: login, logout, and ping. --- bridge/bridge.go | 5 + bridge/commands.go | 22 +++ bridge/provisioning.go | 307 +++++++++++++++++++++++++++++++++++++++++ bridge/user.go | 9 ++ config/appservice.go | 2 + config/config.go | 28 +++- config/provisioning.go | 43 ++++++ database/user.go | 15 ++ remoteauth/client.go | 2 +- run/cmd.go | 2 + 10 files changed, 429 insertions(+), 6 deletions(-) create mode 100644 bridge/provisioning.go create mode 100644 config/provisioning.go diff --git a/bridge/bridge.go b/bridge/bridge.go index c2bddda..1ed867b 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -30,6 +30,7 @@ type Bridge struct { eventProcessor *appservice.EventProcessor matrixHandler *matrixHandler bot *appservice.IntentAPI + provisioning *ProvisioningAPI usersByMXID map[id.UserID]*User usersByID map[string]*User @@ -101,6 +102,10 @@ func New(cfg *config.Config) (*Bridge, error) { StateStore: stateStore, } + if cfg.Appservice.Provisioning.Enabled() { + bridge.provisioning = newProvisioningAPI(bridge) + } + // Setup the event processors bridge.setupEvents() diff --git a/bridge/commands.go b/bridge/commands.go index ab6f57a..aa77f6e 100644 --- a/bridge/commands.go +++ b/bridge/commands.go @@ -48,6 +48,7 @@ type commands struct { Help helpCmd `kong:"cmd,help='Displays this message.'"` Login loginCmd `kong:"cmd,help='Log in to Discord.'"` + Logout logoutCmd `kong:"cmd,help='Log out of Discord.'"` Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"` } @@ -134,3 +135,24 @@ func (l *loginCmd) Run(g *globals) error { return nil } + +type logoutCmd struct{} + +func (l *logoutCmd) Run(g *globals) error { + if !g.user.LoggedIn() { + fmt.Fprintln(g.context.Stdout, "You are not logged in") + + return fmt.Errorf("user is not logged in") + } + + err := g.user.DeleteSession() + if err != nil { + fmt.Fprintln(g.context.Stdout, "Failed to log out") + + return err + } + + fmt.Fprintln(g.context.Stdout, "Successfully logged out") + + return nil +} diff --git a/bridge/provisioning.go b/bridge/provisioning.go new file mode 100644 index 0000000..a41ce8b --- /dev/null +++ b/bridge/provisioning.go @@ -0,0 +1,307 @@ +package bridge + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/gorilla/websocket" + log "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix/id" + + "gitlab.com/beeper/discord/remoteauth" +) + +const ( + SecWebSocketProtocol = "com.gitlab.beeper.discord" +) + +type ProvisioningAPI struct { + bridge *Bridge + log log.Logger +} + +func newProvisioningAPI(bridge *Bridge) *ProvisioningAPI { + p := &ProvisioningAPI{ + bridge: bridge, + log: bridge.log.Sub("Provisioning"), + } + + prefix := bridge.Config.Appservice.Provisioning.Prefix + + p.log.Debugln("Enabling provisioning API at", prefix) + + r := bridge.as.Router.PathPrefix(prefix).Subrouter() + + r.Use(p.authMiddleware) + + r.HandleFunc("/ping", p.Ping).Methods(http.MethodGet) + r.HandleFunc("/login", p.Login).Methods(http.MethodGet) + r.HandleFunc("/logout", p.Logout).Methods(http.MethodPost) + + return p +} + +func jsonResponse(w http.ResponseWriter, status int, response interface{}) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(response) +} + +// Response structs +type Response struct { + Success bool `json:"success"` + Status string `json:"status"` +} + +type Error struct { + Success bool `json:"success"` + Error string `json:"error"` + ErrCode string `json:"errcode"` +} + +// Wrapped http.ResponseWriter to capture the status code +type responseWrap struct { + http.ResponseWriter + statusCode int +} + +var _ http.Hijacker = (*responseWrap)(nil) + +func (rw *responseWrap) WriteHeader(statusCode int) { + rw.ResponseWriter.WriteHeader(statusCode) + rw.statusCode = statusCode +} + +func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("response does not implement http.Hijacker") + } + return hijacker.Hijack() +} + +// Middleware +func (p *ProvisioningAPI) authMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + + // Special case the login endpoint to use the discord qrcode auth + if auth == "" && strings.HasSuffix(r.URL.Path, "/login") { + authParts := strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ",") + for _, part := range authParts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, SecWebSocketProtocol+"-") { + auth = part[len(SecWebSocketProtocol+"-"):] + + break + } + } + } else if strings.HasPrefix(auth, "Bearer ") { + auth = auth[len("Bearer "):] + } + + if auth != p.bridge.Config.Appservice.Provisioning.SharedSecret { + jsonResponse(w, http.StatusForbidden, map[string]interface{}{ + "error": "Invalid auth token", + "errcode": "M_FORBIDDEN", + }) + + return + } + + userID := r.URL.Query().Get("user_id") + user := p.bridge.GetUserByMXID(id.UserID(userID)) + + start := time.Now() + wWrap := &responseWrap{w, 200} + h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user))) + duration := time.Now().Sub(start).Seconds() + + p.log.Infofln("%s %s from %s took %.2f seconds and returned status %d", r.Method, r.URL.Path, user.MXID, duration, wWrap.statusCode) + }) +} + +// websocket upgrader +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + Subprotocols: []string{SecWebSocketProtocol}, +} + +// Handlers +func (p *ProvisioningAPI) Ping(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + discord := map[string]interface{}{ + "has_session": user.Session != nil, + "management_room": user.ManagementRoom, + "conn": nil, + } + + if user.ID != "" { + discord["id"] = user.ID + } + + if user.Session != nil { + discord["conn"] = map[string]interface{}{ + "last_heartbeat_ack": user.Session.LastHeartbeatAck, + "last_heartbeat_sent": user.Session.LastHeartbeatSent, + } + } + + resp := map[string]interface{}{ + "mxid": user.MXID, + "discord": discord, + } + + jsonResponse(w, http.StatusOK, resp) +} + +func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + force := strings.ToLower(r.URL.Query().Get("force")) != "false" + + if user.Session == nil { + if force { + jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) + } else { + jsonResponse(w, http.StatusNotFound, Error{ + Error: "You're not logged in", + ErrCode: "not logged in", + }) + } + + return + } + + err := user.DeleteSession() + if err != nil { + user.log.Warnln("Error while logging out:", err) + + if !force { + jsonResponse(w, http.StatusInternalServerError, Error{ + Error: fmt.Sprintf("Unknown error while logging out: %v", err), + ErrCode: err.Error(), + }) + + return + } + } + + jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) +} + +func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { + userID := r.URL.Query().Get("user_id") + user := p.bridge.GetUserByMXID(id.UserID(userID)) + + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + p.log.Errorln("Failed to upgrade connection to websocket:", err) + return + } + + defer func() { + err := c.Close() + if err != nil { + user.log.Debugln("Error closing websocket:", err) + } + }() + + go func() { + // Read everything so SetCloseHandler() works + for { + _, _, err = c.ReadMessage() + if err != nil { + break + } + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + c.SetCloseHandler(func(code int, text string) error { + user.log.Debugfln("Login websocket closed (%d), cancelling login", code) + + cancel() + + return nil + }) + + client, err := remoteauth.New() + if err != nil { + user.log.Errorf("Failed to log in from provisioning API:", err) + + c.WriteJSON(Error{ + Error: "Failed to connect to Discord", + ErrCode: "connection error", + }) + } + + qrChan := make(chan string) + doneChan := make(chan struct{}) + + user.log.Debugln("Started login via provisioning API") + + err = client.Dial(ctx, qrChan, doneChan) + if err != nil { + close(qrChan) + close(doneChan) + } + + for { + select { + case qrCode, ok := <-qrChan: + if !ok { + continue + } + c.WriteJSON(map[string]interface{}{ + "code": qrCode, + "timeout": 120, // TODO: move this to the library or something + }) + case <-doneChan: + discordUser, err := client.Result() + if err != nil { + c.WriteJSON(Error{ + Error: "Failed to connect to Discord", + ErrCode: "connection error", + }) + + p.log.Errorfln("failed to login via qrcode:", err) + + return + } + + if err := user.Login(discordUser.Token); err != nil { + c.WriteJSON(Error{ + Error: "Failed to connect to Discord", + ErrCode: "connection error", + }) + + p.log.Errorfln("failed to login via qrcode:", err) + + return + } + + user.ID = discordUser.UserID + user.Update() + + c.WriteJSON(map[string]interface{}{ + "success": true, + "id": user.ID, + }) + + return + case <-ctx.Done(): + return + } + } +} diff --git a/bridge/user.go b/bridge/user.go index 72218dd..cc0561d 100644 --- a/bridge/user.go +++ b/bridge/user.go @@ -2,6 +2,7 @@ package bridge import ( "errors" + "fmt" "strings" "github.com/bwmarrin/discordgo" @@ -184,6 +185,10 @@ func (u *User) uploadQRCode(code string) (id.ContentURI, error) { } func (u *User) Login(token string) error { + if token == "" { + return fmt.Errorf("No token specified") + } + err := u.User.NewSession(token) if err != nil { return err @@ -192,6 +197,10 @@ func (u *User) Login(token string) error { return u.Connect() } +func (u *User) LoggedIn() bool { + return u.Session != nil +} + func (u *User) Connect() error { u.log.Debugln("connecting to discord") diff --git a/config/appservice.go b/config/appservice.go index 9e2bc56..a37d1ed 100644 --- a/config/appservice.go +++ b/config/appservice.go @@ -13,6 +13,8 @@ type appservice struct { Bot bot `yaml:"bot"` + Provisioning provisioning `yaml:"provisioning"` + Database database `yaml:"database"` ASToken string `yaml:"as_token"` diff --git a/config/config.go b/config/config.go index 01a5755..cddbd55 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "io/ioutil" "gopkg.in/yaml.v2" @@ -11,8 +12,12 @@ type Config struct { Appservice appservice `yaml:"appservice"` Bridge bridge `yaml:"bridge"` Logging logging `yaml:"logging"` + + filename string `yaml:"-"` } +var configUpdated bool + func (cfg *Config) validate() error { if err := cfg.Homeserver.validate(); err != nil { return err @@ -30,13 +35,20 @@ func (cfg *Config) validate() error { return err } + if configUpdated { + return cfg.Save(cfg.filename) + } + return nil } func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { type rawConfig Config - raw := rawConfig{} + raw := rawConfig{ + filename: cfg.filename, + } + if err := unmarshal(&raw); err != nil { return err } @@ -46,8 +58,10 @@ func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { return cfg.validate() } -func FromBytes(data []byte) (*Config, error) { - cfg := Config{} +func FromBytes(filename string, data []byte) (*Config, error) { + cfg := Config{ + filename: filename, + } if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, err @@ -61,7 +75,7 @@ func FromBytes(data []byte) (*Config, error) { } func FromString(str string) (*Config, error) { - return FromBytes([]byte(str)) + return FromBytes("", []byte(str)) } func FromFile(filename string) (*Config, error) { @@ -70,10 +84,14 @@ func FromFile(filename string) (*Config, error) { return nil, err } - return FromBytes(data) + return FromBytes(filename, data) } func (cfg *Config) Save(filename string) error { + if filename == "" { + return fmt.Errorf("no filename specified yep") + } + data, err := yaml.Marshal(cfg) if err != nil { return err diff --git a/config/provisioning.go b/config/provisioning.go new file mode 100644 index 0000000..70518b6 --- /dev/null +++ b/config/provisioning.go @@ -0,0 +1,43 @@ +package config + +import ( + "strings" + + as "maunium.net/go/mautrix/appservice" +) + +type provisioning struct { + Prefix string `yaml:"prefix"` + SharedSecret string `yaml:"shared_secret"` +} + +func (p *provisioning) validate() error { + if p.Prefix == "" { + p.Prefix = "/_matrix/provision/v1" + } + + if strings.ToLower(p.SharedSecret) == "generate" { + p.SharedSecret = as.RandomString(64) + + configUpdated = true + } + + return nil +} + +func (p *provisioning) UnmarshalYAML(unmarshal func(interface{}) error) error { + type rawProvisioning provisioning + + raw := rawProvisioning{} + if err := unmarshal(&raw); err != nil { + return err + } + + *p = provisioning(raw) + + return p.validate() +} + +func (p *provisioning) Enabled() bool { + return strings.ToLower(p.SharedSecret) != "disable" +} diff --git a/database/user.go b/database/user.go index 4443fb8..fc2ee03 100644 --- a/database/user.go +++ b/database/user.go @@ -37,6 +37,21 @@ func (u *User) NewSession(token string) error { return nil } +// DeleteSession tries to logout and delete the session from the database. +func (u *User) DeleteSession() error { + err := u.Session.Close() + + if err != nil { + u.log.Warnfln("failed to close the session for %s: %v", u.ID, err) + } + + u.Session = nil + + u.Update() + + return nil +} + func (u *User) Scan(row Scannable) *User { var token sql.NullString diff --git a/remoteauth/client.go b/remoteauth/client.go index 94a1537..4b96d8d 100644 --- a/remoteauth/client.go +++ b/remoteauth/client.go @@ -45,7 +45,7 @@ func New() (*Client, error) { }, nil } -// Dialo will start the QRCode login process. ctx may be used to abandon the +// Dial will start the QRCode login process. ctx may be used to abandon the // process. func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan struct{}) error { header := http.Header{ diff --git a/run/cmd.go b/run/cmd.go index c5af845..06846e9 100644 --- a/run/cmd.go +++ b/run/cmd.go @@ -1,6 +1,7 @@ package run import ( + "fmt" "os" "os/signal" "syscall" @@ -13,6 +14,7 @@ import ( type Cmd struct{} func (c *Cmd) Run(g *globals.Globals) error { + fmt.Printf("g.Config: %q\n", g.Config) cfg, err := config.FromFile(g.Config) if err != nil { return err