mirror of
https://github.com/golang/go
synced 2024-11-02 11:50:30 +00:00
crypto/internal/boring/bcache: make Cache type-safe using generics
Generics lets us write Cache[K, V] instead of using unsafe.Pointer, which lets us remove all the uses of package unsafe around the uses of the cache. I tried to do Cache[*K, *V] instead of Cache[K, V] but that was not possible. Change-Id: If3b54cf4c8d2a44879a5f343fd91ecff096537e9 Reviewed-on: https://go-review.googlesource.com/c/go/+/423357 TryBot-Result: Gopher Robot <gobot@golang.org> Run-TryBot: Russ Cox <rsc@golang.org> Reviewed-by: Cherry Mui <cherryyz@google.com> Auto-Submit: Russ Cox <rsc@golang.org>
This commit is contained in:
parent
9709d92bfa
commit
b30ba3df9f
4 changed files with 64 additions and 65 deletions
|
@ -11,7 +11,6 @@ import (
|
|||
"crypto/internal/boring/bbig"
|
||||
"crypto/internal/boring/bcache"
|
||||
"math/big"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Cached conversions from Go PublicKey/PrivateKey to BoringCrypto.
|
||||
|
@ -27,8 +26,8 @@ import (
|
|||
// still matches before using the cached key. The theory is that the real
|
||||
// operations are significantly more expensive than the comparison.
|
||||
|
||||
var pubCache bcache.Cache
|
||||
var privCache bcache.Cache
|
||||
var pubCache bcache.Cache[PublicKey, boringPub]
|
||||
var privCache bcache.Cache[PrivateKey, boringPriv]
|
||||
|
||||
func init() {
|
||||
pubCache.Register()
|
||||
|
@ -41,7 +40,7 @@ type boringPub struct {
|
|||
}
|
||||
|
||||
func boringPublicKey(pub *PublicKey) (*boring.PublicKeyECDSA, error) {
|
||||
b := (*boringPub)(pubCache.Get(unsafe.Pointer(pub)))
|
||||
b := pubCache.Get(pub)
|
||||
if b != nil && publicKeyEqual(&b.orig, pub) {
|
||||
return b.key, nil
|
||||
}
|
||||
|
@ -53,7 +52,7 @@ func boringPublicKey(pub *PublicKey) (*boring.PublicKeyECDSA, error) {
|
|||
return nil, err
|
||||
}
|
||||
b.key = key
|
||||
pubCache.Put(unsafe.Pointer(pub), unsafe.Pointer(b))
|
||||
pubCache.Put(pub, b)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
|
@ -63,7 +62,7 @@ type boringPriv struct {
|
|||
}
|
||||
|
||||
func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyECDSA, error) {
|
||||
b := (*boringPriv)(privCache.Get(unsafe.Pointer(priv)))
|
||||
b := privCache.Get(priv)
|
||||
if b != nil && privateKeyEqual(&b.orig, priv) {
|
||||
return b.key, nil
|
||||
}
|
||||
|
@ -75,7 +74,7 @@ func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyECDSA, error) {
|
|||
return nil, err
|
||||
}
|
||||
b.key = key
|
||||
privCache.Put(unsafe.Pointer(priv), unsafe.Pointer(b))
|
||||
privCache.Put(priv, b)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -22,20 +22,18 @@ import (
|
|||
// This means that clients need to be able to cope with cache entries
|
||||
// disappearing, but it also means that clients don't need to worry about
|
||||
// cache entries keeping the keys from being collected.
|
||||
//
|
||||
// TODO(rsc): Make Cache generic once consumers can handle that.
|
||||
type Cache struct {
|
||||
// ptable is an atomic *[cacheSize]unsafe.Pointer,
|
||||
// where each unsafe.Pointer is an atomic *cacheEntry.
|
||||
type Cache[K, V any] struct {
|
||||
// The runtime atomically stores nil to ptable at the start of each GC.
|
||||
ptable unsafe.Pointer
|
||||
ptable atomic.Pointer[cacheTable[K, V]]
|
||||
}
|
||||
|
||||
type cacheTable[K, V any] [cacheSize]atomic.Pointer[cacheEntry[K, V]]
|
||||
|
||||
// A cacheEntry is a single entry in the linked list for a given hash table entry.
|
||||
type cacheEntry struct {
|
||||
k unsafe.Pointer // immutable once created
|
||||
v unsafe.Pointer // read and written atomically to allow updates
|
||||
next *cacheEntry // immutable once linked into table
|
||||
type cacheEntry[K, V any] struct {
|
||||
k *K // immutable once created
|
||||
v atomic.Pointer[V] // read and written atomically to allow updates
|
||||
next *cacheEntry[K, V] // immutable once linked into table
|
||||
}
|
||||
|
||||
func registerCache(unsafe.Pointer) // provided by runtime
|
||||
|
@ -43,7 +41,7 @@ func registerCache(unsafe.Pointer) // provided by runtime
|
|||
// Register registers the cache with the runtime,
|
||||
// so that c.ptable can be cleared at the start of each GC.
|
||||
// Register must be called during package initialization.
|
||||
func (c *Cache) Register() {
|
||||
func (c *Cache[K, V]) Register() {
|
||||
registerCache(unsafe.Pointer(&c.ptable))
|
||||
}
|
||||
|
||||
|
@ -54,45 +52,45 @@ const cacheSize = 1021
|
|||
|
||||
// table returns a pointer to the current cache hash table,
|
||||
// coping with the possibility of the GC clearing it out from under us.
|
||||
func (c *Cache) table() *[cacheSize]unsafe.Pointer {
|
||||
func (c *Cache[K, V]) table() *cacheTable[K, V] {
|
||||
for {
|
||||
p := atomic.LoadPointer(&c.ptable)
|
||||
p := c.ptable.Load()
|
||||
if p == nil {
|
||||
p = unsafe.Pointer(new([cacheSize]unsafe.Pointer))
|
||||
if !atomic.CompareAndSwapPointer(&c.ptable, nil, p) {
|
||||
p = new(cacheTable[K, V])
|
||||
if !c.ptable.CompareAndSwap(nil, p) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return (*[cacheSize]unsafe.Pointer)(p)
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
// Clear clears the cache.
|
||||
// The runtime does this automatically at each garbage collection;
|
||||
// this method is exposed only for testing.
|
||||
func (c *Cache) Clear() {
|
||||
func (c *Cache[K, V]) Clear() {
|
||||
// The runtime does this at the start of every garbage collection
|
||||
// (itself, not by calling this function).
|
||||
atomic.StorePointer(&c.ptable, nil)
|
||||
c.ptable.Store(nil)
|
||||
}
|
||||
|
||||
// Get returns the cached value associated with v,
|
||||
// which is either the value v corresponding to the most recent call to Put(k, v)
|
||||
// or nil if that cache entry has been dropped.
|
||||
func (c *Cache) Get(k unsafe.Pointer) unsafe.Pointer {
|
||||
head := &c.table()[uintptr(k)%cacheSize]
|
||||
e := (*cacheEntry)(atomic.LoadPointer(head))
|
||||
func (c *Cache[K, V]) Get(k *K) *V {
|
||||
head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
|
||||
e := head.Load()
|
||||
for ; e != nil; e = e.next {
|
||||
if e.k == k {
|
||||
return atomic.LoadPointer(&e.v)
|
||||
return e.v.Load()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put sets the cached value associated with k to v.
|
||||
func (c *Cache) Put(k, v unsafe.Pointer) {
|
||||
head := &c.table()[uintptr(k)%cacheSize]
|
||||
func (c *Cache[K, V]) Put(k *K, v *V) {
|
||||
head := &c.table()[uintptr(unsafe.Pointer(k))%cacheSize]
|
||||
|
||||
// Strategy is to walk the linked list at head,
|
||||
// same as in Get, to look for existing entry.
|
||||
|
@ -112,20 +110,21 @@ func (c *Cache) Put(k, v unsafe.Pointer) {
|
|||
//
|
||||
// 2. We only allocate the entry to be added once,
|
||||
// saving it in add for the next attempt.
|
||||
var add, noK *cacheEntry
|
||||
var add, noK *cacheEntry[K, V]
|
||||
n := 0
|
||||
for {
|
||||
e := (*cacheEntry)(atomic.LoadPointer(head))
|
||||
e := head.Load()
|
||||
start := e
|
||||
for ; e != nil && e != noK; e = e.next {
|
||||
if e.k == k {
|
||||
atomic.StorePointer(&e.v, v)
|
||||
e.v.Store(v)
|
||||
return
|
||||
}
|
||||
n++
|
||||
}
|
||||
if add == nil {
|
||||
add = &cacheEntry{k, v, nil}
|
||||
add = &cacheEntry[K, V]{k: k}
|
||||
add.v.Store(v)
|
||||
}
|
||||
add.next = start
|
||||
if n >= 1000 {
|
||||
|
@ -133,7 +132,7 @@ func (c *Cache) Put(k, v unsafe.Pointer) {
|
|||
// throw it away to avoid quadratic lookup behavior.
|
||||
add.next = nil
|
||||
}
|
||||
if atomic.CompareAndSwapPointer(head, unsafe.Pointer(start), unsafe.Pointer(add)) {
|
||||
if head.CompareAndSwap(start, add) {
|
||||
return
|
||||
}
|
||||
noK = start
|
||||
|
|
|
@ -10,31 +10,39 @@ import (
|
|||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var registeredCache Cache
|
||||
var registeredCache Cache[int, int32]
|
||||
|
||||
func init() {
|
||||
registeredCache.Register()
|
||||
}
|
||||
|
||||
var seq atomic.Uint32
|
||||
|
||||
func next[T int | int32]() *T {
|
||||
x := new(T)
|
||||
*x = T(seq.Add(1))
|
||||
return x
|
||||
}
|
||||
|
||||
func str[T int | int32](x *T) string {
|
||||
if x == nil {
|
||||
return "nil"
|
||||
}
|
||||
return fmt.Sprint(*x)
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
// Use unregistered cache for functionality tests,
|
||||
// to keep the runtime from clearing behind our backs.
|
||||
c := new(Cache)
|
||||
c := new(Cache[int, int32])
|
||||
|
||||
// Create many entries.
|
||||
seq := uint32(0)
|
||||
next := func() unsafe.Pointer {
|
||||
x := new(int)
|
||||
*x = int(atomic.AddUint32(&seq, 1))
|
||||
return unsafe.Pointer(x)
|
||||
}
|
||||
m := make(map[unsafe.Pointer]unsafe.Pointer)
|
||||
m := make(map[*int]*int32)
|
||||
for i := 0; i < 10000; i++ {
|
||||
k := next()
|
||||
v := next()
|
||||
k := next[int]()
|
||||
v := next[int32]()
|
||||
m[k] = v
|
||||
c.Put(k, v)
|
||||
}
|
||||
|
@ -42,7 +50,7 @@ func TestCache(t *testing.T) {
|
|||
// Overwrite a random 20% of those.
|
||||
n := 0
|
||||
for k := range m {
|
||||
v := next()
|
||||
v := next[int32]()
|
||||
m[k] = v
|
||||
c.Put(k, v)
|
||||
if n++; n >= 2000 {
|
||||
|
@ -51,12 +59,6 @@ func TestCache(t *testing.T) {
|
|||
}
|
||||
|
||||
// Check results.
|
||||
str := func(p unsafe.Pointer) string {
|
||||
if p == nil {
|
||||
return "nil"
|
||||
}
|
||||
return fmt.Sprint(*(*int)(p))
|
||||
}
|
||||
for k, v := range m {
|
||||
if cv := c.Get(k); cv != v {
|
||||
t.Fatalf("c.Get(%v) = %v, want %v", str(k), str(cv), str(v))
|
||||
|
@ -86,7 +88,7 @@ func TestCache(t *testing.T) {
|
|||
// Lists are discarded if they reach 1000 entries,
|
||||
// and there are cacheSize list heads, so we should be
|
||||
// able to do 100 * cacheSize entries with no problem at all.
|
||||
c = new(Cache)
|
||||
c = new(Cache[int, int32])
|
||||
var barrier, wg sync.WaitGroup
|
||||
const N = 100
|
||||
barrier.Add(N)
|
||||
|
@ -96,9 +98,9 @@ func TestCache(t *testing.T) {
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
m := make(map[unsafe.Pointer]unsafe.Pointer)
|
||||
m := make(map[*int]*int32)
|
||||
for j := 0; j < cacheSize; j++ {
|
||||
k, v := next(), next()
|
||||
k, v := next[int](), next[int32]()
|
||||
m[k] = v
|
||||
c.Put(k, v)
|
||||
}
|
||||
|
|
|
@ -11,7 +11,6 @@ import (
|
|||
"crypto/internal/boring/bbig"
|
||||
"crypto/internal/boring/bcache"
|
||||
"math/big"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Cached conversions from Go PublicKey/PrivateKey to BoringCrypto.
|
||||
|
@ -32,8 +31,8 @@ type boringPub struct {
|
|||
orig PublicKey
|
||||
}
|
||||
|
||||
var pubCache bcache.Cache
|
||||
var privCache bcache.Cache
|
||||
var pubCache bcache.Cache[PublicKey, boringPub]
|
||||
var privCache bcache.Cache[PrivateKey, boringPriv]
|
||||
|
||||
func init() {
|
||||
pubCache.Register()
|
||||
|
@ -41,7 +40,7 @@ func init() {
|
|||
}
|
||||
|
||||
func boringPublicKey(pub *PublicKey) (*boring.PublicKeyRSA, error) {
|
||||
b := (*boringPub)(pubCache.Get(unsafe.Pointer(pub)))
|
||||
b := pubCache.Get(pub)
|
||||
if b != nil && publicKeyEqual(&b.orig, pub) {
|
||||
return b.key, nil
|
||||
}
|
||||
|
@ -53,7 +52,7 @@ func boringPublicKey(pub *PublicKey) (*boring.PublicKeyRSA, error) {
|
|||
return nil, err
|
||||
}
|
||||
b.key = key
|
||||
pubCache.Put(unsafe.Pointer(pub), unsafe.Pointer(b))
|
||||
pubCache.Put(pub, b)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
|
@ -63,7 +62,7 @@ type boringPriv struct {
|
|||
}
|
||||
|
||||
func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyRSA, error) {
|
||||
b := (*boringPriv)(privCache.Get(unsafe.Pointer(priv)))
|
||||
b := privCache.Get(priv)
|
||||
if b != nil && privateKeyEqual(&b.orig, priv) {
|
||||
return b.key, nil
|
||||
}
|
||||
|
@ -87,7 +86,7 @@ func boringPrivateKey(priv *PrivateKey) (*boring.PrivateKeyRSA, error) {
|
|||
return nil, err
|
||||
}
|
||||
b.key = key
|
||||
privCache.Put(unsafe.Pointer(priv), unsafe.Pointer(b))
|
||||
privCache.Put(priv, b)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue