Fix race condition warnings when attempting to work with JWT data in the websocket

This commit is contained in:
Dane Everitt 2020-04-10 17:06:59 -07:00
parent 2220eb049c
commit e91dd84279
No known key found for this signature in database
GPG Key ID: EEA66103B3D71F53
2 changed files with 58 additions and 47 deletions

View File

@ -13,28 +13,22 @@ func (h *Handler) ListenForExpiration(ctx context.Context) {
// Make a ticker and completion channel that is used to continuously poll the // Make a ticker and completion channel that is used to continuously poll the
// JWT stored in the session to send events to the socket when it is expiring. // JWT stored in the session to send events to the socket when it is expiring.
ticker := time.NewTicker(time.Second * 30) ticker := time.NewTicker(time.Second * 30)
done := make(chan bool)
// Whenever this function is complete, end the ticker, close out the channel, // Whenever this function is complete, end the ticker, close out the channel,
// and then close the websocket connection. // and then close the websocket connection.
defer func() { defer ticker.Stop()
ticker.Stop()
done <- true
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
case <-done:
return return
case <-ticker.C: case <-ticker.C:
{ jwt := h.GetJwt()
if h.JWT != nil { if jwt != nil {
if h.JWT.ExpirationTime.Unix()-time.Now().Unix() <= 0 { if jwt.ExpirationTime.Unix()-time.Now().Unix() <= 0 {
h.SendJson(&Message{Event: TokenExpiredEvent}) h.SendJson(&Message{Event: TokenExpiredEvent})
} else if h.JWT.ExpirationTime.Unix()-time.Now().Unix() <= 180 { } else if jwt.ExpirationTime.Unix()-time.Now().Unix() <= 180 {
h.SendJson(&Message{Event: TokenExpiringEvent}) h.SendJson(&Message{Event: TokenExpiringEvent})
}
} }
} }
} }

View File

