diff --git a/router/router_server_ws.go b/router/router_server_ws.go index c35996b..8ab1411 100644 --- a/router/router_server_ws.go +++ b/router/router_server_ws.go @@ -12,6 +12,14 @@ import ( "github.com/pterodactyl/wings/router/websocket" ) +var expectedCloseCodes = []int{ + ws.CloseGoingAway, + ws.CloseAbnormalClosure, + ws.CloseNormalClosure, + ws.CloseNoStatusReceived, + ws.CloseServiceRestart, +} + // Upgrades a connection to a websocket and passes events along between. func getServerWebsocket(c *gin.Context) { manager := middleware.ExtractManager(c) @@ -24,8 +32,10 @@ func getServerWebsocket(c *gin.Context) { defer handler.Connection.Close() // Create a context that can be canceled when the user disconnects from this - // socket that will also cancel listeners running in separate threads. - ctx, cancel := context.WithCancel(context.Background()) + // socket that will also cancel listeners running in separate threads. If the + // connection itself is terminated listeners using this context will also be + // closed. + ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() // Track this open connection on the server so that we can close them all programmatically @@ -33,22 +43,19 @@ func getServerWebsocket(c *gin.Context) { s.Websockets().Push(handler.Uuid(), &cancel) defer s.Websockets().Remove(handler.Uuid()) - // Listen for the context being canceled and then close the websocket connection. This normally - // just happens because you're disconnecting from the socket in the browser, however in some - // cases we close the connections programmatically (e.g. deleting the server) and need to send - // a close message to the websocket so it disconnects. - go func(ctx context.Context, c *ws.Conn) { - ListenerLoop: - for { - select { - case <-ctx.Done(): - handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5)) - // A break right here without defining the specific loop would only break the select - // and not actually break the for loop, thus causing this routine to stick around forever. - break ListenerLoop - } + // If the server is deleted we need to send a close message to the connected client + // so that they disconnect since there will be no more events sent along. Listen for + // the request context being closed to break this loop, otherwise this routine will + // be left hanging in the background. + go func() { + select { + case <-ctx.Done(): + break + case <-s.Context().Done(): + handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5)) + break } - }(ctx, handler.Connection) + }() go handler.ListenForServerEvents(ctx) go handler.ListenForExpiration(ctx) @@ -58,14 +65,7 @@ func getServerWebsocket(c *gin.Context) { _, p, err := handler.Connection.ReadMessage() if err != nil { - if !ws.IsCloseError( - err, - ws.CloseNormalClosure, - ws.CloseGoingAway, - ws.CloseNoStatusReceived, - ws.CloseServiceRestart, - ws.CloseAbnormalClosure, - ) { + if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) { s.Log().WithField("error", err).Warn("error handling websocket message for server") } break diff --git a/router/websocket/listeners.go b/router/websocket/listeners.go index 332f03a..acb762d 100644 --- a/router/websocket/listeners.go +++ b/router/websocket/listeners.go @@ -8,16 +8,14 @@ import ( "github.com/pterodactyl/wings/server" ) -// Checks the time to expiration on the JWT every 30 seconds until the token has -// expired. If we are within 3 minutes of the token expiring, send a notice over -// the socket that it is expiring soon. If it has expired, send that notice as well. +// ListenForExpiration checks the time to expiration on the JWT every 30 seconds +// until the token has expired. If we are within 3 minutes of the token expiring, +// send a notice over the socket that it is expiring soon. If it has expired, +// send that notice as well. func (h *Handler) ListenForExpiration(ctx context.Context) { // Make a ticker and completion channel that is used to continuously poll the // JWT stored in the session to send events to the socket when it is expiring. ticker := time.NewTicker(time.Second * 30) - - // Whenever this function is complete, end the ticker, close out the channel, - // and then close the websocket connection. defer ticker.Stop() for { @@ -51,8 +49,9 @@ var e = []string{ server.TransferStatusEvent, } -// Listens for different events happening on a server and sends them along -// to the connected websocket. +// ListenForServerEvents will listen for different events happening on a server +// and send them along to the connected websocket client. This function will +// block until the context provided to it is canceled. func (h *Handler) ListenForServerEvents(ctx context.Context) { h.server.Log().Debug("listening for server events over websocket") callback := func(e events.Event) { @@ -67,13 +66,10 @@ func (h *Handler) ListenForServerEvents(ctx context.Context) { h.server.Events().On(evt, &callback) } - go func(ctx context.Context) { - select { - case <-ctx.Done(): - // Once this context is stopped, de-register all of the listeners that have been registered. - for _, evt := range e { - h.server.Events().Off(evt, &callback) - } - } - }(ctx) + <-ctx.Done() + // Block until the context is stopped and then de-register all of the event listeners + // that we registered earlier. + for _, evt := range e { + h.server.Events().Off(evt, &callback) + } } diff --git a/server/websockets.go b/server/websockets.go index 9938512..981fa79 100644 --- a/server/websockets.go +++ b/server/websockets.go @@ -44,11 +44,11 @@ func (w *WebsocketBag) Remove(u uuid.UUID) { w.mu.Unlock() } -// CancelAll cancels all the stored cancel functions which has the effect of disconnecting -// every listening websocket for the server. +// CancelAll cancels all the stored cancel functions which has the effect of +// disconnecting every listening websocket for the server. func (w *WebsocketBag) CancelAll() { w.mu.Lock() - w.mu.Unlock() + defer w.mu.Unlock() if w.conns != nil { for _, cancel := range w.conns {