Skip to content

Latest commit

 

History

History
431 lines (359 loc) · 17.4 KB

README.md

File metadata and controls

431 lines (359 loc) · 17.4 KB

Lego (name in progress) [WIP]

Lego is a library experimenting with multi-stage programming in rust.

Introduction

If you have been programming in Rust, it is safe to assume that you have used a form or another of meta-programming. In essence, metaprogramming let's us write code that writes code for us. In Rust, there are two main methods for meta-programming:

  • macros, that use substitution to general code, effectively manipulating AST node (even mode evident with procedural macros). Here's an example of a macro that generate a function add for a variaty of types:
trait MyAdd {
    fn my_add(self, other: Self) -> Self;
}

macro_rules! impl_my_add_for {
    ($ty:iden) => {
        impl MyAdd for $ty {
            fn my_add(self, other: Self) -> Self {
                self + other
            }
        }
    };
}

impl_my_add_for!(usize);
impl_my_add_for!(u64);
// ...

In this example, $ty is substituted for the argument passed to the macro invocation. For each invocation of the impl_my_add_for, new code is generated, that is in turn compiled to machine code. This example is very straightforward, but macros are a very powerful tool to write code that writes code.

  • generic programming: Rust's traits and generics also let us write code that writes code. Although, contrary to macro, where we manipulate AST, generic programming lets us define contraints, and let the compiler generate an implementation as needed:
trait MyAdd {
    fn my_add(self, other: Self) -> Self;
}

impl<T: Add<Output = T>> MyAdd for T {
    fn my_add(self, other: Self) -> Self {
        Add::add(self other)
    }
}

In this example, we say that any type T that satisfies Add, also satisfies MyAdd, and declare how so. When the compiler finds an instance of my add for two types that satisfies this constraint, it is able to generate code that specialize for this particular type:

2u64.my_add(3u64); // u64 implements Add, therefore it implements `MyAdd`: the compiler generates a specialized implementation of my add for u64.

Macros and trait are very often used together, they are a powerful tool that help us write correct and fast code. Although, they are limited to compile time. This means that we need to know ahead of (run)time all the implementations that we will need. For may things this constaint is just fine. When it's not, Rust provides programming constructs to get the job done. The most widely used has to be dynamic dispatch (aka dyn Trait). Note that we're now exiting the realm of meta-programming.

Let's make this more concrete, and hopefully illustrate the difference between runtime/compile time. To that end, we'll introduce a mini [jq]-like program, that processes lists of json objects. Our jq-like processor has the following operations:

  • map: apply some transformation to the incoming json objects
  • filter: select incoming json ojects that satisfy some predicate
  • reduce: aggregate incoming ojects given some aggregation function

On top of that, we have some expression language that let's us manipulate json values. Here's an example expression:

// or json ojects have the shape: { "x": number, "y": number };
map { x: x + 1, ..} | filter .y > 10 | reduce { output: 0 } { output: .output + .x }

In this example, for each object in the input:

  • we update the x field, adding 1 to it, the rest of the object is unchanged
  • we filter out any oject whose y field is less than 10
  • we collect the sum of the xs in the output field of the the output object, providing an initial value { output: 0}.

Let's ignore the lousy semantics of our language a bit, and let imagine what translating that expression to rust would look like. Iterators are a great fit for that type of computation:

struct Object {
    x: f64,
    y: f64,
}

struct Output {
    output: f64
}

fn main() {
    let objects: Vec<Object> = read_json(..);

    objects
        .into_iter()
        // map { x: x + 1, ..}
        .map(|mut o| { o.x += 1.0; o})
        // filter .y > 10 
        .filter(|o| o.y > 10.0)
        // reduce { output: 0 } { output: .output + .x }
        .fold(Output { output: 0.0 }, |mut output, o| {
            output.output += o.x;
            ouput
        });
}

This example illustrates how we can trivially write an equivalent Rust expression to our jq-like expression. The code generated by this implementation is very efficient, and the code is very readable. This is a short snippet, and yet, quite a lot is hapenning behind the scene, and this short snippet will be compiled down to something roughly equivalent to:

let mut output = Output { output = 0.0 };
for mut o in objects {
    o.x += 1;
    if !(o.y > 10.0) {
        continue
    }

    output.output += o.x;
}

The reason for this is that the compiler has all the information it needs a compile time to generate a specialized iterator, perform inlining, and plethora of other optimizations that outputs efficient machine code.

However, at runtine, we don't get any of that. We must prepare arbitrary combination of map, filter and reduce, arbitrary expression for those operators, arbitrary input types etc. All this generality has a cost. It is evident that we cannot generate efficient implementations for all combination at compile time (that would be an infinite amount of them), so we have to resort to dymanic approaches (in pseudo-rust):

enum Value {
    Number(f64),
    Bool(bool),
    Array(Vec<Self>),
    Map(HashMap<Self>),
    Null,
}

enum Operator {
    /// A source of values
    Source(Box<dyn Iterator<Item = Value>>),
    /// apply expr on every element of inner
    Map {
        inner: Box<Self>,
        expr: Expr,
    }
    /// keep only element for which Expr evaluates to true
    Filter {
        inner: Box<Self>,
        expr: Expr,
    },
    /// iteratively evaluate Expr with Value and the incoming object,
    /// updating value with the output of Expr on each iteration
    Reduce {
        inner: Box<Self>,
        init: Value,
        expr: Expr,
    },
}

impl Operator {
    // given a jq-like expression, build an operator, wrapping src
    fn parse(input: &str, src: Box<dyn Iterator<Item = Value>>) -> Self {
        // ...
    }

    // we can still use iterators for Operator
    // The operator chain is turned into a composition of Box<dyn Iterator<Value>>.
    fn into_iterator(self) -> Box<dyn Iterator<Item = Value>> {
        match self {
            // nothing to do on a source iterator, just return it
            Operator::Source(iterator) => iterator,
            // apply expr to every item from inner
            Operator::Map { inner, expr } => {
                let it = inner.into_iterator();
                Box::new(it.map(|o| expr.eval(&[o])))
            },
            // apply the filter expr to every item from inner
            Operator::Filter { inner, expr } => {
                let it = inner.into_iterator();
                Box::new(it.filter(|o| { 
                    expr.eval(&[o]).into_bool()
                }))
            },
            // reduce input iter
            Operator::Reduce { inner, init, expr } => {
                let it = inner.into_iterator();
                Box::new(std::iter::once(it.fold(init, |acc, o| {
                    expr.eval(&[acc, o])
                })))
            },
        }
    }
}

// I will not go into too much details in the Expr language. I'll show how a single operator could be implemented,
// and let you generalize.
enum Expr {
    Add,
    // .. any operator you can think of
}

impl Expr {
    fn eval(&self, args: &[Value]) -> Value {
        match self {
            Self::Add {
                let lhs = args[0];
                let rhs = args[1];
                match (lhs, rhs) {
                    (Value::Number(l), Value::Number(r)) => Value::Number(l + r),
                    _ => unimplemented!("feel free to generalize"),
                }
            }
        }
    }
}

fn main() {
    let objects: Vec<Object> = read_json(..);
    let op = Operator::parse(input_expr, Box::new(objects.into_iter()));

    op.next().unwrap(); // <- the result of our chain
}

For our example expression, the operator would look something like:

Operator::Fold {
    init: Value::Map(/* {output: 0.0 } */),
    expr: Expr::parse("{ output: .output + .x }")
    inner: Operator::Filter {
        expr: Expr::parse(".y > 10.0"),
        inner: Operator::Map {
            expr: Expr::parse("{ x: .x + 1, ..}")
            inner: Operator::Source(...)
        }
    }
}

Now, there is nothing wrong with this implementation. As a matter of fact, this is usually how this you would implement something like this: this is essentially an interpretter for a DSL. Our implementation uses a tree-walking approach for expression, and iterators and dynamic dispacth to compose pipeline elements, jq implementation is more clever, and uses a bytecode interpreter to squeeze more performance, but essentially, those approaches are the similar: we trade performance for runtime generality.

What about Lego

