diff --git a/src/internal/singleflight/singleflight.go b/src/internal/singleflight/singleflight.go index 19d5a94a0bb..755bf1c3500 100644 --- a/src/internal/singleflight/singleflight.go +++ b/src/internal/singleflight/singleflight.go @@ -94,7 +94,9 @@ func (g *Group) doCall(c *call, key string, fn func() (any, error)) { c.wg.Done() g.mu.Lock() - delete(g.m, key) + if g.m[key] == c { + delete(g.m, key) + } for _, ch := range c.chans { ch <- Result{c.val, c.err, c.dups > 0} } diff --git a/src/internal/singleflight/singleflight_test.go b/src/internal/singleflight/singleflight_test.go index 99713c9e149..c8b4a81d52c 100644 --- a/src/internal/singleflight/singleflight_test.go +++ b/src/internal/singleflight/singleflight_test.go @@ -85,3 +85,59 @@ func TestDoDupSuppress(t *testing.T) { t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) } } + +func TestForgetUnshared(t *testing.T) { + var g Group + + var firstStarted, firstFinished sync.WaitGroup + + firstStarted.Add(1) + firstFinished.Add(1) + + key := "key" + firstCh := make(chan struct{}) + go func() { + g.Do(key, func() (i interface{}, e error) { + firstStarted.Done() + <-firstCh + firstFinished.Done() + return + }) + }() + + firstStarted.Wait() + g.ForgetUnshared(key) // from this point no two function using same key should be executed concurrently + + secondCh := make(chan struct{}) + go func() { + g.Do(key, func() (i interface{}, e error) { + // Notify that we started + secondCh <- struct{}{} + <-secondCh + return 2, nil + }) + }() + + <-secondCh + + resultCh := g.DoChan(key, func() (i interface{}, e error) { + panic("third must not be started") + }) + + if g.ForgetUnshared(key) { + t.Errorf("Before first goroutine finished, key %q is shared, should return false", key) + } + + close(firstCh) + firstFinished.Wait() + + if g.ForgetUnshared(key) { + t.Errorf("After first goroutine finished, key %q is still shared, should return false", key) + } + + secondCh <- struct{}{} + + if result := <-resultCh; result.Val != 2 { + t.Errorf("We should receive result produced by second call, expected: 2, got %d", result.Val) + } +}