Skip to content

Commit

Permalink
Better instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
tachyonicbytes committed Nov 26, 2023
1 parent b5d6b30 commit 95b36c5
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 129 deletions.
150 changes: 66 additions & 84 deletions Cargo.lock

Large diffs are not rendered by default.

28 changes: 15 additions & 13 deletions dozer-sql/expression/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use dozer_types::models::udf_config::OnnxConfig;
#[cfg(feature = "wasm")]
use dozer_types::models::udf_config::WasmConfig;

#[derive(Clone, PartialEq, Debug)]
#[derive(Clone, Debug)]
pub struct ExpressionBuilder {
// Must be an aggregation function
pub aggregations: Vec<Expression>,
Expand Down Expand Up @@ -577,6 +577,7 @@ impl ExpressionBuilder {
schema,
udfs,
)
.await
}

#[cfg(not(feature = "wasm"))]
Expand Down Expand Up @@ -1007,7 +1008,7 @@ impl ExpressionBuilder {
}

#[cfg(feature = "wasm")]
fn parse_wasm_udf(
async fn parse_wasm_udf(
&mut self,
name: String,
config: &WasmConfig,
Expand All @@ -1020,35 +1021,36 @@ impl ExpressionBuilder {
use crate::wasm::utils::wasm_validate_input_and_return;
use std::path::Path;

let args = function
.args
.iter()
.map(|argument| self.parse_sql_function_arg(false, argument, schema, udfs))
.collect::<Result<Vec<_>, Error>>()?;
let mut args = vec![];
for argument in &function.args {
let arg = self
.parse_sql_function_arg(false, argument, schema, udfs)
.await?;
args.push(arg);
}

let (value_types, return_type) = wasm_validate_input_and_return(
let session = wasm_validate_input_and_return(
schema,
name.as_str(),
Path::new(&config.path.clone()),
&args,
)
.unwrap();
let return_type = match return_type {

let return_type = match session.return_type {
wasmtime::ValType::I32 => FieldType::Int,
wasmtime::ValType::I64 => FieldType::Int,
wasmtime::ValType::F32 => FieldType::Float,
wasmtime::ValType::F64 => FieldType::Float,
wasmtime::ValType::V128 => todo!(),
wasmtime::ValType::FuncRef => todo!(),
wasmtime::ValType::ExternRef => todo!(),
_ => todo!(),
};

Ok(Expression::WasmUDF {
name: name.to_string(),
module: config.path.clone(),
args,
value_types,
return_type,
session,
})
}

Expand Down
9 changes: 5 additions & 4 deletions dozer-sql/expression/src/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ use dozer_types::types::Record;
use dozer_types::types::{Field, FieldType, Schema, SourceDefinition};

#[cfg(feature = "wasm")]
use wasmtime::ValType;
use crate::wasm::WasmSession;


#[derive(Clone, Debug, PartialEq)]
pub enum Expression {
Expand Down Expand Up @@ -106,8 +107,8 @@ pub enum Expression {
name: String,
module: String,
args: Vec<Expression>,
value_types: Vec<ValType>,
return_type: FieldType,
session: WasmSession,
},
}

Expand Down Expand Up @@ -362,10 +363,10 @@ impl Expression {

#[cfg(feature = "wasm")]
Expression::WasmUDF {
name, module, args, ..
name: _name, module: _module, args, return_type: _, session
} => {
use crate::wasm::udf::evaluate_wasm_udf;
evaluate_wasm_udf(schema, name, module, args, record)
evaluate_wasm_udf(schema, args, record, session)
}
Expression::UnaryOperator { operator, arg } => operator.evaluate(schema, arg, record),
Expression::AggregateFunction { fun, args: _ } => {
Expand Down
2 changes: 2 additions & 0 deletions dozer-sql/expression/src/wasm/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ pub enum Error {
WasmInputTypeSizeMismatch(usize, usize),
#[error("The WASM function {0} is missing from the module {1}")]
WasmFunctionMissing(String, String),
#[error("Could not instantiate WASM module {0}")]
WasmInstantiateError(String),
}
34 changes: 34 additions & 0 deletions dozer-sql/expression/src/wasm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,37 @@
pub mod error;
pub mod udf;
pub mod utils;
use std::fmt;

use wasmtime::*;

#[derive(Clone)]
pub struct WasmSession {
/// Used just for printing errors
pub module_path: String,
pub function_name: String,
pub instance_pre: InstancePre<()>,
pub engine: Engine,
pub value_types: Vec<ValType>,
pub return_type: ValType,
}

/// Debug implementation for WasmSession.
/// `instance_pre` is omitted from this implementation.
impl fmt::Debug for WasmSession {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WasmSession")
.field("module_path", &self.function_name)
.field("function_name", &self.function_name)
// Omit the debug values for `InstancePre` and `Engine`
.field("value_types", &self.value_types)
.field("return_type", &self.return_type)
.finish()
}
}

impl PartialEq for WasmSession {
fn eq(&self, other: &Self) -> bool {
self.value_types == other.value_types && self.return_type == other.return_type
}
}
38 changes: 15 additions & 23 deletions dozer-sql/expression/src/wasm/udf.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::error::Error::{WasmFunctionMissing, WasmTrap};
use super::error::Error::{WasmTrap, WasmInstantiateError};
use crate::error::Error::{self, Wasm};
use dozer_types::ordered_float::OrderedFloat;
use dozer_types::types::{Field, Record, Schema};
Expand All @@ -7,36 +7,27 @@ use wasmtime::*;

use crate::execution::Expression;

use super::WasmSession;

pub fn evaluate_wasm_udf(
schema: &Schema,
name: &str,
config: &str,
args: &[Expression],
args: &mut [Expression],
record: &Record,
session: &WasmSession,
) -> Result<Field, Error> {
let input_values = args
.iter()
.iter_mut()
.map(|arg| arg.evaluate(record, schema))
.collect::<Result<Vec<_>, Error>>()?;

let engine = Engine::default();
let module = Module::from_file(&engine, config).unwrap();
let mut store = Store::new(&engine, ());
let instance = Instance::new(&mut store, &module, &[]).unwrap();

let wasm_udf_func;
match instance.get_func(&mut store, name) {
Some(func) => {
wasm_udf_func = func;
}
None => {
return Err(Wasm(WasmFunctionMissing(
name.to_string(),
config.to_string(),
)));
}
}
// Instantiate again, because we cannot pass as `Store` in the WasmSession struct
let mut store = Store::new(&session.engine, ());
let instance = session.instance_pre.instantiate(&mut store)
.map_err(|_| (WasmInstantiateError(session.module_path.clone())))?;

// Type checking already checked the name of the function
// Get the Func, FuncType, inputs and output of the wasm function
let wasm_udf_func = instance.get_func(&mut store, session.function_name.as_str()).unwrap();
let func_type = wasm_udf_func.ty(&mut store);
let param_types = func_type.params();
let mut result_type = func_type.results();
Expand Down Expand Up @@ -78,13 +69,14 @@ pub fn evaluate_wasm_udf(
})
.collect();

// Type checking verified this
let result = result_type.next().unwrap();
let mut results: [Val; 1] = [Val::I64(0)];

match wasm_udf_func.call(&mut store, &values, &mut results) {
Ok(()) => {}
Err(trap) => {
return Err(Wasm(WasmTrap(name.to_string(), trap.to_string())));
return Err(Wasm(WasmTrap(session.function_name.clone(), trap.to_string())));
}
}

Expand Down
2 changes: 0 additions & 2 deletions dozer-sql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ use dozer_types::thiserror::Error;
use dozer_types::types::{Field, FieldType};
use std::fmt::{Display, Formatter};

use super::utils::serialize::DeserializationError;

#[derive(Debug, Clone)]
pub struct FieldTypes {
types: Vec<FieldType>,
Expand Down
1 change: 0 additions & 1 deletion dozer-types/src/models/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ pub struct Config {
/// Dozer Cloud specific configuration
pub cloud: Cloud,

#[prost(message, repeated, tag = "15")]
/// UDF specific configuration (eg. !Onnx, Wasm)
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub udfs: Vec<UdfConfig>,
Expand Down
1 change: 0 additions & 1 deletion dozer-types/src/models/udf_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ pub struct JavaScriptConfig {

#[derive(Debug, Serialize, Deserialize, JsonSchema, Eq, PartialEq, Clone)]
pub struct WasmConfig {
#[prost(string)]
/// path to the module file
pub path: String,
}
2 changes: 1 addition & 1 deletion json_schemas/dozer.json
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
]
},
"udfs": {
"description": "UDF specific configuration (eg. !Onnx, !Wasm)",
"description": "UDF specific configuration (eg. !Onnx, Wasm)",
"type": "array",
"items": {
"$ref": "#/definitions/UdfConfig"
Expand Down

0 comments on commit 95b36c5

Please sign in to comment.