diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 12c2580c95..a8e43d0114 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -1189,15 +1189,38 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { } v.Op = ctzNonZeroOp[v.Op] } - + case OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64, + OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64, + OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64, + OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64: + // Check whether, for a >> b, we know that a is non-negative + // and b is all of a's bits except the MSB. If so, a is shifted to zero. + bits := 8 * v.Type.Size() + if v.Args[1].isGenericIntConst() && v.Args[1].AuxInt >= bits-1 && ft.isNonNegative(v.Args[0]) { + if b.Func.pass.debug > 0 { + b.Func.Warnl(v.Pos, "Proved %v shifts to zero", v.Op) + } + switch bits { + case 64: + v.reset(OpConst64) + case 32: + v.reset(OpConst32) + case 16: + v.reset(OpConst16) + case 8: + v.reset(OpConst8) + default: + panic("unexpected integer size") + } + v.AuxInt = 0 + continue // Be sure not to fallthrough - this is no longer OpRsh. + } + // If the Rsh hasn't been replaced with 0, still check if it is bounded. + fallthrough case OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64, OpLsh16x8, OpLsh16x16, OpLsh16x32, OpLsh16x64, OpLsh32x8, OpLsh32x16, OpLsh32x32, OpLsh32x64, OpLsh64x8, OpLsh64x16, OpLsh64x32, OpLsh64x64, - OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64, - OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64, - OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64, - OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64, OpRsh8Ux8, OpRsh8Ux16, OpRsh8Ux32, OpRsh8Ux64, OpRsh16Ux8, OpRsh16Ux16, OpRsh16Ux32, OpRsh16Ux64, OpRsh32Ux8, OpRsh32Ux16, OpRsh32Ux32, OpRsh32Ux64, diff --git a/test/codegen/arithmetic.go b/test/codegen/arithmetic.go index a076664e8e..8f25974376 100644 --- a/test/codegen/arithmetic.go +++ b/test/codegen/arithmetic.go @@ -451,3 +451,14 @@ func addSpecial(a, b, c uint32) (uint32, uint32, uint32) { c += 128 return a, b, c } + + +// Divide -> shift rules usually require fixup for negative inputs. +// If the input is non-negative, make sure the fixup is eliminated. +func divInt(v int64) int64 { + if v < 0 { + return 0 + } + // amd64:-`.*SARQ.*63,`, -".*SHRQ", ".*SARQ.*[$]9," + return v / 512 +} diff --git a/test/prove.go b/test/prove.go index e5636a452e..d37021d283 100644 --- a/test/prove.go +++ b/test/prove.go @@ -956,6 +956,70 @@ func negIndex2(n int) { useSlice(c) } +// Check that prove is zeroing these right shifts of positive ints by bit-width - 1. +// e.g (Rsh64x64 n (Const64 [63])) && ft.isNonNegative(n) -> 0 +func sh64(n int64) int64 { + if n < 0 { + return n + } + return n >> 63 // ERROR "Proved Rsh64x64 shifts to zero" +} + +func sh32(n int32) int32 { + if n < 0 { + return n + } + return n >> 31 // ERROR "Proved Rsh32x64 shifts to zero" +} + +func sh32x64(n int32) int32 { + if n < 0 { + return n + } + return n >> uint64(31) // ERROR "Proved Rsh32x64 shifts to zero" +} + +func sh16(n int16) int16 { + if n < 0 { + return n + } + return n >> 15 // ERROR "Proved Rsh16x64 shifts to zero" +} + +func sh64noopt(n int64) int64 { + return n >> 63 // not optimized; n could be negative +} + +// These cases are division of a positive signed integer by a power of 2. +// The opt pass doesnt have sufficient information to see that n is positive. +// So, instead, opt rewrites the division with a less-than-optimal replacement. +// Prove, which can see that n is nonnegative, cannot see the division because +// opt, an earlier pass, has already replaced it. +// The fix for this issue allows prove to zero a right shift that was added as +// part of the less-than-optimal reqwrite. That change by prove then allows +// lateopt to clean up all the unneccesary parts of the original division +// replacement. See issue #36159. +func divShiftClean(n int) int { + if n < 0 { + return n + } + return n / int(8) // ERROR "Proved Rsh64x64 shifts to zero" +} + +func divShiftClean64(n int64) int64 { + if n < 0 { + return n + } + return n / int64(16) // ERROR "Proved Rsh64x64 shifts to zero" +} + +func divShiftClean32(n int32) int32 { + if n < 0 { + return n + } + return n / int32(16) // ERROR "Proved Rsh32x64 shifts to zero" +} + //go:noinline func useInt(a int) { }