diff --git a/Cargo.lock b/Cargo.lock index 10a0a24..271d195 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,9 +136,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.86" +version = "1.0.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730" +checksum = "02f341c093d19155a6e41631ce5971aac4e9a868262212153124c15fa22d1cdc" dependencies = [ "libc", ] @@ -1195,7 +1195,7 @@ checksum = "0770833d60a970638e989b3fa9fd2bb1aaadcf88963d1659fd7d9990196ed2d6" [[package]] name = "xla" version = "0.1.6" -source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#a1b250fd0cf8f613449f793dc261fa0c22da0bf4" +source = "git+https://github.com/Ebanflo42/xla-rs?branch=dev#4293f038ee1b8466da4bb0e0859413d6ea8a1aca" dependencies = [ "bindgen", "cc", diff --git a/src/core/graph/compile.rs b/src/core/graph/compile.rs index 4f3fea9..ca089ec 100644 --- a/src/core/graph/compile.rs +++ b/src/core/graph/compile.rs @@ -11,6 +11,9 @@ pub enum CompileError { #[error("Found an unused Parameter in the compute graph {0}")] UnusedParameter(Callsite), + #[error("Unable to compile a context that does not return")] + NoReturn, + #[error("XLA error: {0}")] Xla(#[from] xla::Error), } @@ -50,36 +53,47 @@ impl Context { } } - pub fn compile + Copy>( + pub fn compile + Copy, const N: usize>( &mut self, - a: A, name: &str, + returns: [A; N], client: &xla::PjRtClient, ) -> Result { // TODO: gate debug mode behind a feature flag - //self.autodiff(a, usize::MAX); - println!("{}", self.to_string(a)); - while self.autodiff(a, 1)? { - println!("{}", self.to_string(a)); + if returns.is_empty() { + Err(CompileError::NoReturn)?; } - //self.foldconsts(a, usize::MAX); - while self.foldconsts(a, 1)? { - println!("{}", self.to_string(a)); + for a in returns.iter() { + self.autodiff(*a, usize::MAX)?; } + //println!("{}", self.to_string(a)); + //while self.autodiff(a, 1)? { + // println!("{}", self.to_string(a)); + //} - //self.extract_subterms(a, usize::MAX); - while self.extract_subterms(a, 1)? { - println!("{}", self.to_string(a)); + for a in returns.iter() { + self.foldconsts(*a, usize::MAX)?; } + //while self.foldconsts(a, 1)? { + // println!("{}", self.to_string(a)); + //} + + for a in returns.iter() { + self.extract_subterms(*a, usize::MAX)?; + } + //while self.extract_subterms(a, 1)? { + // println!("{}", self.to_string(a)); + //} - let builder = xla::XlaBuilder::new(name); // Get the bottom-up dependencies of the compute graph let mut dependent_nodes = HashMap::new(); let mut constants = HashSet::new(); let mut parameters = HashSet::new(); - self.get_dependent_nodes(a, &mut dependent_nodes, &mut constants, &mut parameters)?; + for a in returns.iter() { + self.get_dependent_nodes(*a, &mut dependent_nodes, &mut constants, &mut parameters)?; + } // Prepare to loop through the unda compute graph and construct the XLA compute graph let mut xla_op_slotmap: SlotMap = SlotMap::with_key(); @@ -87,6 +101,8 @@ impl Context { let mut unda_xla_map: HashMap = HashMap::new(); let mut covered_ops: HashSet = HashSet::new(); + let builder = xla::XlaBuilder::new(name); + // declare parameters with the XLA builder for (i, unda_id) in self.param_indices.iter().enumerate() { let node = &self.nodes[*unda_id]; @@ -116,7 +132,6 @@ impl Context { } // Initialize constants for the XLA builder - // >1 dimensions not yet supported for constants for unda_id in constants.iter() { let node = &self.nodes[*unda_id]; @@ -182,7 +197,10 @@ impl Context { } } - let xla_computation = xla_op_slotmap[unda_xla_map[&a.into()]].build()?; + let xla_return_vec: Vec<&xla::XlaOp> = returns.into_iter().map(|i| &xla_op_slotmap[unda_xla_map[&i.into()]]).collect(); + let xla_return_tuple = builder.tuple(&xla_return_vec.as_slice())?; + + let xla_computation = xla_return_tuple.build()?; Ok(xla_computation.compile(client)?) } diff --git a/src/core/graph/context.rs b/src/core/graph/context.rs index c8c2e7e..ea8c130 100644 --- a/src/core/graph/context.rs +++ b/src/core/graph/context.rs @@ -39,6 +39,9 @@ pub enum ContextError { #[error("Parameter \"{0}\" {1} already exists in the context at {2}")] DuplicateParameter(String, Callsite, Callsite), + + #[error("Tried to call Context::return more than once.")] + MultipleReturns(), } pub type Result = std::result::Result; diff --git a/src/core/graph/tests.rs b/src/core/graph/tests.rs index 21bf10b..bbc5b5e 100644 --- a/src/core/graph/tests.rs +++ b/src/core/graph/tests.rs @@ -19,7 +19,7 @@ mod tests { // client must be exposed to the user, it is very nice to control device, memory fraction, and pre-allocation let client = xla::PjRtClient::gpu(0.7, false).expect("client"); let name = "test"; - let executable = ctx.compile(sum, &name, &client).expect("executable"); + let executable = ctx.compile(&name, [sum], &client).expect("executable"); let x_input = xla::Literal::scalar(2f32); let y_input = xla::Literal::scalar(3f32); @@ -30,7 +30,8 @@ mod tests { let host_result = device_result[0][0] .to_literal_sync() .expect("to_literal_sync"); - let rust_result = host_result.to_vec::().expect("to_vec"); + let untupled_result = host_result.to_tuple1().expect("untuple"); + let rust_result = untupled_result.to_vec::().expect("to_vec"); println!("{:?}", rust_result); assert_eq!(rust_result[0], 9f32); @@ -49,13 +50,14 @@ mod tests { let barbaz = ctx.mul(bar, baz).expect("barbaz"); let client = xla::PjRtClient::gpu(0.7, false).expect("client"); - let executable = ctx.compile(barbaz, "test", &client).expect("executable"); + let executable = ctx.compile("test", [barbaz], &client).expect("executable"); let device_result = executable.execute::(&[]).expect("execute"); let host_result = device_result[0][0] .to_literal_sync() .expect("to_literal_sync"); - let f32_result = host_result + let untupled_result = host_result.to_tuple1().expect("untuple"); + let f32_result = untupled_result .convert(xla::ElementType::F32.primitive_type()) .expect("f32 conversion"); let rust_result = f32_result.to_vec::().expect("to_vec"); @@ -79,7 +81,7 @@ mod tests { let sum = ctx.add(my_const, my_param).expect("sum"); let client = xla::PjRtClient::gpu(0.7, false).expect("client"); - let executable = ctx.compile(sum, "test", &client).expect("executable"); + let executable = ctx.compile("test", [sum], &client).expect("executable"); let my_param_input = xla::Literal::read_npy("test.npy", &()).expect("my_param_input"); @@ -87,8 +89,49 @@ mod tests { let host_result = device_result[0][0] .to_literal_sync() .expect("to_literal_sync"); - let rust_result = host_result.to_vec::().expect("to_vec"); + let untupled_result = host_result.to_tuple1().expect("untuple"); + let rust_result = untupled_result.to_vec::().expect("to_vec"); println!("{:?}", rust_result); assert_eq!(rust_result.as_slice(), &[0, 2, 2, 0]); } + + #[test] + fn test_multiple_outputs() { + let mut ctx = Context::new(); + + let three = ctx.scalar(3, xla::ElementType::F32).expect("three"); + + let x = ctx.parameter("x", [], xla::ElementType::F32).expect("x"); + let y = ctx.parameter("y", [], xla::ElementType::F32).expect("y"); + + let product = ctx.mul(x, three).expect("product"); + let sum = ctx.add(product, y).expect("sum"); + let sum2 = ctx.add(three, x).expect("sum2"); + + // output XLA + // client must be exposed to the user, it is very nice to control device, memory fraction, and pre-allocation + let client = xla::PjRtClient::gpu(0.7, false).expect("client"); + let name = "test"; + let executable = ctx.compile(&name, [sum, product, sum2], &client).expect("executable"); + + let x_input = xla::Literal::scalar(2f32); + let y_input = xla::Literal::scalar(3f32); + // args are just provided in the order they are defined, would be nice to pass a dict or something + // a pjrtbuffer is just an array slice on some device + // but im not sure why its a nested vector instead of just one vector + let device_result = executable.execute(&[x_input, y_input]).expect("execute"); + let host_result = device_result[0][0] + .to_literal_sync() + .expect("to_literal_sync"); + let (eval_sum, eval_product, eval_sum2) = host_result.to_tuple3().expect("untuple"); + let rust_result1 = eval_sum.to_vec::().expect("to_vec"); + println!("{:?}", rust_result1); + + assert_eq!(rust_result1[0], 9f32); + let rust_result2 = eval_product.to_vec::().expect("to_vec"); + assert_eq!(rust_result2[0], 6f32); + let rust_result3 = eval_sum2.to_vec::().expect("to_vec"); + assert_eq!(rust_result3[0], 5f32) + + } }