Rewrite console throttling logic; drop complex timer usage and use a very simple throttle

This also removes server process termination logic when a server is breaching the output limits. It simply continues to efficiently throttle the console output.
This commit is contained in:
Dane Everitt 2022-01-30 19:31:04 -05:00
parent fb73d5dbbf
commit 2b2b5200eb
12 changed files with 482 additions and 386 deletions

View File

@ -222,26 +222,14 @@ type ConsoleThrottles struct {
// Whether or not the throttler is enabled for this instance. // Whether or not the throttler is enabled for this instance.
Enabled bool `json:"enabled" yaml:"enabled" default:"true"` Enabled bool `json:"enabled" yaml:"enabled" default:"true"`
// The total number of lines that can be output in a given LineResetInterval period before // The total number of lines that can be output in a given Period period before
// a warning is triggered and counted against the server. // a warning is triggered and counted against the server.
Lines uint64 `json:"lines" yaml:"lines" default:"2000"` Lines uint64 `json:"lines" yaml:"lines" default:"2000"`
// The total number of throttle activations that can accumulate before a server is considered
// to be breaching and will be stopped. This value is decremented by one every DecayInterval.
MaximumTriggerCount uint64 `json:"maximum_trigger_count" yaml:"maximum_trigger_count" default:"5"`
// The amount of time after which the number of lines processed is reset to 0. This runs in // The amount of time after which the number of lines processed is reset to 0. This runs in
// a constant loop and is not affected by the current console output volumes. By default, this // a constant loop and is not affected by the current console output volumes. By default, this
// will reset the processed line count back to 0 every 100ms. // will reset the processed line count back to 0 every 100ms.
LineResetInterval uint64 `json:"line_reset_interval" yaml:"line_reset_interval" default:"100"` Period uint64 `json:"line_reset_interval" yaml:"line_reset_interval" default:"100"`
// The amount of time in milliseconds that must pass without an output warning being triggered
// before a throttle activation is decremented.
DecayInterval uint64 `json:"decay_interval" yaml:"decay_interval" default:"10000"`
// The amount of time that a server is allowed to be stopping for before it is terminated
// forcefully if it triggers output throttles.
StopGracePeriod uint `json:"stop_grace_period" yaml:"stop_grace_period" default:"15"`
} }
type Configuration struct { type Configuration struct {

View File

@ -1,15 +1,11 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"time" "time"
"emperror.dev/errors"
"github.com/mitchellh/colorstring" "github.com/mitchellh/colorstring"
"github.com/pterodactyl/wings/config" "github.com/pterodactyl/wings/config"
"github.com/pterodactyl/wings/system" "github.com/pterodactyl/wings/system"
) )
@ -18,118 +14,8 @@ import (
// the configuration every time we need to send output along to the websocket for // the configuration every time we need to send output along to the websocket for
// a server. // a server.
var appName string var appName string
var appNameSync sync.Once var appNameSync sync.Once
var ErrTooMuchConsoleData = errors.New("console is outputting too much data")
type ConsoleThrottler struct {
mu sync.Mutex
config.ConsoleThrottles
// The total number of activations that have occurred thus far.
activations uint64
// The total number of lines that have been sent since the last reset timer period.
count uint64
// Wether or not the console output is being throttled. It is up to calling code to
// determine what to do if it is.
isThrottled *system.AtomicBool
// The total number of lines processed so far during the given time period.
timerCancel *context.CancelFunc
}
// Resets the state of the throttler.
func (ct *ConsoleThrottler) Reset() {
atomic.StoreUint64(&ct.count, 0)
atomic.StoreUint64(&ct.activations, 0)
ct.isThrottled.Store(false)
}
// Triggers an activation for a server. You can also decrement the number of activations
// by passing a negative number.
func (ct *ConsoleThrottler) markActivation(increment bool) uint64 {
if !increment {
if atomic.LoadUint64(&ct.activations) == 0 {
return 0
}
// This weird dohicky subtracts 1 from the activation count.
return atomic.AddUint64(&ct.activations, ^uint64(0))
}
return atomic.AddUint64(&ct.activations, 1)
}
// Determines if the console is currently being throttled. Calls to this function can be used to
// determine if output should be funneled along to the websocket processes.
func (ct *ConsoleThrottler) Throttled() bool {
return ct.isThrottled.Load()
}
// Starts a timer that runs in a seperate thread and will continually decrement the lines processed
// and number of activations, regardless of the current console message volume. All of the timers
// are canceled if the context passed through is canceled.
func (ct *ConsoleThrottler) StartTimer(ctx context.Context) {
system.Every(ctx, time.Duration(int64(ct.LineResetInterval))*time.Millisecond, func(_ time.Time) {
ct.isThrottled.Store(false)
atomic.StoreUint64(&ct.count, 0)
})
system.Every(ctx, time.Duration(int64(ct.DecayInterval))*time.Millisecond, func(_ time.Time) {
ct.markActivation(false)
})
}
// Handles output from a server's console. This code ensures that a server is not outputting
// an excessive amount of data to the console that could indicate a malicious or run-away process
// and lead to performance issues for other users.
//
// This was much more of a problem for the NodeJS version of the daemon which struggled to handle
// large volumes of output. However, this code is much more performant so I generally feel a lot
// better about it's abilities.
//
// However, extreme output is still somewhat of a DoS attack vector against this software since we
// are still logging it to the disk temporarily and will want to avoid dumping a huge amount of
// data all at once. These values are all configurable via the wings configuration file, however the
// defaults have been in the wild for almost two years at the time of this writing, so I feel quite
// confident in them.
//
// This function returns an error if the server should be stopped due to violating throttle constraints
// and a boolean value indicating if a throttle is being violated when it is checked.
func (ct *ConsoleThrottler) Increment(onTrigger func()) error {
if !ct.Enabled {
return nil
}
// Increment the line count and if we have now output more lines than are allowed, trigger a throttle
// activation. Once the throttle is triggered and has passed the kill at value we will trigger a server
// stop automatically.
if atomic.AddUint64(&ct.count, 1) >= ct.Lines && !ct.Throttled() {
ct.isThrottled.Store(true)
if ct.markActivation(true) >= ct.MaximumTriggerCount {
return ErrTooMuchConsoleData
}
onTrigger()
}
return nil
}
// Returns the throttler instance for the server or creates a new one.
func (s *Server) Throttler() *ConsoleThrottler {
s.throttleOnce.Do(func() {
s.throttler = &ConsoleThrottler{
isThrottled: system.NewAtomicBool(false),
ConsoleThrottles: config.Get().Throttles,
}
})
return s.throttler
}
// PublishConsoleOutputFromDaemon sends output to the server console formatted // PublishConsoleOutputFromDaemon sends output to the server console formatted
// to appear correctly as being sent from Wings. // to appear correctly as being sent from Wings.
func (s *Server) PublishConsoleOutputFromDaemon(data string) { func (s *Server) PublishConsoleOutputFromDaemon(data string) {
@ -141,3 +27,55 @@ func (s *Server) PublishConsoleOutputFromDaemon(data string) {
colorstring.Color(fmt.Sprintf("[yellow][bold][%s Daemon]:[default] %s", appName, data)), colorstring.Color(fmt.Sprintf("[yellow][bold][%s Daemon]:[default] %s", appName, data)),
) )
} }
// Throttler returns the throttler instance for the server or creates a new one.
func (s *Server) Throttler() *ConsoleThrottle {
s.throttleOnce.Do(func() {
throttles := config.Get().Throttles
period := time.Duration(throttles.Period) * time.Millisecond
s.throttler = newConsoleThrottle(throttles.Lines, period)
s.throttler.strike = func() {
s.PublishConsoleOutputFromDaemon(fmt.Sprintf("Server is outputting console data too quickly -- throttling..."))
}
})
return s.throttler
}
type ConsoleThrottle struct {
limit *system.Rate
lock *system.Locker
strike func()
}
func newConsoleThrottle(lines uint64, period time.Duration) *ConsoleThrottle {
return &ConsoleThrottle{
limit: system.NewRate(lines, period),
lock: system.NewLocker(),
}
}
// Allow checks if the console is allowed to process more output data, or if too
// much has already been sent over the line. If there is too much output the
// strike callback function is triggered, but only if it has not already been
// triggered at this point in the process.
//
// If output is allowed, the lock on the throttler is released and the next time
// it is triggered the strike function will be re-executed.
func (ct *ConsoleThrottle) Allow() bool {
if !ct.limit.Try() {
if err := ct.lock.Acquire(); err == nil {
if ct.strike != nil {
ct.strike()
}
}
return false
}
ct.lock.Release()
return true
}
// Reset resets the console throttler internal rate limiter and overage counter.
func (ct *ConsoleThrottle) Reset() {
ct.limit.Reset()
}

