net/http: index patterns for faster conflict detection

Add an index so that pattern registration isn't always quadratic.

If there were no index, then every pattern that was registered would
have to be compared to every existing pattern for conflicts. This
would make registration quadratic in the number of patterns, in every
case.

The index in this CL should help most of the time. If a pattern has a
literal segment, it will weed out all other patterns that have a
different literal in that position.

The worst case will still be quadratic, but it is unlikely that a set
of such patterns would arise naturally.

One novel (to me) aspect of the CL is the use of fuzz testing on data
that is neither a string nor a byte slice. The test uses fuzzing to
generate a byte slice, then decodes the byte slice into a valid
pattern (most of the time). This test actually caught a bug: see
https://go.dev/cl/529119.

Change-Id: Ice0be6547decb5ce75a8062e4e17227815d5d0b0
Reviewed-on: https://go-review.googlesource.com/c/go/+/529121
Run-TryBot: Jonathan Amsterdam <jba@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Jonathan Amsterdam 2023-09-18 15:51:04 -04:00
parent bda5e6c3d0
commit be11422b1e
5 changed files with 346 additions and 9 deletions

View file

@ -0,0 +1,124 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import "math"
// A routingIndex optimizes conflict detection by indexing patterns.
//
// The basic idea is to rule out patterns that cannot conflict with a given
// pattern because they have a different literal in a corresponding segment.
// See the comments in [routingIndex.possiblyConflictingPatterns] for more details.
type routingIndex struct {
// map from a particular segment position and value to all registered patterns
// with that value in that position.
// For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c"
// but not "/a", "b/a", "/a/c" or "/a/{x}".
segments map[routingIndexKey][]*pattern
// All patterns that end in a multi wildcard (including trailing slash).
// We do not try to be clever about indexing multi patterns, because there
// are unlikely to be many of them.
multis []*pattern
}
type routingIndexKey struct {
pos int // 0-based segment position
s string // literal, or empty for wildcard
}
func (idx *routingIndex) addPattern(pat *pattern) {
if pat.lastSegment().multi {
idx.multis = append(idx.multis, pat)
} else {
if idx.segments == nil {
idx.segments = map[routingIndexKey][]*pattern{}
}
for pos, seg := range pat.segments {
key := routingIndexKey{pos: pos, s: ""}
if !seg.wild {
key.s = seg.s
}
idx.segments[key] = append(idx.segments[key], pat)
}
}
}
// possiblyConflictingPatterns calls f on all patterns that might conflict with
// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately
// with that error.
//
// To be correct, possiblyConflictingPatterns must include all patterns that
// might conflict. But it may also include patterns that cannot conflict.
// For instance, an implementation that returns all registered patterns is correct.
// We use this fact throughout, simplifying the implementation by returning more
// patterns that we might need to.
func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) {
// Terminology:
// dollar pattern: one ending in "{$}"
// multi pattern: one ending in a trailing slash or "{x...}" wildcard
// ordinary pattern: neither of the above
// apply f to all the pats, stopping on error.
apply := func(pats []*pattern) error {
if err != nil {
return err
}
for _, p := range pats {
err = f(p)
if err != nil {
return err
}
}
return nil
}
// Our simple indexing scheme doesn't try to prune multi patterns; assume
// any of them can match the argument.
if err := apply(idx.multis); err != nil {
return err
}
if pat.lastSegment().s == "/" {
// All paths that a dollar pattern matches end in a slash; no paths that
// an ordinary pattern matches do. So only other dollar or multi
// patterns can conflict with a dollar pattern. Furthermore, conflicting
// dollar patterns must have the {$} in the same position.
return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}])
}
// For ordinary and multi patterns, the only conflicts can be with a multi,
// or a pattern that has the same literal or a wildcard at some literal
// position.
// We could intersect all the possible matches at each position, but we
// do something simpler: we find the position with the fewest patterns.
var lmin, wmin []*pattern
min := math.MaxInt
hasLit := false
for i, seg := range pat.segments {
if seg.multi {
break
}
if !seg.wild {
hasLit = true
lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}]
wpats := idx.segments[routingIndexKey{s: "", pos: i}]
if sum := len(lpats) + len(wpats); sum < min {
lmin = lpats
wmin = wpats
min = sum
}
}
}
if hasLit {
apply(lmin)
apply(wmin)
return err
}
// This pattern is all wildcards.
// Check it against everything.
for _, pats := range idx.segments {
apply(pats)
}
return err
}

View file

