Skip to content

Commit

Permalink
[WIP] cranelift-runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
skewballfox committed Feb 22, 2025
1 parent 813f705 commit c0966de
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 94 deletions.
38 changes: 35 additions & 3 deletions crates/cubecl-cranelift/src/compiler/base.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,40 @@
use alloc::fmt::Debug;
use cranelift::prelude::FunctionBuilder;
use cranelift_codegen::ir::Function;
use cubecl_core::ExecutionMode;
use cubecl_core::{Compiler, ExecutionMode};

pub struct FunctionCompiler<'a> {
builder: FunctionBuilder<'a>,
use super::FfiFunction;

#[derive(Clone)]
pub struct FunctionCompiler {
//builder: FunctionBuilder<'static>,
exec_mode: ExecutionMode,
}

impl Compiler for FunctionCompiler {
type Representation = FfiFunction;

type CompilationOptions = ();

fn compile(
&mut self,
kernel: cubecl_core::prelude::KernelDefinition,
compilation_options: &Self::CompilationOptions,
mode: ExecutionMode,
) -> Self::Representation {
todo!()
}

fn elem_size(&self, elem: cubecl_core::ir::Elem) -> usize {
todo!()
}
}

impl Debug for FunctionCompiler {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FunctionCompiler")
.field("exec_mode", &self.exec_mode)
//.field("builder_func", &self.builder.func)
.finish()
}
}
21 changes: 21 additions & 0 deletions crates/cubecl-cranelift/src/compiler/ffi_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
Corresponds to wgpu/compiler/shader.rs. The compiled executable kernel functions, stored
as dynamically linked libraries.
*/

use alloc::fmt::Display;

use cubecl_core::compute::Binding;

#[derive(Debug, Clone)]
pub struct FfiFunction {
pub inputs: Vec<Binding>,
pub outputs: Vec<Binding>,
pub kernel_name: String,
}

impl Display for FfiFunction {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FfiFunction {{ kernel_name: {} }}", self.kernel_name)
}
}
2 changes: 2 additions & 0 deletions crates/cubecl-cranelift/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ mod base;
// mod mma;
mod warp;

mod ffi_function;
pub use base::*;
pub use ffi_function::*;
// pub use body::*;
// pub use element::*;
// pub use instruction::*;
Expand Down
112 changes: 94 additions & 18 deletions crates/cubecl-cranelift/src/compute/server.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,42 @@
use std::future::Future;

use cranelift::prelude::FunctionBuilderContext;
use cubecl_core::{
compute::DebugInformation,
future,
prelude::*,
server::{Binding, Handle},
Feature, KernelId, MemoryConfiguration, WgpuCompilationOptions,
};
use cubecl_runtime::{
debug::{DebugLogger, ProfileLevel},
memory_management::MemoryDeviceProperties,
memory_management::{MemoryDeviceProperties, MemoryManagement},
server::{self, ComputeServer},
storage::BindingResource,
TimestampsError, TimestampsResult,
};
use std::future::Future;
use web_time::Instant;

#[derive(Debug)]
enum KernelTimestamps {
Inferred { start_time: Instant },
Disabled,
}

impl KernelTimestamps {
fn enable(&mut self) {
if !matches!(self, Self::Disabled) {
return;
}

*self = Self::Inferred {
start_time: Instant::now(),
};
}

fn disable(&mut self) {
*self = Self::Disabled;
}
}

