diff --git a/src/bufio/bufio.go b/src/bufio/bufio.go index 5a88def0c7..8469b9eff7 100644 --- a/src/bufio/bufio.go +++ b/src/bufio/bufio.go @@ -70,7 +70,14 @@ func (b *Reader) Size() int { return len(b.buf) } // the buffered reader to read from r. // Calling Reset on the zero value of Reader initializes the internal buffer // to the default size. +// Calling b.Reset(b) (that is, resetting a Reader to itself) does nothing. func (b *Reader) Reset(r io.Reader) { + // If a Reader r is passed to NewReader, NewReader will return r. + // Different layers of code may do that, and then later pass r + // to Reset. Avoid infinite recursion in that case. + if b == r { + return + } if b.buf == nil { b.buf = make([]byte, defaultBufSize) } @@ -608,7 +615,14 @@ func (b *Writer) Size() int { return len(b.buf) } // resets b to write its output to w. // Calling Reset on the zero value of Writer initializes the internal buffer // to the default size. +// Calling w.Reset(w) (that is, resetting a Writer to itself) does nothing. func (b *Writer) Reset(w io.Writer) { + // If a Writer w is passed to NewWriter, NewWriter will return w. + // Different layers of code may do that, and then later pass w + // to Reset. Avoid infinite recursion in that case. + if b == w { + return + } if b.buf == nil { b.buf = make([]byte, defaultBufSize) } diff --git a/src/bufio/bufio_test.go b/src/bufio/bufio_test.go index 64ccd025ea..a8c1e50397 100644 --- a/src/bufio/bufio_test.go +++ b/src/bufio/bufio_test.go @@ -1482,6 +1482,17 @@ func TestReadZero(t *testing.T) { } func TestReaderReset(t *testing.T) { + checkAll := func(r *Reader, want string) { + t.Helper() + all, err := io.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if string(all) != want { + t.Errorf("ReadAll returned %q, want %q", all, want) + } + } + r := NewReader(strings.NewReader("foo foo")) buf := make([]byte, 3) r.Read(buf) @@ -1490,27 +1501,23 @@ func TestReaderReset(t *testing.T) { } r.Reset(strings.NewReader("bar bar")) - all, err := io.ReadAll(r) - if err != nil { - t.Fatal(err) - } - if string(all) != "bar bar" { - t.Errorf("ReadAll = %q; want bar bar", all) - } + checkAll(r, "bar bar") *r = Reader{} // zero out the Reader r.Reset(strings.NewReader("bar bar")) - all, err = io.ReadAll(r) - if err != nil { - t.Fatal(err) - } - if string(all) != "bar bar" { - t.Errorf("ReadAll = %q; want bar bar", all) - } + checkAll(r, "bar bar") + + // Wrap a reader and then Reset to that reader. + r.Reset(strings.NewReader("recur")) + r2 := NewReader(r) + checkAll(r2, "recur") + r.Reset(strings.NewReader("recur2")) + r2.Reset(r) + checkAll(r2, "recur2") } func TestWriterReset(t *testing.T) { - var buf1, buf2, buf3 strings.Builder + var buf1, buf2, buf3, buf4, buf5 strings.Builder w := NewWriter(&buf1) w.WriteString("foo") @@ -1534,6 +1541,22 @@ func TestWriterReset(t *testing.T) { if buf3.String() != "bar" { t.Errorf("buf3 = %q; want bar", buf3.String()) } + + // Wrap a writer and then Reset to that writer. + w.Reset(&buf4) + w2 := NewWriter(w) + w2.WriteString("recur") + w2.Flush() + if buf4.String() != "recur" { + t.Errorf("buf4 = %q, want %q", buf4.String(), "recur") + } + w.Reset(&buf5) + w2.Reset(w) + w2.WriteString("recur2") + w2.Flush() + if buf5.String() != "recur2" { + t.Errorf("buf5 = %q, want %q", buf5.String(), "recur2") + } } func TestReaderDiscard(t *testing.T) {