Use the request context for cancelation, not a background context

This also fixes an improperly written server deletion listener to look at the correct context cancelation.

Theoretically this should help address the issues in pterodactyl/panel#3596 but I'm not really sure how that happens, and theres no steps for reproduction.
This commit is contained in:
Dane Everitt 2021-09-12 11:14:00 -07:00
parent ee91224eb6
commit 09e1ba6f34
3 changed files with 41 additions and 45 deletions

View File

@ -12,6 +12,14 @@ import (
"github.com/pterodactyl/wings/router/websocket" "github.com/pterodactyl/wings/router/websocket"
) )
var expectedCloseCodes = []int{
ws.CloseGoingAway,
ws.CloseAbnormalClosure,
ws.CloseNormalClosure,
ws.CloseNoStatusReceived,
ws.CloseServiceRestart,
}
// Upgrades a connection to a websocket and passes events along between. // Upgrades a connection to a websocket and passes events along between.
func getServerWebsocket(c *gin.Context) { func getServerWebsocket(c *gin.Context) {
manager := middleware.ExtractManager(c) manager := middleware.ExtractManager(c)
@ -24,8 +32,10 @@ func getServerWebsocket(c *gin.Context) {
defer handler.Connection.Close() defer handler.Connection.Close()
// Create a context that can be canceled when the user disconnects from this // Create a context that can be canceled when the user disconnects from this
// socket that will also cancel listeners running in separate threads. // socket that will also cancel listeners running in separate threads. If the
ctx, cancel := context.WithCancel(context.Background()) // connection itself is terminated listeners using this context will also be
// closed.
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
// Track this open connection on the server so that we can close them all programmatically // Track this open connection on the server so that we can close them all programmatically
@ -33,22 +43,19 @@ func getServerWebsocket(c *gin.Context) {
s.Websockets().Push(handler.Uuid(), &cancel) s.Websockets().Push(handler.Uuid(), &cancel)
defer s.Websockets().Remove(handler.Uuid()) defer s.Websockets().Remove(handler.Uuid())
// Listen for the context being canceled and then close the websocket connection. This normally // If the server is deleted we need to send a close message to the connected client
// just happens because you're disconnecting from the socket in the browser, however in some // so that they disconnect since there will be no more events sent along. Listen for
// cases we close the connections programmatically (e.g. deleting the server) and need to send // the request context being closed to break this loop, otherwise this routine will
// a close message to the websocket so it disconnects. // be left hanging in the background.
go func(ctx context.Context, c *ws.Conn) { go func() {
ListenerLoop:
for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
break
case <-s.Context().Done():
handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5)) handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5))
// A break right here without defining the specific loop would only break the select break
// and not actually break the for loop, thus causing this routine to stick around forever.
break ListenerLoop
} }
} }()
}(ctx, handler.Connection)
go handler.ListenForServerEvents(ctx) go handler.ListenForServerEvents(ctx)
go handler.ListenForExpiration(ctx) go handler.ListenForExpiration(ctx)
@ -58,14 +65,7 @@ func getServerWebsocket(c *gin.Context) {
_, p, err := handler.Connection.ReadMessage() _, p, err := handler.Connection.ReadMessage()
if err != nil { if err != nil {
if !ws.IsCloseError( if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
err,
ws.CloseNormalClosure,
ws.CloseGoingAway,
ws.CloseNoStatusReceived,
ws.CloseServiceRestart,
ws.CloseAbnormalClosure,
) {
s.Log().WithField("error", err).Warn("error handling websocket message for server") s.Log().WithField("error", err).Warn("error handling websocket message for server")
} }
break break

View File

@ -8,16 +8,14 @@ import (
"github.com/pterodactyl/wings/server" "github.com/pterodactyl/wings/server"
) )
// Checks the time to expiration on the JWT every 30 seconds until the token has // ListenForExpiration checks the time to expiration on the JWT every 30 seconds
// expired. If we are within 3 minutes of the token expiring, send a notice over // until the token has expired. If we are within 3 minutes of the token expiring,
// the socket that it is expiring soon. If it has expired, send that notice as well. // send a notice over the socket that it is expiring soon. If it has expired,
// send that notice as well.
func (h *Handler) ListenForExpiration(ctx context.Context) { func (h *Handler) ListenForExpiration(ctx context.Context) {
// Make a ticker and completion channel that is used to continuously poll the // Make a ticker and completion channel that is used to continuously poll the
// JWT stored in the session to send events to the socket when it is expiring. // JWT stored in the session to send events to the socket when it is expiring.
ticker := time.NewTicker(time.Second * 30) ticker := time.NewTicker(time.Second * 30)
// Whenever this function is complete, end the ticker, close out the channel,
// and then close the websocket connection.
defer ticker.Stop() defer ticker.Stop()
for { for {
@ -51,8 +49,9 @@ var e = []string{
server.TransferStatusEvent, server.TransferStatusEvent,
} }
// Listens for different events happening on a server and sends them along // ListenForServerEvents will listen for different events happening on a server
// to the connected websocket. // and send them along to the connected websocket client. This function will
// block until the context provided to it is canceled.
func (h *Handler) ListenForServerEvents(ctx context.Context) { func (h *Handler) ListenForServerEvents(ctx context.Context) {
h.server.Log().Debug("listening for server events over websocket") h.server.Log().Debug("listening for server events over websocket")
callback := func(e events.Event) { callback := func(e events.Event) {
@ -67,13 +66,10 @@ func (h *Handler) ListenForServerEvents(ctx context.Context) {
h.server.Events().On(evt, &callback) h.server.Events().On(evt, &callback)
} }
go func(ctx context.Context) { <-ctx.Done()
select { // Block until the context is stopped and then de-register all of the event listeners
case <-ctx.Done(): // that we registered earlier.
// Once this context is stopped, de-register all of the listeners that have been registered.
for _, evt := range e { for _, evt := range e {
h.server.Events().Off(evt, &callback) h.server.Events().Off(evt, &callback)
} }
} }
}(ctx)
}

View File

@ -44,11 +44,11 @@ func (w *WebsocketBag) Remove(u uuid.UUID) {
w.mu.Unlock() w.mu.Unlock()
} }
// CancelAll cancels all the stored cancel functions which has the effect of disconnecting // CancelAll cancels all the stored cancel functions which has the effect of
// every listening websocket for the server. // disconnecting every listening websocket for the server.
func (w *WebsocketBag) CancelAll() { func (w *WebsocketBag) CancelAll() {
w.mu.Lock() w.mu.Lock()
w.mu.Unlock() defer w.mu.Unlock()
if w.conns != nil { if w.conns != nil {
for _, cancel := range w.conns { for _, cancel := range w.conns {