From 20bdf8d8c577e6fc0ba1837ff190e87daf8ecba4 Mon Sep 17 00:00:00 2001 From: Aske Simon Christensen Date: Thu, 13 Oct 2022 21:04:17 +0000 Subject: [PATCH] [dart2wasm] Runtime generic function instantiation Adds a slot to the vtable of all generic closures with a function taking a generic closure and type arguments and returning a closure representing the instantiation of the generic closure with those type arguments. Instantiation operations then call this instantiation function, fetched from the vtable of the closure to be instantiated. The context of an instantiation closure contains the original closure and the type arguments. The vtable of an instantiation closure contains trampoline functions that call the corresponding entry point in the original closure with the instantiated type arguments. The instantiation functions are shared between all closures with the same representation. The trampolines are shared across representations for the same vtable entries. For now, the instantiation closure just inherits the runtime type from the original closure, which means that its runtime type will be incorrect. When we support generic function types with runtime type substitution, we can perform such substitution when instantiating a closure. Change-Id: I5d3a4d623c0673f9c2188f8a7ddd5b28b9404ac4 Cq-Include-Trybots: luci.dart.try:dart2wasm-linux-x64-d8-try Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/262201 Reviewed-by: Joshua Litt Commit-Queue: Aske Simon Christensen --- pkg/dart2wasm/lib/class_info.dart | 3 + pkg/dart2wasm/lib/closures.dart | 283 ++++++++++++++++++++++++-- pkg/dart2wasm/lib/code_generator.dart | 44 +++- pkg/dart2wasm/lib/constants.dart | 3 + pkg/dart2wasm/lib/translator.dart | 3 + 5 files changed, 319 insertions(+), 17 deletions(-) diff --git a/pkg/dart2wasm/lib/class_info.dart b/pkg/dart2wasm/lib/class_info.dart index 7cf72dd8fd5..535d0800ff4 100644 --- a/pkg/dart2wasm/lib/class_info.dart +++ b/pkg/dart2wasm/lib/class_info.dart @@ -26,6 +26,9 @@ class FieldIndex { static const closureContext = 2; static const closureVtable = 3; static const closureRuntimeType = 4; + static const vtableInstantiationFunction = 0; + static const instantiationContextInner = 0; + static const instantiationContextTypeArgumentsBase = 1; static const typeIsNullable = 2; static const interfaceTypeTypeArguments = 4; static const functionTypeNamedParameters = 6; diff --git a/pkg/dart2wasm/lib/closures.dart b/pkg/dart2wasm/lib/closures.dart index 30868e07ce6..83d1f518c2a 100644 --- a/pkg/dart2wasm/lib/closures.dart +++ b/pkg/dart2wasm/lib/closures.dart @@ -6,6 +6,7 @@ import 'dart:collection'; import 'dart:math' show min; import 'package:dart2wasm/code_generator.dart'; +import 'package:dart2wasm/class_info.dart'; import 'package:dart2wasm/translator.dart'; import 'package:kernel/ast.dart'; @@ -40,7 +41,7 @@ class ClosureImplementation { class ClosureRepresentation { /// The struct field index in the vtable struct at which the function /// entries start. - final int vtableBaseIndex; + final int typeCount; /// The Wasm struct type for the vtable. final w.StructType vtableStruct; @@ -50,8 +51,33 @@ class ClosureRepresentation { final Map? _indexOfCombination; - ClosureRepresentation(this.vtableBaseIndex, this.vtableStruct, - this.closureStruct, this._indexOfCombination); + /// The struct type for the context of an instantiated closure. + final w.StructType? instantiationContextStruct; + + /// Entry point functions for instantiations of this generic closure. + late final List instantiationTrampolines = + _instantiationTrampolinesThunk!(); + List Function()? _instantiationTrampolinesThunk; + + /// The function that instantiates this generic closure. + late final w.DefinedFunction instantiationFunction = + _instantiationFunctionThunk!(); + w.DefinedFunction Function()? _instantiationFunctionThunk; + + /// The signature of the function that instantiates this generic closure. + w.FunctionType get instantiationFunctionType => + (vtableStruct.fields[FieldIndex.vtableInstantiationFunction].type + as w.RefType) + .heapType as w.FunctionType; + + ClosureRepresentation(this.typeCount, this.vtableStruct, this.closureStruct, + this._indexOfCombination, this.instantiationContextStruct); + + bool get isGeneric => typeCount > 0; + + int get vtableBaseIndex => isGeneric + ? ClosureLayouter.vtableBaseIndexGeneric + : ClosureLayouter.vtableBaseIndexNonGeneric; /// The field index in the vtable struct for the function entry to use when /// calling the closure with the given number of positional arguments and the @@ -114,6 +140,9 @@ class ClosureLayouter extends RecursiveVisitor { Set visitedConstants = Set.identity(); + static const int vtableBaseIndexNonGeneric = 0; + static const int vtableBaseIndexGeneric = 1; + // Base struct for vtables. late final w.StructType vtableBaseStruct = m.addStructType("#VtableBase"); @@ -189,9 +218,8 @@ class ClosureLayouter extends RecursiveVisitor { final representations = _representationsForCounts(typeCount, positionalCount); if (representations.withoutNamed == null) { - ClosureRepresentation parent = positionalCount == 0 - ? ClosureRepresentation(vtableBaseStruct.fields.length, - vtableBaseStruct, closureBaseStruct, null) + ClosureRepresentation? parent = positionalCount == 0 + ? null : getClosureRepresentation(typeCount, positionalCount - 1, const [])!; representations.withoutNamed = _createRepresentation(typeCount, positionalCount, const [], parent, null, [positionalCount]); @@ -216,14 +244,57 @@ class ClosureLayouter extends RecursiveVisitor { int typeCount, int positionalCount, List names, - ClosureRepresentation parent, + ClosureRepresentation? parent, Map? indexOfCombination, Iterable paramCounts) { List nameTags = ["$typeCount", "$positionalCount", ...names]; String vtableName = ["#Vtable", ...nameTags].join("-"); String closureName = ["#Closure", ...nameTags].join("-"); w.StructType vtableStruct = m.addStructType(vtableName, - fields: parent.vtableStruct.fields, superType: parent.vtableStruct); + fields: [...?parent?.vtableStruct.fields], + superType: parent?.vtableStruct ?? vtableBaseStruct); + w.StructType closureStruct = _makeClosureStruct( + closureName, vtableStruct, parent?.closureStruct ?? closureBaseStruct); + + ClosureRepresentation? instantiatedRepresentation; + w.StructType? instantiationContextStruct; + if (typeCount > 0) { + // Add or set vtable field for the instantiation function. + instantiatedRepresentation = + getClosureRepresentation(0, positionalCount, names)!; + w.RefType inputType = w.RefType.def(closureBaseStruct, nullable: false); + w.RefType outputType = w.RefType.def( + instantiatedRepresentation.closureStruct, + nullable: false); + w.FunctionType instantiationFunctionType = m.addFunctionType( + [inputType, ...List.filled(typeCount, typeType)], [outputType], + superType: parent?.instantiationFunctionType); + w.FieldType functionFieldType = w.FieldType( + w.RefType.def(instantiationFunctionType, nullable: false), + mutable: false); + if (parent == null) { + assert(vtableStruct.fields.length == + FieldIndex.vtableInstantiationFunction); + vtableStruct.fields.add(functionFieldType); + } else { + vtableStruct.fields[FieldIndex.vtableInstantiationFunction] = + functionFieldType; + } + + // Build layout for the context of instantiated closures, containing the + // original closure plus the type arguments. + String instantiationContextName = + ["#InstantiationContext", ...nameTags].join("-"); + instantiationContextStruct = m.addStructType(instantiationContextName, + fields: [ + w.FieldType(w.RefType.def(closureStruct, nullable: false), + mutable: false), + ...List.filled(typeCount, w.FieldType(typeType, mutable: false)) + ], + superType: parent?.instantiationContextStruct); + } + + // Add vtable fields for additional entry points relative to the parent. for (int paramCount in paramCounts) { w.FunctionType entry = m.addFunctionType([ w.RefType.data(nullable: false), @@ -235,10 +306,198 @@ class ClosureLayouter extends RecursiveVisitor { vtableStruct.fields.add( w.FieldType(w.RefType.def(entry, nullable: false), mutable: false)); } - w.StructType closureStruct = - _makeClosureStruct(closureName, vtableStruct, parent.closureStruct); - return ClosureRepresentation(vtableBaseStruct.fields.length, vtableStruct, - closureStruct, indexOfCombination); + + ClosureRepresentation representation = ClosureRepresentation( + typeCount, + vtableStruct, + closureStruct, + indexOfCombination, + instantiationContextStruct); + + if (typeCount > 0) { + // The instantiation trampolines and the instantiation function can't be + // produced now, since we might not have added the module imports yet, and + // we can't define any functions before we have added the imports. + // Therefore, we set thunks in the representation which will be called + // when the instantiation function is needed, which will be during code + // generation, after the imports have been added. + + representation._instantiationTrampolinesThunk = () { + List instantiationTrampolines = [ + ...?parent?.instantiationTrampolines + ]; + if (names.isEmpty) { + // Add trampoline to the corresponding entry in the generic closure. + w.DefinedFunction trampoline = _createInstantiationTrampoline( + typeCount, + closureStruct, + instantiationContextStruct!, + instantiatedRepresentation!.vtableStruct, + vtableBaseIndexNonGeneric + instantiationTrampolines.length, + vtableStruct, + vtableBaseIndexGeneric + instantiationTrampolines.length); + instantiationTrampolines.add(trampoline); + } else { + // For each name combination in the instantiated closure, add a + // trampoline to the entry for the same name combination in the + // generic closure, or a dummy entry if the generic closure does not + // have that name combination. + for (NameCombination combination + in instantiatedRepresentation!._indexOfCombination!.keys) { + int? genericIndex = indexOfCombination![combination]; + w.DefinedFunction trampoline = genericIndex != null + ? _createInstantiationTrampoline( + typeCount, + closureStruct, + instantiationContextStruct!, + instantiatedRepresentation.vtableStruct, + vtableBaseIndexNonGeneric + instantiationTrampolines.length, + vtableStruct, + vtableBaseIndexGeneric + + (positionalCount + 1) + + genericIndex) + : translator.globals.getDummyFunction( + (instantiatedRepresentation + .vtableStruct + .fields[vtableBaseIndexNonGeneric + + instantiationTrampolines.length] + .type as w.RefType) + .heapType as w.FunctionType); + instantiationTrampolines.add(trampoline); + } + } + return instantiationTrampolines; + }; + + representation._instantiationFunctionThunk = () { + String instantiationFunctionName = + ["#Instantiation", ...nameTags].join("-"); + return _createInstantiationFunction( + typeCount, + instantiatedRepresentation!, + representation.instantiationTrampolines, + representation.instantiationFunctionType, + instantiationContextStruct!, + closureStruct, + instantiationFunctionName); + }; + } + + return representation; + } + + w.DefinedFunction _createInstantiationTrampoline( + int typeCount, + w.StructType genericClosureStruct, + w.StructType contextStruct, + w.StructType instantiatedVtableStruct, + int instantiatedVtableFieldIndex, + w.StructType genericVtableStruct, + int genericVtableFieldIndex) { + assert(contextStruct.fields.length == 1 + typeCount); + w.FunctionType instantiatedFunctionType = (instantiatedVtableStruct + .fields[instantiatedVtableFieldIndex].type as w.RefType) + .heapType as w.FunctionType; + w.FunctionType genericFunctionType = + (genericVtableStruct.fields[genericVtableFieldIndex].type as w.RefType) + .heapType as w.FunctionType; + assert(genericFunctionType.inputs.length == + instantiatedFunctionType.inputs.length + typeCount); + + w.DefinedFunction trampoline = m.addFunction(instantiatedFunctionType); + w.Instructions b = trampoline.body; + + // Cast context reference to actual context type. + w.RefType contextType = w.RefType.def(contextStruct, nullable: false); + w.Local contextLocal = trampoline.addLocal(contextType); + b.local_get(trampoline.locals[0]); + b.ref_cast(contextStruct); + b.local_tee(contextLocal); + + // Push inner context + b.struct_get(contextStruct, FieldIndex.instantiationContextInner); + b.struct_get(genericClosureStruct, FieldIndex.closureContext); + + // Push type arguments + for (int t = 0; t < typeCount; t++) { + b.local_get(contextLocal); + b.struct_get( + contextStruct, FieldIndex.instantiationContextTypeArgumentsBase + t); + } + + // Push arguments + for (int p = 1; p < instantiatedFunctionType.inputs.length; p++) { + b.local_get(trampoline.locals[p]); + } + + // Call inner + b.local_get(contextLocal); + b.struct_get(contextStruct, FieldIndex.instantiationContextInner); + b.struct_get(genericClosureStruct, FieldIndex.closureVtable); + b.struct_get(genericVtableStruct, genericVtableFieldIndex); + b.call_ref(); + b.end(); + + return trampoline; + } + + w.DefinedFunction _createInstantiationFunction( + int typeCount, + ClosureRepresentation instantiatedRepresentation, + List instantiationTrampolines, + w.FunctionType functionType, + w.StructType contextStruct, + w.StructType genericClosureStruct, + String name) { + assert(typeCount > 0); + w.RefType genericClosureType = + w.RefType.def(genericClosureStruct, nullable: false); + w.RefType instantiatedClosureType = w.RefType.def( + instantiatedRepresentation.closureStruct, + nullable: false); + assert(functionType.outputs.single == instantiatedClosureType); + + // Create vtable for the instantiated closure, containing the trampolines. + w.DefinedGlobal vtable = m.addGlobal(w.GlobalType( + w.RefType.def(instantiatedRepresentation.vtableStruct, nullable: false), + mutable: false)); + w.Instructions ib = vtable.initializer; + for (w.DefinedFunction trampoline in instantiationTrampolines) { + ib.ref_func(trampoline); + } + ib.struct_new(instantiatedRepresentation.vtableStruct); + ib.end(); + + ClassInfo info = translator.classInfo[translator.functionClass]!; + + w.DefinedFunction instantiationFunction = m.addFunction(functionType, name); + w.Local preciseClosure = instantiationFunction.addLocal(genericClosureType); + w.Instructions b = instantiationFunction.body; + + // Header for the closure struct + b.i32_const(info.classId); + b.i32_const(initialIdentityHash); + + // Context for the instantiated closure, containing the original closure and + // the type arguments + b.local_get(instantiationFunction.locals[0]); + b.ref_cast(genericClosureStruct); + b.local_tee(preciseClosure); + for (int i = 0; i < typeCount; i++) { + b.local_get(instantiationFunction.locals[1 + i]); + } + b.struct_new(contextStruct); + + // The rest of the closure struct + b.global_get(vtable); + // TODO(askesc): Substitute type arguments into type + b.local_get(preciseClosure); + b.struct_get(genericClosureStruct, FieldIndex.closureRuntimeType); + b.struct_new(instantiatedRepresentation.closureStruct); + + b.end(); + + return instantiationFunction; } ClosureRepresentationsForParameterCount _representationsForCounts( diff --git a/pkg/dart2wasm/lib/code_generator.dart b/pkg/dart2wasm/lib/code_generator.dart index 1226a29e225..110063e5035 100644 --- a/pkg/dart2wasm/lib/code_generator.dart +++ b/pkg/dart2wasm/lib/code_generator.dart @@ -2150,6 +2150,45 @@ class CodeGenerator extends ExpressionVisitor1 return translator.outputOrVoid(lambda.function.type.outputs); } + @override + w.ValueType visitInstantiation(Instantiation node, w.ValueType expectedType) { + DartType type = dartTypeOf(node.expression); + if (type is FunctionType) { + int typeCount = type.typeParameters.length; + int posArgCount = type.positionalParameters.length; + List argNames = type.namedParameters.map((a) => a.name).toList(); + ClosureRepresentation representation = translator.closureLayouter + .getClosureRepresentation(typeCount, posArgCount, argNames)!; + + // Operand closure + w.RefType closureType = + w.RefType.def(representation.closureStruct, nullable: false); + w.Local closureTemp = addLocal(closureType); + wrap(node.expression, closureType); + b.local_tee(closureTemp); + + // Type arguments + for (DartType typeArg in node.typeArguments) { + types.makeType(this, typeArg); + } + + // Instantiation function + b.local_get(closureTemp); + b.struct_get(representation.closureStruct, FieldIndex.closureVtable); + b.struct_get( + representation.vtableStruct, FieldIndex.vtableInstantiationFunction); + + // Call instantiation function + b.call_ref(); + return representation.instantiationFunctionType.outputs.single; + } else { + // Only other alternative is `NeverType`. + assert(type is NeverType); + b.unreachable(); + return voidMarker; + } + } + @override w.ValueType visitLogicalExpression( LogicalExpression node, w.ValueType expectedType) { @@ -2274,11 +2313,6 @@ class CodeGenerator extends ExpressionVisitor1 return expectedType; } - @override - w.ValueType visitInstantiation(Instantiation node, w.ValueType expectedType) { - throw "Not supported: Generic function instantiation at ${node.location}"; - } - @override w.ValueType visitConstantExpression( ConstantExpression node, w.ValueType expectedType) { diff --git a/pkg/dart2wasm/lib/constants.dart b/pkg/dart2wasm/lib/constants.dart index 4ba4963f949..c71b9b01e7a 100644 --- a/pkg/dart2wasm/lib/constants.dart +++ b/pkg/dart2wasm/lib/constants.dart @@ -694,6 +694,9 @@ class ConstantCreator extends ConstantVisitor { } void makeVtable() { + if (representation.isGeneric) { + b.ref_func(representation.instantiationFunction); + } for (int posArgCount = 0; posArgCount <= positionalCount; posArgCount++) { diff --git a/pkg/dart2wasm/lib/translator.dart b/pkg/dart2wasm/lib/translator.dart index 86c9cb506c5..036f5600d17 100644 --- a/pkg/dart2wasm/lib/translator.dart +++ b/pkg/dart2wasm/lib/translator.dart @@ -803,6 +803,9 @@ class Translator { w.RefType.def(representation.vtableStruct, nullable: false), mutable: false)); w.Instructions ib = vtable.initializer; + if (representation.isGeneric) { + ib.ref_func(representation.instantiationFunction); + } for (int posArgCount = 0; posArgCount <= positionalCount; posArgCount++) { fillVtableEntry(ib, posArgCount, const []); }