go/types, types2: unify core types for unbound type parameters

NOTE: Should this change cause problems, the new functionality
can be disabled by setting the flag enableCoreTypeUnification
in unify.go to false.

In the code

func f1[M1 map[K1]int, K1 comparable](m1 M1) {}

func f2[M2 map[K2]int, K2 comparable](m2 M2) {
	f1(m2)
}

type inference attempts to unify the types of m1 and m2. This leads
to the unification attempt of M1 and M2. The result is that the type
argument for M1 is inferred to be M2. Since there is no furter function
argument to use, constraint type inference attempts to infer the type
for K1 which is still missing. Constraint type inference (inferB in
the trace below) compares the inferred type for M1 (i.e., M2) against
map[K1]int. M2 is bound to f2, not f1; with the existing algorithm
that means M2 is simply a named type without further information.
Unification fails and with that type inference, and the type checker
reports an error.

-- inferA [M1₁, K1₂] ➞ []
M1₁ ≡ M2₃
.  M1₁ ➞ M2₃
-- inferB [M1₁, K1₂] ➞ [M2₃, <nil>]
M1₁ ➞ M2₃
M1₁ ≡ map[K1₂]int
.  M2₃ ≡ map[K1₂]int
.  M2₃ ≢ map[K1₂]int
M1₁ ≢ map[K1₂]int
=> inferB [M1₁, K1₂] ➞ []
=> inferA [M1₁, K1₂] ➞ []

With this change, when attempting to unify M2 with map[K1]int,
rather than failing, the unifier now considers the core type of
M2 which is map[K2]int. This leads to the unification of K1 and
K2; so type inference successfully infers M2 for M1 and K2 for K1.

-- inferA [M1₁, K1₂] ➞ []
M1₁ ≡ M2₃
.  M1₁ ➞ M2₃
-- inferB [M1₁, K1₂] ➞ [M2₃, <nil>]
M1₁ ➞ M2₃
M1₁ ≡ map[K1₂]int
.  M2₃ ≡ map[K1₂]int
.  .  core M2₃ ≡ map[K1₂]int
.  .  map[K2₄]int ≡ map[K1₂]int
.  .  .  K2₄ ≡ K1₂
.  .  .  .  K1₂ ➞ K2₄
.  .  .  int ≡ int
=> inferB [M1₁, K1₂] ➞ [M2₃, K2₄]
=> inferA [M1₁, K1₂] ➞ [M2₃, K2₄]

The fix for this issue was provided by Rob Findley in CL 380375;
this change is a copy of that fix with some additional changes:

- Constraint type inference doesn't simply use a type parameter's
  core type. Instead, if the type parameter type set consists of
  a single, possibly named type, it uses that type. Factor out the
  existing code into a new function adjCoreType. This change is not
  strictly needed but makes it easier to think about the code.

- Tracing code is added for debugging type inference. All tracing
  code is guarded with the flag traceEnabled which is set to false
  by default.

- The change to the unification algorithm is guarded with the flag
  enableCoreTypeUnification.

- The sprintf function has a new type switch case for lists of
  type parameters. This is used for tracing output (and was also
  missing for a panic that was printing type parameter lists).

Fixes #50755.

Change-Id: Ie50c8f4540fcd446a71b07e2b451a95339b530ce
Reviewed-on: https://go-review.googlesource.com/c/go/+/385354
Trust: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
Robert Griesemer 2022-02-11 16:47:58 -08:00
parent f14ad78e84
commit f03ab0e014
8 changed files with 280 additions and 22 deletions

View file

@ -124,6 +124,17 @@ func sprintf(qf Qualifier, debug bool, format string, args ...interface{}) strin
}
buf.WriteByte(']')
arg = buf.String()
case []*TypeParam:
var buf bytes.Buffer
buf.WriteByte('[')
for i, x := range a {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(typeString(x, qf, debug)) // use typeString so we get subscripts when debugging
}
buf.WriteByte(']')
arg = buf.String()
}
args[i] = arg
}

View file

