Clean up QR login provisioning API

This commit is contained in:
Tulir Asokan
2023-01-13 14:03:08 +02:00
parent de59f449f1
commit 91c3fae7cb

View File

@@ -44,7 +44,7 @@ func newProvisioningAPI(br *DiscordBridge) *ProvisioningAPI {
r.HandleFunc("/disconnect", p.disconnect).Methods(http.MethodPost) r.HandleFunc("/disconnect", p.disconnect).Methods(http.MethodPost)
r.HandleFunc("/ping", p.ping).Methods(http.MethodGet) r.HandleFunc("/ping", p.ping).Methods(http.MethodGet)
r.HandleFunc("/login/qr", p.login).Methods(http.MethodGet) r.HandleFunc("/login/qr", p.qrLogin).Methods(http.MethodGet)
r.HandleFunc("/logout", p.logout).Methods(http.MethodPost) r.HandleFunc("/logout", p.logout).Methods(http.MethodPost)
r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost) r.HandleFunc("/reconnect", p.reconnect).Methods(http.MethodPost)
@@ -59,7 +59,7 @@ func newProvisioningAPI(br *DiscordBridge) *ProvisioningAPI {
func jsonResponse(w http.ResponseWriter, status int, response interface{}) { func jsonResponse(w http.ResponseWriter, status int, response interface{}) {
w.Header().Add("Content-Type", "application/json") w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
json.NewEncoder(w).Encode(response) _ = json.NewEncoder(w).Encode(response)
} }
// Response structs // Response structs
@@ -216,7 +216,7 @@ func (p *ProvisioningAPI) logout(w http.ResponseWriter, r *http.Request) {
jsonResponse(w, http.StatusOK, Response{true, msg}) jsonResponse(w, http.StatusOK, Response{true, msg})
} }
func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) { func (p *ProvisioningAPI) qrLogin(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id") userID := r.URL.Query().Get("user_id")
user := p.bridge.GetUserByMXID(id.UserID(userID)) user := p.bridge.GetUserByMXID(id.UserID(userID))
@@ -226,10 +226,12 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
return return
} }
log := p.log.Sub("QRLogin").Sub(user.MXID.String())
defer func() { defer func() {
err := c.Close() err := c.Close()
if err != nil { if err != nil {
user.log.Debugln("Error closing websocket:", err) log.Debugln("Error closing websocket:", err)
} }
}() }()
@@ -245,41 +247,44 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
c.SetCloseHandler(func(code int, text string) error { c.SetCloseHandler(func(code int, text string) error {
user.log.Debugfln("Login websocket closed (%d), cancelling login", code) log.Debugfln("Login websocket closed (%d), cancelling login", code)
cancel() cancel()
return nil return nil
}) })
if user.IsLoggedIn() { if user.IsLoggedIn() {
c.WriteJSON(Error{ _ = c.WriteJSON(Error{
Error: "You're already logged into Discord", Error: "You're already logged into Discord",
ErrCode: "already logged in", ErrCode: "already logged in",
}) })
return return
} }
client, err := remoteauth.New() client, err := remoteauth.New()
if err != nil { if err != nil {
user.log.Errorf("Failed to log in from provisioning API:", err) log.Errorln("Failed to prepare login:", err)
_ = c.WriteJSON(Error{
c.WriteJSON(Error{ Error: "Failed to prepare login",
Error: "Failed to connect to Discord",
ErrCode: "connection error", ErrCode: "connection error",
}) })
return
} }
qrChan := make(chan string) qrChan := make(chan string)
doneChan := make(chan struct{}) doneChan := make(chan struct{})
user.log.Debugln("Started login via provisioning API") log.Debugln("Started login via provisioning API")
err = client.Dial(ctx, qrChan, doneChan) err = client.Dial(ctx, qrChan, doneChan)
if err != nil { if err != nil {
log.Errorln("Failed to connect to Discord login websocket:", err)
close(qrChan) close(qrChan)
close(doneChan) close(doneChan)
_ = c.WriteJSON(Error{
Error: "Failed to prepare login",
ErrCode: "connection error",
})
return
} }
for { for {
@@ -288,42 +293,47 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
if !ok { if !ok {
continue continue
} }
c.WriteJSON(map[string]interface{}{ err = c.WriteJSON(map[string]interface{}{
"code": qrCode, "code": qrCode,
"timeout": 120, // TODO: move this to the library or something "timeout": 120, // TODO: move this to the library or something
}) })
case <-doneChan:
discordUser, err := client.Result()
if err != nil { if err != nil {
c.WriteJSON(Error{ log.Errorln("Failed to write QR code to websocket:", err)
Error: "Failed to connect to Discord", }
ErrCode: "connection error", case <-doneChan:
var discordUser remoteauth.User
discordUser, err = client.Result()
if err != nil {
log.Errorln("Discord login websocket returned error:", err)
_ = c.WriteJSON(Error{
Error: "Failed to log in",
ErrCode: "login fail",
}) })
p.log.Errorfln("failed to login via qrcode:", err)
return return
} }
log.Infofln("Logged in as %s#%s (%s)", discordUser.Username, discordUser.Discriminator, discordUser.UserID)
user.DiscordID = discordUser.UserID user.DiscordID = discordUser.UserID
user.Update() user.Update()
if err := user.Login(discordUser.Token); err != nil { if err = user.Login(discordUser.Token); err != nil {
c.WriteJSON(Error{ log.Errorln("Failed to connect after logging in:", err)
Error: "Failed to connect to Discord", _ = c.WriteJSON(Error{
ErrCode: "connection error", Error: "Failed to connect to Discord after logging in",
ErrCode: "connect fail",
}) })
p.log.Errorfln("failed to login via qrcode:", err)
return return
} }
c.WriteJSON(map[string]interface{}{ err = c.WriteJSON(respLogin{
"success": true, Success: true,
"id": user.DiscordID, ID: user.DiscordID,
Username: discordUser.Username,
Discriminator: discordUser.Discriminator,
}) })
if err != nil {
log.Errorln("Failed to write login success to websocket:", err)
}
return return
case <-ctx.Done(): case <-ctx.Done():
return return
@@ -331,6 +341,13 @@ func (p *ProvisioningAPI) login(w http.ResponseWriter, r *http.Request) {
} }
} }
type respLogin struct {
Success bool `json:"success"`
ID string `json:"id"`
Username string `json:"username"`
Discriminator string `json:"discriminator"`
}
func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) { func (p *ProvisioningAPI) reconnect(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User) user := r.Context().Value("user").(*User)