diff --git a/router/router_server_ws.go b/router/router_server_ws.go index 8ab1411..22cfca7 100644 --- a/router/router_server_ws.go +++ b/router/router_server_ws.go @@ -5,6 +5,7 @@ import ( "encoding/json" "time" + "emperror.dev/errors" "github.com/gin-gonic/gin" ws "github.com/gorilla/websocket" @@ -24,12 +25,6 @@ var expectedCloseCodes = []int{ func getServerWebsocket(c *gin.Context) { manager := middleware.ExtractManager(c) s, _ := manager.Get(c.Param("server")) - handler, err := websocket.GetHandler(s, c.Writer, c.Request) - if err != nil { - NewServerError(err, s).Abort(c) - return - } - defer handler.Connection.Close() // Create a context that can be canceled when the user disconnects from this // socket that will also cancel listeners running in separate threads. If the @@ -38,10 +33,22 @@ func getServerWebsocket(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() + handler, err := websocket.GetHandler(s, c.Writer, c.Request) + if err != nil { + NewServerError(err, s).Abort(c) + return + } + defer handler.Connection.Close() + // Track this open connection on the server so that we can close them all programmatically // if the server is deleted. s.Websockets().Push(handler.Uuid(), &cancel) - defer s.Websockets().Remove(handler.Uuid()) + handler.Logger().Debug("opening connection to server websocket") + + defer func() { + s.Websockets().Remove(handler.Uuid()) + handler.Logger().Debug("closing connection to server websocket") + }() // If the server is deleted we need to send a close message to the connected client // so that they disconnect since there will be no more events sent along. Listen for @@ -57,7 +64,15 @@ func getServerWebsocket(c *gin.Context) { } }() - go handler.ListenForServerEvents(ctx) + go func() { + if err := handler.ListenForServerEvents(ctx); err != nil { + handler.Logger().Warn("error while processing server event; closing websocket connection") + if err := handler.Connection.Close(); err != nil { + handler.Logger().WithField("error", errors.WithStack(err)).Error("error closing websocket connection") + } + } + }() + go handler.ListenForExpiration(ctx) for { @@ -66,7 +81,7 @@ func getServerWebsocket(c *gin.Context) { _, p, err := handler.Connection.ReadMessage() if err != nil { if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) { - s.Log().WithField("error", err).Warn("error handling websocket message for server") + handler.Logger().WithField("error", err).Warn("error handling websocket message for server") } break } diff --git a/router/websocket/listeners.go b/router/websocket/listeners.go index acb762d..21cf859 100644 --- a/router/websocket/listeners.go +++ b/router/websocket/listeners.go @@ -2,8 +2,10 @@ package websocket import ( "context" + "sync" "time" + "emperror.dev/errors" "github.com/pterodactyl/wings/events" "github.com/pterodactyl/wings/server" ) @@ -52,24 +54,45 @@ var e = []string{ // ListenForServerEvents will listen for different events happening on a server // and send them along to the connected websocket client. This function will // block until the context provided to it is canceled. -func (h *Handler) ListenForServerEvents(ctx context.Context) { - h.server.Log().Debug("listening for server events over websocket") +func (h *Handler) ListenForServerEvents(pctx context.Context) error { + var o sync.Once + var err error + ctx, cancel := context.WithCancel(pctx) + + h.Logger().Debug("listening for server events") callback := func(e events.Event) { - if err := h.SendJson(&Message{Event: e.Topic, Args: []string{e.Data}}); err != nil { - h.server.Log().WithField("error", err).Warn("error while sending server data over websocket") + if sendErr := h.SendJson(&Message{Event: e.Topic, Args: []string{e.Data}}); sendErr != nil { + h.Logger().WithField("event", e.Topic).WithField("error", sendErr).Error("failed to send event over server websocket") + // Avoid race conditions by only setting the error once and then canceling + // the context. This way if additional processing errors come through due + // to a massive flood of things you still only report and stop at the first. + o.Do(func() { + err = sendErr + cancel() + }) } } - // Subscribe to all of the events with the same callback that will push the data out over the - // websocket for the server. + // Subscribe to all of the events with the same callback that will push the + // data out over the websocket for the server. for _, evt := range e { h.server.Events().On(evt, &callback) } + // When this function returns de-register all of the event listeners. + defer func() { + for _, evt := range e { + h.server.Events().Off(evt, &callback) + } + }() + <-ctx.Done() - // Block until the context is stopped and then de-register all of the event listeners - // that we registered earlier. - for _, evt := range e { - h.server.Events().Off(evt, &callback) + // If the internal context is stopped it is either because the parent context + // got canceled or because we ran into an error. If the "err" variable is nil + // we can assume the parent was canceled and need not perform any actions. + if err != nil { + return errors.WithStack(err) } + + return nil } diff --git a/router/websocket/websocket.go b/router/websocket/websocket.go index ff5855b..0d37fca 100644 --- a/router/websocket/websocket.go +++ b/router/websocket/websocket.go @@ -75,7 +75,7 @@ func NewTokenPayload(token []byte) (*tokens.WebsocketPayload, error) { return &payload, nil } -// Returns a new websocket handler using the context provided. +// GetHandler returns a new websocket handler using the context provided. func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Handler, error) { upgrader := websocket.Upgrader{ // Ensure that the websocket request is originating from the Panel itself, @@ -116,6 +116,12 @@ func (h *Handler) Uuid() uuid.UUID { return h.uuid } +func (h *Handler) Logger() *log.Entry { + return log.WithField("subsystem", "websocket"). + WithField("connection", h.Uuid().String()). + WithField("server", h.server.ID()) +} + func (h *Handler) SendJson(v *Message) error { // Do not send JSON down the line if the JWT on the connection is not valid! if err := h.TokenValid(); err != nil { diff --git a/server/websockets.go b/server/websockets.go index 981fa79..e86f88c 100644 --- a/server/websockets.go +++ b/server/websockets.go @@ -52,8 +52,7 @@ func (w *WebsocketBag) CancelAll() { if w.conns != nil { for _, cancel := range w.conns { - c := *cancel - c() + (*cancel)() } }