From c52db4eec0d41b10aab15aa914935b6e16587054 Mon Sep 17 00:00:00 2001 From: Dane Everitt Date: Sun, 23 Jan 2022 10:41:12 -0500 Subject: [PATCH] Add test coverage for sinks; prevent panic on nil channels --- server/sink.go | 19 +++-- server/sink_test.go | 189 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 5 deletions(-) create mode 100644 server/sink_test.go diff --git a/server/sink.go b/server/sink.go index a6f5db9..41a5e93 100644 --- a/server/sink.go +++ b/server/sink.go @@ -52,10 +52,12 @@ func (p *sinkPool) Off(c chan []byte) { copy(sinks[i:], sinks[i+1:]) sinks[len(sinks)-1] = nil sinks = sinks[:len(sinks)-1] - - // Update our tracked sinks, and close the matched channel. p.sinks = sinks - close(c) + + // Avoid a panic if the sink channel is nil at this point. + if c != nil { + close(c) + } return } @@ -68,7 +70,9 @@ func (p *sinkPool) Destroy() { defer p.mu.Unlock() for _, c := range p.sinks { - close(c) + if c != nil { + close(c) + } } p.sinks = nil @@ -77,10 +81,15 @@ func (p *sinkPool) Destroy() { // Push sends a given message to each of the channels registered in the pool. func (p *sinkPool) Push(data []byte) { p.mu.RLock() + // Attempt to send the data over to the channels. If the channel buffer is full, + // or otherwise blocked for some reason (such as being a nil channel), just discard + // the event data and move on to the next channel in the slice. If you don't + // implement the "default" on the select you'll block execution until the channel + // becomes unblocked, which is not what we want to do here. for _, c := range p.sinks { select { - // Send the event data over to the channels. case c <- data: + default: } } p.mu.RUnlock() diff --git a/server/sink_test.go b/server/sink_test.go new file mode 100644 index 0000000..5713cee --- /dev/null +++ b/server/sink_test.go @@ -0,0 +1,189 @@ +package server + +import ( + "reflect" + "sync" + "testing" + + . "github.com/franela/goblin" +) + +func MutexLocked(m *sync.RWMutex) bool { + v := reflect.ValueOf(m).Elem() + + state := v.FieldByName("w").FieldByName("state") + + return state.Int()&1 == 1 || v.FieldByName("readerCount").Int() > 0 +} + +func Test(t *testing.T) { + g := Goblin(t) + + g.Describe("SinkPool#On", func() { + g.It("pushes additional channels to a sink", func() { + pool := &sinkPool{} + + g.Assert(pool.sinks).IsZero() + + c1 := make(chan []byte, 1) + pool.On(c1) + + g.Assert(len(pool.sinks)).Equal(1) + g.Assert(MutexLocked(&pool.mu)).IsFalse() + }) + }) + + g.Describe("SinkPool#Off", func() { + var pool *sinkPool + g.BeforeEach(func() { + pool = &sinkPool{} + }) + + g.It("works when no sinks are registered", func() { + ch := make(chan []byte, 1) + + g.Assert(pool.sinks).IsZero() + pool.Off(ch) + + g.Assert(pool.sinks).IsZero() + g.Assert(MutexLocked(&pool.mu)).IsFalse() + }) + + g.It("does not remove any sinks when the channel does not match", func() { + ch := make(chan []byte, 1) + ch2 := make(chan []byte, 1) + + pool.On(ch) + g.Assert(len(pool.sinks)).Equal(1) + + pool.Off(ch2) + g.Assert(len(pool.sinks)).Equal(1) + g.Assert(pool.sinks[0]).Equal(ch) + g.Assert(MutexLocked(&pool.mu)).IsFalse() + }) + + g.It("removes a channel and maintains the order", func() { + channels := make([]chan []byte, 8) + for i := 0; i < len(channels); i++ { + channels[i] = make(chan []byte, 1) + pool.On(channels[i]) + } + + g.Assert(len(pool.sinks)).Equal(8) + + pool.Off(channels[2]) + g.Assert(len(pool.sinks)).Equal(7) + g.Assert(pool.sinks[1]).Equal(channels[1]) + g.Assert(pool.sinks[2]).Equal(channels[3]) + g.Assert(MutexLocked(&pool.mu)).IsFalse() + }) + + g.It("does not panic if a nil channel is provided", func() { + ch := make([]chan []byte, 1) + + defer func () { + if r := recover(); r != nil { + g.Fail("removing a nil channel should not cause a panic") + } + }() + + pool.On(ch[0]) + pool.Off(ch[0]) + + g.Assert(len(pool.sinks)).Equal(0) + }) + }) + + g.Describe("SinkPool#Push", func() { + var pool *sinkPool + g.BeforeEach(func() { + pool = &sinkPool{} + }) + + g.It("works when no sinks are registered", func() { + g.Assert(len(pool.sinks)).IsZero() + pool.Push([]byte("test")) + g.Assert(MutexLocked(&pool.mu)).IsFalse() + }) + + g.It("sends data to every registered sink", func() { + ch1 := make(chan []byte, 1) + ch2 := make(chan []byte, 1) + + pool.On(ch1) + pool.On(ch2) + + g.Assert(len(pool.sinks)).Equal(2) + b := []byte("test") + pool.Push(b) + + g.Assert(MutexLocked(&pool.mu)).IsFalse() + g.Assert(<-ch1).Equal(b) + g.Assert(<-ch2).Equal(b) + g.Assert(len(pool.sinks)).Equal(2) + }) + + g.It("does not block if a channel is nil or otherwise full", func() { + ch := make([]chan []byte, 2) + ch[1] = make(chan []byte, 1) + ch[1] <- []byte("test") + + pool.On(ch[0]) + pool.On(ch[1]) + + pool.Push([]byte("testing")) + + g.Assert(MutexLocked(&pool.mu)).IsFalse() + g.Assert(<-ch[1]).Equal([]byte("test")) + + pool.Push([]byte("test2")) + g.Assert(<-ch[1]).Equal([]byte("test2")) + g.Assert(MutexLocked(&pool.mu)).IsFalse() + }) + }) + + g.Describe("SinkPool#Destroy", func() { + var pool *sinkPool + g.BeforeEach(func() { + pool = &sinkPool{} + }) + + g.It("works if no sinks are registered", func() { + pool.Destroy() + + g.Assert(MutexLocked(&pool.mu)).IsFalse() + }) + + g.It("closes all channels fully", func() { + ch1 := make(chan []byte, 1) + ch2 := make(chan []byte, 1) + + pool.On(ch1) + pool.On(ch2) + + g.Assert(len(pool.sinks)).Equal(2) + pool.Destroy() + g.Assert(pool.sinks).IsZero() + + defer func() { + r := recover() + + g.Assert(r).IsNotNil() + g.Assert(r.(error).Error()).Equal("send on closed channel") + }() + + ch1 <- []byte("test") + }) + + g.It("works when a sink channel is nil", func() { + ch := make([]chan []byte, 2) + + pool.On(ch[0]) + pool.On(ch[1]) + + pool.Destroy() + + g.Assert(MutexLocked(&pool.mu)).IsFalse() + }) + }) +} \ No newline at end of file