go/types: avoid infinite recursion in unification

This is an almost clean port of CL 352832 from types2 to go/types:
The nest files and unify.go where copied verbatim; unify.go was
adjusted with correct package name, a slightly different comment
was restored to what it was. The test files got adjustments for
error position. infer.go got a missing _Todo error code.

For #48619.
For #48656.

Change-Id: Ia1a2d09e8bb37a85032b4b7e7c7a0b08e8c793a5
Reviewed-on: https://go-review.googlesource.com/c/go/+/353029
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
This commit is contained in:
Robert Griesemer 2021-09-28 20:42:45 -07:00
parent e180e2c27c
commit 3f224bbf9a
4 changed files with 67 additions and 9 deletions

View file

@ -128,7 +128,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
if inferred != tpar { if inferred != tpar {
check.errorf(arg, _Todo, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar) check.errorf(arg, _Todo, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar)
} else { } else {
check.errorf(arg, 0, "%s %s of %s does not match %s", kind, targ, arg.expr, tpar) check.errorf(arg, _Todo, "%s %s of %s does not match %s", kind, targ, arg.expr, tpar)
} }
} }

View file

@ -0,0 +1,22 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
func f[P any](a, _ P) {
var x int
f(a, x /* ERROR type int of x does not match P */)
f(x, a /* ERROR type P of a does not match inferred type int for P */)
}
func g[P any](a, b P) {
g(a, b)
g(&a, &b)
g([]P{}, []P{})
}
func h[P any](a, b P) {
h(&a, &b)
h([]P{a}, []P{b})
}

View file

@ -0,0 +1,13 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
// TODO(gri) Still need better error positions and message here.
// But this doesn't crash anymore.
func f[P /* ERROR does not match \*Q */ interface{*Q}, Q any](p P, q Q) {
_ = f[P]
_ = f /* ERROR cannot infer P */ [*P]
}

View file

@ -9,7 +9,6 @@ package types
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"sort"
) )
// The unifier maintains two separate sets of type parameters x and y // The unifier maintains two separate sets of type parameters x and y
@ -64,6 +63,10 @@ func (u *unifier) unify(x, y Type) bool {
type tparamsList struct { type tparamsList struct {
unifier *unifier unifier *unifier
tparams []*TypeParam tparams []*TypeParam
// For each tparams element, there is a corresponding mask bit in masks.
// If set, the corresponding type parameter is masked and doesn't appear
// as a type parameter with tparamsList.index.
masks []bool
// For each tparams element, there is a corresponding type slot index in indices. // For each tparams element, there is a corresponding type slot index in indices.
// index < 0: unifier.types[-index-1] == nil // index < 0: unifier.types[-index-1] == nil
// index == 0: no type slot allocated yet // index == 0: no type slot allocated yet
@ -104,9 +107,14 @@ func (d *tparamsList) init(tparams []*TypeParam) {
} }
} }
d.tparams = tparams d.tparams = tparams
d.masks = make([]bool, len(tparams))
d.indices = make([]int, len(tparams)) d.indices = make([]int, len(tparams))
} }
// mask and unmask permit the masking/unmasking of the i'th type parameter of d.
func (d *tparamsList) mask(i int) { d.masks[i] = true }
func (d *tparamsList) unmask(i int) { d.masks[i] = false }
// join unifies the i'th type parameter of x with the j'th type parameter of y. // join unifies the i'th type parameter of x with the j'th type parameter of y.
// If both type parameters already have a type associated with them and they are // If both type parameters already have a type associated with them and they are
// not joined, join fails and returns false. // not joined, join fails and returns false.
@ -138,19 +146,25 @@ func (u *unifier) join(i, j int) bool {
case ti > 0: case ti > 0:
// Only the type parameter for x has an inferred type. Use x slot for y. // Only the type parameter for x has an inferred type. Use x slot for y.
u.y.setIndex(j, ti) u.y.setIndex(j, ti)
// This case is handled like the default case.
// case tj > 0:
// // Only the type parameter for y has an inferred type. Use y slot for x.
// u.x.setIndex(i, tj)
default: default:
// Either the type parameter for y has an inferred type, or neither type // Neither type parameter has an inferred type. Use y slot for x
// parameter has an inferred type. In either case, use y slot for x. // (or x slot for y, it doesn't matter).
u.x.setIndex(i, tj) u.x.setIndex(i, tj)
} }
return true return true
} }
// If typ is a type parameter of d, index returns the type parameter index. // If typ is an unmasked type parameter of d, index returns the type parameter index.
// Otherwise, the result is < 0. // Otherwise, the result is < 0.
func (d *tparamsList) index(typ Type) int { func (d *tparamsList) index(typ Type) int {
if tpar, ok := typ.(*TypeParam); ok { if tpar, ok := typ.(*TypeParam); ok {
return tparamIndex(d.tparams, tpar) if i := tparamIndex(d.tparams, tpar); i >= 0 && !d.masks[i] {
return i
}
} }
return -1 return -1
} }
@ -243,7 +257,7 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
} }
} }
// Cases where at least one of x or y is a type parameter. // Cases where at least one of x or y is an (unmasked) type parameter.
switch i, j := u.x.index(x), u.y.index(y); { switch i, j := u.x.index(x), u.y.index(y); {
case i >= 0 && j >= 0: case i >= 0 && j >= 0:
// both x and y are type parameters // both x and y are type parameters
@ -256,6 +270,12 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
case i >= 0: case i >= 0:
// x is a type parameter, y is not // x is a type parameter, y is not
if tx := u.x.at(i); tx != nil { if tx := u.x.at(i); tx != nil {
// The inferred type tx may be or contain x again but we don't
// want to "unpack" it again when unifying tx with y: tx is the
// inferred type. Mask type parameter x for this recursion, so
// that subsequent encounters treat x like an ordinary type.
u.x.mask(i)
defer u.x.unmask(i)
return u.nifyEq(tx, y, p) return u.nifyEq(tx, y, p)
} }
// otherwise, infer type from y // otherwise, infer type from y
@ -265,6 +285,9 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
case j >= 0: case j >= 0:
// y is a type parameter, x is not // y is a type parameter, x is not
if ty := u.y.at(j); ty != nil { if ty := u.y.at(j); ty != nil {
// see comment above
u.y.mask(j)
defer u.y.unmask(j)
return u.nifyEq(x, ty, p) return u.nifyEq(x, ty, p)
} }
// otherwise, infer type from x // otherwise, infer type from x
@ -399,8 +422,8 @@ func (u *unifier) nify(x, y Type, p *ifacePair) bool {
p = p.prev p = p.prev
} }
if debug { if debug {
assert(sort.IsSorted(byUniqueMethodName(a))) assertSortedMethods(a)
assert(sort.IsSorted(byUniqueMethodName(b))) assertSortedMethods(b)
} }
for i, f := range a { for i, f := range a {
g := b[i] g := b[i]