cmd/compile: combine x*n + y*n into (x+y)*n

There are a few cases where this can be useful. Apart from the obvious
(and silly)

  100*n + 200*n

where we generate one IMUL instead of two, consider:

  15*n + 31*n

Currently, the compiler strength-reduces both imuls, generating:

    0x0000 00000	MOVQ	"".n+8(SP), AX
	0x0005 00005 	MOVQ	AX, CX
	0x0008 00008 	SHLQ	$4, AX
	0x000c 00012 	SUBQ	CX, AX
	0x000f 00015 	MOVQ	CX, DX
	0x0012 00018 	SHLQ	$5, CX
	0x0016 00022 	SUBQ	DX, CX
	0x0019 00025 	ADDQ	CX, AX
	0x001c 00028 	MOVQ	AX, "".~r1+16(SP)
	0x0021 00033 	RET

But combining the imuls is both faster and shorter:

	0x0000 00000	MOVQ	"".n+8(SP), AX
	0x0005 00005 	IMULQ	$46, AX
	0x0009 00009	MOVQ	AX, "".~r1+16(SP)
	0x000e 00014 	RET

even without strength-reduction.

Moreover, consider:

  5*n + 7*(n+1) + 11*(n+2)

We already have a rule that rewrites 7(n+1) into 7n+7, so the
generated code (without imuls merging) looks like this:

	0x0000 00000 	MOVQ	"".n+8(SP), AX
	0x0005 00005 	LEAQ	(AX)(AX*4), CX
	0x0009 00009 	MOVQ	AX, DX
	0x000c 00012 	NEGQ	AX
	0x000f 00015 	LEAQ	(AX)(DX*8), AX
	0x0013 00019 	ADDQ	CX, AX
	0x0016 00022 	LEAQ	(DX)(CX*2), CX
	0x001a 00026 	LEAQ	29(AX)(CX*1), AX
	0x001f 00031 	MOVQ	AX, "".~r1+16(SP)

But with imuls merging, the 5n, 7n and 11n factors get merged, and the
generated code looks like this:

	0x0000 00000 	MOVQ	"".n+8(SP), AX
	0x0005 00005 	IMULQ	$23, AX
	0x0009 00009 	ADDQ	$29, AX
	0x000d 00013 	MOVQ	AX, "".~r1+16(SP)
	0x0012 00018 	RET

Which is both faster and shorter; that's also the exact same code that
clang and the intel c compiler generate for the above expression.

Change-Id: Ib4d5503f05d2f2efe31a1be14e2fe6cac33730a9
Reviewed-on: https://go-review.googlesource.com/55143
Reviewed-by: Keith Randall <khr@golang.org>
This commit is contained in:
Alberto Donizetti 2017-08-14 11:44:09 +02:00
parent e70fae8a64
commit a0453a180f
4 changed files with 1149 additions and 44 deletions

View file

