cmd/compile: optimize cmp to cmn under conditions < and >= on arm64

Under the right conditions we can optimize cmp comparisons to cmn
comparisons, such as:
func foo(a, b int) int {
  var c int
  if a + b < 0 {
  	c = 1
  }
  return c
}

Previously it's compiled as:
  ADD     R1, R0, R1
  CMP     $0, R1
  CSET    LT, R0
With this CL it's compiled as:
  CMN     R1, R0
  CSET    MI, R0
Here we need to pay attention to the overflow situation of a+b, the MI
flag means N==1, which doesn't honor the overflow flag V, its value
depends only on the sign of the result. So it has the same semantic of
the Go code, so it's correct.

Similarly, this CL also optimizes the case of >= comparison
using the PL conditional flag.

Change-Id: I47179faba5b30cca84ea69bafa2ad5241bf6dfba
Reviewed-on: https://go-review.googlesource.com/c/go/+/476116
Run-TryBot: Eric Fang <eric.fang@arm.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: David Chase <drchase@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
erifan01 2023-03-14 09:25:07 +08:00 committed by Eric Fang
parent e81cc9119f
commit 42f99b203d
7 changed files with 497 additions and 10 deletions

View file

@ -1105,7 +1105,9 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) {
ssa.OpARM64NotLessThanF,
ssa.OpARM64NotLessEqualF,
ssa.OpARM64NotGreaterThanF,
ssa.OpARM64NotGreaterEqualF:
ssa.OpARM64NotGreaterEqualF,
ssa.OpARM64LessThanNoov,
ssa.OpARM64GreaterEqualNoov:
// generate boolean values using CSET
p := s.Prog(arm64.ACSET)
p.From.Type = obj.TYPE_SPECIAL // assembler encodes conditional bits in Offset
@ -1196,6 +1198,9 @@ var condBits = map[ssa.Op]arm64.SpecialOperand{
ssa.OpARM64NotLessEqualF: arm64.SPOP_HI, // Greater than or unordered
ssa.OpARM64NotGreaterThanF: arm64.SPOP_LE, // Less than, equal to or unordered
ssa.OpARM64NotGreaterEqualF: arm64.SPOP_LT, // Less than or unordered
ssa.OpARM64LessThanNoov: arm64.SPOP_MI, // Less than but without honoring overflow
ssa.OpARM64GreaterEqualNoov: arm64.SPOP_PL, // Greater than or equal to but without honoring overflow
}
var blockJump = map[ssa.BlockKind]struct {

View file

@ -650,15 +650,15 @@
((Equal|NotEqual) (CMPW x z:(NEG y))) && z.Uses == 1 => ((Equal|NotEqual) (CMNW x y))
// For conditional instructions such as CSET, CSEL.
// TODO: add support for LT, LE, GT, GE, overflow needs to be considered.
((Equal|NotEqual) (CMPconst [0] x:(ADDconst [c] y))) && x.Uses == 1 => ((Equal|NotEqual) (CMNconst [c] y))
((Equal|NotEqual) (CMPWconst [0] x:(ADDconst [c] y))) && x.Uses == 1 => ((Equal|NotEqual) (CMNWconst [int32(c)] y))
((Equal|NotEqual) (CMPconst [0] z:(ADD x y))) && z.Uses == 1 => ((Equal|NotEqual) (CMN x y))
((Equal|NotEqual) (CMPWconst [0] z:(ADD x y))) && z.Uses == 1 => ((Equal|NotEqual) (CMNW x y))
((Equal|NotEqual) (CMPconst [0] z:(MADD a x y))) && z.Uses == 1 => ((Equal|NotEqual) (CMN a (MUL <x.Type> x y)))
((Equal|NotEqual) (CMPconst [0] z:(MSUB a x y))) && z.Uses == 1 => ((Equal|NotEqual) (CMP a (MUL <x.Type> x y)))
((Equal|NotEqual) (CMPWconst [0] z:(MADDW a x y))) && z.Uses == 1 => ((Equal|NotEqual) (CMNW a (MULW <x.Type> x y)))
((Equal|NotEqual) (CMPWconst [0] z:(MSUBW a x y))) && z.Uses == 1 => ((Equal|NotEqual) (CMPW a (MULW <x.Type> x y)))
// TODO: add support for LE, GT, overflow needs to be considered.
((Equal|NotEqual|LessThan|GreaterEqual) (CMPconst [0] x:(ADDconst [c] y))) && x.Uses == 1 => ((Equal|NotEqual|LessThanNoov|GreaterEqualNoov) (CMNconst [c] y))
((Equal|NotEqual|LessThan|GreaterEqual) (CMPWconst [0] x:(ADDconst [c] y))) && x.Uses == 1 => ((Equal|NotEqual|LessThanNoov|GreaterEqualNoov) (CMNWconst [int32(c)] y))
((Equal|NotEqual|LessThan|GreaterEqual) (CMPconst [0] z:(ADD x y))) && z.Uses == 1 => ((Equal|NotEqual|LessThanNoov|GreaterEqualNoov) (CMN x y))
((Equal|NotEqual|LessThan|GreaterEqual) (CMPWconst [0] z:(ADD x y))) && z.Uses == 1 => ((Equal|NotEqual|LessThanNoov|GreaterEqualNoov) (CMNW x y))
((Equal|NotEqual|LessThan|GreaterEqual) (CMPconst [0] z:(MADD a x y))) && z.Uses == 1 => ((Equal|NotEqual|LessThanNoov|GreaterEqualNoov) (CMN a (MUL <x.Type> x y)))
((Equal|NotEqual|LessThan|GreaterEqual) (CMPconst [0] z:(MSUB a x y))) && z.Uses == 1 => ((Equal|NotEqual|LessThanNoov|GreaterEqualNoov) (CMP a (MUL <x.Type> x y)))
((Equal|NotEqual|LessThan|GreaterEqual) (CMPWconst [0] z:(MADDW a x y))) && z.Uses == 1 => ((Equal|NotEqual|LessThanNoov|GreaterEqualNoov) (CMNW a (MULW <x.Type> x y)))
((Equal|NotEqual|LessThan|GreaterEqual) (CMPWconst [0] z:(MSUBW a x y))) && z.Uses == 1 => ((Equal|NotEqual|LessThanNoov|GreaterEqualNoov) (CMPW a (MULW <x.Type> x y)))
((CMPconst|CMNconst) [c] y) && c < 0 && c != -1<<63 => ((CMNconst|CMPconst) [-c] y)
((CMPWconst|CMNWconst) [c] y) && c < 0 && c != -1<<31 => ((CMNWconst|CMPWconst) [-c] y)

View file

@ -514,6 +514,9 @@ func init() {
{name: "NotLessEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x>y || x is unordered with y, false otherwise.
{name: "NotGreaterThanF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<=y || x is unordered with y, false otherwise.
{name: "NotGreaterEqualF", argLength: 1, reg: readflags}, // bool, true flags encode floating-point x<y || x is unordered with y, false otherwise.
{name: "LessThanNoov", argLength: 1, reg: readflags}, // bool, true flags encode signed x<y but without honoring overflow, false otherwise.
{name: "GreaterEqualNoov", argLength: 1, reg: readflags}, // bool, true flags encode signed x>=y but without honoring overflow, false otherwise.
// duffzero
// arg0 = address of memory to zero
// arg1 = mem

View file

@ -1670,6 +1670,8 @@ const (
OpARM64NotLessEqualF
OpARM64NotGreaterThanF
OpARM64NotGreaterEqualF
OpARM64LessThanNoov
OpARM64GreaterEqualNoov
OpARM64DUFFZERO
OpARM64LoweredZero
OpARM64DUFFCOPY
@ -22276,6 +22278,24 @@ var opcodeTable = [...]opInfo{
},
},
},
{
name: "LessThanNoov",
argLen: 1,
reg: regInfo{
outputs: []outputInfo{
{0, 670826495}, // R0 R1 R2 R3 R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R19 R20 R21 R22 R23 R24 R25 R26 R30
},
},
},
{
name: "GreaterEqualNoov",
argLen: 1,
reg: regInfo{
outputs: []outputInfo{
{0, 670826495}, // R0 R1 R2 R3 R4 R5 R6 R7 R8 R9 R10 R11 R12 R13 R14 R15 R16 R17 R19 R20 R21 R22 R23 R24 R25 R26 R30
},
},
},
{
name: "DUFFZERO",
auxType: auxInt64,

View file

@ -5721,6 +5721,196 @@ func rewriteValueARM64_OpARM64GreaterEqual(v *Value) bool {
v.AddArg(v0)
return true
}
// match: (GreaterEqual (CMPconst [0] x:(ADDconst [c] y)))
// cond: x.Uses == 1
// result: (GreaterEqualNoov (CMNconst [c] y))
for {
if v_0.Op != OpARM64CMPconst || auxIntToInt64(v_0.AuxInt) != 0 {
break
}
x := v_0.Args[0]
if x.Op != OpARM64ADDconst {
break
}
c := auxIntToInt64(x.AuxInt)
y := x.Args[0]
if !(x.Uses == 1) {
break
}
v.reset(OpARM64GreaterEqualNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMNconst, types.TypeFlags)
v0.AuxInt = int64ToAuxInt(c)
v0.AddArg(y)
v.AddArg(v0)
return true
}
// match: (GreaterEqual (CMPWconst [0] x:(ADDconst [c] y)))
// cond: x.Uses == 1
// result: (GreaterEqualNoov (CMNWconst [int32(c)] y))
for {
if v_0.Op != OpARM64CMPWconst || auxIntToInt32(v_0.AuxInt) != 0 {
break
}
x := v_0.Args[0]
if x.Op != OpARM64ADDconst {
break
}
c := auxIntToInt64(x.AuxInt)
y := x.Args[0]
if !(x.Uses == 1) {
break
}
v.reset(OpARM64GreaterEqualNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMNWconst, types.TypeFlags)
v0.AuxInt = int32ToAuxInt(int32(c))
v0.AddArg(y)
v.AddArg(v0)
return true
}
// match: (GreaterEqual (CMPconst [0] z:(ADD x y)))
// cond: z.Uses == 1
// result: (GreaterEqualNoov (CMN x y))
for {
if v_0.Op != OpARM64CMPconst || auxIntToInt64(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64ADD {
break
}
y := z.Args[1]
x := z.Args[0]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64GreaterEqualNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMN, types.TypeFlags)
v0.AddArg2(x, y)
v.AddArg(v0)
return true
}
// match: (GreaterEqual (CMPWconst [0] z:(ADD x y)))
// cond: z.Uses == 1
// result: (GreaterEqualNoov (CMNW x y))
for {
if v_0.Op != OpARM64CMPWconst || auxIntToInt32(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64ADD {
break
}
y := z.Args[1]
x := z.Args[0]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64GreaterEqualNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMNW, types.TypeFlags)
v0.AddArg2(x, y)
v.AddArg(v0)
return true
}
// match: (GreaterEqual (CMPconst [0] z:(MADD a x y)))
// cond: z.Uses == 1
// result: (GreaterEqualNoov (CMN a (MUL <x.Type> x y)))
for {
if v_0.Op != OpARM64CMPconst || auxIntToInt64(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64MADD {
break
}
y := z.Args[2]
a := z.Args[0]
x := z.Args[1]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64GreaterEqualNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMN, types.TypeFlags)
v1 := b.NewValue0(v.Pos, OpARM64MUL, x.Type)
v1.AddArg2(x, y)
v0.AddArg2(a, v1)
v.AddArg(v0)
return true
}
// match: (GreaterEqual (CMPconst [0] z:(MSUB a x y)))
// cond: z.Uses == 1
// result: (GreaterEqualNoov (CMP a (MUL <x.Type> x y)))
for {
if v_0.Op != OpARM64CMPconst || auxIntToInt64(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64MSUB {
break
}
y := z.Args[2]
a := z.Args[0]
x := z.Args[1]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64GreaterEqualNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMP, types.TypeFlags)
v1 := b.NewValue0(v.Pos, OpARM64MUL, x.Type)
v1.AddArg2(x, y)
v0.AddArg2(a, v1)
v.AddArg(v0)
return true
}
// match: (GreaterEqual (CMPWconst [0] z:(MADDW a x y)))
// cond: z.Uses == 1
// result: (GreaterEqualNoov (CMNW a (MULW <x.Type> x y)))
for {
if v_0.Op != OpARM64CMPWconst || auxIntToInt32(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64MADDW {
break
}
y := z.Args[2]
a := z.Args[0]
x := z.Args[1]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64GreaterEqualNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMNW, types.TypeFlags)
v1 := b.NewValue0(v.Pos, OpARM64MULW, x.Type)
v1.AddArg2(x, y)
v0.AddArg2(a, v1)
v.AddArg(v0)
return true
}
// match: (GreaterEqual (CMPWconst [0] z:(MSUBW a x y)))
// cond: z.Uses == 1
// result: (GreaterEqualNoov (CMPW a (MULW <x.Type> x y)))
for {
if v_0.Op != OpARM64CMPWconst || auxIntToInt32(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64MSUBW {
break
}
y := z.Args[2]
a := z.Args[0]
x := z.Args[1]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64GreaterEqualNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMPW, types.TypeFlags)
v1 := b.NewValue0(v.Pos, OpARM64MULW, x.Type)
v1.AddArg2(x, y)
v0.AddArg2(a, v1)
v.AddArg(v0)
return true
}
// match: (GreaterEqual (FlagConstant [fc]))
// result: (MOVDconst [b2i(fc.ge())])
for {
@ -6245,6 +6435,196 @@ func rewriteValueARM64_OpARM64LessThan(v *Value) bool {
v.AddArg(v0)
return true
}
// match: (LessThan (CMPconst [0] x:(ADDconst [c] y)))
// cond: x.Uses == 1
// result: (LessThanNoov (CMNconst [c] y))
for {
if v_0.Op != OpARM64CMPconst || auxIntToInt64(v_0.AuxInt) != 0 {
break
}
x := v_0.Args[0]
if x.Op != OpARM64ADDconst {
break
}
c := auxIntToInt64(x.AuxInt)
y := x.Args[0]
if !(x.Uses == 1) {
break
}
v.reset(OpARM64LessThanNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMNconst, types.TypeFlags)
v0.AuxInt = int64ToAuxInt(c)
v0.AddArg(y)
v.AddArg(v0)
return true
}
// match: (LessThan (CMPWconst [0] x:(ADDconst [c] y)))
// cond: x.Uses == 1
// result: (LessThanNoov (CMNWconst [int32(c)] y))
for {
if v_0.Op != OpARM64CMPWconst || auxIntToInt32(v_0.AuxInt) != 0 {
break
}
x := v_0.Args[0]
if x.Op != OpARM64ADDconst {
break
}
c := auxIntToInt64(x.AuxInt)
y := x.Args[0]
if !(x.Uses == 1) {
break
}
v.reset(OpARM64LessThanNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMNWconst, types.TypeFlags)
v0.AuxInt = int32ToAuxInt(int32(c))
v0.AddArg(y)
v.AddArg(v0)
return true
}
// match: (LessThan (CMPconst [0] z:(ADD x y)))
// cond: z.Uses == 1
// result: (LessThanNoov (CMN x y))
for {
if v_0.Op != OpARM64CMPconst || auxIntToInt64(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64ADD {
break
}
y := z.Args[1]
x := z.Args[0]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64LessThanNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMN, types.TypeFlags)
v0.AddArg2(x, y)
v.AddArg(v0)
return true
}
// match: (LessThan (CMPWconst [0] z:(ADD x y)))
// cond: z.Uses == 1
// result: (LessThanNoov (CMNW x y))
for {
if v_0.Op != OpARM64CMPWconst || auxIntToInt32(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64ADD {
break
}
y := z.Args[1]
x := z.Args[0]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64LessThanNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMNW, types.TypeFlags)
v0.AddArg2(x, y)
v.AddArg(v0)
return true
}
// match: (LessThan (CMPconst [0] z:(MADD a x y)))
// cond: z.Uses == 1
// result: (LessThanNoov (CMN a (MUL <x.Type> x y)))
for {
if v_0.Op != OpARM64CMPconst || auxIntToInt64(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64MADD {
break
}
y := z.Args[2]
a := z.Args[0]
x := z.Args[1]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64LessThanNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMN, types.TypeFlags)
v1 := b.NewValue0(v.Pos, OpARM64MUL, x.Type)
v1.AddArg2(x, y)
v0.AddArg2(a, v1)
v.AddArg(v0)
return true
}
// match: (LessThan (CMPconst [0] z:(MSUB a x y)))
// cond: z.Uses == 1
// result: (LessThanNoov (CMP a (MUL <x.Type> x y)))
for {
if v_0.Op != OpARM64CMPconst || auxIntToInt64(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64MSUB {
break
}
y := z.Args[2]
a := z.Args[0]
x := z.Args[1]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64LessThanNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMP, types.TypeFlags)
v1 := b.NewValue0(v.Pos, OpARM64MUL, x.Type)
v1.AddArg2(x, y)
v0.AddArg2(a, v1)
v.AddArg(v0)
return true
}
// match: (LessThan (CMPWconst [0] z:(MADDW a x y)))
// cond: z.Uses == 1
// result: (LessThanNoov (CMNW a (MULW <x.Type> x y)))
for {
if v_0.Op != OpARM64CMPWconst || auxIntToInt32(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64MADDW {
break
}
y := z.Args[2]
a := z.Args[0]
x := z.Args[1]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64LessThanNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMNW, types.TypeFlags)
v1 := b.NewValue0(v.Pos, OpARM64MULW, x.Type)
v1.AddArg2(x, y)
v0.AddArg2(a, v1)
v.AddArg(v0)
return true
}
// match: (LessThan (CMPWconst [0] z:(MSUBW a x y)))
// cond: z.Uses == 1
// result: (LessThanNoov (CMPW a (MULW <x.Type> x y)))
for {
if v_0.Op != OpARM64CMPWconst || auxIntToInt32(v_0.AuxInt) != 0 {
break
}
z := v_0.Args[0]
if z.Op != OpARM64MSUBW {
break
}
y := z.Args[2]
a := z.Args[0]
x := z.Args[1]
if !(z.Uses == 1) {
break
}
v.reset(OpARM64LessThanNoov)
v0 := b.NewValue0(v.Pos, OpARM64CMPW, types.TypeFlags)
v1 := b.NewValue0(v.Pos, OpARM64MULW, x.Type)
v1.AddArg2(x, y)
v0.AddArg2(a, v1)
v.AddArg(v0)
return true
}
// match: (LessThan (FlagConstant [fc]))
// result: (MOVDconst [b2i(fc.lt())])
for {

View file

@ -36,6 +36,7 @@ var crTests = []struct {
{"AddConst64", testAddConst64},
{"AddConst32", testAddConst32},
{"AddVar64", testAddVar64},
{"AddVar64Cset", testAddVar64Cset},
{"AddVar32", testAddVar32},
{"MAddVar64", testMAddVar64},
{"MAddVar32", testMAddVar32},
@ -198,6 +199,41 @@ func testAddVar64(t *testing.T) {
}
}
// var + var, cset
func testAddVar64Cset(t *testing.T) {
var a int
if x64+v64 < 0 {
a = 1
}
if a != 1 {
t.Errorf("'%#x + %#x < 0' failed", x64, v64)
}
a = 0
if y64+v64_n >= 0 {
a = 1
}
if a != 1 {
t.Errorf("'%#x + %#x >= 0' failed", y64, v64_n)
}
a = 1
if x64+v64 >= 0 {
a = 0
}
if a == 0 {
t.Errorf("'%#x + %#x >= 0' failed", x64, v64)
}
a = 1
if y64+v64_n < 0 {
a = 0
}
if a == 0 {
t.Errorf("'%#x + %#x < 0' failed", y64, v64_n)
}
}
// 32-bit var+var
func testAddVar32(t *testing.T) {
if x32+v32 < 0 {

View file

@ -456,6 +456,7 @@ func CmpToZero_ex5(e, f int32, u uint32) int {
}
return 0
}
func UintLtZero(a uint8, b uint16, c uint32, d uint64) int {
// amd64: -`(TESTB|TESTW|TESTL|TESTQ|JCC|JCS)`
// arm64: -`(CMPW|CMP|BHS|BLO)`
@ -704,3 +705,45 @@ func cmpToCmn(a, b, c, d int) int {
}
return c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c10 + c11
}
func cmpToCmnLessThan(a, b, c, d int) int {
var c1, c2, c3, c4 int
// arm64:`CMN`,`CSET\tMI`,-`CMP`
if a+1 < 0 {
c1 = 1
}
// arm64:`CMN`,`CSET\tMI`,-`CMP`
if a+b < 0 {
c2 = 1
}
// arm64:`CMN`,`CSET\tMI`,-`CMP`
if a*b+c < 0 {
c3 = 1
}
// arm64:`CMP`,`CSET\tMI`,-`CMN`
if a-b*c < 0 {
c4 = 1
}
return c1 + c2 + c3 + c4
}
func cmpToCmnGreaterThanEqual(a, b, c, d int) int {
var c1, c2, c3, c4 int
// arm64:`CMN`,`CSET\tPL`,-`CMP`
if a+1 >= 0 {
c1 = 1
}
// arm64:`CMN`,`CSET\tPL`,-`CMP`
if a+b >= 0 {
c2 = 1
}
// arm64:`CMN`,`CSET\tPL`,-`CMP`
if a*b+c >= 0 {
c3 = 1
}
// arm64:`CMP`,`CSET\tPL`,-`CMN`
if a-b*c >= 0 {
c4 = 1
}
return c1 + c2 + c3 + c4
}