Skip to content

Commit

Permalink
fix: adaprive dimensionality #16
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Dec 7, 2024
1 parent 0e5b0da commit b297996
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 24 deletions.
18 changes: 4 additions & 14 deletions jaxadi/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def compute_heights(func, graph, antigraph):
def create_graph(func: Function):
N = func.n_instructions()
graph = [[] for _ in range(N)]
values = [None for _ in range(N)]
values = ["" for _ in range(N)]
antigraph = [[] for _ in range(N)]
output_map = {}
workers = [None for _ in range(func.sz_w())]
workers = [0 for _ in range(func.sz_w())]

for i in range(N):
op = func.instruction_id(i)
Expand All @@ -113,12 +113,7 @@ def create_graph(func: Function):
values[i] = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(func.instruction_constant(i)) + "])"
workers[o_idx[0]] = i
elif op == OP_INPUT:
this_shape = func.size_in(i_idx[0])
rows, cols = this_shape # Get the shape of the output
row_number = i_idx[1] % rows # Compute row index for JAX
column_number = i_idx[1] // rows # Compute column index for JAX

values[i] = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number)
values[i] = OP_JAX_VALUE_DICT[op].format(i_idx[0], i_idx[1])
workers[o_idx[0]] = i
elif op == OP_OUTPUT:
rows, cols = func.size_out(o_idx[0])
Expand All @@ -131,11 +126,6 @@ def create_graph(func: Function):
parent = workers[i_idx[0]]
graph[parent].append(i)
antigraph[i].append(parent)
# rows, cols = func.size_out(o_idx[0])
# row_number = o_idx[1] % rows # Compute row index for JAX
# column_number = o_idx[1] // rows # Compute column index for JAX
# output_map[i] = (o_idx[0], row_number, column_number)
# values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
elif op == OP_SQ:
values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
graph[workers[i_idx[0]]].append(i)
Expand Down Expand Up @@ -222,7 +212,7 @@ def translate(func: Function, add_jit=False, add_import=False):
if add_jit:
code += "@jax.jit\n"
code += f"def evaluate_{func.name()}(*args):\n"
code += " inputs = [jnp.expand_dims(jnp.array(arg), axis=-1) for arg in args]\n"
code += " inputs = [jnp.expand_dims(jnp.ravel(jnp.array(arg).T), axis=-1) for arg in args]\n"
code += f" outputs = [jnp.zeros(out) for out in {[func.size_out(i) for i in range(func.n_out())]}]\n"
code += f" work = jnp.zeros(({func.n_instructions()}, 1))\n"
code += codegen(graph, antigraph, heights, output_map, values)
Expand Down
2 changes: 1 addition & 1 deletion jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
OP_ATANH: "jnp.arctanh(work[{0}])",
OP_ATAN2: "jnp.arctan2(work[{0}], work[{1}])",
OP_CONST: "{0:.16f}",
OP_INPUT: "inputs[{0}][{1}, {2}]",
OP_INPUT: "inputs[{0}][{1}]",
OP_OUTPUT: "work[{0}][0]",
}
OP_JAX_EXPAND_VALUE_DICT = {
Expand Down
18 changes: 9 additions & 9 deletions tests/test_casadi_equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def compare_results(casadi_f, jax_f, *inputs):


def test_simo_trig():
x = ca.SX.sym("x", 1, 1)
x = ca.SX.sym("x", 1)
casadi_f = ca.Function("simo_trig", [x], [ca.sin(x), ca.cos(x)])
jax_f = convert(casadi_f)
x_val = np.random.randn(1, 1)
x_val = np.random.randn(1)
compare_results(casadi_f, jax_f, x_val)


Expand All @@ -56,16 +56,16 @@ def test_structural_zeros():
Y = ca.jacobian(A @ X, X)

casadi_f = ca.Function("foo", [X], [ca.densify(Y)])
jax_f = convert(casadi_f, translate=expand_translate)
jax_f = convert(casadi_f)
x_val = np.random.randn(2, 1)
compare_results(casadi_f, jax_f, x_val)


def test_simo_poly():
x = ca.SX.sym("x", 1, 1)
x = ca.SX.sym("x", 1)
casadi_f = ca.Function("simo_poly", [x], [x**2, x**3, ca.sqrt(x)])
jax_f = convert(casadi_f)
x_val = np.random.randn(1, 1)
x_val = np.random.randn(1)
x_val = np.abs(x_val) # Ensure positive for sqrt
compare_results(casadi_f, jax_f, x_val)

Expand Down Expand Up @@ -195,12 +195,12 @@ def test_sum1():


def test_dot():
x = ca.SX.sym("x", 3, 1)
y = ca.SX.sym("y", 3, 1)
x = ca.SX.sym("x", 3)
y = ca.SX.sym("y", 3)
casadi_f = ca.Function("dot", [x, y], [ca.dot(x, y)])
jax_f = convert(casadi_f)
x_val = np.random.randn(3, 1)
y_val = np.random.randn(3, 1)
x_val = np.random.randn(3)
y_val = np.random.randn(3)
compare_results(casadi_f, jax_f, x_val, y_val)


Expand Down
38 changes: 38 additions & 0 deletions tests/test_input.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,48 @@
import casadi as cs
import jax
import jax.numpy as jnp
import numpy as np

from jaxadi import convert


def test_1d_flat_behaviour():
key = jax.random.key(0)
x = cs.SX.sym("x", 2)
Ax = cs.SX.sym("Ax", 3)

Ax[0] = x[0] + x[1]
Ax[1] = -x[0]
Ax[2] = -x[1]

cs_Ax = cs.Function("cs_Ax", [x], [Ax], ["x"], ["Ax"])

jax_Ax = convert(cs_Ax)
x = jax.random.uniform(key, (4, 2))

# VMAP should fail if dimensionality
# is incompatible with the translation
Ax = jax.vmap(jax_Ax)(x)


def test_1d_non_flat_behaviour():
key = jax.random.key(0)
x = cs.SX.sym("x", 2)
Ax = cs.SX.sym("Ax", 3)

Ax[0] = x[0] + x[1]
Ax[1] = -x[0]
Ax[2] = -x[1]

cs_Ax = cs.Function("cs_Ax", [x], [Ax], ["x"], ["Ax"])

jax_Ax = convert(cs_Ax)
x = jax.random.uniform(key, (4, 2, 1))
# VMAP should fail if dimensionality
# is incompatible with the translation
Ax = jax.vmap(jax_Ax)(x)


def test_different_shapes():
x = cs.SX.sym("x", 2, 3)
y = cs.SX.sym("y", 3, 2)
Expand Down

0 comments on commit b297996

Please sign in to comment.