Sorry for that (long) digression on meta-programming and runtime vs. compile time, but it's important to have those concept in mind to understand what lego attempts to offer.

In essence, what lego does, is bring meta-programming to the runtime stage, and do so in a way that feels natural for rust, and provide semantics close to those of rust. Lego provides primitives that wrap Rust's, and behave like them in apparence, but they are really just a way to describe computations that are then, at runtime, compiled to specialized machine code. In a way, you can think lego as a way to JIT rust at runtime, although lego is more constraining. An example speaks more than a thousands words, so let's dive in.

The entry point to lego are functions. Here's how you define a lego function:

// This is the context in which we will be building our functions. Functions and their data are attached to this context
let builder = Ctx::builder();
let mut ctx = builder.build();

// we create new function that takes a slice of i32's and return a single i32
let my_func = ctx.func::<&[i32], i32>(|items: Slice<i32>| -> Val<i32> {
    // TODO: function body
    0i32.value()
});

This doesn't do much yet, but there's a couple of things to point out already:

  • The type of items is Slice<i32>, and not &[i32]. Remember we are not operating on the data itself, but on it's representation.
  • the return type of the closure is Val<i32>: this is a representation of a i32.
  • 0i32.value() construct a constant Val<i32> out of a rust integer.
  • ctx.func will evaluate the passed closure immediately and generate a native function that performs the computation described by the closure body.

Let's make my_func compute the sum of the elements of items with 1 added to them:

// we create new function that takes a slice of i32's and return a single i32
let my_func = ctx.func::<&[i32], i32>(|items: Slice<i32>| -> Val<i32> {
    let sum = items
        .into_jiter()
        .map(|it| it.deref() + 1i32)
        .fold(0i32.value(), |acc, it| {
            acc + it
        });

    sum
});

Let's break this up:

  • into_jiter transforms the Slice<i32> of items into something that implements JITerator<Item = Ref<i32>>. JITerator is like a rust Iterator, except that calling next on it will not return you an immediate value, but rather some handle to the result of that computation.
  • Ref<T> is an immutable reference to a T. You can get a Val<T> out of a Ref<T> by calling deref() on it.
  • map wraps our jiterator into another jiterator that applied the passed closures to the items of the first iterator. In the closure, we call deref on it to get its value, and add 1 to it. +, and generally rust operators, are overloaded in such a way that it lets you perform operations naturally between wrapper types whose wrapped type would usually take part in the operation, e.g u32 + Val<u32>, Val<u32> + Val<u32> are all accepted.
  • fold accumulate the result of the expression with 0 as initial value.

We can now get obtain a reference to the native function and invoke it:

// get a reference to the native function
let my_func_native = ctx.get_compiled_function(main);

let items = &[1, 2, 3, 4];
println!("{}", my_func_native.call(items));
// prints: 14

One thing to understand here is that my_func_native is effectively a thin wrapper around a raw function pointer, that ensures that the function is called with the correct types. The iterator in the function has been compiled to a single loop, just like rust would have. If you are interested, here's the cranelift IR generated for this function:

  pushq   %rbp
  unwind PushFrameRegs { offset_upward_to_caller_sp: 16 }
  movq    %rsp, %rbp
  unwind DefineNewFrame { offset_upward_to_caller_sp: 16, offset_downward_to_clobbers: 0 }
block0:
  xorq    %rdx, %rdx, %rdx
  xorl    %eax, %eax, %eax
  jmp     label1
block1:
  cmpq    %rdi, %rdx
  jnz     label3; j label2
block2:
  movq    %rbp, %rsp
  popq    %rbp
  ret
block3:
  imulq   %rdx, 0x4, %r10
  lea     1(%rdx), %rdx
  movl    0(%rsi,%r10,1), %r10d
  lea     1(%rax,%r10,1), %eax
  jmp     label1

This simple example is just to give you a sense of how close to actual Rust writing lego functions feels. We benefit from rust's type safety, operator overloading, and the high-level construct from lego let us write very simple code, that specializes at runtime.

