Skip to content

Commit

Permalink
wip: imports
Browse files Browse the repository at this point in the history
  • Loading branch information
zshipko committed Sep 12, 2024
1 parent c7cf667 commit 1711113
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 50 deletions.
9 changes: 1 addition & 8 deletions bin/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,5 @@ fn main() {
println!("cargo::rerun-if-changed=../lib/target/wasm32-wasi/release/core.wasm");

let out = std::path::PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("core.wasm");
std::process::Command::new("wasm-opt")
.arg("--disable-reference-types")
.arg("-O2")
.arg("../lib/target/wasm32-wasi/release/core.wasm")
.arg("-o")
.arg(out)
.status()
.unwrap();
std::fs::copy("../lib/target/wasm32-wasi/release/core.wasm", out).unwrap();
}
2 changes: 1 addition & 1 deletion bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn main() -> Result<(), Error> {
}

let mut user_code = std::fs::read_to_string(&opts.input_py)?;
user_code.push_str('\n');
user_code.push('\n');
user_code += INVOKE;

let tmp_dir = TempDir::new()?;
Expand Down
2 changes: 1 addition & 1 deletion bin/src/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ pub(crate) fn optimize_wasm_file(dest: impl AsRef<Path>) -> Result<(), Error> {
.arg("--enable-reference-types")
.arg("--enable-bulk-memory")
.arg("--strip")
.arg("-O3")
.arg("-O2")
.arg(dest.as_ref())
.arg("-o")
.arg(dest.as_ref())
Expand Down
2 changes: 1 addition & 1 deletion bin/src/py.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn get_import<R: std::fmt::Debug>(
}
}

println!("IMPORT {:?}::{:?}: {n_args} -> {has_return}", module, func);
// println!("IMPORT {:?}::{:?}: {n_args} -> {has_return}", module, func);
match (module, func) {
(Some(module), Some(func)) => Ok(Import {
module,
Expand Down
45 changes: 25 additions & 20 deletions bin/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@ pub(crate) fn generate(
) -> Result<(), Error> {
let mut module = wagen::Module::new();

let n_imports = imports.len();
let import_table = module.tables().push(wagen::TableType {
element_type: wagen::RefType::FUNCREF,
minimum: n_imports as u64,
maximum: None,
table64: false,
});

let __arg_start = module.import("core", "__arg_start", None, [], []);
let __arg_i32 = module.import("core", "__arg_i32", None, [ValType::I32], []);
let __arg_i64 = module.import("core", "__arg_i64", None, [ValType::I64], []);
let __arg_f32 = module.import("core", "__arg_f32", None, [ValType::F32], []);
let __arg_f64 = module.import("core", "__arg_f64", None, [ValType::F64], []);

let __invoke = module.import("core", "__invoke", None, [wagen::ValType::I32], []);

let __invoke_i64 = module.import(
Expand All @@ -26,20 +40,6 @@ pub(crate) fn generate(
[wagen::ValType::I32],
);

let __arg_start = module.import("core", "__arg_start", None, [], []);
let __arg_i32 = module.import("core", "__arg_i32", None, [ValType::I32], []);
let __arg_i64 = module.import("core", "__arg_i64", None, [ValType::I64], []);
let __arg_f32 = module.import("core", "__arg_f32", None, [ValType::F32], []);
let __arg_f64 = module.import("core", "__arg_f64", None, [ValType::F64], []);

let n_imports = imports.len();
let import_table = module.tables().push(wagen::TableType {
element_type: wagen::RefType::FUNCREF,
minimum: n_imports as u64,
maximum: None,
table64: false,
});

let mut import_elements = Vec::new();
for import in imports.iter() {
let index = module.import(
Expand All @@ -52,11 +52,6 @@ pub(crate) fn generate(
import_elements.push(index.index());
}

module.active_element(
Some(import_table),
wagen::Elements::Functions(&import_elements),
);

for p in 0..=5 {
for q in 0..=1 {
let indirect_type = module
Expand All @@ -80,11 +75,21 @@ pub(crate) fn generate(
ty: indirect_type,
table: import_table,
});
builder.push(Instr::Return);
}
}

module.active_element(
Some(import_table),
wagen::Elements::Functions(&import_elements),
);

for (index, export) in exports.iter().enumerate() {
if export.results.len() > 1 {
anyhow::bail!(
"Multiple return arguments are not currently supported but used in exported function {}",
export.name
);
}
let func = module
.func(
&export.name,
Expand Down
7 changes: 2 additions & 5 deletions lib/src/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

HttpRequest = ffi.HttpRequest

IMPORT_INDEX = 0

__exports = []

IMPORT_INDEX = 0


class Codec:
def __init__(self, value):
Expand Down Expand Up @@ -77,20 +77,17 @@ def import_fn(module, name):

def inner(func):
def wrapper(*args):
print("ARGS", args)
args = [_alloc(a) for a in args]
if "return" in func.__annotations__:
ret = func.__annotations__["return"]
print("RETURN", func, ret, module, name, idx, args)
res = ffi.__invoke_host_func(idx, *args)
print("AFTER")
return _read(ret, res)
else:
print("NO RETURN", func, module, name, idx, args)
ffi.__invoke_host_func0(idx, *args)

return wrapper

IMPORT_INDEX += 1
return inner

Expand Down
21 changes: 7 additions & 14 deletions lib/src/py_module.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use pyo3::{
exceptions::PyException,
prelude::*,
types::{PyBytes, PyModule, PyTuple},
types::{PyBytes, PyInt, PyModule, PyTuple},
PyErr, PyResult,
};

Expand Down Expand Up @@ -226,11 +226,11 @@ pub fn memory_alloc(data: &[u8]) -> PyResult<MemoryHandle> {
}

#[pyfunction]
#[pyo3(signature = (index, *args))]
#[pyo3(signature = (i, *args))]
#[pyo3(name = "__invoke_host_func")]
fn invoke_host_func(index: u32, args: &Bound<'_, PyTuple>) -> PyResult<Option<MemoryHandle>> {
fn invoke_host_func(i: &Bound<'_, PyInt>, args: &Bound<'_, PyTuple>) -> PyResult<u64> {
let length = args.len();

let index = i.extract::<'_, u32>()?;
let offs = unsafe {
match length {
0 => __invokeHostFunc_0_1(index),
Expand Down Expand Up @@ -269,22 +269,15 @@ fn invoke_host_func(index: u32, args: &Bound<'_, PyTuple>) -> PyResult<Option<Me
}
};

println!("OFFS: {offs}");
if let Some(mem) = extism_pdk::Memory::find(offs) {
Ok(Some(MemoryHandle {
offset: mem.offset(),
length: mem.len() as u64,
}))
} else {
Ok(None)
}
Ok(offs)
}

#[pyfunction]
#[pyo3(signature = (index, *args))]
#[pyo3(name = "__invoke_host_func0")]
fn invoke_host_func0(index: u32, args: &Bound<'_, PyTuple>) -> PyResult<()> {
fn invoke_host_func0(index: &Bound<'_, PyInt>, args: &Bound<'_, PyTuple>) -> PyResult<()> {
let length = args.len();
let index = index.extract::<'_, u32>()?;

unsafe {
match length {
Expand Down

0 comments on commit 1711113

Please sign in to comment.