More code cleanup and struct simplification

This commit is contained in:
Dane Everitt 2021-01-10 14:43:27 -08:00
parent c228acaafc
commit a48abc92ad
No known key found for this signature in database
GPG Key ID: EEA66103B3D71F53
4 changed files with 67 additions and 121 deletions

View File

@ -272,8 +272,8 @@ func rootCmdRun(cmd *cobra.Command, _ []string) {
// Wait until all of the servers are ready to go before we fire up the SFTP and HTTP servers. // Wait until all of the servers are ready to go before we fire up the SFTP and HTTP servers.
pool.StopWait() pool.StopWait()
// Initialize the SFTP server. // Run the SFTP server.
if err := sftp.Initialize(c.System); err != nil { if err := sftp.NewServer().Run(); err != nil {
log.WithError(err).Fatal("failed to initialize the sftp server") log.WithError(err).Fatal("failed to initialize the sftp server")
return return
} }

View File

@ -13,6 +13,7 @@ import (
"path" "path"
"strings" "strings"
"emperror.dev/errors"
"github.com/apex/log" "github.com/apex/log"
"github.com/pkg/sftp" "github.com/pkg/sftp"
"github.com/pterodactyl/wings/api" "github.com/pterodactyl/wings/api"
@ -21,59 +22,35 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Settings struct { //goland:noinspection GoNameStartsWithPackageName
type SFTPServer struct {
BasePath string BasePath string
ReadOnly bool ReadOnly bool
BindPort int BindPort int
BindAddress string BindAddress string
} }
type User struct { var noMatchingServerError = errors.Sentinel("sftp: no matching server with UUID")
Uid int
Gid int func NewServer() *SFTPServer {
cfg := config.Get().System
return &SFTPServer{
BasePath: cfg.Data,
ReadOnly: cfg.Sftp.ReadOnly,
BindAddress: cfg.Sftp.Address,
BindPort: cfg.Sftp.Port,
}
} }
//goland:noinspection GoNameStartsWithPackageName // Starts the SFTP server and add a persistent listener to handle inbound SFTP connections.
type SFTPServer struct { func (c *SFTPServer) Run() error {
Settings Settings
User User
// Validator function that is called when a user connects to the server. This should
// check against whatever system is desired to confirm if the given username and password
// combination is valid. If so, should return an authentication response.
credentialValidator func(r api.SftpAuthRequest) (*api.SftpAuthResponse, error)
}
// Initialize the SFTP server and add a persistent listener to handle inbound SFTP connections.
func (c *SFTPServer) Initialize() error {
serverConfig := &ssh.ServerConfig{ serverConfig := &ssh.ServerConfig{
NoClientAuth: false, NoClientAuth: false,
MaxAuthTries: 6, MaxAuthTries: 6,
PasswordCallback: func(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { PasswordCallback: c.passwordCallback,
resp, err := c.credentialValidator(api.SftpAuthRequest{
User: conn.User(),
Pass: string(pass),
IP: conn.RemoteAddr().String(),
SessionID: conn.SessionID(),
ClientVersion: conn.ClientVersion(),
})
if err != nil {
return nil, err
}
sshPerm := &ssh.Permissions{
Extensions: map[string]string{
"uuid": resp.Server,
"user": conn.User(),
"permissions": strings.Join(resp.Permissions, ","),
},
}
return sshPerm, nil
},
} }
if _, err := os.Stat(path.Join(c.Settings.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) { if _, err := os.Stat(path.Join(c.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) {
if err := c.generatePrivateKey(); err != nil { if err := c.generatePrivateKey(); err != nil {
return err return err
} }
@ -81,7 +58,7 @@ func (c *SFTPServer) Initialize() error {
return err return err
} }
privateBytes, err := ioutil.ReadFile(path.Join(c.Settings.BasePath, ".sftp/id_rsa")) privateBytes, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa"))
if err != nil { if err != nil {
return err return err
} }
@ -94,12 +71,12 @@ func (c *SFTPServer) Initialize() error {
// Add our private key to the server configuration. // Add our private key to the server configuration.
serverConfig.AddHostKey(private) serverConfig.AddHostKey(private)
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.Settings.BindAddress, c.Settings.BindPort)) listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
if err != nil { if err != nil {
return err return err
} }
log.WithField("host", c.Settings.BindAddress).WithField("port", c.Settings.BindPort).Info("sftp subsystem listening for connections") log.WithField("host", c.BindAddress).WithField("port", c.BindPort).Info("sftp subsystem listening for connections")
for { for {
conn, _ := listener.Accept() conn, _ := listener.Accept()
@ -171,7 +148,7 @@ func (c SFTPServer) AcceptInboundConnection(conn net.Conn, config *ssh.ServerCon
// Creates a new SFTP handler for a given server. The directory argument should // Creates a new SFTP handler for a given server. The directory argument should
// be the base directory for a server. All actions done on the server will be // be the base directory for a server. All actions done on the server will be
// relative to that directory, and the user will not be able to escape out of it. // relative to that directory, and the user will not be able to escape out of it.
func (c SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers { func (c *SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers {
s := server.GetServers().Find(func(s *server.Server) bool { s := server.GetServers().Find(func(s *server.Server) bool {
return s.Id() == sc.Permissions.Extensions["uuid"] return s.Id() == sc.Permissions.Extensions["uuid"]
}) })
@ -196,17 +173,17 @@ func (c SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers {
} }
// Generates a private key that will be used by the SFTP server. // Generates a private key that will be used by the SFTP server.
func (c SFTPServer) generatePrivateKey() error { func (c *SFTPServer) generatePrivateKey() error {
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return err return err
} }
if err := os.MkdirAll(path.Join(c.Settings.BasePath, ".sftp"), 0755); err != nil { if err := os.MkdirAll(path.Join(c.BasePath, ".sftp"), 0755); err != nil {
return err return err
} }
o, err := os.OpenFile(path.Join(c.Settings.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) o, err := os.OpenFile(path.Join(c.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil { if err != nil {
return err return err
} }
@ -223,3 +200,38 @@ func (c SFTPServer) generatePrivateKey() error {
return nil return nil
} }
// A function capable of validating user credentials with the Panel API.
func (c *SFTPServer) passwordCallback(conn ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
request := api.SftpAuthRequest{
User: conn.User(),
Pass: string(pass),
IP: conn.RemoteAddr().String(),
SessionID: conn.SessionID(),
ClientVersion: conn.ClientVersion(),
}
logger := log.WithFields(log.Fields{"subsystem": "sftp", "username": conn.User(), "ip": conn.RemoteAddr().String()})
logger.Debug("validating credentials for SFTP connection")
resp, err := api.New().ValidateSftpCredentials(request)
if err != nil {
if api.IsInvalidCredentialsError(err) {
logger.Warn("failed to validate user credentials (invalid username or password)")
} else {
logger.Error("encountered an error while trying to validate user credentials")
}
return nil, err
}
logger.WithField("server", resp.Server).Debug("credentials validated and matched to server instance")
sshPerm := &ssh.Permissions{
Extensions: map[string]string{
"uuid": resp.Server,
"user": conn.User(),
"permissions": strings.Join(resp.Permissions, ","),
},
}
return sshPerm, nil
}

View File

@ -1,67 +0,0 @@
package sftp
import (
"emperror.dev/errors"
"github.com/apex/log"
"github.com/pterodactyl/wings/api"
"github.com/pterodactyl/wings/config"
"github.com/pterodactyl/wings/server"
)
var noMatchingServerError = errors.New("no matching server with that UUID was found")
func Initialize(config config.SystemConfiguration) error {
s := &SFTPServer{
User: User{
Uid: config.User.Uid,
Gid: config.User.Gid,
},
Settings: Settings{
BasePath: config.Data,
ReadOnly: config.Sftp.ReadOnly,
BindAddress: config.Sftp.Address,
BindPort: config.Sftp.Port,
},
credentialValidator: validateCredentials,
}
// Initialize the SFTP server in a background thread since this is
// a long running operation.
go func(s *SFTPServer) {
if err := s.Initialize(); err != nil {
log.WithField("subsystem", "sftp").WithField("error", err).Error("failed to initialize SFTP subsystem")
}
}(s)
return nil
}
// Validates a set of credentials for a SFTP login against Pterodactyl Panel and returns
// the server's UUID if the credentials were valid.
func validateCredentials(c api.SftpAuthRequest) (*api.SftpAuthResponse, error) {
f := log.Fields{"subsystem": "sftp", "username": c.User, "ip": c.IP}
log.WithFields(f).Debug("validating credentials for SFTP connection")
resp, err := api.New().ValidateSftpCredentials(c)
if err != nil {
if api.IsInvalidCredentialsError(err) {
log.WithFields(f).Warn("failed to validate user credentials (invalid username or password)")
} else {
log.WithFields(f).Error("encountered an error while trying to validate user credentials")
}
return resp, err
}
s := server.GetServers().Find(func(server *server.Server) bool {
return server.Id() == resp.Server
})
if s == nil {
return resp, noMatchingServerError
}
s.Log().WithFields(f).Debug("credentials successfully validated and matched user to server instance")
return resp, err
}

View File

@ -5,9 +5,6 @@ import (
"os" "os"
) )
type ListerAt []os.FileInfo
type fxerr uint32
const ( const (
// Extends the default SFTP server to return a quota exceeded error to the client. // Extends the default SFTP server to return a quota exceeded error to the client.
// //
@ -15,6 +12,8 @@ const (
ErrSSHQuotaExceeded = fxerr(15) ErrSSHQuotaExceeded = fxerr(15)
) )
type ListerAt []os.FileInfo
// Returns the number of entries copied and an io.EOF error if we made it to the end of the file list. // Returns the number of entries copied and an io.EOF error if we made it to the end of the file list.
// Take a look at the pkg/sftp godoc for more information about how this function should work. // Take a look at the pkg/sftp godoc for more information about how this function should work.
func (l ListerAt) ListAt(f []os.FileInfo, offset int64) (int, error) { func (l ListerAt) ListAt(f []os.FileInfo, offset int64) (int, error) {
@ -29,6 +28,8 @@ func (l ListerAt) ListAt(f []os.FileInfo, offset int64) (int, error) {
} }
} }
type fxerr uint32
func (e fxerr) Error() string { func (e fxerr) Error() string {
switch e { switch e {
case ErrSSHQuotaExceeded: case ErrSSHQuotaExceeded: