diff --git a/events/events.go b/events/events.go index 3e7f699..16c4861 100644 --- a/events/events.go +++ b/events/events.go @@ -3,6 +3,8 @@ package events import ( "strings" "sync" + + "github.com/apex/log" ) type Listener chan Event @@ -31,8 +33,15 @@ func (b *Bus) Off(listener Listener, topics ...string) { b.listenersMx.Lock() defer b.listenersMx.Unlock() + var closed bool + for _, topic := range topics { - b.off(topic, listener) + ok := b.off(topic, listener) + if !closed && ok { + log.Debug("closing event channel!") + close(listener) + closed = true + } } } @@ -116,11 +125,30 @@ func (b *Bus) Destroy() { b.listenersMx.Lock() defer b.listenersMx.Unlock() + // Track what listeners have already been closed. Because the same listener + // can be listening on multiple topics, we need a way to essentially + // "de-duplicate" all the listeners across all the topics. + var closed []Listener + for _, listeners := range b.listeners { for _, listener := range listeners { + if contains(closed, listener) { + continue + } + close(listener) + closed = append(closed, listener) } } b.listeners = make(map[string][]Listener) } + +func contains(closed []Listener, listener Listener) bool { + for _, c := range closed { + if c == listener { + return true + } + } + return false +} diff --git a/events/events_test.go b/events/events_test.go index 91e6fea..542a8e2 100644 --- a/events/events_test.go +++ b/events/events_test.go @@ -36,8 +36,6 @@ func TestBus_Off(t *testing.T) { bus.Off(listener, topic) g.Assert(len(bus.listeners[topic])).Equal(0, "Topic still has one or more listeners") - - close(listener) }) g.It("unregisters correct listener", func() { @@ -62,9 +60,6 @@ func TestBus_Off(t *testing.T) { // Cleanup bus.Off(listener2, topic) - close(listener) - close(listener2) - close(listener3) }) }) } @@ -91,7 +86,6 @@ func TestBus_On(t *testing.T) { // Cleanup bus.Off(listener, topic) - close(listener) }) }) } @@ -127,7 +121,6 @@ func TestBus_Publish(t *testing.T) { <-done // Cleanup - close(listener) bus.Off(listener, topic) }) @@ -172,9 +165,6 @@ func TestBus_Publish(t *testing.T) { bus.Off(listener, topic) bus.Off(listener2, topic) bus.Off(listener3, topic) - close(listener) - close(listener2) - close(listener3) }) }) } diff --git a/router/websocket/listeners.go b/router/websocket/listeners.go index e4b16b7..aae8855 100644 --- a/router/websocket/listeners.go +++ b/router/websocket/listeners.go @@ -146,12 +146,10 @@ func (h *Handler) listenForServerEvents(ctx context.Context) error { break } + // These functions will automatically close the channel if it hasn't been already. h.server.Events().Off(eventChan, e...) h.server.LogSink().Off(logOutput) h.server.InstallSink().Off(installOutput) - close(eventChan) - close(logOutput) - close(installOutput) // If the internal context is stopped it is either because the parent context // got canceled or because we ran into an error. If the "err" variable is nil diff --git a/server/sink.go b/server/sink.go index c6f2b20..908748e 100644 --- a/server/sink.go +++ b/server/sink.go @@ -16,14 +16,6 @@ func newSinkPool() *sinkPool { return &sinkPool{} } -// On adds a sink on the pool. -func (p *sinkPool) On(c chan []byte) { - p.mx.Lock() - defer p.mx.Unlock() - - p.sinks = append(p.sinks, c) -} - // Off removes a sink from the pool. func (p *sinkPool) Off(c chan []byte) { p.mx.Lock() @@ -39,10 +31,19 @@ func (p *sinkPool) Off(c chan []byte) { sinks[len(sinks)-1] = nil sinks = sinks[:len(sinks)-1] p.sinks = sinks + close(c) return } } +// On adds a sink on the pool. +func (p *sinkPool) On(c chan []byte) { + p.mx.Lock() + defer p.mx.Unlock() + + p.sinks = append(p.sinks, c) +} + // Destroy destroys the pool by removing and closing all sinks. func (p *sinkPool) Destroy() { p.mx.Lock()