Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Port pykeio/ort#218 #5

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 142 additions & 110 deletions src/operator/bound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
kernel::{Kernel, KernelAttributes, KernelContext},
DummyOperator, Operator
};
use crate::error::IntoStatus;
use crate::{error::IntoStatus, extern_system_fn};

#[repr(C)]
#[derive(Clone)]
Expand Down Expand Up @@ -62,115 +62,147 @@ impl<O: Operator> BoundOperator<O> {
&*op.cast()
}

pub(crate) unsafe extern "C" fn CreateKernelV2(
_: *const ort_sys::OrtCustomOp,
_: *const ort_sys::OrtApi,
info: *const ort_sys::OrtKernelInfo,
kernel_ptr: *mut *mut ort_sys::c_void
) -> *mut ort_sys::OrtStatus {
let kernel = match O::create_kernel(&KernelAttributes::new(info)) {
Ok(kernel) => kernel,
e => return e.into_status()
};
*kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast();
Ok(()).into_status()
}

pub(crate) unsafe extern "C" fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
let context = KernelContext::new(context);
O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::<O::Kernel>() }, &context).into_status()
}

pub(crate) unsafe extern "C" fn KernelDestroy(op_kernel: *mut ort_sys::c_void) {
drop(Box::from_raw(op_kernel.cast::<O::Kernel>()));
}

pub(crate) unsafe extern "C" fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.name.as_ptr()
}
pub(crate) unsafe extern "C" fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
}

pub(crate) unsafe extern "C" fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::min_version()
}
pub(crate) unsafe extern "C" fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::max_version()
}

pub(crate) unsafe extern "C" fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType {
O::inputs()[index as usize].memory_type.into()
}
pub(crate) unsafe extern "C" fn GetInputCharacteristic(
_: *const ort_sys::OrtCustomOp,
index: ort_sys::size_t
) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::inputs()[index as usize].characteristic.into()
}
pub(crate) unsafe extern "C" fn GetOutputCharacteristic(
_: *const ort_sys::OrtCustomOp,
index: ort_sys::size_t
) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::outputs()[index as usize].characteristic.into()
}
pub(crate) unsafe extern "C" fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::inputs().len() as _
}
pub(crate) unsafe extern "C" fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::outputs().len() as _
}
pub(crate) unsafe extern "C" fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::inputs()[index as usize]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
pub(crate) unsafe extern "C" fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::outputs()[index as usize]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
pub(crate) unsafe extern "C" fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::inputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("input minimum arity overflows i32")
}
pub(crate) unsafe extern "C" fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::inputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
pub(crate) unsafe extern "C" fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::outputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("output minimum arity overflows i32")
}
pub(crate) unsafe extern "C" fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::outputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}

pub(crate) unsafe extern "C" fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status()
extern_system_fn! {
pub(crate) unsafe fn CreateKernelV2(
_: *const ort_sys::OrtCustomOp,
_: *const ort_sys::OrtApi,
info: *const ort_sys::OrtKernelInfo,
kernel_ptr: *mut *mut ort_sys::c_void
) -> *mut ort_sys::OrtStatus {
let kernel = match O::create_kernel(&KernelAttributes::new(info)) {
Ok(kernel) => kernel,
e => return e.into_status()
};
*kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast();
Ok(()).into_status()
}
}

extern_system_fn! {
pub(crate) unsafe fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus {
let context = KernelContext::new(context);
O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::<O::Kernel>() }, &context).into_status()
}
}

extern_system_fn! {
pub(crate) unsafe fn KernelDestroy(op_kernel: *mut ort_sys::c_void) {
drop(Box::from_raw(op_kernel.cast::<O::Kernel>()));
}
}

extern_system_fn! {
pub(crate) unsafe fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.name.as_ptr()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char {
let safe = Self::safe(op);
safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
}
}

extern_system_fn! {
pub(crate) unsafe fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::min_version()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::max_version()
}
}

extern_system_fn! {
pub(crate) unsafe fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType {
O::inputs()[index as usize].memory_type.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::inputs()[index as usize].characteristic.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic {
O::outputs()[index as usize].characteristic.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::inputs().len() as _
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t {
O::outputs().len() as _
}
}
extern_system_fn! {
pub(crate) unsafe fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::inputs()[index as usize]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
}
extern_system_fn! {
pub(crate) unsafe fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType {
O::outputs()[index as usize]
.r#type
.map(|c| c.into())
.unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)
}
}
extern_system_fn! {
pub(crate) unsafe fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::inputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("input minimum arity overflows i32")
}
}
extern_system_fn! {
pub(crate) unsafe fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::inputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
}
extern_system_fn! {
pub(crate) unsafe fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::outputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_min_arity)
.unwrap_or(1)
.try_into()
.expect("output minimum arity overflows i32")
}
}
extern_system_fn! {
pub(crate) unsafe fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int {
O::outputs()
.into_iter()
.find(|c| c.characteristic == InputOutputCharacteristic::Variadic)
.and_then(|c| c.variadic_homogeneity)
.unwrap_or(false)
.into()
}
}

extern_system_fn! {
pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus {
O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status()
}
}
}

Expand Down
Loading