Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid copy when the plugin returns #13

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions examples/hello_c/hello.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,37 @@
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#define PROTOCOL_FUNCTION __attribute__((import_module("typst_env"))) extern "C"
#else
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#define PROTOCOL_FUNCTION __attribute__((import_module("typst_env"))) extern
#endif

// ===
// Functions for the protocol

PROTOCOL_FUNCTION void
wasm_minimal_protocol_send_result_to_host(const uint8_t *ptr, size_t len);
PROTOCOL_FUNCTION void wasm_minimal_protocol_write_args_to_buffer(uint8_t *ptr);

EMSCRIPTEN_KEEPALIVE void wasm_minimal_protocol_free_byte_buffer(uint8_t *ptr,
size_t len) {
free(ptr);
}

// ===

EMSCRIPTEN_KEEPALIVE
int32_t hello(void) {
const char message[] = "Hello world !";
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
sizeof(message) - 1);
const char static_message[] = "Hello world !";
const size_t length = sizeof(static_message);
char *message = malloc(length);
memcpy((void *)message, (void *)static_message, length);
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
return 0;
}

Expand All @@ -36,7 +50,6 @@ int32_t double_it(size_t arg_len) {
alloc_result[arg_len + i] = alloc_result[i];
}
wasm_minimal_protocol_send_result_to_host(alloc_result, result_len);
free(alloc_result);
return 0;
}

Expand Down Expand Up @@ -66,7 +79,6 @@ int32_t concatenate(size_t arg1_len, size_t arg2_len) {

wasm_minimal_protocol_send_result_to_host(result, total_len + 1);

free(result);
free(args);
return 0;
}
Expand Down Expand Up @@ -102,24 +114,27 @@ int32_t shuffle(size_t arg1_len, size_t arg2_len, size_t arg3_len) {

wasm_minimal_protocol_send_result_to_host(result, result_len);

free(result);
free(args);
return 0;
}

EMSCRIPTEN_KEEPALIVE
int32_t returns_ok() {
const char message[] = "This is an `Ok`";
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
sizeof(message) - 1);
const char static_message[] = "This is an `Ok`";
const size_t length = sizeof(static_message);
char *message = malloc(length);
memcpy((void *)message, (void *)static_message, length);
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
return 0;
}

EMSCRIPTEN_KEEPALIVE
int32_t returns_err() {
const char message[] = "This is an `Err`";
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
sizeof(message) - 1);
const char static_message[] = "This is an `Err`";
const size_t length = sizeof(static_message);
char *message = malloc(length);
memcpy((void *)message, (void *)static_message, length);
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
return 1;
}

Expand Down
49 changes: 32 additions & 17 deletions examples/hello_zig/hello.zig
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
const std = @import("std");
const allocator = std.heap.page_allocator;

// ===
// Functions for the protocol

extern "typst_env" fn wasm_minimal_protocol_send_result_to_host(ptr: [*]const u8, len: usize) void;
extern "typst_env" fn wasm_minimal_protocol_write_args_to_buffer(ptr: [*]u8) void;

export fn wasm_minimal_protocol_free_byte_buffer(ptr: [*]u8, len: usize) void {
var slice: []u8 = undefined;
slice.ptr = ptr;
slice.len = len;
allocator.free(slice);
}

// ===

export fn hello() i32 {
const message = "Hello world !";
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
var result = allocator.alloc(u8, message.len) catch return 1;
@memcpy(result, message);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 0;
}

export fn double_it(arg1_len: usize) i32 {
var alloc_result = allocator.alloc(u8, arg1_len * 2) catch return 1;
defer allocator.free(alloc_result);
wasm_minimal_protocol_write_args_to_buffer(alloc_result.ptr);
var result = allocator.alloc(u8, arg1_len * 2) catch return 1;
wasm_minimal_protocol_write_args_to_buffer(result.ptr);
for (0..arg1_len) |i| {
alloc_result[i + arg1_len] = alloc_result[i];
result[i + arg1_len] = result[i];
}
wasm_minimal_protocol_send_result_to_host(alloc_result.ptr, alloc_result.len);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 0;
}

