diff --git a/websocket.go b/websocket.go index 3396fc6..c791722 100644 --- a/websocket.go +++ b/websocket.go @@ -19,13 +19,14 @@ import ( ) const ( - TokenExpiringEvent = "token expiring" - TokenExpiredEvent = "token expired" - AuthenticationEvent = "auth" - SetStateEvent = "set state" - SendServerLogsEvent = "send logs" - SendCommandEvent = "send command" - ErrorEvent = "daemon error" + AuthenticationSuccessEvent = "auth success" + TokenExpiringEvent = "token expiring" + TokenExpiredEvent = "token expired" + AuthenticationEvent = "auth" + SetStateEvent = "set state" + SendServerLogsEvent = "send logs" + SendCommandEvent = "send command" + ErrorEvent = "daemon error" ) type WebsocketMessage struct { @@ -140,14 +141,11 @@ func (wsh *WebsocketHandler) TokenValid() error { // Handle a request for a specific server websocket. This will handle inbound requests as well // as ensure that any console output is also passed down the wire on the socket. func (rt *Router) routeWebsocket(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - token, err := ParseJWT([]byte(r.URL.Query().Get("token"))) - if err != nil { - return - } - c, err := rt.upgrader.Upgrade(w, r, nil) if err != nil { - zap.S().Error(err) + zap.S().Errorw("error upgrading websocket", zap.Error(errors.WithStack(err))) + http.Error(w, "failed to upgrade websocket", http.StatusInternalServerError) + return } @@ -169,7 +167,7 @@ func (rt *Router) routeWebsocket(w http.ResponseWriter, r *http.Request, ps http Server: s, Mutex: sync.Mutex{}, Connection: c, - JWT: token, + JWT: nil, } handleOutput := func(data string) { @@ -202,8 +200,6 @@ func (rt *Router) routeWebsocket(w http.ResponseWriter, r *http.Request, ps http s.AddListener(server.StatsEvent, &handleResourceUse) defer s.RemoveListener(server.StatsEvent, &handleResourceUse) - s.Emit(server.StatusEvent, s.State) - // Sit here and check the time to expiration on the JWT every 30 seconds until // the token has expired. If we are within 3 minutes of the token expiring, send // a notice over the socket that it is expiring soon. If it has expired, send that @@ -262,6 +258,19 @@ func (rt *Router) routeWebsocket(w http.ResponseWriter, r *http.Request, ps http // concurrent writes to the connection, which would cause a runtime panic and cause // the program to crash out. func (wsh *WebsocketHandler) SendJson(v interface{}) error { + // Do not send JSON down the line if the JWT on the connection is not + // valid! + if err := wsh.TokenValid(); err != nil { + return nil + } + + return wsh.unsafeSendJson(v) +} + +// Sends JSON over the websocket connection, ignoring the authentication state of the +// socket user. Do not call this directly unless you are positive a response should be +// sent back to the client! +func (wsh *WebsocketHandler) unsafeSendJson(v interface{}) error { wsh.Mutex.Lock() defer wsh.Mutex.Unlock() @@ -315,10 +324,17 @@ func (wsh *WebsocketHandler) HandleInbound(m WebsocketMessage) error { return errors.New("cannot handle websocket message, not an inbound connection") } - if err := wsh.TokenValid(); err != nil { - zap.S().Debugw("jwt token is no longer valid", zap.String("message", err.Error())) + if m.Event != AuthenticationEvent { + if err := wsh.TokenValid(); err != nil { + zap.S().Debugw("jwt token is no longer valid", zap.String("message", err.Error())) - return nil + wsh.unsafeSendJson(WebsocketMessage{ + Event: ErrorEvent, + Args: []string{"could not authenticate client: " + err.Error()}, + }) + + return nil + } } switch m.Event { @@ -326,13 +342,22 @@ func (wsh *WebsocketHandler) HandleInbound(m WebsocketMessage) error { { token, err := ParseJWT([]byte(strings.Join(m.Args, ""))) if err != nil { - return nil + return err } if token.HasPermission(PermissionConnect) { wsh.JWT = token } + // On every authentication event, send the current server status back + // to the client. :) + wsh.Server.Emit(server.StatusEvent, wsh.Server.State) + + wsh.unsafeSendJson(WebsocketMessage{ + Event: AuthenticationSuccessEvent, + Args: []string{}, + }) + return nil } case SetStateEvent: