From 1471978bacf91bd9de273e6f265ed31def83711a Mon Sep 17 00:00:00 2001 From: Michael Anthony Knyszek Date: Fri, 31 May 2024 20:10:09 +0000 Subject: [PATCH] iter: propagate panics from the iterator passed to Pull This change propagates panics from the iterator passed to Pull through next and stop. Once the panic occurs, next and stop become no-ops (the iterator is invalidated). For #67712. Change-Id: I05e45601d4d10acdf51b53e3164bd891c1b324ac Reviewed-on: https://go-review.googlesource.com/c/go/+/589136 Reviewed-by: David Chase Reviewed-by: Austin Clements LUCI-TryBot-Result: Go LUCI --- src/iter/iter.go | 66 +++++++++++++++++++++++++++++------- src/iter/pull_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 13 deletions(-) diff --git a/src/iter/iter.go b/src/iter/iter.go index af360aeb07..3e93f3bdb7 100644 --- a/src/iter/iter.go +++ b/src/iter/iter.go @@ -50,11 +50,12 @@ func coroswitch(*coro) // simultaneously. func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { var ( - v V - ok bool - done bool - yieldNext bool - racer int + v V + ok bool + done bool + yieldNext bool + racer int + panicValue any ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) @@ -72,14 +73,22 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { race.Acquire(unsafe.Pointer(&racer)) return !done } + // Recover and propagate panics from seq. + defer func() { + if p := recover(); p != nil { + done = true // Invalidate iterator. + panicValue = p + } + race.Release(unsafe.Pointer(&racer)) + }() seq(yield) var v0 V v, ok = v0, false done = true - race.Release(unsafe.Pointer(&racer)) }) next = func() (v1 V, ok1 bool) { race.Write(unsafe.Pointer(&racer)) // detect races + if done { return } @@ -90,15 +99,26 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { race.Release(unsafe.Pointer(&racer)) coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) + + // Propagate panics from seq. + if panicValue != nil { + panic(panicValue) + } return v, ok } stop = func() { race.Write(unsafe.Pointer(&racer)) // detect races + if !done { done = true race.Release(unsafe.Pointer(&racer)) coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) + + // Propagate panics from seq. + if panicValue != nil { + panic(panicValue) + } } } return next, stop @@ -125,12 +145,13 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) { // simultaneously. func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { var ( - k K - v V - ok bool - done bool - yieldNext bool - racer int + k K + v V + ok bool + done bool + yieldNext bool + racer int + panicValue any ) c := newcoro(func(c *coro) { race.Acquire(unsafe.Pointer(&racer)) @@ -148,15 +169,23 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { race.Acquire(unsafe.Pointer(&racer)) return !done } + // Recover and propagate panics from seq. + defer func() { + if p := recover(); p != nil { + done = true // Invalidate iterator. + panicValue = p + } + race.Release(unsafe.Pointer(&racer)) + }() seq(yield) var k0 K var v0 V k, v, ok = k0, v0, false done = true - race.Release(unsafe.Pointer(&racer)) }) next = func() (k1 K, v1 V, ok1 bool) { race.Write(unsafe.Pointer(&racer)) // detect races + if done { return } @@ -167,15 +196,26 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) { race.Release(unsafe.Pointer(&racer)) coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) + + // Propagate panics from seq. + if panicValue != nil { + panic(panicValue) + } return k, v, ok } stop = func() { race.Write(unsafe.Pointer(&racer)) // detect races + if !done { done = true race.Release(unsafe.Pointer(&racer)) coroswitch(c) race.Acquire(unsafe.Pointer(&racer)) + + // Propagate panics from seq. + if panicValue != nil { + panic(panicValue) + } } } return next, stop diff --git a/src/iter/pull_test.go b/src/iter/pull_test.go index c39574b959..09f2270fa1 100644 --- a/src/iter/pull_test.go +++ b/src/iter/pull_test.go @@ -241,3 +241,82 @@ func storeYield2() Seq2[int, int] { } var yieldSlot2 func(int, int) bool + +func TestPullPanic(t *testing.T) { + t.Run("next", func(t *testing.T) { + next, stop := Pull(panicSeq()) + if !panicsWith("boom", func() { next() }) { + t.Fatal("failed to propagate panic on first next") + } + // Make sure we don't panic again if we try to call next or stop. + if _, ok := next(); ok { + t.Fatal("next returned true after iterator panicked") + } + // Calling stop again should be a no-op. + stop() + }) + t.Run("stop", func(t *testing.T) { + next, stop := Pull(panicSeq()) + if !panicsWith("boom", func() { stop() }) { + t.Fatal("failed to propagate panic on first stop") + } + // Make sure we don't panic again if we try to call next or stop. + if _, ok := next(); ok { + t.Fatal("next returned true after iterator panicked") + } + // Calling stop again should be a no-op. + stop() + }) +} + +func panicSeq() Seq[int] { + return func(yield func(int) bool) { + panic("boom") + } +} + +func TestPull2Panic(t *testing.T) { + t.Run("next", func(t *testing.T) { + next, stop := Pull2(panicSeq2()) + if !panicsWith("boom", func() { next() }) { + t.Fatal("failed to propagate panic on first next") + } + // Make sure we don't panic again if we try to call next or stop. + if _, _, ok := next(); ok { + t.Fatal("next returned true after iterator panicked") + } + // Calling stop again should be a no-op. + stop() + }) + t.Run("stop", func(t *testing.T) { + next, stop := Pull2(panicSeq2()) + if !panicsWith("boom", func() { stop() }) { + t.Fatal("failed to propagate panic on first stop") + } + // Make sure we don't panic again if we try to call next or stop. + if _, _, ok := next(); ok { + t.Fatal("next returned true after iterator panicked") + } + // Calling stop again should be a no-op. + stop() + }) +} + +func panicSeq2() Seq2[int, int] { + return func(yield func(int, int) bool) { + panic("boom") + } +} + +func panicsWith(v any, f func()) (panicked bool) { + defer func() { + if r := recover(); r != nil { + if r != v { + panic(r) + } + panicked = true + } + }() + f() + return +}