Skip to content

Remove unnecessary handling of no longer supported RandomState #1300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/extending/extending_pytensor_solution_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def setup_method(self):
self.op_class = SumDiffOp

def test_perform(self):
rng = np.random.RandomState(43)
rng = np.random.default_rng(43)
x = matrix()
y = matrix()
f = pytensor.function([x, y], self.op_class()(x, y))
Expand All @@ -128,7 +128,7 @@ def test_perform(self):
assert np.allclose([x_val + y_val, x_val - y_val], out)

def test_gradient(self):
rng = np.random.RandomState(43)
rng = np.random.default_rng(43)

def output_0(x, y):
return self.op_class()(x, y)[0]
Expand All @@ -150,7 +150,7 @@ def output_1(x, y):
)

def test_infer_shape(self):
rng = np.random.RandomState(43)
rng = np.random.default_rng(43)

x = dmatrix()
y = dmatrix()
Expand Down
2 changes: 1 addition & 1 deletion doc/library/d3viz/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"noutputs = 10\n",
"nhiddens = 50\n",
"\n",
"rng = np.random.RandomState(0)\n",
"rng = np.random.default_rng(0)\n",
"x = pt.dmatrix('x')\n",
"wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)\n",
"bh = pytensor.shared(np.zeros(nhiddens), borrow=True)\n",
Expand Down
2 changes: 1 addition & 1 deletion doc/library/d3viz/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ hidden layer and a softmax output layer.
noutputs = 10
nhiddens = 50

rng = np.random.RandomState(0)
rng = np.random.default_rng(0)
x = pt.dmatrix('x')
wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)
bh = pytensor.shared(np.zeros(nhiddens), borrow=True)
Expand Down
2 changes: 1 addition & 1 deletion doc/optimizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ Optimization o4 o3 o2
See :func:`insert_inplace_optimizer`

inplace_random
Typically when a graph uses random numbers, the RandomState is stored
Typically when a graph uses random numbers, the random Generator is stored
in a shared variable, used once per call and, updated after each function
call. In this common case, it makes sense to update the random number generator in-place.

Expand Down
5 changes: 1 addition & 4 deletions pytensor/compile/monitormode.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ def detect_nan(fgraph, i, node, fn):
from pytensor.printing import debugprint

for output in fn.outputs:
if (
not isinstance(output[0], np.random.RandomState | np.random.Generator)
and np.isnan(output[0]).any()
):
if not isinstance(output[0], np.random.Generator) and np.isnan(output[0]).any():
print("*** NaN detected ***") # noqa: T201
debugprint(node)
print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/nanguardmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _is_numeric_value(arr, var):

if isinstance(arr, _cdata_type):
return False
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator):
elif isinstance(arr, np.random.Generator):
return False
elif var is not None and isinstance(var.type, RandomType):
return False
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings

from numpy.random import Generator, RandomState
from numpy.random import Generator

from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.link.basic import JITLinker
Expand All @@ -21,7 +21,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):

# Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because
# JAX does not accept RandomState/Generators as inputs, and they will have to
# JAX does not accept Generators as inputs, and they will have to
# be tipyfied
if shared_rng_inputs:
warnings.warn(
Expand Down Expand Up @@ -79,7 +79,7 @@ def create_thunk_inputs(self, storage_map):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState | Generator):
if isinstance(sinput[0], Generator):
new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
Expand Down
20 changes: 1 addition & 19 deletions pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,4 @@ def jit_compile(self, fn):
return jitted_fn

def create_thunk_inputs(self, storage_map):
from numpy.random import RandomState

from pytensor.link.numba.dispatch import numba_typify

thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState):
new_value = numba_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-Numba-fied graphs will have problems.
sinput = [new_value]
thunk_inputs.append(sinput)

return thunk_inputs
return [storage_map[n] for n in self.fgraph.inputs]
17 changes: 9 additions & 8 deletions pytensor/tensor/random/type.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import TypeVar

import numpy as np
from numpy.random import Generator

import pytensor
from pytensor.graph.type import Type


T = TypeVar("T", np.random.RandomState, np.random.Generator)
T = TypeVar("T")


gen_states_keys = {
Expand All @@ -24,14 +25,10 @@


class RandomType(Type[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""

@staticmethod
def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
r"""A Type wrapper for `numpy.random.Generator."""


class RandomGeneratorType(RandomType[np.random.Generator]):
class RandomGeneratorType(RandomType[Generator]):
r"""A Type wrapper for `numpy.random.Generator`.

The reason this exists (and `Generic` doesn't suffice) is that
Expand All @@ -47,6 +44,10 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
def __repr__(self):
return "RandomGeneratorType"

@staticmethod
def may_share_memory(a: Generator, b: Generator):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]

def filter(self, data, strict=False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
Expand All @@ -58,7 +59,7 @@ def filter(self, data, strict=False, allow_downcast=None):
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.Generator):
if isinstance(data, Generator):
return data

if not strict and isinstance(data, dict):
Expand Down
3 changes: 1 addition & 2 deletions tests/unittest_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def fetch_seed(pseed=None):
If config.unittest.rseed is set to "random", it will seed the rng with
None, which is equivalent to seeding with a random seed.

Useful for seeding RandomState or Generator objects.
>>> rng = np.random.RandomState(fetch_seed())
Useful for seeding Generator objects.
>>> rng = np.random.default_rng(fetch_seed())
"""

Expand Down