diff --git a/internal/concurrentmap/concurrentmap.go b/internal/concurrentmap/concurrentmap.go index badb823..730c77c 100644 --- a/internal/concurrentmap/concurrentmap.go +++ b/internal/concurrentmap/concurrentmap.go @@ -37,3 +37,20 @@ func (cmap *ConcurrentMap[T]) Delete(key string) { delete(cmap.nonConcurrentMap, key) } + +func (cmap *ConcurrentMap[T]) DeleteIf(key string, predicate func(T) bool) bool { + cmap.mtx.Lock() + defer cmap.mtx.Unlock() + + value, ok := cmap.nonConcurrentMap[key] + if !ok { + return false + } + if !predicate(value) { + return false + } + + delete(cmap.nonConcurrentMap, key) + + return true +} diff --git a/internal/concurrentmap/concurrentmap_test.go b/internal/concurrentmap/concurrentmap_test.go new file mode 100644 index 0000000..a9ecca9 --- /dev/null +++ b/internal/concurrentmap/concurrentmap_test.go @@ -0,0 +1,34 @@ +package concurrentmap + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDeleteIf(t *testing.T) { + cmap := NewConcurrentMap[int]() + cmap.Store("a", 1) + + deleted := cmap.DeleteIf("a", func(value int) bool { + return value == 1 + }) + require.True(t, deleted) + + _, ok := cmap.Load("a") + require.False(t, ok) +} + +func TestDeleteIfPredicateFalse(t *testing.T) { + cmap := NewConcurrentMap[int]() + cmap.Store("a", 1) + + deleted := cmap.DeleteIf("a", func(value int) bool { + return value == 2 + }) + require.False(t, deleted) + + value, ok := cmap.Load("a") + require.True(t, ok) + require.Equal(t, 1, value) +} diff --git a/internal/controller/notifier/notifier.go b/internal/controller/notifier/notifier.go index df7e6f7..07cfe32 100644 --- a/internal/controller/notifier/notifier.go +++ b/internal/controller/notifier/notifier.go @@ -31,16 +31,19 @@ func NewNotifier(logger *zap.SugaredLogger) *Notifier { func (watcher *Notifier) Register(ctx context.Context, worker string) (chan *rpc.WatchInstruction, func()) { subCtx, cancel := context.WithCancel(ctx) workerCh := make(chan *rpc.WatchInstruction) - - watcher.logger.Debugf("registering worker %s", worker) - watcher.workers.Store(worker, &WorkerSlot{ + slot := &WorkerSlot{ ctx: subCtx, ch: workerCh, - }) + } + + watcher.logger.Debugf("registering worker %s", worker) + watcher.workers.Store(worker, slot) return workerCh, func() { watcher.logger.Debugf("deleting worker %s", worker) - watcher.workers.Delete(worker) + watcher.workers.DeleteIf(worker, func(current *WorkerSlot) bool { + return current == slot + }) cancel() } }