diff --git a/config/config.go b/config/config.go index 052abdb..f6314db 100644 --- a/config/config.go +++ b/config/config.go @@ -222,26 +222,14 @@ type ConsoleThrottles struct { // Whether or not the throttler is enabled for this instance. Enabled bool `json:"enabled" yaml:"enabled" default:"true"` - // The total number of lines that can be output in a given LineResetInterval period before + // The total number of lines that can be output in a given Period period before // a warning is triggered and counted against the server. Lines uint64 `json:"lines" yaml:"lines" default:"2000"` - // The total number of throttle activations that can accumulate before a server is considered - // to be breaching and will be stopped. This value is decremented by one every DecayInterval. - MaximumTriggerCount uint64 `json:"maximum_trigger_count" yaml:"maximum_trigger_count" default:"5"` - // The amount of time after which the number of lines processed is reset to 0. This runs in // a constant loop and is not affected by the current console output volumes. By default, this // will reset the processed line count back to 0 every 100ms. - LineResetInterval uint64 `json:"line_reset_interval" yaml:"line_reset_interval" default:"100"` - - // The amount of time in milliseconds that must pass without an output warning being triggered - // before a throttle activation is decremented. - DecayInterval uint64 `json:"decay_interval" yaml:"decay_interval" default:"10000"` - - // The amount of time that a server is allowed to be stopping for before it is terminated - // forcefully if it triggers output throttles. - StopGracePeriod uint `json:"stop_grace_period" yaml:"stop_grace_period" default:"15"` + Period uint64 `json:"line_reset_interval" yaml:"line_reset_interval" default:"100"` } type Configuration struct { diff --git a/server/console.go b/server/console.go index d5babaa..ee28986 100644 --- a/server/console.go +++ b/server/console.go @@ -1,15 +1,11 @@ package server import ( - "context" "fmt" "sync" - "sync/atomic" "time" - "emperror.dev/errors" "github.com/mitchellh/colorstring" - "github.com/pterodactyl/wings/config" "github.com/pterodactyl/wings/system" ) @@ -18,118 +14,8 @@ import ( // the configuration every time we need to send output along to the websocket for // a server. var appName string - var appNameSync sync.Once -var ErrTooMuchConsoleData = errors.New("console is outputting too much data") - -type ConsoleThrottler struct { - mu sync.Mutex - config.ConsoleThrottles - - // The total number of activations that have occurred thus far. - activations uint64 - - // The total number of lines that have been sent since the last reset timer period. - count uint64 - - // Wether or not the console output is being throttled. It is up to calling code to - // determine what to do if it is. - isThrottled *system.AtomicBool - - // The total number of lines processed so far during the given time period. - timerCancel *context.CancelFunc -} - -// Resets the state of the throttler. -func (ct *ConsoleThrottler) Reset() { - atomic.StoreUint64(&ct.count, 0) - atomic.StoreUint64(&ct.activations, 0) - ct.isThrottled.Store(false) -} - -// Triggers an activation for a server. You can also decrement the number of activations -// by passing a negative number. -func (ct *ConsoleThrottler) markActivation(increment bool) uint64 { - if !increment { - if atomic.LoadUint64(&ct.activations) == 0 { - return 0 - } - - // This weird dohicky subtracts 1 from the activation count. - return atomic.AddUint64(&ct.activations, ^uint64(0)) - } - - return atomic.AddUint64(&ct.activations, 1) -} - -// Determines if the console is currently being throttled. Calls to this function can be used to -// determine if output should be funneled along to the websocket processes. -func (ct *ConsoleThrottler) Throttled() bool { - return ct.isThrottled.Load() -} - -// Starts a timer that runs in a seperate thread and will continually decrement the lines processed -// and number of activations, regardless of the current console message volume. All of the timers -// are canceled if the context passed through is canceled. -func (ct *ConsoleThrottler) StartTimer(ctx context.Context) { - system.Every(ctx, time.Duration(int64(ct.LineResetInterval))*time.Millisecond, func(_ time.Time) { - ct.isThrottled.Store(false) - atomic.StoreUint64(&ct.count, 0) - }) - - system.Every(ctx, time.Duration(int64(ct.DecayInterval))*time.Millisecond, func(_ time.Time) { - ct.markActivation(false) - }) -} - -// Handles output from a server's console. This code ensures that a server is not outputting -// an excessive amount of data to the console that could indicate a malicious or run-away process -// and lead to performance issues for other users. -// -// This was much more of a problem for the NodeJS version of the daemon which struggled to handle -// large volumes of output. However, this code is much more performant so I generally feel a lot -// better about it's abilities. -// -// However, extreme output is still somewhat of a DoS attack vector against this software since we -// are still logging it to the disk temporarily and will want to avoid dumping a huge amount of -// data all at once. These values are all configurable via the wings configuration file, however the -// defaults have been in the wild for almost two years at the time of this writing, so I feel quite -// confident in them. -// -// This function returns an error if the server should be stopped due to violating throttle constraints -// and a boolean value indicating if a throttle is being violated when it is checked. -func (ct *ConsoleThrottler) Increment(onTrigger func()) error { - if !ct.Enabled { - return nil - } - - // Increment the line count and if we have now output more lines than are allowed, trigger a throttle - // activation. Once the throttle is triggered and has passed the kill at value we will trigger a server - // stop automatically. - if atomic.AddUint64(&ct.count, 1) >= ct.Lines && !ct.Throttled() { - ct.isThrottled.Store(true) - if ct.markActivation(true) >= ct.MaximumTriggerCount { - return ErrTooMuchConsoleData - } - - onTrigger() - } - - return nil -} - -// Returns the throttler instance for the server or creates a new one. -func (s *Server) Throttler() *ConsoleThrottler { - s.throttleOnce.Do(func() { - s.throttler = &ConsoleThrottler{ - isThrottled: system.NewAtomicBool(false), - ConsoleThrottles: config.Get().Throttles, - } - }) - return s.throttler -} - // PublishConsoleOutputFromDaemon sends output to the server console formatted // to appear correctly as being sent from Wings. func (s *Server) PublishConsoleOutputFromDaemon(data string) { @@ -141,3 +27,55 @@ func (s *Server) PublishConsoleOutputFromDaemon(data string) { colorstring.Color(fmt.Sprintf("[yellow][bold][%s Daemon]:[default] %s", appName, data)), ) } + +// Throttler returns the throttler instance for the server or creates a new one. +func (s *Server) Throttler() *ConsoleThrottle { + s.throttleOnce.Do(func() { + throttles := config.Get().Throttles + period := time.Duration(throttles.Period) * time.Millisecond + + s.throttler = newConsoleThrottle(throttles.Lines, period) + s.throttler.strike = func() { + s.PublishConsoleOutputFromDaemon(fmt.Sprintf("Server is outputting console data too quickly -- throttling...")) + } + }) + return s.throttler +} + +type ConsoleThrottle struct { + limit *system.Rate + lock *system.Locker + strike func() +} + +func newConsoleThrottle(lines uint64, period time.Duration) *ConsoleThrottle { + return &ConsoleThrottle{ + limit: system.NewRate(lines, period), + lock: system.NewLocker(), + } +} + +// Allow checks if the console is allowed to process more output data, or if too +// much has already been sent over the line. If there is too much output the +// strike callback function is triggered, but only if it has not already been +// triggered at this point in the process. +// +// If output is allowed, the lock on the throttler is released and the next time +// it is triggered the strike function will be re-executed. +func (ct *ConsoleThrottle) Allow() bool { + if !ct.limit.Try() { + if err := ct.lock.Acquire(); err == nil { + if ct.strike != nil { + ct.strike() + } + } + return false + } + ct.lock.Release() + return true +} + +// Reset resets the console throttler internal rate limiter and overage counter. +func (ct *ConsoleThrottle) Reset() { + ct.limit.Reset() +} diff --git a/server/console_test.go b/server/console_test.go new file mode 100644 index 0000000..b9ca986 --- /dev/null +++ b/server/console_test.go @@ -0,0 +1,62 @@ +package server + +import ( + "testing" + "time" + + "github.com/franela/goblin" +) + +func TestName(t *testing.T) { + g := goblin.Goblin(t) + + g.Describe("ConsoleThrottler", func() { + g.It("keeps count of the number of overages in a time period", func() { + t := newConsoleThrottle(1, time.Second) + g.Assert(t.Allow()).IsTrue() + g.Assert(t.Allow()).IsFalse() + g.Assert(t.Allow()).IsFalse() + }) + + g.It("calls strike once per time period", func() { + t := newConsoleThrottle(1, time.Millisecond * 20) + + var times int + t.strike = func() { + times = times + 1 + } + + t.Allow() + t.Allow() + t.Allow() + time.Sleep(time.Millisecond * 100) + t.Allow() + t.Reset() + t.Allow() + t.Allow() + t.Allow() + + g.Assert(times).Equal(2) + }) + + g.It("is properly reset", func() { + t := newConsoleThrottle(10, time.Second) + + for i := 0; i < 10; i++ { + g.Assert(t.Allow()).IsTrue() + } + g.Assert(t.Allow()).IsFalse() + t.Reset() + g.Assert(t.Allow()).IsTrue() + }) + }) +} + +func BenchmarkConsoleThrottle(b *testing.B) { + t := newConsoleThrottle(10, time.Millisecond * 10) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + t.Allow() + } +} \ No newline at end of file diff --git a/server/listeners.go b/server/listeners.go index e2540a0..5d0bfb7 100644 --- a/server/listeners.go +++ b/server/listeners.go @@ -8,7 +8,6 @@ import ( "github.com/apex/log" - "github.com/pterodactyl/wings/config" "github.com/pterodactyl/wings/environment" "github.com/pterodactyl/wings/events" "github.com/pterodactyl/wings/remote" @@ -57,45 +56,23 @@ func (dsl *diskSpaceLimiter) Trigger() { // output lines to determine if the server is started yet, and if the output is // not being throttled, will send the data over to the websocket. func (s *Server) processConsoleOutputEvent(v []byte) { - t := s.Throttler() - err := t.Increment(func() { - s.PublishConsoleOutputFromDaemon("Your server is outputting too much data and is being throttled.") - }) - // An error is only returned if the server has breached the thresholds set. - if err != nil { - // If the process is already stopping, just let it continue with that action rather than attempting - // to terminate again. - if s.Environment.State() != environment.ProcessStoppingState { - s.Environment.SetState(environment.ProcessStoppingState) - - go func() { - s.Log().Warn("stopping server instance, violating throttle limits") - s.PublishConsoleOutputFromDaemon("Your server is being stopped for outputting too much data in a short period of time.") - - // Completely skip over server power actions and terminate the running instance. This gives the - // server 15 seconds to finish stopping gracefully before it is forcefully terminated. - if err := s.Environment.WaitForStop(config.Get().Throttles.StopGracePeriod, true); err != nil { - // If there is an error set the process back to running so that this throttler is called - // again and hopefully kills the server. - if s.Environment.State() != environment.ProcessOfflineState { - s.Environment.SetState(environment.ProcessRunningState) - } - - s.Log().WithField("error", err).Error("failed to terminate environment after triggering throttle") - } - }() - } - } - // Always process the console output, but do this in a seperate thread since we // don't really care about side-effects from this call, and don't want it to block // the console sending logic. go s.onConsoleOutput(v) - // If we are not throttled, go ahead and output the data. - if !t.Throttled() { - s.Sink(LogSink).Push(v) + // If the console is being throttled, do nothing else with it, we don't want + // to waste time. This code previously terminated server instances after violating + // different throttle limits. That code was clunky and difficult to reason about, + // in addition to being a consistent pain point for users. + // + // In the interest of building highly efficient software, that code has been removed + // here, and we'll rely on the host to detect bad actors through their own means. + if !s.Throttler().Allow() { + return } + + s.Sink(LogSink).Push(v) } // StartEventListeners adds all the internal event listeners we want to use for diff --git a/server/manager.go b/server/manager.go index 815b362..8b3ed2d 100644 --- a/server/manager.go +++ b/server/manager.go @@ -199,7 +199,6 @@ func (m *Manager) InitServer(data remote.ServerConfigurationResponse) (*Server, } else { s.Environment = env s.StartEventListeners() - s.Throttler().StartTimer(s.Context()) } // If the server's data directory exists, force disk usage calculation. diff --git a/server/power.go b/server/power.go index 7f34506..c7c6003 100644 --- a/server/power.go +++ b/server/power.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "sync" "time" "emperror.dev/errors" @@ -41,85 +40,6 @@ func (pa PowerAction) IsStart() bool { return pa == PowerActionStart || pa == PowerActionRestart } -type powerLocker struct { - mu sync.RWMutex - ch chan bool -} - -func newPowerLocker() *powerLocker { - return &powerLocker{ - ch: make(chan bool, 1), - } -} - -type errPowerLockerLocked struct{} - -func (e errPowerLockerLocked) Error() string { - return "cannot acquire a lock on the power state: already locked" -} - -var ErrPowerLockerLocked error = errPowerLockerLocked{} - -// IsLocked returns the current state of the locker channel. If there is -// currently a value in the channel, it is assumed to be locked. -func (pl *powerLocker) IsLocked() bool { - pl.mu.RLock() - defer pl.mu.RUnlock() - return len(pl.ch) == 1 -} - -// Acquire will acquire the power lock if it is not currently locked. If it is -// already locked, acquire will fail to acquire the lock, and will return false. -func (pl *powerLocker) Acquire() error { - pl.mu.Lock() - defer pl.mu.Unlock() - select { - case pl.ch <- true: - default: - return errors.WithStack(ErrPowerLockerLocked) - } - return nil -} - -// TryAcquire will attempt to acquire a power-lock until the context provided -// is canceled. -func (pl *powerLocker) TryAcquire(ctx context.Context) error { - select { - case pl.ch <- true: - return nil - case <-ctx.Done(): - if err := ctx.Err(); err != nil { - return errors.WithStack(err) - } - return nil - } -} - -// Release will drain the locker channel so that we can properly re-acquire it -// at a later time. If the channel is not currently locked this function is a -// no-op and will immediately return. -func (pl *powerLocker) Release() { - pl.mu.Lock() - select { - case <-pl.ch: - default: - } - pl.mu.Unlock() -} - -// Destroy cleans up the power locker by closing the channel. -func (pl *powerLocker) Destroy() { - pl.mu.Lock() - if pl.ch != nil { - select { - case <-pl.ch: - default: - } - close(pl.ch) - } - pl.mu.Unlock() -} - // ExecutingPowerAction checks if there is currently a power action being // processed for the server. func (s *Server) ExecutingPowerAction() bool { diff --git a/server/power_test.go b/server/power_test.go index 02446e7..b7aa29a 100644 --- a/server/power_test.go +++ b/server/power_test.go @@ -1,154 +1,18 @@ package server import ( - "context" "testing" - "time" - "emperror.dev/errors" . "github.com/franela/goblin" + "github.com/pterodactyl/wings/system" ) func TestPower(t *testing.T) { g := Goblin(t) - g.Describe("PowerLocker", func() { - var pl *powerLocker - g.BeforeEach(func() { - pl = newPowerLocker() - }) - - g.Describe("PowerLocker#IsLocked", func() { - g.It("should return false when the channel is empty", func() { - g.Assert(cap(pl.ch)).Equal(1) - g.Assert(pl.IsLocked()).IsFalse() - }) - - g.It("should return true when the channel is at capacity", func() { - pl.ch <- true - - g.Assert(pl.IsLocked()).IsTrue() - <-pl.ch - g.Assert(pl.IsLocked()).IsFalse() - - // We don't care what the channel value is, just that there is - // something in it. - pl.ch <- false - g.Assert(pl.IsLocked()).IsTrue() - g.Assert(cap(pl.ch)).Equal(1) - }) - }) - - g.Describe("PowerLocker#Acquire", func() { - g.It("should acquire a lock when channel is empty", func() { - err := pl.Acquire() - - g.Assert(err).IsNil() - g.Assert(cap(pl.ch)).Equal(1) - g.Assert(len(pl.ch)).Equal(1) - }) - - g.It("should return an error when the channel is full", func() { - pl.ch <- true - - err := pl.Acquire() - - g.Assert(err).IsNotNil() - g.Assert(errors.Is(err, ErrPowerLockerLocked)).IsTrue() - g.Assert(cap(pl.ch)).Equal(1) - g.Assert(len(pl.ch)).Equal(1) - }) - }) - - g.Describe("PowerLocker#TryAcquire", func() { - g.It("should acquire a lock when channel is empty", func() { - g.Timeout(time.Second) - - err := pl.TryAcquire(context.Background()) - - g.Assert(err).IsNil() - g.Assert(cap(pl.ch)).Equal(1) - g.Assert(len(pl.ch)).Equal(1) - g.Assert(pl.IsLocked()).IsTrue() - }) - - g.It("should block until context is canceled if channel is full", func() { - g.Timeout(time.Second) - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500) - defer cancel() - - pl.ch <- true - err := pl.TryAcquire(ctx) - - g.Assert(err).IsNotNil() - g.Assert(errors.Is(err, context.DeadlineExceeded)).IsTrue() - g.Assert(cap(pl.ch)).Equal(1) - g.Assert(len(pl.ch)).Equal(1) - g.Assert(pl.IsLocked()).IsTrue() - }) - - g.It("should block until lock can be acquired", func() { - g.Timeout(time.Second) - - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) - defer cancel() - - pl.Acquire() - go func() { - time.AfterFunc(time.Millisecond * 50, func() { - pl.Release() - }) - }() - - err := pl.TryAcquire(ctx) - g.Assert(err).IsNil() - g.Assert(cap(pl.ch)).Equal(1) - g.Assert(len(pl.ch)).Equal(1) - g.Assert(pl.IsLocked()).IsTrue() - }) - }) - - g.Describe("PowerLocker#Release", func() { - g.It("should release when channel is full", func() { - pl.Acquire() - g.Assert(pl.IsLocked()).IsTrue() - pl.Release() - g.Assert(cap(pl.ch)).Equal(1) - g.Assert(len(pl.ch)).Equal(0) - g.Assert(pl.IsLocked()).IsFalse() - }) - - g.It("should release when channel is empty", func() { - g.Assert(pl.IsLocked()).IsFalse() - pl.Release() - g.Assert(cap(pl.ch)).Equal(1) - g.Assert(len(pl.ch)).Equal(0) - g.Assert(pl.IsLocked()).IsFalse() - }) - }) - - g.Describe("PowerLocker#Destroy", func() { - g.It("should unlock and close the channel", func() { - pl.Acquire() - g.Assert(pl.IsLocked()).IsTrue() - pl.Destroy() - g.Assert(pl.IsLocked()).IsFalse() - - defer func() { - r := recover() - - g.Assert(r).IsNotNil() - g.Assert(r.(error).Error()).Equal("send on closed channel") - }() - - pl.Acquire() - }) - }) - }) - g.Describe("Server#ExecutingPowerAction", func() { g.It("should return based on locker status", func() { - s := &Server{powerLock: newPowerLocker()} + s := &Server{powerLock: system.NewLocker()} g.Assert(s.ExecutingPowerAction()).IsFalse() s.powerLock.Acquire() diff --git a/server/server.go b/server/server.go index af60daa..7027431 100644 --- a/server/server.go +++ b/server/server.go @@ -30,9 +30,8 @@ type Server struct { ctx context.Context ctxCancel *context.CancelFunc - emitterLock sync.Mutex - powerLock *powerLocker - throttleOnce sync.Once + emitterLock sync.Mutex + powerLock *system.Locker // Maintains the configuration for the server. This is the data that gets returned by the Panel // such as build settings and container images. @@ -64,7 +63,8 @@ type Server struct { restoring *system.AtomicBool // The console throttler instance used to control outputs. - throttler *ConsoleThrottler + throttler *ConsoleThrottle + throttleOnce sync.Once // Tracks open websocket connections for the server. wsBag *WebsocketBag @@ -87,7 +87,7 @@ func New(client remote.Client) (*Server, error) { installing: system.NewAtomicBool(false), transferring: system.NewAtomicBool(false), restoring: system.NewAtomicBool(false), - powerLock: newPowerLocker(), + powerLock: system.NewLocker(), sinks: map[SinkName]*sinkPool{ LogSink: newSinkPool(), InstallSink: newSinkPool(), diff --git a/system/locker.go b/system/locker.go new file mode 100644 index 0000000..eab2a15 --- /dev/null +++ b/system/locker.go @@ -0,0 +1,83 @@ +package system + +import ( + "context" + "sync" + + "emperror.dev/errors" +) + +var ErrLockerLocked = errors.Sentinel("locker: cannot acquire lock, already locked") + +type Locker struct { + mu sync.RWMutex + ch chan bool +} + +// NewLocker returns a new Locker instance. +func NewLocker() *Locker { + return &Locker{ + ch: make(chan bool, 1), + } +} + +// IsLocked returns the current state of the locker channel. If there is +// currently a value in the channel, it is assumed to be locked. +func (l *Locker) IsLocked() bool { + l.mu.RLock() + defer l.mu.RUnlock() + return len(l.ch) == 1 +} + +// Acquire will acquire the power lock if it is not currently locked. If it is +// already locked, acquire will fail to acquire the lock, and will return false. +func (l *Locker) Acquire() error { + l.mu.Lock() + defer l.mu.Unlock() + select { + case l.ch <- true: + default: + return ErrLockerLocked + } + return nil +} + + +// TryAcquire will attempt to acquire a power-lock until the context provided +// is canceled. +func (l *Locker) TryAcquire(ctx context.Context) error { + select { + case l.ch <- true: + return nil + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + return err + } + return nil + } +} + +// Release will drain the locker channel so that we can properly re-acquire it +// at a later time. If the channel is not currently locked this function is a +// no-op and will immediately return. +func (l *Locker) Release() { + l.mu.Lock() + select { + case <-l.ch: + default: + } + l.mu.Unlock() +} + +// Destroy cleans up the power locker by closing the channel. +func (l *Locker) Destroy() { + l.mu.Lock() + if l.ch != nil { + select { + case <-l.ch: + default: + } + close(l.ch) + } + l.mu.Unlock() +} diff --git a/system/locker_test.go b/system/locker_test.go new file mode 100644 index 0000000..72789f4 --- /dev/null +++ b/system/locker_test.go @@ -0,0 +1,148 @@ +package system + +import ( + "context" + "testing" + "time" + + "emperror.dev/errors" + . "github.com/franela/goblin" +) + +func TestPower(t *testing.T) { + g := Goblin(t) + + g.Describe("Locker", func() { + var l *Locker + g.BeforeEach(func() { + l = NewLocker() + }) + + g.Describe("PowerLocker#IsLocked", func() { + g.It("should return false when the channel is empty", func() { + g.Assert(cap(l.ch)).Equal(1) + g.Assert(l.IsLocked()).IsFalse() + }) + + g.It("should return true when the channel is at capacity", func() { + l.ch <- true + + g.Assert(l.IsLocked()).IsTrue() + <-l.ch + g.Assert(l.IsLocked()).IsFalse() + + // We don't care what the channel value is, just that there is + // something in it. + l.ch <- false + g.Assert(l.IsLocked()).IsTrue() + g.Assert(cap(l.ch)).Equal(1) + }) + }) + + g.Describe("PowerLocker#Acquire", func() { + g.It("should acquire a lock when channel is empty", func() { + err := l.Acquire() + + g.Assert(err).IsNil() + g.Assert(cap(l.ch)).Equal(1) + g.Assert(len(l.ch)).Equal(1) + }) + + g.It("should return an error when the channel is full", func() { + l.ch <- true + + err := l.Acquire() + + g.Assert(err).IsNotNil() + g.Assert(errors.Is(err, ErrLockerLocked)).IsTrue() + g.Assert(cap(l.ch)).Equal(1) + g.Assert(len(l.ch)).Equal(1) + }) + }) + + g.Describe("PowerLocker#TryAcquire", func() { + g.It("should acquire a lock when channel is empty", func() { + g.Timeout(time.Second) + + err := l.TryAcquire(context.Background()) + + g.Assert(err).IsNil() + g.Assert(cap(l.ch)).Equal(1) + g.Assert(len(l.ch)).Equal(1) + g.Assert(l.IsLocked()).IsTrue() + }) + + g.It("should block until context is canceled if channel is full", func() { + g.Timeout(time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500) + defer cancel() + + l.ch <- true + err := l.TryAcquire(ctx) + + g.Assert(err).IsNotNil() + g.Assert(errors.Is(err, context.DeadlineExceeded)).IsTrue() + g.Assert(cap(l.ch)).Equal(1) + g.Assert(len(l.ch)).Equal(1) + g.Assert(l.IsLocked()).IsTrue() + }) + + g.It("should block until lock can be acquired", func() { + g.Timeout(time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + l.Acquire() + go func() { + time.AfterFunc(time.Millisecond * 50, func() { + l.Release() + }) + }() + + err := l.TryAcquire(ctx) + g.Assert(err).IsNil() + g.Assert(cap(l.ch)).Equal(1) + g.Assert(len(l.ch)).Equal(1) + g.Assert(l.IsLocked()).IsTrue() + }) + }) + + g.Describe("PowerLocker#Release", func() { + g.It("should release when channel is full", func() { + l.Acquire() + g.Assert(l.IsLocked()).IsTrue() + l.Release() + g.Assert(cap(l.ch)).Equal(1) + g.Assert(len(l.ch)).Equal(0) + g.Assert(l.IsLocked()).IsFalse() + }) + + g.It("should release when channel is empty", func() { + g.Assert(l.IsLocked()).IsFalse() + l.Release() + g.Assert(cap(l.ch)).Equal(1) + g.Assert(len(l.ch)).Equal(0) + g.Assert(l.IsLocked()).IsFalse() + }) + }) + + g.Describe("PowerLocker#Destroy", func() { + g.It("should unlock and close the channel", func() { + l.Acquire() + g.Assert(l.IsLocked()).IsTrue() + l.Destroy() + g.Assert(l.IsLocked()).IsFalse() + + defer func() { + r := recover() + + g.Assert(r).IsNotNil() + g.Assert(r.(error).Error()).Equal("send on closed channel") + }() + + l.Acquire() + }) + }) + }) +} diff --git a/system/rate.go b/system/rate.go new file mode 100644 index 0000000..270e90b --- /dev/null +++ b/system/rate.go @@ -0,0 +1,50 @@ +package system + +import ( + "sync" + "time" +) + +// Rate defines a rate limiter of n items (limit) per duration of time. +type Rate struct { + mu sync.Mutex + limit uint64 + duration time.Duration + count uint64 + last time.Time +} + +func NewRate(limit uint64, duration time.Duration) *Rate { + return &Rate{ + limit: limit, + duration: duration, + last: time.Now(), + } +} + +// Try returns true if under the rate limit defined, or false if the rate limit +// has been exceeded for the current duration. +func (r *Rate) Try() bool { + r.mu.Lock() + defer r.mu.Unlock() + now := time.Now() + // If it has been more than the duration, reset the timer and count. + if now.Sub(r.last) > r.duration { + r.count = 0 + r.last = now + } + if (r.count + 1) > r.limit { + return false + } + // Hit this once, and return. + r.count = r.count + 1 + return true +} + +// Reset resets the internal state of the rate limiter back to zero. +func (r *Rate) Reset() { + r.mu.Lock() + r.count = 0 + r.last = time.Now() + r.mu.Unlock() +} \ No newline at end of file diff --git a/system/rate_test.go b/system/rate_test.go new file mode 100644 index 0000000..3271723 --- /dev/null +++ b/system/rate_test.go @@ -0,0 +1,67 @@ +package system + +import ( + "testing" + "time" + + . "github.com/franela/goblin" +) + +func TestRate(t *testing.T) { + g := Goblin(t) + + g.Describe("Rate", func() { + g.It("properly rate limits a bucket", func() { + r := NewRate(10, time.Millisecond*100) + + for i := 0; i < 100; i++ { + ok := r.Try() + if i < 10 && !ok { + g.Failf("should not have allowed take on try %d", i) + } else if i >= 10 && ok { + g.Failf("should have blocked take on try %d", i) + } + } + }) + + g.It("handles rate limiting in chunks", func() { + var out []int + r := NewRate(12, time.Millisecond*10) + + for i := 0; i < 100; i++ { + if i%20 == 0 { + // Give it time to recover. + time.Sleep(time.Millisecond * 10) + } + if r.Try() { + out = append(out, i) + } + } + + g.Assert(len(out)).Equal(60) + g.Assert(out[0]).Equal(0) + g.Assert(out[12]).Equal(20) + g.Assert(out[len(out)-1]).Equal(91) + }) + + g.It("resets back to zero when called", func() { + r := NewRate(10, time.Second) + for i := 0; i < 100; i++ { + if i % 10 == 0 { + r.Reset() + } + g.Assert(r.Try()).IsTrue() + } + g.Assert(r.Try()).IsFalse("final attempt should not allow taking") + }) + }) +} + +func BenchmarkRate_Try(b *testing.B) { + r := NewRate(10, time.Millisecond*100) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + r.Try() + } +}