From 40836a5ec4ea6422a2c38e09b2636154a537e3db Mon Sep 17 00:00:00 2001 From: Jakob Schrettenbrunner Date: Wed, 2 Aug 2017 23:47:09 +0200 Subject: [PATCH] api: rewrite auth to handle new routes api: add functions to retrieve auth handler and server from a gin.Context --- api/auth.go | 50 +++++++++++++++++++++++++++++++++++------------- api/auth_test.go | 19 +++++++----------- api/handlers.go | 4 ++-- 3 files changed, 46 insertions(+), 27 deletions(-) diff --git a/api/auth.go b/api/auth.go index fc67cf9..e4552d6 100644 --- a/api/auth.go +++ b/api/auth.go @@ -10,14 +10,10 @@ import ( ) const ( - accessTokenHeader = "X-Access-Token" - accessServerHeader = "X-Access-Server" + accessTokenHeader = "X-Access-Token" - // ContextVarServer is the gin.Context field containing the requested server (gin.Context.Get()) - ContextVarServer = "server" - // ContextVarAuth is the gin.Context field containing the authorizationManager - // for the request (gin.Context.Get()) - ContextVarAuth = "auth" + contextVarServer = "server" + contextVarAuth = "auth" ) type responseError struct { @@ -26,7 +22,7 @@ type responseError struct { // AuthorizationManager handles permission checks type AuthorizationManager interface { - hasPermission(string) bool + HasPermission(string) bool } type authorizationManager struct { @@ -43,7 +39,7 @@ func newAuthorizationManager(token string, server control.Server) *authorization } } -func (a *authorizationManager) hasPermission(permission string) bool { +func (a *authorizationManager) HasPermission(permission string) bool { if permission == "" { return true } @@ -52,6 +48,7 @@ func (a *authorizationManager) hasPermission(permission string) bool { return config.Get().ContainsAuthKey(a.token) } if a.server == nil { + log.WithField("permission", permission).Error("Auth: Server required but none found.") return false } if prefix == "g" { @@ -68,8 +65,9 @@ func (a *authorizationManager) hasPermission(permission string) bool { func AuthHandler(permission string) gin.HandlerFunc { return func(c *gin.Context) { requestToken := c.Request.Header.Get(accessTokenHeader) - requestServer := c.Request.Header.Get(accessServerHeader) + requestServer := c.Param("server") var server control.Server + if requestToken == "" && permission != "" { log.Debug("Token missing in request.") c.JSON(http.StatusBadRequest, responseError{"Missing required " + accessTokenHeader + " header."}) @@ -86,9 +84,9 @@ func AuthHandler(permission string) gin.HandlerFunc { auth := newAuthorizationManager(requestToken, server) - if auth.hasPermission(permission) { - c.Set(ContextVarServer, server) - c.Set(ContextVarAuth, auth) + if auth.HasPermission(permission) { + c.Set(contextVarServer, server) + c.Set(contextVarAuth, auth) return } @@ -96,3 +94,29 @@ func AuthHandler(permission string) gin.HandlerFunc { c.Abort() } } + +// GetContextAuthManager returns a AuthorizationManager contained in +// a gin.Context or nil +func GetContextAuthManager(c *gin.Context) AuthorizationManager { + auth, exists := c.Get(contextVarAuth) + if !exists { + return nil + } + if auth, ok := auth.(AuthorizationManager); ok { + return auth + } + return nil +} + +// GetContextServer returns a control.Server contained in a gin.Context +// or null +func GetContextServer(c *gin.Context) control.Server { + server, exists := c.Get(contextVarAuth) + if !exists { + return nil + } + if server, ok := server.(control.Server); ok { + return server + } + return nil +} diff --git a/api/auth_test.go b/api/auth_test.go index bb33ba7..086f4a0 100644 --- a/api/auth_test.go +++ b/api/auth_test.go @@ -44,15 +44,6 @@ func TestAuthHandler(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) }) - t.Run("rejects missing server uuid", func(t *testing.T) { - loadConfiguration(t, true) - - responded, rec := requestMiddlewareWith("g:test", "existingkey", "") - - assert.False(t, responded) - assert.Equal(t, http.StatusForbidden, rec.Code) - }) - t.Run("rejects not existing server", func(t *testing.T) { loadConfiguration(t, true) @@ -121,15 +112,19 @@ func requestMiddlewareWith(neededPermission string, token string, serverUUID str router := gin.New() responded = false recorder = httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/", nil) + req, _ := http.NewRequest("GET", "/"+serverUUID, nil) - router.GET("/", AuthHandler(neededPermission), func(c *gin.Context) { + endpoint := "/" + if serverUUID != "" { + endpoint += ":server" + } + + router.GET(endpoint, AuthHandler(neededPermission), func(c *gin.Context) { c.String(http.StatusOK, "Access granted.") responded = true }) req.Header.Set(accessTokenHeader, token) - req.Header.Set(accessServerHeader, serverUUID) router.ServeHTTP(recorder, req) return } diff --git a/api/handlers.go b/api/handlers.go index 7740667..73b4c95 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -13,9 +13,9 @@ import ( // handleGetIndex handles GET / func handleGetIndex(c *gin.Context) { - auth, _ := c.Get(ContextVarAuth) + auth := GetContextAuthManager(c) - if auth := auth.(AuthorizationManager); auth.hasPermission("c:info") { + if auth != nil && auth.HasPermission("c:info") { hostInfo, err := host.Info() if err != nil { log.WithError(err).Error("Failed to retrieve host information.")