Finish refactoring SFTP server logic
This commit is contained in:
parent
a48abc92ad
commit
0cb3b815d1
|
@ -273,7 +273,7 @@ func rootCmdRun(cmd *cobra.Command, _ []string) {
|
||||||
pool.StopWait()
|
pool.StopWait()
|
||||||
|
|
||||||
// Run the SFTP server.
|
// Run the SFTP server.
|
||||||
if err := sftp.NewServer().Run(); err != nil {
|
if err := sftp.New().Run(); err != nil {
|
||||||
log.WithError(err).Fatal("failed to initialize the sftp server")
|
log.WithError(err).Fatal("failed to initialize the sftp server")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,9 @@ import (
|
||||||
"emperror.dev/errors"
|
"emperror.dev/errors"
|
||||||
"github.com/apex/log"
|
"github.com/apex/log"
|
||||||
"github.com/pkg/sftp"
|
"github.com/pkg/sftp"
|
||||||
|
"github.com/pterodactyl/wings/config"
|
||||||
"github.com/pterodactyl/wings/server/filesystem"
|
"github.com/pterodactyl/wings/server/filesystem"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -30,6 +32,31 @@ type Handler struct {
|
||||||
ro bool
|
ro bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a new connection handler for the SFTP server. This allows a given user
|
||||||
|
// to access the underlying filesystem.
|
||||||
|
func NewHandler(sc *ssh.ServerConn, fs *filesystem.Filesystem) *Handler {
|
||||||
|
return &Handler{
|
||||||
|
fs: fs,
|
||||||
|
ro: config.Get().System.Sftp.ReadOnly,
|
||||||
|
permissions: strings.Split(sc.Permissions.Extensions["permissions"], ","),
|
||||||
|
logger: log.WithFields(log.Fields{
|
||||||
|
"subsystem": "sftp",
|
||||||
|
"username": sc.User(),
|
||||||
|
"ip": sc.RemoteAddr(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the sftp.Handlers for this struct.
|
||||||
|
func (h *Handler) Handlers() sftp.Handlers {
|
||||||
|
return sftp.Handlers{
|
||||||
|
FileGet: h,
|
||||||
|
FilePut: h,
|
||||||
|
FileCmd: h,
|
||||||
|
FileList: h,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Fileread creates a reader for a file on the system and returns the reader back.
|
// Fileread creates a reader for a file on the system and returns the reader back.
|
||||||
func (h *Handler) Fileread(request *sftp.Request) (io.ReaderAt, error) {
|
func (h *Handler) Fileread(request *sftp.Request) (io.ReaderAt, error) {
|
||||||
// Check first if the user can actually open and view a file. This permission is named
|
// Check first if the user can actually open and view a file. This permission is named
|
||||||
|
|
161
sftp/server.go
161
sftp/server.go
|
@ -5,12 +5,12 @@ import (
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"emperror.dev/errors"
|
"emperror.dev/errors"
|
||||||
|
@ -26,179 +26,138 @@ import (
|
||||||
type SFTPServer struct {
|
type SFTPServer struct {
|
||||||
BasePath string
|
BasePath string
|
||||||
ReadOnly bool
|
ReadOnly bool
|
||||||
BindPort int
|
Listen string
|
||||||
BindAddress string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var noMatchingServerError = errors.Sentinel("sftp: no matching server with UUID")
|
func New() *SFTPServer {
|
||||||
|
|
||||||
func NewServer() *SFTPServer {
|
|
||||||
cfg := config.Get().System
|
cfg := config.Get().System
|
||||||
return &SFTPServer{
|
return &SFTPServer{
|
||||||
BasePath: cfg.Data,
|
BasePath: cfg.Data,
|
||||||
ReadOnly: cfg.Sftp.ReadOnly,
|
ReadOnly: cfg.Sftp.ReadOnly,
|
||||||
BindAddress: cfg.Sftp.Address,
|
Listen: cfg.Sftp.Address + ":" + strconv.Itoa(cfg.Sftp.Port),
|
||||||
BindPort: cfg.Sftp.Port,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Starts the SFTP server and add a persistent listener to handle inbound SFTP connections.
|
// Starts the SFTP server and add a persistent listener to handle inbound SFTP connections.
|
||||||
func (c *SFTPServer) Run() error {
|
func (c *SFTPServer) Run() error {
|
||||||
serverConfig := &ssh.ServerConfig{
|
|
||||||
NoClientAuth: false,
|
|
||||||
MaxAuthTries: 6,
|
|
||||||
PasswordCallback: c.passwordCallback,
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := os.Stat(path.Join(c.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) {
|
if _, err := os.Stat(path.Join(c.BasePath, ".sftp/id_rsa")); os.IsNotExist(err) {
|
||||||
if err := c.generatePrivateKey(); err != nil {
|
if err := c.generatePrivateKey(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "sftp/server: could not stat private key file")
|
||||||
}
|
}
|
||||||
|
pb, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa"))
|
||||||
privateBytes, err := ioutil.ReadFile(path.Join(c.BasePath, ".sftp/id_rsa"))
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "sftp/server: could not read private key file")
|
||||||
|
}
|
||||||
|
private, err := ssh.ParsePrivateKey(pb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
conf := &ssh.ServerConfig{
|
||||||
|
NoClientAuth: false,
|
||||||
|
MaxAuthTries: 6,
|
||||||
|
PasswordCallback: c.passwordCallback,
|
||||||
|
}
|
||||||
|
conf.AddHostKey(private)
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", c.Listen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add our private key to the server configuration.
|
log.WithField("listen", c.Listen).Info("sftp server listening for connections")
|
||||||
serverConfig.AddHostKey(private)
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", c.BindAddress, c.BindPort))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithField("host", c.BindAddress).WithField("port", c.BindPort).Info("sftp subsystem listening for connections")
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, _ := listener.Accept()
|
if conn, _ := listener.Accept(); conn != nil {
|
||||||
if conn != nil {
|
go func(conn net.Conn) {
|
||||||
go c.AcceptInboundConnection(conn, serverConfig)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handles an inbound connection to the instance and determines if we should serve the request
|
|
||||||
// or not.
|
|
||||||
func (c SFTPServer) AcceptInboundConnection(conn net.Conn, config *ssh.ServerConfig) {
|
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
c.AcceptInbound(conn, conf)
|
||||||
|
}(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handles an inbound connection to the instance and determines if we should serve the
|
||||||
|
// request or not.
|
||||||
|
func (c SFTPServer) AcceptInbound(conn net.Conn, config *ssh.ServerConfig) {
|
||||||
// Before beginning a handshake must be performed on the incoming net.Conn
|
// Before beginning a handshake must be performed on the incoming net.Conn
|
||||||
sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
|
sconn, chans, reqs, err := ssh.NewServerConn(conn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer sconn.Close()
|
defer sconn.Close()
|
||||||
|
|
||||||
go ssh.DiscardRequests(reqs)
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
for newChannel := range chans {
|
for ch := range chans {
|
||||||
// If its not a session channel we just move on because its not something we
|
// If its not a session channel we just move on because its not something we
|
||||||
// know how to handle at this point.
|
// know how to handle at this point.
|
||||||
if newChannel.ChannelType() != "session" {
|
if ch.ChannelType() != "session" {
|
||||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
ch.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
channel, requests, err := newChannel.Accept()
|
channel, requests, err := ch.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Channels have a type that is dependent on the protocol. For SFTP this is "subsystem"
|
|
||||||
// with a payload that (should) be "sftp". Discard anything else we receive ("pty", "shell", etc)
|
|
||||||
go func(in <-chan *ssh.Request) {
|
go func(in <-chan *ssh.Request) {
|
||||||
for req := range in {
|
for req := range in {
|
||||||
ok := false
|
// Channels have a type that is dependent on the protocol. For SFTP
|
||||||
|
// this is "subsystem" with a payload that (should) be "sftp". Discard
|
||||||
switch req.Type {
|
// anything else we receive ("pty", "shell", etc)
|
||||||
case "subsystem":
|
req.Reply(req.Type == "subsystem" && string(req.Payload[4:]) == "sftp", nil)
|
||||||
if string(req.Payload[4:]) == "sftp" {
|
|
||||||
ok = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Reply(ok, nil)
|
|
||||||
}
|
}
|
||||||
}(requests)
|
}(requests)
|
||||||
|
|
||||||
if sconn.Permissions.Extensions["uuid"] == "" {
|
// If no UUID has been set on this inbound request then we can assume we
|
||||||
|
// have screwed up something in the authentication code. This is a sanity
|
||||||
|
// check, but should never be encountered (ideally...).
|
||||||
|
//
|
||||||
|
// This will also attempt to match a specific server out of the global server
|
||||||
|
// store and return nil if there is no match.
|
||||||
|
uuid := sconn.Permissions.Extensions["uuid"]
|
||||||
|
srv := server.GetServers().Find(func(s *server.Server) bool {
|
||||||
|
if uuid == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.Id() == uuid
|
||||||
|
})
|
||||||
|
if srv == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new handler for the currently logged in user's server.
|
// Spin up a SFTP server instance for the authenticated user's server allowing
|
||||||
fs := c.newHandler(sconn)
|
// them access to the underlying filesystem.
|
||||||
|
handler := sftp.NewRequestServer(channel, NewHandler(sconn, srv.Filesystem()).Handlers())
|
||||||
// Create the server instance for the channel using the filesystem we created above.
|
|
||||||
handler := sftp.NewRequestServer(channel, fs)
|
|
||||||
if err := handler.Serve(); err == io.EOF {
|
if err := handler.Serve(); err == io.EOF {
|
||||||
handler.Close()
|
handler.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a new SFTP handler for a given server. The directory argument should
|
|
||||||
// be the base directory for a server. All actions done on the server will be
|
|
||||||
// relative to that directory, and the user will not be able to escape out of it.
|
|
||||||
func (c *SFTPServer) newHandler(sc *ssh.ServerConn) sftp.Handlers {
|
|
||||||
s := server.GetServers().Find(func(s *server.Server) bool {
|
|
||||||
return s.Id() == sc.Permissions.Extensions["uuid"]
|
|
||||||
})
|
|
||||||
|
|
||||||
p := Handler{
|
|
||||||
fs: s.Filesystem(),
|
|
||||||
permissions: strings.Split(sc.Permissions.Extensions["permissions"], ","),
|
|
||||||
ro: config.Get().System.Sftp.ReadOnly,
|
|
||||||
logger: log.WithFields(log.Fields{
|
|
||||||
"subsystem": "sftp",
|
|
||||||
"username": sc.User(),
|
|
||||||
"ip": sc.RemoteAddr(),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
return sftp.Handlers{
|
|
||||||
FileGet: &p,
|
|
||||||
FilePut: &p,
|
|
||||||
FileCmd: &p,
|
|
||||||
FileList: &p,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generates a private key that will be used by the SFTP server.
|
// Generates a private key that will be used by the SFTP server.
|
||||||
func (c *SFTPServer) generatePrivateKey() error {
|
func (c *SFTPServer) generatePrivateKey() error {
|
||||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(path.Join(c.BasePath, ".sftp"), 0755); err != nil {
|
if err := os.MkdirAll(path.Join(c.BasePath, ".sftp"), 0755); err != nil {
|
||||||
return err
|
return errors.Wrap(err, "sftp/server: could not create .sftp directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
o, err := os.OpenFile(path.Join(c.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
o, err := os.OpenFile(path.Join(c.BasePath, ".sftp/id_rsa"), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
defer o.Close()
|
defer o.Close()
|
||||||
|
|
||||||
pkey := &pem.Block{
|
err = pem.Encode(o, &pem.Block{
|
||||||
Type: "RSA PRIVATE KEY",
|
Type: "RSA PRIVATE KEY",
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
}
|
})
|
||||||
|
return errors.WithStack(err)
|
||||||
if err := pem.Encode(o, pkey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// A function capable of validating user credentials with the Panel API.
|
// A function capable of validating user credentials with the Panel API.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user