diff --git a/src/cmd/compile/internal/importer/iimport.go b/src/cmd/compile/internal/importer/iimport.go index fd48bfc179..fb39e93073 100644 --- a/src/cmd/compile/internal/importer/iimport.go +++ b/src/cmd/compile/internal/importer/iimport.go @@ -51,6 +51,11 @@ const ( iexportVersionCurrent = iexportVersionGenerics + 1 ) +type ident struct { + pkg string + name string +} + const predeclReserved = 32 type itag uint64 @@ -124,6 +129,9 @@ func iImportData(imports map[string]*types2.Package, data []byte, path string) ( declData: declData, pkgIndex: make(map[*types2.Package]map[string]uint64), typCache: make(map[uint64]types2.Type), + // Separate map for typeparams, keyed by their package and unique + // name (name with subscript). + tparamIndex: make(map[ident]types2.Type), } for i, pt := range predeclared { @@ -202,9 +210,10 @@ type iimporter struct { pkgCache map[uint64]*types2.Package posBaseCache map[uint64]*syntax.PosBase - declData []byte - pkgIndex map[*types2.Package]map[string]uint64 - typCache map[uint64]types2.Type + declData []byte + pkgIndex map[*types2.Package]map[string]uint64 + typCache map[uint64]types2.Type + tparamIndex map[ident]types2.Type interfaceList []*types2.Interface } @@ -358,6 +367,28 @@ func (r *importReader) obj(name string) { } } + case 'P': + // We need to "declare" a typeparam in order to have a name that + // can be referenced recursively (if needed) in the type param's + // bound. + if r.p.exportVersion < iexportVersionGenerics { + errorf("unexpected type param type") + } + index := int(r.int64()) + name0, sub := parseSubscript(name) + tn := types2.NewTypeName(pos, r.currPkg, name0, nil) + t := (*types2.Checker)(nil).NewTypeParam(tn, index, nil) + if sub == 0 { + errorf("missing subscript") + } + t.SetId(sub) + // To handle recursive references to the typeparam within its + // bound, save the partial type in tparamIndex before reading the bounds. + id := ident{r.currPkg.Name(), name} + r.p.tparamIndex[id] = t + + t.SetBound(r.typ()) + case 'V': typ := r.typ() @@ -617,34 +648,15 @@ func (r *importReader) doType(base *types2.Named) types2.Type { if r.p.exportVersion < iexportVersionGenerics { errorf("unexpected type param type") } - r.currPkg = r.pkg() - pos := r.pos() - name := r.string() - - // Extract the subscript value from the type param name. We export - // and import the subscript value, so that all type params have - // unique names. - sub := uint64(0) - startsub := -1 - for i, r := range name { - if '₀' <= r && r < '₀'+10 { - if startsub == -1 { - startsub = i - } - sub = sub*10 + uint64(r-'₀') - } + pkg, name := r.qualifiedIdent() + id := ident{pkg.Name(), name} + if t, ok := r.p.tparamIndex[id]; ok { + // We're already in the process of importing this typeparam. + return t } - if startsub >= 0 { - name = name[:startsub] - } - index := int(r.int64()) - bound := r.typ() - tn := types2.NewTypeName(pos, r.currPkg, name, nil) - t := (*types2.Checker)(nil).NewTypeParam(tn, index, bound) - if sub >= 0 { - t.SetId(sub) - } - return t + // Otherwise, import the definition of the typeparam now. + r.p.doDecl(pkg, name) + return r.p.tparamIndex[id] case instType: if r.p.exportVersion < iexportVersionGenerics { @@ -753,3 +765,23 @@ func baseType(typ types2.Type) *types2.Named { n, _ := typ.(*types2.Named) return n } + +func parseSubscript(name string) (string, uint64) { + // Extract the subscript value from the type param name. We export + // and import the subscript value, so that all type params have + // unique names. + sub := uint64(0) + startsub := -1 + for i, r := range name { + if '₀' <= r && r < '₀'+10 { + if startsub == -1 { + startsub = i + } + sub = sub*10 + uint64(r-'₀') + } + } + if startsub >= 0 { + name = name[:startsub] + } + return name, sub +} diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 08c09c6fb1..3ba364f67c 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -15,7 +15,6 @@ import ( "cmd/compile/internal/types" "fmt" "go/constant" - "strings" ) // For catching problems as we add more features @@ -425,13 +424,6 @@ func (g *irgen) instantiateMethods() { } -// genericSym returns the name of the base generic type for the type named by -// sym. It simply returns the name obtained by removing everything after the -// first bracket ("["). -func genericTypeName(sym *types.Sym) string { - return sym.Name[0:strings.Index(sym.Name, "[")] -} - // getInstantiationForNode returns the function/method instantiation for a // InstExpr node inst. func (g *irgen) getInstantiationForNode(inst *ir.InstExpr) *ir.Func { @@ -479,11 +471,7 @@ type subster struct { g *irgen isMethod bool // If a method is being instantiated newf *ir.Func // Func node for the new stenciled function - tparams []*types.Type - targs []*types.Type - // The substitution map from name nodes in the generic function to the - // name nodes in the new stenciled function. - vars map[*ir.Name]*ir.Name + ts typecheck.Tsubster } // genericSubst returns a new function with name newsym. The function is an @@ -526,9 +514,11 @@ func (g *irgen) genericSubst(newsym *types.Sym, nameNode *ir.Name, targs []*type g: g, isMethod: isMethod, newf: newf, - tparams: tparams, - targs: targs, - vars: make(map[*ir.Name]*ir.Name), + ts: typecheck.Tsubster{ + Tparams: tparams, + Targs: targs, + Vars: make(map[*ir.Name]*ir.Name), + }, } newf.Dcl = make([]*ir.Name, 0, len(gf.Dcl)+1) @@ -574,6 +564,9 @@ func (g *irgen) genericSubst(newsym *types.Sym, nameNode *ir.Name, targs []*type newf.Body.Prepend(g.checkDictionary(dictionaryName, targs)...) ir.CurFunc = savef + // Add any new, fully instantiated types seen during the substitution to + // g.instTypeList. + g.instTypeList = append(g.instTypeList, subst.ts.InstTypeList...) return newf } @@ -586,13 +579,13 @@ func (subst *subster) localvar(name *ir.Name) *ir.Name { if name.IsClosureVar() { m.SetIsClosureVar(true) } - m.SetType(subst.typ(name.Type())) + m.SetType(subst.ts.Typ(name.Type())) m.BuiltinOp = name.BuiltinOp m.Curfn = subst.newf m.Class = name.Class assert(name.Class != ir.PEXTERN && name.Class != ir.PFUNC) m.Func = name.Func - subst.vars[name] = m + subst.ts.Vars[name] = m m.SetTypecheck(1) return m } @@ -635,7 +628,7 @@ func (g *irgen) checkDictionary(name *ir.Name, targs []*types.Type) (code []ir.N return } -// node is like DeepCopy(), but substitutes ONAME nodes based on subst.vars, and +// node is like DeepCopy(), but substitutes ONAME nodes based on subst.ts.vars, and // also descends into closures. It substitutes type arguments for type parameters // in all the new nodes. func (subst *subster) node(n ir.Node) ir.Node { @@ -644,10 +637,10 @@ func (subst *subster) node(n ir.Node) ir.Node { edit = func(x ir.Node) ir.Node { switch x.Op() { case ir.OTYPE: - return ir.TypeNode(subst.typ(x.Type())) + return ir.TypeNode(subst.ts.Typ(x.Type())) case ir.ONAME: - if v := subst.vars[x.(*ir.Name)]; v != nil { + if v := subst.ts.Vars[x.(*ir.Name)]; v != nil { return v } return x @@ -673,7 +666,7 @@ func (subst *subster) node(n ir.Node) ir.Node { base.Fatalf(fmt.Sprintf("Nil type for %v", x)) } } else if x.Op() != ir.OCLOSURE { - m.SetType(subst.typ(x.Type())) + m.SetType(subst.ts.Typ(x.Type())) } } ir.EditChildren(m, edit) @@ -815,7 +808,7 @@ func (subst *subster) node(n ir.Node) ir.Node { m.(*ir.ClosureExpr).Func = newfn // Closure name can already have brackets, if it derives // from a generic method - newsym := typecheck.MakeInstName(oldfn.Nname.Sym(), subst.targs, subst.isMethod) + newsym := typecheck.MakeInstName(oldfn.Nname.Sym(), subst.ts.Targs, subst.isMethod) newfn.Nname = ir.NewNameAt(oldfn.Nname.Pos(), newsym) newfn.Nname.Func = newfn newfn.Nname.Defn = newfn @@ -828,7 +821,7 @@ func (subst *subster) node(n ir.Node) ir.Node { newfn.Dcl = subst.namelist(oldfn.Dcl) newfn.ClosureVars = subst.namelist(oldfn.ClosureVars) - typed(subst.typ(oldfn.Nname.Type()), newfn.Nname) + typed(subst.ts.Typ(oldfn.Nname.Type()), newfn.Nname) typed(newfn.Nname.Type(), m) newfn.SetTypecheck(1) @@ -867,298 +860,6 @@ func (subst *subster) list(l []ir.Node) []ir.Node { return s } -// tstruct substitutes type params in types of the fields of a structure type. For -// each field, tstruct copies the Nname, and translates it if Nname is in -// subst.vars. To always force the creation of a new (top-level) struct, -// regardless of whether anything changed with the types or names of the struct's -// fields, set force to true. -func (subst *subster) tstruct(t *types.Type, force bool) *types.Type { - if t.NumFields() == 0 { - if t.HasTParam() { - // For an empty struct, we need to return a new type, - // since it may now be fully instantiated (HasTParam - // becomes false). - return types.NewStruct(t.Pkg(), nil) - } - return t - } - var newfields []*types.Field - if force { - newfields = make([]*types.Field, t.NumFields()) - } - for i, f := range t.Fields().Slice() { - t2 := subst.typ(f.Type) - if (t2 != f.Type || f.Nname != nil) && newfields == nil { - newfields = make([]*types.Field, t.NumFields()) - for j := 0; j < i; j++ { - newfields[j] = t.Field(j) - } - } - if newfields != nil { - // TODO(danscales): make sure this works for the field - // names of embedded types (which should keep the name of - // the type param, not the instantiated type). - newfields[i] = types.NewField(f.Pos, f.Sym, t2) - if f.Nname != nil { - v := subst.vars[f.Nname.(*ir.Name)] - if v != nil { - // This is the case where we are - // translating the type of the function we - // are substituting, so its dcls are in - // the subst.vars table, and we want to - // change to reference the new dcl. - newfields[i].Nname = v - } else { - // This is the case where we are - // translating the type of a function - // reference inside the function we are - // substituting, so we leave the Nname - // value as is. - newfields[i].Nname = f.Nname - } - } - } - } - if newfields != nil { - return types.NewStruct(t.Pkg(), newfields) - } - return t - -} - -// tinter substitutes type params in types of the methods of an interface type. -func (subst *subster) tinter(t *types.Type) *types.Type { - if t.Methods().Len() == 0 { - return t - } - var newfields []*types.Field - for i, f := range t.Methods().Slice() { - t2 := subst.typ(f.Type) - if (t2 != f.Type || f.Nname != nil) && newfields == nil { - newfields = make([]*types.Field, t.Methods().Len()) - for j := 0; j < i; j++ { - newfields[j] = t.Methods().Index(j) - } - } - if newfields != nil { - newfields[i] = types.NewField(f.Pos, f.Sym, t2) - } - } - if newfields != nil { - return types.NewInterface(t.Pkg(), newfields) - } - return t -} - -// typ computes the type obtained by substituting any type parameter in t with the -// corresponding type argument in subst. If t contains no type parameters, the -// result is t; otherwise the result is a new type. It deals with recursive types -// by using TFORW types and finding partially or fully created types via sym.Def. -func (subst *subster) typ(t *types.Type) *types.Type { - if !t.HasTParam() && t.Kind() != types.TFUNC { - // Note: function types need to be copied regardless, as the - // types of closures may contain declarations that need - // to be copied. See #45738. - return t - } - - if t.Kind() == types.TTYPEPARAM { - for i, tp := range subst.tparams { - if tp == t { - return subst.targs[i] - } - } - // If t is a simple typeparam T, then t has the name/symbol 'T' - // and t.Underlying() == t. - // - // However, consider the type definition: 'type P[T any] T'. We - // might use this definition so we can have a variant of type T - // that we can add new methods to. Suppose t is a reference to - // P[T]. t has the name 'P[T]', but its kind is TTYPEPARAM, - // because P[T] is defined as T. If we look at t.Underlying(), it - // is different, because the name of t.Underlying() is 'T' rather - // than 'P[T]'. But the kind of t.Underlying() is also TTYPEPARAM. - // In this case, we do the needed recursive substitution in the - // case statement below. - if t.Underlying() == t { - // t is a simple typeparam that didn't match anything in tparam - return t - } - // t is a more complex typeparam (e.g. P[T], as above, whose - // definition is just T). - assert(t.Sym() != nil) - } - - var newsym *types.Sym - var neededTargs []*types.Type - var forw *types.Type - - if t.Sym() != nil { - // Translate the type params for this type according to - // the tparam/targs mapping from subst. - neededTargs = make([]*types.Type, len(t.RParams())) - for i, rparam := range t.RParams() { - neededTargs[i] = subst.typ(rparam) - } - // For a named (defined) type, we have to change the name of the - // type as well. We do this first, so we can look up if we've - // already seen this type during this substitution or other - // definitions/substitutions. - genName := genericTypeName(t.Sym()) - newsym = t.Sym().Pkg.Lookup(typecheck.InstTypeName(genName, neededTargs)) - if newsym.Def != nil { - // We've already created this instantiated defined type. - return newsym.Def.Type() - } - - // In order to deal with recursive generic types, create a TFORW - // type initially and set the Def field of its sym, so it can be - // found if this type appears recursively within the type. - forw = typecheck.NewIncompleteNamedType(t.Pos(), newsym) - //println("Creating new type by sub", newsym.Name, forw.HasTParam()) - forw.SetRParams(neededTargs) - // Copy the OrigSym from the re-instantiated type (which is the sym of - // the base generic type). - assert(t.OrigSym != nil) - forw.OrigSym = t.OrigSym - } - - var newt *types.Type - - switch t.Kind() { - case types.TTYPEPARAM: - if t.Sym() == newsym { - // The substitution did not change the type. - return t - } - // Substitute the underlying typeparam (e.g. T in P[T], see - // the example describing type P[T] above). - newt = subst.typ(t.Underlying()) - assert(newt != t) - - case types.TARRAY: - elem := t.Elem() - newelem := subst.typ(elem) - if newelem != elem { - newt = types.NewArray(newelem, t.NumElem()) - } - - case types.TPTR: - elem := t.Elem() - newelem := subst.typ(elem) - if newelem != elem { - newt = types.NewPtr(newelem) - } - - case types.TSLICE: - elem := t.Elem() - newelem := subst.typ(elem) - if newelem != elem { - newt = types.NewSlice(newelem) - } - - case types.TSTRUCT: - newt = subst.tstruct(t, false) - if newt == t { - newt = nil - } - - case types.TFUNC: - newrecvs := subst.tstruct(t.Recvs(), false) - newparams := subst.tstruct(t.Params(), false) - newresults := subst.tstruct(t.Results(), false) - if newrecvs != t.Recvs() || newparams != t.Params() || newresults != t.Results() { - // If any types have changed, then the all the fields of - // of recv, params, and results must be copied, because they have - // offset fields that are dependent, and so must have an - // independent copy for each new signature. - var newrecv *types.Field - if newrecvs.NumFields() > 0 { - if newrecvs == t.Recvs() { - newrecvs = subst.tstruct(t.Recvs(), true) - } - newrecv = newrecvs.Field(0) - } - if newparams == t.Params() { - newparams = subst.tstruct(t.Params(), true) - } - if newresults == t.Results() { - newresults = subst.tstruct(t.Results(), true) - } - newt = types.NewSignature(t.Pkg(), newrecv, t.TParams().FieldSlice(), newparams.FieldSlice(), newresults.FieldSlice()) - } - - case types.TINTER: - newt = subst.tinter(t) - if newt == t { - newt = nil - } - - case types.TMAP: - newkey := subst.typ(t.Key()) - newval := subst.typ(t.Elem()) - if newkey != t.Key() || newval != t.Elem() { - newt = types.NewMap(newkey, newval) - } - - case types.TCHAN: - elem := t.Elem() - newelem := subst.typ(elem) - if newelem != elem { - newt = types.NewChan(newelem, t.ChanDir()) - if !newt.HasTParam() { - // TODO(danscales): not sure why I have to do this - // only for channels..... - types.CheckSize(newt) - } - } - } - if newt == nil { - // Even though there were typeparams in the type, there may be no - // change if this is a function type for a function call (which will - // have its own tparams/targs in the function instantiation). - return t - } - - if t.Sym() == nil { - // Not a named type, so there was no forwarding type and there are - // no methods to substitute. - assert(t.Methods().Len() == 0) - return newt - } - - forw.SetUnderlying(newt) - newt = forw - - if t.Kind() != types.TINTER && t.Methods().Len() > 0 { - // Fill in the method info for the new type. - var newfields []*types.Field - newfields = make([]*types.Field, t.Methods().Len()) - for i, f := range t.Methods().Slice() { - t2 := subst.typ(f.Type) - oldsym := f.Nname.Sym() - newsym := typecheck.MakeInstName(oldsym, subst.targs, true) - // TODO: use newsym? - var nname *ir.Name - if newsym.Def != nil { - nname = newsym.Def.(*ir.Name) - } else { - nname = ir.NewNameAt(f.Pos, oldsym) - nname.SetType(t2) - oldsym.Def = nname - } - newfields[i] = types.NewField(f.Pos, f.Sym, t2) - newfields[i].Nname = nname - } - newt.Methods().Set(newfields) - if !newt.HasTParam() { - // Generate all the methods for a new fully-instantiated type. - subst.g.instTypeList = append(subst.g.instTypeList, newt) - } - } - return newt -} - // fields sets the Nname field for the Field nodes inside a type signature, based // on the corresponding in/out parameters in dcl. It depends on the in and out // parameters being in order in dcl. @@ -1178,7 +879,7 @@ func (subst *subster) fields(class ir.Class, oldfields []*types.Field, dcl []*ir newfields := make([]*types.Field, len(oldfields)) for j := range oldfields { newfields[j] = oldfields[j].Copy() - newfields[j].Type = subst.typ(oldfields[j].Type) + newfields[j].Type = subst.ts.Typ(oldfields[j].Type) // A PPARAM field will be missing from dcl if its name is // unspecified or specified as "_". So, we compare the dcl sym // with the field sym (or sym of the field's Nname node). (Unnamed diff --git a/src/cmd/compile/internal/noder/types.go b/src/cmd/compile/internal/noder/types.go index f0061e79d7..b37793b2d0 100644 --- a/src/cmd/compile/internal/noder/types.go +++ b/src/cmd/compile/internal/noder/types.go @@ -316,13 +316,16 @@ func (g *irgen) fillinMethods(typ *types2.Named, ntyp *types.Type) { tparams[i] = g.typ1(rparam.Type()) } assert(len(tparams) == len(targs)) - subst := &subster{ - g: g, - tparams: tparams, - targs: targs, + ts := typecheck.Tsubster{ + Tparams: tparams, + Targs: targs, } // Do the substitution of the type - meth2.SetType(subst.typ(meth.Type())) + meth2.SetType(ts.Typ(meth.Type())) + // Add any new fully instantiated types + // seen during the substitution to + // g.instTypeList. + g.instTypeList = append(g.instTypeList, ts.InstTypeList...) newsym.Def = meth2 } meth = meth2 diff --git a/src/cmd/compile/internal/typecheck/iexport.go b/src/cmd/compile/internal/typecheck/iexport.go index 66c356ee7c..f635b79ada 100644 --- a/src/cmd/compile/internal/typecheck/iexport.go +++ b/src/cmd/compile/internal/typecheck/iexport.go @@ -506,6 +506,20 @@ func (p *iexporter) doDecl(n *ir.Name) { w.constExt(n) case ir.OTYPE: + if n.Type().Kind() == types.TTYPEPARAM && n.Type().Underlying() == n.Type() { + // Even though it has local scope, a typeparam requires a + // declaration via its package and unique name, because it + // may be referenced within its type bound during its own + // definition. + w.tag('P') + // A typeparam has a name, and has a type bound rather + // than an underlying type. + w.pos(n.Pos()) + w.int64(int64(n.Type().Index())) + w.typ(n.Type().Bound()) + break + } + if n.Alias() { // Alias. w.tag('A') @@ -519,9 +533,10 @@ func (p *iexporter) doDecl(n *ir.Name) { w.pos(n.Pos()) if base.Flag.G > 0 { - // Export any new typeparams needed for this type + // Export type parameters, if any, needed for this type w.typeList(n.Type().RParams()) } + underlying := n.Type().Underlying() if underlying == types.ErrorType.Underlying() { // For "type T error", use error as the @@ -837,26 +852,6 @@ func (w *exportWriter) startType(k itag) { } func (w *exportWriter) doTyp(t *types.Type) { - if t.Kind() == types.TTYPEPARAM { - assert(base.Flag.G > 0) - // A typeparam has a name, but doesn't have an underlying type. - // Just write out the details of the type param here. All other - // uses of this typeparam type will be written out as its unique - // type offset. - w.startType(typeParamType) - s := t.Sym() - w.setPkg(s.Pkg, true) - w.pos(t.Pos()) - - // We are writing out the name with the subscript, so that the - // typeparam name is unique. - w.string(s.Name) - w.int64(int64(t.Index())) - - w.typ(t.Bound()) - return - } - s := t.Sym() if s != nil && t.OrigSym != nil { assert(base.Flag.G > 0) @@ -880,6 +875,21 @@ func (w *exportWriter) doTyp(t *types.Type) { return } + // The 't.Underlying() == t' check is to confirm this is a base typeparam + // type, rather than a defined type with typeparam underlying type, like: + // type orderedAbs[T any] T + if t.Kind() == types.TTYPEPARAM && t.Underlying() == t { + assert(base.Flag.G > 0) + if s.Pkg == types.BuiltinPkg || s.Pkg == ir.Pkgs.Unsafe { + base.Fatalf("builtin type missing from typIndex: %v", t) + } + // Write out the first use of a type param as a qualified ident. + // This will force a "declaration" of the type param. + w.startType(typeParamType) + w.qualifiedIdent(t.Obj().(*ir.Name)) + return + } + if s != nil { if s.Pkg == types.BuiltinPkg || s.Pkg == ir.Pkgs.Unsafe { base.Fatalf("builtin type missing from typIndex: %v", t) @@ -1325,14 +1335,20 @@ func (w *exportWriter) funcExt(n *ir.Name) { // Inline body. if n.Type().HasTParam() { if n.Func.Inl != nil { - base.FatalfAt(n.Pos(), "generic function is marked inlineable") - } - // Populate n.Func.Inl, so body of exported generic function will - // be written out. - n.Func.Inl = &ir.Inline{ - Cost: 1, - Dcl: n.Func.Dcl, - Body: n.Func.Body, + // n.Func.Inl may already be set on a generic function if + // we imported it from another package, but shouldn't be + // set for a generic function in the local package. + if n.Sym().Pkg == types.LocalPkg { + base.FatalfAt(n.Pos(), "generic function is marked inlineable") + } + } else { + // Populate n.Func.Inl, so body of exported generic function will + // be written out. + n.Func.Inl = &ir.Inline{ + Cost: 1, + Dcl: n.Func.Dcl, + Body: n.Func.Body, + } } } if n.Func.Inl != nil { diff --git a/src/cmd/compile/internal/typecheck/iimport.go b/src/cmd/compile/internal/typecheck/iimport.go index 96107b657b..9e6115cbf7 100644 --- a/src/cmd/compile/internal/typecheck/iimport.go +++ b/src/cmd/compile/internal/typecheck/iimport.go @@ -342,19 +342,22 @@ func (r *importReader) doDecl(sym *types.Sym) *ir.Name { // declaration before recursing. n := importtype(r.p.ipkg, pos, sym) t := n.Type() - - // We also need to defer width calculations until - // after the underlying type has been assigned. - types.DeferCheckSize() - underlying := r.typ() - t.SetUnderlying(underlying) - types.ResumeCheckSize() - if rparams != nil { t.SetRParams(rparams) } + // We also need to defer width calculations until + // after the underlying type has been assigned. + types.DeferCheckSize() + deferDoInst() + underlying := r.typ() + t.SetUnderlying(underlying) + if underlying.IsInterface() { + // Finish up all type instantiations and CheckSize calls + // now that a top-level type is fully constructed. + resumeDoInst() + types.ResumeCheckSize() r.typeExt(t) return n } @@ -380,12 +383,38 @@ func (r *importReader) doDecl(sym *types.Sym) *ir.Name { } t.Methods().Set(ms) + // Finish up all instantiations and CheckSize calls now + // that a top-level type is fully constructed. + resumeDoInst() + types.ResumeCheckSize() + r.typeExt(t) for _, m := range ms { r.methExt(m) } return n + case 'P': + if r.p.exportVersion < iexportVersionGenerics { + base.Fatalf("unexpected type param type") + } + if sym.Def != nil { + // Make sure we use the same type param type for the same + // name, whether it is created during types1-import or + // this types2-to-types1 translation. + return sym.Def.(*ir.Name) + } + index := int(r.int64()) + t := types.NewTypeParam(sym, index) + // Nname needed to save the pos. + nname := ir.NewDeclNameAt(pos, ir.OTYPE, sym) + sym.Def = nname + nname.SetType(t) + t.SetNod(nname) + + t.SetBound(r.typ()) + return nname + case 'V': typ := r.typ() @@ -545,7 +574,12 @@ func (r *importReader) pos() src.XPos { } func (r *importReader) typ() *types.Type { - return r.p.typAt(r.uint64()) + // If this is a top-level type call, defer type instantiations until the + // type is fully constructed. + deferDoInst() + t := r.p.typAt(r.uint64()) + resumeDoInst() + return t } func (r *importReader) exoticType() *types.Type { @@ -683,7 +717,13 @@ func (p *iimporter) typAt(off uint64) *types.Type { // are pushed to compile queue, then draining from the queue for compiling. // During this process, the size calculation is disabled, so it is not safe for // calculating size during SSA generation anymore. See issue #44732. - types.CheckSize(t) + // + // No need to calc sizes for re-instantiated generic types, and + // they are not necessarily resolved until the top-level type is + // defined (because of recursive types). + if t.OrigSym == nil || !t.HasTParam() { + types.CheckSize(t) + } p.typCache[off] = t } return t @@ -779,27 +819,18 @@ func (r *importReader) typ1() *types.Type { if r.p.exportVersion < iexportVersionGenerics { base.Fatalf("unexpected type param type") } - r.setPkg() - pos := r.pos() - name := r.string() - sym := r.currPkg.Lookup(name) - index := int(r.int64()) - bound := r.typ() - if sym.Def != nil { - // Make sure we use the same type param type for the same - // name, whether it is created during types1-import or - // this types2-to-types1 translation. - return sym.Def.Type() + // Similar to code for defined types, since we "declared" + // typeparams to deal with recursion (typeparam is used within its + // own type bound). + ident := r.qualifiedIdent() + if ident.Sym().Def != nil { + return ident.Sym().Def.(*ir.Name).Type() } - t := types.NewTypeParam(sym, index) - // Nname needed to save the pos. - nname := ir.NewDeclNameAt(pos, ir.OTYPE, sym) - sym.Def = nname - nname.SetType(t) - t.SetNod(nname) - - t.SetBound(bound) - return t + n := expandDecl(ident) + if n.Op() != ir.OTYPE { + base.Fatalf("expected OTYPE, got %v: %v, %v", n.Op(), n.Sym(), n) + } + return n.Type() case instType: if r.p.exportVersion < iexportVersionGenerics { @@ -1758,20 +1789,91 @@ func Instantiate(pos src.XPos, baseType *types.Type, targs []*types.Type) *types name := InstTypeName(baseSym.Name, targs) instSym := baseSym.Pkg.Lookup(name) if instSym.Def != nil { + // May match existing type from previous import or + // types2-to-types1 conversion, or from in-progress instantiation + // in the current type import stack. return instSym.Def.Type() } t := NewIncompleteNamedType(baseType.Pos(), instSym) t.SetRParams(targs) - // baseType may not yet be complete (since we are in the middle of - // importing it), but its underlying type will be updated when baseType's - // underlying type is finished. - t.SetUnderlying(baseType.Underlying()) - - // As with types2, the methods are the generic method signatures (without - // substitution). - t.Methods().Set(baseType.Methods().Slice()) t.OrigSym = baseSym + // baseType may still be TFORW or its methods may not be fully filled in + // (since we are in the middle of importing it). So, delay call to + // substInstType until we get back up to the top of the current top-most + // type import. + deferredInstStack = append(deferredInstStack, t) + return t } + +var deferredInstStack []*types.Type +var deferInst int + +// deferDoInst defers substitution on instantiated types until we are at the +// top-most defined type, so the base types are fully defined. +func deferDoInst() { + deferInst++ +} + +func resumeDoInst() { + if deferInst == 1 { + for len(deferredInstStack) > 0 { + t := deferredInstStack[0] + deferredInstStack = deferredInstStack[1:] + substInstType(t, t.OrigSym.Def.(*ir.Name).Type(), t.RParams()) + } + } + deferInst-- +} + +// doInst creates a new instantiation type (which will be added to +// deferredInstStack for completion later) for an incomplete type encountered +// during a type substitution for an instantiation. This is needed for +// instantiations of mutually recursive types. +func doInst(t *types.Type) *types.Type { + return Instantiate(t.Pos(), t.OrigSym.Def.(*ir.Name).Type(), t.RParams()) +} + +// substInstType completes the instantiation of a generic type by doing a +// substitution on the underlying type itself and any methods. t is the +// instantiation being created, baseType is the base generic type, and targs are +// the type arguments that baseType is being instantiated with. +func substInstType(t *types.Type, baseType *types.Type, targs []*types.Type) { + subst := Tsubster{ + Tparams: baseType.RParams(), + Targs: targs, + SubstForwFunc: doInst, + } + t.SetUnderlying(subst.Typ(baseType.Underlying())) + + newfields := make([]*types.Field, baseType.Methods().Len()) + for i, f := range baseType.Methods().Slice() { + recvType := f.Type.Recv().Type + if recvType.IsPtr() { + recvType = recvType.Elem() + } + // Substitute in the method using the type params used in the + // method (not the type params in the definition of the generic type). + subst := Tsubster{ + Tparams: recvType.RParams(), + Targs: targs, + SubstForwFunc: doInst, + } + t2 := subst.Typ(f.Type) + oldsym := f.Nname.Sym() + newsym := MakeInstName(oldsym, targs, true) + var nname *ir.Name + if newsym.Def != nil { + nname = newsym.Def.(*ir.Name) + } else { + nname = ir.NewNameAt(f.Pos, newsym) + nname.SetType(t2) + newsym.Def = nname + } + newfields[i] = types.NewField(f.Pos, f.Sym, t2) + newfields[i].Nname = nname + } + t.Methods().Set(newfields) +} diff --git a/src/cmd/compile/internal/typecheck/subr.go b/src/cmd/compile/internal/typecheck/subr.go index 3e7799b35b..8ef49f91c8 100644 --- a/src/cmd/compile/internal/typecheck/subr.go +++ b/src/cmd/compile/internal/typecheck/subr.go @@ -890,7 +890,6 @@ func TypesOf(x []ir.Node) []*types.Type { // based on the name of the function fnsym and the targs. It replaces any // existing bracket type list in the name. makeInstName asserts that fnsym has // brackets in its name if and only if hasBrackets is true. -// TODO(danscales): remove the assertions and the hasBrackets argument later. // // Names of declared generic functions have no brackets originally, so hasBrackets // should be false. Names of generic methods already have brackets, since the new @@ -902,12 +901,24 @@ func TypesOf(x []ir.Node) []*types.Type { func MakeInstName(fnsym *types.Sym, targs []*types.Type, hasBrackets bool) *types.Sym { b := bytes.NewBufferString("") - // marker to distinguish generic instantiations from fully stenciled wrapper functions. + // Determine if the type args are concrete types or new typeparams. + hasTParam := false + for _, targ := range targs { + if hasTParam { + assert(targ.HasTParam()) + } else if targ.HasTParam() { + hasTParam = true + } + } + + // Marker to distinguish generic instantiations from fully stenciled wrapper functions. // Once we move to GC shape implementations, this prefix will not be necessary as the // GC shape naming will distinguish them. // e.g. f[8bytenonpointer] vs. f[int]. // For now, we use .inst.f[int] vs. f[int]. - b.WriteString(".inst.") + if !hasTParam { + b.WriteString(".inst.") + } name := fnsym.Name i := strings.Index(name, "[") @@ -942,10 +953,325 @@ func MakeInstName(fnsym *types.Sym, targs []*types.Type, hasBrackets bool) *type return fnsym.Pkg.Lookup(b.String()) } -// For catching problems as we add more features -// TODO(danscales): remove assertions or replace with base.FatalfAt() func assert(p bool) { if !p { panic("assertion failed") } } + +// General type substituter, for replacing typeparams with type args. +type Tsubster struct { + Tparams []*types.Type + Targs []*types.Type + // If non-nil, the substitution map from name nodes in the generic function to the + // name nodes in the new stenciled function. + Vars map[*ir.Name]*ir.Name + // New fully-instantiated generic types whose methods should be instantiated. + InstTypeList []*types.Type + // If non-nil, function to substitute an incomplete (TFORW) type. + SubstForwFunc func(*types.Type) *types.Type +} + +// Typ computes the type obtained by substituting any type parameter in t with the +// corresponding type argument in subst. If t contains no type parameters, the +// result is t; otherwise the result is a new type. It deals with recursive types +// by using TFORW types and finding partially or fully created types via sym.Def. +func (ts *Tsubster) Typ(t *types.Type) *types.Type { + if !t.HasTParam() && t.Kind() != types.TFUNC { + // Note: function types need to be copied regardless, as the + // types of closures may contain declarations that need + // to be copied. See #45738. + return t + } + + if t.Kind() == types.TTYPEPARAM { + for i, tp := range ts.Tparams { + if tp == t { + return ts.Targs[i] + } + } + // If t is a simple typeparam T, then t has the name/symbol 'T' + // and t.Underlying() == t. + // + // However, consider the type definition: 'type P[T any] T'. We + // might use this definition so we can have a variant of type T + // that we can add new methods to. Suppose t is a reference to + // P[T]. t has the name 'P[T]', but its kind is TTYPEPARAM, + // because P[T] is defined as T. If we look at t.Underlying(), it + // is different, because the name of t.Underlying() is 'T' rather + // than 'P[T]'. But the kind of t.Underlying() is also TTYPEPARAM. + // In this case, we do the needed recursive substitution in the + // case statement below. + if t.Underlying() == t { + // t is a simple typeparam that didn't match anything in tparam + return t + } + // t is a more complex typeparam (e.g. P[T], as above, whose + // definition is just T). + assert(t.Sym() != nil) + } + + var newsym *types.Sym + var neededTargs []*types.Type + var forw *types.Type + + if t.Sym() != nil { + // Translate the type params for this type according to + // the tparam/targs mapping from subst. + neededTargs = make([]*types.Type, len(t.RParams())) + for i, rparam := range t.RParams() { + neededTargs[i] = ts.Typ(rparam) + } + // For a named (defined) type, we have to change the name of the + // type as well. We do this first, so we can look up if we've + // already seen this type during this substitution or other + // definitions/substitutions. + genName := genericTypeName(t.Sym()) + newsym = t.Sym().Pkg.Lookup(InstTypeName(genName, neededTargs)) + if newsym.Def != nil { + // We've already created this instantiated defined type. + return newsym.Def.Type() + } + + // In order to deal with recursive generic types, create a TFORW + // type initially and set the Def field of its sym, so it can be + // found if this type appears recursively within the type. + forw = NewIncompleteNamedType(t.Pos(), newsym) + //println("Creating new type by sub", newsym.Name, forw.HasTParam()) + forw.SetRParams(neededTargs) + // Copy the OrigSym from the re-instantiated type (which is the sym of + // the base generic type). + assert(t.OrigSym != nil) + forw.OrigSym = t.OrigSym + } + + var newt *types.Type + + switch t.Kind() { + case types.TTYPEPARAM: + if t.Sym() == newsym { + // The substitution did not change the type. + return t + } + // Substitute the underlying typeparam (e.g. T in P[T], see + // the example describing type P[T] above). + newt = ts.Typ(t.Underlying()) + assert(newt != t) + + case types.TARRAY: + elem := t.Elem() + newelem := ts.Typ(elem) + if newelem != elem { + newt = types.NewArray(newelem, t.NumElem()) + } + + case types.TPTR: + elem := t.Elem() + newelem := ts.Typ(elem) + if newelem != elem { + newt = types.NewPtr(newelem) + } + + case types.TSLICE: + elem := t.Elem() + newelem := ts.Typ(elem) + if newelem != elem { + newt = types.NewSlice(newelem) + } + + case types.TSTRUCT: + newt = ts.tstruct(t, false) + if newt == t { + newt = nil + } + + case types.TFUNC: + newrecvs := ts.tstruct(t.Recvs(), false) + newparams := ts.tstruct(t.Params(), false) + newresults := ts.tstruct(t.Results(), false) + if newrecvs != t.Recvs() || newparams != t.Params() || newresults != t.Results() { + // If any types have changed, then the all the fields of + // of recv, params, and results must be copied, because they have + // offset fields that are dependent, and so must have an + // independent copy for each new signature. + var newrecv *types.Field + if newrecvs.NumFields() > 0 { + if newrecvs == t.Recvs() { + newrecvs = ts.tstruct(t.Recvs(), true) + } + newrecv = newrecvs.Field(0) + } + if newparams == t.Params() { + newparams = ts.tstruct(t.Params(), true) + } + if newresults == t.Results() { + newresults = ts.tstruct(t.Results(), true) + } + newt = types.NewSignature(t.Pkg(), newrecv, t.TParams().FieldSlice(), newparams.FieldSlice(), newresults.FieldSlice()) + } + + case types.TINTER: + newt = ts.tinter(t) + if newt == t { + newt = nil + } + + case types.TMAP: + newkey := ts.Typ(t.Key()) + newval := ts.Typ(t.Elem()) + if newkey != t.Key() || newval != t.Elem() { + newt = types.NewMap(newkey, newval) + } + + case types.TCHAN: + elem := t.Elem() + newelem := ts.Typ(elem) + if newelem != elem { + newt = types.NewChan(newelem, t.ChanDir()) + if !newt.HasTParam() { + // TODO(danscales): not sure why I have to do this + // only for channels..... + types.CheckSize(newt) + } + } + case types.TFORW: + if ts.SubstForwFunc != nil { + newt = ts.SubstForwFunc(t) + } else { + assert(false) + } + } + if newt == nil { + // Even though there were typeparams in the type, there may be no + // change if this is a function type for a function call (which will + // have its own tparams/targs in the function instantiation). + return t + } + + if t.Sym() == nil { + // Not a named type, so there was no forwarding type and there are + // no methods to substitute. + assert(t.Methods().Len() == 0) + return newt + } + + forw.SetUnderlying(newt) + newt = forw + + if t.Kind() != types.TINTER && t.Methods().Len() > 0 { + // Fill in the method info for the new type. + var newfields []*types.Field + newfields = make([]*types.Field, t.Methods().Len()) + for i, f := range t.Methods().Slice() { + t2 := ts.Typ(f.Type) + oldsym := f.Nname.Sym() + newsym := MakeInstName(oldsym, ts.Targs, true) + var nname *ir.Name + if newsym.Def != nil { + nname = newsym.Def.(*ir.Name) + } else { + nname = ir.NewNameAt(f.Pos, newsym) + nname.SetType(t2) + newsym.Def = nname + } + newfields[i] = types.NewField(f.Pos, f.Sym, t2) + newfields[i].Nname = nname + } + newt.Methods().Set(newfields) + if !newt.HasTParam() { + // Generate all the methods for a new fully-instantiated type. + ts.InstTypeList = append(ts.InstTypeList, newt) + } + } + return newt +} + +// tstruct substitutes type params in types of the fields of a structure type. For +// each field, tstruct copies the Nname, and translates it if Nname is in +// ts.vars. To always force the creation of a new (top-level) struct, +// regardless of whether anything changed with the types or names of the struct's +// fields, set force to true. +func (ts *Tsubster) tstruct(t *types.Type, force bool) *types.Type { + if t.NumFields() == 0 { + if t.HasTParam() { + // For an empty struct, we need to return a new type, + // since it may now be fully instantiated (HasTParam + // becomes false). + return types.NewStruct(t.Pkg(), nil) + } + return t + } + var newfields []*types.Field + if force { + newfields = make([]*types.Field, t.NumFields()) + } + for i, f := range t.Fields().Slice() { + t2 := ts.Typ(f.Type) + if (t2 != f.Type || f.Nname != nil) && newfields == nil { + newfields = make([]*types.Field, t.NumFields()) + for j := 0; j < i; j++ { + newfields[j] = t.Field(j) + } + } + if newfields != nil { + // TODO(danscales): make sure this works for the field + // names of embedded types (which should keep the name of + // the type param, not the instantiated type). + newfields[i] = types.NewField(f.Pos, f.Sym, t2) + if f.Nname != nil && ts.Vars != nil { + v := ts.Vars[f.Nname.(*ir.Name)] + if v != nil { + // This is the case where we are + // translating the type of the function we + // are substituting, so its dcls are in + // the subst.ts.vars table, and we want to + // change to reference the new dcl. + newfields[i].Nname = v + } else { + // This is the case where we are + // translating the type of a function + // reference inside the function we are + // substituting, so we leave the Nname + // value as is. + newfields[i].Nname = f.Nname + } + } + } + } + if newfields != nil { + return types.NewStruct(t.Pkg(), newfields) + } + return t + +} + +// tinter substitutes type params in types of the methods of an interface type. +func (ts *Tsubster) tinter(t *types.Type) *types.Type { + if t.Methods().Len() == 0 { + return t + } + var newfields []*types.Field + for i, f := range t.Methods().Slice() { + t2 := ts.Typ(f.Type) + if (t2 != f.Type || f.Nname != nil) && newfields == nil { + newfields = make([]*types.Field, t.Methods().Len()) + for j := 0; j < i; j++ { + newfields[j] = t.Methods().Index(j) + } + } + if newfields != nil { + newfields[i] = types.NewField(f.Pos, f.Sym, t2) + } + } + if newfields != nil { + return types.NewInterface(t.Pkg(), newfields) + } + return t +} + +// genericSym returns the name of the base generic type for the type named by +// sym. It simply returns the name obtained by removing everything after the +// first bracket ("["). +func genericTypeName(sym *types.Sym) string { + return sym.Name[0:strings.Index(sym.Name, "[")] +} diff --git a/src/cmd/compile/internal/types2/builtins.go b/src/cmd/compile/internal/types2/builtins.go index 20c4ff62a1..8e13b0ff18 100644 --- a/src/cmd/compile/internal/types2/builtins.go +++ b/src/cmd/compile/internal/types2/builtins.go @@ -767,8 +767,10 @@ func (check *Checker) applyTypeFunc(f func(Type) Type, x Type) Type { // uses of real() where the result is used to // define type and initialize a variable? - // construct a suitable new type parameter - tpar := NewTypeName(nopos, nil /* = Universe pkg */, "", nil) + // Construct a suitable new type parameter for the sum type. The + // type param is placed in the current package so export/import + // works as expected. + tpar := NewTypeName(nopos, check.pkg, "", nil) ptyp := check.NewTypeParam(tpar, 0, &emptyInterface) // assigns type to tpar as a side-effect tsum := newUnion(rtypes, tildes) ptyp.bound = &Interface{allMethods: markComplete, allTypes: tsum} diff --git a/src/cmd/compile/internal/types2/type.go b/src/cmd/compile/internal/types2/type.go index 604520d27f..10cb651d0c 100644 --- a/src/cmd/compile/internal/types2/type.go +++ b/src/cmd/compile/internal/types2/type.go @@ -640,9 +640,8 @@ type TypeParam struct { // Obj returns the type name for the type parameter t. func (t *TypeParam) Obj() *TypeName { return t.obj } -// NewTypeParam returns a new TypeParam. +// NewTypeParam returns a new TypeParam. bound can be nil (and set later). func (check *Checker) NewTypeParam(obj *TypeName, index int, bound Type) *TypeParam { - assert(bound != nil) // Always increment lastID, even if it is not used. id := nextID() if check != nil { @@ -679,11 +678,20 @@ func (t *TypeParam) Bound() *Interface { return iface } -// optype returns a type's operational type. Except for type parameters, -// the operational type is the same as the underlying type (as returned -// by under). For Type parameters, the operational type is determined -// by the corresponding type constraint. The result may be the top type, -// but it is never the incoming type parameter. +func (t *TypeParam) SetBound(bound Type) { + if bound == nil { + panic("types2.TypeParam.SetBound: bound must not be nil") + } + t.bound = bound +} + +// optype returns a type's operational type. Except for +// type parameters, the operational type is the same +// as the underlying type (as returned by under). For +// Type parameters, the operational type is determined +// by the corresponding type bound's type list. The +// result may be the bottom or top type, but it is never +// the incoming type parameter. func optype(typ Type) Type { if t := asTypeParam(typ); t != nil { // If the optype is typ, return the top type as we have diff --git a/test/typeparam/absdiff.go b/test/typeparam/absdiff.go index ecaa907795..e76a998b4d 100644 --- a/test/typeparam/absdiff.go +++ b/test/typeparam/absdiff.go @@ -48,8 +48,7 @@ type Complex interface { type orderedAbs[T orderedNumeric] T func (a orderedAbs[T]) Abs() orderedAbs[T] { - // TODO(danscales): orderedAbs[T] conversion shouldn't be needed - if a < orderedAbs[T](0) { + if a < 0 { return -a } return a diff --git a/test/typeparam/absdiffimp.dir/a.go b/test/typeparam/absdiffimp.dir/a.go new file mode 100644 index 0000000000..df81dcf538 --- /dev/null +++ b/test/typeparam/absdiffimp.dir/a.go @@ -0,0 +1,75 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package a + +import ( + "math" +) + +type Numeric interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | + ~complex64 | ~complex128 +} + +// numericAbs matches numeric types with an Abs method. +type numericAbs[T any] interface { + Numeric + Abs() T +} + +// AbsDifference computes the absolute value of the difference of +// a and b, where the absolute value is determined by the Abs method. +func absDifference[T numericAbs[T]](a, b T) T { + d := a - b + return d.Abs() +} + +// orderedNumeric matches numeric types that support the < operator. +type orderedNumeric interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 +} + +// Complex matches the two complex types, which do not have a < operator. +type Complex interface { + ~complex64 | ~complex128 +} + +// orderedAbs is a helper type that defines an Abs method for +// ordered numeric types. +type orderedAbs[T orderedNumeric] T + +func (a orderedAbs[T]) Abs() orderedAbs[T] { + if a < 0 { + return -a + } + return a +} + +// complexAbs is a helper type that defines an Abs method for +// complex types. +type complexAbs[T Complex] T + +func (a complexAbs[T]) Abs() complexAbs[T] { + r := float64(real(a)) + i := float64(imag(a)) + d := math.Sqrt(r * r + i * i) + return complexAbs[T](complex(d, 0)) +} + +// OrderedAbsDifference returns the absolute value of the difference +// between a and b, where a and b are of an ordered type. +func OrderedAbsDifference[T orderedNumeric](a, b T) T { + return T(absDifference(orderedAbs[T](a), orderedAbs[T](b))) +} + +// ComplexAbsDifference returns the absolute value of the difference +// between a and b, where a and b are of a complex type. +func ComplexAbsDifference[T Complex](a, b T) T { + return T(absDifference(complexAbs[T](a), complexAbs[T](b))) +} diff --git a/test/typeparam/absdiffimp.dir/main.go b/test/typeparam/absdiffimp.dir/main.go new file mode 100644 index 0000000000..8eefdbdf38 --- /dev/null +++ b/test/typeparam/absdiffimp.dir/main.go @@ -0,0 +1,29 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "a" + "fmt" +) + +func main() { + if got, want := a.OrderedAbsDifference(1.0, -2.0), 3.0; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := a.OrderedAbsDifference(-1.0, 2.0), 3.0; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := a.OrderedAbsDifference(-20, 15), 35; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + + if got, want := a.ComplexAbsDifference(5.0+2.0i, 2.0-2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := a.ComplexAbsDifference(2.0-2.0i, 5.0+2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } +} diff --git a/test/typeparam/absdiffimp.go b/test/typeparam/absdiffimp.go new file mode 100644 index 0000000000..76930e5e4f --- /dev/null +++ b/test/typeparam/absdiffimp.go @@ -0,0 +1,7 @@ +// rundir -G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ignored diff --git a/test/typeparam/orderedmapsimp.dir/a.go b/test/typeparam/orderedmapsimp.dir/a.go new file mode 100644 index 0000000000..1b5827b4bb --- /dev/null +++ b/test/typeparam/orderedmapsimp.dir/a.go @@ -0,0 +1,226 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package a + +import ( + "context" + "runtime" +) + +type Ordered interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | + ~float32 | ~float64 | + ~string +} + +// Map is an ordered map. +type Map[K, V any] struct { + root *node[K, V] + compare func(K, K) int +} + +// node is the type of a node in the binary tree. +type node[K, V any] struct { + key K + val V + left, right *node[K, V] +} + +// New returns a new map. It takes a comparison function that compares two +// keys and returns < 0 if the first is less, == 0 if they are equal, +// > 0 if the first is greater. +func New[K, V any](compare func(K, K) int) *Map[K, V] { + return &Map[K, V]{compare: compare} +} + +// NewOrdered returns a new map whose key is an ordered type. +// This is like New, but does not require providing a compare function. +// The map compare function uses the obvious key ordering. +func NewOrdered[K Ordered, V any]() *Map[K, V] { + return New[K, V](func(k1, k2 K) int { + switch { + case k1 < k2: + return -1 + case k1 > k2: + return 1 + default: + return 0 + } + }) +} + +// find looks up key in the map, returning either a pointer to the slot of the +// node holding key, or a pointer to the slot where a node would go. +func (m *Map[K, V]) find(key K) **node[K, V] { + pn := &m.root + for *pn != nil { + switch cmp := m.compare(key, (*pn).key); { + case cmp < 0: + pn = &(*pn).left + case cmp > 0: + pn = &(*pn).right + default: + return pn + } + } + return pn +} + +// Insert inserts a new key/value into the map. +// If the key is already present, the value is replaced. +// Reports whether this is a new key. +func (m *Map[K, V]) Insert(key K, val V) bool { + pn := m.find(key) + if *pn != nil { + (*pn).val = val + return false + } + *pn = &node[K, V]{key: key, val: val} + return true +} + +// Find returns the value associated with a key, or the zero value +// if not present. The second result reports whether the key was found. +func (m *Map[K, V]) Find(key K) (V, bool) { + pn := m.find(key) + if *pn == nil { + var zero V + return zero, false + } + return (*pn).val, true +} + +// keyValue is a pair of key and value used while iterating. +type keyValue[K, V any] struct { + key K + val V +} + +// iterate returns an iterator that traverses the map. +// func (m *Map[K, V]) Iterate() *Iterator[K, V] { +// sender, receiver := Ranger[keyValue[K, V]]() +// var f func(*node[K, V]) bool +// f = func(n *node[K, V]) bool { +// if n == nil { +// return true +// } +// // Stop the traversal if Send fails, which means that +// // nothing is listening to the receiver. +// return f(n.left) && +// sender.Send(context.Background(), keyValue[K, V]{n.key, n.val}) && +// f(n.right) +// } +// go func() { +// f(m.root) +// sender.Close() +// }() +// return &Iterator[K, V]{receiver} +// } + +// Iterator is used to iterate over the map. +type Iterator[K, V any] struct { + r *Receiver[keyValue[K, V]] +} + +// Next returns the next key and value pair, and a boolean that reports +// whether they are valid. If not valid, we have reached the end of the map. +func (it *Iterator[K, V]) Next() (K, V, bool) { + keyval, ok := it.r.Next(context.Background()) + if !ok { + var zerok K + var zerov V + return zerok, zerov, false + } + return keyval.key, keyval.val, true +} + +// Equal reports whether two slices are equal: the same length and all +// elements equal. All floating point NaNs are considered equal. +func SliceEqual[Elem comparable](s1, s2 []Elem) bool { + if len(s1) != len(s2) { + return false + } + for i, v1 := range s1 { + v2 := s2[i] + if v1 != v2 { + isNaN := func(f Elem) bool { return f != f } + if !isNaN(v1) || !isNaN(v2) { + return false + } + } + } + return true +} + +// Ranger returns a Sender and a Receiver. The Receiver provides a +// Next method to retrieve values. The Sender provides a Send method +// to send values and a Close method to stop sending values. The Next +// method indicates when the Sender has been closed, and the Send +// method indicates when the Receiver has been freed. +// +// This is a convenient way to exit a goroutine sending values when +// the receiver stops reading them. +func Ranger[Elem any]() (*Sender[Elem], *Receiver[Elem]) { + c := make(chan Elem) + d := make(chan struct{}) + s := &Sender[Elem]{ + values: c, + done: d, + } + r := &Receiver[Elem] { + values: c, + done: d, + } + runtime.SetFinalizer(r, (*Receiver[Elem]).finalize) + return s, r +} + +// A Sender is used to send values to a Receiver. +type Sender[Elem any] struct { + values chan<- Elem + done <-chan struct{} +} + +// Send sends a value to the receiver. It reports whether the value was sent. +// The value will not be sent if the context is closed or the receiver +// is freed. +func (s *Sender[Elem]) Send(ctx context.Context, v Elem) bool { + select { + case <-ctx.Done(): + return false + case s.values <- v: + return true + case <-s.done: + return false + } +} + +// Close tells the receiver that no more values will arrive. +// After Close is called, the Sender may no longer be used. +func (s *Sender[Elem]) Close() { + close(s.values) +} + +// A Receiver receives values from a Sender. +type Receiver[Elem any] struct { + values <-chan Elem + done chan<- struct{} +} + +// Next returns the next value from the channel. The bool result indicates +// whether the value is valid. +func (r *Receiver[Elem]) Next(ctx context.Context) (v Elem, ok bool) { + select { + case <-ctx.Done(): + case v, ok = <-r.values: + } + return v, ok +} + +// finalize is a finalizer for the receiver. +func (r *Receiver[Elem]) finalize() { + close(r.done) +} diff --git a/test/typeparam/orderedmapsimp.dir/main.go b/test/typeparam/orderedmapsimp.dir/main.go new file mode 100644 index 0000000000..77869ad9fc --- /dev/null +++ b/test/typeparam/orderedmapsimp.dir/main.go @@ -0,0 +1,67 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "a" + "bytes" + "fmt" +) + +func TestMap() { + m := a.New[[]byte, int](bytes.Compare) + + if _, found := m.Find([]byte("a")); found { + panic(fmt.Sprintf("unexpectedly found %q in empty map", []byte("a"))) + } + + for _, c := range []int{ 'a', 'c', 'b' } { + if !m.Insert([]byte(string(c)), c) { + panic(fmt.Sprintf("key %q unexpectedly already present", []byte(string(c)))) + } + } + if m.Insert([]byte("c"), 'x') { + panic(fmt.Sprintf("key %q unexpectedly not present", []byte("c"))) + } + + if v, found := m.Find([]byte("a")); !found { + panic(fmt.Sprintf("did not find %q", []byte("a"))) + } else if v != 'a' { + panic(fmt.Sprintf("key %q returned wrong value %c, expected %c", []byte("a"), v, 'a')) + } + if v, found := m.Find([]byte("c")); !found { + panic(fmt.Sprintf("did not find %q", []byte("c"))) + } else if v != 'x' { + panic(fmt.Sprintf("key %q returned wrong value %c, expected %c", []byte("c"), v, 'x')) + } + + if _, found := m.Find([]byte("d")); found { + panic(fmt.Sprintf("unexpectedly found %q", []byte("d"))) + } + + // TODO(danscales): Iterate() has some things to be fixed with inlining in + // stenciled functions and using closures across packages. + + // gather := func(it *a.Iterator[[]byte, int]) []int { + // var r []int + // for { + // _, v, ok := it.Next() + // if !ok { + // return r + // } + // r = append(r, v) + // } + // } + // got := gather(m.Iterate()) + // want := []int{'a', 'b', 'x'} + // if !a.SliceEqual(got, want) { + // panic(fmt.Sprintf("Iterate returned %v, want %v", got, want)) + // } + +} + +func main() { + TestMap() +} diff --git a/test/typeparam/orderedmapsimp.go b/test/typeparam/orderedmapsimp.go new file mode 100644 index 0000000000..76930e5e4f --- /dev/null +++ b/test/typeparam/orderedmapsimp.go @@ -0,0 +1,7 @@ +// rundir -G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ignored diff --git a/test/typeparam/value.go b/test/typeparam/value.go index 5dd7449d9c..6c6dabcf7c 100644 --- a/test/typeparam/value.go +++ b/test/typeparam/value.go @@ -12,7 +12,7 @@ type value[T any] struct { val T } -func get[T2 any](v *value[T2]) T2 { +func get[T any](v *value[T]) T { return v.val } @@ -20,11 +20,11 @@ func set[T any](v *value[T], val T) { v.val = val } -func (v *value[T2]) set(val T2) { +func (v *value[T]) set(val T) { v.val = val } -func (v *value[T2]) get() T2 { +func (v *value[T]) get() T { return v.val }