diff --git a/http.go b/http.go index a97fd3f..1669cfc 100644 --- a/http.go +++ b/http.go @@ -18,22 +18,14 @@ import ( "strings" ) -type ServerCollection []*server.Server - // Retrieves a server out of the collection by UUID. -func (sc *ServerCollection) Get(uuid string) *server.Server { - for _, s := range *sc { - if s.Uuid == uuid { - return s - } - } - - return nil +func (rt *Router) GetServer(uuid string) *server.Server { + return server.GetServers().Find(func(i *server.Server) bool { + return i.Uuid == uuid + }) } type Router struct { - Servers ServerCollection - upgrader websocket.Upgrader // The authentication token defined in the config.yml file that allows @@ -49,7 +41,7 @@ func (rt *Router) AuthenticateRequest(h httprouter.Handle) httprouter.Handle { // is in a state that allows it to be exposed to the API. func (rt *Router) AuthenticateServer(h httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - if rt.Servers.Get(ps.ByName("server")) != nil { + if rt.GetServer(ps.ByName("server")) != nil { h(w, r, ps) return } @@ -98,12 +90,12 @@ func (rt *Router) routeIndex(w http.ResponseWriter, _ *http.Request, _ httproute // requests that include an administrative control key, otherwise a 404 is returned. This // authentication is handled by a middleware. func (rt *Router) routeAllServers(w http.ResponseWriter, _ *http.Request, _ httprouter.Params) { - json.NewEncoder(w).Encode(rt.Servers) + json.NewEncoder(w).Encode(server.GetServers().All()) } // Returns basic information about a single server found on the Daemon. func (rt *Router) routeServer(w http.ResponseWriter, _ *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) json.NewEncoder(w).Encode(s) } @@ -130,7 +122,7 @@ func (pr *PowerActionRequest) IsValid() bool { // things are happening, so theres no reason to sit and wait for a request to finish. We'll // just see over the socket if something isn't working correctly. func (rt *Router) routeServerPower(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) defer r.Body.Close() dec := json.NewDecoder(r.Body) @@ -206,7 +198,7 @@ func (rt *Router) routeServerPower(w http.ResponseWriter, r *http.Request, ps ht // Return the last 1Kb of the server log file. func (rt *Router) routeServerLogs(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) l, _ := strconv.ParseInt(r.URL.Query().Get("size"), 10, 64) if l <= 0 { @@ -225,7 +217,7 @@ func (rt *Router) routeServerLogs(w http.ResponseWriter, r *http.Request, ps htt // Handle a request to get the contents of a file on the server. func (rt *Router) routeServerFileRead(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) cleaned, err := s.Filesystem.SafePath(r.URL.Query().Get("file")) if err != nil { @@ -272,7 +264,7 @@ func (rt *Router) routeServerFileRead(w http.ResponseWriter, r *http.Request, ps // Lists the contents of a directory. func (rt *Router) routeServerListDirectory(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) stats, err := s.Filesystem.ListDirectory(r.URL.Query().Get("directory")) if os.IsNotExist(err) { @@ -290,7 +282,7 @@ func (rt *Router) routeServerListDirectory(w http.ResponseWriter, r *http.Reques // Writes a file to the system for the server. func (rt *Router) routeServerWriteFile(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) p := r.URL.Query().Get("file") defer r.Body.Close() @@ -308,7 +300,7 @@ func (rt *Router) routeServerWriteFile(w http.ResponseWriter, r *http.Request, p // Creates a new directory for the server. func (rt *Router) routeServerCreateDirectory(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) defer r.Body.Close() dec := json.NewDecoder(r.Body) @@ -336,7 +328,7 @@ func (rt *Router) routeServerCreateDirectory(w http.ResponseWriter, r *http.Requ } func (rt *Router) routeServerRenameFile(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) defer r.Body.Close() data := rt.ReaderToBytes(r.Body) @@ -359,7 +351,7 @@ func (rt *Router) routeServerRenameFile(w http.ResponseWriter, r *http.Request, } func (rt *Router) routeServerCopyFile(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) defer r.Body.Close() data := rt.ReaderToBytes(r.Body) @@ -376,7 +368,7 @@ func (rt *Router) routeServerCopyFile(w http.ResponseWriter, r *http.Request, ps } func (rt *Router) routeServerDeleteFile(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) defer r.Body.Close() data := rt.ReaderToBytes(r.Body) @@ -393,7 +385,7 @@ func (rt *Router) routeServerDeleteFile(w http.ResponseWriter, r *http.Request, } func (rt *Router) routeServerSendCommand(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) defer r.Body.Close() if running, err := s.Environment.IsRunning(); !running || err != nil { @@ -419,7 +411,7 @@ func (rt *Router) routeServerSendCommand(w http.ResponseWriter, r *http.Request, } func (rt *Router) routeServerUpdate(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) defer r.Body.Close() data := rt.ReaderToBytes(r.Body) @@ -447,7 +439,7 @@ func (rt *Router) routeCreateServer(w http.ResponseWriter, r *http.Request, ps h // Plop that server instance onto the request so that it can be referenced in // requests from here-on out. - rt.Servers = append(rt.Servers, inst.Server()) + server.GetServers().Add(inst.Server()) // Begin the installation process in the background to not block the request // cycle. If there are any errors they will be logged and communicated back diff --git a/server/collection.go b/server/collection.go new file mode 100644 index 0000000..6aa7cb3 --- /dev/null +++ b/server/collection.go @@ -0,0 +1,39 @@ +package server + +type Collection struct { + items []*Server +} + +// Return all of the items in the collection. +func (c *Collection) All() []*Server { + return c.items +} + +// Adds an item to the collection store. +func (c *Collection) Add(s *Server) { + c.items = append(c.items, s) +} + +// Returns only those items matching the filter criteria. +func (c *Collection) Filter(filter func (*Server) bool) []*Server { + r := make([]*Server, 0) + for _, v := range c.items { + if filter(v) { + r = append(r, v) + } + } + + return r +} + +// Returns a single element from the collection matching the filter. If nothing is +// found a nil result is returned. +func (c *Collection) Find(filter func (*Server) bool) *Server { + for _, v := range c.items { + if filter(v) { + return v + } + } + + return nil +} \ No newline at end of file diff --git a/server/server.go b/server/server.go index d428e8f..6b1af0d 100644 --- a/server/server.go +++ b/server/server.go @@ -18,6 +18,12 @@ import ( "time" ) +var servers *Collection + +func GetServers() *Collection { + return servers +} + // High level definition for a server instance being controlled by Wings. type Server struct { // The unique identifier for the server that should be used when referencing @@ -136,7 +142,7 @@ type Allocations struct { // Iterates over a given directory and loads all of the servers listed before returning // them to the calling function. -func LoadDirectory(dir string, cfg *config.SystemConfiguration) ([]*Server, error) { +func LoadDirectory(dir string, cfg *config.SystemConfiguration) error { // We could theoretically use a standard wait group here, however doing // that introduces the potential to crash the program due to too many // open files. This wouldn't happen on a small setup, but once the daemon is @@ -149,10 +155,10 @@ func LoadDirectory(dir string, cfg *config.SystemConfiguration) ([]*Server, erro f, err := ioutil.ReadDir(dir) if err != nil { - return nil, err + return err } - var servers []*Server + servers = new(Collection) for _, file := range f { if !strings.HasSuffix(file.Name(), ".yml") || file.IsDir() { @@ -177,7 +183,7 @@ func LoadDirectory(dir string, cfg *config.SystemConfiguration) ([]*Server, erro return } - servers = append(servers, s) + servers.Add(s) }(file) } @@ -185,7 +191,7 @@ func LoadDirectory(dir string, cfg *config.SystemConfiguration) ([]*Server, erro // before continuing. wg.Wait() - return servers, nil + return nil } // Initializes the default required internal struct components for a Server. diff --git a/sftp/server.go b/sftp/server.go index 3f842c1..11ee6e2 100644 --- a/sftp/server.go +++ b/sftp/server.go @@ -24,6 +24,8 @@ func Initialize(config *config.Configuration) error { DisableDiskCheck: config.System.Sftp.DisableDiskChecking, }, CredentialValidator: validateCredentials, + PathValidator: validatePath, + DiskSpaceValidator: validateDiskSpace, } if err := sftp_server.New(c); err != nil { @@ -45,6 +47,14 @@ func Initialize(config *config.Configuration) error { return nil } +func validatePath(fs sftp_server.FileSystem, p string) (string, error) { + return p, nil +} + +func validateDiskSpace(fs sftp_server.FileSystem) bool { + return true +} + // Validates a set of credentials for a SFTP login aganist Pterodactyl Panel and returns // the server's UUID if the credentials were valid. func validateCredentials(c sftp_server.AuthenticationRequest) (*sftp_server.AuthenticationResponse, error) { diff --git a/websocket.go b/websocket.go index e6f441f..3396fc6 100644 --- a/websocket.go +++ b/websocket.go @@ -164,7 +164,7 @@ func (rt *Router) routeWebsocket(w http.ResponseWriter, r *http.Request, ps http c.Close() }() - s := rt.Servers.Get(ps.ByName("server")) + s := rt.GetServer(ps.ByName("server")) handler := WebsocketHandler{ Server: s, Mutex: sync.Mutex{}, diff --git a/wings.go b/wings.go index afcda6b..987adf6 100644 --- a/wings.go +++ b/wings.go @@ -65,14 +65,13 @@ func main() { zap.S().Infow("finished ensuring file permissions") } - servers, err := server.LoadDirectory("data/servers", c.System) - if err != nil { + if err := server.LoadDirectory("data/servers", c.System); err != nil { zap.S().Fatalw("failed to load server configurations", zap.Error(err)) return } // Just for some nice log output. - for _, s := range servers { + for _, s := range server.GetServers().All() { zap.S().Infow("loaded configuration for server", zap.String("server", s.Uuid)) } @@ -81,7 +80,7 @@ func main() { // and reboot processes without causing a slow-down due to sequential booting. wg := sizedwaitgroup.New(4) - for _, serv := range servers { + for _, serv := range server.GetServers().All() { wg.Add() go func(s *server.Server) { @@ -144,7 +143,6 @@ func main() { } r := &Router{ - Servers: servers, token: c.AuthenticationToken, upgrader: websocket.Upgrader{ // Ensure that the websocket request is originating from the Panel itself,