diff --git a/cmd/root.go b/cmd/root.go index 78c79cf..18dccb8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -273,7 +273,7 @@ func rootCmdRun(cmd *cobra.Command, _ []string) { pool.StopWait() // Run the SFTP server. - if err := sftp.NewServer().Run(); err != nil { + if err := sftp.New().Run(); err != nil { log.WithError(err).Fatal("failed to initialize the sftp server") return } diff --git a/sftp/handler.go b/sftp/handler.go index 8957583..1461876 100644 --- a/sftp/handler.go +++ b/sftp/handler.go @@ -11,7 +11,9 @@ import ( "emperror.dev/errors" "github.com/apex/log" "github.com/pkg/sftp" + "github.com/pterodactyl/wings/config" "github.com/pterodactyl/wings/server/filesystem" + "golang.org/x/crypto/ssh" ) const ( @@ -30,6 +32,31 @@ type Handler struct { ro bool } +// Returns a new connection handler for the SFTP server. This allows a given user +// to access the underlying filesystem. +func NewHandler(sc *ssh.ServerConn, fs *filesystem.Filesystem) *Handler { + return &Handler{ + fs: fs, + ro: config.Get().System.Sftp.ReadOnly, + permissions: strings.Split(sc.Permissions.Extensions["permissions"], ","), + logger: log.WithFields(log.Fields{ + "subsystem": "sftp", + "username": sc.User(), + "ip": sc.RemoteAddr(), + }), + } +} + +// Returns the sftp.Handlers for this struct. +func (h *Handler) Handlers() sftp.Handlers { + return sftp.Handlers{ + FileGet: h, + FilePut: h, + FileCmd: h, + FileList: h, + } +} + // Fileread creates a reader for a file on the system and returns the reader back. func (h *Handler) Fileread(request *sftp.Request) (io.ReaderAt, error) { // Check first if the user can actually open and view a file. This permission is named diff --git a/sftp/server.go b/sftp/server.go index 4697036..3604430 100644 --- a/sftp/server.go +++ b/sftp/server.go @@ -5,12 +5,12 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" - "fmt" "io" "io/ioutil" "net" "os" "path" + "strconv" "strings" "emperror.dev/errors" @@ -24,181 +24,140 @@ import ( //goland:noinspection GoNameStartsWithPackageName type SFTPServer struct { - BasePath string - ReadOnly bool - BindPort int - BindAddress string + BasePath string + ReadOnly bool + Listen string } -var noMatchingServerError = errors.Sentinel("sftp: no matching server with UUID") - -func NewServer() *SFTPServer { +func New() *SFTPServer { cfg := config.Get().System return &SFTPServer{ - BasePath: cfg.Data, - ReadOnly: cfg.Sftp.ReadOnly, - BindAddress: cfg.Sftp.Address, - BindPort: cfg.Sftp.Port, + BasePath: cfg.Data, + ReadOnly: cfg.Sftp.ReadOnly, + Listen: cfg.Sftp.Address + ":" + strconv.Itoa(cfg.Sftp.Port), } } // Starts the SFTP server and add a persistent listener to handle inbound SFTP connections. func (c *SFTPServer) Run() error { - serverConfig := &ssh.ServerConfig{ - NoClientAuth: false, - MaxAuthTries: 6, - PasswordCallback: c.passwordCallback, - } - if _, err := os.Stat(path.Join(c.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) { if err := c.generatePrivateKey(); err != nil { return err } } else if err != nil { - return err + return errors.Wrap(err, "sftp/server: could not stat private key file") } - - privateBytes, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa")) + pb, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa")) + if err != nil { + return errors.Wrap(err, "sftp/server: could not read private key file") + } + private, err := ssh.ParsePrivateKey(pb) if err != nil { return err } - private, err := ssh.ParsePrivateKey(privateBytes) + conf := &ssh.ServerConfig{ + NoClientAuth: false, + MaxAuthTries: 6, + PasswordCallback: c.passwordCallback, + } + conf.AddHostKey(private) + + listener, err := net.Listen("tcp", c.Listen) if err != nil { return err } - // Add our private key to the server configuration. - serverConfig.AddHostKey(private) - - listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort)) - if err != nil { - return err - } - - log.WithField("host", c.BindAddress).WithField("port", c.BindPort).Info("sftp subsystem listening for connections") - + log.WithField("listen", c.Listen).Info("sftp server listening for connections") for { - conn, _ := listener.Accept() - if conn != nil { - go c.AcceptInboundConnection(conn, serverConfig) + if conn, _ := listener.Accept(); conn != nil { + go func(conn net.Conn) { + defer conn.Close() + c.AcceptInbound(conn, conf) + }(conn) } } } -// Handles an inbound connection to the instance and determines if we should serve the request -// or not. -func (c SFTPServer) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { - defer conn.Close() - +// Handles an inbound connection to the instance and determines if we should serve the +// request or not. +func (c SFTPServer) AcceptInbound(conn net.Conn, config *ssh.ServerConfig) { // Before beginning a handshake must be performed on the incoming net.Conn sconn, chans, reqs, err := ssh.NewServerConn(conn, config) if err != nil { return } defer sconn.Close() - go ssh.DiscardRequests(reqs) - for newChannel := range chans { + for ch := range chans { // If its not a session channel we just move on because its not something we // know how to handle at this point. - if newChannel.ChannelType() != "session" { - newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + if ch.ChannelType() != "session" { + ch.Reject(ssh.UnknownChannelType, "unknown channel type") continue } - channel, requests, err := newChannel.Accept() + channel, requests, err := ch.Accept() if err != nil { continue } - // Channels have a type that is dependent on the protocol. For SFTP this is "subsystem" - // with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc) go func(in <-chan *ssh.Request) { for req := range in { - ok := false - - switch req.Type { - case "subsystem": - if string(req.Payload[4:]) == "sftp" { - ok = true - } - } - - req.Reply(ok, nil) + // Channels have a type that is dependent on the protocol. For SFTP + // this is "subsystem" with a payload that (should) be "sftp". Discard + // anything else we receive ("pty", "shell", etc) + req.Reply(req.Type == "subsystem" && string(req.Payload[4:]) == "sftp", nil) } }(requests) - if sconn.Permissions.Extensions["uuid"] == "" { + // If no UUID has been set on this inbound request then we can assume we + // have screwed up something in the authentication code. This is a sanity + // check, but should never be encountered (ideally...). + // + // This will also attempt to match a specific server out of the global server + // store and return nil if there is no match. + uuid := sconn.Permissions.Extensions["uuid"] + srv := server.GetServers().Find(func(s *server.Server) bool { + if uuid == "" { + return false + } + return s.Id() == uuid + }) + if srv == nil { continue } - // Create a new handler for the currently logged in user's server. - fs := c.newHandler(sconn) - - // Create the server instance for the channel using the filesystem we created above. - handler := sftp.NewRequestServer(channel, fs) + // Spin up a SFTP server instance for the authenticated user's server allowing + // them access to the underlying filesystem. + handler := sftp.NewRequestServer(channel, NewHandler(sconn, srv.Filesystem()).Handlers()) if err := handler.Serve(); err == io.EOF { handler.Close() } } } -// Creates a new SFTP handler for a given server. The directory argument should -// be the base directory for a server. All actions done on the server will be -// relative to that directory, and the user will not be able to escape out of it. -func (c *SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers { - s := server.GetServers().Find(func(s *server.Server) bool { - return s.Id() == sc.Permissions.Extensions["uuid"] - }) - - p := Handler{ - fs: s.Filesystem(), - permissions: strings.Split(sc.Permissions.Extensions["permissions"], ","), - ro: config.Get().System.Sftp.ReadOnly, - logger: log.WithFields(log.Fields{ - "subsystem": "sftp", - "username": sc.User(), - "ip": sc.RemoteAddr(), - }), - } - - return sftp.Handlers{ - FileGet: &p, - FilePut: &p, - FileCmd: &p, - FileList: &p, - } -} - // Generates a private key that will be used by the SFTP server. func (c *SFTPServer) generatePrivateKey() error { key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return errors.WithStack(err) } - if err := os.MkdirAll(path.Join(c.BasePath, ".sftp"), 0755); err != nil { - return err + return errors.Wrap(err, "sftp/server: could not create .sftp directory") } - o, err := os.OpenFile(path.Join(c.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { - return err + return errors.WithStack(err) } defer o.Close() - pkey := &pem.Block{ + err = pem.Encode(o, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), - } - - if err := pem.Encode(o, pkey); err != nil { - return err - } - - return nil + }) + return errors.WithStack(err) } // A function capable of validating user credentials with the Panel API.