Skip to content

Commit

Permalink
feat: support i32, i64, f32, f64 param/return types for Host functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mhmd-azeez committed Feb 6, 2025
1 parent 7a319fe commit db850c0
Show file tree
Hide file tree
Showing 13 changed files with 528 additions and 304 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
matrix:
include:
- name: linux
os: ubuntu-22.04
os: ubuntu-24.04
# Re-enable once we can build on Windows again
# - name: windows
# os: windows-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
matrix:
include:
- name: linux
os: ubuntu-22.04
os: ubuntu-24.04
- name: windows
os: windows-latest
steps:
Expand Down
14 changes: 1 addition & 13 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,7 @@ clean-wasi-sdk:
test: compile-examples
@extism call examples/simple_js.wasm greet --wasi --input="Benjamin"
@extism call examples/bundled.wasm greet --wasi --input="Benjamin" --allow-host "example.com"
ifeq ($(OS),Windows_NT)
@python3 -m venv ./.venv && \
./.venv/Scripts/activate.bat && \
pip install -r examples/host_funcs/requirements.txt && \
python examples/host_funcs/host.py examples/host_funcs.wasm && \
./.venv/Scripts/deactivate.bat
else
@python3 -m venv ./.venv && \
. ./.venv/bin/activate && \
pip install -r examples/host_funcs/requirements.txt && \
python3 examples/host_funcs/host.py examples/host_funcs.wasm && \
deactivate
endif
cd ./examples/host_funcs && go run . ../host_funcs.wasm
@extism call examples/react.wasm render --wasi
@extism call examples/react.wasm setState --input='{"action": "SET_SETTING", "payload": { "backgroundColor": "tomato" }}' --wasi
@error_msg=$$(extism call examples/exception.wasm greet --wasi --input="Benjamin" 2>&1); \
Expand Down
287 changes: 243 additions & 44 deletions crates/cli/src/shims.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use crate::ts_parser::Interface;
use anyhow::Result;
use std::path::Path;
use wagen::{BlockType, Instr, ValType};

use crate::ts_parser::Interface;
use wagen::{Instr, ValType};
#[derive(PartialEq)]
enum TypeCode {
Void = 0,
I32 = 1,
I64 = 2,
F32 = 3,
F64 = 4,
}

