Use the request context for cancelation, not a background context

This also fixes an improperly written server deletion listener to look at the correct context cancelation.

Theoretically this should help address the issues in  but I'm not really sure how that happens, and theres no steps for reproduction.
This commit is contained in:
Dane Everitt 2021-09-12 11:14:00 -07:00
parent ee91224eb6
commit 09e1ba6f34
3 changed files with 41 additions and 45 deletions

View File

@ -12,6 +12,14 @@ import (
"github.com/pterodactyl/wings/router/websocket" "github.com/pterodactyl/wings/router/websocket"
) )
var expectedCloseCodes = []int{
ws.CloseGoingAway,
ws.CloseAbnormalClosure,
ws.CloseNormalClosure,
ws.CloseNoStatusReceived,
ws.CloseServiceRestart,
}
// Upgrades a connection to a websocket and passes events along between. // Upgrades a connection to a websocket and passes events along between.
func getServerWebsocket(c *gin.Context) { func getServerWebsocket(c *gin.Context) {
manager := middleware.ExtractManager(c) manager := middleware.ExtractManager(c)
@ -24,8 +32,10 @@ func getServerWebsocket(c *gin.Context) {
defer handler.Connection.Close() defer handler.Connection.Close()
// Create a context that can be canceled when the user disconnects from this // Create a context that can be canceled when the user disconnects from this
// socket that will also cancel listeners running in separate threads. // socket that will also cancel listeners running in separate threads. If the
ctx, cancel := context.WithCancel(context.Background()) // connection itself is terminated listeners using this context will also be
// closed.
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
// Track this open connection on the server so that we can close them all programmatically // Track this open connection on the server so that we can close them all programmatically
@ -33,22 +43,19 @@ func getServerWebsocket(c *gin.Context) {
s.Websockets().Push(handler.Uuid(), &cancel) s.Websockets().Push(handler.Uuid(), &cancel)
defer s.Websockets().Remove(handler.Uuid()) defer s.Websockets().Remove(handler.Uuid())
// Listen for the context being canceled and then close the websocket connection. This normally // If the server is deleted we need to send a close message to the connected client
// just happens because you're disconnecting from the socket in the browser, however in some // so that they disconnect since there will be no more events sent along. Listen for
// cases we close the connections programmatically (e.g. deleting the server) and need to send // the request context being closed to break this loop, otherwise this routine will
// a close message to the websocket so it disconnects. // be left hanging in the background.
go func(ctx context.Context, c *ws.Conn) { go func() {
ListenerLoop: select {
for { case <-ctx.Done():
select { break
case <-ctx.Done(): case <-s.Context().Done():
handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5)) handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5))
// A break right here without defining the specific loop would only break the select break
// and not actually break the for loop, thus causing this routine to stick around forever.
break ListenerLoop
}
} }
}(ctx, handler.Connection) }()
go handler.ListenForServerEvents(ctx) go handler.ListenForServerEvents(ctx)
go handler.ListenForExpiration(ctx) go handler.ListenForExpiration(ctx)
@ -58,14 +65,7 @@ func getServerWebsocket(c *gin.Context) {
_, p, err := handler.Connection.ReadMessage() _, p, err := handler.Connection.ReadMessage()
if err != nil { if err != nil {
if !ws.IsCloseError( if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
err,
ws.CloseNormalClosure,
ws.CloseGoingAway,
ws.CloseNoStatusReceived,
ws.CloseServiceRestart,
ws.CloseAbnormalClosure,
) {
s.Log().WithField("error", err).Warn("error handling websocket message for server") s.Log().WithField("error", err).Warn("error handling websocket message for server")
} }
break break

View File

@ -8,16 +8,14 @@ import (
"github.com/pterodactyl/wings/server" "github.com/pterodactyl/wings/server"
) )
// Checks the time to expiration on the JWT every 30 seconds until the token has // ListenForExpiration checks the time to expiration on the JWT every 30 seconds
// expired. If we are within 3 minutes of the token expiring, send a notice over // until the token has expired. If we are within 3 minutes of the token expiring,
// the socket that it is expiring soon. If it has expired, send that notice as well. // send a notice over the socket that it is expiring soon. If it has expired,
// send that notice as well.
func (h *Handler) ListenForExpiration(ctx context.Context) { func (h *Handler) ListenForExpiration(ctx context.Context) {
// Make a ticker and completion channel that is used to continuously poll the // Make a ticker and completion channel that is used to continuously poll the
// JWT stored in the session to send events to the socket when it is expiring. // JWT stored in the session to send events to the socket when it is expiring.
ticker := time.NewTicker(time.Second * 30) ticker := time.NewTicker(time.Second * 30)
// Whenever this function is complete, end the ticker, close out the channel,
// and then close the websocket connection.
defer ticker.Stop() defer ticker.Stop()
for { for {
@ -51,8 +49,9 @@ var e = []string{
server.TransferStatusEvent, server.TransferStatusEvent,
} }
// Listens for different events happening on a server and sends them along // ListenForServerEvents will listen for different events happening on a server
// to the connected websocket. // and send them along to the connected websocket client. This function will
// block until the context provided to it is canceled.
func (h *Handler) ListenForServerEvents(ctx context.Context) { func (h *Handler) ListenForServerEvents(ctx context.Context) {
h.server.Log().Debug("listening for server events over websocket") h.server.Log().Debug("listening for server events over websocket")
callback := func(e events.Event) { callback := func(e events.Event) {
@ -67,13 +66,10 @@ func (h *Handler) ListenForServerEvents(ctx context.Context) {
h.server.Events().On(evt, &callback) h.server.Events().On(evt, &callback)
} }
go func(ctx context.Context) { <-ctx.Done()
select { // Block until the context is stopped and then de-register all of the event listeners
case <-ctx.Done(): // that we registered earlier.
// Once this context is stopped, de-register all of the listeners that have been registered. for _, evt := range e {
for _, evt := range e { h.server.Events().Off(evt, &callback)
h.server.Events().Off(evt, &callback) }
}
}
}(ctx)
} }

View File

@ -44,11 +44,11 @@ func (w *WebsocketBag) Remove(u uuid.UUID) {
w.mu.Unlock() w.mu.Unlock()
} }
// CancelAll cancels all the stored cancel functions which has the effect of disconnecting // CancelAll cancels all the stored cancel functions which has the effect of
// every listening websocket for the server. // disconnecting every listening websocket for the server.
func (w *WebsocketBag) CancelAll() { func (w *WebsocketBag) CancelAll() {
w.mu.Lock() w.mu.Lock()
w.mu.Unlock() defer w.mu.Unlock()
if w.conns != nil { if w.conns != nil {
for _, cancel := range w.conns { for _, cancel := range w.conns {