Skip to content

Commit

Permalink
Merge pull request #36 from Ebanflo42/graph/output
Browse files Browse the repository at this point in the history
Multiple outputs for compiled compute graph
  • Loading branch information
BradenEverson authored Feb 25, 2024
2 parents 8850017 + 8ef60f2 commit 0e0c847
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 25 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 34 additions & 16 deletions src/core/graph/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub enum CompileError {
#[error("Found an unused Parameter in the compute graph {0}")]
UnusedParameter(Callsite),

#[error("Unable to compile a context that does not return")]
NoReturn,

#[error("XLA error: {0}")]
Xla(#[from] xla::Error),
}
Expand Down Expand Up @@ -50,43 +53,56 @@ impl Context {
}
}

pub fn compile<A: Into<NodeIdentifier> + Copy>(
pub fn compile<A: Into<NodeIdentifier> + Copy, const N: usize>(
&mut self,
a: A,
name: &str,
returns: [A; N],
client: &xla::PjRtClient,
) -> Result<xla::PjRtLoadedExecutable> {
// TODO: gate debug mode behind a feature flag

//self.autodiff(a, usize::MAX);
println!("{}", self.to_string(a));
while self.autodiff(a, 1)? {
println!("{}", self.to_string(a));
if returns.is_empty() {
Err(CompileError::NoReturn)?;
}

//self.foldconsts(a, usize::MAX);
while self.foldconsts(a, 1)? {
println!("{}", self.to_string(a));
for a in returns.iter() {
self.autodiff(*a, usize::MAX)?;
}
//println!("{}", self.to_string(a));
//while self.autodiff(a, 1)? {
// println!("{}", self.to_string(a));
//}

//self.extract_subterms(a, usize::MAX);
while self.extract_subterms(a, 1)? {
println!("{}", self.to_string(a));
for a in returns.iter() {
self.foldconsts(*a, usize::MAX)?;
}
//while self.foldconsts(a, 1)? {
// println!("{}", self.to_string(a));
//}

for a in returns.iter() {
self.extract_subterms(*a, usize::MAX)?;
}
//while self.extract_subterms(a, 1)? {
// println!("{}", self.to_string(a));
//}

let builder = xla::XlaBuilder::new(name);
// Get the bottom-up dependencies of the compute graph
let mut dependent_nodes = HashMap::new();
let mut constants = HashSet::new();
let mut parameters = HashSet::new();
self.get_dependent_nodes(a, &mut dependent_nodes, &mut constants, &mut parameters)?;
for a in returns.iter() {
self.get_dependent_nodes(*a, &mut dependent_nodes, &mut constants, &mut parameters)?;
}

// Prepare to loop through the unda compute graph and construct the XLA compute graph
let mut xla_op_slotmap: SlotMap<NodeIdentifier, xla::XlaOp> = SlotMap::with_key();
let mut unda_op_queue: VecDeque<NodeIdentifier> = VecDeque::new();
let mut unda_xla_map: HashMap<NodeIdentifier, NodeIdentifier> = HashMap::new();
let mut covered_ops: HashSet<NodeIdentifier> = HashSet::new();

let builder = xla::XlaBuilder::new(name);

// declare parameters with the XLA builder
for (i, unda_id) in self.param_indices.iter().enumerate() {
let node = &self.nodes[*unda_id];
Expand Down Expand Up @@ -116,7 +132,6 @@ impl Context {
}

// Initialize constants for the XLA builder
// >1 dimensions not yet supported for constants
for unda_id in constants.iter() {
let node = &self.nodes[*unda_id];

Expand Down Expand Up @@ -182,7 +197,10 @@ impl Context {
}
}

let xla_computation = xla_op_slotmap[unda_xla_map[&a.into()]].build()?;
let xla_return_vec: Vec<&xla::XlaOp> = returns.into_iter().map(|i| &xla_op_slotmap[unda_xla_map[&i.into()]]).collect();
let xla_return_tuple = builder.tuple(&xla_return_vec.as_slice())?;

let xla_computation = xla_return_tuple.build()?;

Ok(xla_computation.compile(client)?)
}
Expand Down
3 changes: 3 additions & 0 deletions src/core/graph/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ pub enum ContextError {

#[error("Parameter \"{0}\" {1} already exists in the context at {2}")]
DuplicateParameter(String, Callsite, Callsite),

#[error("Tried to call Context::return more than once.")]
MultipleReturns(),
}

pub type Result<T> = std::result::Result<T, ContextError>;
Expand Down
55 changes: 49 additions & 6 deletions src/core/graph/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mod tests {
// client must be exposed to the user, it is very nice to control device, memory fraction, and pre-allocation
let client = xla::PjRtClient::gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx.compile(sum, &name, &client).expect("executable");
let executable = ctx.compile(&name, [sum], &client).expect("executable");

let x_input = xla::Literal::scalar(2f32);
let y_input = xla::Literal::scalar(3f32);
Expand All @@ -30,7 +30,8 @@ mod tests {
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let rust_result = host_result.to_vec::<f32>().expect("to_vec");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result);

assert_eq!(rust_result[0], 9f32);
Expand All @@ -49,13 +50,14 @@ mod tests {
let barbaz = ctx.mul(bar, baz).expect("barbaz");

let client = xla::PjRtClient::gpu(0.7, false).expect("client");
let executable = ctx.compile(barbaz, "test", &client).expect("executable");
let executable = ctx.compile("test", [barbaz], &client).expect("executable");

let device_result = executable.execute::<xla::Literal>(&[]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let f32_result = host_result
let untupled_result = host_result.to_tuple1().expect("untuple");
let f32_result = untupled_result
.convert(xla::ElementType::F32.primitive_type())
.expect("f32 conversion");
let rust_result = f32_result.to_vec::<f32>().expect("to_vec");
Expand All @@ -79,16 +81,57 @@ mod tests {
let sum = ctx.add(my_const, my_param).expect("sum");

let client = xla::PjRtClient::gpu(0.7, false).expect("client");
let executable = ctx.compile(sum, "test", &client).expect("executable");
let executable = ctx.compile("test", [sum], &client).expect("executable");

let my_param_input = xla::Literal::read_npy("test.npy", &()).expect("my_param_input");

let device_result = executable.execute(&[my_param_input]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let rust_result = host_result.to_vec::<i64>().expect("to_vec");
let untupled_result = host_result.to_tuple1().expect("untuple");
let rust_result = untupled_result.to_vec::<i64>().expect("to_vec");
println!("{:?}", rust_result);
assert_eq!(rust_result.as_slice(), &[0, 2, 2, 0]);
}

#[test]
fn test_multiple_outputs() {
let mut ctx = Context::new();

let three = ctx.scalar(3, xla::ElementType::F32).expect("three");

let x = ctx.parameter("x", [], xla::ElementType::F32).expect("x");
let y = ctx.parameter("y", [], xla::ElementType::F32).expect("y");

let product = ctx.mul(x, three).expect("product");
let sum = ctx.add(product, y).expect("sum");
let sum2 = ctx.add(three, x).expect("sum2");

// output XLA
// client must be exposed to the user, it is very nice to control device, memory fraction, and pre-allocation
let client = xla::PjRtClient::gpu(0.7, false).expect("client");
let name = "test";
let executable = ctx.compile(&name, [sum, product, sum2], &client).expect("executable");

let x_input = xla::Literal::scalar(2f32);
let y_input = xla::Literal::scalar(3f32);
// args are just provided in the order they are defined, would be nice to pass a dict or something
// a pjrtbuffer is just an array slice on some device
// but im not sure why its a nested vector instead of just one vector
let device_result = executable.execute(&[x_input, y_input]).expect("execute");
let host_result = device_result[0][0]
.to_literal_sync()
.expect("to_literal_sync");
let (eval_sum, eval_product, eval_sum2) = host_result.to_tuple3().expect("untuple");
let rust_result1 = eval_sum.to_vec::<f32>().expect("to_vec");
println!("{:?}", rust_result1);

assert_eq!(rust_result1[0], 9f32);
let rust_result2 = eval_product.to_vec::<f32>().expect("to_vec");
assert_eq!(rust_result2[0], 6f32);
let rust_result3 = eval_sum2.to_vec::<f32>().expect("to_vec");
assert_eq!(rust_result3[0], 5f32)

}
}

0 comments on commit 0e0c847

Please sign in to comment.