math/big: use built-in clear to simplify code

This commit is contained in:
apocelipes 2024-03-14 19:59:18 +09:00
parent f1d60500bc
commit 2ba8c4c705
3 changed files with 16 additions and 24 deletions

View file

@ -533,10 +533,8 @@ func (x *Int) Bytes() []byte {
//
// If the absolute value of x doesn't fit in buf, FillBytes will panic.
func (x *Int) FillBytes(buf []byte) []byte {
// Clear whole buffer. (This gets optimized into a memclr.)
for i := range buf {
buf[i] = 0
}
// Clear whole buffer.
clear(buf)
x.abs.bytes(buf)
return buf
}

View file

@ -44,12 +44,6 @@ func (z nat) String() string {
return "0x" + string(z.itoa(false, 16))
}
func (z nat) clear() {
for i := range z {
z[i] = 0
}
}
func (z nat) norm() nat {
i := len(z)
for i > 0 && z[i-1] == 0 {
@ -196,7 +190,7 @@ func (z nat) mulAddWW(x nat, y, r Word) nat {
// basicMul multiplies x and y and leaves the result in z.
// The (non-normalized) result is placed in z[0 : len(x) + len(y)].
func basicMul(z, x, y nat) {
z[0 : len(x)+len(y)].clear() // initialize z
clear(z[0 : len(x)+len(y)]) // initialize z
for i, d := range y {
if d != 0 {
z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
@ -222,7 +216,7 @@ func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
panic("math/big: mismatched montgomery number lengths")
}
z = z.make(n * 2)
z.clear()
clear(z)
var c Word
for i := 0; i < n; i++ {
d := y[i]
@ -443,8 +437,8 @@ func (z nat) mul(x, y nat) nat {
y0 := y[0:k] // y0 is not normalized
z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and full result of x*y
karatsuba(z, x0, y0)
z = z[0 : m+n] // z has final length but may be incomplete
z[2*k:].clear() // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
z = z[0 : m+n] // z has final length but may be incomplete
clear(z[2*k:]) // upper portion of z is garbage (and 2*k <= m+n since k <= n <= m)
// If xh != 0 or yh != 0, add the missing terms to z. For
//
@ -497,7 +491,7 @@ func basicSqr(z, x nat) {
n := len(x)
tp := getNat(2 * n)
t := *tp // temporary variable to hold the products
t.clear()
clear(t)
z[1], z[0] = mulWW(x[0], x[0]) // the initial square
for i := 1; i < n; i++ {
d := x[i]
@ -592,7 +586,7 @@ func (z nat) sqr(x nat) nat {
z = z.make(max(6*k, 2*n))
karatsubaSqr(z, x0) // z = x0^2
z = z[0 : 2*n]
z[2*k:].clear()
clear(z[2*k:])
if k < n {
tp := getNat(2 * k)
@ -723,7 +717,7 @@ func (z nat) shl(x nat, s uint) nat {
n := m + int(s/_W)
z = z.make(n + 1)
z[n] = shlVU(z[n-m:n], x, s%_W)
z[0 : n-m].clear()
clear(z[0 : n-m])
return z.norm()
}
@ -769,7 +763,7 @@ func (z nat) setBit(x nat, i uint, b uint) nat {
case 1:
if j >= n {
z = z.make(j + 1)
z[n:].clear()
clear(z[n:])
} else {
z = z.make(n)
}

View file

@ -734,7 +734,7 @@ func (z nat) divRecursive(u, v nat) {
tmp := getNat(3 * len(v))
temps := make([]*nat, recDepth)
z.clear()
clear(z)
z.divRecursiveStep(u, v, 0, tmp, temps)
// Free temporaries.
@ -758,7 +758,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
u = u.norm()
v = v.norm()
if len(u) == 0 {
z.clear()
clear(z)
return
}
@ -816,7 +816,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
// Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n].
qhat := *temps[depth]
qhat.clear()
clear(qhat)
qhat.divRecursiveStep(uu[s:B+n], v[s:], depth+1, tmp, temps)
qhat = qhat.norm()
@ -833,7 +833,7 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
// But we can do the subtraction directly, as in the comment above
// and in long division, because we know that q̂ is wrong by at most one.
qhatv := tmp.make(3 * n)
qhatv.clear()
clear(qhatv)
qhatv = qhatv.mul(qhat, v[:s])
for i := 0; i < 2; i++ {
e := qhatv.cmp(uu.norm())
@ -864,11 +864,11 @@ func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) {
// Choose shift = B-1 again.
s := B - 1
qhat := *temps[depth]
qhat.clear()
clear(qhat)
qhat.divRecursiveStep(u[s:].norm(), v[s:], depth+1, tmp, temps)
qhat = qhat.norm()
qhatv := tmp.make(3 * n)
qhatv.clear()
clear(qhatv)
qhatv = qhatv.mul(qhat, v[:s])
// Set the correct remainder as before.
for i := 0; i < 2; i++ {