Add (dis|re)connect commands and provision api
Also fixed a number of data races.
This commit is contained in:
@@ -101,7 +101,12 @@ func (h *commandHandler) handle(roomID id.RoomID, user *User, message string, re
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
h.log.Warnf("Command %q failed: %v", message, err)
|
h.log.Warnf("Command %q failed: %v", message, err)
|
||||||
|
|
||||||
cmd.globals.reply("unexpected failure")
|
output := buf.String()
|
||||||
|
if output != "" {
|
||||||
|
cmd.globals.reply(output)
|
||||||
|
} else {
|
||||||
|
cmd.globals.reply("unexpected failure")
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,10 +46,12 @@ func (g *globals) reply(msg string) {
|
|||||||
type commands struct {
|
type commands struct {
|
||||||
globals
|
globals
|
||||||
|
|
||||||
Help helpCmd `kong:"cmd,help='Displays this message.'"`
|
Disconnect disconnectCmd `kong:"cmd,help='Disconnect from Discord'"`
|
||||||
Login loginCmd `kong:"cmd,help='Log in to Discord.'"`
|
Help helpCmd `kong:"cmd,help='Displays this message.'"`
|
||||||
Logout logoutCmd `kong:"cmd,help='Log out of Discord.'"`
|
Login loginCmd `kong:"cmd,help='Log in to Discord.'"`
|
||||||
Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"`
|
Logout logoutCmd `kong:"cmd,help='Log out of Discord.'"`
|
||||||
|
Reconnect reconnectCmd `kong:"cmd,help='Reconnect to Discord'"`
|
||||||
|
Version versionCmd `kong:"cmd,help='Displays the version of the bridge.'"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type helpCmd struct {
|
type helpCmd struct {
|
||||||
@@ -87,6 +89,12 @@ func (c *versionCmd) Run(g *globals) error {
|
|||||||
type loginCmd struct{}
|
type loginCmd struct{}
|
||||||
|
|
||||||
func (l *loginCmd) Run(g *globals) error {
|
func (l *loginCmd) Run(g *globals) error {
|
||||||
|
if g.user.LoggedIn() {
|
||||||
|
fmt.Fprintf(g.context.Stdout, "You are already logged in")
|
||||||
|
|
||||||
|
return fmt.Errorf("user already logged in")
|
||||||
|
}
|
||||||
|
|
||||||
client, err := remoteauth.New()
|
client, err := remoteauth.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -145,7 +153,7 @@ func (l *logoutCmd) Run(g *globals) error {
|
|||||||
return fmt.Errorf("user is not logged in")
|
return fmt.Errorf("user is not logged in")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := g.user.DeleteSession()
|
err := g.user.Logout()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(g.context.Stdout, "Failed to log out")
|
fmt.Fprintln(g.context.Stdout, "Failed to log out")
|
||||||
|
|
||||||
@@ -156,3 +164,43 @@ func (l *logoutCmd) Run(g *globals) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type disconnectCmd struct{}
|
||||||
|
|
||||||
|
func (d *disconnectCmd) Run(g *globals) error {
|
||||||
|
if !g.user.Connected() {
|
||||||
|
fmt.Fprintln(g.context.Stdout, "You are not connected")
|
||||||
|
|
||||||
|
return fmt.Errorf("user is not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.user.Disconnect(); err != nil {
|
||||||
|
fmt.Fprintln(g.context.Stdout, "Failed to disconnect")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(g.context.Stdout, "Successfully disconnected")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type reconnectCmd struct{}
|
||||||
|
|
||||||
|
func (r *reconnectCmd) Run(g *globals) error {
|
||||||
|
if g.user.Connected() {
|
||||||
|
fmt.Fprintln(g.context.Stdout, "You are already connected")
|
||||||
|
|
||||||
|
return fmt.Errorf("user is already connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.user.Connect(); err != nil {
|
||||||
|
fmt.Fprintln(g.context.Stdout, "Failed to connect")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintln(g.context.Stdout, "Successfully connected")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ func (mh *matrixHandler) handleBotInvite(evt *event.Event) {
|
|||||||
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Welcome)
|
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Welcome)
|
||||||
|
|
||||||
if evt.RoomID == user.ManagementRoom {
|
if evt.RoomID == user.ManagementRoom {
|
||||||
if user.HasSession() {
|
if user.Connected() {
|
||||||
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Connected)
|
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Connected)
|
||||||
} else {
|
} else {
|
||||||
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.NotConnected)
|
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.NotConnected)
|
||||||
|
|||||||
@@ -42,9 +42,11 @@ func newProvisioningAPI(bridge *Bridge) *ProvisioningAPI {
|
|||||||
|
|
||||||
r.Use(p.authMiddleware)
|
r.Use(p.authMiddleware)
|
||||||
|
|
||||||
r.HandleFunc("/ping", p.Ping).Methods(http.MethodGet)
|
r.HandleFunc("/disconnect", p.disconnect).Methods(http.MethodPost)
|
||||||
r.HandleFunc("/login", p.Login).Methods(http.MethodGet)
|
r.HandleFunc("/ping", p.ping).Methods(http.MethodGet)
|
||||||
r.HandleFunc("/logout", p.Logout).Methods(http.MethodPost)
|
r.HandleFunc("/login", p.login).Methods(http.MethodGet)
|
||||||
|
r.HandleFunc("/logout", p.logout).Methods(http.MethodPost)
|
||||||
|
r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost)
|
||||||
|
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
@@ -138,38 +140,78 @@ var upgrader = websocket.Upgrader{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handlers
|
// Handlers
|
||||||
func (p *ProvisioningAPI) Ping(w http.ResponseWriter, r *http.Request) {
|
func (p *ProvisioningAPI) disconnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user := r.Context().Value("user").(*User)
|
||||||
|
|
||||||
|
if !user.Connected() {
|
||||||
|
jsonResponse(w, http.StatusConflict, Error{
|
||||||
|
Error: "You're not connected to discord",
|
||||||
|
ErrCode: "not connected",
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.Disconnect(); err != nil {
|
||||||
|
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||||
|
Error: "Failed to disconnect from discord",
|
||||||
|
ErrCode: "failed to disconnect",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
jsonResponse(w, http.StatusOK, Response{
|
||||||
|
Success: true,
|
||||||
|
Status: "Disconnected from Discord",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ProvisioningAPI) ping(w http.ResponseWriter, r *http.Request) {
|
||||||
user := r.Context().Value("user").(*User)
|
user := r.Context().Value("user").(*User)
|
||||||
|
|
||||||
discord := map[string]interface{}{
|
discord := map[string]interface{}{
|
||||||
"has_session": user.Session != nil,
|
"logged_in": user.LoggedIn(),
|
||||||
"management_room": user.ManagementRoom,
|
"connected": user.Connected(),
|
||||||
"conn": nil,
|
"conn": nil,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user.Lock()
|
||||||
if user.ID != "" {
|
if user.ID != "" {
|
||||||
discord["id"] = user.ID
|
discord["id"] = user.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
if user.Session != nil {
|
if user.Session != nil {
|
||||||
|
user.Session.Lock()
|
||||||
discord["conn"] = map[string]interface{}{
|
discord["conn"] = map[string]interface{}{
|
||||||
"last_heartbeat_ack": user.Session.LastHeartbeatAck,
|
"last_heartbeat_ack": user.Session.LastHeartbeatAck,
|
||||||
"last_heartbeat_sent": user.Session.LastHeartbeatSent,
|
"last_heartbeat_sent": user.Session.LastHeartbeatSent,
|
||||||
}
|
}
|
||||||
|
user.Session.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := map[string]interface{}{
|
resp := map[string]interface{}{
|
||||||
"mxid": user.MXID,
|
"discord": discord,
|
||||||
"discord": discord,
|
"management_room": user.ManagementRoom,
|
||||||
|
"mxid": user.MXID,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user.Unlock()
|
||||||
|
|
||||||
jsonResponse(w, http.StatusOK, resp)
|
jsonResponse(w, http.StatusOK, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
|
func (p *ProvisioningAPI) logout(w http.ResponseWriter, r *http.Request) {
|
||||||
user := r.Context().Value("user").(*User)
|
user := r.Context().Value("user").(*User)
|
||||||
force := strings.ToLower(r.URL.Query().Get("force")) != "false"
|
force := strings.ToLower(r.URL.Query().Get("force")) != "false"
|
||||||
|
|
||||||
|
if !user.LoggedIn() {
|
||||||
|
jsonResponse(w, http.StatusNotFound, Error{
|
||||||
|
Error: "You're not logged in",
|
||||||
|
ErrCode: "not logged in",
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if user.Session == nil {
|
if user.Session == nil {
|
||||||
if force {
|
if force {
|
||||||
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
|
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
|
||||||
@@ -183,7 +225,7 @@ func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err := user.DeleteSession()
|
err := user.Logout()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
user.log.Warnln("Error while logging out:", err)
|
user.log.Warnln("Error while logging out:", err)
|
||||||
|
|
||||||
@@ -200,7 +242,7 @@ func (p *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
|
|||||||
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
|
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
|
||||||
userID := r.URL.Query().Get("user_id")
|
userID := r.URL.Query().Get("user_id")
|
||||||
user := p.bridge.GetUserByMXID(id.UserID(userID))
|
user := p.bridge.GetUserByMXID(id.UserID(userID))
|
||||||
|
|
||||||
@@ -220,7 +262,7 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
go func() {
|
go func() {
|
||||||
// Read everything so SetCloseHandler() works
|
// Read everything so SetCloseHandler() works
|
||||||
for {
|
for {
|
||||||
_, _, err = c.ReadMessage()
|
_, _, err := c.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -236,6 +278,15 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if user.LoggedIn() {
|
||||||
|
c.WriteJSON(Error{
|
||||||
|
Error: "You're already logged into Discord",
|
||||||
|
ErrCode: "already logged in",
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
client, err := remoteauth.New()
|
client, err := remoteauth.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
user.log.Errorf("Failed to log in from provisioning API:", err)
|
user.log.Errorf("Failed to log in from provisioning API:", err)
|
||||||
@@ -280,6 +331,9 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
user.ID = discordUser.UserID
|
||||||
|
user.Update()
|
||||||
|
|
||||||
if err := user.Login(discordUser.Token); err != nil {
|
if err := user.Login(discordUser.Token); err != nil {
|
||||||
c.WriteJSON(Error{
|
c.WriteJSON(Error{
|
||||||
Error: "Failed to connect to Discord",
|
Error: "Failed to connect to Discord",
|
||||||
@@ -291,9 +345,6 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user.ID = discordUser.UserID
|
|
||||||
user.Update()
|
|
||||||
|
|
||||||
c.WriteJSON(map[string]interface{}{
|
c.WriteJSON(map[string]interface{}{
|
||||||
"success": true,
|
"success": true,
|
||||||
"id": user.ID,
|
"id": user.ID,
|
||||||
@@ -305,3 +356,28 @@ func (p *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
user := r.Context().Value("user").(*User)
|
||||||
|
|
||||||
|
if user.Connected() {
|
||||||
|
jsonResponse(w, http.StatusConflict, Error{
|
||||||
|
Error: "You're already connected to discord",
|
||||||
|
ErrCode: "already connected",
|
||||||
|
})
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := user.Connect(); err != nil {
|
||||||
|
jsonResponse(w, http.StatusInternalServerError, Error{
|
||||||
|
Error: "Failed to connect to discord",
|
||||||
|
ErrCode: "failed to connect",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
jsonResponse(w, http.StatusOK, Response{
|
||||||
|
Success: true,
|
||||||
|
Status: "Connected to Discord",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
121
bridge/user.go
121
bridge/user.go
@@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
"github.com/skip2/go-qrcode"
|
"github.com/skip2/go-qrcode"
|
||||||
@@ -17,11 +18,20 @@ import (
|
|||||||
"gitlab.com/beeper/discord/database"
|
"gitlab.com/beeper/discord/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNotConnected = errors.New("not connected")
|
||||||
|
ErrNotLoggedIn = errors.New("not logged in")
|
||||||
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
*database.User
|
*database.User
|
||||||
|
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
bridge *Bridge
|
bridge *Bridge
|
||||||
log log.Logger
|
log log.Logger
|
||||||
|
|
||||||
|
Session *discordgo.Session
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Bridge) loadUser(dbUser *database.User, mxid *id.UserID) *User {
|
func (b *Bridge) loadUser(dbUser *database.User, mxid *id.UserID) *User {
|
||||||
@@ -140,10 +150,6 @@ func (u *User) SetManagementRoom(roomID id.RoomID) {
|
|||||||
u.Update()
|
u.Update()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) HasSession() bool {
|
|
||||||
return u.User.Session != nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *User) sendQRCode(bot *appservice.IntentAPI, roomID id.RoomID, code string) (id.EventID, error) {
|
func (u *User) sendQRCode(bot *appservice.IntentAPI, roomID id.RoomID, code string) (id.EventID, error) {
|
||||||
url, err := u.uploadQRCode(code)
|
url, err := u.uploadQRCode(code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -189,23 +195,65 @@ func (u *User) Login(token string) error {
|
|||||||
return fmt.Errorf("No token specified")
|
return fmt.Errorf("No token specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := u.User.NewSession(token)
|
u.Token = token
|
||||||
if err != nil {
|
u.Update()
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return u.Connect()
|
return u.Connect()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) LoggedIn() bool {
|
func (u *User) LoggedIn() bool {
|
||||||
|
u.Lock()
|
||||||
|
defer u.Unlock()
|
||||||
|
|
||||||
|
return u.Token != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Logout() error {
|
||||||
|
u.Lock()
|
||||||
|
defer u.Unlock()
|
||||||
|
|
||||||
|
if u.Session == nil {
|
||||||
|
return ErrNotLoggedIn
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := u.Session.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Session = nil
|
||||||
|
|
||||||
|
u.Token = ""
|
||||||
|
u.Update()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) Connected() bool {
|
||||||
|
u.Lock()
|
||||||
|
defer u.Unlock()
|
||||||
|
|
||||||
return u.Session != nil
|
return u.Session != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) Connect() error {
|
func (u *User) Connect() error {
|
||||||
|
u.Lock()
|
||||||
|
defer u.Unlock()
|
||||||
|
|
||||||
|
if u.Token == "" {
|
||||||
|
return ErrNotLoggedIn
|
||||||
|
}
|
||||||
|
|
||||||
u.log.Debugln("connecting to discord")
|
u.log.Debugln("connecting to discord")
|
||||||
|
|
||||||
|
session, err := discordgo.New(u.Token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Session = session
|
||||||
|
|
||||||
// get our user info
|
// get our user info
|
||||||
user, err := u.User.Session.User("@me")
|
user, err := u.Session.User("@me")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -213,37 +261,40 @@ func (u *User) Connect() error {
|
|||||||
u.User.ID = user.ID
|
u.User.ID = user.ID
|
||||||
|
|
||||||
// Add our event handlers
|
// Add our event handlers
|
||||||
u.User.Session.AddHandler(u.connectedHandler)
|
u.Session.AddHandler(u.connectedHandler)
|
||||||
u.User.Session.AddHandler(u.disconnectedHandler)
|
u.Session.AddHandler(u.disconnectedHandler)
|
||||||
|
|
||||||
u.User.Session.AddHandler(u.channelCreateHandler)
|
u.Session.AddHandler(u.channelCreateHandler)
|
||||||
u.User.Session.AddHandler(u.channelDeleteHandler)
|
u.Session.AddHandler(u.channelDeleteHandler)
|
||||||
u.User.Session.AddHandler(u.channelPinsUpdateHandler)
|
u.Session.AddHandler(u.channelPinsUpdateHandler)
|
||||||
u.User.Session.AddHandler(u.channelUpdateHandler)
|
u.Session.AddHandler(u.channelUpdateHandler)
|
||||||
|
|
||||||
u.User.Session.AddHandler(u.messageCreateHandler)
|
u.Session.AddHandler(u.messageCreateHandler)
|
||||||
u.User.Session.AddHandler(u.messageDeleteHandler)
|
u.Session.AddHandler(u.messageDeleteHandler)
|
||||||
u.User.Session.AddHandler(u.messageUpdateHandler)
|
u.Session.AddHandler(u.messageUpdateHandler)
|
||||||
u.User.Session.AddHandler(u.reactionAddHandler)
|
u.Session.AddHandler(u.reactionAddHandler)
|
||||||
u.User.Session.AddHandler(u.reactionRemoveHandler)
|
u.Session.AddHandler(u.reactionRemoveHandler)
|
||||||
|
|
||||||
// u.User.Session.Identify.Capabilities = 125
|
u.Session.Identify.Presence.Status = "online"
|
||||||
// // Setup our properties
|
|
||||||
// u.User.Session.Identify.Properties = discordgo.IdentifyProperties{
|
|
||||||
// OS: "Windows",
|
|
||||||
// OSVersion: "10",
|
|
||||||
// Browser: "Chrome",
|
|
||||||
// BrowserUserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36",
|
|
||||||
// BrowserVersion: "92.0.4515.159",
|
|
||||||
// Referrer: "https://discord.com/channels/@me",
|
|
||||||
// ReferringDomain: "discord.com",
|
|
||||||
// ClientBuildNumber: "83364",
|
|
||||||
// ReleaseChannel: "stable",
|
|
||||||
// }
|
|
||||||
|
|
||||||
u.User.Session.Identify.Presence.Status = "online"
|
return u.Session.Open()
|
||||||
|
}
|
||||||
|
|
||||||
return u.User.Session.Open()
|
func (u *User) Disconnect() error {
|
||||||
|
u.Lock()
|
||||||
|
defer u.Unlock()
|
||||||
|
|
||||||
|
if u.Session == nil {
|
||||||
|
return ErrNotConnected
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := u.Session.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Session = nil
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) {
|
func (u *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) {
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ package database
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
||||||
"github.com/bwmarrin/discordgo"
|
|
||||||
|
|
||||||
log "maunium.net/go/maulogger/v2"
|
log "maunium.net/go/maulogger/v2"
|
||||||
"maunium.net/go/mautrix/id"
|
"maunium.net/go/mautrix/id"
|
||||||
)
|
)
|
||||||
@@ -18,38 +16,7 @@ type User struct {
|
|||||||
|
|
||||||
ManagementRoom id.RoomID
|
ManagementRoom id.RoomID
|
||||||
|
|
||||||
Session *discordgo.Session
|
Token string
|
||||||
}
|
|
||||||
|
|
||||||
// NewSession is just used to create the session and update the database. It
|
|
||||||
// should only be called by bridge.User.Connect which will continue setting up
|
|
||||||
// event handlers and everything else.
|
|
||||||
func (u *User) NewSession(token string) error {
|
|
||||||
session, err := discordgo.New(token)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
u.Session = session
|
|
||||||
|
|
||||||
u.Update()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteSession tries to logout and delete the session from the database.
|
|
||||||
func (u *User) DeleteSession() error {
|
|
||||||
err := u.Session.Close()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
u.log.Warnfln("failed to close the session for %s: %v", u.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
u.Session = nil
|
|
||||||
|
|
||||||
u.Update()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) Scan(row Scannable) *User {
|
func (u *User) Scan(row Scannable) *User {
|
||||||
@@ -65,31 +32,25 @@ func (u *User) Scan(row Scannable) *User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if token.Valid {
|
if token.Valid {
|
||||||
if err := u.NewSession(token.String); err != nil {
|
u.Token = token.String
|
||||||
u.log.Errorln("Failed to login: ", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) sessionNonptr() discordgo.Session {
|
|
||||||
if u.Session != nil {
|
|
||||||
return *u.Session
|
|
||||||
}
|
|
||||||
|
|
||||||
return discordgo.Session{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *User) Insert() {
|
func (u *User) Insert() {
|
||||||
session := u.sessionNonptr()
|
|
||||||
|
|
||||||
query := "INSERT INTO \"user\"" +
|
query := "INSERT INTO \"user\"" +
|
||||||
" (mxid, id, management_room, token)" +
|
" (mxid, id, management_room, token)" +
|
||||||
" VALUES ($1, $2, $3, $4);"
|
" VALUES ($1, $2, $3, $4);"
|
||||||
|
|
||||||
_, err := u.db.Exec(query, u.MXID, u.ID, u.ManagementRoom,
|
var token sql.NullString
|
||||||
session.Identify.Token)
|
|
||||||
|
if u.Token != "" {
|
||||||
|
token.String = u.Token
|
||||||
|
token.Valid = true
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := u.db.Exec(query, u.MXID, u.ID, u.ManagementRoom, token)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
|
u.log.Warnfln("Failed to insert %s: %v", u.MXID, err)
|
||||||
@@ -97,13 +58,18 @@ func (u *User) Insert() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) Update() {
|
func (u *User) Update() {
|
||||||
session := u.sessionNonptr()
|
|
||||||
|
|
||||||
query := "UPDATE \"user\" SET" +
|
query := "UPDATE \"user\" SET" +
|
||||||
" id=$1, management_room=$2, token=$3" +
|
" id=$1, management_room=$2, token=$3" +
|
||||||
" WHERE mxid=$4;"
|
" WHERE mxid=$4;"
|
||||||
|
|
||||||
_, err := u.db.Exec(query, u.ID, u.ManagementRoom, session.Identify.Token, u.MXID)
|
var token sql.NullString
|
||||||
|
|
||||||
|
if u.Token != "" {
|
||||||
|
token.String = u.Token
|
||||||
|
token.Valid = true
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := u.db.Exec(query, u.ID, u.ManagementRoom, token, u.MXID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
u.log.Warnfln("Failed to update %q: %v", u.MXID, err)
|
u.log.Warnfln("Failed to update %q: %v", u.MXID, err)
|
||||||
|
|||||||
@@ -8,11 +8,14 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
sync.Mutex
|
||||||
|
|
||||||
URL string
|
URL string
|
||||||
Origin string
|
Origin string
|
||||||
|
|
||||||
@@ -48,6 +51,9 @@ func New() (*Client, error) {
|
|||||||
// Dial will start the QRCode login process. ctx may be used to abandon the
|
// Dial will start the QRCode login process. ctx may be used to abandon the
|
||||||
// process.
|
// process.
|
||||||
func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan struct{}) error {
|
func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan struct{}) error {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
header := http.Header{
|
header := http.Header{
|
||||||
"Origin": []string{c.Origin},
|
"Origin": []string{c.Origin},
|
||||||
}
|
}
|
||||||
@@ -68,10 +74,16 @@ func (c *Client) Dial(ctx context.Context, qrChan chan string, doneChan chan str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Result() (User, error) {
|
func (c *Client) Result() (User, error) {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
return c.user, c.err
|
return c.user, c.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) close() error {
|
func (c *Client) close() error {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
if c.closed {
|
if c.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -89,6 +101,9 @@ func (c *Client) close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) write(p clientPacket) error {
|
func (c *Client) write(p clientPacket) error {
|
||||||
|
c.Lock()
|
||||||
|
defer c.Unlock()
|
||||||
|
|
||||||
payload, err := json.Marshal(p)
|
payload, err := json.Marshal(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -22,10 +22,15 @@ func (c *Client) processMessages() {
|
|||||||
defer c.close()
|
defer c.close()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
c.Lock()
|
||||||
_, packet, err := c.conn.ReadMessage()
|
_, packet, err := c.conn.ReadMessage()
|
||||||
|
c.Unlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
|
||||||
|
c.Lock()
|
||||||
c.err = err
|
c.err = err
|
||||||
|
c.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -33,7 +38,9 @@ func (c *Client) processMessages() {
|
|||||||
|
|
||||||
raw := rawPacket{}
|
raw := rawPacket{}
|
||||||
if err := json.Unmarshal(packet, &raw); err != nil {
|
if err := json.Unmarshal(packet, &raw); err != nil {
|
||||||
|
c.Lock()
|
||||||
c.err = err
|
c.err = err
|
||||||
|
c.Unlock()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -57,7 +64,9 @@ func (c *Client) processMessages() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(packet, dest); err != nil {
|
if err := json.Unmarshal(packet, dest); err != nil {
|
||||||
|
c.Lock()
|
||||||
c.err = err
|
c.err = err
|
||||||
|
c.Unlock()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -65,7 +74,9 @@ func (c *Client) processMessages() {
|
|||||||
op := dest.(serverPacket)
|
op := dest.(serverPacket)
|
||||||
err = op.process(c)
|
err = op.process(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.Lock()
|
||||||
c.err = err
|
c.err = err
|
||||||
|
c.Unlock()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -92,7 +103,10 @@ func (h *serverHello) process(client *Client) error {
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
h := clientHeartbeat{}
|
h := clientHeartbeat{}
|
||||||
if err := h.send(client); err != nil {
|
if err := h.send(client); err != nil {
|
||||||
|
client.Lock()
|
||||||
client.err = err
|
client.err = err
|
||||||
|
client.Unlock()
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -104,8 +118,10 @@ func (h *serverHello) process(client *Client) error {
|
|||||||
|
|
||||||
<-time.After(duration)
|
<-time.After(duration)
|
||||||
|
|
||||||
|
client.Lock()
|
||||||
client.err = fmt.Errorf("Timed out after %s", duration)
|
client.err = fmt.Errorf("Timed out after %s", duration)
|
||||||
client.close()
|
client.close()
|
||||||
|
client.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
i := clientInit{}
|
i := clientInit{}
|
||||||
|
|||||||
Reference in New Issue
Block a user