diff --git a/router/websocket/listeners.go b/router/websocket/listeners.go index 6d99b51..bf872c6 100644 --- a/router/websocket/listeners.go +++ b/router/websocket/listeners.go @@ -13,28 +13,22 @@ func (h *Handler) ListenForExpiration(ctx context.Context) { // Make a ticker and completion channel that is used to continuously poll the // JWT stored in the session to send events to the socket when it is expiring. ticker := time.NewTicker(time.Second * 30) - done := make(chan bool) // Whenever this function is complete, end the ticker, close out the channel, // and then close the websocket connection. - defer func() { - ticker.Stop() - done <- true - }() + defer ticker.Stop() for { select { case <-ctx.Done(): - case <-done: return case <-ticker.C: - { - if h.JWT != nil { - if h.JWT.ExpirationTime.Unix()-time.Now().Unix() <= 0 { - h.SendJson(&Message{Event: TokenExpiredEvent}) - } else if h.JWT.ExpirationTime.Unix()-time.Now().Unix() <= 180 { - h.SendJson(&Message{Event: TokenExpiringEvent}) - } + jwt := h.GetJwt() + if jwt != nil { + if jwt.ExpirationTime.Unix()-time.Now().Unix() <= 0 { + h.SendJson(&Message{Event: TokenExpiredEvent}) + } else if jwt.ExpirationTime.Unix()-time.Now().Unix() <= 180 { + h.SendJson(&Message{Event: TokenExpiringEvent}) } } } diff --git a/router/websocket/websocket.go b/router/websocket/websocket.go index 403ab43..5e7d4f6 100644 --- a/router/websocket/websocket.go +++ b/router/websocket/websocket.go @@ -31,10 +31,10 @@ const ( ) type Handler struct { + sync.Mutex Connection *websocket.Conn - JWT *tokens.WebsocketPayload `json:"-"` + jwt *tokens.WebsocketPayload `json:"-"` server *server.Server - mutex sync.Mutex } // Parses a JWT into a websocket token payload. @@ -69,9 +69,8 @@ func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Hand return &Handler{ Connection: conn, - JWT: nil, + jwt: nil, server: s, - mutex: sync.Mutex{}, }, nil } @@ -82,20 +81,23 @@ func (h *Handler) SendJson(v *Message) error { return nil } - // If we're sending installation output but the user does not have the required - // permissions to see the output, don't send it down the line. - if v.Event == server.InstallOutputEvent { - zap.S().Debugf("%+v", v.Args) - if h.JWT != nil && !h.JWT.HasPermission(PermissionReceiveInstall) { - return nil + j := h.GetJwt() + if j != nil { + // If we're sending installation output but the user does not have the required + // permissions to see the output, don't send it down the line. + if v.Event == server.InstallOutputEvent { + zap.S().Debugf("%+v", v.Args) + if !j.HasPermission(PermissionReceiveInstall) { + return nil + } } - } - // If the user does not have permission to see backup events, do not emit - // them over the socket. - if strings.HasPrefix(v.Event, server.BackupCompletedEvent) { - if h.JWT != nil && !h.JWT.HasPermission(PermissionReceiveBackups) { - return nil + // If the user does not have permission to see backup events, do not emit + // them over the socket. + if strings.HasPrefix(v.Event, server.BackupCompletedEvent) { + if !j.HasPermission(PermissionReceiveBackups) { + return nil + } } } @@ -106,27 +108,28 @@ func (h *Handler) SendJson(v *Message) error { // socket user. Do not call this directly unless you are positive a response should be // sent back to the client! func (h *Handler) unsafeSendJson(v interface{}) error { - h.mutex.Lock() - defer h.mutex.Unlock() + h.Lock() + defer h.Unlock() return h.Connection.WriteJSON(v) } // Checks if the JWT is still valid. func (h *Handler) TokenValid() error { - if h.JWT == nil { + j := h.GetJwt() + if j == nil { return errors.New("no jwt present") } - if err := jwt.ExpirationTimeValidator(time.Now())(&h.JWT.Payload); err != nil { + if err := jwt.ExpirationTimeValidator(time.Now())(&j.Payload); err != nil { return err } - if !h.JWT.HasPermission(PermissionConnect) { + if !j.HasPermission(PermissionConnect) { return errors.New("jwt does not have connect permission") } - if h.server.Uuid != h.JWT.ServerUUID { + if h.server.Uuid != j.ServerUUID { return errors.New("jwt server uuid mismatch") } @@ -137,14 +140,14 @@ func (h *Handler) TokenValid() error { // of the token. If the user has the "receive-errors" grant we will send back the actual // error message, otherwise we just send back a standard error message. func (h *Handler) SendErrorJson(err error) error { - h.mutex.Lock() - defer h.mutex.Unlock() + h.Lock() + defer h.Unlock() + + j := h.GetJwt() message := "an unexpected error was encountered while handling this request" - if h.JWT != nil { - if server.IsSuspendedError(err) || h.JWT.HasPermission(PermissionReceiveErrors) { - message = err.Error() - } + if server.IsSuspendedError(err) || (j != nil && j.HasPermission(PermissionReceiveErrors)) { + message = err.Error() } m, u := h.GetErrorMessage(message) @@ -174,6 +177,20 @@ func (h *Handler) GetErrorMessage(msg string) (string, uuid.UUID) { return m, u } +// Sets the JWT for the websocket in a race-safe manner. +func (h *Handler) setJwt(token *tokens.WebsocketPayload) { + h.Lock() + h.jwt = token + h.Unlock() +} + +func (h *Handler) GetJwt() *tokens.WebsocketPayload { + h.Lock() + defer h.Unlock() + + return h.jwt +} + // Handle the inbound socket request and route it to the proper server action. func (h *Handler) HandleInbound(m Message) error { if m.Event != AuthenticationEvent { @@ -198,7 +215,7 @@ func (h *Handler) HandleInbound(m Message) error { } if token.HasPermission(PermissionConnect) { - h.JWT = token + h.setJwt(token) } // On every authentication event, send the current server status back @@ -216,17 +233,17 @@ func (h *Handler) HandleInbound(m Message) error { { switch strings.Join(m.Args, "") { case "start": - if h.JWT.HasPermission(PermissionSendPowerStart) { + if h.GetJwt().HasPermission(PermissionSendPowerStart) { return h.server.Environment.Start() } break case "stop": - if h.JWT.HasPermission(PermissionSendPowerStop) { + if h.GetJwt().HasPermission(PermissionSendPowerStop) { return h.server.Environment.Stop() } break case "restart": - if h.JWT.HasPermission(PermissionSendPowerRestart) { + if h.GetJwt().HasPermission(PermissionSendPowerRestart) { if err := h.server.Environment.WaitForStop(60, false); err != nil { return err } @@ -235,7 +252,7 @@ func (h *Handler) HandleInbound(m Message) error { } break case "kill": - if h.JWT.HasPermission(PermissionSendPowerStop) { + if h.GetJwt().HasPermission(PermissionSendPowerStop) { return h.server.Environment.Terminate(os.Kill) } break @@ -265,7 +282,7 @@ func (h *Handler) HandleInbound(m Message) error { } case SendCommandEvent: { - if !h.JWT.HasPermission(PermissionSendCommand) { + if !h.GetJwt().HasPermission(PermissionSendCommand) { return nil }