Skip to content
This repository has been archived by the owner on Jan 23, 2024. It is now read-only.

Commit

Permalink
Merge pull request #9 from TheSouthFrog/branch1
Browse files Browse the repository at this point in the history
Fix the bugs from the param type of take_along_axis
  • Loading branch information
HuiwenChang authored Nov 18, 2022
2 parents cf615d4 + d7c8ec0 commit 1db2359
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions maskgit/libml/parallel_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def mask_by_random_topk(rng, mask_len, probs, temperature=1.0):
rng, probs.shape)
sorted_confidence = jnp.sort(confidence, axis=-1)
# Obtains cut off threshold given the mask lengths.
cut_off = jnp.take_along_axis(sorted_confidence, mask_len, axis=-1)
cut_off = jnp.take_along_axis(sorted_confidence, mask_len.astype(jnp.int32), axis=-1)
# Masks tokens with lower confidence.
masking = (confidence < cut_off)
return masking
Expand Down Expand Up @@ -140,7 +140,7 @@ def loop_body_fn(state):
# Computes the probabilities of each selected tokens.
probs = jax.nn.softmax(logits, axis=-1)
selected_probs = jnp.squeeze(
jnp.take_along_axis(probs, jnp.expand_dims(sampled_ids, -1), -1), -1)
jnp.take_along_axis(probs, jnp.expand_dims(sampled_ids.astype(jnp.int32), -1), -1), -1)
# Ignores the tokens given in the input by overwriting their confidence.
selected_probs = jnp.where(unknown_map, selected_probs,
_CONFIDENCE_OF_KNOWN_TOKENS)
Expand Down

0 comments on commit 1db2359

Please sign in to comment.