From ff0c0dbca6a7a3a3d6528481829679be4c9d7e94 Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Tue, 27 Jul 2021 19:13:26 -0700 Subject: [PATCH] [dev.typeparams] cmd/compile/internal/types2: use type terms to represent unions This is just an internal representation change for now. Change-Id: I7e0126e9b17850ec020c2a60db13582761557bea Reviewed-on: https://go-review.googlesource.com/c/go/+/338092 Trust: Robert Griesemer Reviewed-by: Robert Findley --- src/cmd/compile/internal/types2/infer.go | 18 ++- src/cmd/compile/internal/types2/interface.go | 6 +- src/cmd/compile/internal/types2/operand.go | 6 +- src/cmd/compile/internal/types2/predicates.go | 16 +-- .../compile/internal/types2/sizeof_test.go | 3 +- src/cmd/compile/internal/types2/subst.go | 22 +++- .../types2/testdata/examples/constraints.go2 | 6 +- src/cmd/compile/internal/types2/type.go | 2 +- src/cmd/compile/internal/types2/typeset.go | 7 ++ src/cmd/compile/internal/types2/typestring.go | 6 +- src/cmd/compile/internal/types2/union.go | 117 +++++++++++------- 11 files changed, 130 insertions(+), 79 deletions(-) diff --git a/src/cmd/compile/internal/types2/infer.go b/src/cmd/compile/internal/types2/infer.go index 6e7a217709..00548b402e 100644 --- a/src/cmd/compile/internal/types2/infer.go +++ b/src/cmd/compile/internal/types2/infer.go @@ -308,7 +308,7 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) { } case *Union: - return w.isParameterizedList(t.types) + return w.isParameterizedTermList(t.terms) case *Signature: // t.tparams may not be nil if we are looking at a signature @@ -336,7 +336,7 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) { return w.isParameterized(t.elem) case *Named: - return w.isParameterizedList(t.targs) + return w.isParameterizedTypeList(t.targs) case *TypeParam: // t must be one of w.tparams @@ -349,7 +349,7 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) { return false } -func (w *tpWalker) isParameterizedList(list []Type) bool { +func (w *tpWalker) isParameterizedTypeList(list []Type) bool { for _, t := range list { if w.isParameterized(t) { return true @@ -358,6 +358,15 @@ func (w *tpWalker) isParameterizedList(list []Type) bool { return false } +func (w *tpWalker) isParameterizedTermList(list []*term) bool { + for _, t := range list { + if w.isParameterized(t.typ) { + return true + } + } + return false +} + // inferB returns the list of actual type arguments inferred from the type parameters' // bounds and an initial set of type arguments. If type inference is impossible because // unification fails, an error is reported if report is set to true, the resulting types @@ -466,7 +475,8 @@ func (check *Checker) structuralType(constraint Type) Type { if u, _ := types.(*Union); u != nil { if u.NumTerms() == 1 { // TODO(gri) do we need to respect tilde? - return u.types[0] + t, _ := u.Term(0) + return t } return nil } diff --git a/src/cmd/compile/internal/types2/interface.go b/src/cmd/compile/internal/types2/interface.go index cf8ec1a5e2..fc1f5ffe00 100644 --- a/src/cmd/compile/internal/types2/interface.go +++ b/src/cmd/compile/internal/types2/interface.go @@ -30,7 +30,7 @@ func (t *Interface) is(f func(Type, bool) bool) bool { // TODO(gri) should settle on top or nil to represent this case return false // we must have at least one type! (was bug) case *Union: - return t.is(func(typ Type, tilde bool) bool { return f(typ, tilde) }) + return t.is(func(t *term) bool { return f(t.typ, t.tilde) }) default: return f(t, false) } @@ -260,8 +260,8 @@ func (check *Checker) interfaceType(ityp *Interface, iface *syntax.InterfaceType sortMethods(ityp.methods) // Compute type set with a non-nil *Checker as soon as possible - // to report any errors. Subsequent uses of type sets should be - // using this computed type set and won't need to pass in a *Checker. + // to report any errors. Subsequent uses of type sets will use + // this computed type set and won't need to pass in a *Checker. check.later(func() { computeTypeSet(check, iface.Pos(), ityp) }) } diff --git a/src/cmd/compile/internal/types2/operand.go b/src/cmd/compile/internal/types2/operand.go index 83cc239d93..01c720d1f7 100644 --- a/src/cmd/compile/internal/types2/operand.go +++ b/src/cmd/compile/internal/types2/operand.go @@ -270,13 +270,13 @@ func (x *operand) assignableTo(check *Checker, T Type, reason *string) (bool, er // x is an untyped value representable by a value of type T. if isUntyped(Vu) { if t, ok := Tu.(*Union); ok { - return t.is(func(t Type, tilde bool) bool { + return t.is(func(t *term) bool { // TODO(gri) this could probably be more efficient - if tilde { + if t.tilde { // TODO(gri) We need to check assignability // for the underlying type of x. } - ok, _ := x.assignableTo(check, t, reason) + ok, _ := x.assignableTo(check, t.typ, reason) return ok }), _IncompatibleAssign } diff --git a/src/cmd/compile/internal/types2/predicates.go b/src/cmd/compile/internal/types2/predicates.go index e448ade9e5..cd9fa3f564 100644 --- a/src/cmd/compile/internal/types2/predicates.go +++ b/src/cmd/compile/internal/types2/predicates.go @@ -238,20 +238,8 @@ func identical(x, y Type, cmpTags bool, p *ifacePair) bool { // types - each type appears exactly once. Thus, two union types // must contain the same number of types to have chance of // being equal. - if y, ok := y.(*Union); ok && x.NumTerms() == y.NumTerms() { - // Every type in x.types must be in y.types. - // Quadratic algorithm, but probably good enough for now. - // TODO(gri) we need a fast quick type ID/hash for all types. - L: - for i, xt := range x.types { - for j, yt := range y.types { - if Identical(xt, yt) && x.tilde[i] == y.tilde[j] { - continue L // x is in y.types - } - } - return false // x is not in y.types - } - return true + if y, ok := y.(*Union); ok { + return identicalTerms(x.terms, y.terms) } case *Interface: diff --git a/src/cmd/compile/internal/types2/sizeof_test.go b/src/cmd/compile/internal/types2/sizeof_test.go index a62b7cb3e2..70cf3709e5 100644 --- a/src/cmd/compile/internal/types2/sizeof_test.go +++ b/src/cmd/compile/internal/types2/sizeof_test.go @@ -27,12 +27,13 @@ func TestSizeof(t *testing.T) { {Pointer{}, 8, 16}, {Tuple{}, 12, 24}, {Signature{}, 44, 88}, - {Union{}, 24, 48}, + {Union{}, 12, 24}, {Interface{}, 40, 80}, {Map{}, 16, 32}, {Chan{}, 12, 24}, {Named{}, 88, 168}, {TypeParam{}, 28, 48}, + {term{}, 12, 24}, {top{}, 0, 0}, // Objects diff --git a/src/cmd/compile/internal/types2/subst.go b/src/cmd/compile/internal/types2/subst.go index 87e3e3018e..fc71343431 100644 --- a/src/cmd/compile/internal/types2/subst.go +++ b/src/cmd/compile/internal/types2/subst.go @@ -145,12 +145,12 @@ func (subst *subster) typ(typ Type) Type { } case *Union: - types, copied := subst.typeList(t.types) + terms, copied := subst.termList(t.terms) if copied { // TODO(gri) Remove duplicates that may have crept in after substitution // (unlikely but possible). This matters for the Identical // predicate on unions. - return newUnion(types, t.tilde) + return &Union{terms} } case *Interface: @@ -386,3 +386,21 @@ func (subst *subster) typeList(in []Type) (out []Type, copied bool) { } return } + +func (subst *subster) termList(in []*term) (out []*term, copied bool) { + out = in + for i, t := range in { + if u := subst.typ(t.typ); u != t.typ { + if !copied { + // first function that got substituted => allocate new out slice + // and copy all functions + new := make([]*term, len(in)) + copy(new, out) + out = new + copied = true + } + out[i] = &term{t.tilde, u} + } + } + return +} diff --git a/src/cmd/compile/internal/types2/testdata/examples/constraints.go2 b/src/cmd/compile/internal/types2/testdata/examples/constraints.go2 index d9805fe694..28aa19bb12 100644 --- a/src/cmd/compile/internal/types2/testdata/examples/constraints.go2 +++ b/src/cmd/compile/internal/types2/testdata/examples/constraints.go2 @@ -31,9 +31,9 @@ type ( _ interface{int|~ /* ERROR duplicate term int */ int } _ interface{~int|~ /* ERROR duplicate term int */ int } - // For now we do not permit interfaces with ~ or in unions. - _ interface{~ /* ERROR cannot use interface */ interface{}} - _ interface{int|interface /* ERROR cannot use interface */ {}} + // For now we do not permit interfaces with methods in unions. + _ interface{~ /* ERROR invalid use of ~ */ interface{}} + _ interface{int|interface /* ERROR cannot use .* in union */ { m() }} ) type ( diff --git a/src/cmd/compile/internal/types2/type.go b/src/cmd/compile/internal/types2/type.go index b41b50393d..80054372bc 100644 --- a/src/cmd/compile/internal/types2/type.go +++ b/src/cmd/compile/internal/types2/type.go @@ -60,7 +60,7 @@ func optype(typ Type) Type { // If we have a union with a single entry, ignore // any tilde because under(~t) == under(t). if u, _ := a.(*Union); u != nil && u.NumTerms() == 1 { - a = u.types[0] + a, _ = u.Term(0) } if a != typ { // a != typ and a is a type parameter => under(a) != typ, so this is ok diff --git a/src/cmd/compile/internal/types2/typeset.go b/src/cmd/compile/internal/types2/typeset.go index cc28625070..5a334b2f53 100644 --- a/src/cmd/compile/internal/types2/typeset.go +++ b/src/cmd/compile/internal/types2/typeset.go @@ -42,6 +42,13 @@ func (s *TypeSet) IsComparable() bool { return s.comparable && tcomparable } +// TODO(gri) IsTypeSet is not a great name. Find a better one. + +// IsTypeSet reports whether the type set s is represented by a finite set of underlying types. +func (s *TypeSet) IsTypeSet() bool { + return !s.comparable && len(s.methods) == 0 +} + // NumMethods returns the number of methods available. func (s *TypeSet) NumMethods() int { return len(s.methods) } diff --git a/src/cmd/compile/internal/types2/typestring.go b/src/cmd/compile/internal/types2/typestring.go index 74d2f1dc51..1da3f7f8ed 100644 --- a/src/cmd/compile/internal/types2/typestring.go +++ b/src/cmd/compile/internal/types2/typestring.go @@ -162,14 +162,14 @@ func writeType(buf *bytes.Buffer, typ Type, qf Qualifier, visited []Type) { buf.WriteString("⊥") break } - for i, e := range t.types { + for i, t := range t.terms { if i > 0 { buf.WriteByte('|') } - if t.tilde[i] { + if t.tilde { buf.WriteByte('~') } - writeType(buf, e, qf, visited) + writeType(buf, t.typ, qf, visited) } case *Interface: diff --git a/src/cmd/compile/internal/types2/union.go b/src/cmd/compile/internal/types2/union.go index 5983a73ec6..1215ef9057 100644 --- a/src/cmd/compile/internal/types2/union.go +++ b/src/cmd/compile/internal/types2/union.go @@ -10,10 +10,8 @@ import "cmd/compile/internal/syntax" // API // A Union represents a union of terms. -// A term is a type with a ~ (tilde) flag. type Union struct { - types []Type // types are unique - tilde []bool // if tilde[i] is set, terms[i] is of the form ~T + terms []*term } // NewUnion returns a new Union type with the given terms (types[i], tilde[i]). @@ -21,9 +19,9 @@ type Union struct { // of no types. func NewUnion(types []Type, tilde []bool) *Union { return newUnion(types, tilde) } -func (u *Union) IsEmpty() bool { return len(u.types) == 0 } -func (u *Union) NumTerms() int { return len(u.types) } -func (u *Union) Term(i int) (Type, bool) { return u.types[i], u.tilde[i] } +func (u *Union) IsEmpty() bool { return len(u.terms) == 0 } +func (u *Union) NumTerms() int { return len(u.terms) } +func (u *Union) Term(i int) (Type, bool) { t := u.terms[i]; return t.typ, t.tilde } func (u *Union) Underlying() Type { return u } func (u *Union) String() string { return TypeString(u, nil) } @@ -39,18 +37,20 @@ func newUnion(types []Type, tilde []bool) *Union { return emptyUnion } t := new(Union) - t.types = types - t.tilde = tilde + t.terms = make([]*term, len(types)) + for i, typ := range types { + t.terms[i] = &term{tilde[i], typ} + } return t } -// is reports whether f returned true for all terms (type, tilde) of u. -func (u *Union) is(f func(Type, bool) bool) bool { +// is reports whether f returns true for all terms of u. +func (u *Union) is(f func(*term) bool) bool { if u.IsEmpty() { return false } - for i, t := range u.types { - if !f(t, u.tilde[i]) { + for _, t := range u.terms { + if !f(t) { return false } } @@ -62,8 +62,8 @@ func (u *Union) underIs(f func(Type) bool) bool { if u.IsEmpty() { return false } - for _, t := range u.types { - if !f(under(t)) { + for _, t := range u.terms { + if !f(under(t.typ)) { return false } } @@ -83,7 +83,7 @@ func parseUnion(check *Checker, tlist []syntax.Expr) Type { } // Ensure that each type is only present once in the type list. - // It's ok to do this check at the end because it's not a requirement + // It's ok to do this check later because it's not a requirement // for correctness of the code. // Note: This is a quadratic algorithm, but unions tend to be short. check.later(func() { @@ -96,7 +96,7 @@ func parseUnion(check *Checker, tlist []syntax.Expr) Type { x := tlist[i] pos := syntax.StartPos(x) // We may not know the position of x if it was a typechecker- - // introduced ~T type of a type list entry T. Use the position + // introduced ~T term for a type list entry T. Use the position // of T instead. // TODO(gri) remove this test once we don't support type lists anymore if !pos.IsKnown() { @@ -106,13 +106,24 @@ func parseUnion(check *Checker, tlist []syntax.Expr) Type { } u := under(t) - if tilde[i] && !Identical(u, t) { - check.errorf(x, "invalid use of ~ (underlying type of %s is %s)", t, u) - continue // don't report another error for t + f, _ := u.(*Interface) + if tilde[i] { + if f != nil { + check.errorf(x, "invalid use of ~ (%s is an interface)", t) + continue // don't report another error for t + } + + if !Identical(u, t) { + check.errorf(x, "invalid use of ~ (underlying type of %s is %s)", t, u) + continue // don't report another error for t + } } - if _, ok := u.(*Interface); ok { - // A single type with a ~ is a single-term union. - check.errorf(pos, "cannot use interface %s with ~ or inside a union (implementation restriction)", t) + + // Stand-alone embedded interfaces are ok and are handled by the single-type case + // in the beginning. Embedded interfaces with tilde are excluded above. If we reach + // here, we must have at least two terms in the union. + if f != nil && !f.typeSet().IsTypeSet() { + check.errorf(pos, "cannot use %s in union (interface contains methods)", t) continue // don't report another error for t } @@ -164,25 +175,7 @@ func intersect(x, y Type) (r Type) { yu, _ := y.(*Union) switch { case xu != nil && yu != nil: - // Quadratic algorithm, but good enough for now. - // TODO(gri) fix asymptotic performance - var types []Type - var tilde []bool - for j, y := range yu.types { - yt := yu.tilde[j] - if r, rt := xu.intersect(y, yt); r != nil { - // Terms x[i] and y[j] match: Select the one that - // is not a ~t because that is the intersection - // type. If both are ~t, they are identical: - // T ∩ T = T - // T ∩ ~t = T - // ~t ∩ T = T - // ~t ∩ ~t = ~t - types = append(types, r) - tilde = append(tilde, rt) - } - } - return newUnion(types, tilde) + return &Union{intersectTerms(xu.terms, yu.terms)} case xu != nil: if r, _ := xu.intersect(y, false); r != nil { @@ -216,14 +209,16 @@ func includes(list []Type, typ Type) bool { // intersect computes the intersection of the union u and term (y, yt) // and returns the intersection term, if any. Otherwise the result is // (nil, false). +// TODO(gri) this needs to cleaned up/removed once we switch to lazy +// union type set computation. func (u *Union) intersect(y Type, yt bool) (Type, bool) { under_y := under(y) - for i, x := range u.types { - xt := u.tilde[i] + for _, x := range u.terms { + xt := x.tilde // determine which types xx, yy to compare - xx := x + xx := x.typ if yt { - xx = under(x) + xx = under(xx) } yy := y if xt { @@ -239,3 +234,35 @@ func (u *Union) intersect(y Type, yt bool) (Type, bool) { } return nil, false } + +func identicalTerms(list1, list2 []*term) bool { + if len(list1) != len(list2) { + return false + } + // Every term in list1 must be in list2. + // Quadratic algorithm, but probably good enough for now. + // TODO(gri) we need a fast quick type ID/hash for all types. +L: + for _, x := range list1 { + for _, y := range list2 { + if x.equal(y) { + continue L // x is in list2 + } + } + return false + } + return true +} + +func intersectTerms(list1, list2 []*term) (list []*term) { + // Quadratic algorithm, but good enough for now. + // TODO(gri) fix asymptotic performance + for _, x := range list1 { + for _, y := range list2 { + if r := x.intersect(y); r != nil { + list = append(list, r) + } + } + } + return +}