More code cleanup and struct simplification
This commit is contained in:
parent
c228acaafc
commit
a48abc92ad
|
@ -272,8 +272,8 @@ func rootCmdRun(cmd *cobra.Command, _ []string) {
|
|||
// Wait until all of the servers are ready to go before we fire up the SFTP and HTTP servers.
|
||||
pool.StopWait()
|
||||
|
||||
// Initialize the SFTP server.
|
||||
if err := sftp.Initialize(c.System); err != nil {
|
||||
// Run the SFTP server.
|
||||
if err := sftp.NewServer().Run(); err != nil {
|
||||
log.WithError(err).Fatal("failed to initialize the sftp server")
|
||||
return
|
||||
}
|
||||
|
|
110
sftp/server.go
110
sftp/server.go
|
@ -13,6 +13,7 @@ import (
|
|||
"path"
|
||||
"strings"
|
||||
|
||||
"emperror.dev/errors"
|
||||
"github.com/apex/log"
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/pterodactyl/wings/api"
|
||||
|
@ -21,59 +22,35 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Settings struct {
|
||||
//goland:noinspection GoNameStartsWithPackageName
|
||||
type SFTPServer struct {
|
||||
BasePath string
|
||||
ReadOnly bool
|
||||
BindPort int
|
||||
BindAddress string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Uid int
|
||||
Gid int
|
||||
var noMatchingServerError = errors.Sentinel("sftp: no matching server with UUID")
|
||||
|
||||
func NewServer() *SFTPServer {
|
||||
cfg := config.Get().System
|
||||
return &SFTPServer{
|
||||
BasePath: cfg.Data,
|
||||
ReadOnly: cfg.Sftp.ReadOnly,
|
||||
BindAddress: cfg.Sftp.Address,
|
||||
BindPort: cfg.Sftp.Port,
|
||||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoNameStartsWithPackageName
|
||||
type SFTPServer struct {
|
||||
Settings Settings
|
||||
User User
|
||||
// Validator function that is called when a user connects to the server. This should
|
||||
// check against whatever system is desired to confirm if the given username and password
|
||||
// combination is valid. If so, should return an authentication response.
|
||||
credentialValidator func(r api.SftpAuthRequest) (*api.SftpAuthResponse, error)
|
||||
}
|
||||
|
||||
// Initialize the SFTP server and add a persistent listener to handle inbound SFTP connections.
|
||||
func (c *SFTPServer) Initialize() error {
|
||||
// Starts the SFTP server and add a persistent listener to handle inbound SFTP connections.
|
||||
func (c *SFTPServer) Run() error {
|
||||
serverConfig := &ssh.ServerConfig{
|
||||
NoClientAuth: false,
|
||||
MaxAuthTries: 6,
|
||||
PasswordCallback: func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
|
||||
resp, err := c.credentialValidator(api.SftpAuthRequest{
|
||||
User: conn.User(),
|
||||
Pass: string(pass),
|
||||
IP: conn.RemoteAddr().String(),
|
||||
SessionID: conn.SessionID(),
|
||||
ClientVersion: conn.ClientVersion(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sshPerm := &ssh.Permissions{
|
||||
Extensions: map[string]string{
|
||||
"uuid": resp.Server,
|
||||
"user": conn.User(),
|
||||
"permissions": strings.Join(resp.Permissions, ","),
|
||||
},
|
||||
}
|
||||
|
||||
return sshPerm, nil
|
||||
},
|
||||
NoClientAuth: false,
|
||||
MaxAuthTries: 6,
|
||||
PasswordCallback: c.passwordCallback,
|
||||
}
|
||||
|
||||
if _, err := os.Stat(path.Join(c.Settings.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) {
|
||||
if _, err := os.Stat(path.Join(c.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) {
|
||||
if err := c.generatePrivateKey(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -81,7 +58,7 @@ func (c *SFTPServer) Initialize() error {
|
|||
return err
|
||||
}
|
||||
|
||||
privateBytes, err := ioutil.ReadFile(path.Join(c.Settings.BasePath, ".sftp/id_rsa"))
|
||||
privateBytes, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -94,12 +71,12 @@ func (c *SFTPServer) Initialize() error {
|
|||
// Add our private key to the server configuration.
|
||||
serverConfig.AddHostKey(private)
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.Settings.BindAddress, c.Settings.BindPort))
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithField("host", c.Settings.BindAddress).WithField("port", c.Settings.BindPort).Info("sftp subsystem listening for connections")
|
||||
log.WithField("host", c.BindAddress).WithField("port", c.BindPort).Info("sftp subsystem listening for connections")
|
||||
|
||||
for {
|
||||
conn, _ := listener.Accept()
|
||||
|
@ -171,7 +148,7 @@ func (c SFTPServer) AcceptInboundConnection(conn net.Conn, config *ssh.ServerCon
|
|||
// Creates a new SFTP handler for a given server. The directory argument should
|
||||
// be the base directory for a server. All actions done on the server will be
|
||||
// relative to that directory, and the user will not be able to escape out of it.
|
||||
func (c SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers {
|
||||
func (c *SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers {
|
||||
s := server.GetServers().Find(func(s *server.Server) bool {
|
||||
return s.Id() == sc.Permissions.Extensions["uuid"]
|
||||
})
|
||||
|
@ -196,17 +173,17 @@ func (c SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers {
|
|||
}
|
||||
|
||||
// Generates a private key that will be used by the SFTP server.
|
||||
func (c SFTPServer) generatePrivateKey() error {
|
||||
func (c *SFTPServer) generatePrivateKey() error {
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(path.Join(c.Settings.BasePath, ".sftp"), 0755); err != nil {
|
||||
if err := os.MkdirAll(path.Join(c.BasePath, ".sftp"), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
o, err := os.OpenFile(path.Join(c.Settings.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
o, err := os.OpenFile(path.Join(c.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -223,3 +200,38 @@ func (c SFTPServer) generatePrivateKey() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// A function capable of validating user credentials with the Panel API.
|
||||
func (c *SFTPServer) passwordCallback(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
|
||||
request := api.SftpAuthRequest{
|
||||
User: conn.User(),
|
||||
Pass: string(pass),
|
||||
IP: conn.RemoteAddr().String(),
|
||||
SessionID: conn.SessionID(),
|
||||
ClientVersion: conn.ClientVersion(),
|
||||
}
|
||||
|
||||
logger := log.WithFields(log.Fields{"subsystem": "sftp", "username": conn.User(), "ip": conn.RemoteAddr().String()})
|
||||
logger.Debug("validating credentials for SFTP connection")
|
||||
|
||||
resp, err := api.New().ValidateSftpCredentials(request)
|
||||
if err != nil {
|
||||
if api.IsInvalidCredentialsError(err) {
|
||||
logger.Warn("failed to validate user credentials (invalid username or password)")
|
||||
} else {
|
||||
logger.Error("encountered an error while trying to validate user credentials")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.WithField("server", resp.Server).Debug("credentials validated and matched to server instance")
|
||||
sshPerm := &ssh.Permissions{
|
||||
Extensions: map[string]string{
|
||||
"uuid": resp.Server,
|
||||
"user": conn.User(),
|
||||
"permissions": strings.Join(resp.Permissions, ","),
|
||||
},
|
||||
}
|
||||
|
||||
return sshPerm, nil
|
||||
}
|
||||
|
|
67
sftp/sftp.go
67
sftp/sftp.go
|
@ -1,67 +0,0 @@
|
|||
package sftp
|
||||
|
||||
import (
|
||||
"emperror.dev/errors"
|
||||
"github.com/apex/log"
|
||||
"github.com/pterodactyl/wings/api"
|
||||
"github.com/pterodactyl/wings/config"
|
||||
"github.com/pterodactyl/wings/server"
|
||||
)
|
||||
|
||||
var noMatchingServerError = errors.New("no matching server with that UUID was found")
|
||||
|
||||
func Initialize(config config.SystemConfiguration) error {
|
||||
s := &SFTPServer{
|
||||
User: User{
|
||||
Uid: config.User.Uid,
|
||||
Gid: config.User.Gid,
|
||||
},
|
||||
Settings: Settings{
|
||||
BasePath: config.Data,
|
||||
ReadOnly: config.Sftp.ReadOnly,
|
||||
BindAddress: config.Sftp.Address,
|
||||
BindPort: config.Sftp.Port,
|
||||
},
|
||||
credentialValidator: validateCredentials,
|
||||
}
|
||||
|
||||
// Initialize the SFTP server in a background thread since this is
|
||||
// a long running operation.
|
||||
go func(s *SFTPServer) {
|
||||
if err := s.Initialize(); err != nil {
|
||||
log.WithField("subsystem", "sftp").WithField("error", err).Error("failed to initialize SFTP subsystem")
|
||||
}
|
||||
}(s)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validates a set of credentials for a SFTP login against Pterodactyl Panel and returns
|
||||
// the server's UUID if the credentials were valid.
|
||||
func validateCredentials(c api.SftpAuthRequest) (*api.SftpAuthResponse, error) {
|
||||
f := log.Fields{"subsystem": "sftp", "username": c.User, "ip": c.IP}
|
||||
|
||||
log.WithFields(f).Debug("validating credentials for SFTP connection")
|
||||
resp, err := api.New().ValidateSftpCredentials(c)
|
||||
if err != nil {
|
||||
if api.IsInvalidCredentialsError(err) {
|
||||
log.WithFields(f).Warn("failed to validate user credentials (invalid username or password)")
|
||||
} else {
|
||||
log.WithFields(f).Error("encountered an error while trying to validate user credentials")
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
s := server.GetServers().Find(func(server *server.Server) bool {
|
||||
return server.Id() == resp.Server
|
||||
})
|
||||
|
||||
if s == nil {
|
||||
return resp, noMatchingServerError
|
||||
}
|
||||
|
||||
s.Log().WithFields(f).Debug("credentials successfully validated and matched user to server instance")
|
||||
|
||||
return resp, err
|
||||
}
|
|
@ -5,9 +5,6 @@ import (
|
|||
"os"
|
||||
)
|
||||
|
||||
type ListerAt []os.FileInfo
|
||||
type fxerr uint32
|
||||
|
||||
const (
|
||||
// Extends the default SFTP server to return a quota exceeded error to the client.
|
||||
//
|
||||
|
@ -15,6 +12,8 @@ const (
|
|||
ErrSSHQuotaExceeded = fxerr(15)
|
||||
)
|
||||
|
||||
type ListerAt []os.FileInfo
|
||||
|
||||
// Returns the number of entries copied and an io.EOF error if we made it to the end of the file list.
|
||||
// Take a look at the pkg/sftp godoc for more information about how this function should work.
|
||||
func (l ListerAt) ListAt(f []os.FileInfo, offset int64) (int, error) {
|
||||
|
@ -29,6 +28,8 @@ func (l ListerAt) ListAt(f []os.FileInfo, offset int64) (int, error) {
|
|||
}
|
||||
}
|
||||
|
||||
type fxerr uint32
|
||||
|
||||
func (e fxerr) Error() string {
|
||||
switch e {
|
||||
case ErrSSHQuotaExceeded:
|
||||
|
|
Loading…
Reference in New Issue
Block a user