iter: propagate runtime.Goexit from iterator passed to Pull

This change propagates a runtime.Goexit initiated by the iterator into
the caller of next and/or stop.

Fixes #67712.

Change-Id: I5bb8d22f749fce39ce4f587148c5fc71aee2af65
Reviewed-on: https://go-review.googlesource.com/c/go/+/589137
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Austin Clements <austin@google.com>
Reviewed-by: David Chase <drchase@google.com>
This commit is contained in:
Michael Anthony Knyszek 2024-05-31 20:22:32 +00:00 committed by Michael Knyszek
parent 1471978bac
commit 9d2aeae72d
3 changed files with 133 additions and 13 deletions

View file

@ -8,6 +8,7 @@ package iter
import (
"internal/race"
"runtime"
"unsafe"
)
@ -56,6 +57,7 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
yieldNext bool
racer int
panicValue any
seqDone bool // to detect Goexit
)
c := newcoro(func(c *coro) {
race.Acquire(unsafe.Pointer(&racer))
@ -76,15 +78,17 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
// Recover and propagate panics from seq.
defer func() {
if p := recover(); p != nil {
done = true // Invalidate iterator.
panicValue = p
} else if !seqDone {
panicValue = goexitPanicValue
}
done = true // Invalidate iterator
race.Release(unsafe.Pointer(&racer))
}()
seq(yield)
var v0 V
v, ok = v0, false
done = true
seqDone = true
})
next = func() (v1 V, ok1 bool) {
race.Write(unsafe.Pointer(&racer)) // detect races
@ -100,9 +104,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
coroswitch(c)
race.Acquire(unsafe.Pointer(&racer))
// Propagate panics from seq.
// Propagate panics and goexits from seq.
if panicValue != nil {
panic(panicValue)
if panicValue == goexitPanicValue {
// Propagate runtime.Goexit from seq.
runtime.Goexit()
} else {
panic(panicValue)
}
}
return v, ok
}
@ -115,9 +124,14 @@ func Pull[V any](seq Seq[V]) (next func() (V, bool), stop func()) {
coroswitch(c)
race.Acquire(unsafe.Pointer(&racer))
// Propagate panics from seq.
// Propagate panics and goexits from seq.
if panicValue != nil {
panic(panicValue)
if panicValue == goexitPanicValue {
// Propagate runtime.Goexit from seq.
runtime.Goexit()
} else {
panic(panicValue)
}
}
}
}
@ -152,6 +166,7 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
yieldNext bool
racer int
panicValue any
seqDone bool
)
c := newcoro(func(c *coro) {
race.Acquire(unsafe.Pointer(&racer))
@ -172,16 +187,18 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
// Recover and propagate panics from seq.
defer func() {
if p := recover(); p != nil {
done = true // Invalidate iterator.
panicValue = p
} else if !seqDone {
panicValue = goexitPanicValue
}
done = true // Invalidate iterator.
race.Release(unsafe.Pointer(&racer))
}()
seq(yield)
var k0 K
var v0 V
k, v, ok = k0, v0, false
done = true
seqDone = true
})
next = func() (k1 K, v1 V, ok1 bool) {
race.Write(unsafe.Pointer(&racer)) // detect races
@ -197,9 +214,14 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
coroswitch(c)
race.Acquire(unsafe.Pointer(&racer))
// Propagate panics from seq.
// Propagate panics and goexits from seq.
if panicValue != nil {
panic(panicValue)
if panicValue == goexitPanicValue {
// Propagate runtime.Goexit from seq.
runtime.Goexit()
} else {
panic(panicValue)
}
}
return k, v, ok
}
@ -212,11 +234,20 @@ func Pull2[K, V any](seq Seq2[K, V]) (next func() (K, V, bool), stop func()) {
coroswitch(c)
race.Acquire(unsafe.Pointer(&racer))
// Propagate panics from seq.
// Propagate panics and goexits from seq.
if panicValue != nil {
panic(panicValue)
if panicValue == goexitPanicValue {
// Propagate runtime.Goexit from seq.
runtime.Goexit()
} else {
panic(panicValue)
}
}
}
}
return next, stop
}
// goexitPanicValue is a sentinel value indicating that an iterator
// exited via runtime.Goexit.
var goexitPanicValue any = new(int)

View file

@ -320,3 +320,92 @@ func panicsWith(v any, f func()) (panicked bool) {
f()
return
}
func TestPullGoexit(t *testing.T) {
t.Run("next", func(t *testing.T) {
var next func() (int, bool)
var stop func()
if !goexits(t, func() {
next, stop = Pull(goexitSeq())
next()
}) {
t.Fatal("failed to Goexit from next")
}
if x, ok := next(); x != 0 || ok {
t.Fatal("iterator returned valid value after Goexit")
}
stop()
})
t.Run("stop", func(t *testing.T) {
var next func() (int, bool)
var stop func()
if !goexits(t, func() {
next, stop = Pull(goexitSeq())
stop()
}) {
t.Fatal("failed to Goexit from stop")
}
if x, ok := next(); x != 0 || ok {
t.Fatal("iterator returned valid value after Goexit")
}
stop()
})
}
func goexitSeq() Seq[int] {
return func(yield func(int) bool) {
runtime.Goexit()
}
}
func TestPull2Goexit(t *testing.T) {
t.Run("next", func(t *testing.T) {
var next func() (int, int, bool)
var stop func()
if !goexits(t, func() {
next, stop = Pull2(goexitSeq2())
next()
}) {
t.Fatal("failed to Goexit from next")
}
if x, y, ok := next(); x != 0 || y != 0 || ok {
t.Fatal("iterator returned valid value after Goexit")
}
stop()
})
t.Run("stop", func(t *testing.T) {
var next func() (int, int, bool)
var stop func()
if !goexits(t, func() {
next, stop = Pull2(goexitSeq2())
stop()
}) {
t.Fatal("failed to Goexit from stop")
}
if x, y, ok := next(); x != 0 || y != 0 || ok {
t.Fatal("iterator returned valid value after Goexit")
}
stop()
})
}
func goexitSeq2() Seq2[int, int] {
return func(yield func(int, int) bool) {
runtime.Goexit()
}
}
func goexits(t *testing.T, f func()) bool {
t.Helper()
exit := make(chan bool)
go func() {
cleanExit := false
defer func() {
exit <- recover() == nil && !cleanExit
}()
f()
cleanExit = true
}()
return <-exit
}

View file

@ -68,8 +68,8 @@ func corostart() {
c := gp.coroarg
gp.coroarg = nil
defer coroexit(c)
c.f(c)
coroexit(c)
}
// coroexit is like coroswitch but closes the coro