Close websocket connections and allow the client to re-connect on send errors; ref pterodactyl/panel#3596
This commit is contained in:
		
							parent
							
								
									981756b456
								
							
						
					
					
						commit
						023d7ec1ec
					
				| 
						 | 
					@ -5,6 +5,7 @@ import (
 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"emperror.dev/errors"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	ws "github.com/gorilla/websocket"
 | 
						ws "github.com/gorilla/websocket"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -24,12 +25,6 @@ var expectedCloseCodes = []int{
 | 
				
			||||||
func getServerWebsocket(c *gin.Context) {
 | 
					func getServerWebsocket(c *gin.Context) {
 | 
				
			||||||
	manager := middleware.ExtractManager(c)
 | 
						manager := middleware.ExtractManager(c)
 | 
				
			||||||
	s, _ := manager.Get(c.Param("server"))
 | 
						s, _ := manager.Get(c.Param("server"))
 | 
				
			||||||
	handler, err := websocket.GetHandler(s, c.Writer, c.Request)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		NewServerError(err, s).Abort(c)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	defer handler.Connection.Close()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Create a context that can be canceled when the user disconnects from this
 | 
						// Create a context that can be canceled when the user disconnects from this
 | 
				
			||||||
	// socket that will also cancel listeners running in separate threads. If the
 | 
						// socket that will also cancel listeners running in separate threads. If the
 | 
				
			||||||
| 
						 | 
					@ -38,10 +33,22 @@ func getServerWebsocket(c *gin.Context) {
 | 
				
			||||||
	ctx, cancel := context.WithCancel(c.Request.Context())
 | 
						ctx, cancel := context.WithCancel(c.Request.Context())
 | 
				
			||||||
	defer cancel()
 | 
						defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						handler, err := websocket.GetHandler(s, c.Writer, c.Request)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							NewServerError(err, s).Abort(c)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer handler.Connection.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Track this open connection on the server so that we can close them all programmatically
 | 
						// Track this open connection on the server so that we can close them all programmatically
 | 
				
			||||||
	// if the server is deleted.
 | 
						// if the server is deleted.
 | 
				
			||||||
	s.Websockets().Push(handler.Uuid(), &cancel)
 | 
						s.Websockets().Push(handler.Uuid(), &cancel)
 | 
				
			||||||
	defer s.Websockets().Remove(handler.Uuid())
 | 
						handler.Logger().Debug("opening connection to server websocket")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							s.Websockets().Remove(handler.Uuid())
 | 
				
			||||||
 | 
							handler.Logger().Debug("closing connection to server websocket")
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If the server is deleted we need to send a close message to the connected client
 | 
						// If the server is deleted we need to send a close message to the connected client
 | 
				
			||||||
	// so that they disconnect since there will be no more events sent along. Listen for
 | 
						// so that they disconnect since there will be no more events sent along. Listen for
 | 
				
			||||||
| 
						 | 
					@ -57,7 +64,15 @@ func getServerWebsocket(c *gin.Context) {
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}()
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go handler.ListenForServerEvents(ctx)
 | 
						go func() {
 | 
				
			||||||
 | 
							if err := handler.ListenForServerEvents(ctx); err != nil {
 | 
				
			||||||
 | 
								handler.Logger().Warn("error while processing server event; closing websocket connection")
 | 
				
			||||||
 | 
								if err := handler.Connection.Close(); err != nil {
 | 
				
			||||||
 | 
									handler.Logger().WithField("error", errors.WithStack(err)).Error("error closing websocket connection")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	go handler.ListenForExpiration(ctx)
 | 
						go handler.ListenForExpiration(ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
| 
						 | 
					@ -66,7 +81,7 @@ func getServerWebsocket(c *gin.Context) {
 | 
				
			||||||
		_, p, err := handler.Connection.ReadMessage()
 | 
							_, p, err := handler.Connection.ReadMessage()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
 | 
								if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
 | 
				
			||||||
				s.Log().WithField("error", err).Warn("error handling websocket message for server")
 | 
									handler.Logger().WithField("error", err).Warn("error handling websocket message for server")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			break
 | 
								break
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,8 +2,10 @@ package websocket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"emperror.dev/errors"
 | 
				
			||||||
	"github.com/pterodactyl/wings/events"
 | 
						"github.com/pterodactyl/wings/events"
 | 
				
			||||||
	"github.com/pterodactyl/wings/server"
 | 
						"github.com/pterodactyl/wings/server"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -52,24 +54,45 @@ var e = []string{
 | 
				
			||||||
// ListenForServerEvents will listen for different events happening on a server
 | 
					// ListenForServerEvents will listen for different events happening on a server
 | 
				
			||||||
// and send them along to the connected websocket client. This function will
 | 
					// and send them along to the connected websocket client. This function will
 | 
				
			||||||
// block until the context provided to it is canceled.
 | 
					// block until the context provided to it is canceled.
 | 
				
			||||||
func (h *Handler) ListenForServerEvents(ctx context.Context) {
 | 
					func (h *Handler) ListenForServerEvents(pctx context.Context) error {
 | 
				
			||||||
	h.server.Log().Debug("listening for server events over websocket")
 | 
						var o sync.Once
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						ctx, cancel := context.WithCancel(pctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						h.Logger().Debug("listening for server events")
 | 
				
			||||||
	callback := func(e events.Event) {
 | 
						callback := func(e events.Event) {
 | 
				
			||||||
		if err := h.SendJson(&Message{Event: e.Topic, Args: []string{e.Data}}); err != nil {
 | 
							if sendErr := h.SendJson(&Message{Event: e.Topic, Args: []string{e.Data}}); sendErr != nil {
 | 
				
			||||||
			h.server.Log().WithField("error", err).Warn("error while sending server data over websocket")
 | 
								h.Logger().WithField("event", e.Topic).WithField("error", sendErr).Error("failed to send event over server websocket")
 | 
				
			||||||
 | 
								// Avoid race conditions by only setting the error once and then canceling
 | 
				
			||||||
 | 
								// the context. This way if additional processing errors come through due
 | 
				
			||||||
 | 
								// to a massive flood of things you still only report and stop at the first.
 | 
				
			||||||
 | 
								o.Do(func() {
 | 
				
			||||||
 | 
									err = sendErr
 | 
				
			||||||
 | 
									cancel()
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Subscribe to all of the events with the same callback that will push the data out over the
 | 
						// Subscribe to all of the events with the same callback that will push the
 | 
				
			||||||
	// websocket for the server.
 | 
						// data out over the websocket for the server.
 | 
				
			||||||
	for _, evt := range e {
 | 
						for _, evt := range e {
 | 
				
			||||||
		h.server.Events().On(evt, &callback)
 | 
							h.server.Events().On(evt, &callback)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// When this function returns de-register all of the event listeners.
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							for _, evt := range e {
 | 
				
			||||||
 | 
								h.server.Events().Off(evt, &callback)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	<-ctx.Done()
 | 
						<-ctx.Done()
 | 
				
			||||||
	// Block until the context is stopped and then de-register all of the event listeners
 | 
						// If the internal context is stopped it is either because the parent context
 | 
				
			||||||
	// that we registered earlier.
 | 
						// got canceled or because we ran into an error. If the "err" variable is nil
 | 
				
			||||||
	for _, evt := range e {
 | 
						// we can assume the parent was canceled and need not perform any actions.
 | 
				
			||||||
		h.server.Events().Off(evt, &callback)
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return errors.WithStack(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -75,7 +75,7 @@ func NewTokenPayload(token []byte) (*tokens.WebsocketPayload, error) {
 | 
				
			||||||
	return &payload, nil
 | 
						return &payload, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Returns a new websocket handler using the context provided.
 | 
					// GetHandler returns a new websocket handler using the context provided.
 | 
				
			||||||
func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Handler, error) {
 | 
					func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Handler, error) {
 | 
				
			||||||
	upgrader := websocket.Upgrader{
 | 
						upgrader := websocket.Upgrader{
 | 
				
			||||||
		// Ensure that the websocket request is originating from the Panel itself,
 | 
							// Ensure that the websocket request is originating from the Panel itself,
 | 
				
			||||||
| 
						 | 
					@ -116,6 +116,12 @@ func (h *Handler) Uuid() uuid.UUID {
 | 
				
			||||||
	return h.uuid
 | 
						return h.uuid
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *Handler) Logger() *log.Entry {
 | 
				
			||||||
 | 
						return log.WithField("subsystem", "websocket").
 | 
				
			||||||
 | 
							WithField("connection", h.Uuid().String()).
 | 
				
			||||||
 | 
							WithField("server", h.server.ID())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Handler) SendJson(v *Message) error {
 | 
					func (h *Handler) SendJson(v *Message) error {
 | 
				
			||||||
	// Do not send JSON down the line if the JWT on the connection is not valid!
 | 
						// Do not send JSON down the line if the JWT on the connection is not valid!
 | 
				
			||||||
	if err := h.TokenValid(); err != nil {
 | 
						if err := h.TokenValid(); err != nil {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -52,8 +52,7 @@ func (w *WebsocketBag) CancelAll() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if w.conns != nil {
 | 
						if w.conns != nil {
 | 
				
			||||||
		for _, cancel := range w.conns {
 | 
							for _, cancel := range w.conns {
 | 
				
			||||||
			c := *cancel
 | 
								(*cancel)()
 | 
				
			||||||
			c()
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user