To fully grasp lego's power, it is necessary to understand that whatever is in the func closure's body is effectively partially evaluated: every expression that can be evaluated when the closure is invoked is evaluated. This means that we can combine rust code that is evaluated right away to specialize rust code that will be evaluated later (when the native function is invoked). Let's illustrate that:

fn main() {
    let builder = Ctx::builder();
    let mut ctx = builder.build();

    // we get n at runtime
    let n = std::env::args().nth(1).unwrap().parse::<i32>().unwrap();

    // My pow generate specialed pow(x, n) for an arbitrary n, known at runtime
    let my_pow = ctx.func::<i32, i32>(|x| {
        let mut out = Var::new(x);
        for _ in 1..n {
            out *= x;
        }

        out.value()
    });

    let main = ctx.get_compiled_function(my_pow);

    let x = std::env::args().nth(2).unwrap().parse::<i32>().unwrap();

    println!("{}", main.call(x));
    /// cargo r -- 3 3; prints 27
}
});

This is the generated asm with x = 3 and n = 3:

  pushq   %rbp
  unwind PushFrameRegs { offset_upward_to_caller_sp: 16 }
  movq    %rsp, %rbp
  unwind DefineNewFrame { offset_upward_to_caller_sp: 16, offset_downward_to_clobbers: 0 }
block0:
  movq    %rdi, %rax
  imull   %eax, %edi, %eax
  imull   %eax, %edi, %eax
  movq    %rbp, %rsp
  popq    %rbp
  ret

This is the generated asm with x = 3 and n = 4:

  pushq   %rbp
  unwind PushFrameRegs { offset_upward_to_caller_sp: 16 }
  movq    %rsp, %rbp
  unwind DefineNewFrame { offset_upward_to_caller_sp: 16, offset_downward_to_clobbers: 0 }
block0:
  movq    %rdi, %rax
  imull   %eax, %edi, %eax
  imull   %eax, %edi, %eax
  imull   %eax, %edi, %eax
  movq    %rbp, %rsp
  popq    %rbp
  ret

The consecutive multiplication get unrolled. That's because all the information we need to evaluate the loop is know when we execute the closure.

I have introduced Lego as a multi-stage programming library. So far we can, we've seen two stages:

  • Code that is evaluated to an immediate result within func, let's call this stage runtime compile time
  • Code that is evaluated when the generated function is invoked: let's call it runtime runtime.

But why stop here? We can hook into Rust existing metaprogramming and go one level deeper: we already have a function that generates function at runtime, why note a function that generate functions at compile time that generate functions at runtime?:

fn do_pow_spec<T>(ctx: &mut Ctx, n: usize)
// ignore the type bounds, I'm hoping I can simplify them in the future
where
    T: Param + Primitive + FromStr + ToFFIParams + Display + IntMul,
    T::Out<Bottom>: ToFFIFunctionParams,
    T::Ty: AsVal<Ty = T> + Copy,
    Var<T>: MulAssign<T::Ty>,
{
    let my_pow = ctx.func::<T, T>(|x| {
        let mut out = Var::new(x);
        for _ in 1..n {
            out *= x;
        }

        out.value()
    });

    let x = std::env::args().nth(2).unwrap().parse::<T>().unwrap_or_else(|_| panic!());
    let res = ctx.get_compiled_function(my_pow).call(x);
    println!("{res}");
}

fn main() {
    let builder = Ctx::builder();
    let mut ctx = builder.build();

    let spec = std::env::args().nth(3).unwrap();

    let n = std::env::args().nth(1).unwrap().parse::<usize>().unwrap_or_else(|_| panic!());

    match spec.as_str() {
        "u64" => do_pow_spec::<u64>(&mut ctx, n),
        "i32" => do_pow_spec::<i32>(&mut ctx, n),
        _ => panic!("unsupported specialization"),
    }
}

do_pow_spec introduces a 3rd stage: at compile time, it produces specialization of func for arbitrary integer types, as we can see in main.

Disclaimer

This work is very experimental. A lot of implementations are missing, the API is not in the shape I want it to be, there're a lot of missing features and likely bugs. This is public for discussion only.

Inspiration