Add (dis|re)connect commands and provision api

Also fixed a number of data races.
This commit is contained in:
Gary Kramlich
2022-02-22 07:56:15 -06:00
parent d883befee5
commit 4b87ea1cc7
8 changed files with 287 additions and 110 deletions

View File

@@ -101,7 +101,12 @@ func (h *commandHandler) handle(roomID id.RoomID, user *User, message string, re
if err != nil { if err != nil {
h.log.Warnf("Command %q failed: %v", message, err) h.log.Warnf("Command %q failed: %v", message, err)
cmd.globals.reply("unexpected failure") output := buf.String()
if output != "" {
cmd.globals.reply(output)
} else {
cmd.globals.reply("unexpected failure")
}
return return
} }

View File

@@ -46,10 +46,12 @@ func (g *globals) reply(msg string) {
type commands struct { type commands struct {
globals globals
Help helpCmd `kong:"cmd,help='Displays this message.'"` Disconnect disconnectCmd `kong:"cmd,help='Disconnect from Discord'"`
Login loginCmd `kong:"cmd,help='Log in to Discord.'"` Help helpCmd `kong:"cmd,help='Displays this message.'"`
Logout logoutCmd `kong:"cmd,help='Log out of Discord.'"` Login loginCmd `kong:"cmd,help='Log in to Discord.'"`
Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"` Logout logoutCmd `kong:"cmd,help='Log out of Discord.'"`
Reconnect reconnectCmd `kong:"cmd,help='Reconnect to Discord'"`
Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"`
} }
type helpCmd struct { type helpCmd struct {
@@ -87,6 +89,12 @@ func (c *versionCmd) Run(g *globals) error {
type loginCmd struct{} type loginCmd struct{}
func (l *loginCmd) Run(g *globals) error { func (l *loginCmd) Run(g *globals) error {
if g.user.LoggedIn() {
fmt.Fprintf(g.context.Stdout, "You are already logged in")
return fmt.Errorf("user already logged in")
}
client, err := remoteauth.New() client, err := remoteauth.New()
if err != nil { if err != nil {
return err return err
@@ -145,7 +153,7 @@ func (l *logoutCmd) Run(g *globals) error {
return fmt.Errorf("user is not logged in") return fmt.Errorf("user is not logged in")
} }
err := g.user.DeleteSession() err := g.user.Logout()
if err != nil { if err != nil {
fmt.Fprintln(g.context.Stdout, "Failed to log out") fmt.Fprintln(g.context.Stdout, "Failed to log out")
@@ -156,3 +164,43 @@ func (l *logoutCmd) Run(g *globals) error {
return nil return nil
} }
type disconnectCmd struct{}
func (d *disconnectCmd) Run(g *globals) error {
if !g.user.Connected() {
fmt.Fprintln(g.context.Stdout, "You are not connected")
return fmt.Errorf("user is not connected")
}
if err := g.user.Disconnect(); err != nil {
fmt.Fprintln(g.context.Stdout, "Failed to disconnect")
return err
}
fmt.Fprintln(g.context.Stdout, "Successfully disconnected")
return nil
}
type reconnectCmd struct{}
func (r *reconnectCmd) Run(g *globals) error {
if g.user.Connected() {
fmt.Fprintln(g.context.Stdout, "You are already connected")
return fmt.Errorf("user is already connected")
}
if err := g.user.Connect(); err != nil {
fmt.Fprintln(g.context.Stdout, "Failed to connect")
return err
}
fmt.Fprintln(g.context.Stdout, "Successfully connected")
return nil
}

View File

@@ -162,7 +162,7 @@ func (mh *matrixHandler) handleBotInvite(evt *event.Event) {
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Welcome) mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Welcome)
if evt.RoomID == user.ManagementRoom { if evt.RoomID == user.ManagementRoom {
if user.HasSession() { if user.Connected() {
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Connected) mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Connected)
} else { } else {
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.NotConnected) mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.NotConnected)

View File

@@ -42,9 +42,11 @@ func newProvisioningAPI(bridge *Bridge) *ProvisioningAPI {
r.Use(p.authMiddleware) r.Use(p.authMiddleware)
r.HandleFunc("/ping", p.Ping).Methods(http.MethodGet) r.HandleFunc("/disconnect", p.disconnect).Methods(http.MethodPost)
r.HandleFunc("/login", p.Login).Methods(http.MethodGet) r.HandleFunc("/ping", p.ping).Methods(http.MethodGet)
r.HandleFunc("/logout", p.Logout).Methods(http.MethodPost) r.HandleFunc("/login", p.login).Methods(http.MethodGet)
r.HandleFunc("/logout", p.logout).Methods(http.MethodPost)
r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost)
return p return p
} }
@@ -138,38 +140,78 @@ var upgrader = websocket.Upgrader{
} }
// Handlers // Handlers
func (p *ProvisioningAPI) Ping(w http.ResponseWriter, r *http.Request) { func (p *ProvisioningAPI) disconnect(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
if !user.Connected() {
jsonResponse(w, http.StatusConflict, Error{
Error: "You're not connected to discord",
ErrCode: "not connected",
})
return
}
if err := user.Disconnect(); err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Failed to disconnect from discord",
ErrCode: "failed to disconnect",
})
} else {
jsonResponse(w, http.StatusOK, Response{
Success: true,
Status: "Disconnected from Discord",
})
}
}
func (p *ProvisioningAPI) ping(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User) user := r.Context().Value("user").(*User)
discord := map[string]interface{}{ discord := map[string]interface{}{
"has_session": user.Session != nil, "logged_in": user.LoggedIn(),
"management_room": user.ManagementRoom, "connected": user.Connected(),
"conn": nil, "conn": nil,
} }
user.Lock()
if user.ID != "" { if user.ID != "" {
discord["id"] = user.ID discord["id"] = user.ID
} }
if user.Session != nil { if user.Session != nil {
user.Session.Lock()
discord["conn"] = map[string]interface{}{ discord["conn"] = map[string]interface{}{
"last_heartbeat_ack": user.Session.LastHeartbeatAck, "last_heartbeat_ack": user.Session.LastHeartbeatAck,
"last_heartbeat_sent": user.Session.LastHeartbeatSent, "last_heartbeat_sent": user.Session.LastHeartbeatSent,
} }
user.Session.Unlock()
} }
resp := map[string]interface{}{ resp := map[string]interface{}{
"mxid": user.MXID, "discord": discord,
"discord": discord, "management_room": user.ManagementRoom,
"mxid": user.MXID,
} }
user.Unlock()
jsonResponse(w, http.StatusOK, resp) jsonResponse(w, http.StatusOK, resp)
} }
func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { func (p *ProvisioningAPI) logout(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User) user := r.Context().Value("user").(*User)
force := strings.ToLower(r.URL.Query().Get("force")) != "false" force := strings.ToLower(r.URL.Query().Get("force")) != "false"
if !user.LoggedIn() {
jsonResponse(w, http.StatusNotFound, Error{
Error: "You're not logged in",
ErrCode: "not logged in",
})
return
}
if user.Session == nil { if user.Session == nil {
if force { if force {
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
@@ -183,7 +225,7 @@ func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
return return
} }
err := user.DeleteSession() err := user.Logout()
if err != nil { if err != nil {
user.log.Warnln("Error while logging out:", err) user.log.Warnln("Error while logging out:", err)
@@ -200,7 +242,7 @@ func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
} }
func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id") userID := r.URL.Query().Get("user_id")
user := p.bridge.GetUserByMXID(id.UserID(userID)) user := p.bridge.GetUserByMXID(id.UserID(userID))
@@ -220,7 +262,7 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
go func() { go func() {
// Read everything so SetCloseHandler() works // Read everything so SetCloseHandler() works
for { for {
_, _, err = c.ReadMessage() _, _, err := c.ReadMessage()
if err != nil { if err != nil {
break break
} }
@@ -236,6 +278,15 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
return nil return nil
}) })
if user.LoggedIn() {
c.WriteJSON(Error{
Error: "You're already logged into Discord",
ErrCode: "already logged in",
})
return
}
client, err := remoteauth.New() client, err := remoteauth.New()
if err != nil { if err != nil {
user.log.Errorf("Failed to log in from provisioning API:", err) user.log.Errorf("Failed to log in from provisioning API:", err)
@@ -280,6 +331,9 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
return return
} }
user.ID = discordUser.UserID
user.Update()
if err := user.Login(discordUser.Token); err != nil { if err := user.Login(discordUser.Token); err != nil {
c.WriteJSON(Error{ c.WriteJSON(Error{
Error: "Failed to connect to Discord", Error: "Failed to connect to Discord",
@@ -291,9 +345,6 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
return return
} }
user.ID = discordUser.UserID
user.Update()
c.WriteJSON(map[string]interface{}{ c.WriteJSON(map[string]interface{}{
"success": true, "success": true,
"id": user.ID, "id": user.ID,
@@ -305,3 +356,28 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
} }
} }
} }
func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
if user.Connected() {
jsonResponse(w, http.StatusConflict, Error{
Error: "You're already connected to discord",
ErrCode: "already connected",
})
return
}
if err := user.Connect(); err != nil {
jsonResponse(w, http.StatusInternalServerError, Error{
Error: "Failed to connect to discord",
ErrCode: "failed to connect",
})
} else {
jsonResponse(w, http.StatusOK, Response{
Success: true,
Status: "Connected to Discord",
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"sync"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
@@ -17,11 +18,20 @@ import (
"gitlab.com/beeper/discord/database" "gitlab.com/beeper/discord/database"
) )
var (
ErrNotConnected = errors.New("not connected")
ErrNotLoggedIn = errors.New("not logged in")
)
type User struct { type User struct {
*database.User *database.User
sync.Mutex
bridge *Bridge bridge *Bridge
log log.Logger log log.Logger
Session *discordgo.Session
} }
func (b *Bridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { func (b *Bridge) loadUser(dbUser *database.User, mxid *id.UserID) *User {
@@ -140,10 +150,6 @@ func (u *User) SetManagementRoom(roomID id.RoomID) {
u.Update() u.Update()
} }
func (u *User) HasSession() bool {
return u.User.Session != nil
}
func (u *User) sendQRCode(bot *appservice.IntentAPI, roomID id.RoomID, code string) (id.EventID, error) { func (u *User) sendQRCode(bot *appservice.IntentAPI, roomID id.RoomID, code string) (id.EventID, error) {
url, err := u.uploadQRCode(code) url, err := u.uploadQRCode(code)
if err != nil { if err != nil {
@@ -189,23 +195,65 @@ func (u *User) Login(token string) error {
return fmt.Errorf("No token specified") return fmt.Errorf("No token specified")
} }
err := u.User.NewSession(token) u.Token = token
if err != nil { u.Update()
return err
}
return u.Connect() return u.Connect()
} }
func (u *User) LoggedIn() bool { func (u *User) LoggedIn() bool {
u.Lock()
defer u.Unlock()
return u.Token != ""
}
func (u *User) Logout() error {
u.Lock()
defer u.Unlock()
if u.Session == nil {
return ErrNotLoggedIn
}
if err := u.Session.Close(); err != nil {
return err
}
u.Session = nil
u.Token = ""
u.Update()
return nil
}
func (u *User) Connected() bool {
u.Lock()
defer u.Unlock()
return u.Session != nil return u.Session != nil
} }
func (u *User) Connect() error { func (u *User) Connect() error {
u.Lock()
defer u.Unlock()
if u.Token == "" {
return ErrNotLoggedIn
}
u.log.Debugln("connecting to discord") u.log.Debugln("connecting to discord")
session, err := discordgo.New(u.Token)
if err != nil {
return err
}
u.Session = session
// get our user info // get our user info
user, err := u.User.Session.User("@me") user, err := u.Session.User("@me")
if err != nil { if err != nil {
return err return err
} }
@@ -213,37 +261,40 @@ func (u *User) Connect() error {
u.User.ID = user.ID u.User.ID = user.ID
// Add our event handlers // Add our event handlers
u.User.Session.AddHandler(u.connectedHandler) u.Session.AddHandler(u.connectedHandler)
u.User.Session.AddHandler(u.disconnectedHandler) u.Session.AddHandler(u.disconnectedHandler)
u.User.Session.AddHandler(u.channelCreateHandler) u.Session.AddHandler(u.channelCreateHandler)
u.User.Session.AddHandler(u.channelDeleteHandler) u.Session.AddHandler(u.channelDeleteHandler)
u.User.Session.AddHandler(u.channelPinsUpdateHandler) u.Session.AddHandler(u.channelPinsUpdateHandler)
u.User.Session.AddHandler(u.channelUpdateHandler) u.Session.AddHandler(u.channelUpdateHandler)
u.User.Session.AddHandler(u.messageCreateHandler) u.Session.AddHandler(u.messageCreateHandler)
u.User.Session.AddHandler(u.messageDeleteHandler) u.Session.AddHandler(u.messageDeleteHandler)
u.User.Session.AddHandler(u.messageUpdateHandler) u.Session.AddHandler(u.messageUpdateHandler)
u.User.Session.AddHandler(u.reactionAddHandler) u.Session.AddHandler(u.reactionAddHandler)
u.User.Session.AddHandler(u.reactionRemoveHandler) u.Session.AddHandler(u.reactionRemoveHandler)
// u.User.Session.Identify.Capabilities = 125 u.Session.Identify.Presence.Status = "online"
// // Setup our properties
// u.User.Session.Identify.Properties = discordgo.IdentifyProperties{
// OS: "Windows",
// OSVersion: "10",
// Browser: "Chrome",
// BrowserUserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36",
// BrowserVersion: "92.0.4515.159",
// Referrer: "https://discord.com/channels/@me",
// ReferringDomain: "discord.com",
// ClientBuildNumber: "83364",
// ReleaseChannel: "stable",
// }
u.User.Session.Identify.Presence.Status = "online" return u.Session.Open()
}
return u.User.Session.Open() func (u *User) Disconnect() error {
u.Lock()
defer u.Unlock()
if u.Session == nil {
return ErrNotConnected
}
if err := u.Session.Close(); err != nil {
return err
}
u.Session = nil
return nil
} }
func (u *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) { func (u *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) {

View File

@@ -3,8 +3,6 @@ package database
import ( import (
"database/sql" "database/sql"
"github.com/bwmarrin/discordgo"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@@ -18,38 +16,7 @@ type User struct {
ManagementRoom id.RoomID ManagementRoom id.RoomID
Session *discordgo.Session Token string
}
// NewSession is just used to create the session and update the database. It
// should only be called by bridge.User.Connect which will continue setting up
// event handlers and everything else.
func (u *User) NewSession(token string) error {
session, err := discordgo.New(token)
if err != nil {
return err
}
u.Session = session
u.Update()
return nil
}
// DeleteSession tries to logout and delete the session from the database.
func (u *User) DeleteSession() error {
err := u.Session.Close()
if err != nil {
u.log.Warnfln("failed to close the session for %s: %v", u.ID, err)
}
u.Session = nil
u.Update()
return nil
} }
func (u *User) Scan(row Scannable) *User { func (u *User) Scan(row Scannable) *User {
@@ -65,31 +32,25 @@ func (u *User) Scan(row Scannable) *User {
} }
if token.Valid { if token.Valid {
if err := u.NewSession(token.String); err != nil { u.Token = token.String
u.log.Errorln("Failed to login: ", err)
}
} }
return u return u
} }
func (u *User) sessionNonptr() discordgo.Session {
if u.Session != nil {
return *u.Session
}
return discordgo.Session{}
}
func (u *User) Insert() { func (u *User) Insert() {
session := u.sessionNonptr()
query := "INSERT INTO \"user\"" + query := "INSERT INTO \"user\"" +
" (mxid, id, management_room, token)" + " (mxid, id, management_room, token)" +
" VALUES ($1, $2, $3, $4);" " VALUES ($1, $2, $3, $4);"
_, err := u.db.Exec(query, u.MXID, u.ID, u.ManagementRoom, var token sql.NullString
session.Identify.Token)
if u.Token != "" {
token.String = u.Token
token.Valid = true
}
_, err := u.db.Exec(query, u.MXID, u.ID, u.ManagementRoom, token)
if err != nil { if err != nil {
u.log.Warnfln("Failed to insert %s: %v", u.MXID, err) u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
@@ -97,13 +58,18 @@ func (u *User) Insert() {
} }
func (u *User) Update() { func (u *User) Update() {
session := u.sessionNonptr()
query := "UPDATE \"user\" SET" + query := "UPDATE \"user\" SET" +
" id=$1, management_room=$2, token=$3" + " id=$1, management_room=$2, token=$3" +
" WHERE mxid=$4;" " WHERE mxid=$4;"
_, err := u.db.Exec(query, u.ID, u.ManagementRoom, session.Identify.Token, u.MXID) var token sql.NullString
if u.Token != "" {
token.String = u.Token
token.Valid = true
}
_, err := u.db.Exec(query, u.ID, u.ManagementRoom, token, u.MXID)
if err != nil { if err != nil {
u.log.Warnfln("Failed to update %q: %v", u.MXID, err) u.log.Warnfln("Failed to update %q: %v", u.MXID, err)

View File

@@ -8,11 +8,14 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"net/http" "net/http"
"sync"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
type Client struct { type Client struct {
sync.Mutex
URL string URL string
Origin string Origin string
@@ -48,6 +51,9 @@ func New() (*Client, error) {
// Dial will start the QRCode login process. ctx may be used to abandon the // Dial will start the QRCode login process. ctx may be used to abandon the
// process. // process.
func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan struct{}) error { func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan struct{}) error {
c.Lock()
defer c.Unlock()
header := http.Header{ header := http.Header{
"Origin": []string{c.Origin}, "Origin": []string{c.Origin},
} }
@@ -68,10 +74,16 @@ func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan str
} }
func (c *Client) Result() (User, error) { func (c *Client) Result() (User, error) {
c.Lock()
defer c.Unlock()
return c.user, c.err return c.user, c.err
} }
func (c *Client) close() error { func (c *Client) close() error {
c.Lock()
defer c.Unlock()
if c.closed { if c.closed {
return nil return nil
} }
@@ -89,6 +101,9 @@ func (c *Client) close() error {
} }
func (c *Client) write(p clientPacket) error { func (c *Client) write(p clientPacket) error {
c.Lock()
defer c.Unlock()
payload, err := json.Marshal(p) payload, err := json.Marshal(p)
if err != nil { if err != nil {
return err return err

View File

@@ -22,10 +22,15 @@ func (c *Client) processMessages() {
defer c.close() defer c.close()
for { for {
c.Lock()
_, packet, err := c.conn.ReadMessage() _, packet, err := c.conn.ReadMessage()
c.Unlock()
if err != nil { if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
c.Lock()
c.err = err c.err = err
c.Unlock()
} }
return return
@@ -33,7 +38,9 @@ func (c *Client) processMessages() {
raw := rawPacket{} raw := rawPacket{}
if err := json.Unmarshal(packet, &raw); err != nil { if err := json.Unmarshal(packet, &raw); err != nil {
c.Lock()
c.err = err c.err = err
c.Unlock()
return return
} }
@@ -57,7 +64,9 @@ func (c *Client) processMessages() {
} }
if err := json.Unmarshal(packet, dest); err != nil { if err := json.Unmarshal(packet, dest); err != nil {
c.Lock()
c.err = err c.err = err
c.Unlock()
return return
} }
@@ -65,7 +74,9 @@ func (c *Client) processMessages() {
op := dest.(serverPacket) op := dest.(serverPacket)
err = op.process(c) err = op.process(c)
if err != nil { if err != nil {
c.Lock()
c.err = err c.err = err
c.Unlock()
return return
} }
@@ -92,7 +103,10 @@ func (h *serverHello) process(client *Client) error {
case <-ticker.C: case <-ticker.C:
h := clientHeartbeat{} h := clientHeartbeat{}
if err := h.send(client); err != nil { if err := h.send(client); err != nil {
client.Lock()
client.err = err client.err = err
client.Unlock()
return return
} }
} }
@@ -104,8 +118,10 @@ func (h *serverHello) process(client *Client) error {
<-time.After(duration) <-time.After(duration)
client.Lock()
client.err = fmt.Errorf("Timed out after %s", duration) client.err = fmt.Errorf("Timed out after %s", duration)
client.close() client.close()
client.Unlock()
}() }()
i := clientInit{} i := clientInit{}