Close websocket connections and allow the client to re-connect on send errors; ref pterodactyl/panel#3596

This commit is contained in:
Dane Everitt 2021-10-24 16:14:00 -07:00
parent 981756b456
commit 023d7ec1ec
4 changed files with 65 additions and 22 deletions

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"emperror.dev/errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
ws "github.com/gorilla/websocket" ws "github.com/gorilla/websocket"
@ -24,12 +25,6 @@ var expectedCloseCodes = []int{
func getServerWebsocket(c *gin.Context) { func getServerWebsocket(c *gin.Context) {
manager := middleware.ExtractManager(c) manager := middleware.ExtractManager(c)
s, _ := manager.Get(c.Param("server")) s, _ := manager.Get(c.Param("server"))
handler, err := websocket.GetHandler(s, c.Writer, c.Request)
if err != nil {
NewServerError(err, s).Abort(c)
return
}
defer handler.Connection.Close()
// Create a context that can be canceled when the user disconnects from this // Create a context that can be canceled when the user disconnects from this
// socket that will also cancel listeners running in separate threads. If the // socket that will also cancel listeners running in separate threads. If the
@ -38,10 +33,22 @@ func getServerWebsocket(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
handler, err := websocket.GetHandler(s, c.Writer, c.Request)
if err != nil {
NewServerError(err, s).Abort(c)
return
}
defer handler.Connection.Close()
// Track this open connection on the server so that we can close them all programmatically // Track this open connection on the server so that we can close them all programmatically
// if the server is deleted. // if the server is deleted.
s.Websockets().Push(handler.Uuid(), &cancel) s.Websockets().Push(handler.Uuid(), &cancel)
defer s.Websockets().Remove(handler.Uuid()) handler.Logger().Debug("opening connection to server websocket")
defer func() {
s.Websockets().Remove(handler.Uuid())
handler.Logger().Debug("closing connection to server websocket")
}()
// If the server is deleted we need to send a close message to the connected client // If the server is deleted we need to send a close message to the connected client
// so that they disconnect since there will be no more events sent along. Listen for // so that they disconnect since there will be no more events sent along. Listen for
@ -57,7 +64,15 @@ func getServerWebsocket(c *gin.Context) {
} }
}() }()
go handler.ListenForServerEvents(ctx) go func() {
if err := handler.ListenForServerEvents(ctx); err != nil {
handler.Logger().Warn("error while processing server event; closing websocket connection")
if err := handler.Connection.Close(); err != nil {
handler.Logger().WithField("error", errors.WithStack(err)).Error("error closing websocket connection")
}
}
}()
go handler.ListenForExpiration(ctx) go handler.ListenForExpiration(ctx)
for { for {
@ -66,7 +81,7 @@ func getServerWebsocket(c *gin.Context) {
_, p, err := handler.Connection.ReadMessage() _, p, err := handler.Connection.ReadMessage()
if err != nil { if err != nil {
if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) { if ws.IsUnexpectedCloseError(err, expectedCloseCodes...) {
s.Log().WithField("error", err).Warn("error handling websocket message for server") handler.Logger().WithField("error", err).Warn("error handling websocket message for server")
} }
break break
} }

View File

@ -2,8 +2,10 @@ package websocket
import ( import (
"context" "context"
"sync"
"time" "time"
"emperror.dev/errors"
"github.com/pterodactyl/wings/events" "github.com/pterodactyl/wings/events"
"github.com/pterodactyl/wings/server" "github.com/pterodactyl/wings/server"
) )
@ -52,24 +54,45 @@ var e = []string{
// ListenForServerEvents will listen for different events happening on a server // ListenForServerEvents will listen for different events happening on a server
// and send them along to the connected websocket client. This function will // and send them along to the connected websocket client. This function will
// block until the context provided to it is canceled. // block until the context provided to it is canceled.
func (h *Handler) ListenForServerEvents(ctx context.Context) { func (h *Handler) ListenForServerEvents(pctx context.Context) error {
h.server.Log().Debug("listening for server events over websocket") var o sync.Once
var err error
ctx, cancel := context.WithCancel(pctx)
h.Logger().Debug("listening for server events")
callback := func(e events.Event) { callback := func(e events.Event) {
if err := h.SendJson(&Message{Event: e.Topic, Args: []string{e.Data}}); err != nil { if sendErr := h.SendJson(&Message{Event: e.Topic, Args: []string{e.Data}}); sendErr != nil {
h.server.Log().WithField("error", err).Warn("error while sending server data over websocket") h.Logger().WithField("event", e.Topic).WithField("error", sendErr).Error("failed to send event over server websocket")
// Avoid race conditions by only setting the error once and then canceling
// the context. This way if additional processing errors come through due
// to a massive flood of things you still only report and stop at the first.
o.Do(func() {
err = sendErr
cancel()
})
} }
} }
// Subscribe to all of the events with the same callback that will push the data out over the // Subscribe to all of the events with the same callback that will push the
// websocket for the server. // data out over the websocket for the server.
for _, evt := range e { for _, evt := range e {
h.server.Events().On(evt, &callback) h.server.Events().On(evt, &callback)
} }
<-ctx.Done() // When this function returns de-register all of the event listeners.
// Block until the context is stopped and then de-register all of the event listeners defer func() {
// that we registered earlier.
for _, evt := range e { for _, evt := range e {
h.server.Events().Off(evt, &callback) h.server.Events().Off(evt, &callback)
} }
}()
<-ctx.Done()
// If the internal context is stopped it is either because the parent context
// got canceled or because we ran into an error. If the "err" variable is nil
// we can assume the parent was canceled and need not perform any actions.
if err != nil {
return errors.WithStack(err)
}
return nil
} }

View File

@ -75,7 +75,7 @@ func NewTokenPayload(token []byte) (*tokens.WebsocketPayload, error) {
return &payload, nil return &payload, nil
} }
// Returns a new websocket handler using the context provided. // GetHandler returns a new websocket handler using the context provided.
func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Handler, error) { func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Handler, error) {
upgrader := websocket.Upgrader{ upgrader := websocket.Upgrader{
// Ensure that the websocket request is originating from the Panel itself, // Ensure that the websocket request is originating from the Panel itself,
@ -116,6 +116,12 @@ func (h *Handler) Uuid() uuid.UUID {
return h.uuid return h.uuid
} }
func (h *Handler) Logger() *log.Entry {
return log.WithField("subsystem", "websocket").
WithField("connection", h.Uuid().String()).
WithField("server", h.server.ID())
}
func (h *Handler) SendJson(v *Message) error { func (h *Handler) SendJson(v *Message) error {
// Do not send JSON down the line if the JWT on the connection is not valid! // Do not send JSON down the line if the JWT on the connection is not valid!
if err := h.TokenValid(); err != nil { if err := h.TokenValid(); err != nil {

View File

@ -52,8 +52,7 @@ func (w *WebsocketBag) CancelAll() {
if w.conns != nil { if w.conns != nil {
for _, cancel := range w.conns { for _, cancel := range w.conns {
c := *cancel (*cancel)()
c()
} }
} }