struct CompiledKernel {
cube_dim: CubeDim,
Expand All @@ -26,41 +49,79 @@ use hashbrown::HashMap;

use crate::compiler::FunctionCompiler;

use super::storage::CraneliftStorage;
use super::{storage::CraneliftStorage, CraneliftResource};
#[derive(Debug)]
pub struct CraneliftServer {
context: CraneLiftContext,
context: CraneliftContext,
logger: DebugLogger,
}

pub(crate) struct CraneLiftContext {
//Contains the state for
pub(crate) struct CraneliftContext {
builder_context: FunctionBuilderContext,
codegen_context: cranelift_codegen::Context,
timestamp: KernelTimestamps,
memory_management: MemoryManagement<CraneliftStorage>,
modules: HashMap<KernelId, cranelift_jit::JITModule>,
}

impl CraneliftContext {
fn execute_task(&mut self, kernel_id: KernelId, resources: Vec<CraneliftResource>) {
//let kernel: &CompiledKernel = self.modules.get(&kernel_id);
todo!()
}
}

impl ComputeServer for CraneliftServer {
type Kernel = Box<dyn for<'a> CubeTask<FunctionCompiler<'a>>>;
type Kernel = Box<dyn CubeTask<FunctionCompiler>>;
type Storage = CraneliftStorage;
type Feature = Feature;

fn read(
&mut self,
bindings: Vec<Binding>,
) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static {
todo!()
let mut result = Vec::with_capacity(bindings.len());
result.extend(bindings.into_iter().map(|binding| {
let rb = self.get_resource(binding);
let resource = rb.resource();
Vec::<u8>::from(resource)
}));
async move { result }
}

fn get_resource(&mut self, binding: Binding) -> BindingResource<Self> {
todo!()
BindingResource::new(
binding.clone(),
self.context
.memory_management
.get_resource(binding.memory, binding.offset_start, binding.offset_end)
.expect("Failed to find resource"),
)
}

fn create(&mut self, data: &[u8]) -> Handle {
todo!()
let alloc_handle = self.empty(data.len());
let alloc_binding = alloc_handle.clone().binding();
// maybe use rayon here?
let resource_dest = self
.context
.memory_management
.get_resource(
alloc_binding.memory,
alloc_binding.offset_start,
alloc_binding.offset_end,
)
.expect("Failed to find resource");
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), resource_dest.ptr, data.len());
}
alloc_handle
}

fn empty(&mut self, size: usize) -> Handle {
todo!()
let alloc_handle = self.context.memory_management.reserve(size as u64);
Handle::new(alloc_handle, None, None, size as u64)
}

unsafe fn execute(
Expand All @@ -70,37 +131,52 @@ impl ComputeServer for CraneliftServer {
bindings: Vec<Binding>,
kind: ExecutionMode,
) {
todo!()
// Note: Maybe this can be a function/trait in cubecl-core?
// Check for any profiling work to be done before execution.
let profile_level = self.logger.profile_level();
let profile_info = if profile_level.is_some() {
Some((kernel.name(), kernel.id()))
} else {
None
};

if let Some(level) = profile_level {
todo!()
} else {
}

//match count if needed
}

fn flush(&mut self) {
todo!()
}

fn sync(&mut self) -> impl std::future::Future<Output = ()> + Send + 'static {
todo!()
self.logger.profile_summary();
async move { todo!() }
}

fn sync_elapsed(
&mut self,
) -> impl std::future::Future<Output = TimestampsResult> + Send + 'static {
todo!()
async move { todo!() }
}

fn memory_usage(&self) -> cubecl_core::MemoryUsage {
todo!()
self.context.memory_management.memory_usage()
}

fn enable_timestamps(&mut self) {
todo!()
self.context.timestamp.enable();
}

fn disable_timestamps(&mut self) {
todo!()
self.context.timestamp.disable();
}
}

impl alloc::fmt::Debug for CraneLiftContext {
impl alloc::fmt::Debug for CraneliftContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
//None of the fields implement Debug. It might be possible to
//display some state, but I haven't worked out how to do that yet.
Expand Down
32 changes: 20 additions & 12 deletions crates/cubecl-cranelift/src/compute/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,47 @@ pub struct CraneliftStorage {
unsafe impl Send for CraneliftStorage {}

pub struct CraneliftResource {
ptr: *mut u8,
pub ptr: *mut u8,
offset: u64,
size: u64,
}

impl From<&CraneliftResource> for Vec<u8> {
fn from(resource: &CraneliftResource) -> Self {
unsafe {
Vec::from_raw_parts(
resource.ptr.add(resource.offset as usize),
resource.size as usize,
resource.size as usize,
)
}
}
}

unsafe impl Send for CraneliftResource {}

impl ComputeStorage for CraneliftStorage {
type Resource = CraneliftResource;

const ALIGNMENT: u64 = 8;
const ALIGNMENT: u64 = 8; //FIXME

fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let ptr = self.memory.get(&handle.id).unwrap();
let offset = handle.offset();
let size = handle.size();

todo!()
CraneliftResource {
ptr: *ptr,
offset,
size,
}
}

fn alloc(&mut self, size: u64) -> StorageHandle {
let id = StorageId::new();
let ptr = unsafe {
alloc(Layout::from_size_align_unchecked(
next_2_power(size),
(size as usize).next_power_of_two(),
Self::ALIGNMENT as usize,
))
};
Expand All @@ -50,11 +66,3 @@ impl ComputeStorage for CraneliftStorage {
(self.memory.remove(&id).unwrap());
}
}

fn next_2_power(n: u64) -> usize {
let mut p = 1;
while p < n {
p <<= 1;
}
p as usize
}
Loading

0 comments on commit c0966de

Please sign in to comment.