Skip to content

Commit

Permalink
fix: windows free
Browse files Browse the repository at this point in the history
  • Loading branch information
mworzala committed Jul 28, 2024
1 parent 542c51c commit 7789c36
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 35 deletions.
5 changes: 4 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,14 @@ signing {
}

task<JExtractTask>("jextractLuacode") {
header = file("native/luau/Compiler/include/luacode.h")
// Note: we use the header from the build directory here because we need to catch "luau_ext_free" which is added
// by native/build.gradle.kts
header = file("native/build/root/luau/Compiler/include/luacode.h")
targetPackage = "net.hollowcube.luau.internal.compiler"
extraArgs = listOf(
"--define-macro", "LUA_API=\"extern \\\"C\\\"\"",
"--include-function", "luau_compile",
"--include-function", "luau_ext_free",
"--include-struct", "lua_CompileOptions",
)
}
Expand Down
33 changes: 23 additions & 10 deletions native/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,35 @@ task<Copy>("copyForModification") {
into(buildProjectDir)
}

fun edit(file: File, edit: (String) -> String) {
val originalText = file.readText()
val modifiedText = edit(originalText)
if (originalText != modifiedText) {
file.writeText(modifiedText)
}
}

task("luauStaticToShared") {
description = "Make Luau.Compiler and Luau.VM compile as shared libraries"
dependsOn("copyForModification")

val targetFile = buildProjectDir.resolve("luau/CMakeLists.txt")
inputs.file(targetFile)
outputs.file(targetFile)
val cmakeLists = buildProjectDir.resolve("luau/CMakeLists.txt")
val luacodeHeader = buildProjectDir.resolve("luau/Compiler/include/luacode.h")
val luacodeSource = buildProjectDir.resolve("luau/Compiler/src/lcode.cpp")
inputs.files(cmakeLists, luacodeHeader, luacodeSource)
outputs.files(cmakeLists, luacodeHeader, luacodeSource)

doLast {
val originalText = targetFile.readText()
val modifiedText = originalText
.replace("add_library(Luau.Compiler STATIC)", "add_library(Luau.Compiler SHARED)")
.replace("add_library(Luau.VM STATIC)", "add_library(Luau.VM SHARED)")

if (originalText != modifiedText) {
targetFile.writeText(modifiedText)
edit(cmakeLists) {
return@edit it
.replace("add_library(Luau.Compiler STATIC)", "add_library(Luau.Compiler SHARED)")
.replace("add_library(Luau.VM STATIC)", "add_library(Luau.VM SHARED)")
}
edit(luacodeHeader) {
return@edit "$it\n\nLUACODE_API void luau_ext_free(char *bytecode);"
}
edit(luacodeSource) {
return@edit "$it\n\nvoid luau_ext_free(char *bytecode) {\n free(bytecode);\n}"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,62 @@ public static MemorySegment luau_compile(MemorySegment source, long size, Memory
throw new AssertionError("should not reach here", ex$);
}
}

private static class luau_ext_free {
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
luacode_h.C_POINTER
);

public static final MemorySegment ADDR = luacode_h.findOrThrow("luau_ext_free");

public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC);
}

/**
* Function descriptor for:
* {@snippet lang=c :
* extern void luau_ext_free(char *bytecode)
* }
*/
public static FunctionDescriptor luau_ext_free$descriptor() {
return luau_ext_free.DESC;
}

/**
* Downcall method handle for:
* {@snippet lang=c :
* extern void luau_ext_free(char *bytecode)
* }
*/
public static MethodHandle luau_ext_free$handle() {
return luau_ext_free.HANDLE;
}

/**
* Address for:
* {@snippet lang=c :
* extern void luau_ext_free(char *bytecode)
* }
*/
public static MemorySegment luau_ext_free$address() {
return luau_ext_free.ADDR;
}

/**
* {@snippet lang=c :
* extern void luau_ext_free(char *bytecode)
* }
*/
public static void luau_ext_free(MemorySegment bytecode) {
var mh$ = luau_ext_free.HANDLE;
try {
if (TRACE_DOWNCALLS) {
traceDowncall("luau_ext_free", bytecode);
}
mh$.invokeExact(bytecode);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
}

1 change: 1 addition & 0 deletions src/main/java/net/hollowcube/luau/LuaStateImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ public Object toLightUserDataTagged(int index) {
@Override
public Object toUserData(int index) {
final MemorySegment ud = lua_touserdata(L, index);
//todo NULL can be returned. //todo this call is not valid against userDataInt ud values. We need to include some metadata in
return GlobalRef.get(ud.get(ValueLayout.JAVA_LONG, 0));
}

Expand Down
30 changes: 6 additions & 24 deletions src/main/java/net/hollowcube/luau/compiler/LuauCompilerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.charset.StandardCharsets;
import java.util.List;

import static net.hollowcube.luau.internal.compiler.luacode_h.luau_ext_free;

@SuppressWarnings("preview")
record LuauCompilerImpl(
@NotNull OptimizationLevel optimizationLevel,
Expand All @@ -27,19 +30,6 @@ record LuauCompilerImpl(
NativeLibraryLoader.loadLibrary("compiler");
}

private static final MethodHandle FREE_HANDLE;

static {
// This is basically just a manually inlined version of what jextract would generate for `free`.

final SymbolLookup symbolLookup = SymbolLookup.loaderLookup().or(Linker.nativeLinker().defaultLookup());
final MemorySegment freeAddr = symbolLookup.find("free")
.orElseThrow(() -> new UnsatisfiedLinkError("unresolved symbol: free"));
final FunctionDescriptor freeDesc = FunctionDescriptor.ofVoid(ValueLayout.ADDRESS);

FREE_HANDLE = Linker.nativeLinker().downcallHandle(freeAddr, freeDesc);
}

@Override
public byte[] compile(byte @NotNull [] source) throws LuauCompileException {
try (Arena arena = Arena.ofConfined()) {
Expand All @@ -50,7 +40,7 @@ public byte[] compile(byte @NotNull [] source) throws LuauCompileException {
final MemorySegment result = luacode_h.luau_compile(sourceStr, source.length, compileOpts, bytecodeSize);
final long length = bytecodeSize.get(ValueLayout.JAVA_LONG, 0);
final byte[] bytecode = result.asSlice(0, length).toArray(ValueLayout.JAVA_BYTE);
free(result);
luau_ext_free(result);

// Bytecode now contains either an error or valid luau bytecode.
// A zero in the first byte indicates that the rest is an error.
Expand Down Expand Up @@ -93,14 +83,6 @@ public byte[] compile(byte @NotNull [] source) throws LuauCompileException {
return opts;
}

private void free(@NotNull MemorySegment segment) {
try {
FREE_HANDLE.invokeExact(segment);
} catch (Throwable ex) {
throw new AssertionError("should not reach here", ex);
}
}

static final class BuilderImpl implements Builder {
private OptimizationLevel optimizationLevel = OptimizationLevel.BASELINE;
private DebugLevel debugLevel = DebugLevel.BACKTRACE;
Expand Down

0 comments on commit 7789c36

Please sign in to comment.