[dev.typeparams] cmd/compile: add stenciling of simple generic functions

Allow full compilation and running of simple programs with generic
functions by stenciling on the fly the needed generic functions. Deal
with some simple derived types based on type params.

Include a few new typeparam tests min.go and add.go which involve
fully compiling and running simple generic code.

Change-Id: Ifc2a64ecacdbd860faaeee800e2ef49ffef9df5e
Reviewed-on: https://go-review.googlesource.com/c/go/+/289630
Run-TryBot: Dan Scales <danscales@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Dan Scales <danscales@google.com>
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Griesemer <gri@golang.org>
This commit is contained in:
Dan Scales 2021-02-03 15:45:26 -08:00
parent f37b0c6c12
commit dcb5e0392e
5 changed files with 380 additions and 0 deletions

View file

@ -32,4 +32,7 @@ type Package struct {
// Exported (or re-exported) symbols.
Exports []*Name
// Map from function names of stencils to already-created stencils.
Stencils map[*types.Sym]*Func
}

View file

@ -186,6 +186,21 @@ Outer:
return false
})
}
// Create any needed stencils of generic functions
g.stencil()
// For now, remove all generic functions from g.target.Decl, since they
// have been used for stenciling, but don't compile. TODO: We will
// eventually export any exportable generic functions.
j := 0
for i, decl := range g.target.Decls {
if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() == 0 {
g.target.Decls[j] = g.target.Decls[i]
j++
}
}
g.target.Decls = g.target.Decls[:j]
}
func (g *irgen) unhandled(what string, p poser) {

View file

@ -0,0 +1,280 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file will evolve, since we plan to do a mix of stenciling and passing
// around dictionaries.
package noder
import (
"bytes"
"cmd/compile/internal/base"
"cmd/compile/internal/ir"
"cmd/compile/internal/typecheck"
"cmd/compile/internal/types"
"fmt"
)
// stencil scans functions for instantiated generic function calls and
// creates the required stencils for simple generic functions.
func (g *irgen) stencil() {
g.target.Stencils = make(map[*types.Sym]*ir.Func)
for _, decl := range g.target.Decls {
if decl.Op() != ir.ODCLFUNC || decl.Type().NumTParams() > 0 {
// Skip any non-function declarations and skip generic functions
continue
}
// For each non-generic function, search for any function calls using
// generic function instantiations. (We don't yet handle generic
// function instantiations that are not immediately called.)
// Then create the needed instantiated function if it hasn't been
// created yet, and change to calling that function directly.
f := decl.(*ir.Func)
modified := false
ir.VisitList(f.Body, func(n ir.Node) {
if n.Op() != ir.OCALLFUNC || n.(*ir.CallExpr).X.Op() != ir.OFUNCINST {
return
}
// We have found a function call using a generic function
// instantiation.
call := n.(*ir.CallExpr)
inst := call.X.(*ir.InstExpr)
sym := makeInstName(inst)
//fmt.Printf("Found generic func call in %v to %v\n", f, s)
st := g.target.Stencils[sym]
if st == nil {
// If instantiation doesn't exist yet, create it and add
// to the list of decls.
st = genericSubst(sym, inst)
g.target.Stencils[sym] = st
g.target.Decls = append(g.target.Decls, st)
if base.Flag.W > 1 {
ir.Dump(fmt.Sprintf("\nstenciled %v", st), st)
}
}
// Replace the OFUNCINST with a direct reference to the
// new stenciled function
call.X = st.Nname
modified = true
})
if base.Flag.W > 1 && modified {
ir.Dump(fmt.Sprintf("\nmodified %v", decl), decl)
}
}
}
// makeInstName makes the unique name for a stenciled generic function, based on
// the name of the function and the types of the type params.
func makeInstName(inst *ir.InstExpr) *types.Sym {
b := bytes.NewBufferString("#")
b.WriteString(inst.X.(*ir.Name).Name().Sym().Name)
b.WriteString("[")
for i, targ := range inst.Targs {
if i > 0 {
b.WriteString(",")
}
b.WriteString(targ.Name().Sym().Name)
}
b.WriteString("]")
return typecheck.Lookup(b.String())
}
// Struct containing info needed for doing the substitution as we create the
// instantiation of a generic function with specified type arguments.
type subster struct {
newf *ir.Func // Func node for the new stenciled function
tparams *types.Fields
targs []ir.Node
// The substitution map from name nodes in the generic function to the
// name nodes in the new stenciled function.
vars map[*ir.Name]*ir.Name
}
// genericSubst returns a new function with the specified name. The function is an
// instantiation of a generic function with type params, as specified by inst.
func genericSubst(name *types.Sym, inst *ir.InstExpr) *ir.Func {
// Similar to noder.go: funcDecl
nameNode := inst.X.(*ir.Name)
gf := nameNode.Func
newf := ir.NewFunc(inst.Pos())
newf.Nname = ir.NewNameAt(inst.Pos(), name)
newf.Nname.Func = newf
newf.Nname.Defn = newf
subst := &subster{
newf: newf,
tparams: nameNode.Type().TParams().Fields(),
targs: inst.Targs,
vars: make(map[*ir.Name]*ir.Name),
}
newf.Dcl = make([]*ir.Name, len(gf.Dcl))
for i, n := range gf.Dcl {
newf.Dcl[i] = subst.node(n).(*ir.Name)
}
newf.Body = subst.list(gf.Body)
// Ugly: we have to insert the Name nodes of the parameters/results into
// the function type. The current function type has no Nname fields set,
// because it came via conversion from the types2 type.
oldt := inst.Type()
newt := types.NewSignature(oldt.Pkg(), nil, nil, subst.fields(ir.PPARAM, oldt.Params(), newf.Dcl),
subst.fields(ir.PPARAMOUT, oldt.Results(), newf.Dcl))
newf.Nname.Ntype = ir.TypeNode(newt)
newf.Nname.SetType(newt)
ir.MarkFunc(newf.Nname)
newf.SetTypecheck(1)
newf.Nname.SetTypecheck(1)
// TODO(danscales) - remove later, but avoid confusion for now.
newf.Pragma = ir.Noinline
return newf
}
// node is like DeepCopy(), but creates distinct ONAME nodes, and also descends
// into closures. It substitutes type arguments for type parameters in all the new
// nodes.
func (subst *subster) node(n ir.Node) ir.Node {
// Use closure to capture all state needed by the ir.EditChildren argument.
var edit func(ir.Node) ir.Node
edit = func(x ir.Node) ir.Node {
switch x.Op() {
case ir.ONAME:
name := x.(*ir.Name)
if v := subst.vars[name]; v != nil {
return v
}
m := ir.NewNameAt(name.Pos(), name.Sym())
t := x.Type()
newt := subst.typ(t)
m.SetType(newt)
m.Curfn = subst.newf
m.Class = name.Class
subst.vars[name] = m
m.SetTypecheck(1)
return m
case ir.OLITERAL, ir.ONIL:
if x.Sym() != nil {
return x
}
}
m := ir.Copy(x)
if _, isExpr := m.(ir.Expr); isExpr {
m.SetType(subst.typ(x.Type()))
}
ir.EditChildren(m, edit)
if x.Op() == ir.OCLOSURE {
x := x.(*ir.ClosureExpr)
// Need to save/duplicate x.Func.Nname,
// x.Func.Nname.Ntype, x.Func.Dcl, x.Func.ClosureVars, and
// x.Func.Body.
oldfn := x.Func
newfn := ir.NewFunc(oldfn.Pos())
if oldfn.ClosureCalled() {
newfn.SetClosureCalled(true)
}
m.(*ir.ClosureExpr).Func = newfn
newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), oldfn.Nname.Sym())
newfn.Nname.SetType(oldfn.Nname.Type())
newfn.Nname.Ntype = subst.node(oldfn.Nname.Ntype).(ir.Ntype)
newfn.Body = subst.list(oldfn.Body)
// Make shallow copy of the Dcl and ClosureVar slices
newfn.Dcl = append([]*ir.Name(nil), oldfn.Dcl...)
newfn.ClosureVars = append([]*ir.Name(nil), oldfn.ClosureVars...)
}
return m
}
return edit(n)
}
func (subst *subster) list(l []ir.Node) []ir.Node {
s := make([]ir.Node, len(l))
for i, n := range l {
s[i] = subst.node(n)
}
return s
}
// typ substitutes any type parameter found with the corresponding type argument.
func (subst *subster) typ(t *types.Type) *types.Type {
for i, tp := range subst.tparams.Slice() {
if tp.Type == t {
return subst.targs[i].Type()
}
}
switch t.Kind() {
case types.TARRAY:
elem := t.Elem()
newelem := subst.typ(elem)
if subst.typ(elem) != elem {
return types.NewArray(newelem, t.NumElem())
}
case types.TPTR:
elem := t.Elem()
newelem := subst.typ(elem)
if subst.typ(elem) != elem {
return types.NewPtr(newelem)
}
case types.TSLICE:
elem := t.Elem()
newelem := subst.typ(elem)
if subst.typ(elem) != elem {
return types.NewSlice(newelem)
}
case types.TSTRUCT:
newfields := make([]*types.Field, t.NumFields())
change := false
for i, f := range t.Fields().Slice() {
t2 := subst.typ(f.Type)
if t2 != f.Type {
change = true
}
newfields[i] = types.NewField(f.Pos, f.Sym, t2)
}
if change {
return types.NewStruct(t.Pkg(), newfields)
}
// TODO: case TFUNC
// TODO: case TCHAN
// TODO: case TMAP
// TODO: case TINTER
}
return t
}
// fields sets the Nname field for the Field nodes inside a type signature, based
// on the corresponding in/out parameters in dcl. It depends on the in and out
// parameters being in order in dcl.
func (subst *subster) fields(class ir.Class, oldt *types.Type, dcl []*ir.Name) []*types.Field {
oldfields := oldt.FieldSlice()
newfields := make([]*types.Field, len(oldfields))
var i int
// Find the starting index in dcl of declarations of the class (either
// PPARAM or PPARAMOUT).
for i = range dcl {
if dcl[i].Class == class {
break
}
}
// Create newfields nodes that are copies of the oldfields nodes, but
// with substitution for any type params, and with Nname set to be the node in
// Dcl for the corresponding PPARAM or PPARAMOUT.
for j := range oldfields {
newfields[j] = oldfields[j].Copy()
newfields[j].Type = subst.typ(oldfields[j].Type)
newfields[j].Nname = dcl[i]
i++
}
return newfields
}