62
server/console_test.go Normal file
View File

@ -0,0 +1,62 @@
package server
import (
"testing"
"time"
"github.com/franela/goblin"
)
func TestName(t *testing.T) {
g := goblin.Goblin(t)
g.Describe("ConsoleThrottler", func() {
g.It("keeps count of the number of overages in a time period", func() {
t := newConsoleThrottle(1, time.Second)
g.Assert(t.Allow()).IsTrue()
g.Assert(t.Allow()).IsFalse()
g.Assert(t.Allow()).IsFalse()
})
g.It("calls strike once per time period", func() {
t := newConsoleThrottle(1, time.Millisecond * 20)
var times int
t.strike = func() {
times = times + 1
}
t.Allow()
t.Allow()
t.Allow()
time.Sleep(time.Millisecond * 100)
t.Allow()
t.Reset()
t.Allow()
t.Allow()
t.Allow()
g.Assert(times).Equal(2)
})
g.It("is properly reset", func() {
t := newConsoleThrottle(10, time.Second)
for i := 0; i < 10; i++ {
g.Assert(t.Allow()).IsTrue()
}
g.Assert(t.Allow()).IsFalse()
t.Reset()
g.Assert(t.Allow()).IsTrue()
})
})
}
func BenchmarkConsoleThrottle(b *testing.B) {
t := newConsoleThrottle(10, time.Millisecond * 10)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
t.Allow()
}
}