@ -0,0 +1,207 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
import (
"bytes"
"fmt"
"slices"
"sort"
"strings"
"testing"
)
func TestIndex(t *testing.T) {
pats := []string{"HEAD /", "/a"}
var patterns []*pattern
var idx routingIndex
for _, p := range pats {
pat := mustParsePattern(t, p)
patterns = append(patterns, pat)
idx.addPattern(pat)
}
compare := func(pat *pattern) {
t.Helper()
got := indexConflicts(pat, &idx)
want := trueConflicts(pat, patterns)
if !slices.Equal(got, want) {
t.Errorf("%q:\ngot %q\nwant %q", pat, got, want)
}
}
compare(mustParsePattern(t, "GET /foo"))
compare(mustParsePattern(t, "GET /{x}"))
}
// This test works by comparing possiblyConflictingPatterns with
// an exhaustive loop through all patterns.
func FuzzIndex(f *testing.F) {
inits := []string{"/a", "/a/b", "/{x0}", "/{x0}/b", "/a/{x0}", "/a/{$}", "/a/b/{$}",
"/a/", "/a/b/", "/{x}/b/c/{$}", "GET /{x0}/", "HEAD /a"}
var patterns []*pattern
var idx routingIndex
// compare takes a fatalf function because fuzzing doesn't like
// it when the fuzz function calls f.Fatalf.
compare := func(pat *pattern, fatalf func(string, ...any)) {
got := indexConflicts(pat, &idx)
want := trueConflicts(pat, patterns)
if !slices.Equal(got, want) {
fatalf("%q:\ngot %q\nwant %q", pat, got, want)
}
}
for _, p := range inits {
pat, err := parsePattern(p)
if err != nil {
f.Fatal(err)
}
compare(pat, f.Fatalf)
patterns = append(patterns, pat)
idx.addPattern(pat)
f.Add(bytesFromPattern(pat))
}
f.Fuzz(func(t *testing.T, pb []byte) {
pat := bytesToPattern(pb)
if pat == nil {
return
}
compare(pat, t.Fatalf)
})
}
func trueConflicts(pat *pattern, pats []*pattern) []string {
var s []string
for _, p := range pats {
if pat.conflictsWith(p) {
s = append(s, p.String())
}
}
sort.Strings(s)
return s
}
func indexConflicts(pat *pattern, idx *routingIndex) []string {
var s []string
idx.possiblyConflictingPatterns(pat, func(p *pattern) error {
if pat.conflictsWith(p) {
s = append(s, p.String())
}
return nil
})
sort.Strings(s)
return slices.Compact(s)
}
// TODO: incorporate host and method; make encoding denser.
func bytesToPattern(bs []byte) *pattern {
if len(bs) == 0 {
return nil
}
var sb strings.Builder
wc := 0
for _, b := range bs[:len(bs)-1] {
sb.WriteByte('/')
switch b & 0x3 {
case 0:
fmt.Fprintf(&sb, "{x%d}", wc)
wc++
case 1:
sb.WriteString("a")
case 2:
sb.WriteString("b")
case 3:
sb.WriteString("c")
}
}
sb.WriteByte('/')
switch bs[len(bs)-1] & 0x7 {
case 0:
fmt.Fprintf(&sb, "{x%d}", wc)
case 1:
sb.WriteString("a")
case 2:
sb.WriteString("b")
case 3:
sb.WriteString("c")
case 4, 5:
fmt.Fprintf(&sb, "{x%d...}", wc)
default:
sb.WriteString("{$}")
}
pat, err := parsePattern(sb.String())
if err != nil {
panic(err)
}
return pat
}
func bytesFromPattern(p *pattern) []byte {
var bs []byte
for _, s := range p.segments {
var b byte
switch {
case s.multi:
b = 4
case s.wild:
b = 0
case s.s == "/":
b = 7
case s.s == "a":
b = 1
case s.s == "b":
b = 2
case s.s == "c":
b = 3
default:
panic("bad pattern")
}
bs = append(bs, b)
}
return bs
}
func TestBytesPattern(t *testing.T) {
tests := []struct {
bs []byte
pat string
}{
{[]byte{0, 1, 2, 3}, "/{x0}/a/b/c"},
{[]byte{16, 17, 18, 19}, "/{x0}/a/b/c"},
{[]byte{4, 4}, "/{x0}/{x1...}"},
{[]byte{6, 7}, "/b/{$}"},
}
t.Run("To", func(t *testing.T) {
for _, test := range tests {
p := bytesToPattern(test.bs)
got := p.String()
if got != test.pat {
t.Errorf("%v: got %q, want %q", test.bs, got, test.pat)
}
}
})
t.Run("From", func(t *testing.T) {
for _, test := range tests {
p, err := parsePattern(test.pat)
if err != nil {
t.Fatal(err)
}
got := bytesFromPattern(p)
var want []byte
for _, b := range test.bs[:len(test.bs)-1] {
want = append(want, b%4)
}
want = append(want, test.bs[len(test.bs)-1]%8)
if !bytes.Equal(got, want) {
t.Errorf("%s: got %v, want %v", test.pat, got, want)
}
}
})
}

View file

@ -2347,7 +2347,8 @@ func RedirectHandler(url string, code int) Handler {
type ServeMux struct {
mu sync.RWMutex
tree routingNode
patterns []*pattern
index routingIndex
patterns []*pattern // TODO(jba): remove if possible
}
// NewServeMux allocates and returns a new ServeMux.
@ -2624,8 +2625,8 @@ func (mux *ServeMux) register(pattern string, handler Handler) {
}
}
func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
if pattern == "" {
func (mux *ServeMux) registerErr(patstr string, handler Handler) error {
if patstr == "" {
return errors.New("http: invalid pattern")
}
if handler == nil {
@ -2635,9 +2636,9 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
return errors.New("http: nil handler")
}
pat, err := parsePattern(pattern)
pat, err := parsePattern(patstr)
if err != nil {
return fmt.Errorf("parsing %q: %w", pattern, err)
return fmt.Errorf("parsing %q: %w", patstr, err)
}
// Get the caller's location, for better conflict error messages.
@ -2652,16 +2653,17 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
mux.mu.Lock()
defer mux.mu.Unlock()
// Check for conflict.
// This makes a quadratic number of calls to conflictsWith: we check
// each pattern against every other pattern.
// TODO(jba): add indexing to speed this up.
for _, pat2 := range mux.patterns {
if err := mux.index.possiblyConflictingPatterns(pat, func(pat2 *pattern) error {
if pat.conflictsWith(pat2) {
return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s)",
pat, pat.loc, pat2, pat2.loc)
}
return nil
}); err != nil {
return err
}
mux.tree.addPattern(pat, handler)
mux.index.addPattern(pat)
mux.patterns = append(mux.patterns, pat)
return nil
}

View file

@ -0,0 +1,2 @@
go test fuzz v1
[]byte("101$")

View file

@ -0,0 +1,2 @@
go test fuzz v1
[]byte("1010")