Skip to content

Commit

Permalink
jax-specific fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bwpriest committed Dec 20, 2023
1 parent 13b6633 commit 54c0270
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions MuyGPyS/_src/gp/muygps/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def _muygps_posterior_mean(
batch_shape = Kin.shape[: -2 * len(in_shape)] # (b_1, ..., b_b)
extra_shape = batch_nn_targets.shape[len(batch_shape) + len(in_shape) :]

in_size = jnp.prod(in_shape, dtype=int)
out_size = jnp.prod(out_shape, dtype=int)
extra_size = jnp.prod(extra_shape, dtype=int)
in_size = jnp.prod(jnp.array(in_shape), dtype=int)
out_size = jnp.prod(jnp.array(out_shape), dtype=int)
extra_size = jnp.prod(jnp.array(extra_shape), dtype=int)

batch_nn_targets_flat = batch_nn_targets.reshape(
batch_shape + (in_size, extra_size)
Expand Down Expand Up @@ -60,8 +60,8 @@ def _muygps_diagonal_variance(
in_shape = Kin.shape[batch_size + in_dim_count :]
out_shape = Kcross.shape[batch_size + in_dim_count :]

in_size = jnp.prod(in_shape, dtype=int)
out_size = jnp.prod(out_shape, dtype=int)
in_size = jnp.prod(jnp.array(in_shape), dtype=int)
out_size = jnp.prod(jnp.array(out_shape), dtype=int)

Kin_flat = Kin.reshape(batch_shape + (in_size, in_size))
Kcross_flat = Kcross.reshape(batch_shape + (in_size, out_size))
Expand Down
4 changes: 2 additions & 2 deletions MuyGPyS/_src/optimize/scale/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def _analytic_scale_optim(
batch_shape = Kin.shape[:batch_dim_count]
in_shape = Kin.shape[batch_dim_count + in_dim_count :]

batch_size = jnp.prod(batch_shape, dtype=int)
in_size = jnp.prod(in_shape, dtype=int)
batch_size = jnp.prod(jnp.array(batch_shape), dtype=int)
in_size = jnp.prod(jnp.array(in_shape), dtype=int)

Kin_flat = Kin.reshape(batch_shape + (in_size, in_size))
nn_targets_flat = nn_targets.reshape(batch_shape + (in_size, 1))
Expand Down

0 comments on commit 54c0270

Please sign in to comment.