Finish refactoring SFTP server logic

This commit is contained in:
Dane Everitt 2021-01-10 15:06:06 -08:00
parent a48abc92ad
commit 0cb3b815d1
No known key found for this signature in database
GPG Key ID: EEA66103B3D71F53
3 changed files with 90 additions and 104 deletions

View File

@ -273,7 +273,7 @@ func rootCmdRun(cmd *cobra.Command, _ []string) {
pool.StopWait() pool.StopWait()
// Run the SFTP server. // Run the SFTP server.
if err := sftp.NewServer().Run(); err != nil { if err := sftp.New().Run(); err != nil {
log.WithError(err).Fatal("failed to initialize the sftp server") log.WithError(err).Fatal("failed to initialize the sftp server")
return return
} }

View File

@ -11,7 +11,9 @@ import (
"emperror.dev/errors" "emperror.dev/errors"
"github.com/apex/log" "github.com/apex/log"
"github.com/pkg/sftp" "github.com/pkg/sftp"
"github.com/pterodactyl/wings/config"
"github.com/pterodactyl/wings/server/filesystem" "github.com/pterodactyl/wings/server/filesystem"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -30,6 +32,31 @@ type Handler struct {
ro bool ro bool
} }
// Returns a new connection handler for the SFTP server. This allows a given user
// to access the underlying filesystem.
func NewHandler(sc *ssh.ServerConn, fs *filesystem.Filesystem) *Handler {
return &Handler{
fs: fs,
ro: config.Get().System.Sftp.ReadOnly,
permissions: strings.Split(sc.Permissions.Extensions["permissions"], ","),
logger: log.WithFields(log.Fields{
"subsystem": "sftp",
"username": sc.User(),
"ip": sc.RemoteAddr(),
}),
}
}
// Returns the sftp.Handlers for this struct.
func (h *Handler) Handlers() sftp.Handlers {
return sftp.Handlers{
FileGet: h,
FilePut: h,
FileCmd: h,
FileList: h,
}
}
// Fileread creates a reader for a file on the system and returns the reader back. // Fileread creates a reader for a file on the system and returns the reader back.
func (h *Handler) Fileread(request *sftp.Request) (io.ReaderAt, error) { func (h *Handler) Fileread(request *sftp.Request) (io.ReaderAt, error) {
// Check first if the user can actually open and view a file. This permission is named // Check first if the user can actually open and view a file. This permission is named

View File

@ -5,12 +5,12 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"path" "path"
"strconv"
"strings" "strings"
"emperror.dev/errors" "emperror.dev/errors"
@ -24,181 +24,140 @@ import (
//goland:noinspection GoNameStartsWithPackageName //goland:noinspection GoNameStartsWithPackageName
type SFTPServer struct { type SFTPServer struct {
BasePath string BasePath string
ReadOnly bool ReadOnly bool
BindPort int Listen string
BindAddress string
} }
var noMatchingServerError = errors.Sentinel("sftp: no matching server with UUID") func New() *SFTPServer {
func NewServer() *SFTPServer {
cfg := config.Get().System cfg := config.Get().System
return &SFTPServer{ return &SFTPServer{
BasePath: cfg.Data, BasePath: cfg.Data,
ReadOnly: cfg.Sftp.ReadOnly, ReadOnly: cfg.Sftp.ReadOnly,
BindAddress: cfg.Sftp.Address, Listen: cfg.Sftp.Address + ":" + strconv.Itoa(cfg.Sftp.Port),
BindPort: cfg.Sftp.Port,
} }
} }
// Starts the SFTP server and add a persistent listener to handle inbound SFTP connections. // Starts the SFTP server and add a persistent listener to handle inbound SFTP connections.
func (c *SFTPServer) Run() error { func (c *SFTPServer) Run() error {
serverConfig := &ssh.ServerConfig{
NoClientAuth: false,
MaxAuthTries: 6,
PasswordCallback: c.passwordCallback,
}
if _, err := os.Stat(path.Join(c.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) { if _, err := os.Stat(path.Join(c.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) {
if err := c.generatePrivateKey(); err != nil { if err := c.generatePrivateKey(); err != nil {
return err return err
} }
} else if err != nil { } else if err != nil {
return err return errors.Wrap(err, "sftp/server: could not stat private key file")
} }
pb, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa"))
privateBytes, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa")) if err != nil {
return errors.Wrap(err, "sftp/server: could not read private key file")
}
private, err := ssh.ParsePrivateKey(pb)
if err != nil { if err != nil {
return err return err
} }
private, err := ssh.ParsePrivateKey(privateBytes) conf := &ssh.ServerConfig{
NoClientAuth: false,
MaxAuthTries: 6,
PasswordCallback: c.passwordCallback,
}
conf.AddHostKey(private)
listener, err := net.Listen("tcp", c.Listen)
if err != nil { if err != nil {
return err return err
} }
// Add our private key to the server configuration. log.WithField("listen", c.Listen).Info("sftp server listening for connections")
serverConfig.AddHostKey(private)
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
if err != nil {
return err
}
log.WithField("host", c.BindAddress).WithField("port", c.BindPort).Info("sftp subsystem listening for connections")
for { for {
conn, _ := listener.Accept() if conn, _ := listener.Accept(); conn != nil {
if conn != nil { go func(conn net.Conn) {
go c.AcceptInboundConnection(conn, serverConfig) defer conn.Close()
c.AcceptInbound(conn, conf)
}(conn)
} }
} }
} }
// Handles an inbound connection to the instance and determines if we should serve the request // Handles an inbound connection to the instance and determines if we should serve the
// or not. // request or not.
func (c SFTPServer) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) { func (c SFTPServer) AcceptInbound(conn net.Conn, config *ssh.ServerConfig) {
defer conn.Close()
// Before beginning a handshake must be performed on the incoming net.Conn // Before beginning a handshake must be performed on the incoming net.Conn
sconn, chans, reqs, err := ssh.NewServerConn(conn, config) sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil { if err != nil {
return return
} }
defer sconn.Close() defer sconn.Close()
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
for newChannel := range chans { for ch := range chans {
// If its not a session channel we just move on because its not something we // If its not a session channel we just move on because its not something we
// know how to handle at this point. // know how to handle at this point.
if newChannel.ChannelType() != "session" { if ch.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") ch.Reject(ssh.UnknownChannelType, "unknown channel type")
continue continue
} }
channel, requests, err := newChannel.Accept() channel, requests, err := ch.Accept()
if err != nil { if err != nil {
continue continue
} }
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
go func(in <-chan *ssh.Request) { go func(in <-chan *ssh.Request) {
for req := range in { for req := range in {
ok := false // Channels have a type that is dependent on the protocol. For SFTP
// this is "subsystem" with a payload that (should) be "sftp". Discard
switch req.Type { // anything else we receive ("pty", "shell", etc)
case "subsystem": req.Reply(req.Type == "subsystem" && string(req.Payload[4:]) == "sftp", nil)
if string(req.Payload[4:]) == "sftp" {
ok = true
}
}
req.Reply(ok, nil)
} }
}(requests) }(requests)
if sconn.Permissions.Extensions["uuid"] == "" { // If no UUID has been set on this inbound request then we can assume we
// have screwed up something in the authentication code. This is a sanity
// check, but should never be encountered (ideally...).
//
// This will also attempt to match a specific server out of the global server
// store and return nil if there is no match.
uuid := sconn.Permissions.Extensions["uuid"]
srv := server.GetServers().Find(func(s *server.Server) bool {
if uuid == "" {
return false
}
return s.Id() == uuid
})
if srv == nil {
continue continue
} }
// Create a new handler for the currently logged in user's server. // Spin up a SFTP server instance for the authenticated user's server allowing
fs := c.newHandler(sconn) // them access to the underlying filesystem.
handler := sftp.NewRequestServer(channel, NewHandler(sconn, srv.Filesystem()).Handlers())
// Create the server instance for the channel using the filesystem we created above.
handler := sftp.NewRequestServer(channel, fs)
if err := handler.Serve(); err == io.EOF { if err := handler.Serve(); err == io.EOF {
handler.Close() handler.Close()
} }
} }
} }
// Creates a new SFTP handler for a given server. The directory argument should
// be the base directory for a server. All actions done on the server will be
// relative to that directory, and the user will not be able to escape out of it.
func (c *SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers {
s := server.GetServers().Find(func(s *server.Server) bool {
return s.Id() == sc.Permissions.Extensions["uuid"]
})
p := Handler{
fs: s.Filesystem(),
permissions: strings.Split(sc.Permissions.Extensions["permissions"], ","),
ro: config.Get().System.Sftp.ReadOnly,
logger: log.WithFields(log.Fields{
"subsystem": "sftp",
"username": sc.User(),
"ip": sc.RemoteAddr(),
}),
}
return sftp.Handlers{
FileGet: &p,
FilePut: &p,
FileCmd: &p,
FileList: &p,
}
}
// Generates a private key that will be used by the SFTP server. // Generates a private key that will be used by the SFTP server.
func (c *SFTPServer) generatePrivateKey() error { func (c *SFTPServer) generatePrivateKey() error {
key, err := rsa.GenerateKey(rand.Reader, 2048) key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil { if err != nil {
return err return errors.WithStack(err)
} }
if err := os.MkdirAll(path.Join(c.BasePath, ".sftp"), 0755); err != nil { if err := os.MkdirAll(path.Join(c.BasePath, ".sftp"), 0755); err != nil {
return err return errors.Wrap(err, "sftp/server: could not create .sftp directory")
} }
o, err := os.OpenFile(path.Join(c.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) o, err := os.OpenFile(path.Join(c.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil { if err != nil {
return err return errors.WithStack(err)
} }
defer o.Close() defer o.Close()
pkey := &pem.Block{ err = pem.Encode(o, &pem.Block{
Type: "RSA PRIVATE KEY", Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key), Bytes: x509.MarshalPKCS1PrivateKey(key),
} })
return errors.WithStack(err)
if err := pem.Encode(o, pkey); err != nil {
return err
}
return nil
} }
// A function capable of validating user credentials with the Panel API. // A function capable of validating user credentials with the Panel API.