Skip to content

Commit

Permalink
[GR-49541] Fix passing multi-value result between different WasmLangu…
Browse files Browse the repository at this point in the history
…age instances.

PullRequest: graal/15837
  • Loading branch information
woess committed Oct 27, 2023
2 parents f9ab576 + d9df9a5 commit 9c05445
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 242 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ public Object execute(VirtualFrame frame) {
final Object b = WebAssembly.instanceExport(instance, "b");
lib.execute(writeTable, a, 0, f);
final Object readValue = lib.execute(readTable, b, 0);
Assert.assertEquals("Written function should correspond ro read function", 42, lib.asInt(lib.execute(readValue)));
Assert.assertEquals("Written function should correspond to read function", 42, lib.asInt(lib.execute(readValue)));
} catch (UnsupportedMessageException | UnknownIdentifierException | UnsupportedTypeException | ArityException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -1293,21 +1293,15 @@ public void testMultiValueReferencePassThrough() throws IOException, Interrupted
""");
runTest(context -> {
final WebAssembly wasm = new WebAssembly(context);
final WasmFunctionInstance func = new WasmFunctionInstance(context, new RootNode(context.language()) {
@Override
public Object execute(VirtualFrame frame) {
return 0;
}
}.getCallTarget());
final WasmFunctionInstance f = new WasmFunctionInstance(context, new RootNode(context.language()) {
@Override
public Object execute(VirtualFrame frame) {
final Object[] result = new Object[2];
result[0] = func;
result[1] = "foo";
return InteropArray.create(result);
}
}.getCallTarget());
final var func = new Executable((args) -> {
return 0;
});
final var f = new Executable((args) -> {
final Object[] result = new Object[2];
result[0] = func;
result[1] = "foo";
return InteropArray.create(result);
});
final Dictionary importObject = Dictionary.create(new Object[]{"m", Dictionary.create(new Object[]{"f", f})});
final WasmInstance instance = moduleInstantiate(wasm, source, importObject);
final Object main = WebAssembly.instanceExport(instance, "main");
Expand Down Expand Up @@ -1633,13 +1627,10 @@ public void testImportMultiValue() throws IOException, InterruptedException {

runTest(context -> {
final WebAssembly wasm = new WebAssembly(context);
final Object f = new WasmFunctionInstance(context, new RootNode(context.language()) {
@Override
public Object execute(VirtualFrame frame) {
final Object[] arr = {1, 2, 3};
return InteropArray.create(arr);
}
}.getCallTarget());
final Object f = new Executable((args) -> {
final Object[] arr = {1, 2, 3};
return InteropArray.create(arr);
});
final Dictionary d = new Dictionary();
d.addMember("m", Dictionary.create(new Object[]{
"f", f
Expand Down Expand Up @@ -1671,12 +1662,9 @@ public void testImportMultiValueNotArray() throws IOException, InterruptedExcept
""");
runTest(context -> {
final WebAssembly wasm = new WebAssembly(context);
final Object f = new WasmFunctionInstance(context, new RootNode(context.language()) {
@Override
public Object execute(VirtualFrame frame) {
return 0;
}
}.getCallTarget());
final Object f = new Executable((args) -> {
return 0;
});
final Dictionary d = new Dictionary();
d.addMember("m", Dictionary.create(new Object[]{
"f", f
Expand Down Expand Up @@ -1712,12 +1700,9 @@ public void testImportMultiValueInvalidArraySize() throws IOException, Interrupt
runTest(context -> {
final WebAssembly wasm = new WebAssembly(context);

final Object f = new WasmFunctionInstance(context, new RootNode(context.language()) {
@Override
public Object execute(VirtualFrame frame) {
return InteropArray.create(new Object[]{1, 2});
}
}.getCallTarget());
final Object f = new Executable((args) -> {
return InteropArray.create(new Object[]{1, 2});
});
final Dictionary d = new Dictionary();
d.addMember("m", Dictionary.create(new Object[]{
"f", f
Expand Down Expand Up @@ -1753,12 +1738,9 @@ public void testImportMultiValueTypeMismatch() throws IOException, InterruptedEx
runTest(context -> {
final WebAssembly wasm = new WebAssembly(context);

final Object f = new WasmFunctionInstance(context, new RootNode(context.language()) {
@Override
public Object execute(VirtualFrame frame) {
return InteropArray.create(new Object[]{0, 1.1, 2});
}
}.getCallTarget());
final Object f = new Executable((args) -> {
return InteropArray.create(new Object[]{0, 1.1, 2});
});
final Dictionary d = new Dictionary();
d.addMember("m", Dictionary.create(new Object[]{
"f", f
Expand Down Expand Up @@ -1818,13 +1800,10 @@ public void testImportExportMultiValue() throws IOException, InterruptedExceptio
""");
runTest(context -> {
final WebAssembly wasm = new WebAssembly(context);
final Object f = new WasmFunctionInstance(context, new RootNode(context.language()) {
@Override
public Object execute(VirtualFrame frame) {
final Object[] arr = {1, 2, 3};
return InteropArray.create(arr);
}
}.getCallTarget());
final Object f = new Executable((args) -> {
final Object[] arr = {1, 2, 3};
return InteropArray.create(arr);
});
final Dictionary d = new Dictionary();
d.addMember("m", Dictionary.create(new Object[]{
"f", f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@ public void testMultiValueReturn() throws IOException, InterruptedException {
return
)
(func (export "test") (result i32 i32 f64)
i32.const 0
i32.const 1
f64.const 3.14
i32.const 6
i32.const 8
f64.const 2.72
return
)
)
Expand Down Expand Up @@ -549,9 +549,7 @@ public void testMultiValueReturn() throws IOException, InterruptedException {
Value v2 = context.asValue(instance2);
Value main2 = v2.getMember("main");
Value result2 = main2.execute();
if (Boolean.FALSE) { // GR-49541
Assert.assertEquals("Return value of main", List.of(42L, 0, 1, 3.14), List.copyOf(result2.as(List.class)));
}
Assert.assertEquals("Return value of main", List.of(42L, 6, 8, 2.72), List.copyOf(result2.as(List.class)));
}));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,11 @@ private Object multiValueStackAsArray(WasmLanguage language) {
case WasmType.I64_TYPE -> primitiveMultiValueStack[i];
case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]);
case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]);
case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> referenceMultiValueStack[i];
case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> {
Object ref = referenceMultiValueStack[i];
referenceMultiValueStack[i] = null;
yield ref;
}
default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,48 +41,181 @@
package org.graalvm.wasm.api;

import org.graalvm.wasm.WasmArguments;
import org.graalvm.wasm.WasmConstant;
import org.graalvm.wasm.WasmContext;
import org.graalvm.wasm.WasmInstance;
import org.graalvm.wasm.WasmLanguage;
import org.graalvm.wasm.WasmModule;
import org.graalvm.wasm.WasmType;
import org.graalvm.wasm.exception.Failure;
import org.graalvm.wasm.exception.WasmException;
import org.graalvm.wasm.predefined.WasmBuiltinRootNode;

import com.oracle.truffle.api.CompilerAsserts;
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.TruffleContext;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.interop.ArityException;
import com.oracle.truffle.api.interop.InteropException;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.interop.UnsupportedTypeException;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.profiles.BranchProfile;

public class ExecuteInParentContextNode extends WasmBuiltinRootNode {
private final Object executable;
@CompilationFinal private int functionTypeIndex = -1;
private final BranchProfile errorBranch = BranchProfile.create();
@Child private InteropLibrary functionInterop;
@Child private InteropLibrary arrayInterop;
@Child private InteropLibrary resultInterop;

public ExecuteInParentContextNode(WasmLanguage language, WasmModule module, Object executable) {
public ExecuteInParentContextNode(WasmLanguage language, WasmModule module, Object executable, int resultCount) {
super(language, module);
this.executable = executable;
this.functionInterop = InteropLibrary.getUncached(executable);
if (resultCount > 1) {
this.arrayInterop = InteropLibrary.getFactory().createDispatched(5);
}
}

void setFunctionTypeIndex(int functionTypeIndex) {
this.functionTypeIndex = functionTypeIndex;
}

@Override
public Object executeWithContext(VirtualFrame frame, WasmContext context, WasmInstance instance) {
// Imported executables come from the parent context
TruffleContext truffleContext = context.environment().getContext().getParent();
Object prev = truffleContext.enter(this);
Object[] arguments = WasmArguments.getArguments(frame.getArguments());
try {
return InteropLibrary.getUncached().execute(executable, WasmArguments.getArguments(frame.getArguments()));
Object prev = truffleContext.enter(this);
Object result;
try {
result = functionInterop.execute(executable, arguments);
} finally {
truffleContext.leave(this, prev);
}
int resultCount = module().symbolTable().functionTypeResultCount(functionTypeIndex);
CompilerAsserts.partialEvaluationConstant(resultCount);
if (resultCount == 0) {
return WasmConstant.VOID;
} else if (resultCount == 1) {
byte resultType = module().symbolTable().functionTypeResultTypeAt(functionTypeIndex, 0);
return convertResult(result, resultType);
} else {
pushMultiValueResult(result, resultCount);
return WasmConstant.MULTI_VALUE;
}
} catch (UnsupportedTypeException | UnsupportedMessageException | ArityException e) {
errorBranch.enter();
throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Call failed: %s", getMessage(e));
} finally {
truffleContext.leave(this, prev);
}
}

/**
* Convert result to a WebAssembly value. Note: In most cases, the result values will already be
* of the correct boxed type because they're already converted on the JS side, so we only need
* to unbox and and can forego InteropLibrary.
*/
private Object convertResult(Object result, byte resultType) throws UnsupportedMessageException {
CompilerAsserts.partialEvaluationConstant(resultType);
return switch (resultType) {
case WasmType.I32_TYPE -> asInt(result);
case WasmType.I64_TYPE -> asLong(result);
case WasmType.F32_TYPE -> asFloat(result);
case WasmType.F64_TYPE -> asDouble(result);
case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> result;
default -> {
throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType);
}
};
}

@ExplodeLoop
private void pushMultiValueResult(Object result, int resultCount) {
CompilerAsserts.partialEvaluationConstant(resultCount);
if (!arrayInterop.hasArrayElements(result)) {
errorBranch.enter();
throw WasmException.create(Failure.UNSUPPORTED_MULTI_VALUE_TYPE);
}
try {
final int size = (int) arrayInterop.getArraySize(result);
if (size != resultCount) {
errorBranch.enter();
throw WasmException.create(Failure.INVALID_MULTI_VALUE_ARITY);
}
final var multiValueStack = WasmLanguage.get(this).multiValueStack();
multiValueStack.resize(resultCount);
final long[] primitiveMultiValueStack = multiValueStack.primitiveStack();
final Object[] referenceMultiValueStack = multiValueStack.referenceStack();
for (int i = 0; i < resultCount; i++) {
byte resultType = module().symbolTable().functionTypeResultTypeAt(functionTypeIndex, i);
CompilerAsserts.partialEvaluationConstant(resultType);
Object value = arrayInterop.readArrayElement(result, i);
switch (resultType) {
case WasmType.I32_TYPE -> primitiveMultiValueStack[i] = asInt(value);
case WasmType.I64_TYPE -> primitiveMultiValueStack[i] = asLong(value);
case WasmType.F32_TYPE -> primitiveMultiValueStack[i] = Float.floatToRawIntBits(asFloat(value));
case WasmType.F64_TYPE -> primitiveMultiValueStack[i] = Double.doubleToRawLongBits(asDouble(value));
case WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> referenceMultiValueStack[i] = value;
default -> {
errorBranch.enter();
throw WasmException.format(Failure.UNSPECIFIED_TRAP, this, "Unknown result type: %d", resultType);
}
}
}
} catch (UnsupportedMessageException | InvalidArrayIndexException e) {
errorBranch.enter();
throw WasmException.create(Failure.INVALID_TYPE_IN_MULTI_VALUE);
}
}

private int asInt(Object result) throws UnsupportedMessageException {
if (result instanceof Integer i) {
return i;
} else {
return resultInterop().asInt(result);
}
}

private long asLong(Object result) throws UnsupportedMessageException {
if (result instanceof Long l) {
return l;
} else {
return resultInterop().asLong(result);
}
}

private float asFloat(Object result) throws UnsupportedMessageException {
if (result instanceof Float f) {
return f;
} else {
return resultInterop().asFloat(result);
}
}

private double asDouble(Object result) throws UnsupportedMessageException {
if (result instanceof Double d) {
return d;
} else {
return resultInterop().asDouble(result);
}
}

private InteropLibrary resultInterop() {
InteropLibrary interop = resultInterop;
if (interop == null) {
CompilerDirectives.transferToInterpreterAndInvalidate();
interop = insert(InteropLibrary.getFactory().createDispatched(5));
}
return interop;
}

@TruffleBoundary
private static String getMessage(InteropException e) {
return e.getMessage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@ protected WasmInstance createInstance(WasmLanguage language, WasmContext context
final String functionName = entry.getKey();
final Pair<WasmFunction, Object> info = entry.getValue();
final WasmFunction function = info.getLeft();
final Object executable = info.getRight();
final SymbolTable.FunctionType type = function.type();
if (info.getRight() instanceof WasmFunctionInstance) {
defineExportedFunction(instance, functionName, type.paramTypes(), type.resultTypes(), (WasmFunctionInstance) info.getRight());
if (executable instanceof WasmFunctionInstance) {
defineExportedFunction(instance, functionName, type.paramTypes(), type.resultTypes(), (WasmFunctionInstance) executable);
} else {
var executableWrapper = new ExecuteInParentContextNode(context.language(), module, info.getRight());
var executableWrapper = new ExecuteInParentContextNode(context.language(), module, executable, function.resultCount());
WasmFunction exported = defineFunction(context, module, functionName, type.paramTypes(), type.resultTypes(), executableWrapper);
executableWrapper.setFunctionTypeIndex(exported.index());
instance.setTarget(exported.index(), executableWrapper.getCallTarget());
}
}
Expand Down
Loading

0 comments on commit 9c05445

Please sign in to comment.