@ -741,6 +741,29 @@ var linuxAMD64Tests = []*asmTest{
}`,
[]string{"\tPOPCNTQ\t", "support_popcnt"},
},
// multiplication merging tests
{
`
func mul1(n int) int {
return 15*n + 31*n
}`,
[]string{"\tIMULQ\t[$]46"}, // 46*n
},
{
`
func mul2(n int) int {
return 5*n + 7*(n+1) + 11*(n+2)
}`,
[]string{"\tIMULQ\t[$]23", "\tADDQ\t[$]29"}, // 23*n + 29
},
{
`
func mul3(a, n int) int {
return a*n + 19*n
}`,
[]string{"\tADDQ\t[$]19", "\tIMULQ"}, // (a+19)*n
},
// see issue 19595.
// We want to merge load+op in f58, but not in f59.
{
@ -928,6 +951,21 @@ var linux386Tests = []*asmTest{
`,
[]string{"\tMOVL\t\\(.*\\)\\(.*\\*1\\),"},
},
// multiplication merging tests
{
`
func mul1(n int) int {
return 9*n + 14*n
}`,
[]string{"\tIMULL\t[$]23"}, // 23*n
},
{
`
func mul2(a, n int) int {
return 19*a + a*n
}`,
[]string{"\tADDL\t[$]19", "\tIMULL"}, // (n+19)*a
},
}
var linuxS390XTests = []*asmTest{

View file

@ -322,6 +322,12 @@
(Mul32 (Const32 <t> [c]) (Add32 <t> (Const32 <t> [d]) x)) ->
(Add32 (Const32 <t> [int64(int32(c*d))]) (Mul32 <t> (Const32 <t> [c]) x))
// Rewrite x*y + x*z to x*(y+z)
(Add64 <t> (Mul64 x y) (Mul64 x z)) -> (Mul64 x (Add64 <t> y z))
(Add32 <t> (Mul32 x y) (Mul32 x z)) -> (Mul32 x (Add32 <t> y z))
(Add16 <t> (Mul16 x y) (Mul16 x z)) -> (Mul16 x (Add16 <t> y z))
(Add8 <t> (Mul8 x y) (Mul8 x z)) -> (Mul8 x (Add8 <t> y z))
// rewrite shifts of 8/16/32 bit consts into 64 bit consts to reduce
// the number of the other rewrite rules for const shifts
(Lsh64x32 <t> x (Const32 [c])) -> (Lsh64x64 x (Const64 <t> [int64(uint32(c))]))

File diff suppressed because it is too large Load diff

81
test/mergemul.go Normal file
View file

@ -0,0 +1,81 @@
// runoutput
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import "fmt"
// Check that expressions like (c*n + d*(n+k)) get correctly merged by
// the compiler into (c+d)*n + d*k (with c+d and d*k computed at
// compile time).
//
// The merging is performed by a combination of the multiplication
// merge rules
// (c*n + d*n) -> (c+d)*n
// and the distributive multiplication rules
// c * (d+x) -> c*d + c*x
// Generate a MergeTest that looks like this:
//
// a8, b8 = m1*n8 + m2*(n8+k), (m1+m2)*n8 + m2*k
// if a8 != b8 {
// // print error msg and panic
// }
func makeMergeTest(m1, m2, k int, size string) string {
model := " a" + size + ", b" + size
model += fmt.Sprintf(" = %%d*n%s + %%d*(n%s+%%d), (%%d+%%d)*n%s + (%%d*%%d)", size, size, size)
test := fmt.Sprintf(model, m1, m2, k, m1, m2, m2, k)
test += fmt.Sprintf(`
if a%s != b%s {
fmt.Printf("MergeTest(%d, %d, %d, %s) failed\n")
fmt.Printf("%%d != %%d\n", a%s, b%s)
panic("FAIL")
}
`, size, size, m1, m2, k, size, size, size)
return test + "\n"
}
func makeAllSizes(m1, m2, k int) string {
var tests string
tests += makeMergeTest(m1, m2, k, "8")
tests += makeMergeTest(m1, m2, k, "16")
tests += makeMergeTest(m1, m2, k, "32")
tests += makeMergeTest(m1, m2, k, "64")
tests += "\n"
return tests
}
func main() {
fmt.Println(`package main
import "fmt"
var n8 int8 = 42
var n16 int16 = 42
var n32 int32 = 42
var n64 int64 = 42
func main() {
var a8, b8 int8
var a16, b16 int16
var a32, b32 int32
var a64, b64 int64
`)
fmt.Println(makeAllSizes(03, 05, 0)) // 3*n + 5*n
fmt.Println(makeAllSizes(17, 33, 0))
fmt.Println(makeAllSizes(80, 45, 0))
fmt.Println(makeAllSizes(32, 64, 0))
fmt.Println(makeAllSizes(7, 11, +1)) // 7*n + 11*(n+1)
fmt.Println(makeAllSizes(9, 13, +2))
fmt.Println(makeAllSizes(11, 16, -1))
fmt.Println(makeAllSizes(17, 9, -2))
fmt.Println("}")
}