@ -31,10 +31,10 @@ const (
) )
type Handler struct { type Handler struct {
sync.Mutex
Connection *websocket.Conn Connection *websocket.Conn
JWT *tokens.WebsocketPayload `json:"-"` jwt *tokens.WebsocketPayload `json:"-"`
server *server.Server server *server.Server
mutex sync.Mutex
} }
// Parses a JWT into a websocket token payload. // Parses a JWT into a websocket token payload.
@ -69,9 +69,8 @@ func GetHandler(s *server.Server, w http.ResponseWriter, r *http.Request) (*Hand
return &Handler{ return &Handler{
Connection: conn, Connection: conn,
JWT: nil, jwt: nil,
server: s, server: s,
mutex: sync.Mutex{},
}, nil }, nil
} }
@ -82,20 +81,23 @@ func (h *Handler) SendJson(v *Message) error {
return nil return nil
} }
// If we're sending installation output but the user does not have the required j := h.GetJwt()
// permissions to see the output, don't send it down the line. if j != nil {
if v.Event == server.InstallOutputEvent { // If we're sending installation output but the user does not have the required
zap.S().Debugf("%+v", v.Args) // permissions to see the output, don't send it down the line.
if h.JWT != nil && !h.JWT.HasPermission(PermissionReceiveInstall) { if v.Event == server.InstallOutputEvent {
return nil zap.S().Debugf("%+v", v.Args)
if !j.HasPermission(PermissionReceiveInstall) {
return nil
}
} }
}
// If the user does not have permission to see backup events, do not emit // If the user does not have permission to see backup events, do not emit
// them over the socket. // them over the socket.
if strings.HasPrefix(v.Event, server.BackupCompletedEvent) { if strings.HasPrefix(v.Event, server.BackupCompletedEvent) {
if h.JWT != nil && !h.JWT.HasPermission(PermissionReceiveBackups) { if !j.HasPermission(PermissionReceiveBackups) {
return nil return nil
}
} }
} }
@ -106,27 +108,28 @@ func (h *Handler) SendJson(v *Message) error {
// socket user. Do not call this directly unless you are positive a response should be // socket user. Do not call this directly unless you are positive a response should be
// sent back to the client! // sent back to the client!
func (h *Handler) unsafeSendJson(v interface{}) error { func (h *Handler) unsafeSendJson(v interface{}) error {
h.mutex.Lock() h.Lock()
defer h.mutex.Unlock() defer h.Unlock()
return h.Connection.WriteJSON(v) return h.Connection.WriteJSON(v)
} }
// Checks if the JWT is still valid. // Checks if the JWT is still valid.
func (h *Handler) TokenValid() error { func (h *Handler) TokenValid() error {
if h.JWT == nil { j := h.GetJwt()
if j == nil {
return errors.New("no jwt present") return errors.New("no jwt present")
} }
if err := jwt.ExpirationTimeValidator(time.Now())(&h.JWT.Payload); err != nil { if err := jwt.ExpirationTimeValidator(time.Now())(&j.Payload); err != nil {
return err return err
} }
if !h.JWT.HasPermission(PermissionConnect) { if !j.HasPermission(PermissionConnect) {
return errors.New("jwt does not have connect permission") return errors.New("jwt does not have connect permission")
} }
if h.server.Uuid != h.JWT.ServerUUID { if h.server.Uuid != j.ServerUUID {
return errors.New("jwt server uuid mismatch") return errors.New("jwt server uuid mismatch")
} }
@ -137,14 +140,14 @@ func (h *Handler) TokenValid() error {
// of the token. If the user has the "receive-errors" grant we will send back the actual // of the token. If the user has the "receive-errors" grant we will send back the actual
// error message, otherwise we just send back a standard error message. // error message, otherwise we just send back a standard error message.
func (h *Handler) SendErrorJson(err error) error { func (h *Handler) SendErrorJson(err error) error {
h.mutex.Lock() h.Lock()
defer h.mutex.Unlock() defer h.Unlock()
j := h.GetJwt()
message := "an unexpected error was encountered while handling this request" message := "an unexpected error was encountered while handling this request"
if h.JWT != nil { if server.IsSuspendedError(err) || (j != nil && j.HasPermission(PermissionReceiveErrors)) {
if server.IsSuspendedError(err) || h.JWT.HasPermission(PermissionReceiveErrors) { message = err.Error()
message = err.Error()
}
} }
m, u := h.GetErrorMessage(message) m, u := h.GetErrorMessage(message)
@ -174,6 +177,20 @@ func (h *Handler) GetErrorMessage(msg string) (string, uuid.UUID) {
return m, u return m, u
} }
// Sets the JWT for the websocket in a race-safe manner.
func (h *Handler) setJwt(token *tokens.WebsocketPayload) {
h.Lock()
h.jwt = token
h.Unlock()
}
func (h *Handler) GetJwt() *tokens.WebsocketPayload {
h.Lock()
defer h.Unlock()
return h.jwt
}
// Handle the inbound socket request and route it to the proper server action. // Handle the inbound socket request and route it to the proper server action.
func (h *Handler) HandleInbound(m Message) error { func (h *Handler) HandleInbound(m Message) error {
if m.Event != AuthenticationEvent { if m.Event != AuthenticationEvent {
@ -198,7 +215,7 @@ func (h *Handler) HandleInbound(m Message) error {
} }
if token.HasPermission(PermissionConnect) { if token.HasPermission(PermissionConnect) {
h.JWT = token h.setJwt(token)
} }
// On every authentication event, send the current server status back // On every authentication event, send the current server status back
@ -216,17 +233,17 @@ func (h *Handler) HandleInbound(m Message) error {
{ {
switch strings.Join(m.Args, "") { switch strings.Join(m.Args, "") {
case "start": case "start":
if h.JWT.HasPermission(PermissionSendPowerStart) { if h.GetJwt().HasPermission(PermissionSendPowerStart) {
return h.server.Environment.Start() return h.server.Environment.Start()
} }
break break
case "stop": case "stop":
if h.JWT.HasPermission(PermissionSendPowerStop) { if h.GetJwt().HasPermission(PermissionSendPowerStop) {
return h.server.Environment.Stop() return h.server.Environment.Stop()
} }
break break
case "restart": case "restart":
if h.JWT.HasPermission(PermissionSendPowerRestart) { if h.GetJwt().HasPermission(PermissionSendPowerRestart) {
if err := h.server.Environment.WaitForStop(60, false); err != nil { if err := h.server.Environment.WaitForStop(60, false); err != nil {
return err return err
} }
@ -235,7 +252,7 @@ func (h *Handler) HandleInbound(m Message) error {
} }
break break
case "kill": case "kill":
if h.JWT.HasPermission(PermissionSendPowerStop) { if h.GetJwt().HasPermission(PermissionSendPowerStop) {
return h.server.Environment.Terminate(os.Kill) return h.server.Environment.Terminate(os.Kill)
} }
break break
@ -265,7 +282,7 @@ func (h *Handler) HandleInbound(m Message) error {
} }
case SendCommandEvent: case SendCommandEvent:
{ {
if !h.JWT.HasPermission(PermissionSendCommand) { if !h.GetJwt().HasPermission(PermissionSendCommand) {
return nil return nil
} }