[dart2wasm] Runtime generic function instantiation

Adds a slot to the vtable of all generic closures with a function
taking a generic closure and type arguments and returning a closure
representing the instantiation of the generic closure with those type
arguments. Instantiation operations then call this instantiation
function, fetched from the vtable of the closure to be instantiated.

The context of an instantiation closure contains the original closure
and the type arguments. The vtable of an instantiation closure
contains trampoline functions that call the corresponding entry point
in the original closure with the instantiated type arguments.

The instantiation functions are shared between all closures with the
same representation. The trampolines are shared across representations
for the same vtable entries.

For now, the instantiation closure just inherits the runtime type from
the original closure, which means that its runtime type will be
incorrect. When we support generic function types with runtime type
substitution, we can perform such substitution when instantiating
a closure.

Change-Id: I5d3a4d623c0673f9c2188f8a7ddd5b28b9404ac4
Cq-Include-Trybots: luci.dart.try:dart2wasm-linux-x64-d8-try
Reviewed-on: https://dart-review.googlesource.com/c/sdk/+/262201
Reviewed-by: Joshua Litt <joshualitt@google.com>
Commit-Queue: Aske Simon Christensen <askesc@google.com>
This commit is contained in:
Aske Simon Christensen 2022-10-13 21:04:17 +00:00 committed by Commit Queue
parent 6f8d1e4859
commit 20bdf8d8c5
5 changed files with 319 additions and 17 deletions

View file

