Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsanbear committed Mar 15, 2024
1 parent e38c90a commit 5761311
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 2 deletions.
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ repository = "https://github.com/tomsanbear/candle-einops"
candle-core = { version = "0.4" }
candle-nn = { version = "0.4" }
candle-einops-macros = { path = "candle-einops-macros", version = "0.1.0" }
iter_tools = "0.10.0"
itertools = "0.12.1"

[dev-dependencies]
anyhow = "1"

[package.metadata.docs.rs]
no-default-features = true

[package.syn]
features = ["proc_macro_quote"]
1 change: 1 addition & 0 deletions candle-einops-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ proc-macro = true
proc-macro2 = "1"
quote = "1"
syn = { version = "2", features = ["full", "extra-traits"] }
itertools = "0.12"
83 changes: 83 additions & 0 deletions candle-einops-macros/src/einsum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use itertools::Itertools;
use proc_macro2::{Ident, Literal, TokenStream};
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::{Ref, Token};
use syn::{braced, token, Expr, Field, LitStr, Token};

pub fn einsum(input: proc_macro2::TokenStream) -> syn::Result<TokenStream> {
let parsed_expression: ParsedExpression = syn::parse2(input)?;
let code = quote! {
#parsed_expression
};
Ok(code)
}

/// Parses syntax for an einsum expression
/// einsum!("a b c, d e f -> a b c f", &x, &y)
/// where each comma delimited string before the arrow is an input tensor and the string after the arrow is the output tensor
struct ParsedExpression {}

impl Parse for ParsedExpression {
fn parse(input: ParseStream) -> syn::Result<Self> {
// gets the first argument, this determines how many input tensors we need
if !input.peek(LitStr) {
return Err(input.error("first argument must be a string literal"));
}
let expression: LitStr = input.parse()?;

// Iterate across the stream until we collect all tensors
let mut tensors: Vec<Ident> = vec![];
for _ in 0.. {
// advance past the comma
if !input.peek(Token![,]) {
break;
}
input.parse::<Token![,]>()?;

// parse the next tensor, should be an identifier or reference to an identifier
tensors.push(input.parse()?);
}

// Extract LHS and RHS of the expression, split on ->
let expression_str = expression.value();
let (lhs, rhs) = expression_str
.split("->")
.collect_tuple()
.ok_or_else(|| input.error("expected expression to contain ->"))?;
let lhs = lhs.trim();
let rhs = rhs.trim();

// Gather the tensors information from the lhs and rhs
let input_tensors = lhs.split(',').into_iter().map(|s| s.trim()).collect_vec();
let output_tensor = rhs.trim();

// Ensure the number of tensors provided matches the number of tensors in the expression
if tensors.len() != input_tensors.len() {
return Err(input.error(format!(
"expected {} input tensors, got {}",
input_tensors.len(),
tensors.len()
)));
}

// Get the shape of each input

return Err(input.error(format!(
"lhs {:?} rhs {:?} tensor count {:?}",
lhs,
rhs,
input_tensors.len()
)));
}
}

impl quote::ToTokens for ParsedExpression {
fn to_tokens(&self, tokens: &mut TokenStream) {
let expression = quote! { "a b c, d e f -> a b c f" };
tokens.extend(quote! {
let expression = #expression;
});
}
}
9 changes: 9 additions & 0 deletions candle-einops-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod einops;
mod einsum;

/// Macro to perform tensor transformations using simple expressions
///
Expand All @@ -13,3 +14,11 @@ pub fn einops(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
.unwrap_or_else(|e| e.to_compile_error())
.into()
}

/// Macro to perform einsum, equivalent to torch.einsum
#[proc_macro]
pub fn einsum(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
einsum::einsum(input.into())
.unwrap_or_else(|e| e.to_compile_error())
.into()
}
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
mod backend;

pub use candle_einops_macros::einops;

pub use backend::Backend;
pub use candle_einops_macros::einops;
pub use candle_einops_macros::einsum;

/// Specifies the operation used to reduce an axis
#[derive(Copy, Clone, Debug)]
Expand Down
12 changes: 12 additions & 0 deletions tests/einsum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use anyhow::Result;
use candle_core::{Device, Tensor};
use candle_einops::einsum;
#[test]
fn simple() -> Result<()> {
let device = Device::Cpu;
let x = &Tensor::new(&[1, 2, 3, 4], &device)?;
let y = &Tensor::new(&[1, 2, 3, 4], &device)?;
einsum!("a b c, d e f -> a e f", x, y)?;

Ok(())
}

0 comments on commit 5761311

Please sign in to comment.