diff --git a/src/cmd/compile/internal/gc/subr.go b/src/cmd/compile/internal/gc/subr.go index 365cfad81f..566403bcde 100644 --- a/src/cmd/compile/internal/gc/subr.go +++ b/src/cmd/compile/internal/gc/subr.go @@ -552,19 +552,6 @@ func methtype(t *types.Type) *types.Type { return nil } -func cplxsubtype(et types.EType) types.EType { - switch et { - case TCOMPLEX64: - return TFLOAT32 - - case TCOMPLEX128: - return TFLOAT64 - } - - Fatalf("cplxsubtype: %v\n", et) - return 0 -} - // eqtype reports whether t1 and t2 are identical, following the spec rules. // // Any cyclic type must go through a named type, and if one is diff --git a/src/cmd/compile/internal/gc/typecheck.go b/src/cmd/compile/internal/gc/typecheck.go index b3dfe9dc8c..f21cc8f826 100644 --- a/src/cmd/compile/internal/gc/typecheck.go +++ b/src/cmd/compile/internal/gc/typecheck.go @@ -1365,7 +1365,7 @@ OpSwitch: ok = okforcap[t.Etype] } if !ok { - yyerror("invalid argument %L for %v", n.Left, n.Op) + yyerror("invalid argument %L for %v", l, n.Op) n.Type = nil return n } @@ -1401,7 +1401,6 @@ OpSwitch: } n.Left = typecheck(n.Left, Erv) - n.Left = defaultlit(n.Left, nil) l := n.Left t := l.Type if t == nil { @@ -1409,22 +1408,57 @@ OpSwitch: return n } - if !t.IsComplex() { - yyerror("invalid argument %L for %v", n.Left, n.Op) + if t.Etype != TIDEAL && !t.IsComplex() { + yyerror("invalid argument %L for %v", l, n.Op) n.Type = nil return n } - if Isconst(l, CTCPLX) { - r := n - if n.Op == OREAL { - n = nodfltconst(&l.Val().U.(*Mpcplx).Real) - } else { - n = nodfltconst(&l.Val().U.(*Mpcplx).Imag) + + // if the argument is a constant, the result is a constant + // (any untyped numeric constant can be represented as a + // complex number) + if l.Op == OLITERAL { + var re, im *Mpflt + switch consttype(l) { + case CTINT, CTRUNE: + re = newMpflt() + re.SetInt(l.Val().U.(*Mpint)) + // im = 0 + case CTFLT: + re = l.Val().U.(*Mpflt) + // im = 0 + case CTCPLX: + re = &l.Val().U.(*Mpcplx).Real + im = &l.Val().U.(*Mpcplx).Imag + default: + yyerror("invalid argument %L for %v", l, n.Op) + n.Type = nil + return n } - n.Orig = r + if n.Op == OIMAG { + if im == nil { + im = newMpflt() + } + re = im + } + orig := n + n = nodfltconst(re) + n.Orig = orig } - n.Type = types.Types[cplxsubtype(t.Etype)] + // determine result type + et := t.Etype + switch et { + case TIDEAL: + // result is ideal + case TCOMPLEX64: + et = TFLOAT32 + case TCOMPLEX128: + et = TFLOAT64 + default: + Fatalf("unexpected Etype: %v\n", et) + } + n.Type = types.Types[et] break OpSwitch case OCOMPLEX: diff --git a/test/fixedbugs/issue11945.go b/test/fixedbugs/issue11945.go new file mode 100644 index 0000000000..510b6555c6 --- /dev/null +++ b/test/fixedbugs/issue11945.go @@ -0,0 +1,71 @@ +// run + +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "fmt" + +// issue 17446 +const ( + _ = real(0) // from bug report + _ = imag(0) // from bug report + + // if the arguments are untyped, the results must be untyped + // (and compatible with types that can represent the values) + _ int = real(1) + _ int = real('a') + _ int = real(2.0) + _ int = real(3i) + + _ float32 = real(1) + _ float32 = real('a') + _ float32 = real(2.1) + _ float32 = real(3.2i) + + _ float64 = real(1) + _ float64 = real('a') + _ float64 = real(2.1) + _ float64 = real(3.2i) + + _ int = imag(1) + _ int = imag('a') + _ int = imag(2.1 + 3i) + _ int = imag(3i) + + _ float32 = imag(1) + _ float32 = imag('a') + _ float32 = imag(2.1 + 3.1i) + _ float32 = imag(3i) + + _ float64 = imag(1) + _ float64 = imag('a') + _ float64 = imag(2.1 + 3.1i) + _ float64 = imag(3i) +) + +var tests = []struct { + code string + got, want interface{} +}{ + {"real(1)", real(1), 1.0}, + {"real('a')", real('a'), float64('a')}, + {"real(2.0)", real(2.0), 2.0}, + {"real(3.2i)", real(3.2i), 0.0}, + + {"imag(1)", imag(1), 0.0}, + {"imag('a')", imag('a'), 0.0}, + {"imag(2.1 + 3.1i)", imag(2.1 + 3.1i), 3.1}, + {"imag(3i)", imag(3i), 3.0}, +} + +func main() { + // verify compile-time evaluated constant expressions + for _, test := range tests { + if test.got != test.want { + panic(fmt.Sprintf("%s: %v (%T) != %v (%T)", test.code, test.got, test.got, test.want, test.want)) + } + } +}