Skip to content

Commit

Permalink
Random should respect dtypes.
Browse files Browse the repository at this point in the history
  • Loading branch information
alxmrs committed Jul 23, 2024
1 parent 2907001 commit 10b22b9
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cubed/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def random(size, *, chunks=None, spec=None, device=None):
)


def _random(x, numblocks=None, root_seed=None, block_id=None):
def _random(x, numblocks=None, root_seed=None, block_id=None, dtype=None):
stream_id = block_id_to_offset(block_id, numblocks)
rg = Generator(Philox(key=root_seed + stream_id))
out = rg.random(x.shape)
out = numpy_array_to_backend_array(out)
out = rg.random(x.shape, dtype=dtype)
out = numpy_array_to_backend_array(out, dtype=dtype)
return out

0 comments on commit 10b22b9

Please sign in to comment.