diff --git a/src/go/types/infer.go b/src/go/types/infer.go new file mode 100644 index 0000000000..c0b1a4b71a --- /dev/null +++ b/src/go/types/infer.go @@ -0,0 +1,359 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements type parameter inference given +// a list of concrete arguments and a parameter list. + +package types + +import ( + "go/token" + "strings" +) + +// infer returns the list of actual type arguments for the given list of type parameters tparams +// by inferring them from the actual arguments args for the parameters params. If type inference +// is impossible because unification fails, an error is reported and the resulting types list is +// nil, and index is 0. Otherwise, types is the list of inferred type arguments, and index is +// the index of the first type argument in that list that couldn't be inferred (and thus is nil). +// If all type arguments were inferred successfully, index is < 0. +func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand) (types []Type, index int) { + assert(params.Len() == len(args)) + + u := newUnifier(check, false) + u.x.init(tparams) + + errorf := func(kind string, tpar, targ Type, arg *operand) { + // provide a better error message if we can + targs, failed := u.x.types() + if failed == 0 { + // The first type parameter couldn't be inferred. + // If none of them could be inferred, don't try + // to provide the inferred type in the error msg. + allFailed := true + for _, targ := range targs { + if targ != nil { + allFailed = false + break + } + } + if allFailed { + check.errorf(arg, 0, "%s %s of %s does not match %s (cannot infer %s)", kind, targ, arg.expr, tpar, typeNamesString(tparams)) + return + } + } + smap := makeSubstMap(tparams, targs) + // TODO(rFindley): pass a positioner here, rather than arg.Pos(). + inferred := check.subst(arg.Pos(), tpar, smap) + if inferred != tpar { + check.errorf(arg, 0, "%s %s of %s does not match inferred type %s for %s", kind, targ, arg.expr, inferred, tpar) + } else { + check.errorf(arg, 0, "%s %s of %s does not match %s", kind, targ, arg.expr, tpar) + } + } + + // Terminology: generic parameter = function parameter with a type-parameterized type + + // 1st pass: Unify parameter and argument types for generic parameters with typed arguments + // and collect the indices of generic parameters with untyped arguments. + var indices []int + for i, arg := range args { + par := params.At(i) + // If we permit bidirectional unification, this conditional code needs to be + // executed even if par.typ is not parameterized since the argument may be a + // generic function (for which we want to infer // its type arguments). + if isParameterized(tparams, par.typ) { + if arg.mode == invalid { + // An error was reported earlier. Ignore this targ + // and continue, we may still be able to infer all + // targs resulting in fewer follon-on errors. + continue + } + if targ := arg.typ; isTyped(targ) { + // If we permit bidirectional unification, and targ is + // a generic function, we need to initialize u.y with + // the respective type parameters of targ. + if !u.unify(par.typ, targ) { + errorf("type", par.typ, targ, arg) + return nil, 0 + } + } else { + indices = append(indices, i) + } + } + } + + // Some generic parameters with untyped arguments may have been given a type + // indirectly through another generic parameter with a typed argument; we can + // ignore those now. (This only means that we know the types for those generic + // parameters; it doesn't mean untyped arguments can be passed safely. We still + // need to verify that assignment of those arguments is valid when we check + // function parameter passing external to infer.) + j := 0 + for _, i := range indices { + par := params.At(i) + // Since untyped types are all basic (i.e., non-composite) types, an + // untyped argument will never match a composite parameter type; the + // only parameter type it can possibly match against is a *TypeParam. + // Thus, only keep the indices of generic parameters that are not of + // composite types and which don't have a type inferred yet. + if tpar, _ := par.typ.(*TypeParam); tpar != nil && u.x.at(tpar.index) == nil { + indices[j] = i + j++ + } + } + indices = indices[:j] + + // 2nd pass: Unify parameter and default argument types for remaining generic parameters. + for _, i := range indices { + par := params.At(i) + arg := args[i] + targ := Default(arg.typ) + // The default type for an untyped nil is untyped nil. We must not + // infer an untyped nil type as type parameter type. Ignore untyped + // nil by making sure all default argument types are typed. + if isTyped(targ) && !u.unify(par.typ, targ) { + errorf("default type", par.typ, targ, arg) + return nil, 0 + } + } + + return u.x.types() +} + +// typeNamesString produces a string containing all the +// type names in list suitable for human consumption. +func typeNamesString(list []*TypeName) string { + // common cases + n := len(list) + switch n { + case 0: + return "" + case 1: + return list[0].name + case 2: + return list[0].name + " and " + list[1].name + } + + // general case (n > 2) + var b strings.Builder + for i, tname := range list[:n-1] { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(tname.name) + } + b.WriteString(", and ") + b.WriteString(list[n-1].name) + return b.String() +} + +// IsParameterized reports whether typ contains any of the type parameters of tparams. +func isParameterized(tparams []*TypeName, typ Type) bool { + w := tpWalker{ + seen: make(map[Type]bool), + tparams: tparams, + } + return w.isParameterized(typ) +} + +type tpWalker struct { + seen map[Type]bool + tparams []*TypeName +} + +func (w *tpWalker) isParameterized(typ Type) (res bool) { + // detect cycles + if x, ok := w.seen[typ]; ok { + return x + } + w.seen[typ] = false + defer func() { + w.seen[typ] = res + }() + + switch t := typ.(type) { + case nil, *Basic: // TODO(gri) should nil be handled here? + break + + case *Array: + return w.isParameterized(t.elem) + + case *Slice: + return w.isParameterized(t.elem) + + case *Struct: + for _, fld := range t.fields { + if w.isParameterized(fld.typ) { + return true + } + } + + case *Pointer: + return w.isParameterized(t.base) + + case *Tuple: + n := t.Len() + for i := 0; i < n; i++ { + if w.isParameterized(t.At(i).typ) { + return true + } + } + + case *Sum: + return w.isParameterizedList(t.types) + + case *Signature: + // t.tparams may not be nil if we are looking at a signature + // of a generic function type (or an interface method) that is + // part of the type we're testing. We don't care about these type + // parameters. + // Similarly, the receiver of a method may declare (rather then + // use) type parameters, we don't care about those either. + // Thus, we only need to look at the input and result parameters. + return w.isParameterized(t.params) || w.isParameterized(t.results) + + case *Interface: + if t.allMethods != nil { + // TODO(rFindley) at some point we should enforce completeness here + for _, m := range t.allMethods { + if w.isParameterized(m.typ) { + return true + } + } + return w.isParameterizedList(unpackType(t.allTypes)) + } + + return t.iterate(func(t *Interface) bool { + for _, m := range t.methods { + if w.isParameterized(m.typ) { + return true + } + } + return w.isParameterizedList(unpackType(t.types)) + }, nil) + + case *Map: + return w.isParameterized(t.key) || w.isParameterized(t.elem) + + case *Chan: + return w.isParameterized(t.elem) + + case *Named: + return w.isParameterizedList(t.targs) + + case *TypeParam: + // t must be one of w.tparams + return t.index < len(w.tparams) && w.tparams[t.index].typ == t + + case *instance: + return w.isParameterizedList(t.targs) + + default: + unreachable() + } + + return false +} + +func (w *tpWalker) isParameterizedList(list []Type) bool { + for _, t := range list { + if w.isParameterized(t) { + return true + } + } + return false +} + +// inferB returns the list of actual type arguments inferred from the type parameters' +// bounds and an initial set of type arguments. If type inference is impossible because +// unification fails, an error is reported, the resulting types list is nil, and index is 0. +// Otherwise, types is the list of inferred type arguments, and index is the index of the +// first type argument in that list that couldn't be inferred (and thus is nil). If all +// type arguments where inferred successfully, index is < 0. The number of type arguments +// provided may be less than the number of type parameters, but there must be at least one. +func (check *Checker) inferB(tparams []*TypeName, targs []Type) (types []Type, index int) { + assert(len(tparams) >= len(targs) && len(targs) > 0) + + // Setup bidirectional unification between those structural bounds + // and the corresponding type arguments (which may be nil!). + u := newUnifier(check, false) + u.x.init(tparams) + u.y = u.x // type parameters between LHS and RHS of unification are identical + + // Set the type arguments which we know already. + for i, targ := range targs { + if targ != nil { + u.x.set(i, targ) + } + } + + // Unify type parameters with their structural constraints, if any. + for _, tpar := range tparams { + typ := tpar.typ.(*TypeParam) + sbound := check.structuralType(typ.bound) + if sbound != nil { + if !u.unify(typ, sbound) { + check.errorf(tpar, 0, "%s does not match %s", tpar, sbound) + return nil, 0 + } + } + } + + // u.x.types() now contains the incoming type arguments plus any additional type + // arguments for which there were structural constraints. The newly inferred non- + // nil entries may still contain references to other type parameters. For instance, + // for [A any, B interface{type []C}, C interface{type *A}], if A == int + // was given, unification produced the type list [int, []C, *A]. We eliminate the + // remaining type parameters by substituting the type parameters in this type list + // until nothing changes anymore. + types, index = u.x.types() + if debug { + for i, targ := range targs { + assert(targ == nil || types[i] == targ) + } + } + + // dirty tracks the indices of all types that may still contain type parameters. + // We know that nil type entries and entries corresponding to provided (non-nil) + // type arguments are clean, so exclude them from the start. + var dirty []int + for i, typ := range types { + if typ != nil && (i >= len(targs) || targs[i] == nil) { + dirty = append(dirty, i) + } + } + + for len(dirty) > 0 { + // TODO(gri) Instead of creating a new substMap for each iteration, + // provide an update operation for substMaps and only change when + // needed. Optimization. + smap := makeSubstMap(tparams, types) + n := 0 + for _, index := range dirty { + t0 := types[index] + if t1 := check.subst(token.NoPos, t0, smap); t1 != t0 { + types[index] = t1 + dirty[n] = index + n++ + } + } + dirty = dirty[:n] + } + + return +} + +// structuralType returns the structural type of a constraint, if any. +func (check *Checker) structuralType(constraint Type) Type { + if iface, _ := under(constraint).(*Interface); iface != nil { + check.completeInterface(token.NoPos, iface) + types := unpackType(iface.allTypes) + if len(types) == 1 { + return types[0] + } + return nil + } + return constraint +} diff --git a/src/go/types/unify.go b/src/go/types/unify.go new file mode 100644 index 0000000000..ab18febbdf --- /dev/null +++ b/src/go/types/unify.go @@ -0,0 +1,452 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file implements type unification. + +package types + +import ( + "go/token" + "sort" +) + +// The unifier maintains two separate sets of type parameters x and y +// which are used to resolve type parameters in the x and y arguments +// provided to the unify call. For unidirectional unification, only +// one of these sets (say x) is provided, and then type parameters are +// only resolved for the x argument passed to unify, not the y argument +// (even if that also contains possibly the same type parameters). This +// is crucial to infer the type parameters of self-recursive calls: +// +// func f[P any](a P) { f(a) } +// +// For the call f(a) we want to infer that the type argument for P is P. +// During unification, the parameter type P must be resolved to the type +// parameter P ("x" side), but the argument type P must be left alone so +// that unification resolves the type parameter P to P. +// +// For bidirection unification, both sets are provided. This enables +// unification to go from argument to parameter type and vice versa. +// For constraint type inference, we use bidirectional unification +// where both the x and y type parameters are identical. This is done +// by setting up one of them (using init) and then assigning its value +// to the other. + +// A unifier maintains the current type parameters for x and y +// and the respective types inferred for each type parameter. +// A unifier is created by calling newUnifier. +type unifier struct { + check *Checker + exact bool + x, y tparamsList // x and y must initialized via tparamsList.init + types []Type // inferred types, shared by x and y +} + +// newUnifier returns a new unifier. +// If exact is set, unification requires unified types to match +// exactly. If exact is not set, a named type's underlying type +// is considered if unification would fail otherwise, and the +// direction of channels is ignored. +func newUnifier(check *Checker, exact bool) *unifier { + u := &unifier{check: check, exact: exact} + u.x.unifier = u + u.y.unifier = u + return u +} + +// unify attempts to unify x and y and reports whether it succeeded. +func (u *unifier) unify(x, y Type) bool { + return u.nify(x, y, nil) +} + +// A tparamsList describes a list of type parameters and the types inferred for them. +type tparamsList struct { + unifier *unifier + tparams []*TypeName + // For each tparams element, there is a corresponding type slot index in indices. + // index < 0: unifier.types[-index-1] == nil + // index == 0: no type slot allocated yet + // index > 0: unifier.types[index-1] == typ + // Joined tparams elements share the same type slot and thus have the same index. + // By using a negative index for nil types we don't need to check unifier.types + // to see if we have a type or not. + indices []int // len(d.indices) == len(d.tparams) +} + +// init initializes d with the given type parameters. +// The type parameters must be in the order in which they appear in their declaration +// (this ensures that the tparams indices match the respective type parameter index). +func (d *tparamsList) init(tparams []*TypeName) { + if len(tparams) == 0 { + return + } + if debug { + for i, tpar := range tparams { + assert(i == tpar.typ.(*TypeParam).index) + } + } + d.tparams = tparams + d.indices = make([]int, len(tparams)) +} + +// join unifies the i'th type parameter of x with the j'th type parameter of y. +// If both type parameters already have a type associated with them and they are +// not joined, join fails and return false. +func (u *unifier) join(i, j int) bool { + ti := u.x.indices[i] + tj := u.y.indices[j] + switch { + case ti == 0 && tj == 0: + // Neither type parameter has a type slot associated with them. + // Allocate a new joined nil type slot (negative index). + u.types = append(u.types, nil) + u.x.indices[i] = -len(u.types) + u.y.indices[j] = -len(u.types) + case ti == 0: + // The type parameter for x has no type slot yet. Use slot of y. + u.x.indices[i] = tj + case tj == 0: + // The type parameter for y has no type slot yet. Use slot of x. + u.y.indices[j] = ti + + // Both type parameters have a slot: ti != 0 && tj != 0. + case ti == tj: + // Both type parameters already share the same slot. Nothing to do. + break + case ti > 0 && tj > 0: + // Both type parameters have (possibly different) inferred types. Cannot join. + return false + case ti > 0: + // Only the type parameter for x has an inferred type. Use x slot for y. + u.y.setIndex(j, ti) + default: + // Either the type parameter for y has an inferred type, or neither type + // parameter has an inferred type. In either case, use y slot for x. + u.x.setIndex(i, tj) + } + return true +} + +// If typ is a type parameter of d, index returns the type parameter index. +// Otherwise, the result is < 0. +func (d *tparamsList) index(typ Type) int { + if t, ok := typ.(*TypeParam); ok { + if i := t.index; i < len(d.tparams) && d.tparams[i].typ == t { + return i + } + } + return -1 +} + +// setIndex sets the type slot index for the i'th type parameter +// (and all its joined parameters) to tj. The type parameter +// must have a (possibly nil) type slot associated with it. +func (d *tparamsList) setIndex(i, tj int) { + ti := d.indices[i] + assert(ti != 0 && tj != 0) + for k, tk := range d.indices { + if tk == ti { + d.indices[k] = tj + } + } +} + +// at returns the type set for the i'th type parameter; or nil. +func (d *tparamsList) at(i int) Type { + if ti := d.indices[i]; ti > 0 { + return d.unifier.types[ti-1] + } + return nil +} + +// set sets the type typ for the i'th type parameter; +// typ must not be nil and it must not have been set before. +func (d *tparamsList) set(i int, typ Type) { + assert(typ != nil) + u := d.unifier + switch ti := d.indices[i]; { + case ti < 0: + u.types[-ti-1] = typ + d.setIndex(i, -ti) + case ti == 0: + u.types = append(u.types, typ) + d.indices[i] = len(u.types) + default: + panic("type already set") + } +} + +// types returns the list of inferred types (via unification) for the type parameters +// described by d, and an index. If all types were inferred, the returned index is < 0. +// Otherwise, it is the index of the first type parameter which couldn't be inferred; +// i.e., for which list[index] is nil. +func (d *tparamsList) types() (list []Type, index int) { + list = make([]Type, len(d.tparams)) + index = -1 + for i := range d.tparams { + t := d.at(i) + list[i] = t + if index < 0 && t == nil { + index = i + } + } + return +} + +func (u *unifier) nifyEq(x, y Type, p *ifacePair) bool { + return x == y || u.nify(x, y, p) +} + +// nify implements the core unification algorithm which is an +// adapted version of Checker.identical0. For changes to that +// code the corresponding changes should be made here. +// Must not be called directly from outside the unifier. +func (u *unifier) nify(x, y Type, p *ifacePair) bool { + // types must be expanded for comparison + x = expand(x) + y = expand(y) + + if !u.exact { + // If exact unification is known to fail because we attempt to + // match a type name against an unnamed type literal, consider + // the underlying type of the named type. + // (Subtle: We use isNamed to include any type with a name (incl. + // basic types and type parameters. We use asNamed() because we only + // want *Named types.) + switch { + case !isNamed(x) && y != nil && asNamed(y) != nil: + return u.nify(x, under(y), p) + case x != nil && asNamed(x) != nil && !isNamed(y): + return u.nify(under(x), y, p) + } + } + + // Cases where at least one of x or y is a type parameter. + switch i, j := u.x.index(x), u.y.index(y); { + case i >= 0 && j >= 0: + // both x and y are type parameters + if u.join(i, j) { + return true + } + // both x and y have an inferred type - they must match + return u.nifyEq(u.x.at(i), u.y.at(j), p) + + case i >= 0: + // x is a type parameter, y is not + if tx := u.x.at(i); tx != nil { + return u.nifyEq(tx, y, p) + } + // otherwise, infer type from y + u.x.set(i, y) + return true + + case j >= 0: + // y is a type parameter, x is not + if ty := u.y.at(j); ty != nil { + return u.nifyEq(x, ty, p) + } + // otherwise, infer type from x + u.y.set(j, x) + return true + } + + // For type unification, do not shortcut (x == y) for identical + // types. Instead keep comparing them element-wise to unify the + // matching (and equal type parameter types). A simple test case + // where this matters is: func f[P any](a P) { f(a) } . + + switch x := x.(type) { + case *Basic: + // Basic types are singletons except for the rune and byte + // aliases, thus we cannot solely rely on the x == y check + // above. See also comment in TypeName.IsAlias. + if y, ok := y.(*Basic); ok { + return x.kind == y.kind + } + + case *Array: + // Two array types are identical if they have identical element types + // and the same array length. + if y, ok := y.(*Array); ok { + // If one or both array lengths are unknown (< 0) due to some error, + // assume they are the same to avoid spurious follow-on errors. + return (x.len < 0 || y.len < 0 || x.len == y.len) && u.nify(x.elem, y.elem, p) + } + + case *Slice: + // Two slice types are identical if they have identical element types. + if y, ok := y.(*Slice); ok { + return u.nify(x.elem, y.elem, p) + } + + case *Struct: + // Two struct types are identical if they have the same sequence of fields, + // and if corresponding fields have the same names, and identical types, + // and identical tags. Two embedded fields are considered to have the same + // name. Lower-case field names from different packages are always different. + if y, ok := y.(*Struct); ok { + if x.NumFields() == y.NumFields() { + for i, f := range x.fields { + g := y.fields[i] + if f.embedded != g.embedded || + x.Tag(i) != y.Tag(i) || + !f.sameId(g.pkg, g.name) || + !u.nify(f.typ, g.typ, p) { + return false + } + } + return true + } + } + + case *Pointer: + // Two pointer types are identical if they have identical base types. + if y, ok := y.(*Pointer); ok { + return u.nify(x.base, y.base, p) + } + + case *Tuple: + // Two tuples types are identical if they have the same number of elements + // and corresponding elements have identical types. + if y, ok := y.(*Tuple); ok { + if x.Len() == y.Len() { + if x != nil { + for i, v := range x.vars { + w := y.vars[i] + if !u.nify(v.typ, w.typ, p) { + return false + } + } + } + return true + } + } + + case *Signature: + // Two function types are identical if they have the same number of parameters + // and result values, corresponding parameter and result types are identical, + // and either both functions are variadic or neither is. Parameter and result + // names are not required to match. + // TODO(gri) handle type parameters or document why we can ignore them. + if y, ok := y.(*Signature); ok { + return x.variadic == y.variadic && + u.nify(x.params, y.params, p) && + u.nify(x.results, y.results, p) + } + + case *Sum: + // This should not happen with the current internal use of sum types. + panic("type inference across sum types not implemented") + + case *Interface: + // Two interface types are identical if they have the same set of methods with + // the same names and identical function types. Lower-case method names from + // different packages are always different. The order of the methods is irrelevant. + if y, ok := y.(*Interface); ok { + // If identical0 is called (indirectly) via an external API entry point + // (such as Identical, IdenticalIgnoreTags, etc.), check is nil. But in + // that case, interfaces are expected to be complete and lazy completion + // here is not needed. + if u.check != nil { + u.check.completeInterface(token.NoPos, x) + u.check.completeInterface(token.NoPos, y) + } + a := x.allMethods + b := y.allMethods + if len(a) == len(b) { + // Interface types are the only types where cycles can occur + // that are not "terminated" via named types; and such cycles + // can only be created via method parameter types that are + // anonymous interfaces (directly or indirectly) embedding + // the current interface. Example: + // + // type T interface { + // m() interface{T} + // } + // + // If two such (differently named) interfaces are compared, + // endless recursion occurs if the cycle is not detected. + // + // If x and y were compared before, they must be equal + // (if they were not, the recursion would have stopped); + // search the ifacePair stack for the same pair. + // + // This is a quadratic algorithm, but in practice these stacks + // are extremely short (bounded by the nesting depth of interface + // type declarations that recur via parameter types, an extremely + // rare occurrence). An alternative implementation might use a + // "visited" map, but that is probably less efficient overall. + q := &ifacePair{x, y, p} + for p != nil { + if p.identical(q) { + return true // same pair was compared before + } + p = p.prev + } + if debug { + assert(sort.IsSorted(byUniqueMethodName(a))) + assert(sort.IsSorted(byUniqueMethodName(b))) + } + for i, f := range a { + g := b[i] + if f.Id() != g.Id() || !u.nify(f.typ, g.typ, q) { + return false + } + } + return true + } + } + + case *Map: + // Two map types are identical if they have identical key and value types. + if y, ok := y.(*Map); ok { + return u.nify(x.key, y.key, p) && u.nify(x.elem, y.elem, p) + } + + case *Chan: + // Two channel types are identical if they have identical value types. + if y, ok := y.(*Chan); ok { + return (!u.exact || x.dir == y.dir) && u.nify(x.elem, y.elem, p) + } + + case *Named: + // Two named types are identical if their type names originate + // in the same type declaration. + // if y, ok := y.(*Named); ok { + // return x.obj == y.obj + // } + if y, ok := y.(*Named); ok { + // TODO(gri) This is not always correct: two types may have the same names + // in the same package if one of them is nested in a function. + // Extremely unlikely but we need an always correct solution. + if x.obj.pkg == y.obj.pkg && x.obj.name == y.obj.name { + assert(len(x.targs) == len(y.targs)) + for i, x := range x.targs { + if !u.nify(x, y.targs[i], p) { + return false + } + } + return true + } + } + + case *TypeParam: + // Two type parameters (which are not part of the type parameters of the + // enclosing type as those are handled in the beginning of this function) + // are identical if they originate in the same declaration. + return x == y + + // case *instance: + // unreachable since types are expanded + + case nil: + // avoid a crash in case of nil type + + default: + u.check.dump("### u.nify(%s, %s), u.x.tparams = %s", x, y, u.x.tparams) + unreachable() + } + + return false +}