api: rewrite auth to handle new routes

api: add functions to retrieve auth handler and server from a gin.Context
This commit is contained in:
Jakob Schrettenbrunner 2017-08-02 23:47:09 +02:00
parent 3d789c0541
commit 40836a5ec4
3 changed files with 46 additions and 27 deletions

View File

@ -11,13 +11,9 @@ import (
const ( const (
accessTokenHeader = "X-Access-Token" accessTokenHeader = "X-Access-Token"
accessServerHeader = "X-Access-Server"
// ContextVarServer is the gin.Context field containing the requested server (gin.Context.Get()) contextVarServer = "server"
ContextVarServer = "server" contextVarAuth = "auth"
// ContextVarAuth is the gin.Context field containing the authorizationManager
// for the request (gin.Context.Get())
ContextVarAuth = "auth"
) )
type responseError struct { type responseError struct {
@ -26,7 +22,7 @@ type responseError struct {
// AuthorizationManager handles permission checks // AuthorizationManager handles permission checks
type AuthorizationManager interface { type AuthorizationManager interface {
hasPermission(string) bool HasPermission(string) bool
} }
type authorizationManager struct { type authorizationManager struct {
@ -43,7 +39,7 @@ func newAuthorizationManager(token string, server control.Server) *authorization
} }
} }
func (a *authorizationManager) hasPermission(permission string) bool { func (a *authorizationManager) HasPermission(permission string) bool {
if permission == "" { if permission == "" {
return true return true
} }
@ -52,6 +48,7 @@ func (a *authorizationManager) hasPermission(permission string) bool {
return config.Get().ContainsAuthKey(a.token) return config.Get().ContainsAuthKey(a.token)
} }
if a.server == nil { if a.server == nil {
log.WithField("permission", permission).Error("Auth: Server required but none found.")
return false return false
} }
if prefix == "g" { if prefix == "g" {
@ -68,8 +65,9 @@ func (a *authorizationManager) hasPermission(permission string) bool {
func AuthHandler(permission string) gin.HandlerFunc { func AuthHandler(permission string) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
requestToken := c.Request.Header.Get(accessTokenHeader) requestToken := c.Request.Header.Get(accessTokenHeader)
requestServer := c.Request.Header.Get(accessServerHeader) requestServer := c.Param("server")
var server control.Server var server control.Server
if requestToken == "" && permission != "" { if requestToken == "" && permission != "" {
log.Debug("Token missing in request.") log.Debug("Token missing in request.")
c.JSON(http.StatusBadRequest, responseError{"Missing required " + accessTokenHeader + " header."}) c.JSON(http.StatusBadRequest, responseError{"Missing required " + accessTokenHeader + " header."})
@ -86,9 +84,9 @@ func AuthHandler(permission string) gin.HandlerFunc {
auth := newAuthorizationManager(requestToken, server) auth := newAuthorizationManager(requestToken, server)
if auth.hasPermission(permission) { if auth.HasPermission(permission) {
c.Set(ContextVarServer, server) c.Set(contextVarServer, server)
c.Set(ContextVarAuth, auth) c.Set(contextVarAuth, auth)
return return
} }
@ -96,3 +94,29 @@ func AuthHandler(permission string) gin.HandlerFunc {
c.Abort() c.Abort()
} }
} }
// GetContextAuthManager returns a AuthorizationManager contained in
// a gin.Context or nil
func GetContextAuthManager(c *gin.Context) AuthorizationManager {
auth, exists := c.Get(contextVarAuth)
if !exists {
return nil
}
if auth, ok := auth.(AuthorizationManager); ok {
return auth
}
return nil
}
// GetContextServer returns a control.Server contained in a gin.Context
// or null
func GetContextServer(c *gin.Context) control.Server {
server, exists := c.Get(contextVarAuth)
if !exists {
return nil
}
if server, ok := server.(control.Server); ok {
return server
}
return nil
}

View File

@ -44,15 +44,6 @@ func TestAuthHandler(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
}) })
t.Run("rejects missing server uuid", func(t *testing.T) {
loadConfiguration(t, true)
responded, rec := requestMiddlewareWith("g:test", "existingkey", "")
assert.False(t, responded)
assert.Equal(t, http.StatusForbidden, rec.Code)
})
t.Run("rejects not existing server", func(t *testing.T) { t.Run("rejects not existing server", func(t *testing.T) {
loadConfiguration(t, true) loadConfiguration(t, true)
@ -121,15 +112,19 @@ func requestMiddlewareWith(neededPermission string, token string, serverUUID str
router := gin.New() router := gin.New()
responded = false responded = false
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil) req, _ := http.NewRequest("GET", "/"+serverUUID, nil)
router.GET("/", AuthHandler(neededPermission), func(c *gin.Context) { endpoint := "/"
if serverUUID != "" {
endpoint += ":server"
}
router.GET(endpoint, AuthHandler(neededPermission), func(c *gin.Context) {
c.String(http.StatusOK, "Access granted.") c.String(http.StatusOK, "Access granted.")
responded = true responded = true
}) })
req.Header.Set(accessTokenHeader, token) req.Header.Set(accessTokenHeader, token)
req.Header.Set(accessServerHeader, serverUUID)
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
return return
} }

View File

@ -13,9 +13,9 @@ import (
// handleGetIndex handles GET / // handleGetIndex handles GET /
func handleGetIndex(c *gin.Context) { func handleGetIndex(c *gin.Context) {
auth, _ := c.Get(ContextVarAuth) auth := GetContextAuthManager(c)
if auth := auth.(AuthorizationManager); auth.hasPermission("c:info") { if auth != nil && auth.HasPermission("c:info") {
hostInfo, err := host.Info() hostInfo, err := host.Info()
if err != nil { if err != nil {
log.WithError(err).Error("Failed to retrieve host information.") log.WithError(err).Error("Failed to retrieve host information.")