diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 238297f945..b894be0813 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -1424,3 +1424,59 @@ func testWriteHeader0(t *testing.T, h2 bool) { t.Error("expected panic in handler") } } + +// Issue 23010: don't be super strict checking WriteHeader's code if +// it's not even valid to call WriteHeader then anyway. +func TestWriteHeaderNoCodeCheck_h1(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, false) } +func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) } +func TestWriteHeaderNoCodeCheck_h2(t *testing.T) { testWriteHeaderAfterWrite(t, h2Mode, false) } +func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { + setParallel(t) + defer afterTest(t) + + var errorLog lockedBytesBuffer + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + if hijack { + conn, _, _ := w.(Hijacker).Hijack() + defer conn.Close() + conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nfoo")) + w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 + conn.Write([]byte("bar")) + return + } + io.WriteString(w, "foo") + w.(Flusher).Flush() + w.WriteHeader(0) // verify this doesn't panic if there's already output; Issue 23010 + io.WriteString(w, "bar") + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(&errorLog, "", 0) + }) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(body), "foobar"; got != want { + t.Errorf("got = %q; want %q", got, want) + } + + // Also check the stderr output: + if h2 { + // TODO: also emit this log message for HTTP/2? + // We historically haven't, so don't check. + return + } + gotLog := strings.TrimSpace(errorLog.String()) + wantLog := "http: multiple response.WriteHeader calls" + if hijack { + wantLog = "http: response.WriteHeader on hijacked connection" + } + if gotLog != wantLog { + t.Errorf("stderr output = %q; want %q", gotLog, wantLog) + } +} diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index e6e164467d..161a1ed137 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -6204,7 +6204,6 @@ func http2checkWriteHeaderCode(code int) { } func (w *http2responseWriter) WriteHeader(code int) { - http2checkWriteHeaderCode(code) rws := w.rws if rws == nil { panic("WriteHeader called after Handler finished") @@ -6214,6 +6213,7 @@ func (w *http2responseWriter) WriteHeader(code int) { func (rws *http2responseWriterState) writeHeader(code int) { if !rws.wroteHeader { + http2checkWriteHeaderCode(code) rws.wroteHeader = true rws.status = code if len(rws.handlerHeader) > 0 { diff --git a/src/net/http/server.go b/src/net/http/server.go index e1698ccfa3..ceb1a047cf 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -1072,7 +1072,6 @@ func checkWriteHeaderCode(code int) { } func (w *response) WriteHeader(code int) { - checkWriteHeaderCode(code) if w.conn.hijacked() { w.conn.server.logf("http: response.WriteHeader on hijacked connection") return @@ -1081,6 +1080,7 @@ func (w *response) WriteHeader(code int) { w.conn.server.logf("http: multiple response.WriteHeader calls") return } + checkWriteHeaderCode(code) w.wroteHeader = true w.status = code