@ -26,6 +26,9 @@ class FieldIndex {
static const closureContext = 2;
static const closureVtable = 3;
static const closureRuntimeType = 4;
static const vtableInstantiationFunction = 0;
static const instantiationContextInner = 0;
static const instantiationContextTypeArgumentsBase = 1;
static const typeIsNullable = 2;
static const interfaceTypeTypeArguments = 4;
static const functionTypeNamedParameters = 6;

View file

@ -6,6 +6,7 @@ import 'dart:collection';
import 'dart:math' show min;
import 'package:dart2wasm/code_generator.dart';
import 'package:dart2wasm/class_info.dart';
import 'package:dart2wasm/translator.dart';
import 'package:kernel/ast.dart';
@ -40,7 +41,7 @@ class ClosureImplementation {
class ClosureRepresentation {
/// The struct field index in the vtable struct at which the function
/// entries start.
final int vtableBaseIndex;
final int typeCount;
/// The Wasm struct type for the vtable.
final w.StructType vtableStruct;
@ -50,8 +51,33 @@ class ClosureRepresentation {
final Map<NameCombination, int>? _indexOfCombination;
ClosureRepresentation(this.vtableBaseIndex, this.vtableStruct,
this.closureStruct, this._indexOfCombination);
/// The struct type for the context of an instantiated closure.
final w.StructType? instantiationContextStruct;
/// Entry point functions for instantiations of this generic closure.
late final List<w.DefinedFunction> instantiationTrampolines =
_instantiationTrampolinesThunk!();
List<w.DefinedFunction> Function()? _instantiationTrampolinesThunk;
/// The function that instantiates this generic closure.
late final w.DefinedFunction instantiationFunction =
_instantiationFunctionThunk!();
w.DefinedFunction Function()? _instantiationFunctionThunk;
/// The signature of the function that instantiates this generic closure.
w.FunctionType get instantiationFunctionType =>
(vtableStruct.fields[FieldIndex.vtableInstantiationFunction].type
as w.RefType)
.heapType as w.FunctionType;
ClosureRepresentation(this.typeCount, this.vtableStruct, this.closureStruct,
this._indexOfCombination, this.instantiationContextStruct);
bool get isGeneric => typeCount > 0;
int get vtableBaseIndex => isGeneric
? ClosureLayouter.vtableBaseIndexGeneric
: ClosureLayouter.vtableBaseIndexNonGeneric;
/// The field index in the vtable struct for the function entry to use when
/// calling the closure with the given number of positional arguments and the
@ -114,6 +140,9 @@ class ClosureLayouter extends RecursiveVisitor {
Set<Constant> visitedConstants = Set.identity();
static const int vtableBaseIndexNonGeneric = 0;
static const int vtableBaseIndexGeneric = 1;
// Base struct for vtables.
late final w.StructType vtableBaseStruct = m.addStructType("#VtableBase");
@ -189,9 +218,8 @@ class ClosureLayouter extends RecursiveVisitor {
final representations =
_representationsForCounts(typeCount, positionalCount);
if (representations.withoutNamed == null) {
ClosureRepresentation parent = positionalCount == 0
? ClosureRepresentation(vtableBaseStruct.fields.length,
vtableBaseStruct, closureBaseStruct, null)
ClosureRepresentation? parent = positionalCount == 0
? null
: getClosureRepresentation(typeCount, positionalCount - 1, const [])!;
representations.withoutNamed = _createRepresentation(typeCount,
positionalCount, const [], parent, null, [positionalCount]);
@ -216,14 +244,57 @@ class ClosureLayouter extends RecursiveVisitor {
int typeCount,
int positionalCount,
List<String> names,
ClosureRepresentation parent,
ClosureRepresentation? parent,
Map<NameCombination, int>? indexOfCombination,
Iterable<int> paramCounts) {
List<String> nameTags = ["$typeCount", "$positionalCount", ...names];
String vtableName = ["#Vtable", ...nameTags].join("-");
String closureName = ["#Closure", ...nameTags].join("-");
w.StructType vtableStruct = m.addStructType(vtableName,
fields: parent.vtableStruct.fields, superType: parent.vtableStruct);
fields: [...?parent?.vtableStruct.fields],
superType: parent?.vtableStruct ?? vtableBaseStruct);
w.StructType closureStruct = _makeClosureStruct(
closureName, vtableStruct, parent?.closureStruct ?? closureBaseStruct);
ClosureRepresentation? instantiatedRepresentation;
w.StructType? instantiationContextStruct;
if (typeCount > 0) {
// Add or set vtable field for the instantiation function.
instantiatedRepresentation =
getClosureRepresentation(0, positionalCount, names)!;
w.RefType inputType = w.RefType.def(closureBaseStruct, nullable: false);
w.RefType outputType = w.RefType.def(
instantiatedRepresentation.closureStruct,
nullable: false);
w.FunctionType instantiationFunctionType = m.addFunctionType(
[inputType, ...List.filled(typeCount, typeType)], [outputType],
superType: parent?.instantiationFunctionType);
w.FieldType functionFieldType = w.FieldType(
w.RefType.def(instantiationFunctionType, nullable: false),
mutable: false);
if (parent == null) {
assert(vtableStruct.fields.length ==
FieldIndex.vtableInstantiationFunction);
vtableStruct.fields.add(functionFieldType);
} else {
vtableStruct.fields[FieldIndex.vtableInstantiationFunction] =
functionFieldType;
}
// Build layout for the context of instantiated closures, containing the
// original closure plus the type arguments.
String instantiationContextName =
["#InstantiationContext", ...nameTags].join("-");
instantiationContextStruct = m.addStructType(instantiationContextName,
fields: [
w.FieldType(w.RefType.def(closureStruct, nullable: false),
mutable: false),
...List.filled(typeCount, w.FieldType(typeType, mutable: false))
],
superType: parent?.instantiationContextStruct);
}
// Add vtable fields for additional entry points relative to the parent.
for (int paramCount in paramCounts) {
w.FunctionType entry = m.addFunctionType([
w.RefType.data(nullable: false),
@ -235,10 +306,198 @@ class ClosureLayouter extends RecursiveVisitor {
vtableStruct.fields.add(
w.FieldType(w.RefType.def(entry, nullable: false), mutable: false));
}
w.StructType closureStruct =
_makeClosureStruct(closureName, vtableStruct, parent.closureStruct);
return ClosureRepresentation(vtableBaseStruct.fields.length, vtableStruct,
closureStruct, indexOfCombination);
ClosureRepresentation representation = ClosureRepresentation(
typeCount,
vtableStruct,
closureStruct,
indexOfCombination,
instantiationContextStruct);
if (typeCount > 0) {
// The instantiation trampolines and the instantiation function can't be
// produced now, since we might not have added the module imports yet, and
// we can't define any functions before we have added the imports.
// Therefore, we set thunks in the representation which will be called
// when the instantiation function is needed, which will be during code
// generation, after the imports have been added.
representation._instantiationTrampolinesThunk = () {
List<w.DefinedFunction> instantiationTrampolines = [
...?parent?.instantiationTrampolines
];
if (names.isEmpty) {
// Add trampoline to the corresponding entry in the generic closure.
w.DefinedFunction trampoline = _createInstantiationTrampoline(
typeCount,
closureStruct,
instantiationContextStruct!,
instantiatedRepresentation!.vtableStruct,
vtableBaseIndexNonGeneric + instantiationTrampolines.length,
vtableStruct,
vtableBaseIndexGeneric + instantiationTrampolines.length);
instantiationTrampolines.add(trampoline);
} else {
// For each name combination in the instantiated closure, add a
// trampoline to the entry for the same name combination in the
// generic closure, or a dummy entry if the generic closure does not
// have that name combination.
for (NameCombination combination
in instantiatedRepresentation!._indexOfCombination!.keys) {
int? genericIndex = indexOfCombination![combination];
w.DefinedFunction trampoline = genericIndex != null
? _createInstantiationTrampoline(
typeCount,
closureStruct,
instantiationContextStruct!,
instantiatedRepresentation.vtableStruct,
vtableBaseIndexNonGeneric + instantiationTrampolines.length,
vtableStruct,
vtableBaseIndexGeneric +
(positionalCount + 1) +
genericIndex)
: translator.globals.getDummyFunction(
(instantiatedRepresentation
.vtableStruct
.fields[vtableBaseIndexNonGeneric +
instantiationTrampolines.length]
.type as w.RefType)
.heapType as w.FunctionType);
instantiationTrampolines.add(trampoline);
}
}
return instantiationTrampolines;
};
representation._instantiationFunctionThunk = () {
String instantiationFunctionName =
["#Instantiation", ...nameTags].join("-");
return _createInstantiationFunction(
typeCount,
instantiatedRepresentation!,
representation.instantiationTrampolines,
representation.instantiationFunctionType,
instantiationContextStruct!,
closureStruct,
instantiationFunctionName);
};
}
return representation;
}
w.DefinedFunction _createInstantiationTrampoline(
int typeCount,
w.StructType genericClosureStruct,
w.StructType contextStruct,
w.StructType instantiatedVtableStruct,
int instantiatedVtableFieldIndex,
w.StructType genericVtableStruct,
int genericVtableFieldIndex) {
assert(contextStruct.fields.length == 1 + typeCount);
w.FunctionType instantiatedFunctionType = (instantiatedVtableStruct
.fields[instantiatedVtableFieldIndex].type as w.RefType)
.heapType as w.FunctionType;
w.FunctionType genericFunctionType =
(genericVtableStruct.fields[genericVtableFieldIndex].type as w.RefType)
.heapType as w.FunctionType;
assert(genericFunctionType.inputs.length ==
instantiatedFunctionType.inputs.length + typeCount);
w.DefinedFunction trampoline = m.addFunction(instantiatedFunctionType);
w.Instructions b = trampoline.body;
// Cast context reference to actual context type.
w.RefType contextType = w.RefType.def(contextStruct, nullable: false);
w.Local contextLocal = trampoline.addLocal(contextType);
b.local_get(trampoline.locals[0]);
b.ref_cast(contextStruct);
b.local_tee(contextLocal);
// Push inner context
b.struct_get(contextStruct, FieldIndex.instantiationContextInner);
b.struct_get(genericClosureStruct, FieldIndex.closureContext);
// Push type arguments
for (int t = 0; t < typeCount; t++) {
b.local_get(contextLocal);
b.struct_get(
contextStruct, FieldIndex.instantiationContextTypeArgumentsBase + t);
}
// Push arguments
for (int p = 1; p < instantiatedFunctionType.inputs.length; p++) {
b.local_get(trampoline.locals[p]);
}
// Call inner
b.local_get(contextLocal);
b.struct_get(contextStruct, FieldIndex.instantiationContextInner);
b.struct_get(genericClosureStruct, FieldIndex.closureVtable);
b.struct_get(genericVtableStruct, genericVtableFieldIndex);
b.call_ref();
b.end();
return trampoline;
}
w.DefinedFunction _createInstantiationFunction(
int typeCount,
ClosureRepresentation instantiatedRepresentation,
List<w.DefinedFunction> instantiationTrampolines,
w.FunctionType functionType,
w.StructType contextStruct,
w.StructType genericClosureStruct,
String name) {
assert(typeCount > 0);
w.RefType genericClosureType =
w.RefType.def(genericClosureStruct, nullable: false);
w.RefType instantiatedClosureType = w.RefType.def(
instantiatedRepresentation.closureStruct,
nullable: false);
assert(functionType.outputs.single == instantiatedClosureType);
// Create vtable for the instantiated closure, containing the trampolines.
w.DefinedGlobal vtable = m.addGlobal(w.GlobalType(
w.RefType.def(instantiatedRepresentation.vtableStruct, nullable: false),
mutable: false));
w.Instructions ib = vtable.initializer;
for (w.DefinedFunction trampoline in instantiationTrampolines) {
ib.ref_func(trampoline);
}
ib.struct_new(instantiatedRepresentation.vtableStruct);
ib.end();
ClassInfo info = translator.classInfo[translator.functionClass]!;
w.DefinedFunction instantiationFunction = m.addFunction(functionType, name);
w.Local preciseClosure = instantiationFunction.addLocal(genericClosureType);
w.Instructions b = instantiationFunction.body;
// Header for the closure struct
b.i32_const(info.classId);
b.i32_const(initialIdentityHash);
// Context for the instantiated closure, containing the original closure and
// the type arguments
b.local_get(instantiationFunction.locals[0]);
b.ref_cast(genericClosureStruct);
b.local_tee(preciseClosure);
for (int i = 0; i < typeCount; i++) {
b.local_get(instantiationFunction.locals[1 + i]);
}
b.struct_new(contextStruct);
// The rest of the closure struct
b.global_get(vtable);
// TODO(askesc): Substitute type arguments into type
b.local_get(preciseClosure);
b.struct_get(genericClosureStruct, FieldIndex.closureRuntimeType);
b.struct_new(instantiatedRepresentation.closureStruct);
b.end();
return instantiationFunction;
}
ClosureRepresentationsForParameterCount _representationsForCounts(

View file

@ -2150,6 +2150,45 @@ class CodeGenerator extends ExpressionVisitor1<w.ValueType, w.ValueType>
return translator.outputOrVoid(lambda.function.type.outputs);
}
@override
w.ValueType visitInstantiation(Instantiation node, w.ValueType expectedType) {
DartType type = dartTypeOf(node.expression);
if (type is FunctionType) {
int typeCount = type.typeParameters.length;
int posArgCount = type.positionalParameters.length;
List<String> argNames = type.namedParameters.map((a) => a.name).toList();
ClosureRepresentation representation = translator.closureLayouter
.getClosureRepresentation(typeCount, posArgCount, argNames)!;
// Operand closure
w.RefType closureType =
w.RefType.def(representation.closureStruct, nullable: false);
w.Local closureTemp = addLocal(closureType);
wrap(node.expression, closureType);
b.local_tee(closureTemp);
// Type arguments
for (DartType typeArg in node.typeArguments) {
types.makeType(this, typeArg);
}
// Instantiation function
b.local_get(closureTemp);
b.struct_get(representation.closureStruct, FieldIndex.closureVtable);
b.struct_get(
representation.vtableStruct, FieldIndex.vtableInstantiationFunction);
// Call instantiation function
b.call_ref();
return representation.instantiationFunctionType.outputs.single;
} else {
// Only other alternative is `NeverType`.
assert(type is NeverType);
b.unreachable();
return voidMarker;
}
}
@override
w.ValueType visitLogicalExpression(
LogicalExpression node, w.ValueType expectedType) {
@ -2274,11 +2313,6 @@ class CodeGenerator extends ExpressionVisitor1<w.ValueType, w.ValueType>
return expectedType;
}
@override
w.ValueType visitInstantiation(Instantiation node, w.ValueType expectedType) {
throw "Not supported: Generic function instantiation at ${node.location}";
}
@override
w.ValueType visitConstantExpression(
ConstantExpression node, w.ValueType expectedType) {

View file

@ -694,6 +694,9 @@ class ConstantCreator extends ConstantVisitor<ConstantInfo?> {
}
void makeVtable() {
if (representation.isGeneric) {
b.ref_func(representation.instantiationFunction);
}
for (int posArgCount = 0;
posArgCount <= positionalCount;
posArgCount++) {

View file

@ -803,6 +803,9 @@ class Translator {
w.RefType.def(representation.vtableStruct, nullable: false),
mutable: false));
w.Instructions ib = vtable.initializer;
if (representation.isGeneric) {
ib.ref_func(representation.instantiationFunction);
}
for (int posArgCount = 0; posArgCount <= positionalCount; posArgCount++) {
fillVtableEntry(ib, posArgCount, const []);
}