diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 70a2c7b97f..83abee1dd2 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -357,8 +357,7 @@ func (g *irgen) buildClosure(outer *ir.Func, x ir.Node) ir.Node { // } // Make a new internal function. - fn := ir.NewClosureFunc(pos, outer != nil) - ir.NameClosure(fn.OClosure, outer) + fn, formalParams, formalResults := startClosure(pos, outer, typ) // This is the dictionary we want to use. // It may be a constant, or it may be a dictionary acquired from the outer function's dictionary. @@ -395,38 +394,6 @@ func (g *irgen) buildClosure(outer *ir.Func, x ir.Node) ir.Node { outer.Dcl = append(outer.Dcl, rcvrVar) } - // Build formal argument and return lists. - var formalParams []*types.Field // arguments of closure - var formalResults []*types.Field // returns of closure - for i := 0; i < typ.NumParams(); i++ { - t := typ.Params().Field(i).Type - arg := ir.NewNameAt(pos, typecheck.LookupNum("a", i)) - arg.Class = ir.PPARAM - typed(t, arg) - arg.Curfn = fn - fn.Dcl = append(fn.Dcl, arg) - f := types.NewField(pos, arg.Sym(), t) - f.Nname = arg - formalParams = append(formalParams, f) - } - for i := 0; i < typ.NumResults(); i++ { - t := typ.Results().Field(i).Type - result := ir.NewNameAt(pos, typecheck.LookupNum("r", i)) // TODO: names not needed? - result.Class = ir.PPARAMOUT - typed(t, result) - result.Curfn = fn - fn.Dcl = append(fn.Dcl, result) - f := types.NewField(pos, result.Sym(), t) - f.Nname = result - formalResults = append(formalResults, f) - } - - // Build an internal function with the right signature. - closureType := types.NewSignature(x.Type().Pkg(), nil, nil, formalParams, formalResults) - typed(closureType, fn.Nname) - typed(x.Type(), fn.OClosure) - fn.SetTypecheck(1) - // Build body of closure. This involves just calling the wrapped function directly // with the additional dictionary argument. @@ -1092,14 +1059,15 @@ func getDictionaryEntry(pos src.XPos, dict *ir.Name, i int, size int) ir.Node { return r } -// getDictionaryType returns a *runtime._type from the dictionary entry i -// (which refers to a type param or a derived type that uses type params). -func (subst *subster) getDictionaryType(pos src.XPos, i int) ir.Node { - if i < 0 || i >= subst.info.startSubDict { +// getDictionaryType returns a *runtime._type from the dictionary entry i (which +// refers to a type param or a derived type that uses type params). It uses the +// specified dictionary dictParam, rather than the one in info.dictParam. +func getDictionaryType(info *instInfo, dictParam *ir.Name, pos src.XPos, i int) ir.Node { + if i < 0 || i >= info.startSubDict { base.Fatalf(fmt.Sprintf("bad dict index %d", i)) } - r := getDictionaryEntry(pos, subst.info.dictParam, i, subst.info.startSubDict) + r := getDictionaryEntry(pos, info.dictParam, i, info.startSubDict) // change type of retrieved dictionary entry to *byte, which is the // standard typing of a *runtime._type in the compiler typed(types.Types[types.TUINT8].PtrTo(), r) @@ -1235,9 +1203,9 @@ func (subst *subster) node(n ir.Node) ir.Node { // The only dot on a shape type value are methods. if mse.X.Op() == ir.OTYPE { // Method expression T.M - // Fall back from shape type to concrete type. - src = subst.unshapifyTyp(src) - mse.X = ir.TypeNode(src) + m = subst.g.buildClosure2(subst.newf, subst.info, m, x) + // No need for transformDot - buildClosure2 has already + // transformed to OCALLINTER/ODOTINTER. } else { // Implement x.M as a conversion-to-bound-interface // 1) convert x to the bound interface @@ -1247,12 +1215,11 @@ func (subst *subster) node(n ir.Node) ir.Node { if dst.HasTParam() { dst = subst.ts.Typ(dst) } - mse.X = subst.convertUsingDictionary(m.Pos(), mse.X, x, dst, gsrc) + mse.X = convertUsingDictionary(subst.info, subst.info.dictParam, m.Pos(), mse.X, x, dst, gsrc) + transformDot(mse, false) } - } - transformDot(mse, false) - if mse.Op() == ir.OMETHEXPR && mse.X.Type().HasShape() { - mse.X = ir.TypeNodeAt(mse.X.Pos(), subst.unshapifyTyp(mse.X.Type())) + } else { + transformDot(mse, false) } m.SetTypecheck(1) @@ -1341,11 +1308,14 @@ func (subst *subster) node(n ir.Node) ir.Node { // outer function. Since the dictionary is shared, use the // same entries for startSubDict, dictLen, dictEntryMap. cinfo := &instInfo{ - fun: newfn, - dictParam: cdict, - startSubDict: subst.info.startSubDict, - dictLen: subst.info.dictLen, - dictEntryMap: subst.info.dictEntryMap, + fun: newfn, + dictParam: cdict, + gf: subst.info.gf, + gfInfo: subst.info.gfInfo, + startSubDict: subst.info.startSubDict, + startItabConv: subst.info.startItabConv, + dictLen: subst.info.dictLen, + dictEntryMap: subst.info.dictEntryMap, } subst.g.instInfoMap[newfn.Nname.Sym()] = cinfo @@ -1353,8 +1323,11 @@ func (subst *subster) node(n ir.Node) ir.Node { typed(newfn.Nname.Type(), newfn.OClosure) newfn.SetTypecheck(1) + outerinfo := subst.info + subst.info = cinfo // Make sure type of closure function is set before doing body. newfn.Body = subst.list(oldfn.Body) + subst.info = outerinfo subst.newf = saveNewf ir.CurFunc = saveNewf @@ -1366,15 +1339,15 @@ func (subst *subster) node(n ir.Node) ir.Node { // Note: x's argument is still typed as a type parameter. // m's argument now has an instantiated type. if x.X.Type().HasTParam() { - m = subst.convertUsingDictionary(m.Pos(), m.(*ir.ConvExpr).X, x, m.Type(), x.X.Type()) + m = convertUsingDictionary(subst.info, subst.info.dictParam, m.Pos(), m.(*ir.ConvExpr).X, x, m.Type(), x.X.Type()) } case ir.ODOTTYPE, ir.ODOTTYPE2: dt := m.(*ir.TypeAssertExpr) var rt ir.Node if dt.Type().IsInterface() || dt.X.Type().IsEmptyInterface() { - ix := subst.findDictType(x.Type()) + ix := findDictType(subst.info, x.Type()) assert(ix >= 0) - rt = subst.getDictionaryType(dt.Pos(), ix) + rt = getDictionaryType(subst.info, subst.info.dictParam, dt.Pos(), ix) } else { // nonempty interface to noninterface. Need an itab. ix := -1 @@ -1395,9 +1368,6 @@ func (subst *subster) node(n ir.Node) ir.Node { m.SetType(dt.Type()) m.SetTypecheck(1) - case ir.OMETHEXPR: - se := m.(*ir.SelectorExpr) - se.X = ir.TypeNodeAt(se.X.Pos(), subst.unshapifyTyp(se.X.Type())) case ir.OFUNCINST: inst := m.(*ir.InstExpr) targs2 := make([]ir.Node, len(inst.Targs)) @@ -1414,17 +1384,17 @@ func (subst *subster) node(n ir.Node) ir.Node { } // findDictType looks for type t in the typeparams or derived types in the generic -// function info subst.info.gfInfo. This will indicate the dictionary entry with the +// function info.gfInfo. This will indicate the dictionary entry with the // correct concrete type for the associated instantiated function. -func (subst *subster) findDictType(t *types.Type) int { - for i, dt := range subst.info.gfInfo.tparams { +func findDictType(info *instInfo, t *types.Type) int { + for i, dt := range info.gfInfo.tparams { if dt == t { return i } } - for i, dt := range subst.info.gfInfo.derivedTypes { + for i, dt := range info.gfInfo.derivedTypes { if types.Identical(dt, t) { - return i + len(subst.info.gfInfo.tparams) + return i + len(info.gfInfo.tparams) } } return -1 @@ -1435,7 +1405,7 @@ func (subst *subster) findDictType(t *types.Type) int { // is the generic (not shape) type, and gn is the original generic node of the // CONVIFACE node or XDOT node (for a bound method call) that is causing the // conversion. -func (subst *subster) convertUsingDictionary(pos src.XPos, v ir.Node, gn ir.Node, dst, src *types.Type) ir.Node { +func convertUsingDictionary(info *instInfo, dictParam *ir.Name, pos src.XPos, v ir.Node, gn ir.Node, dst, src *types.Type) ir.Node { assert(src.HasTParam()) assert(dst.IsInterface()) @@ -1445,19 +1415,19 @@ func (subst *subster) convertUsingDictionary(pos src.XPos, v ir.Node, gn ir.Node // will be more efficient than converting to an empty interface first // and then type asserting to dst. ix := -1 - for i, ic := range subst.info.gfInfo.itabConvs { + for i, ic := range info.gfInfo.itabConvs { if ic == gn { - ix = subst.info.startItabConv + i + ix = info.startItabConv + i break } } assert(ix >= 0) - rt = getDictionaryEntry(pos, subst.info.dictParam, ix, subst.info.dictLen) + rt = getDictionaryEntry(pos, dictParam, ix, info.dictLen) } else { - ix := subst.findDictType(src) + ix := findDictType(info, src) assert(ix >= 0) // Load the actual runtime._type of the type parameter from the dictionary. - rt = subst.getDictionaryType(pos, ix) + rt = getDictionaryType(info, dictParam, pos, ix) } // Figure out what the data field of the interface will be. @@ -1670,7 +1640,7 @@ func (g *irgen) getDictionarySym(gf *ir.Name, targs []*types.Type, isMeth bool) case ir.OXDOT: selExpr := n.(*ir.SelectorExpr) - subtargs := selExpr.X.Type().RParams() + subtargs := deref(selExpr.X.Type()).RParams() s2targs := make([]*types.Type, len(subtargs)) for i, t := range subtargs { s2targs[i] = subst.Typ(t) @@ -1842,7 +1812,7 @@ func (g *irgen) getGfInfo(gn *ir.Name) *gfInfo { } } else if n.Op() == ir.OXDOT && !n.(*ir.SelectorExpr).Implicit() && n.(*ir.SelectorExpr).Selection != nil && - len(n.(*ir.SelectorExpr).X.Type().RParams()) > 0 { + len(deref(n.(*ir.SelectorExpr).X.Type()).RParams()) > 0 { if n.(*ir.SelectorExpr).X.Op() == ir.OTYPE { infoPrint(" Closure&subdictionary required at generic meth expr %v\n", n) } else { @@ -2017,3 +1987,105 @@ func parameterizedBy1(t *types.Type, params []*types.Type, visited map[*types.Ty return true } } + +// startClosures starts creation of a closure that has the function type typ. It +// creates all the formal params and results according to the type typ. On return, +// the body and closure variables of the closure must still be filled in, and +// ir.UseClosure() called. +func startClosure(pos src.XPos, outer *ir.Func, typ *types.Type) (*ir.Func, []*types.Field, []*types.Field) { + // Make a new internal function. + fn := ir.NewClosureFunc(pos, outer != nil) + ir.NameClosure(fn.OClosure, outer) + + // Build formal argument and return lists. + var formalParams []*types.Field // arguments of closure + var formalResults []*types.Field // returns of closure + for i := 0; i < typ.NumParams(); i++ { + t := typ.Params().Field(i).Type + arg := ir.NewNameAt(pos, typecheck.LookupNum("a", i)) + arg.Class = ir.PPARAM + typed(t, arg) + arg.Curfn = fn + fn.Dcl = append(fn.Dcl, arg) + f := types.NewField(pos, arg.Sym(), t) + f.Nname = arg + formalParams = append(formalParams, f) + } + for i := 0; i < typ.NumResults(); i++ { + t := typ.Results().Field(i).Type + result := ir.NewNameAt(pos, typecheck.LookupNum("r", i)) // TODO: names not needed? + result.Class = ir.PPARAMOUT + typed(t, result) + result.Curfn = fn + fn.Dcl = append(fn.Dcl, result) + f := types.NewField(pos, result.Sym(), t) + f.Nname = result + formalResults = append(formalResults, f) + } + + // Build an internal function with the right signature. + closureType := types.NewSignature(typ.Pkg(), nil, nil, formalParams, formalResults) + typed(closureType, fn.Nname) + typed(typ, fn.OClosure) + fn.SetTypecheck(1) + return fn, formalParams, formalResults + +} + +// buildClosure2 makes a closure to implement a method expression m (generic form x) +// which has a shape type as receiver. If the receiver is exactly a shape (i.e. from +// a typeparam), then the body of the closure converts the first argument (the +// receiver) to the interface bound type, and makes an interface call with the +// remaining arguments. +// +// The returned closure is fully substituted and has already has any needed +// transformations done. +func (g *irgen) buildClosure2(outer *ir.Func, info *instInfo, m, x ir.Node) ir.Node { + pos := m.Pos() + typ := m.Type() // type of the closure + + fn, formalParams, formalResults := startClosure(pos, outer, typ) + + // Capture dictionary calculated in the outer function + dictVar := ir.CaptureName(pos, fn, info.dictParam) + typed(types.Types[types.TUINTPTR], dictVar) + + // Build arguments to call inside the closure. + var args []ir.Node + for i := 0; i < typ.NumParams(); i++ { + args = append(args, formalParams[i].Nname.(*ir.Name)) + } + + // Build call itself. This involves converting the first argument to the + // bound type (an interface) using the dictionary, and then making an + // interface call with the remaining arguments. + var innerCall ir.Node + rcvr := args[0] + args = args[1:] + assert(m.(*ir.SelectorExpr).X.Type().IsShape()) + rcvr = convertUsingDictionary(info, dictVar, pos, rcvr, x, x.(*ir.SelectorExpr).X.Type().Bound(), x.(*ir.SelectorExpr).X.Type()) + dot := ir.NewSelectorExpr(pos, ir.ODOTINTER, rcvr, x.(*ir.SelectorExpr).Sel) + dot.Selection = typecheck.Lookdot1(dot, dot.Sel, dot.X.Type(), dot.X.Type().AllMethods(), 1) + + typed(x.(*ir.SelectorExpr).Selection.Type, dot) + innerCall = ir.NewCallExpr(pos, ir.OCALLINTER, dot, args) + t := m.Type() + if t.NumResults() == 0 { + innerCall.SetTypecheck(1) + } else if t.NumResults() == 1 { + typed(t.Results().Field(0).Type, innerCall) + } else { + typed(t.Results(), innerCall) + } + if len(formalResults) > 0 { + innerCall = ir.NewReturnStmt(pos, []ir.Node{innerCall}) + innerCall.SetTypecheck(1) + } + fn.Body = []ir.Node{innerCall} + + // We're all done with the captured dictionary + ir.FinishCaptureNames(pos, outer, fn) + + // Do final checks on closure and return it. + return ir.UseClosure(fn.OClosure, g.target) +} diff --git a/src/cmd/compile/internal/noder/transform.go b/src/cmd/compile/internal/noder/transform.go index 2fe55a6852..9c791d8a7b 100644 --- a/src/cmd/compile/internal/noder/transform.go +++ b/src/cmd/compile/internal/noder/transform.go @@ -664,7 +664,11 @@ func transformMethodExpr(n *ir.SelectorExpr) (res ir.Node) { s := n.Sel m := typecheck.Lookdot1(n, s, t, ms, 0) - assert(m != nil) + if !t.HasShape() { + // It's OK to not find the method if t is instantiated by shape types, + // because we will use the methods on the generic type anyway. + assert(m != nil) + } n.SetOp(ir.OMETHEXPR) n.Selection = m diff --git a/test/typeparam/boundmethod.go b/test/typeparam/boundmethod.go index 3deabbcdce..22f416422d 100644 --- a/test/typeparam/boundmethod.go +++ b/test/typeparam/boundmethod.go @@ -29,32 +29,78 @@ type Stringer interface { func stringify[T Stringer](s []T) (ret []string) { for _, v := range s { + // Test normal bounds method call on type param + x1 := v.String() + + // Test converting type param to its bound interface first + v1 := Stringer(v) + x2 := v1.String() + + // Test method expression with type param type + f1 := T.String + x3 := f1(v) + + // Test creating and calling closure equivalent to the method expression + f2 := func(v1 T) string { + return Stringer(v1).String() + } + x4 := f2(v) + + if x1 != x2 || x2 != x3 || x3 != x4 { + panic(fmt.Sprintf("Mismatched values %v, %v, %v, %v\n", x1, x2, x3, x4)) + } + ret = append(ret, v.String()) } return ret } -type StringInt[T any] T +type Ints interface { + ~int32 | ~int +} + +type StringInt[T Ints] T //go:noinline func (m StringInt[T]) String() string { - return "aa" + return strconv.Itoa(int(m)) +} + +type StringStruct[T Ints] struct { + f T +} + +func (m StringStruct[T]) String() string { + return strconv.Itoa(int(m.f)) } func main() { x := []myint{myint(1), myint(2), myint(3)} + // stringify on a normal type, whose bound method is associated with the base type. got := stringify(x) want := []string{"1", "2", "3"} if !reflect.DeepEqual(got, want) { panic(fmt.Sprintf("got %s, want %s", got, want)) } - x2 := []StringInt[myint]{StringInt[myint](1), StringInt[myint](2), StringInt[myint](3)} + x2 := []StringInt[myint]{StringInt[myint](5), StringInt[myint](7), StringInt[myint](6)} + // stringify on an instantiated type, whose bound method is associated with + // the generic type StringInt[T], which maps directly to T. got2 := stringify(x2) - want2 := []string{"aa", "aa", "aa"} + want2 := []string{ "5", "7", "6" } if !reflect.DeepEqual(got2, want2) { panic(fmt.Sprintf("got %s, want %s", got2, want2)) } + + // stringify on an instantiated type, whose bound method is associated with + // the generic type StringStruct[T], which maps to a struct containing T. + x3 := []StringStruct[myint]{StringStruct[myint]{f: 11}, StringStruct[myint]{f: 10}, StringStruct[myint]{f: 9}} + + got3 := stringify(x3) + want3 := []string{ "11", "10", "9" } + if !reflect.DeepEqual(got3, want3) { + panic(fmt.Sprintf("got %s, want %s", got3, want3)) + } }