diff --git a/src/go/types/api_test.go b/src/go/types/api_test.go index e6c209dda0..9ca24db1de 100644 --- a/src/go/types/api_test.go +++ b/src/go/types/api_test.go @@ -1817,3 +1817,26 @@ func f(x T) T { return foo.F(x) } } } } + +func TestInstantiate(t *testing.T) { + // eventually we like more tests but this is a start + const src = genericPkg + "p; type T[P any] *T[P]" + pkg, err := pkgFor(".", src, nil) + if err != nil { + t.Fatal(err) + } + + // type T should have one type parameter + T := pkg.Scope().Lookup("T").Type().(*Named) + if n := len(T.TParams()); n != 1 { + t.Fatalf("expected 1 type parameter; found %d", n) + } + + // instantiation should succeed (no endless recursion) + res := Instantiate(token.NoPos, T, []Type{Typ[Int]}) + + // instantiated type should point to itself + if res.Underlying().(*Pointer).Elem() != res { + t.Fatalf("unexpected result type: %s", res) + } +} diff --git a/src/go/types/subst.go b/src/go/types/subst.go index 4809b8c47a..64146be27e 100644 --- a/src/go/types/subst.go +++ b/src/go/types/subst.go @@ -237,15 +237,27 @@ func (check *Checker) subst(pos token.Pos, typ Type, smap *substMap) Type { } // general case - subst := subster{check, pos, make(map[Type]Type), smap} + var subst subster + subst.pos = pos + subst.smap = smap + if check != nil { + subst.check = check + subst.typMap = check.typMap + } else { + // If we don't have a *Checker and its global type map, + // use a local version. Besides avoiding duplicate work, + // the type map prevents infinite recursive substitution + // for recursive types (example: type T[P any] *T[P]). + subst.typMap = make(map[string]*Named) + } return subst.typ(typ) } type subster struct { - check *Checker - pos token.Pos - cache map[Type]Type - smap *substMap + pos token.Pos + smap *substMap + check *Checker // nil if called via Instantiate + typMap map[string]*Named } func (subst *subster) typ(typ Type) Type { @@ -390,22 +402,16 @@ func (subst *subster) typ(typ Type) Type { // before creating a new named type, check if we have this one already h := instantiatedHash(t, newTargs) dump(">>> new type hash: %s", h) - if subst.check != nil { - if named, found := subst.check.typMap[h]; found { - dump(">>> found %s", named) - subst.cache[t] = named - return named - } + if named, found := subst.typMap[h]; found { + dump(">>> found %s", named) + return named } - // create a new named type and populate caches to avoid endless recursion + // create a new named type and populate typMap to avoid endless recursion tname := NewTypeName(subst.pos, t.obj.pkg, t.obj.name, nil) named := subst.check.newNamed(tname, t, t.Underlying(), t.TParams(), t.methods) // method signatures are updated lazily named.targs = newTargs - if subst.check != nil { - subst.check.typMap[h] = named - } - subst.cache[t] = named + subst.typMap[h] = named // do the substitution dump(">>> subst %s with %s (new: %s)", t.underlying, subst.smap, newTargs)