Skip to content

Commit f71e5b9

Browse files
committed
Remove unnecessary handling of no longer supported RandomState
1 parent 4378d48 commit f71e5b9

File tree

10 files changed

+15
-37
lines changed

10 files changed

+15
-37
lines changed

doc/extending/extending_pytensor_solution_1.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def setup_method(self):
118118
self.op_class = SumDiffOp
119119

120120
def test_perform(self):
121-
rng = np.random.RandomState(43)
121+
rng = np.random.default_rng(43)
122122
x = matrix()
123123
y = matrix()
124124
f = pytensor.function([x, y], self.op_class()(x, y))
@@ -128,7 +128,7 @@ def test_perform(self):
128128
assert np.allclose([x_val + y_val, x_val - y_val], out)
129129

130130
def test_gradient(self):
131-
rng = np.random.RandomState(43)
131+
rng = np.random.default_rng(43)
132132

133133
def output_0(x, y):
134134
return self.op_class()(x, y)[0]
@@ -150,7 +150,7 @@ def output_1(x, y):
150150
)
151151

152152
def test_infer_shape(self):
153-
rng = np.random.RandomState(43)
153+
rng = np.random.default_rng(43)
154154

155155
x = dmatrix()
156156
y = dmatrix()

doc/library/d3viz/index.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
"noutputs = 10\n",
9696
"nhiddens = 50\n",
9797
"\n",
98-
"rng = np.random.RandomState(0)\n",
98+
"rng = np.random.default_rng(0)\n",
9999
"x = pt.dmatrix('x')\n",
100100
"wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)\n",
101101
"bh = pytensor.shared(np.zeros(nhiddens), borrow=True)\n",

doc/library/d3viz/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ hidden layer and a softmax output layer.
5858
noutputs = 10
5959
nhiddens = 50
6060
61-
rng = np.random.RandomState(0)
61+
rng = np.random.default_rng(0)
6262
x = pt.dmatrix('x')
6363
wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)
6464
bh = pytensor.shared(np.zeros(nhiddens), borrow=True)

doc/optimizations.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ Optimization o4 o3 o2
239239
See :func:`insert_inplace_optimizer`
240240

241241
inplace_random
242-
Typically when a graph uses random numbers, the RandomState is stored
242+
Typically when a graph uses random numbers, the random Generator is stored
243243
in a shared variable, used once per call and, updated after each function
244244
call. In this common case, it makes sense to update the random number generator in-place.
245245

pytensor/compile/monitormode.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,7 @@ def detect_nan(fgraph, i, node, fn):
104104
from pytensor.printing import debugprint
105105

106106
for output in fn.outputs:
107-
if (
108-
not isinstance(output[0], np.random.RandomState | np.random.Generator)
109-
and np.isnan(output[0]).any()
110-
):
107+
if not isinstance(output[0], np.random.Generator) and np.isnan(output[0]).any():
111108
print("*** NaN detected ***") # noqa: T201
112109
debugprint(node)
113110
print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201

pytensor/compile/nanguardmode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _is_numeric_value(arr, var):
3434

3535
if isinstance(arr, _cdata_type):
3636
return False
37-
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator):
37+
elif isinstance(arr, np.random.Generator):
3838
return False
3939
elif var is not None and isinstance(var.type, RandomType):
4040
return False

pytensor/link/jax/linker.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22

3-
from numpy.random import Generator, RandomState
3+
from numpy.random import Generator
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
66
from pytensor.link.basic import JITLinker
@@ -21,7 +21,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2121

2222
# Replace any shared RNG inputs so that their values can be updated in place
2323
# without affecting the original RNG container. This is necessary because
24-
# JAX does not accept RandomState/Generators as inputs, and they will have to
24+
# JAX does not accept Generators as inputs, and they will have to
2525
# be tipyfied
2626
if shared_rng_inputs:
2727
warnings.warn(
@@ -79,7 +79,7 @@ def create_thunk_inputs(self, storage_map):
7979
thunk_inputs = []
8080
for n in self.fgraph.inputs:
8181
sinput = storage_map[n]
82-
if isinstance(sinput[0], RandomState | Generator):
82+
if isinstance(sinput[0], Generator):
8383
new_value = jax_typify(
8484
sinput[0], dtype=getattr(sinput[0], "dtype", None)
8585
)

pytensor/link/numba/linker.py

+1-19
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,4 @@ def jit_compile(self, fn):
1616
return jitted_fn
1717

1818
def create_thunk_inputs(self, storage_map):
19-
from numpy.random import RandomState
20-
21-
from pytensor.link.numba.dispatch import numba_typify
22-
23-
thunk_inputs = []
24-
for n in self.fgraph.inputs:
25-
sinput = storage_map[n]
26-
if isinstance(sinput[0], RandomState):
27-
new_value = numba_typify(
28-
sinput[0], dtype=getattr(sinput[0], "dtype", None)
29-
)
30-
# We need to remove the reference-based connection to the
31-
# original `RandomState`/shared variable's storage, because
32-
# subsequent attempts to use the same shared variable within
33-
# other non-Numba-fied graphs will have problems.
34-
sinput = [new_value]
35-
thunk_inputs.append(sinput)
36-
37-
return thunk_inputs
19+
return [storage_map[n] for n in self.fgraph.inputs]

pytensor/tensor/random/type.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pytensor.graph.type import Type
77

88

9-
T = TypeVar("T", np.random.RandomState, np.random.Generator)
9+
T = TypeVar("T", np.random.Generator)
1010

1111

1212
gen_states_keys = {
@@ -24,7 +24,7 @@
2424

2525

2626
class RandomType(Type[T]):
27-
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
27+
r"""A Type wrapper for `numpy.random.Generator."""
2828

2929
@staticmethod
3030
def may_share_memory(a: T, b: T):

tests/unittest_tools.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def fetch_seed(pseed=None):
2727
If config.unittest.rseed is set to "random", it will seed the rng with
2828
None, which is equivalent to seeding with a random seed.
2929
30-
Useful for seeding RandomState or Generator objects.
31-
>>> rng = np.random.RandomState(fetch_seed())
30+
Useful for seeding Generator objects.
3231
>>> rng = np.random.default_rng(fetch_seed())
3332
"""
3433

0 commit comments

Comments
 (0)