Finish refactoring SFTP server logic
This commit is contained in:
		
							parent
							
								
									a48abc92ad
								
							
						
					
					
						commit
						0cb3b815d1
					
				| 
						 | 
				
			
			@ -273,7 +273,7 @@ func rootCmdRun(cmd *cobra.Command, _ []string) {
 | 
			
		|||
	pool.StopWait()
 | 
			
		||||
 | 
			
		||||
	// Run the SFTP server.
 | 
			
		||||
	if err := sftp.NewServer().Run(); err != nil {
 | 
			
		||||
	if err := sftp.New().Run(); err != nil {
 | 
			
		||||
		log.WithError(err).Fatal("failed to initialize the sftp server")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,7 +11,9 @@ import (
 | 
			
		|||
	"emperror.dev/errors"
 | 
			
		||||
	"github.com/apex/log"
 | 
			
		||||
	"github.com/pkg/sftp"
 | 
			
		||||
	"github.com/pterodactyl/wings/config"
 | 
			
		||||
	"github.com/pterodactyl/wings/server/filesystem"
 | 
			
		||||
	"golang.org/x/crypto/ssh"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
| 
						 | 
				
			
			@ -30,6 +32,31 @@ type Handler struct {
 | 
			
		|||
	ro          bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns a new connection handler for the SFTP server. This allows a given user
 | 
			
		||||
// to access the underlying filesystem.
 | 
			
		||||
func NewHandler(sc *ssh.ServerConn, fs *filesystem.Filesystem) *Handler {
 | 
			
		||||
	return &Handler{
 | 
			
		||||
		fs:          fs,
 | 
			
		||||
		ro:          config.Get().System.Sftp.ReadOnly,
 | 
			
		||||
		permissions: strings.Split(sc.Permissions.Extensions["permissions"], ","),
 | 
			
		||||
		logger: log.WithFields(log.Fields{
 | 
			
		||||
			"subsystem": "sftp",
 | 
			
		||||
			"username":  sc.User(),
 | 
			
		||||
			"ip":        sc.RemoteAddr(),
 | 
			
		||||
		}),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns the sftp.Handlers for this struct.
 | 
			
		||||
func (h *Handler) Handlers() sftp.Handlers {
 | 
			
		||||
	return sftp.Handlers{
 | 
			
		||||
		FileGet:  h,
 | 
			
		||||
		FilePut:  h,
 | 
			
		||||
		FileCmd:  h,
 | 
			
		||||
		FileList: h,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fileread creates a reader for a file on the system and returns the reader back.
 | 
			
		||||
func (h *Handler) Fileread(request *sftp.Request) (io.ReaderAt, error) {
 | 
			
		||||
	// Check first if the user can actually open and view a file. This permission is named
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										165
									
								
								sftp/server.go
									
									
									
									
									
								
							
							
						
						
									
										165
									
								
								sftp/server.go
									
									
									
									
									
								
							| 
						 | 
				
			
			@ -5,12 +5,12 @@ import (
 | 
			
		|||
	"crypto/rsa"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"encoding/pem"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"net"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
| 
						 | 
				
			
			@ -24,181 +24,140 @@ import (
 | 
			
		|||
 | 
			
		||||
//goland:noinspection GoNameStartsWithPackageName
 | 
			
		||||
type SFTPServer struct {
 | 
			
		||||
	BasePath    string
 | 
			
		||||
	ReadOnly    bool
 | 
			
		||||
	BindPort    int
 | 
			
		||||
	BindAddress string
 | 
			
		||||
	BasePath string
 | 
			
		||||
	ReadOnly bool
 | 
			
		||||
	Listen   string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var noMatchingServerError = errors.Sentinel("sftp: no matching server with UUID")
 | 
			
		||||
 | 
			
		||||
func NewServer() *SFTPServer {
 | 
			
		||||
func New() *SFTPServer {
 | 
			
		||||
	cfg := config.Get().System
 | 
			
		||||
	return &SFTPServer{
 | 
			
		||||
		BasePath:    cfg.Data,
 | 
			
		||||
		ReadOnly:    cfg.Sftp.ReadOnly,
 | 
			
		||||
		BindAddress: cfg.Sftp.Address,
 | 
			
		||||
		BindPort:    cfg.Sftp.Port,
 | 
			
		||||
		BasePath: cfg.Data,
 | 
			
		||||
		ReadOnly: cfg.Sftp.ReadOnly,
 | 
			
		||||
		Listen:   cfg.Sftp.Address + ":" + strconv.Itoa(cfg.Sftp.Port),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Starts the SFTP server and add a persistent listener to handle inbound SFTP connections.
 | 
			
		||||
func (c *SFTPServer) Run() error {
 | 
			
		||||
	serverConfig := &ssh.ServerConfig{
 | 
			
		||||
		NoClientAuth:     false,
 | 
			
		||||
		MaxAuthTries:     6,
 | 
			
		||||
		PasswordCallback: c.passwordCallback,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := os.Stat(path.Join(c.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) {
 | 
			
		||||
		if err := c.generatePrivateKey(); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	} else if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errors.Wrap(err, "sftp/server: could not stat private key file")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	privateBytes, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa"))
 | 
			
		||||
	pb, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa"))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errors.Wrap(err, "sftp/server: could not read private key file")
 | 
			
		||||
	}
 | 
			
		||||
	private, err := ssh.ParsePrivateKey(pb)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	private, err := ssh.ParsePrivateKey(privateBytes)
 | 
			
		||||
	conf := &ssh.ServerConfig{
 | 
			
		||||
		NoClientAuth:     false,
 | 
			
		||||
		MaxAuthTries:     6,
 | 
			
		||||
		PasswordCallback: c.passwordCallback,
 | 
			
		||||
	}
 | 
			
		||||
	conf.AddHostKey(private)
 | 
			
		||||
 | 
			
		||||
	listener, err := net.Listen("tcp", c.Listen)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Add our private key to the server configuration.
 | 
			
		||||
	serverConfig.AddHostKey(private)
 | 
			
		||||
 | 
			
		||||
	listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.WithField("host", c.BindAddress).WithField("port", c.BindPort).Info("sftp subsystem listening for connections")
 | 
			
		||||
 | 
			
		||||
	log.WithField("listen", c.Listen).Info("sftp server listening for connections")
 | 
			
		||||
	for {
 | 
			
		||||
		conn, _ := listener.Accept()
 | 
			
		||||
		if conn != nil {
 | 
			
		||||
			go c.AcceptInboundConnection(conn, serverConfig)
 | 
			
		||||
		if conn, _ := listener.Accept(); conn != nil {
 | 
			
		||||
			go func(conn net.Conn) {
 | 
			
		||||
				defer conn.Close()
 | 
			
		||||
				c.AcceptInbound(conn, conf)
 | 
			
		||||
			}(conn)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handles an inbound connection to the instance and determines if we should serve the request
 | 
			
		||||
// or not.
 | 
			
		||||
func (c SFTPServer) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
 | 
			
		||||
	defer conn.Close()
 | 
			
		||||
 | 
			
		||||
// Handles an inbound connection to the instance and determines if we should serve the
 | 
			
		||||
// request or not.
 | 
			
		||||
func (c SFTPServer) AcceptInbound(conn net.Conn, config *ssh.ServerConfig) {
 | 
			
		||||
	// Before beginning a handshake must be performed on the incoming net.Conn
 | 
			
		||||
	sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	defer sconn.Close()
 | 
			
		||||
 | 
			
		||||
	go ssh.DiscardRequests(reqs)
 | 
			
		||||
 | 
			
		||||
	for newChannel := range chans {
 | 
			
		||||
	for ch := range chans {
 | 
			
		||||
		// If its not a session channel we just move on because its not something we
 | 
			
		||||
		// know how to handle at this point.
 | 
			
		||||
		if newChannel.ChannelType() != "session" {
 | 
			
		||||
			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
 | 
			
		||||
		if ch.ChannelType() != "session" {
 | 
			
		||||
			ch.Reject(ssh.UnknownChannelType, "unknown channel type")
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		channel, requests, err := newChannel.Accept()
 | 
			
		||||
		channel, requests, err := ch.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
 | 
			
		||||
		// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
 | 
			
		||||
		go func(in <-chan *ssh.Request) {
 | 
			
		||||
			for req := range in {
 | 
			
		||||
				ok := false
 | 
			
		||||
 | 
			
		||||
				switch req.Type {
 | 
			
		||||
				case "subsystem":
 | 
			
		||||
					if string(req.Payload[4:]) == "sftp" {
 | 
			
		||||
						ok = true
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				req.Reply(ok, nil)
 | 
			
		||||
				// Channels have a type that is dependent on the protocol. For SFTP
 | 
			
		||||
				// this is "subsystem" with a payload that (should) be "sftp". Discard
 | 
			
		||||
				// anything else we receive ("pty", "shell", etc)
 | 
			
		||||
				req.Reply(req.Type == "subsystem" && string(req.Payload[4:]) == "sftp", nil)
 | 
			
		||||
			}
 | 
			
		||||
		}(requests)
 | 
			
		||||
 | 
			
		||||
		if sconn.Permissions.Extensions["uuid"] == "" {
 | 
			
		||||
		// If no UUID has been set on this inbound request then we can assume we
 | 
			
		||||
		// have screwed up something in the authentication code. This is a sanity
 | 
			
		||||
		// check, but should never be encountered (ideally...).
 | 
			
		||||
		//
 | 
			
		||||
		// This will also attempt to match a specific server out of the global server
 | 
			
		||||
		// store and return nil if there is no match.
 | 
			
		||||
		uuid := sconn.Permissions.Extensions["uuid"]
 | 
			
		||||
		srv := server.GetServers().Find(func(s *server.Server) bool {
 | 
			
		||||
			if uuid == "" {
 | 
			
		||||
				return false
 | 
			
		||||
			}
 | 
			
		||||
			return s.Id() == uuid
 | 
			
		||||
		})
 | 
			
		||||
		if srv == nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Create a new handler for the currently logged in user's server.
 | 
			
		||||
		fs := c.newHandler(sconn)
 | 
			
		||||
 | 
			
		||||
		// Create the server instance for the channel using the filesystem we created above.
 | 
			
		||||
		handler := sftp.NewRequestServer(channel, fs)
 | 
			
		||||
		// Spin up a SFTP server instance for the authenticated user's server allowing
 | 
			
		||||
		// them access to the underlying filesystem.
 | 
			
		||||
		handler := sftp.NewRequestServer(channel, NewHandler(sconn, srv.Filesystem()).Handlers())
 | 
			
		||||
		if err := handler.Serve(); err == io.EOF {
 | 
			
		||||
			handler.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Creates a new SFTP handler for a given server. The directory argument should
 | 
			
		||||
// be the base directory for a server. All actions done on the server will be
 | 
			
		||||
// relative to that directory, and the user will not be able to escape out of it.
 | 
			
		||||
func (c *SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers {
 | 
			
		||||
	s := server.GetServers().Find(func(s *server.Server) bool {
 | 
			
		||||
		return s.Id() == sc.Permissions.Extensions["uuid"]
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	p := Handler{
 | 
			
		||||
		fs:          s.Filesystem(),
 | 
			
		||||
		permissions: strings.Split(sc.Permissions.Extensions["permissions"], ","),
 | 
			
		||||
		ro:          config.Get().System.Sftp.ReadOnly,
 | 
			
		||||
		logger: log.WithFields(log.Fields{
 | 
			
		||||
			"subsystem": "sftp",
 | 
			
		||||
			"username":  sc.User(),
 | 
			
		||||
			"ip":        sc.RemoteAddr(),
 | 
			
		||||
		}),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sftp.Handlers{
 | 
			
		||||
		FileGet:  &p,
 | 
			
		||||
		FilePut:  &p,
 | 
			
		||||
		FileCmd:  &p,
 | 
			
		||||
		FileList: &p,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Generates a private key that will be used by the SFTP server.
 | 
			
		||||
func (c *SFTPServer) generatePrivateKey() error {
 | 
			
		||||
	key, err := rsa.GenerateKey(rand.Reader, 2048)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errors.WithStack(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := os.MkdirAll(path.Join(c.BasePath, ".sftp"), 0755); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errors.Wrap(err, "sftp/server: could not create .sftp directory")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	o, err := os.OpenFile(path.Join(c.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
		return errors.WithStack(err)
 | 
			
		||||
	}
 | 
			
		||||
	defer o.Close()
 | 
			
		||||
 | 
			
		||||
	pkey := &pem.Block{
 | 
			
		||||
	err = pem.Encode(o, &pem.Block{
 | 
			
		||||
		Type:  "RSA PRIVATE KEY",
 | 
			
		||||
		Bytes: x509.MarshalPKCS1PrivateKey(key),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := pem.Encode(o, pkey); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
	})
 | 
			
		||||
	return errors.WithStack(err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A function capable of validating user credentials with the Panel API.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user