diff --git a/bridge/commandhandler.go b/bridge/commandhandler.go index c5f91e8..3bd49c2 100644 --- a/bridge/commandhandler.go +++ b/bridge/commandhandler.go @@ -101,7 +101,12 @@ func (h *commandHandler) handle(roomID id.RoomID, user *User, message string, re if err != nil { h.log.Warnf("Command %q failed: %v", message, err) - cmd.globals.reply("unexpected failure") + output := buf.String() + if output != "" { + cmd.globals.reply(output) + } else { + cmd.globals.reply("unexpected failure") + } return } diff --git a/bridge/commands.go b/bridge/commands.go index aa77f6e..0d7fbba 100644 --- a/bridge/commands.go +++ b/bridge/commands.go @@ -46,10 +46,12 @@ func (g *globals) reply(msg string) { type commands struct { globals - Help helpCmd `kong:"cmd,help='Displays this message.'"` - Login loginCmd `kong:"cmd,help='Log in to Discord.'"` - Logout logoutCmd `kong:"cmd,help='Log out of Discord.'"` - Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"` + Disconnect disconnectCmd `kong:"cmd,help='Disconnect from Discord'"` + Help helpCmd `kong:"cmd,help='Displays this message.'"` + Login loginCmd `kong:"cmd,help='Log in to Discord.'"` + Logout logoutCmd `kong:"cmd,help='Log out of Discord.'"` + Reconnect reconnectCmd `kong:"cmd,help='Reconnect to Discord'"` + Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"` } type helpCmd struct { @@ -87,6 +89,12 @@ func (c *versionCmd) Run(g *globals) error { type loginCmd struct{} func (l *loginCmd) Run(g *globals) error { + if g.user.LoggedIn() { + fmt.Fprintf(g.context.Stdout, "You are already logged in") + + return fmt.Errorf("user already logged in") + } + client, err := remoteauth.New() if err != nil { return err @@ -145,7 +153,7 @@ func (l *logoutCmd) Run(g *globals) error { return fmt.Errorf("user is not logged in") } - err := g.user.DeleteSession() + err := g.user.Logout() if err != nil { fmt.Fprintln(g.context.Stdout, "Failed to log out") @@ -156,3 +164,43 @@ func (l *logoutCmd) Run(g *globals) error { return nil } + +type disconnectCmd struct{} + +func (d *disconnectCmd) Run(g *globals) error { + if !g.user.Connected() { + fmt.Fprintln(g.context.Stdout, "You are not connected") + + return fmt.Errorf("user is not connected") + } + + if err := g.user.Disconnect(); err != nil { + fmt.Fprintln(g.context.Stdout, "Failed to disconnect") + + return err + } + + fmt.Fprintln(g.context.Stdout, "Successfully disconnected") + + return nil +} + +type reconnectCmd struct{} + +func (r *reconnectCmd) Run(g *globals) error { + if g.user.Connected() { + fmt.Fprintln(g.context.Stdout, "You are already connected") + + return fmt.Errorf("user is already connected") + } + + if err := g.user.Connect(); err != nil { + fmt.Fprintln(g.context.Stdout, "Failed to connect") + + return err + } + + fmt.Fprintln(g.context.Stdout, "Successfully connected") + + return nil +} diff --git a/bridge/matrix.go b/bridge/matrix.go index bbc5a45..a8ceb02 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -162,7 +162,7 @@ func (mh *matrixHandler) handleBotInvite(evt *event.Event) { mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Welcome) if evt.RoomID == user.ManagementRoom { - if user.HasSession() { + if user.Connected() { mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Connected) } else { mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.NotConnected) diff --git a/bridge/provisioning.go b/bridge/provisioning.go index a41ce8b..02aaf31 100644 --- a/bridge/provisioning.go +++ b/bridge/provisioning.go @@ -42,9 +42,11 @@ func newProvisioningAPI(bridge *Bridge) *ProvisioningAPI { r.Use(p.authMiddleware) - r.HandleFunc("/ping", p.Ping).Methods(http.MethodGet) - r.HandleFunc("/login", p.Login).Methods(http.MethodGet) - r.HandleFunc("/logout", p.Logout).Methods(http.MethodPost) + r.HandleFunc("/disconnect", p.disconnect).Methods(http.MethodPost) + r.HandleFunc("/ping", p.ping).Methods(http.MethodGet) + r.HandleFunc("/login", p.login).Methods(http.MethodGet) + r.HandleFunc("/logout", p.logout).Methods(http.MethodPost) + r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost) return p } @@ -138,38 +140,78 @@ var upgrader = websocket.Upgrader{ } // Handlers -func (p *ProvisioningAPI) Ping(w http.ResponseWriter, r *http.Request) { +func (p *ProvisioningAPI) disconnect(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + if !user.Connected() { + jsonResponse(w, http.StatusConflict, Error{ + Error: "You're not connected to discord", + ErrCode: "not connected", + }) + + return + } + + if err := user.Disconnect(); err != nil { + jsonResponse(w, http.StatusInternalServerError, Error{ + Error: "Failed to disconnect from discord", + ErrCode: "failed to disconnect", + }) + } else { + jsonResponse(w, http.StatusOK, Response{ + Success: true, + Status: "Disconnected from Discord", + }) + } +} + +func (p *ProvisioningAPI) ping(w http.ResponseWriter, r *http.Request) { user := r.Context().Value("user").(*User) discord := map[string]interface{}{ - "has_session": user.Session != nil, - "management_room": user.ManagementRoom, - "conn": nil, + "logged_in": user.LoggedIn(), + "connected": user.Connected(), + "conn": nil, } + user.Lock() if user.ID != "" { discord["id"] = user.ID } if user.Session != nil { + user.Session.Lock() discord["conn"] = map[string]interface{}{ "last_heartbeat_ack": user.Session.LastHeartbeatAck, "last_heartbeat_sent": user.Session.LastHeartbeatSent, } + user.Session.Unlock() } resp := map[string]interface{}{ - "mxid": user.MXID, - "discord": discord, + "discord": discord, + "management_room": user.ManagementRoom, + "mxid": user.MXID, } + user.Unlock() + jsonResponse(w, http.StatusOK, resp) } -func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { +func (p *ProvisioningAPI) logout(w http.ResponseWriter, r *http.Request) { user := r.Context().Value("user").(*User) force := strings.ToLower(r.URL.Query().Get("force")) != "false" + if !user.LoggedIn() { + jsonResponse(w, http.StatusNotFound, Error{ + Error: "You're not logged in", + ErrCode: "not logged in", + }) + + return + } + if user.Session == nil { if force { jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) @@ -183,7 +225,7 @@ func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { return } - err := user.DeleteSession() + err := user.Logout() if err != nil { user.log.Warnln("Error while logging out:", err) @@ -200,7 +242,7 @@ func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) } -func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { +func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) { userID := r.URL.Query().Get("user_id") user := p.bridge.GetUserByMXID(id.UserID(userID)) @@ -220,7 +262,7 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { go func() { // Read everything so SetCloseHandler() works for { - _, _, err = c.ReadMessage() + _, _, err := c.ReadMessage() if err != nil { break } @@ -236,6 +278,15 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { return nil }) + if user.LoggedIn() { + c.WriteJSON(Error{ + Error: "You're already logged into Discord", + ErrCode: "already logged in", + }) + + return + } + client, err := remoteauth.New() if err != nil { user.log.Errorf("Failed to log in from provisioning API:", err) @@ -280,6 +331,9 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { return } + user.ID = discordUser.UserID + user.Update() + if err := user.Login(discordUser.Token); err != nil { c.WriteJSON(Error{ Error: "Failed to connect to Discord", @@ -291,9 +345,6 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { return } - user.ID = discordUser.UserID - user.Update() - c.WriteJSON(map[string]interface{}{ "success": true, "id": user.ID, @@ -305,3 +356,28 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { } } } + +func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value("user").(*User) + + if user.Connected() { + jsonResponse(w, http.StatusConflict, Error{ + Error: "You're already connected to discord", + ErrCode: "already connected", + }) + + return + } + + if err := user.Connect(); err != nil { + jsonResponse(w, http.StatusInternalServerError, Error{ + Error: "Failed to connect to discord", + ErrCode: "failed to connect", + }) + } else { + jsonResponse(w, http.StatusOK, Response{ + Success: true, + Status: "Connected to Discord", + }) + } +} diff --git a/bridge/user.go b/bridge/user.go index cc0561d..5ef93e1 100644 --- a/bridge/user.go +++ b/bridge/user.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strings" + "sync" "github.com/bwmarrin/discordgo" "github.com/skip2/go-qrcode" @@ -17,11 +18,20 @@ import ( "gitlab.com/beeper/discord/database" ) +var ( + ErrNotConnected = errors.New("not connected") + ErrNotLoggedIn = errors.New("not logged in") +) + type User struct { *database.User + sync.Mutex + bridge *Bridge log log.Logger + + Session *discordgo.Session } func (b *Bridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { @@ -140,10 +150,6 @@ func (u *User) SetManagementRoom(roomID id.RoomID) { u.Update() } -func (u *User) HasSession() bool { - return u.User.Session != nil -} - func (u *User) sendQRCode(bot *appservice.IntentAPI, roomID id.RoomID, code string) (id.EventID, error) { url, err := u.uploadQRCode(code) if err != nil { @@ -189,23 +195,65 @@ func (u *User) Login(token string) error { return fmt.Errorf("No token specified") } - err := u.User.NewSession(token) - if err != nil { - return err - } + u.Token = token + u.Update() return u.Connect() } func (u *User) LoggedIn() bool { + u.Lock() + defer u.Unlock() + + return u.Token != "" +} + +func (u *User) Logout() error { + u.Lock() + defer u.Unlock() + + if u.Session == nil { + return ErrNotLoggedIn + } + + if err := u.Session.Close(); err != nil { + return err + } + + u.Session = nil + + u.Token = "" + u.Update() + + return nil +} + +func (u *User) Connected() bool { + u.Lock() + defer u.Unlock() + return u.Session != nil } func (u *User) Connect() error { + u.Lock() + defer u.Unlock() + + if u.Token == "" { + return ErrNotLoggedIn + } + u.log.Debugln("connecting to discord") + session, err := discordgo.New(u.Token) + if err != nil { + return err + } + + u.Session = session + // get our user info - user, err := u.User.Session.User("@me") + user, err := u.Session.User("@me") if err != nil { return err } @@ -213,37 +261,40 @@ func (u *User) Connect() error { u.User.ID = user.ID // Add our event handlers - u.User.Session.AddHandler(u.connectedHandler) - u.User.Session.AddHandler(u.disconnectedHandler) + u.Session.AddHandler(u.connectedHandler) + u.Session.AddHandler(u.disconnectedHandler) - u.User.Session.AddHandler(u.channelCreateHandler) - u.User.Session.AddHandler(u.channelDeleteHandler) - u.User.Session.AddHandler(u.channelPinsUpdateHandler) - u.User.Session.AddHandler(u.channelUpdateHandler) + u.Session.AddHandler(u.channelCreateHandler) + u.Session.AddHandler(u.channelDeleteHandler) + u.Session.AddHandler(u.channelPinsUpdateHandler) + u.Session.AddHandler(u.channelUpdateHandler) - u.User.Session.AddHandler(u.messageCreateHandler) - u.User.Session.AddHandler(u.messageDeleteHandler) - u.User.Session.AddHandler(u.messageUpdateHandler) - u.User.Session.AddHandler(u.reactionAddHandler) - u.User.Session.AddHandler(u.reactionRemoveHandler) + u.Session.AddHandler(u.messageCreateHandler) + u.Session.AddHandler(u.messageDeleteHandler) + u.Session.AddHandler(u.messageUpdateHandler) + u.Session.AddHandler(u.reactionAddHandler) + u.Session.AddHandler(u.reactionRemoveHandler) - // u.User.Session.Identify.Capabilities = 125 - // // Setup our properties - // u.User.Session.Identify.Properties = discordgo.IdentifyProperties{ - // OS: "Windows", - // OSVersion: "10", - // Browser: "Chrome", - // BrowserUserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36", - // BrowserVersion: "92.0.4515.159", - // Referrer: "https://discord.com/channels/@me", - // ReferringDomain: "discord.com", - // ClientBuildNumber: "83364", - // ReleaseChannel: "stable", - // } + u.Session.Identify.Presence.Status = "online" - u.User.Session.Identify.Presence.Status = "online" + return u.Session.Open() +} - return u.User.Session.Open() +func (u *User) Disconnect() error { + u.Lock() + defer u.Unlock() + + if u.Session == nil { + return ErrNotConnected + } + + if err := u.Session.Close(); err != nil { + return err + } + + u.Session = nil + + return nil } func (u *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) { diff --git a/database/user.go b/database/user.go index fc2ee03..4fc093a 100644 --- a/database/user.go +++ b/database/user.go @@ -3,8 +3,6 @@ package database import ( "database/sql" - "github.com/bwmarrin/discordgo" - log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" ) @@ -18,38 +16,7 @@ type User struct { ManagementRoom id.RoomID - Session *discordgo.Session -} - -// NewSession is just used to create the session and update the database. It -// should only be called by bridge.User.Connect which will continue setting up -// event handlers and everything else. -func (u *User) NewSession(token string) error { - session, err := discordgo.New(token) - if err != nil { - return err - } - - u.Session = session - - u.Update() - - return nil -} - -// DeleteSession tries to logout and delete the session from the database. -func (u *User) DeleteSession() error { - err := u.Session.Close() - - if err != nil { - u.log.Warnfln("failed to close the session for %s: %v", u.ID, err) - } - - u.Session = nil - - u.Update() - - return nil + Token string } func (u *User) Scan(row Scannable) *User { @@ -65,31 +32,25 @@ func (u *User) Scan(row Scannable) *User { } if token.Valid { - if err := u.NewSession(token.String); err != nil { - u.log.Errorln("Failed to login: ", err) - } + u.Token = token.String } return u } -func (u *User) sessionNonptr() discordgo.Session { - if u.Session != nil { - return *u.Session - } - - return discordgo.Session{} -} - func (u *User) Insert() { - session := u.sessionNonptr() - query := "INSERT INTO \"user\"" + " (mxid, id, management_room, token)" + " VALUES ($1, $2, $3, $4);" - _, err := u.db.Exec(query, u.MXID, u.ID, u.ManagementRoom, - session.Identify.Token) + var token sql.NullString + + if u.Token != "" { + token.String = u.Token + token.Valid = true + } + + _, err := u.db.Exec(query, u.MXID, u.ID, u.ManagementRoom, token) if err != nil { u.log.Warnfln("Failed to insert %s: %v", u.MXID, err) @@ -97,13 +58,18 @@ func (u *User) Insert() { } func (u *User) Update() { - session := u.sessionNonptr() - query := "UPDATE \"user\" SET" + " id=$1, management_room=$2, token=$3" + " WHERE mxid=$4;" - _, err := u.db.Exec(query, u.ID, u.ManagementRoom, session.Identify.Token, u.MXID) + var token sql.NullString + + if u.Token != "" { + token.String = u.Token + token.Valid = true + } + + _, err := u.db.Exec(query, u.ID, u.ManagementRoom, token, u.MXID) if err != nil { u.log.Warnfln("Failed to update %q: %v", u.MXID, err) diff --git a/remoteauth/client.go b/remoteauth/client.go index 4b96d8d..fdea7df 100644 --- a/remoteauth/client.go +++ b/remoteauth/client.go @@ -8,11 +8,14 @@ import ( "encoding/base64" "encoding/json" "net/http" + "sync" "github.com/gorilla/websocket" ) type Client struct { + sync.Mutex + URL string Origin string @@ -48,6 +51,9 @@ func New() (*Client, error) { // Dial will start the QRCode login process. ctx may be used to abandon the // process. func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan struct{}) error { + c.Lock() + defer c.Unlock() + header := http.Header{ "Origin": []string{c.Origin}, } @@ -68,10 +74,16 @@ func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan str } func (c *Client) Result() (User, error) { + c.Lock() + defer c.Unlock() + return c.user, c.err } func (c *Client) close() error { + c.Lock() + defer c.Unlock() + if c.closed { return nil } @@ -89,6 +101,9 @@ func (c *Client) close() error { } func (c *Client) write(p clientPacket) error { + c.Lock() + defer c.Unlock() + payload, err := json.Marshal(p) if err != nil { return err diff --git a/remoteauth/serverpackets.go b/remoteauth/serverpackets.go index 74666a0..7996309 100644 --- a/remoteauth/serverpackets.go +++ b/remoteauth/serverpackets.go @@ -22,10 +22,15 @@ func (c *Client) processMessages() { defer c.close() for { + c.Lock() _, packet, err := c.conn.ReadMessage() + c.Unlock() + if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { + c.Lock() c.err = err + c.Unlock() } return @@ -33,7 +38,9 @@ func (c *Client) processMessages() { raw := rawPacket{} if err := json.Unmarshal(packet, &raw); err != nil { + c.Lock() c.err = err + c.Unlock() return } @@ -57,7 +64,9 @@ func (c *Client) processMessages() { } if err := json.Unmarshal(packet, dest); err != nil { + c.Lock() c.err = err + c.Unlock() return } @@ -65,7 +74,9 @@ func (c *Client) processMessages() { op := dest.(serverPacket) err = op.process(c) if err != nil { + c.Lock() c.err = err + c.Unlock() return } @@ -92,7 +103,10 @@ func (h *serverHello) process(client *Client) error { case <-ticker.C: h := clientHeartbeat{} if err := h.send(client); err != nil { + client.Lock() client.err = err + client.Unlock() + return } } @@ -104,8 +118,10 @@ func (h *serverHello) process(client *Client) error { <-time.After(duration) + client.Lock() client.err = fmt.Errorf("Timed out after %s", duration) client.close() + client.Unlock() }() i := clientInit{}