diff --git a/src/bytes/bytes.go b/src/bytes/bytes.go index 529d95a888..9e6b68eaf4 100644 --- a/src/bytes/bytes.go +++ b/src/bytes/bytes.go @@ -21,7 +21,7 @@ func Equal(a, b []byte) bool { } // Compare returns an integer comparing two byte slices lexicographically. -// The result will be 0 if a==b, -1 if a < b, and +1 if a > b. +// The result will be 0 if a == b, -1 if a < b, and +1 if a > b. // A nil argument is equivalent to an empty slice. func Compare(a, b []byte) int { return bytealg.Compare(a, b) diff --git a/src/crypto/x509/root_darwin.go b/src/crypto/x509/root_darwin.go index 164ad9dc77..ef051efd31 100644 --- a/src/crypto/x509/root_darwin.go +++ b/src/crypto/x509/root_darwin.go @@ -10,11 +10,11 @@ import ( "bytes" macOS "crypto/x509/internal/macos" "fmt" + "internal/godebug" "os" - "strings" ) -var debugDarwinRoots = strings.Contains(os.Getenv("GODEBUG"), "x509roots=1") +var debugDarwinRoots = godebug.Get("x509roots") == "1" func (c *Certificate) systemVerify(opts *VerifyOptions) (chains [][]*Certificate, err error) { return nil, nil diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index f4a92f8be4..1dd65d60d9 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -173,7 +173,7 @@ var depsRules = ` io/fs < embed; - unicode, fmt !< os, os/signal; + unicode, fmt !< net, os, os/signal; os/signal, STR < path/filepath @@ -187,6 +187,8 @@ var depsRules = ` OS < golang.org/x/sys/cpu; + os < internal/godebug; + # FMT is OS (which includes string routines) plus reflect and fmt. # It does not include package log, which should be avoided in core packages. strconv, unicode @@ -352,6 +354,13 @@ var depsRules = ` golang.org/x/net/lif, golang.org/x/net/route; + os, runtime, strconv, sync, unsafe, + internal/godebug + < internal/intern; + + internal/bytealg, internal/intern, internal/itoa, math/bits, sort, strconv + < net/netip; + # net is unavoidable when doing any networking, # so large dependencies must be kept out. # This is a long-looking list but most of these @@ -360,10 +369,12 @@ var depsRules = ` golang.org/x/net/dns/dnsmessage, golang.org/x/net/lif, golang.org/x/net/route, + internal/godebug, internal/nettrace, internal/poll, internal/singleflight, internal/race, + net/netip, os < net; @@ -515,7 +526,8 @@ var depsRules = ` FMT, DEBUG, flag, runtime/trace, internal/sysinfo, math/rand < testing; - FMT, crypto/sha256, encoding/json, go/ast, go/parser, go/token, math/rand, encoding/hex, crypto/sha256 + FMT, crypto/sha256, encoding/json, go/ast, go/parser, go/token, + internal/godebug, math/rand, encoding/hex, crypto/sha256 < internal/fuzz; internal/fuzz, internal/testlog, runtime/pprof, regexp diff --git a/src/internal/fuzz/fuzz.go b/src/internal/fuzz/fuzz.go index 78319a7496..2ebe2a64db 100644 --- a/src/internal/fuzz/fuzz.go +++ b/src/internal/fuzz/fuzz.go @@ -12,6 +12,7 @@ import ( "crypto/sha256" "errors" "fmt" + "internal/godebug" "io" "io/ioutil" "math/bits" @@ -1063,13 +1064,7 @@ var ( func shouldPrintDebugInfo() bool { debugInfoOnce.Do(func() { - debug := strings.Split(os.Getenv("GODEBUG"), ",") - for _, f := range debug { - if f == "fuzzdebug=1" { - debugInfo = true - break - } - } + debugInfo = godebug.Get("fuzzdebug") == "1" }) return debugInfo } diff --git a/src/internal/godebug/godebug.go b/src/internal/godebug/godebug.go new file mode 100644 index 0000000000..ac434e5fd8 --- /dev/null +++ b/src/internal/godebug/godebug.go @@ -0,0 +1,34 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package godebug parses the GODEBUG environment variable. +package godebug + +import "os" + +// Get returns the value for the provided GODEBUG key. +func Get(key string) string { + return get(os.Getenv("GODEBUG"), key) +} + +// get returns the value part of key=value in s (a GODEBUG value). +func get(s, key string) string { + for i := 0; i < len(s)-len(key)-1; i++ { + if i > 0 && s[i-1] != ',' { + continue + } + afterKey := s[i+len(key):] + if afterKey[0] != '=' || s[i:i+len(key)] != key { + continue + } + val := afterKey[1:] + for i, b := range val { + if b == ',' { + return val[:i] + } + } + return val + } + return "" +} diff --git a/src/internal/godebug/godebug_test.go b/src/internal/godebug/godebug_test.go new file mode 100644 index 0000000000..41b9117b73 --- /dev/null +++ b/src/internal/godebug/godebug_test.go @@ -0,0 +1,34 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package godebug + +import "testing" + +func TestGet(t *testing.T) { + tests := []struct { + godebug string + key string + want string + }{ + {"", "", ""}, + {"", "foo", ""}, + {"foo=bar", "foo", "bar"}, + {"foo=bar,after=x", "foo", "bar"}, + {"before=x,foo=bar,after=x", "foo", "bar"}, + {"before=x,foo=bar", "foo", "bar"}, + {",,,foo=bar,,,", "foo", "bar"}, + {"foodecoy=wrong,foo=bar", "foo", "bar"}, + {"foo=", "foo", ""}, + {"foo", "foo", ""}, + {",foo", "foo", ""}, + {"foo=bar,baz", "loooooooong", ""}, + } + for _, tt := range tests { + got := get(tt.godebug, tt.key) + if got != tt.want { + t.Errorf("get(%q, %q) = %q; want %q", tt.godebug, tt.key, got, tt.want) + } + } +} diff --git a/src/internal/intern/intern.go b/src/internal/intern/intern.go new file mode 100644 index 0000000000..666caa6d2f --- /dev/null +++ b/src/internal/intern/intern.go @@ -0,0 +1,178 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package intern lets you make smaller comparable values by boxing +// a larger comparable value (such as a 16 byte string header) down +// into a globally unique 8 byte pointer. +// +// The globally unique pointers are garbage collected with weak +// references and finalizers. This package hides that. +package intern + +import ( + "internal/godebug" + "runtime" + "sync" + "unsafe" +) + +// A Value pointer is the handle to an underlying comparable value. +// See func Get for how Value pointers may be used. +type Value struct { + _ [0]func() // prevent people from accidentally using value type as comparable + cmpVal interface{} + // resurrected is guarded by mu (for all instances of Value). + // It is set true whenever v is synthesized from a uintptr. + resurrected bool +} + +// Get returns the comparable value passed to the Get func +// that returned v. +func (v *Value) Get() interface{} { return v.cmpVal } + +// key is a key in our global value map. +// It contains type-specialized fields to avoid allocations +// when converting common types to empty interfaces. +type key struct { + s string + cmpVal interface{} + // isString reports whether key contains a string. + // Without it, the zero value of key is ambiguous. + isString bool +} + +// keyFor returns a key to use with cmpVal. +func keyFor(cmpVal interface{}) key { + if s, ok := cmpVal.(string); ok { + return key{s: s, isString: true} + } + return key{cmpVal: cmpVal} +} + +// Value returns a *Value built from k. +func (k key) Value() *Value { + if k.isString { + return &Value{cmpVal: k.s} + } + return &Value{cmpVal: k.cmpVal} +} + +var ( + // mu guards valMap, a weakref map of *Value by underlying value. + // It also guards the resurrected field of all *Values. + mu sync.Mutex + valMap = map[key]uintptr{} // to uintptr(*Value) + valSafe = safeMap() // non-nil in safe+leaky mode +) + +// safeMap returns a non-nil map if we're in safe-but-leaky mode, +// as controlled by GODEBUG=intern=leaky +func safeMap() map[key]*Value { + if godebug.Get("intern") == "leaky" { + return map[key]*Value{} + } + return nil +} + +// Get returns a pointer representing the comparable value cmpVal. +// +// The returned pointer will be the same for Get(v) and Get(v2) +// if and only if v == v2, and can be used as a map key. +func Get(cmpVal interface{}) *Value { + return get(keyFor(cmpVal)) +} + +// GetByString is identical to Get, except that it is specialized for strings. +// This avoids an allocation from putting a string into an interface{} +// to pass as an argument to Get. +func GetByString(s string) *Value { + return get(key{s: s, isString: true}) +} + +// We play unsafe games that violate Go's rules (and assume a non-moving +// collector). So we quiet Go here. +// See the comment below Get for more implementation details. +//go:nocheckptr +func get(k key) *Value { + mu.Lock() + defer mu.Unlock() + + var v *Value + if valSafe != nil { + v = valSafe[k] + } else if addr, ok := valMap[k]; ok { + v = (*Value)(unsafe.Pointer(addr)) + v.resurrected = true + } + if v != nil { + return v + } + v = k.Value() + if valSafe != nil { + valSafe[k] = v + } else { + // SetFinalizer before uintptr conversion (theoretical concern; + // see https://github.com/go4org/intern/issues/13) + runtime.SetFinalizer(v, finalize) + valMap[k] = uintptr(unsafe.Pointer(v)) + } + return v +} + +func finalize(v *Value) { + mu.Lock() + defer mu.Unlock() + if v.resurrected { + // We lost the race. Somebody resurrected it while we + // were about to finalize it. Try again next round. + v.resurrected = false + runtime.SetFinalizer(v, finalize) + return + } + delete(valMap, keyFor(v.cmpVal)) +} + +// Interning is simple if you don't require that unused values be +// garbage collectable. But we do require that; we don't want to be +// DOS vector. We do this by using a uintptr to hide the pointer from +// the garbage collector, and using a finalizer to eliminate the +// pointer when no other code is using it. +// +// The obvious implementation of this is to use a +// map[interface{}]uintptr-of-*interface{}, and set up a finalizer to +// delete from the map. Unfortunately, this is racy. Because pointers +// are being created in violation of Go's unsafety rules, it's +// possible to create a pointer to a value concurrently with the GC +// concluding that the value can be collected. There are other races +// that break the equality invariant as well, but the use-after-free +// will cause a runtime crash. +// +// To make this work, the finalizer needs to know that no references +// have been unsafely created since the finalizer was set up. To do +// this, values carry a "resurrected" sentinel, which gets set +// whenever a pointer is unsafely created. If the finalizer encounters +// the sentinel, it clears the sentinel and delays collection for one +// additional GC cycle, by re-installing itself as finalizer. This +// ensures that the unsafely created pointer is visible to the GC, and +// will correctly prevent collection. +// +// This technique does mean that interned values that get reused take +// at least 3 GC cycles to fully collect (1 to clear the sentinel, 1 +// to clean up the unsafe map, 1 to be actually deleted). +// +// @ianlancetaylor commented in +// https://github.com/golang/go/issues/41303#issuecomment-717401656 +// that it is possible to implement weak references in terms of +// finalizers without unsafe. Unfortunately, the approach he outlined +// does not work here, for two reasons. First, there is no way to +// construct a strong pointer out of a weak pointer; our map stores +// weak pointers, but we must return strong pointers to callers. +// Second, and more fundamentally, we must return not just _a_ strong +// pointer to callers, but _the same_ strong pointer to callers. In +// order to return _the same_ strong pointer to callers, we must track +// it, which is exactly what we cannot do with strong pointers. +// +// See https://github.com/inetaf/netaddr/issues/53 for more +// discussion, and https://github.com/go4org/intern/issues/2 for an +// illustration of the subtleties at play. diff --git a/src/internal/intern/intern_test.go b/src/internal/intern/intern_test.go new file mode 100644 index 0000000000..d1e409ef95 --- /dev/null +++ b/src/internal/intern/intern_test.go @@ -0,0 +1,199 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package intern + +import ( + "fmt" + "runtime" + "testing" +) + +func TestBasics(t *testing.T) { + clearMap() + foo := Get("foo") + bar := Get("bar") + empty := Get("") + nilEface := Get(nil) + i := Get(0x7777777) + foo2 := Get("foo") + bar2 := Get("bar") + empty2 := Get("") + nilEface2 := Get(nil) + i2 := Get(0x7777777) + foo3 := GetByString("foo") + empty3 := GetByString("") + + if foo.Get() != foo2.Get() { + t.Error("foo/foo2 values differ") + } + if foo.Get() != foo3.Get() { + t.Error("foo/foo3 values differ") + } + if foo.Get() != "foo" { + t.Error("foo.Get not foo") + } + if foo != foo2 { + t.Error("foo/foo2 pointers differ") + } + if foo != foo3 { + t.Error("foo/foo3 pointers differ") + } + + if bar.Get() != bar2.Get() { + t.Error("bar values differ") + } + if bar.Get() != "bar" { + t.Error("bar.Get not bar") + } + if bar != bar2 { + t.Error("bar pointers differ") + } + + if i.Get() != i.Get() { + t.Error("i values differ") + } + if i.Get() != 0x7777777 { + t.Error("i.Get not 0x7777777") + } + if i != i2 { + t.Error("i pointers differ") + } + + if empty.Get() != empty2.Get() { + t.Error("empty/empty2 values differ") + } + if empty.Get() != empty.Get() { + t.Error("empty/empty3 values differ") + } + if empty.Get() != "" { + t.Error("empty.Get not empty string") + } + if empty != empty2 { + t.Error("empty/empty2 pointers differ") + } + if empty != empty3 { + t.Error("empty/empty3 pointers differ") + } + + if nilEface.Get() != nilEface2.Get() { + t.Error("nilEface values differ") + } + if nilEface.Get() != nil { + t.Error("nilEface.Get not nil") + } + if nilEface != nilEface2 { + t.Error("nilEface pointers differ") + } + + if n := mapLen(); n != 5 { + t.Errorf("map len = %d; want 4", n) + } + + wantEmpty(t) +} + +func wantEmpty(t testing.TB) { + t.Helper() + const gcTries = 5000 + for try := 0; try < gcTries; try++ { + runtime.GC() + n := mapLen() + if n == 0 { + break + } + if try == gcTries-1 { + t.Errorf("map len = %d after (%d GC tries); want 0, contents: %v", n, gcTries, mapKeys()) + } + } +} + +func TestStress(t *testing.T) { + iters := 10000 + if testing.Short() { + iters = 1000 + } + var sink []byte + for i := 0; i < iters; i++ { + _ = Get("foo") + sink = make([]byte, 1<<20) + } + _ = sink +} + +func BenchmarkStress(b *testing.B) { + done := make(chan struct{}) + defer close(done) + go func() { + for { + select { + case <-done: + return + default: + } + runtime.GC() + } + }() + + clearMap() + v1 := Get("foo") + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + v2 := Get("foo") + if v1 != v2 { + b.Fatal("wrong value") + } + // And also a key we don't retain: + _ = Get("bar") + } + }) + runtime.GC() + wantEmpty(b) +} + +func mapLen() int { + mu.Lock() + defer mu.Unlock() + return len(valMap) +} + +func mapKeys() (keys []string) { + mu.Lock() + defer mu.Unlock() + for k := range valMap { + keys = append(keys, fmt.Sprint(k)) + } + return keys +} + +func clearMap() { + mu.Lock() + defer mu.Unlock() + for k := range valMap { + delete(valMap, k) + } +} + +var ( + globalString = "not a constant" + sink string +) + +func TestGetByStringAllocs(t *testing.T) { + allocs := int(testing.AllocsPerRun(100, func() { + GetByString(globalString) + })) + if allocs != 0 { + t.Errorf("GetString allocated %d objects, want 0", allocs) + } +} + +func BenchmarkGetByString(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + v := GetByString(globalString) + sink = v.Get().(string) + } +} diff --git a/src/net/conf.go b/src/net/conf.go index 1115699ab0..415caedacc 100644 --- a/src/net/conf.go +++ b/src/net/conf.go @@ -8,6 +8,7 @@ package net import ( "internal/bytealg" + "internal/godebug" "os" "runtime" "sync" @@ -286,7 +287,7 @@ func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrde // cgo+2 // same, but debug level 2 // etc. func goDebugNetDNS() (dnsMode string, debugLevel int) { - goDebug := goDebugString("netdns") + goDebug := godebug.Get("netdns") parsePart := func(s string) { if s == "" { return diff --git a/src/net/http/server.go b/src/net/http/server.go index e9b0b4d9bd..91fad68694 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "errors" "fmt" + "internal/godebug" "io" "log" "math/rand" @@ -20,7 +21,6 @@ import ( "net/textproto" "net/url" urlpkg "net/url" - "os" "path" "runtime" "sort" @@ -3296,7 +3296,7 @@ func (srv *Server) onceSetNextProtoDefaults_Serve() { // configured otherwise. (by setting srv.TLSNextProto non-nil) // It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*). func (srv *Server) onceSetNextProtoDefaults() { - if omitBundledHTTP2 || strings.Contains(os.Getenv("GODEBUG"), "http2server=0") { + if omitBundledHTTP2 || godebug.Get("http2server") == "0" { return } // Enable HTTP/2 by default if the user hasn't otherwise diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 0e60992e6c..05a1659136 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -17,6 +17,7 @@ import ( "crypto/tls" "errors" "fmt" + "internal/godebug" "io" "log" "net" @@ -24,7 +25,6 @@ import ( "net/http/internal/ascii" "net/textproto" "net/url" - "os" "reflect" "strings" "sync" @@ -360,7 +360,7 @@ func (t *Transport) hasCustomTLSDialer() bool { // It must be called via t.nextProtoOnce.Do. func (t *Transport) onceSetNextProtoDefaults() { t.tlsNextProtoWasNil = (t.TLSNextProto == nil) - if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") { + if godebug.Get("http2client") == "0" { return } diff --git a/src/net/lookup.go b/src/net/lookup.go index fe573b8a27..e10c71ae75 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -8,6 +8,7 @@ import ( "context" "internal/nettrace" "internal/singleflight" + "net/netip" "sync" ) @@ -232,6 +233,28 @@ func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, er return ips, nil } +// LookupNetIP looks up host using the local resolver. +// It returns a slice of that host's IP addresses of the type specified by +// network. +// The network must be one of "ip", "ip4" or "ip6". +func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { + // TODO(bradfitz): make this efficient, making the internal net package + // type throughout be netip.Addr and only converting to the net.IP slice + // version at the edge. But for now (2021-10-20), this is a wrapper around + // the old way. + ips, err := r.LookupIP(ctx, network, host) + if err != nil { + return nil, err + } + ret := make([]netip.Addr, 0, len(ips)) + for _, ip := range ips { + if a, ok := netip.AddrFromSlice(ip); ok { + ret = append(ret, a) + } + } + return ret, nil +} + // onlyValuesCtx is a context that uses an underlying context // for value lookup if the underlying context hasn't yet expired. type onlyValuesCtx struct { diff --git a/src/net/netip/export_test.go b/src/net/netip/export_test.go new file mode 100644 index 0000000000..59971fa2e4 --- /dev/null +++ b/src/net/netip/export_test.go @@ -0,0 +1,30 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import "internal/intern" + +var ( + Z0 = z0 + Z4 = z4 + Z6noz = z6noz +) + +type Uint128 = uint128 + +func Mk128(hi, lo uint64) Uint128 { + return uint128{hi, lo} +} + +func MkAddr(u Uint128, z *intern.Value) Addr { + return Addr{u, z} +} + +func IPv4(a, b, c, d uint8) Addr { return AddrFrom4([4]byte{a, b, c, d}) } + +var TestAppendToMarshal = testAppendToMarshal + +func (a Addr) IsZero() bool { return a.isZero() } +func (p Prefix) IsZero() bool { return p.isZero() } diff --git a/src/net/netip/inlining_test.go b/src/net/netip/inlining_test.go new file mode 100644 index 0000000000..107fe1f083 --- /dev/null +++ b/src/net/netip/inlining_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import ( + "internal/testenv" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strings" + "testing" +) + +func TestInlining(t *testing.T) { + testenv.MustHaveGoBuild(t) + t.Parallel() + var exe string + if runtime.GOOS == "windows" { + exe = ".exe" + } + out, err := exec.Command( + filepath.Join(runtime.GOROOT(), "bin", "go"+exe), + "build", + "--gcflags=-m", + "net/netip").CombinedOutput() + if err != nil { + t.Fatalf("go build: %v, %s", err, out) + } + got := map[string]bool{} + regexp.MustCompile(` can inline (\S+)`).ReplaceAllFunc(out, func(match []byte) []byte { + got[strings.TrimPrefix(string(match), " can inline ")] = true + return nil + }) + wantInlinable := []string{ + "(*uint128).halves", + "Addr.BitLen", + "Addr.hasZone", + "Addr.Is4", + "Addr.Is4In6", + "Addr.Is6", + "Addr.IsLoopback", + "Addr.IsMulticast", + "Addr.IsInterfaceLocalMulticast", + "Addr.IsValid", + "Addr.IsUnspecified", + "Addr.Less", + "Addr.lessOrEq", + "Addr.Unmap", + "Addr.Zone", + "Addr.v4", + "Addr.v6", + "Addr.v6u16", + "Addr.withoutZone", + "AddrPortFrom", + "AddrPort.Addr", + "AddrPort.Port", + "AddrPort.IsValid", + "Prefix.IsSingleIP", + "Prefix.Masked", + "Prefix.IsValid", + "PrefixFrom", + "Prefix.Addr", + "Prefix.Bits", + "AddrFrom4", + "IPv6LinkLocalAllNodes", + "IPv6Unspecified", + "MustParseAddr", + "MustParseAddrPort", + "MustParsePrefix", + "appendDecimal", + "appendHex", + "uint128.addOne", + "uint128.and", + "uint128.bitsClearedFrom", + "uint128.bitsSetFrom", + "uint128.isZero", + "uint128.not", + "uint128.or", + "uint128.subOne", + "uint128.xor", + } + switch runtime.GOARCH { + case "amd64", "arm64": + // These don't inline on 32-bit. + wantInlinable = append(wantInlinable, + "u64CommonPrefixLen", + "uint128.commonPrefixLen", + "Addr.Next", + "Addr.Prev", + ) + } + + for _, want := range wantInlinable { + if !got[want] { + t.Errorf("%q is no longer inlinable", want) + continue + } + delete(got, want) + } + for sym := range got { + if strings.Contains(sym, ".func") { + continue + } + t.Logf("not in expected set, but also inlinable: %q", sym) + + } +} diff --git a/src/net/netip/leaf_alts.go b/src/net/netip/leaf_alts.go new file mode 100644 index 0000000000..c51f7dfa54 --- /dev/null +++ b/src/net/netip/leaf_alts.go @@ -0,0 +1,43 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Stuff that exists in std, but we can't use due to being a dependency +// of net, for go/build deps_test policy reasons. + +package netip + +func stringsLastIndexByte(s string, b byte) int { + for i := len(s) - 1; i >= 0; i-- { + if s[i] == b { + return i + } + } + return -1 +} + +func beUint64(b []byte) uint64 { + _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | + uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56 +} + +func bePutUint64(b []byte, v uint64) { + _ = b[7] // early bounds check to guarantee safety of writes below + b[0] = byte(v >> 56) + b[1] = byte(v >> 48) + b[2] = byte(v >> 40) + b[3] = byte(v >> 32) + b[4] = byte(v >> 24) + b[5] = byte(v >> 16) + b[6] = byte(v >> 8) + b[7] = byte(v) +} + +func bePutUint32(b []byte, v uint32) { + _ = b[3] // early bounds check to guarantee safety of writes below + b[0] = byte(v >> 24) + b[1] = byte(v >> 16) + b[2] = byte(v >> 8) + b[3] = byte(v) +} diff --git a/src/net/netip/netip.go b/src/net/netip/netip.go new file mode 100644 index 0000000000..4ef3b4bb68 --- /dev/null +++ b/src/net/netip/netip.go @@ -0,0 +1,1414 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package netip defines a IP address type that's a small value type. +// Building on that Addr type, the package also defines AddrPort (an +// IP address and a port), and Prefix (a IP address and a bit length +// prefix). +// +// Compared to the net.IP type, this package's Addr type takes less +// memory, is immutable, and is comparable (supports == and being a +// map key). +package netip + +import ( + "errors" + "math" + "strconv" + + "internal/bytealg" + "internal/intern" + "internal/itoa" +) + +// Sizes: (64-bit) +// net.IP: 24 byte slice header + {4, 16} = 28 to 40 bytes +// net.IPAddr: 40 byte slice header + {4, 16} = 44 to 56 bytes + zone length +// netip.Addr: 24 bytes (zone is per-name singleton, shared across all users) + +// Addr represents an IPv4 or IPv6 address (with or without a scoped +// addressing zone), similar to net.IP or net.IPAddr. +// +// Unlike net.IP or net.IPAddr, Addr is a comparable value +// type (it supports == and can be a map key) and is immutable. +type Addr struct { + // addr is the hi and lo bits of an IPv6 address. If z==z4, + // hi and lo contain the IPv4-mapped IPv6 address. + // + // hi and lo are constructed by interpreting a 16-byte IPv6 + // address as a big-endian 128-bit number. The most significant + // bits of that number go into hi, the rest into lo. + // + // For example, 0011:2233:4455:6677:8899:aabb:ccdd:eeff is stored as: + // addr.hi = 0x0011223344556677 + // addr.lo = 0x8899aabbccddeeff + // + // We store IPs like this, rather than as [16]byte, because it + // turns most operations on IPs into arithmetic and bit-twiddling + // operations on 64-bit registers, which is much faster than + // bytewise processing. + addr uint128 + + // z is a combination of the address family and the IPv6 zone. + // + // nil means invalid IP address (for a zero Addr). + // z4 means an IPv4 address. + // z6noz means an IPv6 address without a zone. + // + // Otherwise it's the interned zone name string. + z *intern.Value +} + +// z0, z4, and z6noz are sentinel IP.z values. +// See the IP type's field docs. +var ( + z0 = (*intern.Value)(nil) + z4 = new(intern.Value) + z6noz = new(intern.Value) +) + +// IPv6LinkLocalAllNodes returns the IPv6 link-local all nodes multicast +// address ff02::1. +func IPv6LinkLocalAllNodes() Addr { return AddrFrom16([16]byte{0: 0xff, 1: 0x02, 15: 0x01}) } + +// IPv6Unspecified returns the IPv6 unspecified address "::". +func IPv6Unspecified() Addr { return Addr{z: z6noz} } + +// AddrFrom4 returns the address of the IPv4 address given by the bytes in addr. +func AddrFrom4(addr [4]byte) Addr { + return Addr{ + addr: uint128{0, 0xffff00000000 | uint64(addr[0])<<24 | uint64(addr[1])<<16 | uint64(addr[2])<<8 | uint64(addr[3])}, + z: z4, + } +} + +// AddrFrom16 returns the IPv6 address given by the bytes in addr. +// An IPv6-mapped IPv4 address is left as an IPv6 address. +// (Use Unmap to convert them if needed.) +func AddrFrom16(addr [16]byte) Addr { + return Addr{ + addr: uint128{ + beUint64(addr[:8]), + beUint64(addr[8:]), + }, + z: z6noz, + } +} + +// ipv6Slice is like IPv6Raw, but operates on a 16-byte slice. Assumes +// slice is 16 bytes, caller must enforce this. +func ipv6Slice(addr []byte) Addr { + return Addr{ + addr: uint128{ + beUint64(addr[:8]), + beUint64(addr[8:]), + }, + z: z6noz, + } +} + +// ParseAddr parses s as an IP address, returning the result. The string +// s can be in dotted decimal ("192.0.2.1"), IPv6 ("2001:db8::68"), +// or IPv6 with a scoped addressing zone ("fe80::1cc0:3e8c:119f:c2e1%ens18"). +func ParseAddr(s string) (Addr, error) { + for i := 0; i < len(s); i++ { + switch s[i] { + case '.': + return parseIPv4(s) + case ':': + return parseIPv6(s) + case '%': + // Assume that this was trying to be an IPv6 address with + // a zone specifier, but the address is missing. + return Addr{}, parseAddrError{in: s, msg: "missing IPv6 address"} + } + } + return Addr{}, parseAddrError{in: s, msg: "unable to parse IP"} +} + +// MustParseAddr calls ParseAddr(s) and panics on error. +// It is intended for use in tests with hard-coded strings. +func MustParseAddr(s string) Addr { + ip, err := ParseAddr(s) + if err != nil { + panic(err) + } + return ip +} + +type parseAddrError struct { + in string // the string given to ParseAddr + msg string // an explanation of the parse failure + at string // optionally, the unparsed portion of in at which the error occurred. +} + +func (err parseAddrError) Error() string { + q := strconv.Quote + if err.at != "" { + return "ParseAddr(" + q(err.in) + "): " + err.msg + " (at " + q(err.at) + ")" + } + return "ParseAddr(" + q(err.in) + "): " + err.msg +} + +// parseIPv4 parses s as an IPv4 address (in form "192.168.0.1"). +func parseIPv4(s string) (ip Addr, err error) { + var fields [4]uint8 + var val, pos int + for i := 0; i < len(s); i++ { + if s[i] >= '0' && s[i] <= '9' { + val = val*10 + int(s[i]) - '0' + if val > 255 { + return Addr{}, parseAddrError{in: s, msg: "IPv4 field has value >255"} + } + } else if s[i] == '.' { + // .1.2.3 + // 1.2.3. + // 1..2.3 + if i == 0 || i == len(s)-1 || s[i-1] == '.' { + return Addr{}, parseAddrError{in: s, msg: "IPv4 field must have at least one digit", at: s[i:]} + } + // 1.2.3.4.5 + if pos == 3 { + return Addr{}, parseAddrError{in: s, msg: "IPv4 address too long"} + } + fields[pos] = uint8(val) + pos++ + val = 0 + } else { + return Addr{}, parseAddrError{in: s, msg: "unexpected character", at: s[i:]} + } + } + if pos < 3 { + return Addr{}, parseAddrError{in: s, msg: "IPv4 address too short"} + } + fields[3] = uint8(val) + return AddrFrom4(fields), nil +} + +// parseIPv6 parses s as an IPv6 address (in form "2001:db8::68"). +func parseIPv6(in string) (Addr, error) { + s := in + + // Split off the zone right from the start. Yes it's a second scan + // of the string, but trying to handle it inline makes a bunch of + // other inner loop conditionals more expensive, and it ends up + // being slower. + zone := "" + i := bytealg.IndexByteString(s, '%') + if i != -1 { + s, zone = s[:i], s[i+1:] + if zone == "" { + // Not allowed to have an empty zone if explicitly specified. + return Addr{}, parseAddrError{in: in, msg: "zone must be a non-empty string"} + } + } + + var ip [16]byte + ellipsis := -1 // position of ellipsis in ip + + // Might have leading ellipsis + if len(s) >= 2 && s[0] == ':' && s[1] == ':' { + ellipsis = 0 + s = s[2:] + // Might be only ellipsis + if len(s) == 0 { + return IPv6Unspecified().WithZone(zone), nil + } + } + + // Loop, parsing hex numbers followed by colon. + i = 0 + for i < 16 { + // Hex number. Similar to parseIPv4, inlining the hex number + // parsing yields a significant performance increase. + off := 0 + acc := uint32(0) + for ; off < len(s); off++ { + c := s[off] + if c >= '0' && c <= '9' { + acc = (acc << 4) + uint32(c-'0') + } else if c >= 'a' && c <= 'f' { + acc = (acc << 4) + uint32(c-'a'+10) + } else if c >= 'A' && c <= 'F' { + acc = (acc << 4) + uint32(c-'A'+10) + } else { + break + } + if acc > math.MaxUint16 { + // Overflow, fail. + return Addr{}, parseAddrError{in: in, msg: "IPv6 field has value >=2^16", at: s} + } + } + if off == 0 { + // No digits found, fail. + return Addr{}, parseAddrError{in: in, msg: "each colon-separated field must have at least one digit", at: s} + } + + // If followed by dot, might be in trailing IPv4. + if off < len(s) && s[off] == '.' { + if ellipsis < 0 && i != 12 { + // Not the right place. + return Addr{}, parseAddrError{in: in, msg: "embedded IPv4 address must replace the final 2 fields of the address", at: s} + } + if i+4 > 16 { + // Not enough room. + return Addr{}, parseAddrError{in: in, msg: "too many hex fields to fit an embedded IPv4 at the end of the address", at: s} + } + // TODO: could make this a bit faster by having a helper + // that parses to a [4]byte, and have both parseIPv4 and + // parseIPv6 use it. + ip4, err := parseIPv4(s) + if err != nil { + return Addr{}, parseAddrError{in: in, msg: err.Error(), at: s} + } + ip[i] = ip4.v4(0) + ip[i+1] = ip4.v4(1) + ip[i+2] = ip4.v4(2) + ip[i+3] = ip4.v4(3) + s = "" + i += 4 + break + } + + // Save this 16-bit chunk. + ip[i] = byte(acc >> 8) + ip[i+1] = byte(acc) + i += 2 + + // Stop at end of string. + s = s[off:] + if len(s) == 0 { + break + } + + // Otherwise must be followed by colon and more. + if s[0] != ':' { + return Addr{}, parseAddrError{in: in, msg: "unexpected character, want colon", at: s} + } else if len(s) == 1 { + return Addr{}, parseAddrError{in: in, msg: "colon must be followed by more characters", at: s} + } + s = s[1:] + + // Look for ellipsis. + if s[0] == ':' { + if ellipsis >= 0 { // already have one + return Addr{}, parseAddrError{in: in, msg: "multiple :: in address", at: s} + } + ellipsis = i + s = s[1:] + if len(s) == 0 { // can be at end + break + } + } + } + + // Must have used entire string. + if len(s) != 0 { + return Addr{}, parseAddrError{in: in, msg: "trailing garbage after address", at: s} + } + + // If didn't parse enough, expand ellipsis. + if i < 16 { + if ellipsis < 0 { + return Addr{}, parseAddrError{in: in, msg: "address string too short"} + } + n := 16 - i + for j := i - 1; j >= ellipsis; j-- { + ip[j+n] = ip[j] + } + for j := ellipsis + n - 1; j >= ellipsis; j-- { + ip[j] = 0 + } + } else if ellipsis >= 0 { + // Ellipsis must represent at least one 0 group. + return Addr{}, parseAddrError{in: in, msg: "the :: must expand to at least one field of zeros"} + } + return AddrFrom16(ip).WithZone(zone), nil +} + +// AddrFromSlice parses the 4- or 16-byte byte slice as an IPv4 or IPv6 address. +// Note that a net.IP can be passed directly as the []byte argument. +// If slice's length is not 4 or 16, AddrFromSlice returns Addr{}, false. +func AddrFromSlice(slice []byte) (ip Addr, ok bool) { + switch len(slice) { + case 4: + return AddrFrom4(*(*[4]byte)(slice)), true + case 16: + return ipv6Slice(slice), true + } + return Addr{}, false +} + +// v4 returns the i'th byte of ip. If ip is not an IPv4, v4 returns +// unspecified garbage. +func (ip Addr) v4(i uint8) uint8 { + return uint8(ip.addr.lo >> ((3 - i) * 8)) +} + +// v6 returns the i'th byte of ip. If ip is an IPv4 address, this +// accesses the IPv4-mapped IPv6 address form of the IP. +func (ip Addr) v6(i uint8) uint8 { + return uint8(*(ip.addr.halves()[(i/8)%2]) >> ((7 - i%8) * 8)) +} + +// v6u16 returns the i'th 16-bit word of ip. If ip is an IPv4 address, +// this accesses the IPv4-mapped IPv6 address form of the IP. +func (ip Addr) v6u16(i uint8) uint16 { + return uint16(*(ip.addr.halves()[(i/4)%2]) >> ((3 - i%4) * 16)) +} + +// isZero reports whether ip is the zero value of the IP type. +// The zero value is not a valid IP address of any type. +// +// Note that "0.0.0.0" and "::" are not the zero value. Use IsUnspecified to +// check for these values instead. +func (ip Addr) isZero() bool { + // Faster than comparing ip == Addr{}, but effectively equivalent, + // as there's no way to make an IP with a nil z from this package. + return ip.z == z0 +} + +// IsValid reports whether the Addr is an initialized address (not the zero Addr). +// +// Note that "0.0.0.0" and "::" are both valid values. +func (ip Addr) IsValid() bool { return ip.z != z0 } + +// BitLen returns the number of bits in the IP address: +// 128 for IPv6, 32 for IPv4, and 0 for the zero Addr. +// +// Note that IPv4-mapped IPv6 addresses are considered IPv6 addresses +// and therefore have bit length 128. +func (ip Addr) BitLen() int { + switch ip.z { + case z0: + return 0 + case z4: + return 32 + } + return 128 +} + +// Zone returns ip's IPv6 scoped addressing zone, if any. +func (ip Addr) Zone() string { + if ip.z == nil { + return "" + } + zone, _ := ip.z.Get().(string) + return zone +} + +// Compare returns an integer comparing two IPs. +// The result will be 0 if ip == ip2, -1 if ip < ip2, and +1 if ip > ip2. +// The definition of "less than" is the same as the Less method. +func (ip Addr) Compare(ip2 Addr) int { + f1, f2 := ip.BitLen(), ip2.BitLen() + if f1 < f2 { + return -1 + } + if f1 > f2 { + return 1 + } + hi1, hi2 := ip.addr.hi, ip2.addr.hi + if hi1 < hi2 { + return -1 + } + if hi1 > hi2 { + return 1 + } + lo1, lo2 := ip.addr.lo, ip2.addr.lo + if lo1 < lo2 { + return -1 + } + if lo1 > lo2 { + return 1 + } + if ip.Is6() { + za, zb := ip.Zone(), ip2.Zone() + if za < zb { + return -1 + } + if za > zb { + return 1 + } + } + return 0 +} + +// Less reports whether ip sorts before ip2. +// IP addresses sort first by length, then their address. +// IPv6 addresses with zones sort just after the same address without a zone. +func (ip Addr) Less(ip2 Addr) bool { return ip.Compare(ip2) == -1 } + +func (ip Addr) lessOrEq(ip2 Addr) bool { return ip.Compare(ip2) <= 0 } + +// ipZone returns the standard library net.IP from ip, as well +// as the zone. +// The optional reuse IP provides memory to reuse. +func (ip Addr) ipZone(reuse []byte) (stdIP []byte, zone string) { + base := reuse[:0] + switch { + case ip.z == z0: + return nil, "" + case ip.Is4(): + a4 := ip.As4() + return append(base, a4[:]...), "" + default: + a16 := ip.As16() + return append(base, a16[:]...), ip.Zone() + } +} + +// IPAddrParts returns the net.IPAddr representation of an Addr. +// +// The slice will be nil if ip is the zero Addr. +// The zone is the empty string if there is no zone. +func (ip Addr) IPAddrParts() (slice []byte, zone string) { + return ip.ipZone(nil) +} + +// Is4 reports whether ip is an IPv4 address. +// +// It returns false for IP4-mapped IPv6 addresses. See IP.Unmap. +func (ip Addr) Is4() bool { + return ip.z == z4 +} + +// Is4In6 reports whether ip is an IPv4-mapped IPv6 address. +func (ip Addr) Is4In6() bool { + return ip.Is6() && ip.addr.hi == 0 && ip.addr.lo>>32 == 0xffff +} + +// Is6 reports whether ip is an IPv6 address, including IPv4-mapped +// IPv6 addresses. +func (ip Addr) Is6() bool { + return ip.z != z0 && ip.z != z4 +} + +// Unmap returns ip with any IPv4-mapped IPv6 address prefix removed. +// +// That is, if ip is an IPv6 address wrapping an IPv4 adddress, it +// returns the wrapped IPv4 address. Otherwise it returns ip unmodified. +func (ip Addr) Unmap() Addr { + if ip.Is4In6() { + ip.z = z4 + } + return ip +} + +// WithZone returns an IP that's the same as ip but with the provided +// zone. If zone is empty, the zone is removed. If ip is an IPv4 +// address, WithZone is a no-op and returns ip unchanged. +func (ip Addr) WithZone(zone string) Addr { + if !ip.Is6() { + return ip + } + if zone == "" { + ip.z = z6noz + return ip + } + ip.z = intern.GetByString(zone) + return ip +} + +// withoutZone unconditionally strips the zone from IP. +// It's similar to WithZone, but small enough to be inlinable. +func (ip Addr) withoutZone() Addr { + if !ip.Is6() { + return ip + } + ip.z = z6noz + return ip +} + +// hasZone reports whether IP has an IPv6 zone. +func (ip Addr) hasZone() bool { + return ip.z != z0 && ip.z != z4 && ip.z != z6noz +} + +// IsLinkLocalUnicast reports whether ip is a link-local unicast address. +func (ip Addr) IsLinkLocalUnicast() bool { + // Dynamic Configuration of IPv4 Link-Local Addresses + // https://datatracker.ietf.org/doc/html/rfc3927#section-2.1 + if ip.Is4() { + return ip.v4(0) == 169 && ip.v4(1) == 254 + } + // IP Version 6 Addressing Architecture (2.4 Address Type Identification) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4 + if ip.Is6() { + return ip.v6u16(0)&0xffc0 == 0xfe80 + } + return false // zero value +} + +// IsLoopback reports whether ip is a loopback address. +func (ip Addr) IsLoopback() bool { + // Requirements for Internet Hosts -- Communication Layers (3.2.1.3 Addressing) + // https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.3 + if ip.Is4() { + return ip.v4(0) == 127 + } + // IP Version 6 Addressing Architecture (2.4 Address Type Identification) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4 + if ip.Is6() { + return ip.addr.hi == 0 && ip.addr.lo == 1 + } + return false // zero value +} + +// IsMulticast reports whether ip is a multicast address. +func (ip Addr) IsMulticast() bool { + // Host Extensions for IP Multicasting (4. HOST GROUP ADDRESSES) + // https://datatracker.ietf.org/doc/html/rfc1112#section-4 + if ip.Is4() { + return ip.v4(0)&0xf0 == 0xe0 + } + // IP Version 6 Addressing Architecture (2.4 Address Type Identification) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4 + if ip.Is6() { + return ip.addr.hi>>(64-8) == 0xff // ip.v6(0) == 0xff + } + return false // zero value +} + +// IsInterfaceLocalMulticast reports whether ip is an IPv6 interface-local +// multicast address. +func (ip Addr) IsInterfaceLocalMulticast() bool { + // IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1 + if ip.Is6() { + return ip.v6u16(0)&0xff0f == 0xff01 + } + return false // zero value +} + +// IsLinkLocalMulticast reports whether ip is a link-local multicast address. +func (ip Addr) IsLinkLocalMulticast() bool { + // IPv4 Multicast Guidelines (4. Local Network Control Block (224.0.0/24)) + // https://datatracker.ietf.org/doc/html/rfc5771#section-4 + if ip.Is4() { + return ip.v4(0) == 224 && ip.v4(1) == 0 && ip.v4(2) == 0 + } + // IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1 + if ip.Is6() { + return ip.v6u16(0)&0xff0f == 0xff02 + } + return false // zero value +} + +// IsGlobalUnicast reports whether ip is a global unicast address. +// +// It returns true for IPv6 addresses which fall outside of the current +// IANA-allocated 2000::/3 global unicast space, with the exception of the +// link-local address space. It also returns true even if ip is in the IPv4 +// private address space or IPv6 unique local address space. +// It returns false for the zero Addr. +// +// For reference, see RFC 1122, RFC 4291, and RFC 4632. +func (ip Addr) IsGlobalUnicast() bool { + if ip.z == z0 { + // Invalid or zero-value. + return false + } + + // Match package net's IsGlobalUnicast logic. Notably private IPv4 addresses + // and ULA IPv6 addresses are still considered "global unicast". + if ip.Is4() && (ip == AddrFrom4([4]byte{}) || ip == AddrFrom4([4]byte{255, 255, 255, 255})) { + return false + } + + return ip != IPv6Unspecified() && + !ip.IsLoopback() && + !ip.IsMulticast() && + !ip.IsLinkLocalUnicast() +} + +// IsPrivate reports whether ip is a private address, according to RFC 1918 +// (IPv4 addresses) and RFC 4193 (IPv6 addresses). That is, it reports whether +// ip is in 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, or fc00::/7. This is the +// same as net.IP.IsPrivate. +func (ip Addr) IsPrivate() bool { + // Match the stdlib's IsPrivate logic. + if ip.Is4() { + // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as + // private IPv4 address subnets. + return ip.v4(0) == 10 || + (ip.v4(0) == 172 && ip.v4(1)&0xf0 == 16) || + (ip.v4(0) == 192 && ip.v4(1) == 168) + } + + if ip.Is6() { + // RFC 4193 allocates fc00::/7 as the unique local unicast IPv6 address + // subnet. + return ip.v6(0)&0xfe == 0xfc + } + + return false // zero value +} + +// IsUnspecified reports whether ip is an unspecified address, either the IPv4 +// address "0.0.0.0" or the IPv6 address "::". +// +// Note that the zero Addr is not an unspecified address. +func (ip Addr) IsUnspecified() bool { + return ip == AddrFrom4([4]byte{}) || ip == IPv6Unspecified() +} + +// Prefix keeps only the top b bits of IP, producing a Prefix +// of the specified length. +// If ip is a zero Addr, Prefix always returns a zero Prefix and a nil error. +// Otherwise, if bits is less than zero or greater than ip.BitLen(), +// Prefix returns an error. +func (ip Addr) Prefix(b int) (Prefix, error) { + if b < 0 { + return Prefix{}, errors.New("negative Prefix bits") + } + effectiveBits := b + switch ip.z { + case z0: + return Prefix{}, nil + case z4: + if b > 32 { + return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv4") + } + effectiveBits += 96 + default: + if b > 128 { + return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv6") + } + } + ip.addr = ip.addr.and(mask6(effectiveBits)) + return PrefixFrom(ip, b), nil +} + +const ( + netIPv4len = 4 + netIPv6len = 16 +) + +// As16 returns the IP address in its 16-byte representation. +// IPv4 addresses are returned in their v6-mapped form. +// IPv6 addresses with zones are returned without their zone (use the +// Zone method to get it). +// The ip zero value returns all zeroes. +func (ip Addr) As16() [16]byte { + var ret [16]byte + bePutUint64(ret[:8], ip.addr.hi) + bePutUint64(ret[8:], ip.addr.lo) + return ret +} + +// As4 returns an IPv4 or IPv4-in-IPv6 address in its 4-byte representation. +// If ip is the zero Addr or an IPv6 address, As4 panics. +// Note that 0.0.0.0 is not the zero Addr. +func (ip Addr) As4() [4]byte { + if ip.z == z4 || ip.Is4In6() { + var ret [4]byte + bePutUint32(ret[:], uint32(ip.addr.lo)) + return ret + } + if ip.z == z0 { + panic("As4 called on IP zero value") + } + panic("As4 called on IPv6 address") +} + +// Next returns the address following ip. +// If there is none, it returns the zero Addr. +func (ip Addr) Next() Addr { + ip.addr = ip.addr.addOne() + if ip.Is4() { + if uint32(ip.addr.lo) == 0 { + // Overflowed. + return Addr{} + } + } else { + if ip.addr.isZero() { + // Overflowed + return Addr{} + } + } + return ip +} + +// Prev returns the IP before ip. +// If there is none, it returns the IP zero value. +func (ip Addr) Prev() Addr { + if ip.Is4() { + if uint32(ip.addr.lo) == 0 { + return Addr{} + } + } else if ip.addr.isZero() { + return Addr{} + } + ip.addr = ip.addr.subOne() + return ip +} + +// String returns the string form of the IP address ip. +// It returns one of 5 forms: +// +// - "invalid IP", if ip is the zero Addr +// - IPv4 dotted decimal ("192.0.2.1") +// - IPv6 ("2001:db8::1") +// - "::ffff:1.2.3.4" (if Is4In6) +// - IPv6 with zone ("fe80:db8::1%eth0") +// +// Note that unlike package net's IP.String method, +// IP4-mapped IPv6 addresses format with a "::ffff:" +// prefix before the dotted quad. +func (ip Addr) String() string { + switch ip.z { + case z0: + return "invalid IP" + case z4: + return ip.string4() + default: + if ip.Is4In6() { + // TODO(bradfitz): this could alloc less. + return "::ffff:" + ip.Unmap().String() + } + return ip.string6() + } +} + +// AppendTo appends a text encoding of ip, +// as generated by MarshalText, +// to b and returns the extended buffer. +func (ip Addr) AppendTo(b []byte) []byte { + switch ip.z { + case z0: + return b + case z4: + return ip.appendTo4(b) + default: + if ip.Is4In6() { + b = append(b, "::ffff:"...) + return ip.Unmap().appendTo4(b) + } + return ip.appendTo6(b) + } +} + +// digits is a string of the hex digits from 0 to f. It's used in +// appendDecimal and appendHex to format IP addresses. +const digits = "0123456789abcdef" + +// appendDecimal appends the decimal string representation of x to b. +func appendDecimal(b []byte, x uint8) []byte { + // Using this function rather than strconv.AppendUint makes IPv4 + // string building 2x faster. + + if x >= 100 { + b = append(b, digits[x/100]) + } + if x >= 10 { + b = append(b, digits[x/10%10]) + } + return append(b, digits[x%10]) +} + +// appendHex appends the hex string representation of x to b. +func appendHex(b []byte, x uint16) []byte { + // Using this function rather than strconv.AppendUint makes IPv6 + // string building 2x faster. + + if x >= 0x1000 { + b = append(b, digits[x>>12]) + } + if x >= 0x100 { + b = append(b, digits[x>>8&0xf]) + } + if x >= 0x10 { + b = append(b, digits[x>>4&0xf]) + } + return append(b, digits[x&0xf]) +} + +// appendHexPad appends the fully padded hex string representation of x to b. +func appendHexPad(b []byte, x uint16) []byte { + return append(b, digits[x>>12], digits[x>>8&0xf], digits[x>>4&0xf], digits[x&0xf]) +} + +func (ip Addr) string4() string { + const max = len("255.255.255.255") + ret := make([]byte, 0, max) + ret = ip.appendTo4(ret) + return string(ret) +} + +func (ip Addr) appendTo4(ret []byte) []byte { + ret = appendDecimal(ret, ip.v4(0)) + ret = append(ret, '.') + ret = appendDecimal(ret, ip.v4(1)) + ret = append(ret, '.') + ret = appendDecimal(ret, ip.v4(2)) + ret = append(ret, '.') + ret = appendDecimal(ret, ip.v4(3)) + return ret +} + +// string6 formats ip in IPv6 textual representation. It follows the +// guidelines in section 4 of RFC 5952 +// (https://tools.ietf.org/html/rfc5952#section-4): no unnecessary +// zeros, use :: to elide the longest run of zeros, and don't use :: +// to compact a single zero field. +func (ip Addr) string6() string { + // Use a zone with a "plausibly long" name, so that most zone-ful + // IP addresses won't require additional allocation. + // + // The compiler does a cool optimization here, where ret ends up + // stack-allocated and so the only allocation this function does + // is to construct the returned string. As such, it's okay to be a + // bit greedy here, size-wise. + const max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0") + ret := make([]byte, 0, max) + ret = ip.appendTo6(ret) + return string(ret) +} + +func (ip Addr) appendTo6(ret []byte) []byte { + zeroStart, zeroEnd := uint8(255), uint8(255) + for i := uint8(0); i < 8; i++ { + j := i + for j < 8 && ip.v6u16(j) == 0 { + j++ + } + if l := j - i; l >= 2 && l > zeroEnd-zeroStart { + zeroStart, zeroEnd = i, j + } + } + + for i := uint8(0); i < 8; i++ { + if i == zeroStart { + ret = append(ret, ':', ':') + i = zeroEnd + if i >= 8 { + break + } + } else if i > 0 { + ret = append(ret, ':') + } + + ret = appendHex(ret, ip.v6u16(i)) + } + + if ip.z != z6noz { + ret = append(ret, '%') + ret = append(ret, ip.Zone()...) + } + return ret +} + +// StringExpanded is like String but IPv6 addresses are expanded with leading +// zeroes and no "::" compression. For example, "2001:db8::1" becomes +// "2001:0db8:0000:0000:0000:0000:0000:0001". +func (ip Addr) StringExpanded() string { + switch ip.z { + case z0, z4: + return ip.String() + } + + const size = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + ret := make([]byte, 0, size) + for i := uint8(0); i < 8; i++ { + if i > 0 { + ret = append(ret, ':') + } + + ret = appendHexPad(ret, ip.v6u16(i)) + } + + if ip.z != z6noz { + // The addition of a zone will cause a second allocation, but when there + // is no zone the ret slice will be stack allocated. + ret = append(ret, '%') + ret = append(ret, ip.Zone()...) + } + return string(ret) +} + +// MarshalText implements the encoding.TextMarshaler interface, +// The encoding is the same as returned by String, with one exception: +// If ip is the zero Addr, the encoding is the empty string. +func (ip Addr) MarshalText() ([]byte, error) { + switch ip.z { + case z0: + return []byte(""), nil + case z4: + max := len("255.255.255.255") + b := make([]byte, 0, max) + return ip.appendTo4(b), nil + default: + max := len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0") + b := make([]byte, 0, max) + if ip.Is4In6() { + b = append(b, "::ffff:"...) + return ip.Unmap().appendTo4(b), nil + } + return ip.appendTo6(b), nil + } +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The IP address is expected in a form accepted by ParseAddr. +// +// If text is empty, UnmarshalText sets *ip to the zero Addr and +// returns no error. +func (ip *Addr) UnmarshalText(text []byte) error { + if len(text) == 0 { + *ip = Addr{} + return nil + } + var err error + *ip, err = ParseAddr(string(text)) + return err +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +// It returns a zero-length slice for the zero Addr, +// the 4-byte form for an IPv4 address, +// and the 16-byte form with zone appended for an IPv6 address. +func (ip Addr) MarshalBinary() ([]byte, error) { + switch ip.z { + case z0: + return nil, nil + case z4: + b := ip.As4() + return b[:], nil + default: + b16 := ip.As16() + b := b16[:] + if z := ip.Zone(); z != "" { + b = append(b, []byte(z)...) + } + return b, nil + } +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// It expects data in the form generated by MarshalBinary. +func (ip *Addr) UnmarshalBinary(b []byte) error { + n := len(b) + switch { + case n == 0: + *ip = Addr{} + return nil + case n == 4: + *ip = AddrFrom4(*(*[4]byte)(b)) + return nil + case n == 16: + *ip = ipv6Slice(b) + return nil + case n > 16: + *ip = ipv6Slice(b[:16]).WithZone(string(b[16:])) + return nil + } + return errors.New("unexpected slice size") +} + +// AddrPort is an IP and a port number. +type AddrPort struct { + ip Addr + port uint16 +} + +// AddrPortFrom returns an AddrPort with the provided IP and port. +// It does not allocate. +func AddrPortFrom(ip Addr, port uint16) AddrPort { return AddrPort{ip: ip, port: port} } + +// Addr returns p's IP address. +func (p AddrPort) Addr() Addr { return p.ip } + +// Port returns p's port. +func (p AddrPort) Port() uint16 { return p.port } + +// splitAddrPort splits s into an IP address string and a port +// string. It splits strings shaped like "foo:bar" or "[foo]:bar", +// without further validating the substrings. v6 indicates whether the +// ip string should parse as an IPv6 address or an IPv4 address, in +// order for s to be a valid ip:port string. +func splitAddrPort(s string) (ip, port string, v6 bool, err error) { + i := stringsLastIndexByte(s, ':') + if i == -1 { + return "", "", false, errors.New("not an ip:port") + } + + ip, port = s[:i], s[i+1:] + if len(ip) == 0 { + return "", "", false, errors.New("no IP") + } + if len(port) == 0 { + return "", "", false, errors.New("no port") + } + if ip[0] == '[' { + if len(ip) < 2 || ip[len(ip)-1] != ']' { + return "", "", false, errors.New("missing ]") + } + ip = ip[1 : len(ip)-1] + v6 = true + } + + return ip, port, v6, nil +} + +// ParseAddrPort parses s as an AddrPort. +// +// It doesn't do any name resolution: both the address and the port +// must be numeric. +func ParseAddrPort(s string) (AddrPort, error) { + var ipp AddrPort + ip, port, v6, err := splitAddrPort(s) + if err != nil { + return ipp, err + } + port16, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return ipp, errors.New("invalid port " + strconv.Quote(port) + " parsing " + strconv.Quote(s)) + } + ipp.port = uint16(port16) + ipp.ip, err = ParseAddr(ip) + if err != nil { + return AddrPort{}, err + } + if v6 && ipp.ip.Is4() { + return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", square brackets can only be used with IPv6 addresses") + } else if !v6 && ipp.ip.Is6() { + return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", IPv6 addresses must be surrounded by square brackets") + } + return ipp, nil +} + +// MustParseAddrPort calls ParseAddrPort(s) and panics on error. +// It is intended for use in tests with hard-coded strings. +func MustParseAddrPort(s string) AddrPort { + ip, err := ParseAddrPort(s) + if err != nil { + panic(err) + } + return ip +} + +// isZero reports whether p is the zero AddrPort. +func (p AddrPort) isZero() bool { return p == AddrPort{} } + +// IsValid reports whether p.IP() is valid. +// All ports are valid, including zero. +func (p AddrPort) IsValid() bool { return p.ip.IsValid() } + +func (p AddrPort) String() string { + switch p.ip.z { + case z0: + return "invalid AddrPort" + case z4: + a := p.ip.As4() + buf := make([]byte, 0, 21) + for i := range a { + buf = strconv.AppendUint(buf, uint64(a[i]), 10) + buf = append(buf, "...:"[i]) + } + buf = strconv.AppendUint(buf, uint64(p.port), 10) + return string(buf) + default: + // TODO: this could be more efficient allocation-wise: + return joinHostPort(p.ip.String(), itoa.Itoa(int(p.port))) + } +} + +func joinHostPort(host, port string) string { + // We assume that host is a literal IPv6 address if host has + // colons. + if bytealg.IndexByteString(host, ':') >= 0 { + return "[" + host + "]:" + port + } + return host + ":" + port +} + +// AppendTo appends a text encoding of p, +// as generated by MarshalText, +// to b and returns the extended buffer. +func (p AddrPort) AppendTo(b []byte) []byte { + switch p.ip.z { + case z0: + return b + case z4: + b = p.ip.appendTo4(b) + default: + b = append(b, '[') + b = p.ip.appendTo6(b) + b = append(b, ']') + } + b = append(b, ':') + b = strconv.AppendInt(b, int64(p.port), 10) + return b +} + +// MarshalText implements the encoding.TextMarshaler interface. The +// encoding is the same as returned by String, with one exception: if +// p.Addr() is the zero Addr, the encoding is the empty string. +func (p AddrPort) MarshalText() ([]byte, error) { + var max int + switch p.ip.z { + case z0: + case z4: + max = len("255.255.255.255:65535") + default: + max = len("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0]:65535") + } + b := make([]byte, 0, max) + b = p.AppendTo(b) + return b, nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler +// interface. The AddrPort is expected in a form +// generated by MarshalText or accepted by ParseAddrPort. +func (p *AddrPort) UnmarshalText(text []byte) error { + if len(text) == 0 { + *p = AddrPort{} + return nil + } + var err error + *p, err = ParseAddrPort(string(text)) + return err +} + +// Prefix is an IP address prefix (CIDR) representing an IP network. +// +// The first Bits() of Addr() are specified. The remaining bits match any address. +// The range of Bits() is [0,32] for IPv4 or [0,128] for IPv6. +type Prefix struct { + ip Addr + + // bits is logically a uint8 (storing [0,128]) but also + // encodes an "invalid" bit, currently represented by the + // invalidPrefixBits sentinel value. It could be packed into + // the uint8 more with more comlicated expressions in the + // accessors, but the extra byte (in padding anyway) doesn't + // hurt and simplifies code below. + bits int16 +} + +// invalidPrefixBits is the Prefix.bits value used when PrefixFrom is +// outside the range of a uint8. It's returned as the int -1 in the +// public API. +const invalidPrefixBits = -1 + +// PrefixFrom returns an Prefix with the provided IP address and bit +// prefix length. +// +// It does not allocate. Unlike Addr.Prefix, PrefixFrom does not mask +// off the host bits of ip. +// +// If bits is less than zero or greater than ip.BitLen, Prefix.Bits +// will return an invalid value -1. +func PrefixFrom(ip Addr, bits int) Prefix { + if bits < 0 || bits > ip.BitLen() { + bits = invalidPrefixBits + } + b16 := int16(bits) + return Prefix{ + ip: ip.withoutZone(), + bits: b16, + } +} + +// Addr returns p's IP address. +func (p Prefix) Addr() Addr { return p.ip } + +// Bits returns p's prefix length. +// +// It reports -1 if invalid. +func (p Prefix) Bits() int { return int(p.bits) } + +// IsValid reports whether whether p.Bits() has a valid range for p.IP(). +// If p.Addr() is the zero Addr, IsValid returns false. +// Note that if p is the zero Prefix, then p.IsValid() == false. +func (p Prefix) IsValid() bool { return !p.ip.isZero() && p.bits >= 0 && int(p.bits) <= p.ip.BitLen() } + +func (p Prefix) isZero() bool { return p == Prefix{} } + +// IsSingleIP reports whether p contains exactly one IP. +func (p Prefix) IsSingleIP() bool { return p.bits != 0 && int(p.bits) == p.ip.BitLen() } + +// ParsePrefix parses s as an IP address prefix. +// The string can be in the form "192.168.1.0/24" or "2001::db8::/32", +// the CIDR notation defined in RFC 4632 and RFC 4291. +// +// Note that masked address bits are not zeroed. Use Masked for that. +func ParsePrefix(s string) (Prefix, error) { + i := stringsLastIndexByte(s, '/') + if i < 0 { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): no '/'") + } + ip, err := ParseAddr(s[:i]) + if err != nil { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): " + err.Error()) + } + bitsStr := s[i+1:] + bits, err := strconv.Atoi(bitsStr) + if err != nil { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + ": bad bits after slash: " + strconv.Quote(bitsStr)) + } + maxBits := 32 + if ip.Is6() { + maxBits = 128 + } + if bits < 0 || bits > maxBits { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + ": prefix length out of range") + } + return PrefixFrom(ip, bits), nil +} + +// MustParsePrefix calls ParsePrefix(s) and panics on error. +// It is intended for use in tests with hard-coded strings. +func MustParsePrefix(s string) Prefix { + ip, err := ParsePrefix(s) + if err != nil { + panic(err) + } + return ip +} + +// Masked returns p in its canonical form, with all but the high +// p.Bits() bits of p.Addr() masked off. +// +// If p is zero or otherwise invalid, Masked returns the zero Prefix. +func (p Prefix) Masked() Prefix { + if m, err := p.ip.Prefix(int(p.bits)); err == nil { + return m + } + return Prefix{} +} + +// Contains reports whether the network p includes ip. +// +// An IPv4 address will not match an IPv6 prefix. +// A v6-mapped IPv6 address will not match an IPv4 prefix. +// A zero-value IP will not match any prefix. +// If ip has an IPv6 zone, Contains returns false, +// because Prefixes strip zones. +func (p Prefix) Contains(ip Addr) bool { + if !p.IsValid() || ip.hasZone() { + return false + } + if f1, f2 := p.ip.BitLen(), ip.BitLen(); f1 == 0 || f2 == 0 || f1 != f2 { + return false + } + if ip.Is4() { + // xor the IP addresses together; mismatched bits are now ones. + // Shift away the number of bits we don't care about. + // Shifts in Go are more efficient if the compiler can prove + // that the shift amount is smaller than the width of the shifted type (64 here). + // We know that p.bits is in the range 0..32 because p is Valid; + // the compiler doesn't know that, so mask with 63 to help it. + // Now truncate to 32 bits, because this is IPv4. + // If all the bits we care about are equal, the result will be zero. + return uint32((ip.addr.lo^p.ip.addr.lo)>>((32-p.bits)&63)) == 0 + } else { + // xor the IP addresses together. + // Mask away the bits we don't care about. + // If all the bits we care about are equal, the result will be zero. + return ip.addr.xor(p.ip.addr).and(mask6(int(p.bits))).isZero() + } +} + +// Overlaps reports whether p and o contain any IP addresses in common. +// +// If p and o are of different address families or either have a zero +// IP, it reports false. Like the Contains method, a prefix with a +// v6-mapped IPv4 IP is still treated as an IPv6 mask. +func (p Prefix) Overlaps(o Prefix) bool { + if !p.IsValid() || !o.IsValid() { + return false + } + if p == o { + return true + } + if p.ip.Is4() != o.ip.Is4() { + return false + } + var minBits int16 + if p.bits < o.bits { + minBits = p.bits + } else { + minBits = o.bits + } + if minBits == 0 { + return true + } + // One of these Prefix calls might look redundant, but we don't require + // that p and o values are normalized (via Prefix.Masked) first, + // so the Prefix call on the one that's already minBits serves to zero + // out any remaining bits in IP. + var err error + if p, err = p.ip.Prefix(int(minBits)); err != nil { + return false + } + if o, err = o.ip.Prefix(int(minBits)); err != nil { + return false + } + return p.ip == o.ip +} + +// AppendTo appends a text encoding of p, +// as generated by MarshalText, +// to b and returns the extended buffer. +func (p Prefix) AppendTo(b []byte) []byte { + if p.isZero() { + return b + } + if !p.IsValid() { + return append(b, "invalid Prefix"...) + } + + // p.ip is non-nil, because p is valid. + if p.ip.z == z4 { + b = p.ip.appendTo4(b) + } else { + b = p.ip.appendTo6(b) + } + + b = append(b, '/') + b = appendDecimal(b, uint8(p.bits)) + return b +} + +// MarshalText implements the encoding.TextMarshaler interface, +// The encoding is the same as returned by String, with one exception: +// If p is the zero value, the encoding is the empty string. +func (p Prefix) MarshalText() ([]byte, error) { + var max int + switch p.ip.z { + case z0: + case z4: + max = len("255.255.255.255/32") + default: + max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0/128") + } + b := make([]byte, 0, max) + b = p.AppendTo(b) + return b, nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The IP address is expected in a form accepted by ParsePrefix +// or generated by MarshalText. +func (p *Prefix) UnmarshalText(text []byte) error { + if len(text) == 0 { + *p = Prefix{} + return nil + } + var err error + *p, err = ParsePrefix(string(text)) + return err +} + +// String returns the CIDR notation of p: "/". +func (p Prefix) String() string { + if !p.IsValid() { + return "invalid Prefix" + } + return p.ip.String() + "/" + itoa.Itoa(int(p.bits)) +} diff --git a/src/net/netip/netip_pkg_test.go b/src/net/netip/netip_pkg_test.go new file mode 100644 index 0000000000..f5cd9ee86d --- /dev/null +++ b/src/net/netip/netip_pkg_test.go @@ -0,0 +1,359 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import ( + "bytes" + "encoding" + "encoding/json" + "strings" + "testing" +) + +var ( + mustPrefix = MustParsePrefix + mustIP = MustParseAddr +) + +func TestPrefixValid(t *testing.T) { + v4 := MustParseAddr("1.2.3.4") + v6 := MustParseAddr("::1") + tests := []struct { + ipp Prefix + want bool + }{ + {Prefix{v4, -2}, false}, + {Prefix{v4, -1}, false}, + {Prefix{v4, 0}, true}, + {Prefix{v4, 32}, true}, + {Prefix{v4, 33}, false}, + + {Prefix{v6, -2}, false}, + {Prefix{v6, -1}, false}, + {Prefix{v6, 0}, true}, + {Prefix{v6, 32}, true}, + {Prefix{v6, 128}, true}, + {Prefix{v6, 129}, false}, + + {Prefix{Addr{}, -2}, false}, + {Prefix{Addr{}, -1}, false}, + {Prefix{Addr{}, 0}, false}, + {Prefix{Addr{}, 32}, false}, + {Prefix{Addr{}, 128}, false}, + } + for _, tt := range tests { + got := tt.ipp.IsValid() + if got != tt.want { + t.Errorf("(%v).IsValid() = %v want %v", tt.ipp, got, tt.want) + } + } +} + +var nextPrevTests = []struct { + ip Addr + next Addr + prev Addr +}{ + {mustIP("10.0.0.1"), mustIP("10.0.0.2"), mustIP("10.0.0.0")}, + {mustIP("10.0.0.255"), mustIP("10.0.1.0"), mustIP("10.0.0.254")}, + {mustIP("127.0.0.1"), mustIP("127.0.0.2"), mustIP("127.0.0.0")}, + {mustIP("254.255.255.255"), mustIP("255.0.0.0"), mustIP("254.255.255.254")}, + {mustIP("255.255.255.255"), Addr{}, mustIP("255.255.255.254")}, + {mustIP("0.0.0.0"), mustIP("0.0.0.1"), Addr{}}, + {mustIP("::"), mustIP("::1"), Addr{}}, + {mustIP("::%x"), mustIP("::1%x"), Addr{}}, + {mustIP("::1"), mustIP("::2"), mustIP("::")}, + {mustIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"), Addr{}, mustIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe")}, +} + +func TestIPNextPrev(t *testing.T) { + doNextPrev(t) + + for _, ip := range []Addr{ + mustIP("0.0.0.0"), + mustIP("::"), + } { + got := ip.Prev() + if !got.isZero() { + t.Errorf("IP(%v).Prev = %v; want zero", ip, got) + } + } + + var allFF [16]byte + for i := range allFF { + allFF[i] = 0xff + } + + for _, ip := range []Addr{ + mustIP("255.255.255.255"), + AddrFrom16(allFF), + } { + got := ip.Next() + if !got.isZero() { + t.Errorf("IP(%v).Next = %v; want zero", ip, got) + } + } +} + +func BenchmarkIPNextPrev(b *testing.B) { + for i := 0; i < b.N; i++ { + doNextPrev(b) + } +} + +func doNextPrev(t testing.TB) { + for _, tt := range nextPrevTests { + gnext, gprev := tt.ip.Next(), tt.ip.Prev() + if gnext != tt.next { + t.Errorf("IP(%v).Next = %v; want %v", tt.ip, gnext, tt.next) + } + if gprev != tt.prev { + t.Errorf("IP(%v).Prev = %v; want %v", tt.ip, gprev, tt.prev) + } + if !tt.ip.Next().isZero() && tt.ip.Next().Prev() != tt.ip { + t.Errorf("IP(%v).Next.Prev = %v; want %v", tt.ip, tt.ip.Next().Prev(), tt.ip) + } + if !tt.ip.Prev().isZero() && tt.ip.Prev().Next() != tt.ip { + t.Errorf("IP(%v).Prev.Next = %v; want %v", tt.ip, tt.ip.Prev().Next(), tt.ip) + } + } +} + +func TestIPBitLen(t *testing.T) { + tests := []struct { + ip Addr + want int + }{ + {Addr{}, 0}, + {mustIP("0.0.0.0"), 32}, + {mustIP("10.0.0.1"), 32}, + {mustIP("::"), 128}, + {mustIP("fed0::1"), 128}, + {mustIP("::ffff:10.0.0.1"), 128}, + } + for _, tt := range tests { + got := tt.ip.BitLen() + if got != tt.want { + t.Errorf("BitLen(%v) = %d; want %d", tt.ip, got, tt.want) + } + } +} + +func TestPrefixContains(t *testing.T) { + tests := []struct { + ipp Prefix + ip Addr + want bool + }{ + {mustPrefix("9.8.7.6/0"), mustIP("9.8.7.6"), true}, + {mustPrefix("9.8.7.6/16"), mustIP("9.8.7.6"), true}, + {mustPrefix("9.8.7.6/16"), mustIP("9.8.6.4"), true}, + {mustPrefix("9.8.7.6/16"), mustIP("9.9.7.6"), false}, + {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.6"), true}, + {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.7"), false}, + {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.7"), false}, + {mustPrefix("::1/0"), mustIP("::1"), true}, + {mustPrefix("::1/0"), mustIP("::2"), true}, + {mustPrefix("::1/127"), mustIP("::1"), true}, + {mustPrefix("::1/127"), mustIP("::2"), false}, + {mustPrefix("::1/128"), mustIP("::1"), true}, + {mustPrefix("::1/127"), mustIP("::2"), false}, + // zones support + {mustPrefix("::1%a/128"), mustIP("::1"), true}, // prefix zones are stripped... + {mustPrefix("::1%a/128"), mustIP("::1%a"), false}, // but ip zones are not + // invalid IP + {mustPrefix("::1/0"), Addr{}, false}, + {mustPrefix("1.2.3.4/0"), Addr{}, false}, + // invalid Prefix + {Prefix{mustIP("::1"), 129}, mustIP("::1"), false}, + {Prefix{mustIP("1.2.3.4"), 33}, mustIP("1.2.3.4"), false}, + {Prefix{Addr{}, 0}, mustIP("1.2.3.4"), false}, + {Prefix{Addr{}, 32}, mustIP("1.2.3.4"), false}, + {Prefix{Addr{}, 128}, mustIP("::1"), false}, + // wrong IP family + {mustPrefix("::1/0"), mustIP("1.2.3.4"), false}, + {mustPrefix("1.2.3.4/0"), mustIP("::1"), false}, + } + for _, tt := range tests { + got := tt.ipp.Contains(tt.ip) + if got != tt.want { + t.Errorf("(%v).Contains(%v) = %v want %v", tt.ipp, tt.ip, got, tt.want) + } + } +} + +func TestParseIPError(t *testing.T) { + tests := []struct { + ip string + errstr string + }{ + { + ip: "localhost", + }, + { + ip: "500.0.0.1", + errstr: "field has value >255", + }, + { + ip: "::gggg%eth0", + errstr: "must have at least one digit", + }, + { + ip: "fe80::1cc0:3e8c:119f:c2e1%", + errstr: "zone must be a non-empty string", + }, + { + ip: "%eth0", + errstr: "missing IPv6 address", + }, + } + for _, test := range tests { + t.Run(test.ip, func(t *testing.T) { + _, err := ParseAddr(test.ip) + if err == nil { + t.Fatal("no error") + } + if _, ok := err.(parseAddrError); !ok { + t.Errorf("error type is %T, want parseIPError", err) + } + if test.errstr == "" { + test.errstr = "unable to parse IP" + } + if got := err.Error(); !strings.Contains(got, test.errstr) { + t.Errorf("error is missing substring %q: %s", test.errstr, got) + } + }) + } +} + +func TestParseAddrPort(t *testing.T) { + tests := []struct { + in string + want AddrPort + wantErr bool + }{ + {in: "1.2.3.4:1234", want: AddrPort{mustIP("1.2.3.4"), 1234}}, + {in: "1.1.1.1:123456", wantErr: true}, + {in: "1.1.1.1:-123", wantErr: true}, + {in: "[::1]:1234", want: AddrPort{mustIP("::1"), 1234}}, + {in: "[1.2.3.4]:1234", wantErr: true}, + {in: "fe80::1:1234", wantErr: true}, + {in: ":0", wantErr: true}, // if we need to parse this form, there should be a separate function that explicitly allows it + } + for _, test := range tests { + t.Run(test.in, func(t *testing.T) { + got, err := ParseAddrPort(test.in) + if err != nil { + if test.wantErr { + return + } + t.Fatal(err) + } + if got != test.want { + t.Errorf("got %v; want %v", got, test.want) + } + if got.String() != test.in { + t.Errorf("String = %q; want %q", got.String(), test.in) + } + }) + + t.Run(test.in+"/AppendTo", func(t *testing.T) { + got, err := ParseAddrPort(test.in) + if err == nil { + testAppendToMarshal(t, got) + } + }) + + // TextMarshal and TextUnmarshal mostly behave like + // ParseAddrPort and String. Divergent behavior are handled in + // TestAddrPortMarshalUnmarshal. + t.Run(test.in+"/Marshal", func(t *testing.T) { + var got AddrPort + jsin := `"` + test.in + `"` + err := json.Unmarshal([]byte(jsin), &got) + if err != nil { + if test.wantErr { + return + } + t.Fatal(err) + } + if got != test.want { + t.Errorf("got %v; want %v", got, test.want) + } + gotb, err := json.Marshal(got) + if err != nil { + t.Fatal(err) + } + if string(gotb) != jsin { + t.Errorf("Marshal = %q; want %q", string(gotb), jsin) + } + }) + } +} + +func TestAddrPortMarshalUnmarshal(t *testing.T) { + tests := []struct { + in string + want AddrPort + }{ + {"", AddrPort{}}, + } + + for _, test := range tests { + t.Run(test.in, func(t *testing.T) { + orig := `"` + test.in + `"` + + var ipp AddrPort + if err := json.Unmarshal([]byte(orig), &ipp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + ippb, err := json.Marshal(ipp) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + back := string(ippb) + if orig != back { + t.Errorf("Marshal = %q; want %q", back, orig) + } + + testAppendToMarshal(t, ipp) + }) + } +} + +type appendMarshaler interface { + encoding.TextMarshaler + AppendTo([]byte) []byte +} + +// testAppendToMarshal tests that x's AppendTo and MarshalText methods yield the same results. +// x's MarshalText method must not return an error. +func testAppendToMarshal(t *testing.T, x appendMarshaler) { + t.Helper() + m, err := x.MarshalText() + if err != nil { + t.Fatalf("(%v).MarshalText: %v", x, err) + } + a := make([]byte, 0, len(m)) + a = x.AppendTo(a) + if !bytes.Equal(m, a) { + t.Errorf("(%v).MarshalText = %q, (%v).AppendTo = %q", x, m, x, a) + } +} + +func TestIPv6Accessor(t *testing.T) { + var a [16]byte + for i := range a { + a[i] = uint8(i) + 1 + } + ip := AddrFrom16(a) + for i := range a { + if got, want := ip.v6(uint8(i)), uint8(i)+1; got != want { + t.Errorf("v6(%v) = %v; want %v", i, got, want) + } + } +} diff --git a/src/net/netip/netip_test.go b/src/net/netip/netip_test.go new file mode 100644 index 0000000000..5d935c8fd3 --- /dev/null +++ b/src/net/netip/netip_test.go @@ -0,0 +1,1798 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip_test + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "internal/intern" + "net" + . "net/netip" + "reflect" + "sort" + "strings" + "testing" +) + +var long = flag.Bool("long", false, "run long tests") + +type uint128 = Uint128 + +var ( + mustPrefix = MustParsePrefix + mustIP = MustParseAddr +) + +func TestParseAddr(t *testing.T) { + var validIPs = []struct { + in string + ip Addr // output of ParseAddr() + str string // output of String(). If "", use in. + }{ + // Basic zero IPv4 address. + { + in: "0.0.0.0", + ip: MkAddr(Mk128(0, 0xffff00000000), Z4), + }, + // Basic non-zero IPv4 address. + { + in: "192.168.140.255", + ip: MkAddr(Mk128(0, 0xffffc0a88cff), Z4), + }, + // IPv4 address in windows-style "print all the digits" form. + { + in: "010.000.015.001", + ip: MkAddr(Mk128(0, 0xffff0a000f01), Z4), + str: "10.0.15.1", + }, + // IPv4 address with a silly amount of leading zeros. + { + in: "000001.00000002.00000003.000000004", + ip: MkAddr(Mk128(0, 0xffff01020304), Z4), + str: "1.2.3.4", + }, + // Basic zero IPv6 address. + { + in: "::", + ip: MkAddr(Mk128(0, 0), Z6noz), + }, + // Localhost IPv6. + { + in: "::1", + ip: MkAddr(Mk128(0, 1), Z6noz), + }, + // Fully expanded IPv6 address. + { + in: "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b", + ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b430b), Z6noz), + }, + // IPv6 with elided fields in the middle. + { + in: "fd7a:115c::626b:430b", + ip: MkAddr(Mk128(0xfd7a115c00000000, 0x00000000626b430b), Z6noz), + }, + // IPv6 with elided fields at the end. + { + in: "fd7a:115c:a1e0:ab12:4843:cd96::", + ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd9600000000), Z6noz), + }, + // IPv6 with single elided field at the end. + { + in: "fd7a:115c:a1e0:ab12:4843:cd96:626b::", + ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b0000), Z6noz), + str: "fd7a:115c:a1e0:ab12:4843:cd96:626b:0", + }, + // IPv6 with single elided field in the middle. + { + in: "fd7a:115c:a1e0::4843:cd96:626b:430b", + ip: MkAddr(Mk128(0xfd7a115ca1e00000, 0x4843cd96626b430b), Z6noz), + str: "fd7a:115c:a1e0:0:4843:cd96:626b:430b", + }, + // IPv6 with the trailing 32 bits written as IPv4 dotted decimal. (4in6) + { + in: "::ffff:192.168.140.255", + ip: MkAddr(Mk128(0, 0x0000ffffc0a88cff), Z6noz), + str: "::ffff:192.168.140.255", + }, + // IPv6 with a zone specifier. + { + in: "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b%eth0", + ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b430b), intern.Get("eth0")), + }, + // IPv6 with dotted decimal and zone specifier. + { + in: "1:2::ffff:192.168.140.255%eth1", + ip: MkAddr(Mk128(0x0001000200000000, 0x0000ffffc0a88cff), intern.Get("eth1")), + str: "1:2::ffff:c0a8:8cff%eth1", + }, + // IPv6 with capital letters. + { + in: "FD9E:1A04:F01D::1", + ip: MkAddr(Mk128(0xfd9e1a04f01d0000, 0x1), Z6noz), + str: "fd9e:1a04:f01d::1", + }, + } + + for _, test := range validIPs { + t.Run(test.in, func(t *testing.T) { + got, err := ParseAddr(test.in) + if err != nil { + t.Fatal(err) + } + if got != test.ip { + t.Errorf("ParseAddr(%q) got %#v, want %#v", test.in, got, test.ip) + } + + // Check that ParseAddr is a pure function. + got2, err := ParseAddr(test.in) + if err != nil { + t.Fatal(err) + } + if got != got2 { + t.Errorf("ParseAddr(%q) got 2 different results: %#v, %#v", test.in, got, got2) + } + + // Check that ParseAddr(ip.String()) is the identity function. + s := got.String() + got3, err := ParseAddr(s) + if err != nil { + t.Fatal(err) + } + if got != got3 { + t.Errorf("ParseAddr(%q) != ParseAddr(ParseIP(%q).String()). Got %#v, want %#v", test.in, test.in, got3, got) + } + + // Check that the slow-but-readable parser produces the same result. + slow, err := parseIPSlow(test.in) + if err != nil { + t.Fatal(err) + } + if got != slow { + t.Errorf("ParseAddr(%q) = %#v, parseIPSlow(%q) = %#v", test.in, got, test.in, slow) + } + + // Check that the parsed IP formats as expected. + s = got.String() + wants := test.str + if wants == "" { + wants = test.in + } + if s != wants { + t.Errorf("ParseAddr(%q).String() got %q, want %q", test.in, s, wants) + } + + // Check that AppendTo matches MarshalText. + TestAppendToMarshal(t, got) + + // Check that MarshalText/UnmarshalText work similarly to + // ParseAddr/String (see TestIPMarshalUnmarshal for + // marshal-specific behavior that's not common with + // ParseAddr/String). + js := `"` + test.in + `"` + var jsgot Addr + if err := json.Unmarshal([]byte(js), &jsgot); err != nil { + t.Fatal(err) + } + if jsgot != got { + t.Errorf("json.Unmarshal(%q) = %#v, want %#v", test.in, jsgot, got) + } + jsb, err := json.Marshal(jsgot) + if err != nil { + t.Fatal(err) + } + jswant := `"` + wants + `"` + jsback := string(jsb) + if jsback != jswant { + t.Errorf("Marshal(Unmarshal(%q)) = %s, want %s", test.in, jsback, jswant) + } + }) + } + + var invalidIPs = []string{ + // Empty string + "", + // Garbage non-IP + "bad", + // Single number. Some parsers accept this as an IPv4 address in + // big-endian uint32 form, but we don't. + "1234", + // IPv4 with a zone specifier + "1.2.3.4%eth0", + // IPv4 field must have at least one digit + ".1.2.3", + "1.2.3.", + "1..2.3", + // IPv4 address too long + "1.2.3.4.5", + // IPv4 in dotted octal form + "0300.0250.0214.0377", + // IPv4 in dotted hex form + "0xc0.0xa8.0x8c.0xff", + // IPv4 in class B form + "192.168.12345", + // IPv4 in class B form, with a small enough number to be + // parseable as a regular dotted decimal field. + "127.0.1", + // IPv4 in class A form + "192.1234567", + // IPv4 in class A form, with a small enough number to be + // parseable as a regular dotted decimal field. + "127.1", + // IPv4 field has value >255 + "192.168.300.1", + // IPv4 with too many fields + "192.168.0.1.5.6", + // IPv6 with not enough fields + "1:2:3:4:5:6:7", + // IPv6 with too many fields + "1:2:3:4:5:6:7:8:9", + // IPv6 with 8 fields and a :: expander + "1:2:3:4::5:6:7:8", + // IPv6 with a field bigger than 2b + "fe801::1", + // IPv6 with non-hex values in field + "fe80:tail:scal:e::", + // IPv6 with a zone delimiter but no zone. + "fe80::1%", + // IPv6 (without ellipsis) with too many fields for trailing embedded IPv4. + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255", + // IPv6 (with ellipsis) with too many fields for trailing embedded IPv4. + "ffff::ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255", + // IPv6 with invalid embedded IPv4. + "::ffff:192.168.140.bad", + // IPv6 with multiple ellipsis ::. + "fe80::1::1", + // IPv6 with invalid non hex/colon character. + "fe80:1?:1", + // IPv6 with truncated bytes after single colon. + "fe80:", + } + + for _, s := range invalidIPs { + t.Run(s, func(t *testing.T) { + got, err := ParseAddr(s) + if err == nil { + t.Errorf("ParseAddr(%q) = %#v, want error", s, got) + } + + slow, err := parseIPSlow(s) + if err == nil { + t.Errorf("parseIPSlow(%q) = %#v, want error", s, slow) + } + + std := net.ParseIP(s) + if std != nil { + t.Errorf("net.ParseIP(%q) = %#v, want error", s, std) + } + + if s == "" { + // Don't test unmarshaling of "" here, do it in + // IPMarshalUnmarshal. + return + } + var jsgot Addr + js := []byte(`"` + s + `"`) + if err := json.Unmarshal(js, &jsgot); err == nil { + t.Errorf("json.Unmarshal(%q) = %#v, want error", s, jsgot) + } + }) + } +} + +func TestIPv4Constructors(t *testing.T) { + if AddrFrom4([4]byte{1, 2, 3, 4}) != MustParseAddr("1.2.3.4") { + t.Errorf("don't match") + } +} + +func TestAddrMarshalUnmarshalBinary(t *testing.T) { + tests := []struct { + ip string + wantSize int + }{ + {"", 0}, // zero IP + {"1.2.3.4", 4}, + {"fd7a:115c:a1e0:ab12:4843:cd96:626b:430b", 16}, + {"::ffff:c000:0280", 16}, + {"::ffff:c000:0280%eth0", 20}, + } + for _, tc := range tests { + var ip Addr + if len(tc.ip) > 0 { + ip = mustIP(tc.ip) + } + b, err := ip.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) != tc.wantSize { + t.Fatalf("%q encoded to size %d; want %d", tc.ip, len(b), tc.wantSize) + } + var ip2 Addr + if err := ip2.UnmarshalBinary(b); err != nil { + t.Fatal(err) + } + if ip != ip2 { + t.Fatalf("got %v; want %v", ip2, ip) + } + } + + // Cannot unmarshal from unexpected IP length. + for _, l := range []int{3, 5} { + var ip2 Addr + if err := ip2.UnmarshalBinary(bytes.Repeat([]byte{1}, l)); err == nil { + t.Fatalf("unmarshaled from unexpected IP length %d", l) + } + } +} + +func TestAddrMarshalUnmarshal(t *testing.T) { + // This only tests the cases where Marshal/Unmarshal diverges from + // the behavior of ParseAddr/String. For the rest of the test cases, + // see TestParseAddr above. + orig := `""` + var ip Addr + if err := json.Unmarshal([]byte(orig), &ip); err != nil { + t.Fatalf("Unmarshal(%q) got error %v", orig, err) + } + if ip != (Addr{}) { + t.Errorf("Unmarshal(%q) is not the zero Addr", orig) + } + + jsb, err := json.Marshal(ip) + if err != nil { + t.Fatalf("Marshal(%v) got error %v", ip, err) + } + back := string(jsb) + if back != orig { + t.Errorf("Marshal(Unmarshal(%q)) got %q, want %q", orig, back, orig) + } +} + +func TestAddrFrom16(t *testing.T) { + tests := []struct { + name string + in [16]byte + want Addr + }{ + { + name: "v6-raw", + in: [...]byte{15: 1}, + want: MkAddr(Mk128(0, 1), Z6noz), + }, + { + name: "v4-raw", + in: [...]byte{10: 0xff, 11: 0xff, 12: 1, 13: 2, 14: 3, 15: 4}, + want: MkAddr(Mk128(0, 0xffff01020304), Z6noz), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AddrFrom16(tt.in) + if got != tt.want { + t.Errorf("got %#v; want %#v", got, tt.want) + } + }) + } +} + +func TestIPProperties(t *testing.T) { + var ( + nilIP Addr + + unicast4 = mustIP("192.0.2.1") + unicast6 = mustIP("2001:db8::1") + unicastZone6 = mustIP("2001:db8::1%eth0") + unicast6Unassigned = mustIP("4000::1") // not in 2000::/3. + + multicast4 = mustIP("224.0.0.1") + multicast6 = mustIP("ff02::1") + multicastZone6 = mustIP("ff02::1%eth0") + + llu4 = mustIP("169.254.0.1") + llu6 = mustIP("fe80::1") + llu6Last = mustIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + lluZone6 = mustIP("fe80::1%eth0") + + loopback4 = mustIP("127.0.0.1") + loopback6 = mustIP("::1") + + ilm6 = mustIP("ff01::1") + ilmZone6 = mustIP("ff01::1%eth0") + + private4a = mustIP("10.0.0.1") + private4b = mustIP("172.16.0.1") + private4c = mustIP("192.168.1.1") + private6 = mustIP("fd00::1") + + unspecified4 = AddrFrom4([4]byte{}) + unspecified6 = IPv6Unspecified() + ) + + tests := []struct { + name string + ip Addr + globalUnicast bool + interfaceLocalMulticast bool + linkLocalMulticast bool + linkLocalUnicast bool + loopback bool + multicast bool + private bool + unspecified bool + }{ + { + name: "nil", + ip: nilIP, + }, + { + name: "unicast v4Addr", + ip: unicast4, + globalUnicast: true, + }, + { + name: "unicast v6Addr", + ip: unicast6, + globalUnicast: true, + }, + { + name: "unicast v6AddrZone", + ip: unicastZone6, + globalUnicast: true, + }, + { + name: "unicast v6Addr unassigned", + ip: unicast6Unassigned, + globalUnicast: true, + }, + { + name: "multicast v4Addr", + ip: multicast4, + linkLocalMulticast: true, + multicast: true, + }, + { + name: "multicast v6Addr", + ip: multicast6, + linkLocalMulticast: true, + multicast: true, + }, + { + name: "multicast v6AddrZone", + ip: multicastZone6, + linkLocalMulticast: true, + multicast: true, + }, + { + name: "link-local unicast v4Addr", + ip: llu4, + linkLocalUnicast: true, + }, + { + name: "link-local unicast v6Addr", + ip: llu6, + linkLocalUnicast: true, + }, + { + name: "link-local unicast v6Addr upper bound", + ip: llu6Last, + linkLocalUnicast: true, + }, + { + name: "link-local unicast v6AddrZone", + ip: lluZone6, + linkLocalUnicast: true, + }, + { + name: "loopback v4Addr", + ip: loopback4, + loopback: true, + }, + { + name: "loopback v6Addr", + ip: loopback6, + loopback: true, + }, + { + name: "interface-local multicast v6Addr", + ip: ilm6, + interfaceLocalMulticast: true, + multicast: true, + }, + { + name: "interface-local multicast v6AddrZone", + ip: ilmZone6, + interfaceLocalMulticast: true, + multicast: true, + }, + { + name: "private v4Addr 10/8", + ip: private4a, + globalUnicast: true, + private: true, + }, + { + name: "private v4Addr 172.16/12", + ip: private4b, + globalUnicast: true, + private: true, + }, + { + name: "private v4Addr 192.168/16", + ip: private4c, + globalUnicast: true, + private: true, + }, + { + name: "private v6Addr", + ip: private6, + globalUnicast: true, + private: true, + }, + { + name: "unspecified v4Addr", + ip: unspecified4, + unspecified: true, + }, + { + name: "unspecified v6Addr", + ip: unspecified6, + unspecified: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gu := tt.ip.IsGlobalUnicast() + if gu != tt.globalUnicast { + t.Errorf("IsGlobalUnicast(%v) = %v; want %v", tt.ip, gu, tt.globalUnicast) + } + + ilm := tt.ip.IsInterfaceLocalMulticast() + if ilm != tt.interfaceLocalMulticast { + t.Errorf("IsInterfaceLocalMulticast(%v) = %v; want %v", tt.ip, ilm, tt.interfaceLocalMulticast) + } + + llu := tt.ip.IsLinkLocalUnicast() + if llu != tt.linkLocalUnicast { + t.Errorf("IsLinkLocalUnicast(%v) = %v; want %v", tt.ip, llu, tt.linkLocalUnicast) + } + + llm := tt.ip.IsLinkLocalMulticast() + if llm != tt.linkLocalMulticast { + t.Errorf("IsLinkLocalMulticast(%v) = %v; want %v", tt.ip, llm, tt.linkLocalMulticast) + } + + lo := tt.ip.IsLoopback() + if lo != tt.loopback { + t.Errorf("IsLoopback(%v) = %v; want %v", tt.ip, lo, tt.loopback) + } + + multicast := tt.ip.IsMulticast() + if multicast != tt.multicast { + t.Errorf("IsMulticast(%v) = %v; want %v", tt.ip, multicast, tt.multicast) + } + + private := tt.ip.IsPrivate() + if private != tt.private { + t.Errorf("IsPrivate(%v) = %v; want %v", tt.ip, private, tt.private) + } + + unspecified := tt.ip.IsUnspecified() + if unspecified != tt.unspecified { + t.Errorf("IsUnspecified(%v) = %v; want %v", tt.ip, unspecified, tt.unspecified) + } + }) + } +} + +func TestAddrWellKnown(t *testing.T) { + tests := []struct { + name string + ip Addr + std net.IP + }{ + { + name: "IPv6 link-local all nodes", + ip: IPv6LinkLocalAllNodes(), + std: net.IPv6linklocalallnodes, + }, + { + name: "IPv6 unspecified", + ip: IPv6Unspecified(), + std: net.IPv6unspecified, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + want := tt.std.String() + got := tt.ip.String() + + if got != want { + t.Fatalf("got %s, want %s", got, want) + } + }) + } +} + +func TestLessCompare(t *testing.T) { + tests := []struct { + a, b Addr + want bool + }{ + {Addr{}, Addr{}, false}, + {Addr{}, mustIP("1.2.3.4"), true}, + {mustIP("1.2.3.4"), Addr{}, false}, + + {mustIP("1.2.3.4"), mustIP("0102:0304::0"), true}, + {mustIP("0102:0304::0"), mustIP("1.2.3.4"), false}, + {mustIP("1.2.3.4"), mustIP("1.2.3.4"), false}, + + {mustIP("::1"), mustIP("::2"), true}, + {mustIP("::1"), mustIP("::1%foo"), true}, + {mustIP("::1%foo"), mustIP("::2"), true}, + {mustIP("::2"), mustIP("::3"), true}, + + {mustIP("::"), mustIP("0.0.0.0"), false}, + {mustIP("0.0.0.0"), mustIP("::"), true}, + + {mustIP("::1%a"), mustIP("::1%b"), true}, + {mustIP("::1%a"), mustIP("::1%a"), false}, + {mustIP("::1%b"), mustIP("::1%a"), false}, + } + for _, tt := range tests { + got := tt.a.Less(tt.b) + if got != tt.want { + t.Errorf("Less(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want) + } + cmp := tt.a.Compare(tt.b) + if got && cmp != -1 { + t.Errorf("Less(%q, %q) = true, but Compare = %v (not -1)", tt.a, tt.b, cmp) + } + if cmp < -1 || cmp > 1 { + t.Errorf("bogus Compare return value %v", cmp) + } + if cmp == 0 && tt.a != tt.b { + t.Errorf("Compare(%q, %q) = 0; but not equal", tt.a, tt.b) + } + if cmp == 1 && !tt.b.Less(tt.a) { + t.Errorf("Compare(%q, %q) = 1; but b.Less(a) isn't true", tt.a, tt.b) + } + + // Also check inverse. + if got == tt.want && got { + got2 := tt.b.Less(tt.a) + if got2 { + t.Errorf("Less(%q, %q) was correctly %v, but so was Less(%q, %q)", tt.a, tt.b, got, tt.b, tt.a) + } + } + } + + // And just sort. + values := []Addr{ + mustIP("::1"), + mustIP("::2"), + Addr{}, + mustIP("1.2.3.4"), + mustIP("8.8.8.8"), + mustIP("::1%foo"), + } + sort.Slice(values, func(i, j int) bool { return values[i].Less(values[j]) }) + got := fmt.Sprintf("%s", values) + want := `[invalid IP 1.2.3.4 8.8.8.8 ::1 ::1%foo ::2]` + if got != want { + t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want) + } +} + +func TestIPStringExpanded(t *testing.T) { + tests := []struct { + ip Addr + s string + }{ + { + ip: Addr{}, + s: "invalid IP", + }, + { + ip: mustIP("192.0.2.1"), + s: "192.0.2.1", + }, + { + ip: mustIP("::ffff:192.0.2.1"), + s: "0000:0000:0000:0000:0000:ffff:c000:0201", + }, + { + ip: mustIP("2001:db8::1"), + s: "2001:0db8:0000:0000:0000:0000:0000:0001", + }, + { + ip: mustIP("2001:db8::1%eth0"), + s: "2001:0db8:0000:0000:0000:0000:0000:0001%eth0", + }, + } + + for _, tt := range tests { + t.Run(tt.ip.String(), func(t *testing.T) { + want := tt.s + got := tt.ip.StringExpanded() + + if got != want { + t.Fatalf("got %s, want %s", got, want) + } + }) + } +} + +func TestPrefixMasking(t *testing.T) { + type subtest struct { + ip Addr + bits uint8 + p Prefix + ok bool + } + + // makeIPv6 produces a set of IPv6 subtests with an optional zone identifier. + makeIPv6 := func(zone string) []subtest { + if zone != "" { + zone = "%" + zone + } + + return []subtest{ + { + ip: mustIP(fmt.Sprintf("2001:db8::1%s", zone)), + bits: 255, + }, + { + ip: mustIP(fmt.Sprintf("2001:db8::1%s", zone)), + bits: 32, + p: mustPrefix(fmt.Sprintf("2001:db8::%s/32", zone)), + ok: true, + }, + { + ip: mustIP(fmt.Sprintf("fe80::dead:beef:dead:beef%s", zone)), + bits: 96, + p: mustPrefix(fmt.Sprintf("fe80::dead:beef:0:0%s/96", zone)), + ok: true, + }, + { + ip: mustIP(fmt.Sprintf("aaaa::%s", zone)), + bits: 4, + p: mustPrefix(fmt.Sprintf("a000::%s/4", zone)), + ok: true, + }, + { + ip: mustIP(fmt.Sprintf("::%s", zone)), + bits: 63, + p: mustPrefix(fmt.Sprintf("::%s/63", zone)), + ok: true, + }, + } + } + + tests := []struct { + family string + subtests []subtest + }{ + { + family: "nil", + subtests: []subtest{ + { + bits: 255, + ok: true, + }, + { + bits: 16, + ok: true, + }, + }, + }, + { + family: "IPv4", + subtests: []subtest{ + { + ip: mustIP("192.0.2.0"), + bits: 255, + }, + { + ip: mustIP("192.0.2.0"), + bits: 16, + p: mustPrefix("192.0.0.0/16"), + ok: true, + }, + { + ip: mustIP("255.255.255.255"), + bits: 20, + p: mustPrefix("255.255.240.0/20"), + ok: true, + }, + { + // Partially masking one byte that contains both + // 1s and 0s on either side of the mask limit. + ip: mustIP("100.98.156.66"), + bits: 10, + p: mustPrefix("100.64.0.0/10"), + ok: true, + }, + }, + }, + { + family: "IPv6", + subtests: makeIPv6(""), + }, + { + family: "IPv6 zone", + subtests: makeIPv6("eth0"), + }, + } + + for _, tt := range tests { + t.Run(tt.family, func(t *testing.T) { + for _, st := range tt.subtests { + t.Run(st.p.String(), func(t *testing.T) { + // Ensure st.ip is not mutated. + orig := st.ip.String() + + p, err := st.ip.Prefix(int(st.bits)) + if st.ok && err != nil { + t.Fatalf("failed to produce prefix: %v", err) + } + if !st.ok && err == nil { + t.Fatal("expected an error, but none occurred") + } + if err != nil { + t.Logf("err: %v", err) + return + } + + if !reflect.DeepEqual(p, st.p) { + t.Errorf("prefix = %q, want %q", p, st.p) + } + + if got := st.ip.String(); got != orig { + t.Errorf("IP was mutated: %q, want %q", got, orig) + } + }) + } + }) + } +} + +func TestPrefixMarshalUnmarshal(t *testing.T) { + tests := []string{ + "", + "1.2.3.4/32", + "0.0.0.0/0", + "::/0", + "::1/128", + "::ffff:c000:1234/128", + "2001:db8::/32", + } + + for _, s := range tests { + t.Run(s, func(t *testing.T) { + // Ensure that JSON (and by extension, text) marshaling is + // sane by entering quoted input. + orig := `"` + s + `"` + + var p Prefix + if err := json.Unmarshal([]byte(orig), &p); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + pb, err := json.Marshal(p) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + back := string(pb) + if orig != back { + t.Errorf("Marshal = %q; want %q", back, orig) + } + }) + } +} + +func TestPrefixMarshalUnmarshalZone(t *testing.T) { + orig := `"fe80::1cc0:3e8c:119f:c2e1%ens18/128"` + unzoned := `"fe80::1cc0:3e8c:119f:c2e1/128"` + + var p Prefix + if err := json.Unmarshal([]byte(orig), &p); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + pb, err := json.Marshal(p) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + back := string(pb) + if back != unzoned { + t.Errorf("Marshal = %q; want %q", back, unzoned) + } +} + +func TestPrefixUnmarshalTextNonZero(t *testing.T) { + ip := mustPrefix("fe80::/64") + if err := ip.UnmarshalText([]byte("xxx")); err == nil { + t.Fatal("unmarshaled into non-empty Prefix") + } +} + +func TestIs4AndIs6(t *testing.T) { + tests := []struct { + ip Addr + is4 bool + is6 bool + }{ + {Addr{}, false, false}, + {mustIP("1.2.3.4"), true, false}, + {mustIP("127.0.0.2"), true, false}, + {mustIP("::1"), false, true}, + {mustIP("::ffff:192.0.2.128"), false, true}, + {mustIP("::fffe:c000:0280"), false, true}, + {mustIP("::1%eth0"), false, true}, + } + for _, tt := range tests { + got4 := tt.ip.Is4() + if got4 != tt.is4 { + t.Errorf("Is4(%q) = %v; want %v", tt.ip, got4, tt.is4) + } + + got6 := tt.ip.Is6() + if got6 != tt.is6 { + t.Errorf("Is6(%q) = %v; want %v", tt.ip, got6, tt.is6) + } + } +} + +func TestIs4In6(t *testing.T) { + tests := []struct { + ip Addr + want bool + wantUnmap Addr + }{ + {Addr{}, false, Addr{}}, + {mustIP("::ffff:c000:0280"), true, mustIP("192.0.2.128")}, + {mustIP("::ffff:192.0.2.128"), true, mustIP("192.0.2.128")}, + {mustIP("::ffff:192.0.2.128%eth0"), true, mustIP("192.0.2.128")}, + {mustIP("::fffe:c000:0280"), false, mustIP("::fffe:c000:0280")}, + {mustIP("::ffff:127.001.002.003"), true, mustIP("127.1.2.3")}, + {mustIP("::ffff:7f01:0203"), true, mustIP("127.1.2.3")}, + {mustIP("0:0:0:0:0000:ffff:127.1.2.3"), true, mustIP("127.1.2.3")}, + {mustIP("0:0:0:0:000000:ffff:127.1.2.3"), true, mustIP("127.1.2.3")}, + {mustIP("0:0:0:0::ffff:127.1.2.3"), true, mustIP("127.1.2.3")}, + {mustIP("::1"), false, mustIP("::1")}, + {mustIP("1.2.3.4"), false, mustIP("1.2.3.4")}, + } + for _, tt := range tests { + got := tt.ip.Is4In6() + if got != tt.want { + t.Errorf("Is4In6(%q) = %v; want %v", tt.ip, got, tt.want) + } + u := tt.ip.Unmap() + if u != tt.wantUnmap { + t.Errorf("Unmap(%q) = %v; want %v", tt.ip, u, tt.wantUnmap) + } + } +} + +func TestPrefixMasked(t *testing.T) { + tests := []struct { + prefix Prefix + masked Prefix + }{ + { + prefix: mustPrefix("192.168.0.255/24"), + masked: mustPrefix("192.168.0.0/24"), + }, + { + prefix: mustPrefix("2100::/3"), + masked: mustPrefix("2000::/3"), + }, + { + prefix: PrefixFrom(mustIP("2000::"), 129), + masked: Prefix{}, + }, + { + prefix: PrefixFrom(mustIP("1.2.3.4"), 33), + masked: Prefix{}, + }, + } + for _, test := range tests { + t.Run(test.prefix.String(), func(t *testing.T) { + got := test.prefix.Masked() + if got != test.masked { + t.Errorf("Masked=%s, want %s", got, test.masked) + } + }) + } +} + +func TestPrefix(t *testing.T) { + tests := []struct { + prefix string + ip Addr + bits int + str string + contains []Addr + notContains []Addr + }{ + { + prefix: "192.168.0.0/24", + ip: mustIP("192.168.0.0"), + bits: 24, + contains: mustIPs("192.168.0.1", "192.168.0.55"), + notContains: mustIPs("192.168.1.1", "1.1.1.1"), + }, + { + prefix: "192.168.1.1/32", + ip: mustIP("192.168.1.1"), + bits: 32, + contains: mustIPs("192.168.1.1"), + notContains: mustIPs("192.168.1.2"), + }, + { + prefix: "100.64.0.0/10", // CGNAT range; prefix not multiple of 8 + ip: mustIP("100.64.0.0"), + bits: 10, + contains: mustIPs("100.64.0.0", "100.64.0.1", "100.81.251.94", "100.100.100.100", "100.127.255.254", "100.127.255.255"), + notContains: mustIPs("100.63.255.255", "100.128.0.0"), + }, + { + prefix: "2001:db8::/96", + ip: mustIP("2001:db8::"), + bits: 96, + contains: mustIPs("2001:db8::aaaa:bbbb", "2001:db8::1"), + notContains: mustIPs("2001:db8::1:aaaa:bbbb", "2001:db9::"), + }, + { + prefix: "0.0.0.0/0", + ip: mustIP("0.0.0.0"), + bits: 0, + contains: mustIPs("192.168.0.1", "1.1.1.1"), + notContains: append(mustIPs("2001:db8::1"), Addr{}), + }, + { + prefix: "::/0", + ip: mustIP("::"), + bits: 0, + contains: mustIPs("::1", "2001:db8::1"), + notContains: mustIPs("192.0.2.1"), + }, + { + prefix: "2000::/3", + ip: mustIP("2000::"), + bits: 3, + contains: mustIPs("2001:db8::1"), + notContains: mustIPs("fe80::1"), + }, + { + prefix: "::%0/00/80", + ip: mustIP("::"), + bits: 80, + str: "::/80", + contains: mustIPs("::"), + notContains: mustIPs("ff::%0/00", "ff::%1/23", "::%0/00", "::%1/23"), + }, + } + for _, test := range tests { + t.Run(test.prefix, func(t *testing.T) { + prefix, err := ParsePrefix(test.prefix) + if err != nil { + t.Fatal(err) + } + if prefix.Addr() != test.ip { + t.Errorf("IP=%s, want %s", prefix.Addr(), test.ip) + } + if prefix.Bits() != test.bits { + t.Errorf("bits=%d, want %d", prefix.Bits(), test.bits) + } + for _, ip := range test.contains { + if !prefix.Contains(ip) { + t.Errorf("does not contain %s", ip) + } + } + for _, ip := range test.notContains { + if prefix.Contains(ip) { + t.Errorf("contains %s", ip) + } + } + want := test.str + if want == "" { + want = test.prefix + } + if got := prefix.String(); got != want { + t.Errorf("prefix.String()=%q, want %q", got, want) + } + + TestAppendToMarshal(t, prefix) + }) + } +} + +func TestPrefixFromInvalidBits(t *testing.T) { + v4 := MustParseAddr("1.2.3.4") + v6 := MustParseAddr("66::66") + tests := []struct { + ip Addr + in, want int + }{ + {v4, 0, 0}, + {v6, 0, 0}, + {v4, 1, 1}, + {v4, 33, -1}, + {v6, 33, 33}, + {v6, 127, 127}, + {v6, 128, 128}, + {v4, 254, -1}, + {v4, 255, -1}, + {v4, -1, -1}, + {v6, -1, -1}, + {v4, -5, -1}, + {v6, -5, -1}, + } + for _, tt := range tests { + p := PrefixFrom(tt.ip, tt.in) + if got := p.Bits(); got != tt.want { + t.Errorf("for (%v, %v), Bits out = %v; want %v", tt.ip, tt.in, got, tt.want) + } + } +} + +func TestParsePrefixAllocs(t *testing.T) { + tests := []struct { + ip string + slash string + }{ + {"192.168.1.0", "/24"}, + {"aaaa:bbbb:cccc::", "/24"}, + } + for _, test := range tests { + prefix := test.ip + test.slash + t.Run(prefix, func(t *testing.T) { + ipAllocs := int(testing.AllocsPerRun(5, func() { + ParseAddr(test.ip) + })) + prefixAllocs := int(testing.AllocsPerRun(5, func() { + ParsePrefix(prefix) + })) + if got := prefixAllocs - ipAllocs; got != 0 { + t.Errorf("allocs=%d, want 0", got) + } + }) + } +} + +func TestParsePrefixError(t *testing.T) { + tests := []struct { + prefix string + errstr string + }{ + { + prefix: "192.168.0.0", + errstr: "no '/'", + }, + { + prefix: "1.257.1.1/24", + errstr: "value >255", + }, + { + prefix: "1.1.1.0/q", + errstr: "bad bits", + }, + { + prefix: "1.1.1.0/-1", + errstr: "out of range", + }, + { + prefix: "1.1.1.0/33", + errstr: "out of range", + }, + { + prefix: "2001::/129", + errstr: "out of range", + }, + } + for _, test := range tests { + t.Run(test.prefix, func(t *testing.T) { + _, err := ParsePrefix(test.prefix) + if err == nil { + t.Fatal("no error") + } + if got := err.Error(); !strings.Contains(got, test.errstr) { + t.Errorf("error is missing substring %q: %s", test.errstr, got) + } + }) + } +} + +func TestPrefixIsSingleIP(t *testing.T) { + tests := []struct { + ipp Prefix + want bool + }{ + {ipp: mustPrefix("127.0.0.1/32"), want: true}, + {ipp: mustPrefix("127.0.0.1/31"), want: false}, + {ipp: mustPrefix("127.0.0.1/0"), want: false}, + {ipp: mustPrefix("::1/128"), want: true}, + {ipp: mustPrefix("::1/127"), want: false}, + {ipp: mustPrefix("::1/0"), want: false}, + {ipp: Prefix{}, want: false}, + } + for _, tt := range tests { + got := tt.ipp.IsSingleIP() + if got != tt.want { + t.Errorf("IsSingleIP(%v) = %v want %v", tt.ipp, got, tt.want) + } + } +} + +func mustIPs(strs ...string) []Addr { + var res []Addr + for _, s := range strs { + res = append(res, mustIP(s)) + } + return res +} + +func BenchmarkBinaryMarshalRoundTrip(b *testing.B) { + b.ReportAllocs() + tests := []struct { + name string + ip string + }{ + {"ipv4", "1.2.3.4"}, + {"ipv6", "2001:db8::1"}, + {"ipv6+zone", "2001:db8::1%eth0"}, + } + for _, tc := range tests { + b.Run(tc.name, func(b *testing.B) { + ip := mustIP(tc.ip) + for i := 0; i < b.N; i++ { + bt, err := ip.MarshalBinary() + if err != nil { + b.Fatal(err) + } + var ip2 Addr + if err := ip2.UnmarshalBinary(bt); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkStdIPv4(b *testing.B) { + b.ReportAllocs() + ips := []net.IP{} + for i := 0; i < b.N; i++ { + ip := net.IPv4(8, 8, 8, 8) + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +func BenchmarkIPv4(b *testing.B) { + b.ReportAllocs() + ips := []Addr{} + for i := 0; i < b.N; i++ { + ip := IPv4(8, 8, 8, 8) + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +// ip4i was one of the possible representations of IP that came up in +// discussions, inlining IPv4 addresses, but having an "overflow" +// interface for IPv6 or IPv6 + zone. This is here for benchmarking. +type ip4i struct { + ip4 [4]byte + flags1 byte + flags2 byte + flags3 byte + flags4 byte + ipv6 interface{} +} + +func newip4i_v4(a, b, c, d byte) ip4i { + return ip4i{ip4: [4]byte{a, b, c, d}} +} + +// BenchmarkIPv4_inline benchmarks the candidate representation, ip4i. +func BenchmarkIPv4_inline(b *testing.B) { + b.ReportAllocs() + ips := []ip4i{} + for i := 0; i < b.N; i++ { + ip := newip4i_v4(8, 8, 8, 8) + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +func BenchmarkStdIPv6(b *testing.B) { + b.ReportAllocs() + ips := []net.IP{} + for i := 0; i < b.N; i++ { + ip := net.ParseIP("2001:db8::1") + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +func BenchmarkIPv6(b *testing.B) { + b.ReportAllocs() + ips := []Addr{} + for i := 0; i < b.N; i++ { + ip := mustIP("2001:db8::1") + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +func BenchmarkIPv4Contains(b *testing.B) { + b.ReportAllocs() + prefix := PrefixFrom(IPv4(192, 168, 1, 0), 24) + ip := IPv4(192, 168, 1, 1) + for i := 0; i < b.N; i++ { + prefix.Contains(ip) + } +} + +func BenchmarkIPv6Contains(b *testing.B) { + b.ReportAllocs() + prefix := MustParsePrefix("::1/128") + ip := MustParseAddr("::1") + for i := 0; i < b.N; i++ { + prefix.Contains(ip) + } +} + +var parseBenchInputs = []struct { + name string + ip string +}{ + {"v4", "192.168.1.1"}, + {"v6", "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b"}, + {"v6_ellipsis", "fd7a:115c::626b:430b"}, + {"v6_v4", "::ffff:192.168.140.255"}, + {"v6_zone", "1:2::ffff:192.168.140.255%eth1"}, +} + +func BenchmarkParseAddr(b *testing.B) { + sinkInternValue = intern.Get("eth1") // Pin to not benchmark the intern package + for _, test := range parseBenchInputs { + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkIP, _ = ParseAddr(test.ip) + } + }) + } +} + +func BenchmarkStdParseIP(b *testing.B) { + for _, test := range parseBenchInputs { + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkStdIP = net.ParseIP(test.ip) + } + }) + } +} + +func BenchmarkIPString(b *testing.B) { + for _, test := range parseBenchInputs { + ip := MustParseAddr(test.ip) + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkString = ip.String() + } + }) + } +} + +func BenchmarkIPStringExpanded(b *testing.B) { + for _, test := range parseBenchInputs { + ip := MustParseAddr(test.ip) + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkString = ip.StringExpanded() + } + }) + } +} + +func BenchmarkIPMarshalText(b *testing.B) { + b.ReportAllocs() + ip := MustParseAddr("66.55.44.33") + for i := 0; i < b.N; i++ { + sinkBytes, _ = ip.MarshalText() + } +} + +func BenchmarkAddrPortString(b *testing.B) { + for _, test := range parseBenchInputs { + ip := MustParseAddr(test.ip) + ipp := AddrPortFrom(ip, 60000) + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkString = ipp.String() + } + }) + } +} + +func BenchmarkAddrPortMarshalText(b *testing.B) { + for _, test := range parseBenchInputs { + ip := MustParseAddr(test.ip) + ipp := AddrPortFrom(ip, 60000) + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkBytes, _ = ipp.MarshalText() + } + }) + } +} + +func BenchmarkPrefixMasking(b *testing.B) { + tests := []struct { + name string + ip Addr + bits int + }{ + { + name: "IPv4 /32", + ip: IPv4(192, 0, 2, 0), + bits: 32, + }, + { + name: "IPv4 /17", + ip: IPv4(192, 0, 2, 0), + bits: 17, + }, + { + name: "IPv4 /0", + ip: IPv4(192, 0, 2, 0), + bits: 0, + }, + { + name: "IPv6 /128", + ip: mustIP("2001:db8::1"), + bits: 128, + }, + { + name: "IPv6 /65", + ip: mustIP("2001:db8::1"), + bits: 65, + }, + { + name: "IPv6 /0", + ip: mustIP("2001:db8::1"), + bits: 0, + }, + { + name: "IPv6 zone /128", + ip: mustIP("2001:db8::1%eth0"), + bits: 128, + }, + { + name: "IPv6 zone /65", + ip: mustIP("2001:db8::1%eth0"), + bits: 65, + }, + { + name: "IPv6 zone /0", + ip: mustIP("2001:db8::1%eth0"), + bits: 0, + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + sinkPrefix, _ = tt.ip.Prefix(tt.bits) + } + }) + } +} + +func BenchmarkPrefixMarshalText(b *testing.B) { + b.ReportAllocs() + ipp := MustParsePrefix("66.55.44.33/22") + for i := 0; i < b.N; i++ { + sinkBytes, _ = ipp.MarshalText() + } +} + +func BenchmarkParseAddrPort(b *testing.B) { + for _, test := range parseBenchInputs { + var ipp string + if strings.HasPrefix(test.name, "v6") { + ipp = fmt.Sprintf("[%s]:1234", test.ip) + } else { + ipp = fmt.Sprintf("%s:1234", test.ip) + } + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + sinkAddrPort, _ = ParseAddrPort(ipp) + } + }) + } +} + +func TestAs4(t *testing.T) { + tests := []struct { + ip Addr + want [4]byte + wantPanic bool + }{ + { + ip: mustIP("1.2.3.4"), + want: [4]byte{1, 2, 3, 4}, + }, + { + ip: AddrFrom16(mustIP("1.2.3.4").As16()), // IPv4-in-IPv6 + want: [4]byte{1, 2, 3, 4}, + }, + { + ip: mustIP("0.0.0.0"), + want: [4]byte{0, 0, 0, 0}, + }, + { + ip: Addr{}, + wantPanic: true, + }, + { + ip: mustIP("::1"), + wantPanic: true, + }, + } + as4 := func(ip Addr) (v [4]byte, gotPanic bool) { + defer func() { + if recover() != nil { + gotPanic = true + return + } + }() + v = ip.As4() + return + } + for i, tt := range tests { + got, gotPanic := as4(tt.ip) + if gotPanic != tt.wantPanic { + t.Errorf("%d. panic on %v = %v; want %v", i, tt.ip, gotPanic, tt.wantPanic) + continue + } + if got != tt.want { + t.Errorf("%d. %v = %v; want %v", i, tt.ip, got, tt.want) + } + } +} + +func TestPrefixOverlaps(t *testing.T) { + pfx := mustPrefix + tests := []struct { + a, b Prefix + want bool + }{ + {Prefix{}, pfx("1.2.0.0/16"), false}, // first zero + {pfx("1.2.0.0/16"), Prefix{}, false}, // second zero + {pfx("::0/3"), pfx("0.0.0.0/3"), false}, // different families + + {pfx("1.2.0.0/16"), pfx("1.2.0.0/16"), true}, // equal + + {pfx("1.2.0.0/16"), pfx("1.2.3.0/24"), true}, + {pfx("1.2.3.0/24"), pfx("1.2.0.0/16"), true}, + + {pfx("1.2.0.0/16"), pfx("1.2.3.0/32"), true}, + {pfx("1.2.3.0/32"), pfx("1.2.0.0/16"), true}, + + // Match /0 either order + {pfx("1.2.3.0/32"), pfx("0.0.0.0/0"), true}, + {pfx("0.0.0.0/0"), pfx("1.2.3.0/32"), true}, + + {pfx("1.2.3.0/32"), pfx("5.5.5.5/0"), true}, // normalization not required; /0 means true + + // IPv6 overlapping + {pfx("5::1/128"), pfx("5::0/8"), true}, + {pfx("5::0/8"), pfx("5::1/128"), true}, + + // IPv6 not overlapping + {pfx("1::1/128"), pfx("2::2/128"), false}, + {pfx("0100::0/8"), pfx("::1/128"), false}, + + // v6-mapped v4 should not overlap with IPv4. + {PrefixFrom(AddrFrom16(mustIP("1.2.0.0").As16()), 16), pfx("1.2.3.0/24"), false}, + + // Invalid prefixes + {PrefixFrom(mustIP("1.2.3.4"), 33), pfx("1.2.3.0/24"), false}, + {PrefixFrom(mustIP("2000::"), 129), pfx("2000::/64"), false}, + } + for i, tt := range tests { + if got := tt.a.Overlaps(tt.b); got != tt.want { + t.Errorf("%d. (%v).Overlaps(%v) = %v; want %v", i, tt.a, tt.b, got, tt.want) + } + // Overlaps is commutative + if got := tt.b.Overlaps(tt.a); got != tt.want { + t.Errorf("%d. (%v).Overlaps(%v) = %v; want %v", i, tt.b, tt.a, got, tt.want) + } + } +} + +// Sink variables are here to force the compiler to not elide +// seemingly useless work in benchmarks and allocation tests. If you +// were to just `_ = foo()` within a test function, the compiler could +// correctly deduce that foo() does nothing and doesn't need to be +// called. By writing results to a global variable, we hide that fact +// from the compiler and force it to keep the code under test. +var ( + sinkIP Addr + sinkStdIP net.IP + sinkAddrPort AddrPort + sinkPrefix Prefix + sinkPrefixSlice []Prefix + sinkInternValue *intern.Value + sinkIP16 [16]byte + sinkIP4 [4]byte + sinkBool bool + sinkString string + sinkBytes []byte + sinkUDPAddr = &net.UDPAddr{IP: make(net.IP, 0, 16)} +) + +func TestNoAllocs(t *testing.T) { + // Wrappers that panic on error, to prove that our alloc-free + // methods are returning successfully. + panicIP := func(ip Addr, err error) Addr { + if err != nil { + panic(err) + } + return ip + } + panicPfx := func(pfx Prefix, err error) Prefix { + if err != nil { + panic(err) + } + return pfx + } + panicIPP := func(ipp AddrPort, err error) AddrPort { + if err != nil { + panic(err) + } + return ipp + } + test := func(name string, f func()) { + t.Run(name, func(t *testing.T) { + n := testing.AllocsPerRun(1000, f) + if n != 0 { + t.Fatalf("allocs = %d; want 0", int(n)) + } + }) + } + + // IP constructors + test("IPv4", func() { sinkIP = IPv4(1, 2, 3, 4) }) + test("AddrFrom4", func() { sinkIP = AddrFrom4([4]byte{1, 2, 3, 4}) }) + test("AddrFrom16", func() { sinkIP = AddrFrom16([16]byte{}) }) + test("ParseAddr/4", func() { sinkIP = panicIP(ParseAddr("1.2.3.4")) }) + test("ParseAddr/6", func() { sinkIP = panicIP(ParseAddr("::1")) }) + test("MustParseAddr", func() { sinkIP = MustParseAddr("1.2.3.4") }) + test("IPv6LinkLocalAllNodes", func() { sinkIP = IPv6LinkLocalAllNodes() }) + test("IPv6Unspecified", func() { sinkIP = IPv6Unspecified() }) + + // IP methods + test("IP.IsZero", func() { sinkBool = MustParseAddr("1.2.3.4").IsZero() }) + test("IP.BitLen", func() { sinkBool = MustParseAddr("1.2.3.4").BitLen() == 8 }) + test("IP.Zone/4", func() { sinkBool = MustParseAddr("1.2.3.4").Zone() == "" }) + test("IP.Zone/6", func() { sinkBool = MustParseAddr("fe80::1").Zone() == "" }) + test("IP.Zone/6zone", func() { sinkBool = MustParseAddr("fe80::1%zone").Zone() == "" }) + test("IP.Compare", func() { + a := MustParseAddr("1.2.3.4") + b := MustParseAddr("2.3.4.5") + sinkBool = a.Compare(b) == 0 + }) + test("IP.Less", func() { + a := MustParseAddr("1.2.3.4") + b := MustParseAddr("2.3.4.5") + sinkBool = a.Less(b) + }) + test("IP.Is4", func() { sinkBool = MustParseAddr("1.2.3.4").Is4() }) + test("IP.Is6", func() { sinkBool = MustParseAddr("fe80::1").Is6() }) + test("IP.Is4In6", func() { sinkBool = MustParseAddr("fe80::1").Is4In6() }) + test("IP.Unmap", func() { sinkIP = MustParseAddr("ffff::2.3.4.5").Unmap() }) + test("IP.WithZone", func() { sinkIP = MustParseAddr("fe80::1").WithZone("") }) + test("IP.IsGlobalUnicast", func() { sinkBool = MustParseAddr("2001:db8::1").IsGlobalUnicast() }) + test("IP.IsInterfaceLocalMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsInterfaceLocalMulticast() }) + test("IP.IsLinkLocalMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsLinkLocalMulticast() }) + test("IP.IsLinkLocalUnicast", func() { sinkBool = MustParseAddr("fe80::1").IsLinkLocalUnicast() }) + test("IP.IsLoopback", func() { sinkBool = MustParseAddr("fe80::1").IsLoopback() }) + test("IP.IsMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsMulticast() }) + test("IP.IsPrivate", func() { sinkBool = MustParseAddr("fd00::1").IsPrivate() }) + test("IP.IsUnspecified", func() { sinkBool = IPv6Unspecified().IsUnspecified() }) + test("IP.Prefix/4", func() { sinkPrefix = panicPfx(MustParseAddr("1.2.3.4").Prefix(20)) }) + test("IP.Prefix/6", func() { sinkPrefix = panicPfx(MustParseAddr("fe80::1").Prefix(64)) }) + test("IP.As16", func() { sinkIP16 = MustParseAddr("1.2.3.4").As16() }) + test("IP.As4", func() { sinkIP4 = MustParseAddr("1.2.3.4").As4() }) + test("IP.Next", func() { sinkIP = MustParseAddr("1.2.3.4").Next() }) + test("IP.Prev", func() { sinkIP = MustParseAddr("1.2.3.4").Prev() }) + + // AddrPort constructors + test("AddrPortFrom", func() { sinkAddrPort = AddrPortFrom(IPv4(1, 2, 3, 4), 22) }) + test("ParseAddrPort", func() { sinkAddrPort = panicIPP(ParseAddrPort("[::1]:1234")) }) + test("MustParseAddrPort", func() { sinkAddrPort = MustParseAddrPort("[::1]:1234") }) + + // Prefix constructors + test("PrefixFrom", func() { sinkPrefix = PrefixFrom(IPv4(1, 2, 3, 4), 32) }) + test("ParsePrefix/4", func() { sinkPrefix = panicPfx(ParsePrefix("1.2.3.4/20")) }) + test("ParsePrefix/6", func() { sinkPrefix = panicPfx(ParsePrefix("fe80::1/64")) }) + test("MustParsePrefix", func() { sinkPrefix = MustParsePrefix("1.2.3.4/20") }) + + // Prefix methods + test("Prefix.Contains", func() { sinkBool = MustParsePrefix("1.2.3.0/24").Contains(MustParseAddr("1.2.3.4")) }) + test("Prefix.Overlaps", func() { + a, b := MustParsePrefix("1.2.3.0/24"), MustParsePrefix("1.2.0.0/16") + sinkBool = a.Overlaps(b) + }) + test("Prefix.IsZero", func() { sinkBool = MustParsePrefix("1.2.0.0/16").IsZero() }) + test("Prefix.IsSingleIP", func() { sinkBool = MustParsePrefix("1.2.3.4/32").IsSingleIP() }) + test("IPPRefix.Masked", func() { sinkPrefix = MustParsePrefix("1.2.3.4/16").Masked() }) +} + +func TestPrefixString(t *testing.T) { + tests := []struct { + ipp Prefix + want string + }{ + {Prefix{}, "invalid Prefix"}, + {PrefixFrom(Addr{}, 8), "invalid Prefix"}, + {PrefixFrom(MustParseAddr("1.2.3.4"), 88), "invalid Prefix"}, + } + + for _, tt := range tests { + if got := tt.ipp.String(); got != tt.want { + t.Errorf("(%#v).String() = %q want %q", tt.ipp, got, tt.want) + } + } +} + +func TestInvalidAddrPortString(t *testing.T) { + tests := []struct { + ipp AddrPort + want string + }{ + {AddrPort{}, "invalid AddrPort"}, + {AddrPortFrom(Addr{}, 80), "invalid AddrPort"}, + } + + for _, tt := range tests { + if got := tt.ipp.String(); got != tt.want { + t.Errorf("(%#v).String() = %q want %q", tt.ipp, got, tt.want) + } + } +} diff --git a/src/net/netip/slow_test.go b/src/net/netip/slow_test.go new file mode 100644 index 0000000000..5b46a39a83 --- /dev/null +++ b/src/net/netip/slow_test.go @@ -0,0 +1,190 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip_test + +import ( + "fmt" + . "net/netip" + "strconv" + "strings" +) + +// zeros is a slice of eight stringified zeros. It's used in +// parseIPSlow to construct slices of specific amounts of zero fields, +// from 1 to 8. +var zeros = []string{"0", "0", "0", "0", "0", "0", "0", "0"} + +// parseIPSlow is like ParseIP, but aims for readability above +// speed. It's the reference implementation for correctness checking +// and against which we measure optimized parsers. +// +// parseIPSlow understands the following forms of IP addresses: +// - Regular IPv4: 1.2.3.4 +// - IPv4 with many leading zeros: 0000001.0000002.0000003.0000004 +// - Regular IPv6: 1111:2222:3333:4444:5555:6666:7777:8888 +// - IPv6 with many leading zeros: 00000001:0000002:0000003:0000004:0000005:0000006:0000007:0000008 +// - IPv6 with zero blocks elided: 1111:2222::7777:8888 +// - IPv6 with trailing 32 bits expressed as IPv4: 1111:2222:3333:4444:5555:6666:77.77.88.88 +// +// It does not process the following IP address forms, which have been +// varyingly accepted by some programs due to an under-specification +// of the shapes of IPv4 addresses: +// +// - IPv4 as a single 32-bit uint: 4660 (same as "1.2.3.4") +// - IPv4 with octal numbers: 0300.0250.0.01 (same as "192.168.0.1") +// - IPv4 with hex numbers: 0xc0.0xa8.0x0.0x1 (same as "192.168.0.1") +// - IPv4 in "class-B style": 1.2.52 (same as "1.2.3.4") +// - IPv4 in "class-A style": 1.564 (same as "1.2.3.4") +func parseIPSlow(s string) (Addr, error) { + // Identify and strip out the zone, if any. There should be 0 or 1 + // '%' in the string. + var zone string + fs := strings.Split(s, "%") + switch len(fs) { + case 1: + // No zone, that's fine. + case 2: + s, zone = fs[0], fs[1] + if zone == "" { + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): no zone after zone specifier", s) + } + default: + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): too many zone specifiers", s) // TODO: less specific? + } + + // IPv4 by itself is easy to do in a helper. + if strings.Count(s, ":") == 0 { + if zone != "" { + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): IPv4 addresses cannot have a zone", s) + } + return parseIPv4Slow(s) + } + + normal, err := normalizeIPv6Slow(s) + if err != nil { + return Addr{}, err + } + + // At this point, we've normalized the address back into 8 hex + // fields of 16 bits each. Parse that. + fs = strings.Split(normal, ":") + if len(fs) != 8 { + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): wrong size address", s) + } + var ret [16]byte + for i, f := range fs { + a, b, err := parseWord(f) + if err != nil { + return Addr{}, err + } + ret[i*2] = a + ret[i*2+1] = b + } + + return AddrFrom16(ret).WithZone(zone), nil +} + +// normalizeIPv6Slow expands s, which is assumed to be an IPv6 +// address, to its canonical text form. +// +// The canonical form of an IPv6 address is 8 colon-separated fields, +// where each field should be a hex value from 0 to ffff. This +// function does not verify the contents of each field. +// +// This function performs two transformations: +// - The last 32 bits of an IPv6 address may be represented in +// IPv4-style dotted quad form, as in 1:2:3:4:5:6:7.8.9.10. That +// address is transformed to its hex equivalent, +// e.g. 1:2:3:4:5:6:708:90a. +// - An address may contain one "::", which expands into as many +// 16-bit blocks of zeros as needed to make the address its correct +// full size. For example, fe80::1:2 expands to fe80:0:0:0:0:0:1:2. +// +// Both short forms may be present in a single address, +// e.g. fe80::1.2.3.4. +func normalizeIPv6Slow(orig string) (string, error) { + s := orig + + // Find and convert an IPv4 address in the final field, if any. + i := strings.LastIndex(s, ":") + if i == -1 { + return "", fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", orig) + } + if strings.Contains(s[i+1:], ".") { + ip, err := parseIPv4Slow(s[i+1:]) + if err != nil { + return "", err + } + a4 := ip.As4() + s = fmt.Sprintf("%s:%02x%02x:%02x%02x", s[:i], a4[0], a4[1], a4[2], a4[3]) + } + + // Find and expand a ::, if any. + fs := strings.Split(s, "::") + switch len(fs) { + case 1: + // No ::, nothing to do. + case 2: + lhs, rhs := fs[0], fs[1] + // Found a ::, figure out how many zero blocks need to be + // inserted. + nblocks := strings.Count(lhs, ":") + strings.Count(rhs, ":") + if lhs != "" { + nblocks++ + } + if rhs != "" { + nblocks++ + } + if nblocks > 7 { + return "", fmt.Errorf("netaddr.ParseIP(%q): address too long", orig) + } + fs = nil + // Either side of the :: can be empty. We don't want empty + // fields to feature in the final normalized address. + if lhs != "" { + fs = append(fs, lhs) + } + fs = append(fs, zeros[:8-nblocks]...) + if rhs != "" { + fs = append(fs, rhs) + } + s = strings.Join(fs, ":") + default: + // Too many :: + return "", fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", orig) + } + + return s, nil +} + +// parseIPv4Slow parses and returns an IPv4 address in dotted quad +// form, e.g. "192.168.0.1". It is slow but easy to read, and the +// reference implementation against which we compare faster +// implementations for correctness. +func parseIPv4Slow(s string) (Addr, error) { + fs := strings.Split(s, ".") + if len(fs) != 4 { + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", s) + } + var ret [4]byte + for i := range ret { + val, err := strconv.ParseUint(fs[i], 10, 8) + if err != nil { + return Addr{}, err + } + ret[i] = uint8(val) + } + return AddrFrom4([4]byte{ret[0], ret[1], ret[2], ret[3]}), nil +} + +// parseWord converts a 16-bit hex string into its corresponding +// two-byte value. +func parseWord(s string) (byte, byte, error) { + ret, err := strconv.ParseUint(s, 16, 16) + if err != nil { + return 0, 0, err + } + return uint8(ret >> 8), uint8(ret), nil +} diff --git a/src/net/netip/uint128.go b/src/net/netip/uint128.go new file mode 100644 index 0000000000..738939d7de --- /dev/null +++ b/src/net/netip/uint128.go @@ -0,0 +1,92 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import "math/bits" + +// uint128 represents a uint128 using two uint64s. +// +// When the methods below mention a bit number, bit 0 is the most +// significant bit (in hi) and bit 127 is the lowest (lo&1). +type uint128 struct { + hi uint64 + lo uint64 +} + +// mask6 returns a uint128 bitmask with the topmost n bits of a +// 128-bit number. +func mask6(n int) uint128 { + return uint128{^(^uint64(0) >> n), ^uint64(0) << (128 - n)} +} + +// isZero reports whether u == 0. +// +// It's faster than u == (uint128{}) because the compiler (as of Go +// 1.15/1.16b1) doesn't do this trick and instead inserts a branch in +// its eq alg's generated code. +func (u uint128) isZero() bool { return u.hi|u.lo == 0 } + +// and returns the bitwise AND of u and m (u&m). +func (u uint128) and(m uint128) uint128 { + return uint128{u.hi & m.hi, u.lo & m.lo} +} + +// xor returns the bitwise XOR of u and m (u^m). +func (u uint128) xor(m uint128) uint128 { + return uint128{u.hi ^ m.hi, u.lo ^ m.lo} +} + +// or returns the bitwise OR of u and m (u|m). +func (u uint128) or(m uint128) uint128 { + return uint128{u.hi | m.hi, u.lo | m.lo} +} + +// not returns the bitwise NOT of u. +func (u uint128) not() uint128 { + return uint128{^u.hi, ^u.lo} +} + +// subOne returns u - 1. +func (u uint128) subOne() uint128 { + lo, borrow := bits.Sub64(u.lo, 1, 0) + return uint128{u.hi - borrow, lo} +} + +// addOne returns u + 1. +func (u uint128) addOne() uint128 { + lo, carry := bits.Add64(u.lo, 1, 0) + return uint128{u.hi + carry, lo} +} + +func u64CommonPrefixLen(a, b uint64) uint8 { + return uint8(bits.LeadingZeros64(a ^ b)) +} + +func (u uint128) commonPrefixLen(v uint128) (n uint8) { + if n = u64CommonPrefixLen(u.hi, v.hi); n == 64 { + n += u64CommonPrefixLen(u.lo, v.lo) + } + return +} + +// halves returns the two uint64 halves of the uint128. +// +// Logically, think of it as returning two uint64s. +// It only returns pointers for inlining reasons on 32-bit platforms. +func (u *uint128) halves() [2]*uint64 { + return [2]*uint64{&u.hi, &u.lo} +} + +// bitsSetFrom returns a copy of u with the given bit +// and all subsequent ones set. +func (u uint128) bitsSetFrom(bit uint8) uint128 { + return u.or(mask6(int(bit)).not()) +} + +// bitsClearedFrom returns a copy of u with the given bit +// and all subsequent ones cleared. +func (u uint128) bitsClearedFrom(bit uint8) uint128 { + return u.and(mask6(int(bit))) +} diff --git a/src/net/netip/uint128_test.go b/src/net/netip/uint128_test.go new file mode 100644 index 0000000000..dd1ae0ec79 --- /dev/null +++ b/src/net/netip/uint128_test.go @@ -0,0 +1,89 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import ( + "testing" +) + +func TestUint128AddSub(t *testing.T) { + const add1 = 1 + const sub1 = -1 + tests := []struct { + in uint128 + op int // +1 or -1 to add vs subtract + want uint128 + }{ + {uint128{0, 0}, add1, uint128{0, 1}}, + {uint128{0, 1}, add1, uint128{0, 2}}, + {uint128{1, 0}, add1, uint128{1, 1}}, + {uint128{0, ^uint64(0)}, add1, uint128{1, 0}}, + {uint128{^uint64(0), ^uint64(0)}, add1, uint128{0, 0}}, + + {uint128{0, 0}, sub1, uint128{^uint64(0), ^uint64(0)}}, + {uint128{0, 1}, sub1, uint128{0, 0}}, + {uint128{0, 2}, sub1, uint128{0, 1}}, + {uint128{1, 0}, sub1, uint128{0, ^uint64(0)}}, + {uint128{1, 1}, sub1, uint128{1, 0}}, + } + for _, tt := range tests { + var got uint128 + switch tt.op { + case add1: + got = tt.in.addOne() + case sub1: + got = tt.in.subOne() + default: + panic("bogus op") + } + if got != tt.want { + t.Errorf("%v add %d = %v; want %v", tt.in, tt.op, got, tt.want) + } + } +} + +func TestBitsSetFrom(t *testing.T) { + tests := []struct { + bit uint8 + want uint128 + }{ + {0, uint128{^uint64(0), ^uint64(0)}}, + {1, uint128{^uint64(0) >> 1, ^uint64(0)}}, + {63, uint128{1, ^uint64(0)}}, + {64, uint128{0, ^uint64(0)}}, + {65, uint128{0, ^uint64(0) >> 1}}, + {127, uint128{0, 1}}, + {128, uint128{0, 0}}, + } + for _, tt := range tests { + var zero uint128 + got := zero.bitsSetFrom(tt.bit) + if got != tt.want { + t.Errorf("0.bitsSetFrom(%d) = %064b want %064b", tt.bit, got, tt.want) + } + } +} + +func TestBitsClearedFrom(t *testing.T) { + tests := []struct { + bit uint8 + want uint128 + }{ + {0, uint128{0, 0}}, + {1, uint128{1 << 63, 0}}, + {63, uint128{^uint64(0) &^ 1, 0}}, + {64, uint128{^uint64(0), 0}}, + {65, uint128{^uint64(0), 1 << 63}}, + {127, uint128{^uint64(0), ^uint64(0) &^ 1}}, + {128, uint128{^uint64(0), ^uint64(0)}}, + } + for _, tt := range tests { + ones := uint128{^uint64(0), ^uint64(0)} + got := ones.bitsClearedFrom(tt.bit) + if got != tt.want { + t.Errorf("ones.bitsClearedFrom(%d) = %064b want %064b", tt.bit, got, tt.want) + } + } +} diff --git a/src/net/parse.go b/src/net/parse.go index 0d7cce12e6..ee2890fe2c 100644 --- a/src/net/parse.go +++ b/src/net/parse.go @@ -341,26 +341,3 @@ func readFull(r io.Reader) (all []byte, err error) { } } } - -// goDebugString returns the value of the named GODEBUG key. -// GODEBUG is of the form "key=val,key2=val2" -func goDebugString(key string) string { - s := os.Getenv("GODEBUG") - for i := 0; i < len(s)-len(key)-1; i++ { - if i > 0 && s[i-1] != ',' { - continue - } - afterKey := s[i+len(key):] - if afterKey[0] != '=' || s[i:i+len(key)] != key { - continue - } - val := afterKey[1:] - for i, b := range val { - if b == ',' { - return val[:i] - } - } - return val - } - return "" -} diff --git a/src/net/parse_test.go b/src/net/parse_test.go index c5f8bfd198..97716d769a 100644 --- a/src/net/parse_test.go +++ b/src/net/parse_test.go @@ -51,33 +51,6 @@ func TestReadLine(t *testing.T) { } } -func TestGoDebugString(t *testing.T) { - defer os.Setenv("GODEBUG", os.Getenv("GODEBUG")) - tests := []struct { - godebug string - key string - want string - }{ - {"", "foo", ""}, - {"foo=", "foo", ""}, - {"foo=bar", "foo", "bar"}, - {"foo=bar,", "foo", "bar"}, - {"foo,foo=bar,", "foo", "bar"}, - {"foo1=bar,foo=bar,", "foo", "bar"}, - {"foo=bar,foo=bar,", "foo", "bar"}, - {"foo=", "foo", ""}, - {"foo", "foo", ""}, - {",foo", "foo", ""}, - {"foo=bar,baz", "loooooooong", ""}, - } - for _, tt := range tests { - os.Setenv("GODEBUG", tt.godebug) - if got := goDebugString(tt.key); got != tt.want { - t.Errorf("for %q, goDebugString(%q) = %q; want %q", tt.godebug, tt.key, got, tt.want) - } - } -} - func TestDtoi(t *testing.T) { for _, tt := range []struct { in string diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go index 19a90143f3..fddb018aab 100644 --- a/src/net/tcpsock.go +++ b/src/net/tcpsock.go @@ -8,6 +8,7 @@ import ( "context" "internal/itoa" "io" + "net/netip" "os" "syscall" "time" @@ -23,6 +24,20 @@ type TCPAddr struct { Zone string // IPv6 scoped addressing zone } +// AddrPort returns the TCPAddr a as a netip.AddrPort. +// +// If a.Port does not fit in a uint16, it's silently truncated. +// +// If a is nil, a zero value is returned. +func (a *TCPAddr) AddrPort() netip.AddrPort { + if a == nil { + return netip.AddrPort{} + } + na, _ := netip.AddrFromSlice(a.IP) + na = na.WithZone(a.Zone) + return netip.AddrPortFrom(na, uint16(a.Port)) +} + // Network returns the address's network name, "tcp". func (a *TCPAddr) Network() string { return "tcp" } diff --git a/src/net/udpsock.go b/src/net/udpsock.go index 70f2ce226a..95ffa85939 100644 --- a/src/net/udpsock.go +++ b/src/net/udpsock.go @@ -7,6 +7,7 @@ package net import ( "context" "internal/itoa" + "net/netip" "syscall" ) @@ -26,6 +27,20 @@ type UDPAddr struct { Zone string // IPv6 scoped addressing zone } +// AddrPort returns the UDPAddr a as a netip.AddrPort. +// +// If a.Port does not fit in a uint16, it's silently truncated. +// +// If a is nil, a zero value is returned. +func (a *UDPAddr) AddrPort() netip.AddrPort { + if a == nil { + return netip.AddrPort{} + } + na, _ := netip.AddrFromSlice(a.IP) + na = na.WithZone(a.Zone) + return netip.AddrPortFrom(na, uint16(a.Port)) +} + // Network returns the address's network name, "udp". func (a *UDPAddr) Network() string { return "udp" } @@ -84,6 +99,21 @@ func ResolveUDPAddr(network, address string) (*UDPAddr, error) { return addrs.forResolve(network, address).(*UDPAddr), nil } +// UDPAddrFromAddrPort returns addr as a UDPAddr. +// +// If addr is not valid, it returns nil. +func UDPAddrFromAddrPort(addr netip.AddrPort) *UDPAddr { + if !addr.IsValid() { + return nil + } + ip16 := addr.Addr().As16() + return &UDPAddr{ + IP: IP(ip16[:]), + Zone: addr.Addr().Zone(), + Port: int(addr.Port()), + } +} + // UDPConn is the implementation of the Conn and PacketConn interfaces // for UDP network connections. type UDPConn struct { @@ -148,6 +178,18 @@ func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, return } +// ReadMsgUDPAddrPort is like ReadMsgUDP but returns an netip.AddrPort instead of a UDPAddr. +func (c *UDPConn) ReadMsgUDPAddrPort(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) { + // TODO(bradfitz): make this efficient, making the internal net package + // type throughout be netip.Addr and only converting to the net.IP slice + // version at the edge. But for now (2021-10-20), this is a wrapper around + // the old way. + var ua *UDPAddr + n, oobn, flags, ua, err = c.ReadMsgUDP(b, oob) + addr = ua.AddrPort() + return +} + // WriteToUDP acts like WriteTo but takes a UDPAddr. func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { if !c.ok() { @@ -160,6 +202,15 @@ func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { return n, err } +// WriteToUDPAddrPort acts like WriteTo but takes a netip.AddrPort. +func (c *UDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + // TODO(bradfitz): make this efficient, making the internal net package + // type throughout be netip.Addr and only converting to the net.IP slice + // version at the edge. But for now (2021-10-20), this is a wrapper around + // the old way. + return c.WriteToUDP(b, UDPAddrFromAddrPort(addr)) +} + // WriteTo implements the PacketConn WriteTo method. func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) { if !c.ok() { @@ -195,6 +246,15 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er return } +// WriteMsgUDPAddrPort is like WriteMsgUDP but takes a netip.AddrPort instead of a UDPAddr. +func (c *UDPConn) WriteMsgUDPAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) { + // TODO(bradfitz): make this efficient, making the internal net package + // type throughout be netip.Addr and only converting to the net.IP slice + // version at the edge. But for now (2021-10-20), this is a wrapper around + // the old way. + return c.WriteMsgUDP(b, oob, UDPAddrFromAddrPort(addr)) +} + func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } // DialUDP acts like Dial for UDP networks. diff --git a/src/strings/compare.go b/src/strings/compare.go index 1fe6b8d89a..2bd4a243db 100644 --- a/src/strings/compare.go +++ b/src/strings/compare.go @@ -5,7 +5,7 @@ package strings // Compare returns an integer comparing two strings lexicographically. -// The result will be 0 if a==b, -1 if a < b, and +1 if a > b. +// The result will be 0 if a == b, -1 if a < b, and +1 if a > b. // // Compare is included only for symmetry with package bytes. // It is usually clearer and always faster to use the built-in