From b29799682397202f8eef343601209838fb73f4f1 Mon Sep 17 00:00:00 2001 From: Igor Alentev Date: Sat, 7 Dec 2024 17:05:40 +0900 Subject: [PATCH] fix: adaprive dimensionality #16 --- jaxadi/_graph.py | 18 ++++------------- jaxadi/_ops.py | 2 +- tests/test_casadi_equality.py | 18 ++++++++--------- tests/test_input.py | 38 +++++++++++++++++++++++++++++++++++ 4 files changed, 52 insertions(+), 24 deletions(-) diff --git a/jaxadi/_graph.py b/jaxadi/_graph.py index 6dbcbf8..2940436 100644 --- a/jaxadi/_graph.py +++ b/jaxadi/_graph.py @@ -99,10 +99,10 @@ def compute_heights(func, graph, antigraph): def create_graph(func: Function): N = func.n_instructions() graph = [[] for _ in range(N)] - values = [None for _ in range(N)] + values = ["" for _ in range(N)] antigraph = [[] for _ in range(N)] output_map = {} - workers = [None for _ in range(func.sz_w())] + workers = [0 for _ in range(func.sz_w())] for i in range(N): op = func.instruction_id(i) @@ -113,12 +113,7 @@ def create_graph(func: Function): values[i] = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(func.instruction_constant(i)) + "])" workers[o_idx[0]] = i elif op == OP_INPUT: - this_shape = func.size_in(i_idx[0]) - rows, cols = this_shape # Get the shape of the output - row_number = i_idx[1] % rows # Compute row index for JAX - column_number = i_idx[1] // rows # Compute column index for JAX - - values[i] = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number) + values[i] = OP_JAX_VALUE_DICT[op].format(i_idx[0], i_idx[1]) workers[o_idx[0]] = i elif op == OP_OUTPUT: rows, cols = func.size_out(o_idx[0]) @@ -131,11 +126,6 @@ def create_graph(func: Function): parent = workers[i_idx[0]] graph[parent].append(i) antigraph[i].append(parent) - # rows, cols = func.size_out(o_idx[0]) - # row_number = o_idx[1] % rows # Compute row index for JAX - # column_number = o_idx[1] // rows # Compute column index for JAX - # output_map[i] = (o_idx[0], row_number, column_number) - # values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) elif op == OP_SQ: values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) graph[workers[i_idx[0]]].append(i) @@ -222,7 +212,7 @@ def translate(func: Function, add_jit=False, add_import=False): if add_jit: code += "@jax.jit\n" code += f"def evaluate_{func.name()}(*args):\n" - code += " inputs = [jnp.expand_dims(jnp.array(arg), axis=-1) for arg in args]\n" + code += " inputs = [jnp.expand_dims(jnp.ravel(jnp.array(arg).T), axis=-1) for arg in args]\n" code += f" outputs = [jnp.zeros(out) for out in {[func.size_out(i) for i in range(func.n_out())]}]\n" code += f" work = jnp.zeros(({func.n_instructions()}, 1))\n" code += codegen(graph, antigraph, heights, output_map, values) diff --git a/jaxadi/_ops.py b/jaxadi/_ops.py index 26633ea..cb7bea1 100644 --- a/jaxadi/_ops.py +++ b/jaxadi/_ops.py @@ -94,7 +94,7 @@ OP_ATANH: "jnp.arctanh(work[{0}])", OP_ATAN2: "jnp.arctan2(work[{0}], work[{1}])", OP_CONST: "{0:.16f}", - OP_INPUT: "inputs[{0}][{1}, {2}]", + OP_INPUT: "inputs[{0}][{1}]", OP_OUTPUT: "work[{0}][0]", } OP_JAX_EXPAND_VALUE_DICT = { diff --git a/tests/test_casadi_equality.py b/tests/test_casadi_equality.py index 3fc2457..22d41d7 100644 --- a/tests/test_casadi_equality.py +++ b/tests/test_casadi_equality.py @@ -31,10 +31,10 @@ def compare_results(casadi_f, jax_f, *inputs): def test_simo_trig(): - x = ca.SX.sym("x", 1, 1) + x = ca.SX.sym("x", 1) casadi_f = ca.Function("simo_trig", [x], [ca.sin(x), ca.cos(x)]) jax_f = convert(casadi_f) - x_val = np.random.randn(1, 1) + x_val = np.random.randn(1) compare_results(casadi_f, jax_f, x_val) @@ -56,16 +56,16 @@ def test_structural_zeros(): Y = ca.jacobian(A @ X, X) casadi_f = ca.Function("foo", [X], [ca.densify(Y)]) - jax_f = convert(casadi_f, translate=expand_translate) + jax_f = convert(casadi_f) x_val = np.random.randn(2, 1) compare_results(casadi_f, jax_f, x_val) def test_simo_poly(): - x = ca.SX.sym("x", 1, 1) + x = ca.SX.sym("x", 1) casadi_f = ca.Function("simo_poly", [x], [x**2, x**3, ca.sqrt(x)]) jax_f = convert(casadi_f) - x_val = np.random.randn(1, 1) + x_val = np.random.randn(1) x_val = np.abs(x_val) # Ensure positive for sqrt compare_results(casadi_f, jax_f, x_val) @@ -195,12 +195,12 @@ def test_sum1(): def test_dot(): - x = ca.SX.sym("x", 3, 1) - y = ca.SX.sym("y", 3, 1) + x = ca.SX.sym("x", 3) + y = ca.SX.sym("y", 3) casadi_f = ca.Function("dot", [x, y], [ca.dot(x, y)]) jax_f = convert(casadi_f) - x_val = np.random.randn(3, 1) - y_val = np.random.randn(3, 1) + x_val = np.random.randn(3) + y_val = np.random.randn(3) compare_results(casadi_f, jax_f, x_val, y_val) diff --git a/tests/test_input.py b/tests/test_input.py index 372244a..4919d8c 100644 --- a/tests/test_input.py +++ b/tests/test_input.py @@ -1,10 +1,48 @@ import casadi as cs +import jax import jax.numpy as jnp import numpy as np from jaxadi import convert +def test_1d_flat_behaviour(): + key = jax.random.key(0) + x = cs.SX.sym("x", 2) + Ax = cs.SX.sym("Ax", 3) + + Ax[0] = x[0] + x[1] + Ax[1] = -x[0] + Ax[2] = -x[1] + + cs_Ax = cs.Function("cs_Ax", [x], [Ax], ["x"], ["Ax"]) + + jax_Ax = convert(cs_Ax) + x = jax.random.uniform(key, (4, 2)) + + # VMAP should fail if dimensionality + # is incompatible with the translation + Ax = jax.vmap(jax_Ax)(x) + + +def test_1d_non_flat_behaviour(): + key = jax.random.key(0) + x = cs.SX.sym("x", 2) + Ax = cs.SX.sym("Ax", 3) + + Ax[0] = x[0] + x[1] + Ax[1] = -x[0] + Ax[2] = -x[1] + + cs_Ax = cs.Function("cs_Ax", [x], [Ax], ["x"], ["Ax"]) + + jax_Ax = convert(cs_Ax) + x = jax.random.uniform(key, (4, 2, 1)) + # VMAP should fail if dimensionality + # is incompatible with the translation + Ax = jax.vmap(jax_Ax)(x) + + def test_different_shapes(): x = cs.SX.sym("x", 2, 3) y = cs.SX.sym("y", 3, 2)