Use the request context for cancelation, not a background context

This also fixes an improperly written server deletion listener to look at the correct context cancelation.

Theoretically this should help address the issues in pterodactyl/panel#3596 but I'm not really sure how that happens, and theres no steps for reproduction.
This commit is contained in:
Dane Everitt 2021-09-12 11:14:00 -07:00
parent ee91224eb6
commit 09e1ba6f34
3 changed files with 41 additions and 45 deletions

View File

@ -12,6 +12,14 @@ import (
"github.com/pterodactyl/wings/router/websocket"
)
var expectedCloseCodes = []int{
ws.CloseGoingAway,
ws.CloseAbnormalClosure,
ws.CloseNormalClosure,
ws.CloseNoStatusReceived,
ws.CloseServiceRestart,
}
// Upgrades a connection to a websocket and passes events along between.
func getServerWebsocket(c *gin.Context) {
manager := middleware.ExtractManager(c)
@ -24,8 +32,10 @@ func getServerWebsocket(c *gin.Context) {
defer handler.Connection.Close()
// Create a context that can be canceled when the user disconnects from this
// socket that will also cancel listeners running in separate threads.
ctx, cancel := context.WithCancel(context.Background())
// socket that will also cancel listeners running in separate threads. If the
// connection itself is terminated listeners using this context will also be
// closed.
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
// Track this open connection on the server so that we can close them all programmatically
@ -33,22 +43,19 @@ func getServerWebsocket(c *gin.Context) {
s.Websockets().Push(handler.Uuid(), &cancel)
defer s.Websockets().Remove(handler.Uuid())
// Listen for the context being canceled and then close the websocket connection. This normally
// just happens because you're disconnecting from the socket in the browser, however in some
// cases we close the connections programmatically (e.g. deleting the server) and need to send
// a close message to the websocket so it disconnects.
go func(ctx context.Context, c *ws.Conn) {
ListenerLoop:
for {
select {
case <-ctx.Done():
handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5))
// A break right here without defining the specific loop would only break the select
// and not actually break the for loop, thus causing this routine to stick around forever.
break ListenerLoop
}
// If the server is deleted we need to send a close message to the connected client
// so that they disconnect since there will be no more events sent along. Listen for
// the request context being closed to break this loop, otherwise this routine will
// be left hanging in the background.
go func() {
select {
case <-ctx.Done():
break
case <-s.Context().Done():
handler.Connection.WriteControl(ws.CloseMessage, ws.FormatCloseMessage(ws.CloseGoingAway, "server deleted"), time.Now().Add(time.Second*5))
break
}
}(ctx, handler.Connection)
}()
go handler.ListenForServerEvents(ctx)
go handler.ListenForExpiration(ctx)
@ -58,14 +65,7 @@ func getServerWebsocket(c *gin.Context) {
_, p, err := handler.Connection.ReadMessage()
if err != nil {
if !ws.IsCloseError(
err,
ws.CloseNormalClosure,
ws.CloseGoingAway,
ws.CloseNoStatusReceived,
ws.CloseServiceRestart,
ws.CloseAbnormalClosure,
) {
if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
s.Log().WithField("error", err).Warn("error handling websocket message for server")
}
break

View File

@ -8,16 +8,14 @@ import (
"github.com/pterodactyl/wings/server"
)
// Checks the time to expiration on the JWT every 30 seconds until the token has
// expired. If we are within 3 minutes of the token expiring, send a notice over
// the socket that it is expiring soon. If it has expired, send that notice as well.
// ListenForExpiration checks the time to expiration on the JWT every 30 seconds
// until the token has expired. If we are within 3 minutes of the token expiring,
// send a notice over the socket that it is expiring soon. If it has expired,
// send that notice as well.
func (h *Handler) ListenForExpiration(ctx context.Context) {
// Make a ticker and completion channel that is used to continuously poll the
// JWT stored in the session to send events to the socket when it is expiring.
ticker := time.NewTicker(time.Second * 30)
// Whenever this function is complete, end the ticker, close out the channel,
// and then close the websocket connection.
defer ticker.Stop()
for {
@ -51,8 +49,9 @@ var e = []string{
server.TransferStatusEvent,
}
// Listens for different events happening on a server and sends them along
// to the connected websocket.
// ListenForServerEvents will listen for different events happening on a server
// and send them along to the connected websocket client. This function will
// block until the context provided to it is canceled.
func (h *Handler) ListenForServerEvents(ctx context.Context) {
h.server.Log().Debug("listening for server events over websocket")
callback := func(e events.Event) {
@ -67,13 +66,10 @@ func (h *Handler) ListenForServerEvents(ctx context.Context) {
h.server.Events().On(evt, &callback)
}
go func(ctx context.Context) {
select {
case <-ctx.Done():
// Once this context is stopped, de-register all of the listeners that have been registered.
for _, evt := range e {
h.server.Events().Off(evt, &callback)
}
}
}(ctx)
<-ctx.Done()
// Block until the context is stopped and then de-register all of the event listeners
// that we registered earlier.
for _, evt := range e {
h.server.Events().Off(evt, &callback)
}
}

View File

@ -44,11 +44,11 @@ func (w *WebsocketBag) Remove(u uuid.UUID) {
w.mu.Unlock()
}
// CancelAll cancels all the stored cancel functions which has the effect of disconnecting
// every listening websocket for the server.
// CancelAll cancels all the stored cancel functions which has the effect of
// disconnecting every listening websocket for the server.
func (w *WebsocketBag) CancelAll() {
w.mu.Lock()
w.mu.Unlock()
defer w.mu.Unlock()
if w.conns != nil {
for _, cancel := range w.conns {