50
test/typeparam/add.go Normal file
View file

@ -0,0 +1,50 @@
// run -gcflags=-G=3
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
)
func add[T interface{ type int, float64 }](vec []T) T {
var sum T
for _, elt := range vec {
sum = sum + elt
}
return sum
}
func abs(f float64) float64 {
if f < 0.0 {
return -f
}
return f
}
func main() {
vec1 := []int{3, 4}
vec2 := []float64{5.8, 9.6}
want := vec1[0] + vec1[1]
got := add[int](vec1)
if want != got {
panic(fmt.Sprintf("Want %d, got %d", want, got))
}
got = add(vec1)
if want != got {
panic(fmt.Sprintf("Want %d, got %d", want, got))
}
fwant := vec2[0] + vec2[1]
fgot := add[float64](vec2)
if abs(fgot - fwant) > 1e-10 {
panic(fmt.Sprintf("Want %f, got %f", fwant, fgot))
}
fgot = add(vec2)
if abs(fgot - fwant) > 1e-10 {
panic(fmt.Sprintf("Want %f, got %f", fwant, fgot))
}
}

32
test/typeparam/min.go Normal file
View file

@ -0,0 +1,32 @@
// run -gcflags=-G=3
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
)
func min[T interface{ type int }](x, y T) T {
if x < y {
return x
}
return y
}
func main() {
want := 2
got := min[int](2, 3)
if want != got {
panic(fmt.Sprintf("Want %d, got %d", want, got))
}
got = min(2, 3)
if want != got {
panic(fmt.Sprintf("Want %d, got %d", want, got))
}
}