Expand All @@ -27,7 +40,6 @@ export fn concatenate(arg1_len: usize, arg2_len: usize) i32 {
wasm_minimal_protocol_write_args_to_buffer(args.ptr);

var result = allocator.alloc(u8, arg1_len + arg2_len + 1) catch return 1;
defer allocator.free(result);
for (0..arg1_len) |i| {
result[i] = args[i];
}
Expand All @@ -49,27 +61,30 @@ export fn shuffle(arg1_len: usize, arg2_len: usize, arg3_len: usize) i32 {
var arg2 = args[arg1_len .. arg1_len + arg2_len];
var arg3 = args[arg1_len + arg2_len .. args.len];

var result: std.ArrayList(u8) = std.ArrayList(u8).initCapacity(allocator, args_len + 2) catch return 1;
defer result.deinit();
result.appendSlice(arg3) catch return 1;
result.append('-') catch return 1;
result.appendSlice(arg1) catch return 1;
result.append('-') catch return 1;
result.appendSlice(arg2) catch return 1;
var result = allocator.alloc(u8, arg1_len + arg2_len + arg3_len + 2) catch return 1;
@memcpy(result[0..arg3.len], arg3);
result[arg3.len] = '-';
@memcpy(result[arg3.len + 1 ..][0..arg1.len], arg1);
result[arg3.len + arg1.len + 1] = '-';
@memcpy(result[arg3.len + arg1.len + 2 ..][0..arg2.len], arg2);

wasm_minimal_protocol_send_result_to_host(result.items.ptr, result.items.len);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 0;
}

export fn returns_ok() i32 {
const message = "This is an `Ok`";
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
var result = allocator.alloc(u8, message.len) catch return 1;
@memcpy(result, message);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 0;
}

export fn returns_err() i32 {
const message = "This is an `Err`";
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
var result = allocator.alloc(u8, message.len) catch return 1;
@memcpy(result, message);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 1;
}

Expand Down
121 changes: 79 additions & 42 deletions examples/host-wasmi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,67 @@
use wasmi::{AsContext, Caller, Engine, Func as Function, Linker, Module, Value};
use wasmi::{AsContext, Caller, Engine, Func as Function, Linker, Memory, Module, Value};

type Store = wasmi::Store<PersistentData>;

/// Reference to a slice of memory returned after
/// [calling a wasm function](PluginInstance::call).
///
/// # Drop
/// On [`Drop`], this will free the slice of memory inside the plugin.
///
/// As such, this structure mutably borrows the [`PluginInstance`], which prevents
/// another function from being called.
pub struct ReturnedData<'a> {
memory: Memory,
ptr: u32,
len: u32,
free_function: &'a Function,
context_mut: &'a mut Store,
}

impl<'a> ReturnedData<'a> {
/// Get a reference to the returned slice of data.
///
/// # Panic
/// This may panic if the function returned an invalid `(ptr, len)` pair.
pub fn get(&self) -> &[u8] {
&self.memory.data(&*self.context_mut)[self.ptr as usize..(self.ptr + self.len) as usize]
}
}

impl Drop for ReturnedData<'_> {
fn drop(&mut self) {
self.free_function
.call(
&mut *self.context_mut,
&[Value::I32(self.ptr as _), Value::I32(self.len as _)],
&mut [],
)
.unwrap();
}
}

#[derive(Debug, Clone)]
struct PersistentData {
result_data: Vec<u8>,
result_ptr: u32,
result_len: u32,
arg_buffer: Vec<u8>,
}

#[derive(Debug)]
pub struct PluginInstance {
store: Store,
memory: Memory,
free_function: Function,
functions: Vec<(String, Function)>,
}

