mirror of
https://github.com/golang/go
synced 2024-09-18 15:32:18 +00:00
net/http: index patterns for faster conflict detection
Add an index so that pattern registration isn't always quadratic. If there were no index, then every pattern that was registered would have to be compared to every existing pattern for conflicts. This would make registration quadratic in the number of patterns, in every case. The index in this CL should help most of the time. If a pattern has a literal segment, it will weed out all other patterns that have a different literal in that position. The worst case will still be quadratic, but it is unlikely that a set of such patterns would arise naturally. One novel (to me) aspect of the CL is the use of fuzz testing on data that is neither a string nor a byte slice. The test uses fuzzing to generate a byte slice, then decodes the byte slice into a valid pattern (most of the time). This test actually caught a bug: see https://go.dev/cl/529119. Change-Id: Ice0be6547decb5ce75a8062e4e17227815d5d0b0 Reviewed-on: https://go-review.googlesource.com/c/go/+/529121 Run-TryBot: Jonathan Amsterdam <jba@google.com> TryBot-Result: Gopher Robot <gobot@golang.org> Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
parent
bda5e6c3d0
commit
be11422b1e
124
src/net/http/routing_index.go
Normal file
124
src/net/http/routing_index.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
import "math"
|
||||
|
||||
// A routingIndex optimizes conflict detection by indexing patterns.
|
||||
//
|
||||
// The basic idea is to rule out patterns that cannot conflict with a given
|
||||
// pattern because they have a different literal in a corresponding segment.
|
||||
// See the comments in [routingIndex.possiblyConflictingPatterns] for more details.
|
||||
type routingIndex struct {
|
||||
// map from a particular segment position and value to all registered patterns
|
||||
// with that value in that position.
|
||||
// For example, the key {1, "b"} would hold the patterns "/a/b" and "/a/b/c"
|
||||
// but not "/a", "b/a", "/a/c" or "/a/{x}".
|
||||
segments map[routingIndexKey][]*pattern
|
||||
// All patterns that end in a multi wildcard (including trailing slash).
|
||||
// We do not try to be clever about indexing multi patterns, because there
|
||||
// are unlikely to be many of them.
|
||||
multis []*pattern
|
||||
}
|
||||
|
||||
type routingIndexKey struct {
|
||||
pos int // 0-based segment position
|
||||
s string // literal, or empty for wildcard
|
||||
}
|
||||
|
||||
func (idx *routingIndex) addPattern(pat *pattern) {
|
||||
if pat.lastSegment().multi {
|
||||
idx.multis = append(idx.multis, pat)
|
||||
} else {
|
||||
if idx.segments == nil {
|
||||
idx.segments = map[routingIndexKey][]*pattern{}
|
||||
}
|
||||
for pos, seg := range pat.segments {
|
||||
key := routingIndexKey{pos: pos, s: ""}
|
||||
if !seg.wild {
|
||||
key.s = seg.s
|
||||
}
|
||||
idx.segments[key] = append(idx.segments[key], pat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// possiblyConflictingPatterns calls f on all patterns that might conflict with
|
||||
// pat. If f returns a non-nil error, possiblyConflictingPatterns returns immediately
|
||||
// with that error.
|
||||
//
|
||||
// To be correct, possiblyConflictingPatterns must include all patterns that
|
||||
// might conflict. But it may also include patterns that cannot conflict.
|
||||
// For instance, an implementation that returns all registered patterns is correct.
|
||||
// We use this fact throughout, simplifying the implementation by returning more
|
||||
// patterns that we might need to.
|
||||
func (idx *routingIndex) possiblyConflictingPatterns(pat *pattern, f func(*pattern) error) (err error) {
|
||||
// Terminology:
|
||||
// dollar pattern: one ending in "{$}"
|
||||
// multi pattern: one ending in a trailing slash or "{x...}" wildcard
|
||||
// ordinary pattern: neither of the above
|
||||
|
||||
// apply f to all the pats, stopping on error.
|
||||
apply := func(pats []*pattern) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, p := range pats {
|
||||
err = f(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Our simple indexing scheme doesn't try to prune multi patterns; assume
|
||||
// any of them can match the argument.
|
||||
if err := apply(idx.multis); err != nil {
|
||||
return err
|
||||
}
|
||||
if pat.lastSegment().s == "/" {
|
||||
// All paths that a dollar pattern matches end in a slash; no paths that
|
||||
// an ordinary pattern matches do. So only other dollar or multi
|
||||
// patterns can conflict with a dollar pattern. Furthermore, conflicting
|
||||
// dollar patterns must have the {$} in the same position.
|
||||
return apply(idx.segments[routingIndexKey{s: "/", pos: len(pat.segments) - 1}])
|
||||
}
|
||||
// For ordinary and multi patterns, the only conflicts can be with a multi,
|
||||
// or a pattern that has the same literal or a wildcard at some literal
|
||||
// position.
|
||||
// We could intersect all the possible matches at each position, but we
|
||||
// do something simpler: we find the position with the fewest patterns.
|
||||
var lmin, wmin []*pattern
|
||||
min := math.MaxInt
|
||||
hasLit := false
|
||||
for i, seg := range pat.segments {
|
||||
if seg.multi {
|
||||
break
|
||||
}
|
||||
if !seg.wild {
|
||||
hasLit = true
|
||||
lpats := idx.segments[routingIndexKey{s: seg.s, pos: i}]
|
||||
wpats := idx.segments[routingIndexKey{s: "", pos: i}]
|
||||
if sum := len(lpats) + len(wpats); sum < min {
|
||||
lmin = lpats
|
||||
wmin = wpats
|
||||
min = sum
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasLit {
|
||||
apply(lmin)
|
||||
apply(wmin)
|
||||
return err
|
||||
}
|
||||
|
||||
// This pattern is all wildcards.
|
||||
// Check it against everything.
|
||||
for _, pats := range idx.segments {
|
||||
apply(pats)
|
||||
}
|
||||
return err
|
||||
}
|
207
src/net/http/routing_index_test.go
Normal file
207
src/net/http/routing_index_test.go
Normal file
|
@ -0,0 +1,207 @@
|
|||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIndex(t *testing.T) {
|
||||
pats := []string{"HEAD /", "/a"}
|
||||
|
||||
var patterns []*pattern
|
||||
var idx routingIndex
|
||||
for _, p := range pats {
|
||||
pat := mustParsePattern(t, p)
|
||||
patterns = append(patterns, pat)
|
||||
idx.addPattern(pat)
|
||||
}
|
||||
|
||||
compare := func(pat *pattern) {
|
||||
t.Helper()
|
||||
got := indexConflicts(pat, &idx)
|
||||
want := trueConflicts(pat, patterns)
|
||||
if !slices.Equal(got, want) {
|
||||
t.Errorf("%q:\ngot %q\nwant %q", pat, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
compare(mustParsePattern(t, "GET /foo"))
|
||||
compare(mustParsePattern(t, "GET /{x}"))
|
||||
}
|
||||
|
||||
// This test works by comparing possiblyConflictingPatterns with
|
||||
// an exhaustive loop through all patterns.
|
||||
func FuzzIndex(f *testing.F) {
|
||||
inits := []string{"/a", "/a/b", "/{x0}", "/{x0}/b", "/a/{x0}", "/a/{$}", "/a/b/{$}",
|
||||
"/a/", "/a/b/", "/{x}/b/c/{$}", "GET /{x0}/", "HEAD /a"}
|
||||
|
||||
var patterns []*pattern
|
||||
var idx routingIndex
|
||||
|
||||
// compare takes a fatalf function because fuzzing doesn't like
|
||||
// it when the fuzz function calls f.Fatalf.
|
||||
compare := func(pat *pattern, fatalf func(string, ...any)) {
|
||||
got := indexConflicts(pat, &idx)
|
||||
want := trueConflicts(pat, patterns)
|
||||
if !slices.Equal(got, want) {
|
||||
fatalf("%q:\ngot %q\nwant %q", pat, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range inits {
|
||||
pat, err := parsePattern(p)
|
||||
if err != nil {
|
||||
f.Fatal(err)
|
||||
}
|
||||
compare(pat, f.Fatalf)
|
||||
patterns = append(patterns, pat)
|
||||
idx.addPattern(pat)
|
||||
f.Add(bytesFromPattern(pat))
|
||||
}
|
||||
|
||||
f.Fuzz(func(t *testing.T, pb []byte) {
|
||||
pat := bytesToPattern(pb)
|
||||
if pat == nil {
|
||||
return
|
||||
}
|
||||
compare(pat, t.Fatalf)
|
||||
})
|
||||
}
|
||||
|
||||
func trueConflicts(pat *pattern, pats []*pattern) []string {
|
||||
var s []string
|
||||
for _, p := range pats {
|
||||
if pat.conflictsWith(p) {
|
||||
s = append(s, p.String())
|
||||
}
|
||||
}
|
||||
sort.Strings(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func indexConflicts(pat *pattern, idx *routingIndex) []string {
|
||||
var s []string
|
||||
idx.possiblyConflictingPatterns(pat, func(p *pattern) error {
|
||||
if pat.conflictsWith(p) {
|
||||
s = append(s, p.String())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
sort.Strings(s)
|
||||
return slices.Compact(s)
|
||||
}
|
||||
|
||||
// TODO: incorporate host and method; make encoding denser.
|
||||
func bytesToPattern(bs []byte) *pattern {
|
||||
if len(bs) == 0 {
|
||||
return nil
|
||||
}
|
||||
var sb strings.Builder
|
||||
wc := 0
|
||||
for _, b := range bs[:len(bs)-1] {
|
||||
sb.WriteByte('/')
|
||||
switch b & 0x3 {
|
||||
case 0:
|
||||
fmt.Fprintf(&sb, "{x%d}", wc)
|
||||
wc++
|
||||
case 1:
|
||||
sb.WriteString("a")
|
||||
case 2:
|
||||
sb.WriteString("b")
|
||||
case 3:
|
||||
sb.WriteString("c")
|
||||
}
|
||||
}
|
||||
sb.WriteByte('/')
|
||||
switch bs[len(bs)-1] & 0x7 {
|
||||
case 0:
|
||||
fmt.Fprintf(&sb, "{x%d}", wc)
|
||||
case 1:
|
||||
sb.WriteString("a")
|
||||
case 2:
|
||||
sb.WriteString("b")
|
||||
case 3:
|
||||
sb.WriteString("c")
|
||||
case 4, 5:
|
||||
fmt.Fprintf(&sb, "{x%d...}", wc)
|
||||
default:
|
||||
sb.WriteString("{$}")
|
||||
}
|
||||
pat, err := parsePattern(sb.String())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pat
|
||||
}
|
||||
|
||||
func bytesFromPattern(p *pattern) []byte {
|
||||
var bs []byte
|
||||
for _, s := range p.segments {
|
||||
var b byte
|
||||
switch {
|
||||
case s.multi:
|
||||
b = 4
|
||||
case s.wild:
|
||||
b = 0
|
||||
case s.s == "/":
|
||||
b = 7
|
||||
case s.s == "a":
|
||||
b = 1
|
||||
case s.s == "b":
|
||||
b = 2
|
||||
case s.s == "c":
|
||||
b = 3
|
||||
default:
|
||||
panic("bad pattern")
|
||||
}
|
||||
bs = append(bs, b)
|
||||
}
|
||||
return bs
|
||||
}
|
||||
|
||||
func TestBytesPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
bs []byte
|
||||
pat string
|
||||
}{
|
||||
{[]byte{0, 1, 2, 3}, "/{x0}/a/b/c"},
|
||||
{[]byte{16, 17, 18, 19}, "/{x0}/a/b/c"},
|
||||
{[]byte{4, 4}, "/{x0}/{x1...}"},
|
||||
{[]byte{6, 7}, "/b/{$}"},
|
||||
}
|
||||
t.Run("To", func(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
p := bytesToPattern(test.bs)
|
||||
got := p.String()
|
||||
if got != test.pat {
|
||||
t.Errorf("%v: got %q, want %q", test.bs, got, test.pat)
|
||||
}
|
||||
}
|
||||
})
|
||||
t.Run("From", func(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
p, err := parsePattern(test.pat)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := bytesFromPattern(p)
|
||||
var want []byte
|
||||
for _, b := range test.bs[:len(test.bs)-1] {
|
||||
want = append(want, b%4)
|
||||
|
||||
}
|
||||
want = append(want, test.bs[len(test.bs)-1]%8)
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("%s: got %v, want %v", test.pat, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
|
@ -2347,7 +2347,8 @@ func RedirectHandler(url string, code int) Handler {
|
|||
type ServeMux struct {
|
||||
mu sync.RWMutex
|
||||
tree routingNode
|
||||
patterns []*pattern
|
||||
index routingIndex
|
||||
patterns []*pattern // TODO(jba): remove if possible
|
||||
}
|
||||
|
||||
// NewServeMux allocates and returns a new ServeMux.
|
||||
|
@ -2624,8 +2625,8 @@ func (mux *ServeMux) register(pattern string, handler Handler) {
|
|||
}
|
||||
}
|
||||
|
||||
func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
|
||||
if pattern == "" {
|
||||
func (mux *ServeMux) registerErr(patstr string, handler Handler) error {
|
||||
if patstr == "" {
|
||||
return errors.New("http: invalid pattern")
|
||||
}
|
||||
if handler == nil {
|
||||
|
@ -2635,9 +2636,9 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
|
|||
return errors.New("http: nil handler")
|
||||
}
|
||||
|
||||
pat, err := parsePattern(pattern)
|
||||
pat, err := parsePattern(patstr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing %q: %w", pattern, err)
|
||||
return fmt.Errorf("parsing %q: %w", patstr, err)
|
||||
}
|
||||
|
||||
// Get the caller's location, for better conflict error messages.
|
||||
|
@ -2652,16 +2653,17 @@ func (mux *ServeMux) registerErr(pattern string, handler Handler) error {
|
|||
mux.mu.Lock()
|
||||
defer mux.mu.Unlock()
|
||||
// Check for conflict.
|
||||
// This makes a quadratic number of calls to conflictsWith: we check
|
||||
// each pattern against every other pattern.
|
||||
// TODO(jba): add indexing to speed this up.
|
||||
for _, pat2 := range mux.patterns {
|
||||
if err := mux.index.possiblyConflictingPatterns(pat, func(pat2 *pattern) error {
|
||||
if pat.conflictsWith(pat2) {
|
||||
return fmt.Errorf("pattern %q (registered at %s) conflicts with pattern %q (registered at %s)",
|
||||
pat, pat.loc, pat2, pat2.loc)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
mux.tree.addPattern(pat, handler)
|
||||
mux.index.addPattern(pat)
|
||||
mux.patterns = append(mux.patterns, pat)
|
||||
return nil
|
||||
}
|
||||
|
|
2
src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da
vendored
Normal file
2
src/net/http/testdata/fuzz/FuzzIndex/48161038f0c8b2da
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
go test fuzz v1
|
||||
[]byte("101$")
|
2
src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3
vendored
Normal file
2
src/net/http/testdata/fuzz/FuzzIndex/716514f590ce7ab3
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
go test fuzz v1
|
||||
[]byte("1010")
|
Loading…
Reference in a new issue