View File

@ -8,7 +8,6 @@ import (
"github.com/apex/log" "github.com/apex/log"
"github.com/pterodactyl/wings/config"
"github.com/pterodactyl/wings/environment" "github.com/pterodactyl/wings/environment"
"github.com/pterodactyl/wings/events" "github.com/pterodactyl/wings/events"
"github.com/pterodactyl/wings/remote" "github.com/pterodactyl/wings/remote"
@ -57,45 +56,23 @@ func (dsl *diskSpaceLimiter) Trigger() {
// output lines to determine if the server is started yet, and if the output is // output lines to determine if the server is started yet, and if the output is
// not being throttled, will send the data over to the websocket. // not being throttled, will send the data over to the websocket.
func (s *Server) processConsoleOutputEvent(v []byte) { func (s *Server) processConsoleOutputEvent(v []byte) {
t := s.Throttler()
err := t.Increment(func() {
s.PublishConsoleOutputFromDaemon("Your server is outputting too much data and is being throttled.")
})
// An error is only returned if the server has breached the thresholds set.
if err != nil {
// If the process is already stopping, just let it continue with that action rather than attempting
// to terminate again.
if s.Environment.State() != environment.ProcessStoppingState {
s.Environment.SetState(environment.ProcessStoppingState)
go func() {
s.Log().Warn("stopping server instance, violating throttle limits")
s.PublishConsoleOutputFromDaemon("Your server is being stopped for outputting too much data in a short period of time.")
// Completely skip over server power actions and terminate the running instance. This gives the
// server 15 seconds to finish stopping gracefully before it is forcefully terminated.
if err := s.Environment.WaitForStop(config.Get().Throttles.StopGracePeriod, true); err != nil {
// If there is an error set the process back to running so that this throttler is called
// again and hopefully kills the server.
if s.Environment.State() != environment.ProcessOfflineState {
s.Environment.SetState(environment.ProcessRunningState)
}
s.Log().WithField("error", err).Error("failed to terminate environment after triggering throttle")
}
}()
}
}
// Always process the console output, but do this in a seperate thread since we // Always process the console output, but do this in a seperate thread since we
// don't really care about side-effects from this call, and don't want it to block // don't really care about side-effects from this call, and don't want it to block
// the console sending logic. // the console sending logic.
go s.onConsoleOutput(v) go s.onConsoleOutput(v)
// If we are not throttled, go ahead and output the data. // If the console is being throttled, do nothing else with it, we don't want
if !t.Throttled() { // to waste time. This code previously terminated server instances after violating
s.Sink(LogSink).Push(v) // different throttle limits. That code was clunky and difficult to reason about,
// in addition to being a consistent pain point for users.
//
// In the interest of building highly efficient software, that code has been removed
// here, and we'll rely on the host to detect bad actors through their own means.
if !s.Throttler().Allow() {
return
} }
s.Sink(LogSink).Push(v)
} }
// StartEventListeners adds all the internal event listeners we want to use for // StartEventListeners adds all the internal event listeners we want to use for

View File

@ -199,7 +199,6 @@ func (m *Manager) InitServer(data remote.ServerConfigurationResponse) (*Server,
} else { } else {
s.Environment = env s.Environment = env
s.StartEventListeners() s.StartEventListeners()
s.Throttler().StartTimer(s.Context())
} }
// If the server's data directory exists, force disk usage calculation. // If the server's data directory exists, force disk usage calculation.

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"sync"
"time" "time"
"emperror.dev/errors" "emperror.dev/errors"
@ -41,85 +40,6 @@ func (pa PowerAction) IsStart() bool {
return pa == PowerActionStart || pa == PowerActionRestart return pa == PowerActionStart || pa == PowerActionRestart
} }
type powerLocker struct {
mu sync.RWMutex
ch chan bool
}
func newPowerLocker() *powerLocker {
return &powerLocker{
ch: make(chan bool, 1),
}
}
type errPowerLockerLocked struct{}
func (e errPowerLockerLocked) Error() string {
return "cannot acquire a lock on the power state: already locked"
}
var ErrPowerLockerLocked error = errPowerLockerLocked{}
// IsLocked returns the current state of the locker channel. If there is
// currently a value in the channel, it is assumed to be locked.
func (pl *powerLocker) IsLocked() bool {
pl.mu.RLock()
defer pl.mu.RUnlock()
return len(pl.ch) == 1
}
// Acquire will acquire the power lock if it is not currently locked. If it is
// already locked, acquire will fail to acquire the lock, and will return false.
func (pl *powerLocker) Acquire() error {
pl.mu.Lock()
defer pl.mu.Unlock()
select {
case pl.ch <- true:
default:
return errors.WithStack(ErrPowerLockerLocked)
}
return nil
}
// TryAcquire will attempt to acquire a power-lock until the context provided
// is canceled.
func (pl *powerLocker) TryAcquire(ctx context.Context) error {
select {
case pl.ch <- true:
return nil
case <-ctx.Done():
if err := ctx.Err(); err != nil {
return errors.WithStack(err)
}
return nil
}
}
// Release will drain the locker channel so that we can properly re-acquire it
// at a later time. If the channel is not currently locked this function is a
// no-op and will immediately return.
func (pl *powerLocker) Release() {
pl.mu.Lock()
select {
case <-pl.ch:
default:
}
pl.mu.Unlock()
}
// Destroy cleans up the power locker by closing the channel.
func (pl *powerLocker) Destroy() {
pl.mu.Lock()
if pl.ch != nil {
select {
case <-pl.ch:
default:
}
close(pl.ch)
}
pl.mu.Unlock()
}
// ExecutingPowerAction checks if there is currently a power action being // ExecutingPowerAction checks if there is currently a power action being
// processed for the server. // processed for the server.
func (s *Server) ExecutingPowerAction() bool { func (s *Server) ExecutingPowerAction() bool {

View File

@ -1,154 +1,18 @@
package server package server
import ( import (
"context"
"testing" "testing"
"time"
"emperror.dev/errors"
. "github.com/franela/goblin" . "github.com/franela/goblin"
"github.com/pterodactyl/wings/system"
) )
func TestPower(t *testing.T) { func TestPower(t *testing.T) {
g := Goblin(t) g := Goblin(t)
g.Describe("PowerLocker", func() {
var pl *powerLocker
g.BeforeEach(func() {
pl = newPowerLocker()
})
g.Describe("PowerLocker#IsLocked", func() {
g.It("should return false when the channel is empty", func() {
g.Assert(cap(pl.ch)).Equal(1)
g.Assert(pl.IsLocked()).IsFalse()
})
g.It("should return true when the channel is at capacity", func() {
pl.ch <- true
g.Assert(pl.IsLocked()).IsTrue()
<-pl.ch
g.Assert(pl.IsLocked()).IsFalse()
// We don't care what the channel value is, just that there is
// something in it.
pl.ch <- false
g.Assert(pl.IsLocked()).IsTrue()
g.Assert(cap(pl.ch)).Equal(1)
})
})
g.Describe("PowerLocker#Acquire", func() {
g.It("should acquire a lock when channel is empty", func() {
err := pl.Acquire()
g.Assert(err).IsNil()
g.Assert(cap(pl.ch)).Equal(1)
g.Assert(len(pl.ch)).Equal(1)
})
g.It("should return an error when the channel is full", func() {
pl.ch <- true
err := pl.Acquire()
g.Assert(err).IsNotNil()
g.Assert(errors.Is(err, ErrPowerLockerLocked)).IsTrue()
g.Assert(cap(pl.ch)).Equal(1)
g.Assert(len(pl.ch)).Equal(1)
})
})
g.Describe("PowerLocker#TryAcquire", func() {
g.It("should acquire a lock when channel is empty", func() {
g.Timeout(time.Second)
err := pl.TryAcquire(context.Background())
g.Assert(err).IsNil()
g.Assert(cap(pl.ch)).Equal(1)
g.Assert(len(pl.ch)).Equal(1)
g.Assert(pl.IsLocked()).IsTrue()
})
g.It("should block until context is canceled if channel is full", func() {
g.Timeout(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
defer cancel()
pl.ch <- true
err := pl.TryAcquire(ctx)
g.Assert(err).IsNotNil()
g.Assert(errors.Is(err, context.DeadlineExceeded)).IsTrue()
g.Assert(cap(pl.ch)).Equal(1)
g.Assert(len(pl.ch)).Equal(1)
g.Assert(pl.IsLocked()).IsTrue()
})
g.It("should block until lock can be acquired", func() {
g.Timeout(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()
pl.Acquire()
go func() {
time.AfterFunc(time.Millisecond * 50, func() {
pl.Release()
})
}()
err := pl.TryAcquire(ctx)
g.Assert(err).IsNil()
g.Assert(cap(pl.ch)).Equal(1)
g.Assert(len(pl.ch)).Equal(1)
g.Assert(pl.IsLocked()).IsTrue()
})
})
g.Describe("PowerLocker#Release", func() {
g.It("should release when channel is full", func() {
pl.Acquire()
g.Assert(pl.IsLocked()).IsTrue()
pl.Release()
g.Assert(cap(pl.ch)).Equal(1)
g.Assert(len(pl.ch)).Equal(0)
g.Assert(pl.IsLocked()).IsFalse()
})
g.It("should release when channel is empty", func() {
g.Assert(pl.IsLocked()).IsFalse()
pl.Release()
g.Assert(cap(pl.ch)).Equal(1)
g.Assert(len(pl.ch)).Equal(0)
g.Assert(pl.IsLocked()).IsFalse()
})
})
g.Describe("PowerLocker#Destroy", func() {
g.It("should unlock and close the channel", func() {
pl.Acquire()
g.Assert(pl.IsLocked()).IsTrue()
pl.Destroy()
g.Assert(pl.IsLocked()).IsFalse()
defer func() {
r := recover()
g.Assert(r).IsNotNil()
g.Assert(r.(error).Error()).Equal("send on closed channel")
}()
pl.Acquire()
})
})
})
g.Describe("Server#ExecutingPowerAction", func() { g.Describe("Server#ExecutingPowerAction", func() {
g.It("should return based on locker status", func() { g.It("should return based on locker status", func() {
s := &Server{powerLock: newPowerLocker()} s := &Server{powerLock: system.NewLocker()}
g.Assert(s.ExecutingPowerAction()).IsFalse() g.Assert(s.ExecutingPowerAction()).IsFalse()
s.powerLock.Acquire() s.powerLock.Acquire()

View File

@ -30,9 +30,8 @@ type Server struct {
ctx context.Context ctx context.Context
ctxCancel *context.CancelFunc ctxCancel *context.CancelFunc
emitterLock sync.Mutex emitterLock sync.Mutex
powerLock *powerLocker powerLock *system.Locker
throttleOnce sync.Once
// Maintains the configuration for the server. This is the data that gets returned by the Panel // Maintains the configuration for the server. This is the data that gets returned by the Panel
// such as build settings and container images. // such as build settings and container images.
@ -64,7 +63,8 @@ type Server struct {
restoring *system.AtomicBool restoring *system.AtomicBool
// The console throttler instance used to control outputs. // The console throttler instance used to control outputs.
throttler *ConsoleThrottler throttler *ConsoleThrottle
throttleOnce sync.Once
// Tracks open websocket connections for the server. // Tracks open websocket connections for the server.
wsBag *WebsocketBag wsBag *WebsocketBag
@ -87,7 +87,7 @@ func New(client remote.Client) (*Server, error) {
installing: system.NewAtomicBool(false), installing: system.NewAtomicBool(false),
transferring: system.NewAtomicBool(false), transferring: system.NewAtomicBool(false),
restoring: system.NewAtomicBool(false), restoring: system.NewAtomicBool(false),
powerLock: newPowerLocker(), powerLock: system.NewLocker(),
sinks: map[SinkName]*sinkPool{ sinks: map[SinkName]*sinkPool{
LogSink: newSinkPool(), LogSink: newSinkPool(),
InstallSink: newSinkPool(), InstallSink: newSinkPool(),

83
system/locker.go Normal file
View File

@ -0,0 +1,83 @@
package system
import (
"context"
"sync"
"emperror.dev/errors"
)
var ErrLockerLocked = errors.Sentinel("locker: cannot acquire lock, already locked")
type Locker struct {
mu sync.RWMutex
ch chan bool
}
// NewLocker returns a new Locker instance.
func NewLocker() *Locker {
return &Locker{
ch: make(chan bool, 1),
}
}
// IsLocked returns the current state of the locker channel. If there is
// currently a value in the channel, it is assumed to be locked.
func (l *Locker) IsLocked() bool {
l.mu.RLock()
defer l.mu.RUnlock()
return len(l.ch) == 1
}
// Acquire will acquire the power lock if it is not currently locked. If it is
// already locked, acquire will fail to acquire the lock, and will return false.
func (l *Locker) Acquire() error {
l.mu.Lock()
defer l.mu.Unlock()
select {
case l.ch <- true:
default:
return ErrLockerLocked
}
return nil
}
// TryAcquire will attempt to acquire a power-lock until the context provided
// is canceled.
func (l *Locker) TryAcquire(ctx context.Context) error {
select {
case l.ch <- true:
return nil
case <-ctx.Done():
if err := ctx.Err(); err != nil {
return err
}
return nil
}
}
// Release will drain the locker channel so that we can properly re-acquire it
// at a later time. If the channel is not currently locked this function is a
// no-op and will immediately return.
func (l *Locker) Release() {
l.mu.Lock()
select {
case <-l.ch:
default:
}
l.mu.Unlock()
}
// Destroy cleans up the power locker by closing the channel.
func (l *Locker) Destroy() {
l.mu.Lock()
if l.ch != nil {
select {
case <-l.ch:
default:
}
close(l.ch)
}
l.mu.Unlock()
}

148
system/locker_test.go Normal file
View File

@ -0,0 +1,148 @@
package system
import (
"context"
"testing"
"time"
"emperror.dev/errors"
. "github.com/franela/goblin"
)
func TestPower(t *testing.T) {
g := Goblin(t)
g.Describe("Locker", func() {
var l *Locker
g.BeforeEach(func() {
l = NewLocker()
})
g.Describe("PowerLocker#IsLocked", func() {
g.It("should return false when the channel is empty", func() {
g.Assert(cap(l.ch)).Equal(1)
g.Assert(l.IsLocked()).IsFalse()
})
g.It("should return true when the channel is at capacity", func() {
l.ch <- true
g.Assert(l.IsLocked()).IsTrue()
<-l.ch
g.Assert(l.IsLocked()).IsFalse()
// We don't care what the channel value is, just that there is
// something in it.
l.ch <- false
g.Assert(l.IsLocked()).IsTrue()
g.Assert(cap(l.ch)).Equal(1)
})
})
g.Describe("PowerLocker#Acquire", func() {
g.It("should acquire a lock when channel is empty", func() {
err := l.Acquire()
g.Assert(err).IsNil()
g.Assert(cap(l.ch)).Equal(1)
g.Assert(len(l.ch)).Equal(1)
})
g.It("should return an error when the channel is full", func() {
l.ch <- true
err := l.Acquire()
g.Assert(err).IsNotNil()
g.Assert(errors.Is(err, ErrLockerLocked)).IsTrue()
g.Assert(cap(l.ch)).Equal(1)
g.Assert(len(l.ch)).Equal(1)
})
})
g.Describe("PowerLocker#TryAcquire", func() {
g.It("should acquire a lock when channel is empty", func() {
g.Timeout(time.Second)
err := l.TryAcquire(context.Background())
g.Assert(err).IsNil()
g.Assert(cap(l.ch)).Equal(1)
g.Assert(len(l.ch)).Equal(1)
g.Assert(l.IsLocked()).IsTrue()
})
g.It("should block until context is canceled if channel is full", func() {
g.Timeout(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*500)
defer cancel()
l.ch <- true
err := l.TryAcquire(ctx)
g.Assert(err).IsNotNil()
g.Assert(errors.Is(err, context.DeadlineExceeded)).IsTrue()
g.Assert(cap(l.ch)).Equal(1)
g.Assert(len(l.ch)).Equal(1)
g.Assert(l.IsLocked()).IsTrue()
})
g.It("should block until lock can be acquired", func() {
g.Timeout(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
defer cancel()
l.Acquire()
go func() {
time.AfterFunc(time.Millisecond * 50, func() {
l.Release()
})
}()
err := l.TryAcquire(ctx)
g.Assert(err).IsNil()
g.Assert(cap(l.ch)).Equal(1)
g.Assert(len(l.ch)).Equal(1)
g.Assert(l.IsLocked()).IsTrue()
})
})
g.Describe("PowerLocker#Release", func() {
g.It("should release when channel is full", func() {
l.Acquire()
g.Assert(l.IsLocked()).IsTrue()
l.Release()
g.Assert(cap(l.ch)).Equal(1)
g.Assert(len(l.ch)).Equal(0)
g.Assert(l.IsLocked()).IsFalse()
})
g.It("should release when channel is empty", func() {
g.Assert(l.IsLocked()).IsFalse()
l.Release()
g.Assert(cap(l.ch)).Equal(1)
g.Assert(len(l.ch)).Equal(0)
g.Assert(l.IsLocked()).IsFalse()
})
})
g.Describe("PowerLocker#Destroy", func() {
g.It("should unlock and close the channel", func() {
l.Acquire()
g.Assert(l.IsLocked()).IsTrue()
l.Destroy()
g.Assert(l.IsLocked()).IsFalse()
defer func() {
r := recover()
g.Assert(r).IsNotNil()
g.Assert(r.(error).Error()).Equal("send on closed channel")
}()
l.Acquire()
})
})
})
}

50
system/rate.go Normal file
View File

@ -0,0 +1,50 @@
package system
import (
"sync"
"time"
)
// Rate defines a rate limiter of n items (limit) per duration of time.
type Rate struct {
mu sync.Mutex
limit uint64
duration time.Duration
count uint64
last time.Time
}
func NewRate(limit uint64, duration time.Duration) *Rate {
return &Rate{
limit: limit,
duration: duration,
last: time.Now(),
}
}
// Try returns true if under the rate limit defined, or false if the rate limit
// has been exceeded for the current duration.
func (r *Rate) Try() bool {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
// If it has been more than the duration, reset the timer and count.
if now.Sub(r.last) > r.duration {
r.count = 0
r.last = now
}
if (r.count + 1) > r.limit {
return false
}
// Hit this once, and return.
r.count = r.count + 1
return true
}
// Reset resets the internal state of the rate limiter back to zero.
func (r *Rate) Reset() {
r.mu.Lock()
r.count = 0
r.last = time.Now()
r.mu.Unlock()
}

67
system/rate_test.go Normal file
View File

@ -0,0 +1,67 @@
package system
import (
"testing"
"time"
. "github.com/franela/goblin"
)
func TestRate(t *testing.T) {
g := Goblin(t)
g.Describe("Rate", func() {
g.It("properly rate limits a bucket", func() {
r := NewRate(10, time.Millisecond*100)
for i := 0; i < 100; i++ {
ok := r.Try()
if i < 10 && !ok {
g.Failf("should not have allowed take on try %d", i)
} else if i >= 10 && ok {
g.Failf("should have blocked take on try %d", i)
}
}
})
g.It("handles rate limiting in chunks", func() {
var out []int
r := NewRate(12, time.Millisecond*10)
for i := 0; i < 100; i++ {
if i%20 == 0 {
// Give it time to recover.
time.Sleep(time.Millisecond * 10)
}
if r.Try() {
out = append(out, i)
}
}
g.Assert(len(out)).Equal(60)
g.Assert(out[0]).Equal(0)
g.Assert(out[12]).Equal(20)
g.Assert(out[len(out)-1]).Equal(91)
})
g.It("resets back to zero when called", func() {
r := NewRate(10, time.Second)
for i := 0; i < 100; i++ {
if i % 10 == 0 {
r.Reset()
}
g.Assert(r.Try()).IsTrue()
}
g.Assert(r.Try()).IsFalse("final attempt should not allow taking")
})
})
}
func BenchmarkRate_Try(b *testing.B) {
r := NewRate(10, time.Millisecond*100)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
r.Try()
}
}