impl PluginInstance {
pub fn new_from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, String> {
let engine = Engine::default();
let data = PersistentData {
result_data: Vec::new(),
arg_buffer: Vec::new(),
result_ptr: 0,
result_len: 0,
};
let mut store = Store::new(&engine, data);

Expand All @@ -32,11 +74,8 @@ impl PluginInstance {
"typst_env",
"wasm_minimal_protocol_send_result_to_host",
move |mut caller: Caller<PersistentData>, ptr: u32, len: u32| {
let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
let mut buffer = std::mem::take(&mut caller.data_mut().result_data);
buffer.resize(len as usize, 0);
memory.read(&caller, ptr as _, &mut buffer).unwrap();
caller.data_mut().result_data = buffer;
caller.data_mut().result_ptr = ptr;
caller.data_mut().result_len = len;
},
)
.unwrap()
Expand All @@ -51,54 +90,44 @@ impl PluginInstance {
},
)
.unwrap()
// hack to accept wasi file
// https://github.com/near/wasi-stub is preferred
/*
.func_wrap(
"wasi_snapshot_preview1",
"fd_write",
|_: i32, _: i32, _: i32, _: i32| 0i32,
)
.unwrap()
.func_wrap(
"wasi_snapshot_preview1",
"environ_get",
|_: i32, _: i32| 0i32,
)
.unwrap()
.func_wrap(
"wasi_snapshot_preview1",
"environ_sizes_get",
|_: i32, _: i32| 0i32,
)
.unwrap()
.func_wrap(
"wasi_snapshot_preview1",
"proc_exit",
|_: i32| {},
)
.unwrap()
*/
.instantiate(&mut store, &module)
.map_err(|e| format!("{e}"))?
.start(&mut store)
.map_err(|e| format!("{e}"))?;

let mut free_function = None;
let functions = instance
.exports(&store)
.filter_map(|e| {
let name = e.name().to_owned();
e.into_func().map(|func| (name, func))

e.into_func().map(|func| {
if name == "wasm_minimal_protocol_free_byte_buffer" {
free_function = Some(func);
}
(name, func)
})
})
.collect::<Vec<_>>();
Ok(Self { store, functions })
let free_function = free_function.unwrap();
let memory = instance
.get_export(&store, "memory")
.unwrap()
.into_memory()
.unwrap();
Ok(Self {
store,
memory,
free_function,
functions,
})
}

fn write(&mut self, args: &[&[u8]]) {
self.store.data_mut().arg_buffer = args.concat();
}

pub fn call(&mut self, function: &str, args: &[&[u8]]) -> Result<Vec<u8>, String> {
pub fn call(&mut self, function: &str, args: &[&[u8]]) -> Result<ReturnedData, String> {
self.write(args);

let (_, function) = self
Expand All @@ -122,11 +151,19 @@ impl PluginInstance {
code.first().cloned().unwrap_or(Value::I32(3)) // if the function returns nothing
};

let s = std::mem::take(&mut self.store.data_mut().result_data);
let (ptr, len) = (self.store.data().result_ptr, self.store.data().result_len);

let result = ReturnedData {
memory: self.memory,
ptr,
len,
free_function: &self.free_function,
context_mut: &mut self.store,
};

match code {
Value::I32(0) => Ok(s),
Value::I32(1) => Err(match String::from_utf8(s) {
Value::I32(0) => Ok(result),
Value::I32(1) => Err(match std::str::from_utf8(result.get()) {
Ok(err) => format!("plugin errored with: '{}'", err,),
Err(_) => String::from("plugin errored and did not return valid UTF-8"),
}),
Expand Down
7 changes: 3 additions & 4 deletions examples/test-runner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
// you need to build the hello example first

use anyhow::Result;
use std::process::Command;

use host_wasmi::PluginInstance;
use std::process::Command;

#[cfg(not(feature = "wasi"))]
mod consts {
Expand Down Expand Up @@ -118,7 +117,7 @@ fn main() -> Result<()> {
return Ok(());
}
};
match String::from_utf8(result) {
match std::str::from_utf8(result.get()) {
Ok(s) => println!("{s}"),
Err(_) => panic!("Error: function call '{function}' did not return UTF-8"),
}
Expand All @@ -141,7 +140,7 @@ fn main() -> Result<()> {
continue;
}
};
match String::from_utf8(result) {
match std::str::from_utf8(result.get()) {
Ok(s) => println!("{s}"),
Err(_) => panic!("Error: function call '{function}' did not return UTF-8"),
}
Expand Down
Loading