Close websocket connections and allow the client to re-connect on send errors; ref pterodactyl/panel#3596

This commit is contained in:
Dane Everitt 2021-10-24 16:14:00 -07:00
parent 981756b456
commit 023d7ec1ec
4 changed files with 65 additions and 22 deletions

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"time"
"emperror.dev/errors"
"github.com/gin-gonic/gin"
ws "github.com/gorilla/websocket"
@ -24,12 +25,6 @@ var expectedCloseCodes = []int{
func getServerWebsocket(c *gin.Context) {
manager := middleware.ExtractManager(c)
s, _ := manager.Get(c.Param("server"))
handler, err := websocket.GetHandler(s, c.Writer, c.Request)
if err != nil {
NewServerError(err, s).Abort(c)
return
}
defer handler.Connection.Close()
// Create a context that can be canceled when the user disconnects from this
// socket that will also cancel listeners running in separate threads. If the
@ -38,10 +33,22 @@ func getServerWebsocket(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
handler, err := websocket.GetHandler(s, c.Writer, c.Request)
if err != nil {
NewServerError(err, s).Abort(c)
return
}
defer handler.Connection.Close()
// Track this open connection on the server so that we can close them all programmatically
// if the server is deleted.
s.Websockets().Push(handler.Uuid(), &cancel)
defer s.Websockets().Remove(handler.Uuid())
handler.Logger().Debug("opening connection to server websocket")
defer func() {
s.Websockets().Remove(handler.Uuid())
handler.Logger().Debug("closing connection to server websocket")
}()
// If the server is deleted we need to send a close message to the connected client
// so that they disconnect since there will be no more events sent along. Listen for
@ -57,7 +64,15 @@ func getServerWebsocket(c *gin.Context) {
}
}()
go handler.ListenForServerEvents(ctx)
go func() {
if err := handler.ListenForServerEvents(ctx); err != nil {
handler.Logger().Warn("error while processing server event; closing websocket connection")
if err := handler.Connection.Close(); err != nil {
handler.Logger().WithField("error", errors.WithStack(err)).Error("error closing websocket connection")
}
}
}()
go handler.ListenForExpiration(ctx)
for {
@ -66,7 +81,7 @@ func getServerWebsocket(c *gin.Context) {
_, p, err := handler.Connection.ReadMessage()
if err != nil {
if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
s.Log().WithField("error", err).Warn("error handling websocket message for server")
handler.Logger().WithField("error", err).Warn("error handling websocket message for server")
}
break
}

View File

@ -2,8 +2,10 @@ package websocket
import (
"context"
"sync"
"time"
"emperror.dev/errors"
"github.com/pterodactyl/wings/events"
"github.com/pterodactyl/wings/server"
)
@ -52,24 +54,45 @@ var e = []string{
// ListenForServerEvents will listen for different events happening on a server
// and send them along to the connected websocket client. This function will
// block until the context provided to it is canceled.
func (h *Handler) ListenForServerEvents(ctx context.Context) {
h.server.Log().Debug("listening for server events over websocket")
func (h *Handler) ListenForServerEvents(pctx context.Context) error {
var o sync.Once
var err error
ctx, cancel := context.WithCancel(pctx)
h.Logger().Debug("listening for server events")
callback := func(e events.Event) {
if err := h.SendJson(&Message{Event: e.Topic, Args: []string{e.Data}}); err != nil {
h.server.Log().WithField("error", err).Warn("error while sending server data over websocket")
if sendErr := h.SendJson(&Message{Event: e.Topic, Args: []string{e.Data}}); sendErr != nil {
h.Logger().WithField("event", e.Topic).WithField("error", sendErr).Error("failed to send event over server websocket")
// Avoid race conditions by only setting the error once and then canceling
// the context. This way if additional processing errors come through due
// to a massive flood of things you still only report and stop at the first.
o.Do(func() {
err = sendErr
cancel()
})
}
}
// Subscribe to all of the events with the same callback that will push the data out over the
// websocket for the server.
// Subscribe to all of the events with the same callback that will push the
// data out over the websocket for the server.
for _, evt := range e {
h.server.Events().On(evt, &callback)
}
<-ctx.Done()
// Block until the context is stopped and then de-register all of the event listeners
// that we registered earlier.
// When this function returns de-register all of the event listeners.
defer func() {
for _, evt := range e {
h.server.Events().Off(evt, &callback)
}
}()
<-ctx.Done()
// If the internal context is stopped it is either because the parent context
// got canceled or because we ran into an error. If the "err" variable is nil
// we can assume the parent was canceled and need not perform any actions.
if err != nil {
return errors.WithStack(err)
}
return nil
}

View File

@ -75,7 +75,7 @@ func NewTokenPayload(token []byte) (*tokens.WebsocketPayload, error) {
return &payload, nil
}
// Returns a new websocket handler using the context provided.
// GetHandler returns a new websocket handler using the context provided.
func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Handler, error) {
upgrader := websocket.Upgrader{
// Ensure that the websocket request is originating from the Panel itself,
@ -116,6 +116,12 @@ func (h *Handler) Uuid() uuid.UUID {
return h.uuid
}
func (h *Handler) Logger() *log.Entry {
return log.WithField("subsystem", "websocket").
WithField("connection", h.Uuid().String()).
WithField("server", h.server.ID())
}
func (h *Handler) SendJson(v *Message) error {
// Do not send JSON down the line if the JWT on the connection is not valid!
if err := h.TokenValid(); err != nil {

View File

@ -52,8 +52,7 @@ func (w *WebsocketBag) CancelAll() {
if w.conns != nil {
for _, cancel := range w.conns {
c := *cancel
c()
(*cancel)()
}
}