@ -41,6 +41,13 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
}()
}
if traceInference {
check.dump("-- inferA %s ➞ %s", tparams, targs)
defer func() {
check.dump("=> inferA %s ➞ %s", tparams, result)
}()
}
// There must be at least one type parameter, and no more type arguments than type parameters.
n := len(tparams)
assert(n > 0 && len(targs) <= n)
@ -403,6 +410,13 @@ func (w *tpWalker) isParameterizedTypeList(list []Type) bool {
func (check *Checker) inferB(pos syntax.Pos, tparams []*TypeParam, targs []Type) (types []Type, index int) {
assert(len(tparams) >= len(targs) && len(targs) > 0)
if traceInference {
check.dump("-- inferB %s ➞ %s", tparams, targs)
defer func() {
check.dump("=> inferB %s ➞ %s", tparams, types)
}()
}
// Setup bidirectional unification between constraints
// and the corresponding type arguments (which may be nil!).
u := newUnifier(false)
@ -418,17 +432,11 @@ func (check *Checker) inferB(pos syntax.Pos, tparams []*TypeParam, targs []Type)
// If a constraint has a core type, unify the corresponding type parameter with it.
for _, tpar := range tparams {
sbound := coreType(tpar)
if sbound != nil {
// If the core type is the underlying type of a single
// defined type in the constraint, use that defined type instead.
if named, _ := tpar.singleType().(*Named); named != nil {
sbound = named
}
if !u.unify(tpar, sbound) {
if ctype := adjCoreType(tpar); ctype != nil {
if !u.unify(tpar, ctype) {
// TODO(gri) improve error message by providing the type arguments
// which we know already
check.errorf(pos, "%s does not match %s", tpar, sbound)
check.errorf(pos, "%s does not match %s", tpar, ctype)
return nil, 0
}
}
@ -525,6 +533,19 @@ func (check *Checker) inferB(pos syntax.Pos, tparams []*TypeParam, targs []Type)
return
}
func adjCoreType(tpar *TypeParam) Type {
// If the type parameter embeds a single, possibly named
// type, use that one instead of the core type (which is
// always the underlying type of that single type).
if single := tpar.singleType(); single != nil {
if debug {
assert(under(single) == coreType(tpar))
}
return single
}
return coreType(tpar)
}
type cycleFinder struct {
tparams []*TypeParam
types []Type

View file

@ -0,0 +1,27 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
func f1[M1 map[K1]int, K1 comparable](m1 M1) {}
func f2[M2 map[K2]int, K2 comparable](m2 M2) {
f1(m2)
}
// test case from issue
func Copy[MC ~map[KC]VC, KC comparable, VC any](dst, src MC) {
for k, v := range src {
dst[k] = v
}
}
func Merge[MM ~map[KM]VM, KM comparable, VM any](ms ...MM) MM {
result := MM{}
for _, m := range ms {
Copy(result, m)
}
return result
}

View file

@ -9,6 +9,7 @@ package types2
import (
"bytes"
"fmt"
"strings"
)
// The unifier maintains two separate sets of type parameters x and y
@ -41,6 +42,19 @@ const (
// Whether to panic when unificationDepthLimit is reached. Turn on when
// investigating infinite recursion.
panicAtUnificationDepthLimit = false
// If enableCoreTypeUnification is set, unification will consider
// the core types, if any, of non-local (unbound) type parameters.
enableCoreTypeUnification = true
// If traceInference is set, unification will print a trace of its operation.
// Interpretation of trace:
// x ≡ y attempt to unify types x and y
// p ➞ y type parameter p is set to type y (p is inferred to be y)
// p ⇄ q type parameters p and q match (p is inferred to be q and vice versa)
// x ≢ y types x and y cannot be unified
// [p, q, ...] ➞ [x, y, ...] mapping from type parameters to types
traceInference = false
)
// A unifier maintains the current type parameters for x and y
@ -58,6 +72,7 @@ type unifier struct {
// exactly. If exact is not set, a named type's underlying type
// is considered if unification would fail otherwise, and the
// direction of channels is ignored.
// TODO(gri) exact is not set anymore by a caller. Consider removing it.
func newUnifier(exact bool) *unifier {
u := &unifier{exact: exact}
u.x.unifier = u
@ -70,6 +85,10 @@ func (u *unifier) unify(x, y Type) bool {
return u.nify(x, y, nil)
}
func (u *unifier) tracef(format string, args ...interface{}) {
fmt.Println(strings.Repeat(". ", u.depth) + sprintf(nil, true, format, args...))
}
// A tparamsList describes a list of type parameters and the types inferred for them.
type tparamsList struct {
unifier *unifier
@ -121,6 +140,9 @@ func (d *tparamsList) init(tparams []*TypeParam) {
// If both type parameters already have a type associated with them and they are
// not joined, join fails and returns false.
func (u *unifier) join(i, j int) bool {
if traceInference {
u.tracef("%s ⇄ %s", u.x.tparams[i], u.y.tparams[j])
}
ti := u.x.indices[i]
tj := u.y.indices[j]
switch {
@ -210,6 +232,9 @@ func (d *tparamsList) at(i int) Type {
func (d *tparamsList) set(i int, typ Type) {
assert(typ != nil)
u := d.unifier
if traceInference {
u.tracef("%s ➞ %s", d.tparams[i], typ)
}
switch ti := d.indices[i]; {
case ti < 0:
u.types[-ti-1] = typ
@ -247,9 +272,16 @@ func (u *unifier) nifyEq(x, y Type, p *ifacePair) bool {
// adapted version of Checker.identical. For changes to that
// code the corresponding changes should be made here.
// Must not be called directly from outside the unifier.
func (u *unifier) nify(x, y Type, p *ifacePair) bool {
func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) {
if traceInference {
u.tracef("%s ≡ %s", x, y)
}
// Stop gap for cases where unification fails.
if u.depth >= unificationDepthLimit {
if traceInference {
u.tracef("depth %d >= %d", u.depth, unificationDepthLimit)
}
if panicAtUnificationDepthLimit {
panic("unification reached recursion depth limit")
}
@ -258,6 +290,9 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
u.depth++
defer func() {
u.depth--
if traceInference && !result {
u.tracef("%s ≢ %s", x, y)
}
}()
if !u.exact {
@ -267,8 +302,14 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
// (We use !hasName to exclude any type with a name, including
// basic types and type parameters; the rest are unamed types.)
if nx, _ := x.(*Named); nx != nil && !hasName(y) {
if traceInference {
u.tracef("under %s ≡ %s", nx, y)
}
return u.nify(nx.under(), y, p)
} else if ny, _ := y.(*Named); ny != nil && !hasName(x) {
if traceInference {
u.tracef("%s ≡ under %s", x, ny)
}
return u.nify(x, ny.under(), p)
}
}
@ -302,6 +343,35 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
return true
}
// If we get here and x or y is a type parameter, they are type parameters
// from outside our declaration list. Try to unify their core types, if any
// (see issue #50755 for a test case).
if enableCoreTypeUnification && !u.exact {
if isTypeParam(x) && !hasName(y) {
// When considering the type parameter for unification
// we look at the adjusted core type (adjCoreType).
// If the adjusted core type is a named type N; the
// corresponding core type is under(N). Since !u.exact
// and y doesn't have a name, unification will end up
// comparing under(N) to y, so we can just use the core
// type instead. Optimization.
if cx := coreType(x); cx != nil {
if traceInference {
u.tracef("core %s ≡ %s", x, y)
}
return u.nify(cx, y, p)
}
} else if isTypeParam(y) && !hasName(x) {
// see comment above
if cy := coreType(y); cy != nil {
if traceInference {
u.tracef("%s ≡ core %s", x, y)
}
return u.nify(x, cy, p)
}
}
}
// For type unification, do not shortcut (x == y) for identical
// types. Instead keep comparing them element-wise to unify the
// matching (and equal type parameter types). A simple test case
@ -490,7 +560,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
// avoid a crash in case of nil type
default:
panic(fmt.Sprintf("### u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams))
panic(sprintf(nil, true, "u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams))
}
return false

View file

@ -110,6 +110,17 @@ func sprintf(fset *token.FileSet, qf Qualifier, debug bool, format string, args
}
buf.WriteByte(']')
arg = buf.String()
case []*TypeParam:
var buf bytes.Buffer
buf.WriteByte('[')
for i, x := range a {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(typeString(x, qf, debug)) // use typeString so we get subscripts when debugging
}
buf.WriteByte(']')
arg = buf.String()
}
args[i] = arg
}

View file

@ -40,6 +40,13 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
}()
}
if traceInference {
check.dump("-- inferA %s ➞ %s", tparams, targs)
defer func() {
check.dump("=> inferA %s ➞ %s", tparams, result)
}()
}
// There must be at least one type parameter, and no more type arguments than type parameters.
n := len(tparams)
assert(n > 0 && len(targs) <= n)
@ -402,6 +409,13 @@ func (w *tpWalker) isParameterizedTypeList(list []Type) bool {
func (check *Checker) inferB(posn positioner, tparams []*TypeParam, targs []Type) (types []Type, index int) {
assert(len(tparams) >= len(targs) && len(targs) > 0)
if traceInference {
check.dump("-- inferB %s ➞ %s", tparams, targs)
defer func() {
check.dump("=> inferB %s ➞ %s", tparams, types)
}()
}
// Setup bidirectional unification between constraints
// and the corresponding type arguments (which may be nil!).
u := newUnifier(false)
@ -417,17 +431,11 @@ func (check *Checker) inferB(posn positioner, tparams []*TypeParam, targs []Type
// If a constraint has a core type, unify the corresponding type parameter with it.
for _, tpar := range tparams {
sbound := coreType(tpar)
if sbound != nil {
// If the core type is the underlying type of a single
// defined type in the constraint, use that defined type instead.
if named, _ := tpar.singleType().(*Named); named != nil {
sbound = named
}
if !u.unify(tpar, sbound) {
if ctype := adjCoreType(tpar); ctype != nil {
if !u.unify(tpar, ctype) {
// TODO(gri) improve error message by providing the type arguments
// which we know already
check.errorf(posn, _InvalidTypeArg, "%s does not match %s", tpar, sbound)
check.errorf(posn, _InvalidTypeArg, "%s does not match %s", tpar, ctype)
return nil, 0
}
}
@ -524,6 +532,19 @@ func (check *Checker) inferB(posn positioner, tparams []*TypeParam, targs []Type
return
}
func adjCoreType(tpar *TypeParam) Type {
// If the type parameter embeds a single, possibly named
// type, use that one instead of the core type (which is
// always the underlying type of that single type).
if single := tpar.singleType(); single != nil {
if debug {
assert(under(single) == coreType(tpar))
}
return single
}
return coreType(tpar)
}
type cycleFinder struct {
tparams []*TypeParam
types []Type

View file

@ -0,0 +1,27 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
func f1[M1 map[K1]int, K1 comparable](m1 M1) {}
func f2[M2 map[K2]int, K2 comparable](m2 M2) {
f1(m2)
}
// test case from issue
func Copy[MC ~map[KC]VC, KC comparable, VC any](dst, src MC) {
for k, v := range src {
dst[k] = v
}
}
func Merge[MM ~map[KM]VM, KM comparable, VM any](ms ...MM) MM {
result := MM{}
for _, m := range ms {
Copy(result, m)
}
return result
}

View file

@ -9,6 +9,7 @@ package types
import (
"bytes"
"fmt"
"strings"
)
// The unifier maintains two separate sets of type parameters x and y
@ -41,6 +42,19 @@ const (
// Whether to panic when unificationDepthLimit is reached. Turn on when
// investigating infinite recursion.
panicAtUnificationDepthLimit = false
// If enableCoreTypeUnification is set, unification will consider
// the core types, if any, of non-local (unbound) type parameters.
enableCoreTypeUnification = true
// If traceInference is set, unification will print a trace of its operation.
// Interpretation of trace:
// x ≡ y attempt to unify types x and y
// p ➞ y type parameter p is set to type y (p is inferred to be y)
// p ⇄ q type parameters p and q match (p is inferred to be q and vice versa)
// x ≢ y types x and y cannot be unified
// [p, q, ...] ➞ [x, y, ...] mapping from type parameters to types
traceInference = false
)
// A unifier maintains the current type parameters for x and y
@ -58,6 +72,7 @@ type unifier struct {
// exactly. If exact is not set, a named type's underlying type
// is considered if unification would fail otherwise, and the
// direction of channels is ignored.
// TODO(gri) exact is not set anymore by a caller. Consider removing it.
func newUnifier(exact bool) *unifier {
u := &unifier{exact: exact}
u.x.unifier = u
@ -70,6 +85,10 @@ func (u *unifier) unify(x, y Type) bool {
return u.nify(x, y, nil)
}
func (u *unifier) tracef(format string, args ...interface{}) {
fmt.Println(strings.Repeat(". ", u.depth) + sprintf(nil, nil, true, format, args...))
}
// A tparamsList describes a list of type parameters and the types inferred for them.
type tparamsList struct {
unifier *unifier
@ -121,6 +140,9 @@ func (d *tparamsList) init(tparams []*TypeParam) {
// If both type parameters already have a type associated with them and they are
// not joined, join fails and returns false.
func (u *unifier) join(i, j int) bool {
if traceInference {
u.tracef("%s ⇄ %s", u.x.tparams[i], u.y.tparams[j])
}
ti := u.x.indices[i]
tj := u.y.indices[j]
switch {
@ -210,6 +232,9 @@ func (d *tparamsList) at(i int) Type {
func (d *tparamsList) set(i int, typ Type) {
assert(typ != nil)
u := d.unifier
if traceInference {
u.tracef("%s ➞ %s", d.tparams[i], typ)
}
switch ti := d.indices[i]; {
case ti < 0:
u.types[-ti-1] = typ
@ -247,9 +272,16 @@ func (u *unifier) nifyEq(x, y Type, p *ifacePair) bool {
// adapted version of Checker.identical. For changes to that
// code the corresponding changes should be made here.
// Must not be called directly from outside the unifier.
func (u *unifier) nify(x, y Type, p *ifacePair) bool {
func (u *unifier) nify(x, y Type, p *ifacePair) (result bool) {
if traceInference {
u.tracef("%s ≡ %s", x, y)
}
// Stop gap for cases where unification fails.
if u.depth >= unificationDepthLimit {
if traceInference {
u.tracef("depth %d >= %d", u.depth, unificationDepthLimit)
}
if panicAtUnificationDepthLimit {
panic("unification reached recursion depth limit")
}
@ -258,6 +290,9 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
u.depth++
defer func() {
u.depth--
if traceInference && !result {
u.tracef("%s ≢ %s", x, y)
}
}()
if !u.exact {
@ -267,8 +302,14 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
// (We use !hasName to exclude any type with a name, including
// basic types and type parameters; the rest are unamed types.)
if nx, _ := x.(*Named); nx != nil && !hasName(y) {
if traceInference {
u.tracef("under %s ≡ %s", nx, y)
}
return u.nify(nx.under(), y, p)
} else if ny, _ := y.(*Named); ny != nil && !hasName(x) {
if traceInference {
u.tracef("%s ≡ under %s", x, ny)
}
return u.nify(x, ny.under(), p)
}
}
@ -302,6 +343,35 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
return true
}
// If we get here and x or y is a type parameter, they are type parameters
// from outside our declaration list. Try to unify their core types, if any
// (see issue #50755 for a test case).
if enableCoreTypeUnification && !u.exact {
if isTypeParam(x) && !hasName(y) {
// When considering the type parameter for unification
// we look at the adjusted core type (adjCoreType).
// If the adjusted core type is a named type N; the
// corresponding core type is under(N). Since !u.exact
// and y doesn't have a name, unification will end up
// comparing under(N) to y, so we can just use the core
// type instead. Optimization.
if cx := coreType(x); cx != nil {
if traceInference {
u.tracef("core %s ≡ %s", x, y)
}
return u.nify(cx, y, p)
}
} else if isTypeParam(y) && !hasName(x) {
// see comment above
if cy := coreType(y); cy != nil {
if traceInference {
u.tracef("%s ≡ core %s", x, y)
}
return u.nify(x, cy, p)
}
}
}
// For type unification, do not shortcut (x == y) for identical
// types. Instead keep comparing them element-wise to unify the
// matching (and equal type parameter types). A simple test case
@ -490,7 +560,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
// avoid a crash in case of nil type
default:
panic(fmt.Sprintf("### u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams))
panic(sprintf(nil, nil, true, "u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams))
}
return false