Rewrite console throttling logic; drop complex timer usage and use a very simple throttle
This also removes server process termination logic when a server is breaching the output limits. It simply continues to efficiently throttle the console output.
This commit is contained in:
		
							parent
							
								
									fb73d5dbbf
								
							
						
					
					
						commit
						2b2b5200eb
					
				| 
						 | 
				
			
			@ -222,26 +222,14 @@ type ConsoleThrottles struct {
 | 
			
		|||
	// Whether or not the throttler is enabled for this instance.
 | 
			
		||||
	Enabled bool `json:"enabled" yaml:"enabled" default:"true"`
 | 
			
		||||
 | 
			
		||||
	// The total number of lines that can be output in a given LineResetInterval period before
 | 
			
		||||
	// The total number of lines that can be output in a given Period period before
 | 
			
		||||
	// a warning is triggered and counted against the server.
 | 
			
		||||
	Lines uint64 `json:"lines" yaml:"lines" default:"2000"`
 | 
			
		||||
 | 
			
		||||
	// The total number of throttle activations that can accumulate before a server is considered
 | 
			
		||||
	// to be breaching and will be stopped. This value is decremented by one every DecayInterval.
 | 
			
		||||
	MaximumTriggerCount uint64 `json:"maximum_trigger_count" yaml:"maximum_trigger_count" default:"5"`
 | 
			
		||||
 | 
			
		||||
	// The amount of time after which the number of lines processed is reset to 0. This runs in
 | 
			
		||||
	// a constant loop and is not affected by the current console output volumes. By default, this
 | 
			
		||||
	// will reset the processed line count back to 0 every 100ms.
 | 
			
		||||
	LineResetInterval uint64 `json:"line_reset_interval" yaml:"line_reset_interval" default:"100"`
 | 
			
		||||
 | 
			
		||||
	// The amount of time in milliseconds that must pass without an output warning being triggered
 | 
			
		||||
	// before a throttle activation is decremented.
 | 
			
		||||
	DecayInterval uint64 `json:"decay_interval" yaml:"decay_interval" default:"10000"`
 | 
			
		||||
 | 
			
		||||
	// The amount of time that a server is allowed to be stopping for before it is terminated
 | 
			
		||||
	// forcefully if it triggers output throttles.
 | 
			
		||||
	StopGracePeriod uint `json:"stop_grace_period" yaml:"stop_grace_period" default:"15"`
 | 
			
		||||
	Period uint64 `json:"line_reset_interval" yaml:"line_reset_interval" default:"100"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Configuration struct {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,15 +1,11 @@
 | 
			
		|||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	"github.com/mitchellh/colorstring"
 | 
			
		||||
 | 
			
		||||
	"github.com/pterodactyl/wings/config"
 | 
			
		||||
	"github.com/pterodactyl/wings/system"
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -18,118 +14,8 @@ import (
 | 
			
		|||
// the configuration every time we need to send output along to the websocket for
 | 
			
		||||
// a server.
 | 
			
		||||
var appName string
 | 
			
		||||
 | 
			
		||||
var appNameSync sync.Once
 | 
			
		||||
 | 
			
		||||
var ErrTooMuchConsoleData = errors.New("console is outputting too much data")
 | 
			
		||||
 | 
			
		||||
type ConsoleThrottler struct {
 | 
			
		||||
	mu sync.Mutex
 | 
			
		||||
	config.ConsoleThrottles
 | 
			
		||||
 | 
			
		||||
	// The total number of activations that have occurred thus far.
 | 
			
		||||
	activations uint64
 | 
			
		||||
 | 
			
		||||
	// The total number of lines that have been sent since the last reset timer period.
 | 
			
		||||
	count uint64
 | 
			
		||||
 | 
			
		||||
	// Wether or not the console output is being throttled. It is up to calling code to
 | 
			
		||||
	// determine what to do if it is.
 | 
			
		||||
	isThrottled *system.AtomicBool
 | 
			
		||||
 | 
			
		||||
	// The total number of lines processed so far during the given time period.
 | 
			
		||||
	timerCancel *context.CancelFunc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Resets the state of the throttler.
 | 
			
		||||
func (ct *ConsoleThrottler) Reset() {
 | 
			
		||||
	atomic.StoreUint64(&ct.count, 0)
 | 
			
		||||
	atomic.StoreUint64(&ct.activations, 0)
 | 
			
		||||
	ct.isThrottled.Store(false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Triggers an activation for a server. You can also decrement the number of activations
 | 
			
		||||
// by passing a negative number.
 | 
			
		||||
func (ct *ConsoleThrottler) markActivation(increment bool) uint64 {
 | 
			
		||||
	if !increment {
 | 
			
		||||
		if atomic.LoadUint64(&ct.activations) == 0 {
 | 
			
		||||
			return 0
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// This weird dohicky subtracts 1 from the activation count.
 | 
			
		||||
		return atomic.AddUint64(&ct.activations, ^uint64(0))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return atomic.AddUint64(&ct.activations, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Determines if the console is currently being throttled. Calls to this function can be used to
 | 
			
		||||
// determine if output should be funneled along to the websocket processes.
 | 
			
		||||
func (ct *ConsoleThrottler) Throttled() bool {
 | 
			
		||||
	return ct.isThrottled.Load()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Starts a timer that runs in a seperate thread and will continually decrement the lines processed
 | 
			
		||||
// and number of activations, regardless of the current console message volume. All of the timers
 | 
			
		||||
// are canceled if the context passed through is canceled.
 | 
			
		||||
func (ct *ConsoleThrottler) StartTimer(ctx context.Context) {
 | 
			
		||||
	system.Every(ctx, time.Duration(int64(ct.LineResetInterval))*time.Millisecond, func(_ time.Time) {
 | 
			
		||||
		ct.isThrottled.Store(false)
 | 
			
		||||
		atomic.StoreUint64(&ct.count, 0)
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	system.Every(ctx, time.Duration(int64(ct.DecayInterval))*time.Millisecond, func(_ time.Time) {
 | 
			
		||||
		ct.markActivation(false)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Handles output from a server's console. This code ensures that a server is not outputting
 | 
			
		||||
// an excessive amount of data to the console that could indicate a malicious or run-away process
 | 
			
		||||
// and lead to performance issues for other users.
 | 
			
		||||
//
 | 
			
		||||
// This was much more of a problem for the NodeJS version of the daemon which struggled to handle
 | 
			
		||||
// large volumes of output. However, this code is much more performant so I generally feel a lot
 | 
			
		||||
// better about it's abilities.
 | 
			
		||||
//
 | 
			
		||||
// However, extreme output is still somewhat of a DoS attack vector against this software since we
 | 
			
		||||
// are still logging it to the disk temporarily and will want to avoid dumping a huge amount of
 | 
			
		||||
// data all at once. These values are all configurable via the wings configuration file, however the
 | 
			
		||||
// defaults have been in the wild for almost two years at the time of this writing, so I feel quite
 | 
			
		||||
// confident in them.
 | 
			
		||||
//
 | 
			
		||||
// This function returns an error if the server should be stopped due to violating throttle constraints
 | 
			
		||||
// and a boolean value indicating if a throttle is being violated when it is checked.
 | 
			
		||||
func (ct *ConsoleThrottler) Increment(onTrigger func()) error {
 | 
			
		||||
	if !ct.Enabled {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Increment the line count and if we have now output more lines than are allowed, trigger a throttle
 | 
			
		||||
	// activation. Once the throttle is triggered and has passed the kill at value we will trigger a server
 | 
			
		||||
	// stop automatically.
 | 
			
		||||
	if atomic.AddUint64(&ct.count, 1) >= ct.Lines && !ct.Throttled() {
 | 
			
		||||
		ct.isThrottled.Store(true)
 | 
			
		||||
		if ct.markActivation(true) >= ct.MaximumTriggerCount {
 | 
			
		||||
			return ErrTooMuchConsoleData
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		onTrigger()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns the throttler instance for the server or creates a new one.
 | 
			
		||||
func (s *Server) Throttler() *ConsoleThrottler {
 | 
			
		||||
	s.throttleOnce.Do(func() {
 | 
			
		||||
		s.throttler = &ConsoleThrottler{
 | 
			
		||||
			isThrottled:      system.NewAtomicBool(false),
 | 
			
		||||
			ConsoleThrottles: config.Get().Throttles,
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return s.throttler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PublishConsoleOutputFromDaemon sends output to the server console formatted
 | 
			
		||||
// to appear correctly as being sent from Wings.
 | 
			
		||||
func (s *Server) PublishConsoleOutputFromDaemon(data string) {
 | 
			
		||||
| 
						 | 
				
			
			@ -141,3 +27,55 @@ func (s *Server) PublishConsoleOutputFromDaemon(data string) {
 | 
			
		|||
		colorstring.Color(fmt.Sprintf("[yellow][bold][%s Daemon]:[default] %s", appName, data)),
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Throttler returns the throttler instance for the server or creates a new one.
 | 
			
		||||
func (s *Server) Throttler() *ConsoleThrottle {
 | 
			
		||||
	s.throttleOnce.Do(func() {
 | 
			
		||||
		throttles := config.Get().Throttles
 | 
			
		||||
		period := time.Duration(throttles.Period) * time.Millisecond
 | 
			
		||||
 | 
			
		||||
		s.throttler = newConsoleThrottle(throttles.Lines, period)
 | 
			
		||||
		s.throttler.strike = func() {
 | 
			
		||||
			s.PublishConsoleOutputFromDaemon(fmt.Sprintf("Server is outputting console data too quickly -- throttling..."))
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return s.throttler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ConsoleThrottle struct {
 | 
			
		||||
	limit  *system.Rate
 | 
			
		||||
	lock   *system.Locker
 | 
			
		||||
	strike func()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newConsoleThrottle(lines uint64, period time.Duration) *ConsoleThrottle {
 | 
			
		||||
	return &ConsoleThrottle{
 | 
			
		||||
		limit: system.NewRate(lines, period),
 | 
			
		||||
		lock:  system.NewLocker(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Allow checks if the console is allowed to process more output data, or if too
 | 
			
		||||
// much has already been sent over the line. If there is too much output the
 | 
			
		||||
// strike callback function is triggered, but only if it has not already been
 | 
			
		||||
// triggered at this point in the process.
 | 
			
		||||
//
 | 
			
		||||
// If output is allowed, the lock on the throttler is released and the next time
 | 
			
		||||
// it is triggered the strike function will be re-executed.
 | 
			
		||||
func (ct *ConsoleThrottle) Allow() bool {
 | 
			
		||||
	if !ct.limit.Try() {
 | 
			
		||||
		if err := ct.lock.Acquire(); err == nil {
 | 
			
		||||
			if ct.strike != nil {
 | 
			
		||||
				ct.strike()
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	ct.lock.Release()
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reset resets the console throttler internal rate limiter and overage counter.
 | 
			
		||||
func (ct *ConsoleThrottle) Reset() {
 | 
			
		||||
	ct.limit.Reset()
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										62
									
								
								server/console_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								server/console_test.go
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,62 @@
 | 
			
		|||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/franela/goblin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestName(t *testing.T) {
 | 
			
		||||
	g := goblin.Goblin(t)
 | 
			
		||||
 | 
			
		||||
	g.Describe("ConsoleThrottler", func() {
 | 
			
		||||
		g.It("keeps count of the number of overages in a time period", func() {
 | 
			
		||||
			t := newConsoleThrottle(1, time.Second)
 | 
			
		||||
			g.Assert(t.Allow()).IsTrue()
 | 
			
		||||
			g.Assert(t.Allow()).IsFalse()
 | 
			
		||||
			g.Assert(t.Allow()).IsFalse()
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.It("calls strike once per time period", func() {
 | 
			
		||||
			t := newConsoleThrottle(1, time.Millisecond * 20)
 | 
			
		||||
 | 
			
		||||
			var times int
 | 
			
		||||
			t.strike = func() {
 | 
			
		||||
				times = times + 1
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			t.Allow()
 | 
			
		||||
			t.Allow()
 | 
			
		||||
			t.Allow()
 | 
			
		||||
			time.Sleep(time.Millisecond * 100)
 | 
			
		||||
			t.Allow()
 | 
			
		||||
			t.Reset()
 | 
			
		||||
			t.Allow()
 | 
			
		||||
			t.Allow()
 | 
			
		||||
			t.Allow()
 | 
			
		||||
 | 
			
		||||
			g.Assert(times).Equal(2)
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.It("is properly reset", func() {
 | 
			
		||||
			t := newConsoleThrottle(10, time.Second)
 | 
			
		||||
 | 
			
		||||
			for i := 0; i < 10; i++ {
 | 
			
		||||
				g.Assert(t.Allow()).IsTrue()
 | 
			
		||||
			}
 | 
			
		||||
			g.Assert(t.Allow()).IsFalse()
 | 
			
		||||
			t.Reset()
 | 
			
		||||
			g.Assert(t.Allow()).IsTrue()
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkConsoleThrottle(b *testing.B) {
 | 
			
		||||
	t := newConsoleThrottle(10, time.Millisecond * 10)
 | 
			
		||||
 | 
			
		||||
	b.ReportAllocs()
 | 
			
		||||
    for i := 0; i < b.N; i++ {
 | 
			
		||||
        t.Allow()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -8,7 +8,6 @@ import (
 | 
			
		|||
 | 
			
		||||
	"github.com/apex/log"
 | 
			
		||||
 | 
			
		||||
	"github.com/pterodactyl/wings/config"
 | 
			
		||||
	"github.com/pterodactyl/wings/environment"
 | 
			
		||||
	"github.com/pterodactyl/wings/events"
 | 
			
		||||
	"github.com/pterodactyl/wings/remote"
 | 
			
		||||
| 
						 | 
				
			
			@ -57,45 +56,23 @@ func (dsl *diskSpaceLimiter) Trigger() {
 | 
			
		|||
// output lines to determine if the server is started yet, and if the output is
 | 
			
		||||
// not being throttled, will send the data over to the websocket.
 | 
			
		||||
func (s *Server) processConsoleOutputEvent(v []byte) {
 | 
			
		||||
	t := s.Throttler()
 | 
			
		||||
	err := t.Increment(func() {
 | 
			
		||||
		s.PublishConsoleOutputFromDaemon("Your server is outputting too much data and is being throttled.")
 | 
			
		||||
	})
 | 
			
		||||
	// An error is only returned if the server has breached the thresholds set.
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// If the process is already stopping, just let it continue with that action rather than attempting
 | 
			
		||||
		// to terminate again.
 | 
			
		||||
		if s.Environment.State() != environment.ProcessStoppingState {
 | 
			
		||||
			s.Environment.SetState(environment.ProcessStoppingState)
 | 
			
		||||
 | 
			
		||||
			go func() {
 | 
			
		||||
				s.Log().Warn("stopping server instance, violating throttle limits")
 | 
			
		||||
				s.PublishConsoleOutputFromDaemon("Your server is being stopped for outputting too much data in a short period of time.")
 | 
			
		||||
 | 
			
		||||
				// Completely skip over server power actions and terminate the running instance. This gives the
 | 
			
		||||
				// server 15 seconds to finish stopping gracefully before it is forcefully terminated.
 | 
			
		||||
				if err := s.Environment.WaitForStop(config.Get().Throttles.StopGracePeriod, true); err != nil {
 | 
			
		||||
					// If there is an error set the process back to running so that this throttler is called
 | 
			
		||||
					// again and hopefully kills the server.
 | 
			
		||||
					if s.Environment.State() != environment.ProcessOfflineState {
 | 
			
		||||
						s.Environment.SetState(environment.ProcessRunningState)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					s.Log().WithField("error", err).Error("failed to terminate environment after triggering throttle")
 | 
			
		||||
				}
 | 
			
		||||
			}()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Always process the console output, but do this in a seperate thread since we
 | 
			
		||||
	// don't really care about side-effects from this call, and don't want it to block
 | 
			
		||||
	// the console sending logic.
 | 
			
		||||
	go s.onConsoleOutput(v)
 | 
			
		||||
 | 
			
		||||
	// If we are not throttled, go ahead and output the data.
 | 
			
		||||
	if !t.Throttled() {
 | 
			
		||||
		s.Sink(LogSink).Push(v)
 | 
			
		||||
	// If the console is being throttled, do nothing else with it, we don't want
 | 
			
		||||
	// to waste time. This code previously terminated server instances after violating
 | 
			
		||||
	// different throttle limits. That code was clunky and difficult to reason about,
 | 
			
		||||
	// in addition to being a consistent pain point for users.
 | 
			
		||||
	//
 | 
			
		||||
	// In the interest of building highly efficient software, that code has been removed
 | 
			
		||||
	// here, and we'll rely on the host to detect bad actors through their own means.
 | 
			
		||||
	if !s.Throttler().Allow() {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s.Sink(LogSink).Push(v)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// StartEventListeners adds all the internal event listeners we want to use for
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -199,7 +199,6 @@ func (m *Manager) InitServer(data remote.ServerConfigurationResponse) (*Server,
 | 
			
		|||
	} else {
 | 
			
		||||
		s.Environment = env
 | 
			
		||||
		s.StartEventListeners()
 | 
			
		||||
		s.Throttler().StartTimer(s.Context())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// If the server's data directory exists, force disk usage calculation.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,7 +4,6 @@ import (
 | 
			
		|||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
| 
						 | 
				
			
			@ -41,85 +40,6 @@ func (pa PowerAction) IsStart() bool {
 | 
			
		|||
	return pa == PowerActionStart || pa == PowerActionRestart
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type powerLocker struct {
 | 
			
		||||
	mu sync.RWMutex
 | 
			
		||||
	ch chan bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newPowerLocker() *powerLocker {
 | 
			
		||||
	return &powerLocker{
 | 
			
		||||
		ch: make(chan bool, 1),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type errPowerLockerLocked struct{}
 | 
			
		||||
 | 
			
		||||
func (e errPowerLockerLocked) Error() string {
 | 
			
		||||
	return "cannot acquire a lock on the power state: already locked"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ErrPowerLockerLocked error = errPowerLockerLocked{}
 | 
			
		||||
 | 
			
		||||
// IsLocked returns the current state of the locker channel. If there is
 | 
			
		||||
// currently a value in the channel, it is assumed to be locked.
 | 
			
		||||
func (pl *powerLocker) IsLocked() bool {
 | 
			
		||||
	pl.mu.RLock()
 | 
			
		||||
	defer pl.mu.RUnlock()
 | 
			
		||||
	return len(pl.ch) == 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Acquire will acquire the power lock if it is not currently locked. If it is
 | 
			
		||||
// already locked, acquire will fail to acquire the lock, and will return false.
 | 
			
		||||
func (pl *powerLocker) Acquire() error {
 | 
			
		||||
	pl.mu.Lock()
 | 
			
		||||
	defer pl.mu.Unlock()
 | 
			
		||||
	select {
 | 
			
		||||
	case pl.ch <- true:
 | 
			
		||||
	default:
 | 
			
		||||
		return errors.WithStack(ErrPowerLockerLocked)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TryAcquire will attempt to acquire a power-lock until the context provided
 | 
			
		||||
// is canceled.
 | 
			
		||||
func (pl *powerLocker) TryAcquire(ctx context.Context) error {
 | 
			
		||||
	select {
 | 
			
		||||
	case pl.ch <- true:
 | 
			
		||||
		return nil
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		if err := ctx.Err(); err != nil {
 | 
			
		||||
			return errors.WithStack(err)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Release will drain the locker channel so that we can properly re-acquire it
 | 
			
		||||
// at a later time. If the channel is not currently locked this function is a
 | 
			
		||||
// no-op and will immediately return.
 | 
			
		||||
func (pl *powerLocker) Release() {
 | 
			
		||||
	pl.mu.Lock()
 | 
			
		||||
	select {
 | 
			
		||||
	case <-pl.ch:
 | 
			
		||||
	default:
 | 
			
		||||
	}
 | 
			
		||||
	pl.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Destroy cleans up the power locker by closing the channel.
 | 
			
		||||
func (pl *powerLocker) Destroy() {
 | 
			
		||||
	pl.mu.Lock()
 | 
			
		||||
	if pl.ch != nil {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-pl.ch:
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
		close(pl.ch)
 | 
			
		||||
	}
 | 
			
		||||
	pl.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExecutingPowerAction checks if there is currently a power action being
 | 
			
		||||
// processed for the server.
 | 
			
		||||
func (s *Server) ExecutingPowerAction() bool {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,154 +1,18 @@
 | 
			
		|||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	. "github.com/franela/goblin"
 | 
			
		||||
	"github.com/pterodactyl/wings/system"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestPower(t *testing.T) {
 | 
			
		||||
	g := Goblin(t)
 | 
			
		||||
 | 
			
		||||
	g.Describe("PowerLocker", func() {
 | 
			
		||||
		var pl *powerLocker
 | 
			
		||||
		g.BeforeEach(func() {
 | 
			
		||||
			pl = newPowerLocker()
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#IsLocked", func() {
 | 
			
		||||
			g.It("should return false when the channel is empty", func() {
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsFalse()
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should return true when the channel is at capacity", func() {
 | 
			
		||||
				pl.ch <- true
 | 
			
		||||
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsTrue()
 | 
			
		||||
				<-pl.ch
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsFalse()
 | 
			
		||||
 | 
			
		||||
				// We don't care what the channel value is, just that there is
 | 
			
		||||
				// something in it.
 | 
			
		||||
				pl.ch <- false
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsTrue()
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#Acquire", func() {
 | 
			
		||||
			g.It("should acquire a lock when channel is empty", func() {
 | 
			
		||||
				err := pl.Acquire()
 | 
			
		||||
 | 
			
		||||
				g.Assert(err).IsNil()
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(pl.ch)).Equal(1)
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should return an error when the channel is full", func() {
 | 
			
		||||
				pl.ch <- true
 | 
			
		||||
 | 
			
		||||
				err := pl.Acquire()
 | 
			
		||||
 | 
			
		||||
				g.Assert(err).IsNotNil()
 | 
			
		||||
				g.Assert(errors.Is(err, ErrPowerLockerLocked)).IsTrue()
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(pl.ch)).Equal(1)
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#TryAcquire", func() {
 | 
			
		||||
			g.It("should acquire a lock when channel is empty", func() {
 | 
			
		||||
				g.Timeout(time.Second)
 | 
			
		||||
 | 
			
		||||
				err := pl.TryAcquire(context.Background())
 | 
			
		||||
 | 
			
		||||
				g.Assert(err).IsNil()
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsTrue()
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should block until context is canceled if channel is full", func() {
 | 
			
		||||
				g.Timeout(time.Second)
 | 
			
		||||
				ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
 | 
			
		||||
				defer cancel()
 | 
			
		||||
 | 
			
		||||
				pl.ch <- true
 | 
			
		||||
				err := pl.TryAcquire(ctx)
 | 
			
		||||
 | 
			
		||||
				g.Assert(err).IsNotNil()
 | 
			
		||||
				g.Assert(errors.Is(err, context.DeadlineExceeded)).IsTrue()
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsTrue()
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should block until lock can be acquired", func() {
 | 
			
		||||
				g.Timeout(time.Second)
 | 
			
		||||
 | 
			
		||||
				ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
 | 
			
		||||
				defer cancel()
 | 
			
		||||
 | 
			
		||||
				pl.Acquire()
 | 
			
		||||
				go func() {
 | 
			
		||||
					time.AfterFunc(time.Millisecond * 50, func() {
 | 
			
		||||
						pl.Release()
 | 
			
		||||
					})
 | 
			
		||||
				}()
 | 
			
		||||
 | 
			
		||||
				err := pl.TryAcquire(ctx)
 | 
			
		||||
				g.Assert(err).IsNil()
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsTrue()
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#Release", func() {
 | 
			
		||||
			g.It("should release when channel is full", func() {
 | 
			
		||||
				pl.Acquire()
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsTrue()
 | 
			
		||||
				pl.Release()
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(pl.ch)).Equal(0)
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsFalse()
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should release when channel is empty", func() {
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsFalse()
 | 
			
		||||
				pl.Release()
 | 
			
		||||
				g.Assert(cap(pl.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(pl.ch)).Equal(0)
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsFalse()
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#Destroy", func() {
 | 
			
		||||
			g.It("should unlock and close the channel", func() {
 | 
			
		||||
				pl.Acquire()
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsTrue()
 | 
			
		||||
				pl.Destroy()
 | 
			
		||||
				g.Assert(pl.IsLocked()).IsFalse()
 | 
			
		||||
 | 
			
		||||
				defer func() {
 | 
			
		||||
					r := recover()
 | 
			
		||||
 | 
			
		||||
					g.Assert(r).IsNotNil()
 | 
			
		||||
					g.Assert(r.(error).Error()).Equal("send on closed channel")
 | 
			
		||||
				}()
 | 
			
		||||
 | 
			
		||||
				pl.Acquire()
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	g.Describe("Server#ExecutingPowerAction", func() {
 | 
			
		||||
		g.It("should return based on locker status", func() {
 | 
			
		||||
			s := &Server{powerLock: newPowerLocker()}
 | 
			
		||||
			s := &Server{powerLock: system.NewLocker()}
 | 
			
		||||
 | 
			
		||||
			g.Assert(s.ExecutingPowerAction()).IsFalse()
 | 
			
		||||
			s.powerLock.Acquire()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,9 +30,8 @@ type Server struct {
 | 
			
		|||
	ctx       context.Context
 | 
			
		||||
	ctxCancel *context.CancelFunc
 | 
			
		||||
 | 
			
		||||
	emitterLock  sync.Mutex
 | 
			
		||||
	powerLock    *powerLocker
 | 
			
		||||
	throttleOnce sync.Once
 | 
			
		||||
	emitterLock sync.Mutex
 | 
			
		||||
	powerLock   *system.Locker
 | 
			
		||||
 | 
			
		||||
	// Maintains the configuration for the server. This is the data that gets returned by the Panel
 | 
			
		||||
	// such as build settings and container images.
 | 
			
		||||
| 
						 | 
				
			
			@ -64,7 +63,8 @@ type Server struct {
 | 
			
		|||
	restoring    *system.AtomicBool
 | 
			
		||||
 | 
			
		||||
	// The console throttler instance used to control outputs.
 | 
			
		||||
	throttler *ConsoleThrottler
 | 
			
		||||
	throttler    *ConsoleThrottle
 | 
			
		||||
	throttleOnce sync.Once
 | 
			
		||||
 | 
			
		||||
	// Tracks open websocket connections for the server.
 | 
			
		||||
	wsBag       *WebsocketBag
 | 
			
		||||
| 
						 | 
				
			
			@ -87,7 +87,7 @@ func New(client remote.Client) (*Server, error) {
 | 
			
		|||
		installing:   system.NewAtomicBool(false),
 | 
			
		||||
		transferring: system.NewAtomicBool(false),
 | 
			
		||||
		restoring:    system.NewAtomicBool(false),
 | 
			
		||||
		powerLock:    newPowerLocker(),
 | 
			
		||||
		powerLock:    system.NewLocker(),
 | 
			
		||||
		sinks: map[SinkName]*sinkPool{
 | 
			
		||||
			LogSink:     newSinkPool(),
 | 
			
		||||
			InstallSink: newSinkPool(),
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										83
									
								
								system/locker.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								system/locker.go
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,83 @@
 | 
			
		|||
package system
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ErrLockerLocked = errors.Sentinel("locker: cannot acquire lock, already locked")
 | 
			
		||||
 | 
			
		||||
type Locker struct {
 | 
			
		||||
	mu sync.RWMutex
 | 
			
		||||
	ch chan bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewLocker returns a new Locker instance.
 | 
			
		||||
func NewLocker() *Locker {
 | 
			
		||||
	return &Locker{
 | 
			
		||||
		ch: make(chan bool, 1),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsLocked returns the current state of the locker channel. If there is
 | 
			
		||||
// currently a value in the channel, it is assumed to be locked.
 | 
			
		||||
func (l *Locker) IsLocked() bool {
 | 
			
		||||
	l.mu.RLock()
 | 
			
		||||
	defer l.mu.RUnlock()
 | 
			
		||||
	return len(l.ch) == 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Acquire will acquire the power lock if it is not currently locked. If it is
 | 
			
		||||
// already locked, acquire will fail to acquire the lock, and will return false.
 | 
			
		||||
func (l *Locker) Acquire() error {
 | 
			
		||||
	l.mu.Lock()
 | 
			
		||||
	defer l.mu.Unlock()
 | 
			
		||||
	select {
 | 
			
		||||
	case l.ch <- true:
 | 
			
		||||
	default:
 | 
			
		||||
		return ErrLockerLocked
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// TryAcquire will attempt to acquire a power-lock until the context provided
 | 
			
		||||
// is canceled.
 | 
			
		||||
func (l *Locker) TryAcquire(ctx context.Context) error {
 | 
			
		||||
	select {
 | 
			
		||||
	case l.ch <- true:
 | 
			
		||||
		return nil
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		if err := ctx.Err(); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Release will drain the locker channel so that we can properly re-acquire it
 | 
			
		||||
// at a later time. If the channel is not currently locked this function is a
 | 
			
		||||
// no-op and will immediately return.
 | 
			
		||||
func (l *Locker) Release() {
 | 
			
		||||
	l.mu.Lock()
 | 
			
		||||
	select {
 | 
			
		||||
	case <-l.ch:
 | 
			
		||||
	default:
 | 
			
		||||
	}
 | 
			
		||||
	l.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Destroy cleans up the power locker by closing the channel.
 | 
			
		||||
func (l *Locker) Destroy() {
 | 
			
		||||
	l.mu.Lock()
 | 
			
		||||
	if l.ch != nil {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-l.ch:
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
		close(l.ch)
 | 
			
		||||
	}
 | 
			
		||||
	l.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										148
									
								
								system/locker_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								system/locker_test.go
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,148 @@
 | 
			
		|||
package system
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"emperror.dev/errors"
 | 
			
		||||
	. "github.com/franela/goblin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestPower(t *testing.T) {
 | 
			
		||||
	g := Goblin(t)
 | 
			
		||||
 | 
			
		||||
	g.Describe("Locker", func() {
 | 
			
		||||
		var l *Locker
 | 
			
		||||
		g.BeforeEach(func() {
 | 
			
		||||
			l = NewLocker()
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#IsLocked", func() {
 | 
			
		||||
			g.It("should return false when the channel is empty", func() {
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(l.IsLocked()).IsFalse()
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should return true when the channel is at capacity", func() {
 | 
			
		||||
				l.ch <- true
 | 
			
		||||
 | 
			
		||||
				g.Assert(l.IsLocked()).IsTrue()
 | 
			
		||||
				<-l.ch
 | 
			
		||||
				g.Assert(l.IsLocked()).IsFalse()
 | 
			
		||||
 | 
			
		||||
				// We don't care what the channel value is, just that there is
 | 
			
		||||
				// something in it.
 | 
			
		||||
				l.ch <- false
 | 
			
		||||
				g.Assert(l.IsLocked()).IsTrue()
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#Acquire", func() {
 | 
			
		||||
			g.It("should acquire a lock when channel is empty", func() {
 | 
			
		||||
				err := l.Acquire()
 | 
			
		||||
 | 
			
		||||
				g.Assert(err).IsNil()
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(l.ch)).Equal(1)
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should return an error when the channel is full", func() {
 | 
			
		||||
				l.ch <- true
 | 
			
		||||
 | 
			
		||||
				err := l.Acquire()
 | 
			
		||||
 | 
			
		||||
				g.Assert(err).IsNotNil()
 | 
			
		||||
				g.Assert(errors.Is(err, ErrLockerLocked)).IsTrue()
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(l.ch)).Equal(1)
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#TryAcquire", func() {
 | 
			
		||||
			g.It("should acquire a lock when channel is empty", func() {
 | 
			
		||||
				g.Timeout(time.Second)
 | 
			
		||||
 | 
			
		||||
				err := l.TryAcquire(context.Background())
 | 
			
		||||
 | 
			
		||||
				g.Assert(err).IsNil()
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(l.IsLocked()).IsTrue()
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should block until context is canceled if channel is full", func() {
 | 
			
		||||
				g.Timeout(time.Second)
 | 
			
		||||
				ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
 | 
			
		||||
				defer cancel()
 | 
			
		||||
 | 
			
		||||
				l.ch <- true
 | 
			
		||||
				err := l.TryAcquire(ctx)
 | 
			
		||||
 | 
			
		||||
				g.Assert(err).IsNotNil()
 | 
			
		||||
				g.Assert(errors.Is(err, context.DeadlineExceeded)).IsTrue()
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(l.IsLocked()).IsTrue()
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should block until lock can be acquired", func() {
 | 
			
		||||
				g.Timeout(time.Second)
 | 
			
		||||
 | 
			
		||||
				ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
 | 
			
		||||
				defer cancel()
 | 
			
		||||
 | 
			
		||||
				l.Acquire()
 | 
			
		||||
				go func() {
 | 
			
		||||
					time.AfterFunc(time.Millisecond * 50, func() {
 | 
			
		||||
						l.Release()
 | 
			
		||||
					})
 | 
			
		||||
				}()
 | 
			
		||||
 | 
			
		||||
				err := l.TryAcquire(ctx)
 | 
			
		||||
				g.Assert(err).IsNil()
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(l.IsLocked()).IsTrue()
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#Release", func() {
 | 
			
		||||
			g.It("should release when channel is full", func() {
 | 
			
		||||
				l.Acquire()
 | 
			
		||||
				g.Assert(l.IsLocked()).IsTrue()
 | 
			
		||||
				l.Release()
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(l.ch)).Equal(0)
 | 
			
		||||
				g.Assert(l.IsLocked()).IsFalse()
 | 
			
		||||
			})
 | 
			
		||||
 | 
			
		||||
			g.It("should release when channel is empty", func() {
 | 
			
		||||
				g.Assert(l.IsLocked()).IsFalse()
 | 
			
		||||
				l.Release()
 | 
			
		||||
				g.Assert(cap(l.ch)).Equal(1)
 | 
			
		||||
				g.Assert(len(l.ch)).Equal(0)
 | 
			
		||||
				g.Assert(l.IsLocked()).IsFalse()
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.Describe("PowerLocker#Destroy", func() {
 | 
			
		||||
			g.It("should unlock and close the channel", func() {
 | 
			
		||||
				l.Acquire()
 | 
			
		||||
				g.Assert(l.IsLocked()).IsTrue()
 | 
			
		||||
				l.Destroy()
 | 
			
		||||
				g.Assert(l.IsLocked()).IsFalse()
 | 
			
		||||
 | 
			
		||||
				defer func() {
 | 
			
		||||
					r := recover()
 | 
			
		||||
 | 
			
		||||
					g.Assert(r).IsNotNil()
 | 
			
		||||
					g.Assert(r.(error).Error()).Equal("send on closed channel")
 | 
			
		||||
				}()
 | 
			
		||||
 | 
			
		||||
				l.Acquire()
 | 
			
		||||
			})
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										50
									
								
								system/rate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								system/rate.go
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,50 @@
 | 
			
		|||
package system
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Rate defines a rate limiter of n items (limit) per duration of time.
 | 
			
		||||
type Rate struct {
 | 
			
		||||
	mu       sync.Mutex
 | 
			
		||||
	limit    uint64
 | 
			
		||||
	duration time.Duration
 | 
			
		||||
	count    uint64
 | 
			
		||||
	last     time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRate(limit uint64, duration time.Duration) *Rate {
 | 
			
		||||
	return &Rate{
 | 
			
		||||
		limit:    limit,
 | 
			
		||||
		duration: duration,
 | 
			
		||||
		last:     time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Try returns true if under the rate limit defined, or false if the rate limit
 | 
			
		||||
// has been exceeded for the current duration.
 | 
			
		||||
func (r *Rate) Try() bool {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
	defer r.mu.Unlock()
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	// If it has been more than the duration, reset the timer and count.
 | 
			
		||||
	if now.Sub(r.last) > r.duration {
 | 
			
		||||
		r.count = 0
 | 
			
		||||
		r.last = now
 | 
			
		||||
	}
 | 
			
		||||
	if (r.count + 1) > r.limit {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	// Hit this once, and return.
 | 
			
		||||
	r.count = r.count + 1
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reset resets the internal state of the rate limiter back to zero.
 | 
			
		||||
func (r *Rate) Reset() {
 | 
			
		||||
	r.mu.Lock()
 | 
			
		||||
    r.count = 0
 | 
			
		||||
    r.last = time.Now()
 | 
			
		||||
	r.mu.Unlock()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										67
									
								
								system/rate_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								system/rate_test.go
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,67 @@
 | 
			
		|||
package system
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	. "github.com/franela/goblin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestRate(t *testing.T) {
 | 
			
		||||
	g := Goblin(t)
 | 
			
		||||
 | 
			
		||||
	g.Describe("Rate", func() {
 | 
			
		||||
		g.It("properly rate limits a bucket", func() {
 | 
			
		||||
			r := NewRate(10, time.Millisecond*100)
 | 
			
		||||
 | 
			
		||||
			for i := 0; i < 100; i++ {
 | 
			
		||||
				ok := r.Try()
 | 
			
		||||
				if i < 10 && !ok {
 | 
			
		||||
					g.Failf("should not have allowed take on try %d", i)
 | 
			
		||||
				} else if i >= 10 && ok {
 | 
			
		||||
					g.Failf("should have blocked take on try %d", i)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.It("handles rate limiting in chunks", func() {
 | 
			
		||||
			var out []int
 | 
			
		||||
			r := NewRate(12, time.Millisecond*10)
 | 
			
		||||
 | 
			
		||||
			for i := 0; i < 100; i++ {
 | 
			
		||||
				if i%20 == 0 {
 | 
			
		||||
					// Give it time to recover.
 | 
			
		||||
					time.Sleep(time.Millisecond * 10)
 | 
			
		||||
				}
 | 
			
		||||
				if r.Try() {
 | 
			
		||||
					out = append(out, i)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			g.Assert(len(out)).Equal(60)
 | 
			
		||||
			g.Assert(out[0]).Equal(0)
 | 
			
		||||
			g.Assert(out[12]).Equal(20)
 | 
			
		||||
			g.Assert(out[len(out)-1]).Equal(91)
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		g.It("resets back to zero when called", func() {
 | 
			
		||||
			r := NewRate(10, time.Second)
 | 
			
		||||
			for i := 0; i < 100; i++ {
 | 
			
		||||
				if i % 10 == 0 {
 | 
			
		||||
					r.Reset()
 | 
			
		||||
				}
 | 
			
		||||
				g.Assert(r.Try()).IsTrue()
 | 
			
		||||
			}
 | 
			
		||||
			g.Assert(r.Try()).IsFalse("final attempt should not allow taking")
 | 
			
		||||
		})
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkRate_Try(b *testing.B) {
 | 
			
		||||
	r := NewRate(10, time.Millisecond*100)
 | 
			
		||||
 | 
			
		||||
	b.ReportAllocs()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		r.Try()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user