/// Generates the wasm shim for the exports
pub fn generate_wasm_shims(
path: impl AsRef<Path>,
exports: &Interface,
Expand All @@ -23,64 +30,247 @@ pub fn generate_wasm_shims(
let __invoke_f64 = module.import("core", "__invoke_f64", None, [ValType::I32], [ValType::F64]);
let __invoke = module.import("core", "__invoke", None, [ValType::I32], []);

let mut n_imports = 0;
let mut import_elements = Vec::new();
let mut import_items = vec![];
for import in imports.iter() {
for _ in import.functions.iter() {
n_imports += 1;
for f in import.functions.iter() {
let params: Vec<_> = f.params.iter().map(|x| x.ptype).collect();
let results: Vec<_> = f.results.iter().map(|x| x.ptype).collect();
let index = module.import(&import.name, &f.name, None, params.clone(), results.clone());
import_items.push((f.name.clone(), index, params, results));
}
}
import_items.sort_by_key(|x| x.0.to_string());

for (_name, index, _params, _results) in &import_items {
import_elements.push(index.index());
}

let table_min = import_elements.len() as u32;

let import_table = module.tables().push(wagen::TableType {
element_type: wagen::RefType::FUNCREF,
minimum: n_imports,
minimum: table_min,
maximum: None,
});

let mut import_elements = Vec::new();
let mut import_items = vec![];
for import in imports.iter() {
for f in import.functions.iter() {
let params: Vec<_> = f.params.iter().map(|x| x.ptype).collect();
let results: Vec<_> = f.results.iter().map(|x| x.ptype).collect();
let index = module.import(&import.name, &f.name, None, params, results);
import_items.push((f.name.clone(), index));
let mut get_function_return_type_builder = wagen::Builder::default();

for (func_idx, (_name, _index, _params, results)) in import_items.iter().enumerate() {
let type_code = results.first().map_or(TypeCode::Void, |val_type| match val_type {
ValType::I32 => TypeCode::I32,
ValType::I64 => TypeCode::I64,
ValType::F32 => TypeCode::F32,
ValType::F64 => TypeCode::F64,
_ => TypeCode::Void,
});

if type_code == TypeCode::Void {
continue;
}

// Compare the input function index with the current index.
get_function_return_type_builder.push(Instr::LocalGet(0)); // load requested function index
get_function_return_type_builder.push(Instr::I32Const(func_idx as i32)); // load func_idx
get_function_return_type_builder.push(Instr::I32Eq); // compare
get_function_return_type_builder.push(Instr::If(BlockType::Empty)); // if true
get_function_return_type_builder.push(Instr::I32Const(type_code as i32)); // load type code
get_function_return_type_builder.push(Instr::Return); // early return if match
get_function_return_type_builder.push(Instr::End);
}
import_items.sort_by_key(|x| x.0.to_string());

for (_f, index) in import_items {
import_elements.push(index.index());
get_function_return_type_builder.push(Instr::I32Const(0)); // Default to 0
get_function_return_type_builder.push(Instr::Return);

let get_function_return_type_func = module.func(
"__get_function_return_type",
vec![ValType::I32], // takes function index
vec![ValType::I32], // returns type code
vec![],
);
get_function_return_type_func.export("__get_function_return_type");
get_function_return_type_func.body = get_function_return_type_builder;

let mut get_function_arg_type_builder = wagen::Builder::default();

for (func_idx, (_name, _index, params, _results)) in import_items.iter().enumerate() {
for arg_idx in 0..params.len() {
let type_code = match params[arg_idx] {
ValType::I32 => TypeCode::I32,
ValType::I64 => TypeCode::I64,
ValType::F32 => TypeCode::F32,
ValType::F64 => TypeCode::F64,
_ => panic!("Unsupported argument type for function {} at index {}", func_idx, arg_idx),
};

// Compare both function index and argument index
get_function_arg_type_builder.push(Instr::LocalGet(0)); // function index
get_function_arg_type_builder.push(Instr::I32Const(func_idx as i32));
get_function_arg_type_builder.push(Instr::I32Eq);

get_function_arg_type_builder.push(Instr::LocalGet(1)); // argument index
get_function_arg_type_builder.push(Instr::I32Const(arg_idx as i32));
get_function_arg_type_builder.push(Instr::I32Eq);

get_function_arg_type_builder.push(Instr::I32And); // Both must match

// If both match, return the type code
get_function_arg_type_builder.push(Instr::If(BlockType::Empty));
get_function_arg_type_builder.push(Instr::I32Const(type_code as i32));
get_function_arg_type_builder.push(Instr::Return);
get_function_arg_type_builder.push(Instr::End);
}
}

for p in 0..=5 {
for q in 0..=1 {
let indirect_type = module
.types()
.push(|t| t.function(vec![ValType::I64; p], vec![ValType::I64; q]));
let name = format!("__invokeHostFunc_{p}_{q}");
let mut params = vec![ValType::I32];
for _ in 0..p {
params.push(ValType::I64);
// Default return if no match
get_function_arg_type_builder.push(Instr::I32Const(0));
get_function_arg_type_builder.push(Instr::Return);

let get_function_arg_type_func = module.func(
"__get_function_arg_type",
vec![ValType::I32, ValType::I32], // takes (function_index, arg_index)
vec![ValType::I32], // returns type code
vec![],
);
get_function_arg_type_func.export("__get_function_arg_type");
get_function_arg_type_func.body = get_function_arg_type_builder;

// Create converters for each host function to reinterpret the I64 bit pattern as the expected type
let mut converter_indices = Vec::new();
for (_, (name, _index, params, results)) in import_items.iter().enumerate() {
let import_type = module
.types()
.push(|t| t.function(params.clone(), results.clone()));

let mut builder = wagen::Builder::default();

// Convert input parameters
for (i, param) in params.iter().enumerate() {
builder.push(Instr::LocalGet((i + 1) as u32)); // skip function index param

match param {
ValType::I32 => {
builder.push(Instr::I32WrapI64);
}
ValType::I64 => {
// No conversion needed - already i64
}
ValType::F32 => {
// Input is already the bit pattern from globals.rs convert_to_u64_bits
// First truncate to i32 then reinterpret as f32
builder.push(Instr::I32WrapI64);
builder.push(Instr::F32ReinterpretI32);
}
ValType::F64 => {
// Input is already the bit pattern from JS DataView
// Just reinterpret the i64 as f64
builder.push(Instr::F64ReinterpretI64);
}
r => {
anyhow::bail!("Unsupported param type: {:?}", r);
}
}
let invoke_host = module
.func(&name, params, vec![ValType::I64; q], [])
.export(&name);
let builder = invoke_host.builder();
for i in 1..=p {
builder.push(Instr::LocalGet(i as u32));
}

// Call the imported function
builder.push(Instr::LocalGet(0));
builder.push(Instr::CallIndirect {
ty: import_type,
table: import_table,
});

// Convert result back to i64 bits for JS
if let Some(result) = results.first() {
match result {
ValType::I32 => {
builder.push(Instr::I64ExtendI32U);
}
ValType::I64 => {
// Already i64, no conversion needed
}
ValType::F32 => {
// Convert f32 to its bit pattern
builder.push(Instr::I32ReinterpretF32);
builder.push(Instr::I64ExtendI32U);
}
ValType::F64 => {
// Convert f64 to its bit pattern
builder.push(Instr::I64ReinterpretF64);
}
r => {
anyhow::bail!("Unsupported result type: {:?}", r);
}
}
builder.push(Instr::LocalGet(0));
builder.push(Instr::CallIndirect {
ty: indirect_type,
table: import_table,
});
} else {
// No return value, push 0
builder.push(Instr::I64Const(0));
}

// Create the converter function
let mut shim_params = vec![ValType::I32]; // Function index
shim_params.extend(std::iter::repeat(ValType::I64).take(params.len()));

let conv_func = module.func(
&format!("__conv_{}", name),
shim_params,
vec![ValType::I64],
vec![],
);
conv_func.export(&format!("__conv_{}", name));
conv_func.body = builder;

converter_indices.push(conv_func.index);
}

let router = module.func(
"__invokeHostFunc",
vec![
ValType::I32, // func_idx
ValType::I64, // args[0]
ValType::I64, // args[1]
ValType::I64, // args[2]
ValType::I64, // args[3]
ValType::I64, // args[4]
],
vec![ValType::I64],
vec![],
);

// Similar builder logic as before but simplified to one function
let mut router_builder = wagen::Builder::default();

for (func_idx, (_name, _index, params, _results)) in import_items.iter().enumerate() {
router_builder.push(Instr::LocalGet(0)); // func index
router_builder.push(Instr::I32Const(func_idx as i32));
router_builder.push(Instr::I32Eq);
router_builder.push(Instr::If(BlockType::Empty));

// First push func_idx for converter
router_builder.push(Instr::LocalGet(0));

// Then push remaining args from router's inputs
for (i, _) in params.iter().enumerate() {
router_builder.push(Instr::LocalGet((i + 1) as u32));
}

router_builder.push(Instr::Call(converter_indices[func_idx]));
router_builder.push(Instr::Return);
router_builder.push(Instr::End);
}

router_builder.push(Instr::I64Const(0));
router_builder.push(Instr::Return);

router.export("__invokeHostFunc");
router.body = router_builder;

// Set up the table
module.active_element(
Some(import_table),
wagen::Elements::Functions(&import_elements),
);

// Generate exports
for (idx, export) in exports.functions.iter().enumerate() {
let params: Vec<_> = export.params.iter().map(|x| x.ptype).collect();
let results: Vec<_> = export.results.iter().map(|x| x.ptype).collect();
Expand All @@ -90,12 +280,11 @@ pub fn generate_wasm_shims(
export.name
);
}
let func = module
.func(&export.name, params.clone(), results.clone(), [])
.export(&export.name);
let builder = func.builder();

let mut builder = wagen::Builder::default();
builder.push(Instr::Call(__arg_start.index()));
for (parami, param) in params.into_iter().enumerate() {

for (parami, param) in params.iter().enumerate() {
builder.push(Instr::LocalGet(parami as u32));

match param {
Expand Down Expand Up @@ -138,8 +327,18 @@ pub fn generate_wasm_shims(
anyhow::bail!("Unsupported result type: {:?}", r);
}
}

let f = module.func(&export.name, params, results, vec![]);
f.export(&export.name);
f.body = builder;
}

// Validation with debug output
if let Err(error) = module.clone().validate_save(path.as_ref()) {
eprintln!("Validation failed: {:?}", error);
module.save("/tmp/wizer/incomplete_shim.wasm")?;
return Err(error);
}

module.validate_save(path.as_ref())?;
Ok(())
}
Loading

0 comments on commit db850c0

Please sign in to comment.