Skip to content

Commit

Permalink
Remove _asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 10, 2024
1 parent 02cea48 commit 25147b8
Show file tree
Hide file tree
Showing 29 changed files with 134 additions and 216 deletions.
57 changes: 0 additions & 57 deletions pytensor/misc/safe_asarray.py

This file was deleted.

13 changes: 7 additions & 6 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from pytensor.graph.utils import MetaObject, MethodNotDefined
from pytensor.link.c.op import COp
from pytensor.link.c.type import CType
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import pprint
from pytensor.utils import (
apply_across_args,
Expand Down Expand Up @@ -150,7 +149,7 @@ def __call__(self, x):
and rval.dtype in ("float64", "float32")
and rval.dtype != config.floatX
):
rval = _asarray(rval, dtype=config.floatX)
rval = np.asarray(rval, dtype=config.floatX)
return rval

# The following is the original code, corresponding to the 'custom'
Expand All @@ -176,15 +175,15 @@ def __call__(self, x):
and config.floatX in self.dtypes
and config.floatX != "float64"
):
return _asarray(x, dtype=config.floatX)
return np.asarray(x, dtype=config.floatX)

# Don't autocast to float16 unless config.floatX is float16
try_dtypes = [
d for d in self.dtypes if config.floatX == "float16" or d != "float16"
]

for dtype in try_dtypes:
x_ = _asarray(x, dtype=dtype)
x_ = np.asarray(x).astype(dtype=dtype)
if np.all(x == x_):
break
# returns either an exact x_==x, or the last cast x_
Expand Down Expand Up @@ -245,7 +244,9 @@ def convert(x, dtype=None):

if dtype is not None:
# in this case, the semantics are that the caller is forcing the dtype
x_ = _asarray(x, dtype=dtype)
if dtype == "floatX":
dtype = config.floatX
x_ = np.asarray(x).astype(dtype)
else:
# In this case, this function should infer the dtype according to the
# autocasting rules. See autocasting above.
Expand All @@ -256,7 +257,7 @@ def convert(x, dtype=None):
except OverflowError:
# This is to imitate numpy behavior which tries to fit
# bigger numbers into a uint64.
x_ = _asarray(x, dtype="uint64")
x_ = np.asarray(x, dtype="uint64")
elif isinstance(x, builtins.float):
x_ = autocast_float(x)
elif isinstance(x, np.ndarray):
Expand Down
23 changes: 11 additions & 12 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.link.c.type import generic
from pytensor.misc.safe_asarray import _asarray
from pytensor.sparse.type import SparseTensorType, _is_sparse
from pytensor.sparse.utils import hash_from_sparse
from pytensor.tensor import basic as ptb
Expand Down Expand Up @@ -595,11 +594,11 @@ def perform(self, node, inputs, out):
(csm,) = inputs
out[0][0] = csm.data
if str(csm.data.dtype) == "int32":
out[0][0] = _asarray(out[0][0], dtype="int32")
out[0][0] = np.asarray(out[0][0], dtype="int32")
# backport
out[1][0] = _asarray(csm.indices, dtype="int32")
out[2][0] = _asarray(csm.indptr, dtype="int32")
out[3][0] = _asarray(csm.shape, dtype="int32")
out[1][0] = np.asarray(csm.indices, dtype="int32")
out[2][0] = np.asarray(csm.indptr, dtype="int32")
out[3][0] = np.asarray(csm.shape, dtype="int32")

def grad(self, inputs, g):
# g[1:] is all integers, so their Jacobian in this op
Expand Down Expand Up @@ -698,17 +697,17 @@ def make_node(self, data, indices, indptr, shape):

if not isinstance(indices, Variable):
indices_ = np.asarray(indices)
indices_32 = _asarray(indices, dtype="int32")
indices_32 = np.asarray(indices, dtype="int32")
assert (indices_ == indices_32).all()
indices = indices_32
if not isinstance(indptr, Variable):
indptr_ = np.asarray(indptr)
indptr_32 = _asarray(indptr, dtype="int32")
indptr_32 = np.asarray(indptr, dtype="int32")
assert (indptr_ == indptr_32).all()
indptr = indptr_32
if not isinstance(shape, Variable):
shape_ = np.asarray(shape)
shape_32 = _asarray(shape, dtype="int32")
shape_32 = np.asarray(shape, dtype="int32")
assert (shape_ == shape_32).all()
shape = shape_32

Expand Down Expand Up @@ -1461,7 +1460,7 @@ def perform(self, node, inputs, outputs):
(x, ind1, ind2) = inputs
(out,) = outputs
assert _is_sparse(x)
out[0] = _asarray(x[ind1, ind2], x.dtype)
out[0] = np.asarray(x[ind1, ind2], x.dtype)


get_item_scalar = GetItemScalar()
Expand Down Expand Up @@ -2142,7 +2141,7 @@ def perform(self, node, inputs, outputs):

# The asarray is needed as in some case, this return a
# numpy.matrixlib.defmatrix.matrix object and not an ndarray.
out[0] = _asarray(x + y, dtype=node.outputs[0].type.dtype)
out[0] = np.asarray(x + y, dtype=node.outputs[0].type.dtype)

def grad(self, inputs, gout):
(x, y) = inputs
Expand Down Expand Up @@ -3497,7 +3496,7 @@ def perform(self, node, inputs, outputs):

# The cast is needed as otherwise we hit the bug mentioned into
# _asarray function documentation.
out[0] = _asarray(variable, str(variable.dtype))
out[0] = np.asarray(variable, str(variable.dtype))

def grad(self, inputs, gout):
# a is sparse, b is dense, g_out is dense
Expand Down Expand Up @@ -4012,7 +4011,7 @@ def perform(self, node, inputs, out):
if x_is_sparse and y_is_sparse:
rval = rval.toarray()

out[0] = _asarray(rval, dtype=node.outputs[0].dtype)
out[0] = np.asarray(rval, dtype=node.outputs[0].dtype)

def grad(self, inputs, gout):
(x, y) = inputs
Expand Down
4 changes: 2 additions & 2 deletions pytensor/sparse/rewriting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import scipy

import pytensor
Expand All @@ -10,7 +11,6 @@
node_rewriter,
)
from pytensor.link.c.op import COp, _NoPythonCOp
from pytensor.misc.safe_asarray import _asarray
from pytensor.sparse import basic as sparse
from pytensor.sparse.basic import (
CSC,
Expand Down Expand Up @@ -283,7 +283,7 @@ def perform(self, node, inputs, outputs):
(a_val, a_ind, a_ptr), (a_nrows, b.shape[0]), copy=False
)
# out[0] = a.dot(b)
out[0] = _asarray(a * b, dtype=node.outputs[0].type.dtype)
out[0] = np.asarray(a * b, dtype=node.outputs[0].type.dtype)
assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense

def c_code(self, node, name, inputs, outputs, sub):
Expand Down
7 changes: 3 additions & 4 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise, assert_op
from pytensor.scalar import int32
Expand Down Expand Up @@ -512,7 +511,7 @@ def get_underlying_scalar_constant_value(
ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur)
# MakeVector can cast implicitly its input in some case.
return _asarray(ret, dtype=v.type.dtype)
return np.asarray(ret, dtype=v.type.dtype)

# This is needed when we take the grad as the Shape op
# are not already changed into MakeVector
Expand Down Expand Up @@ -1834,7 +1833,7 @@ def perform(self, node, inputs, out_):
(out,) = out_
# not calling pytensor._asarray as optimization
if (out[0] is None) or (out[0].size != len(inputs)):
out[0] = _asarray(inputs, dtype=node.outputs[0].dtype)
out[0] = np.asarray(inputs, dtype=node.outputs[0].dtype)
else:
# assume that out has correct dtype. there is no cheap way to check
out[0][...] = inputs
Expand Down Expand Up @@ -2537,7 +2536,7 @@ def perform(self, node, axis_and_tensors, out_):
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
)

out[0] = _asarray(
out[0] = np.asarray(
np.concatenate(tens, axis=axis), dtype=node.outputs[0].type.dtype
)

Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool
Expand Down Expand Up @@ -1412,7 +1411,7 @@ def perform(self, node, inp, out):

out = self.ufunc.reduce(input, axis=axis, dtype=acc_dtype)

output[0] = _asarray(out, dtype=out_dtype)
output[0] = np.asarray(out, dtype=out_dtype)

def infer_shape(self, fgraph, node, shapes):
(ishape,) = shapes
Expand Down
5 changes: 2 additions & 3 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import EnumList, Generic
from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
Expand Down Expand Up @@ -1307,7 +1306,7 @@ def perform(self, node, inp, out):
res = np.unravel_index(indices, dims, order=self.order)
assert len(res) == len(out)
for i in range(len(out)):
ret = _asarray(res[i], node.outputs[0].dtype)
ret = np.asarray(res[i], node.outputs[0].dtype)
if ret.base is not None:
# NumPy will return a view when it can.
# But we don't want that.
Expand Down Expand Up @@ -1382,7 +1381,7 @@ def infer_shape(self, fgraph, node, input_shapes):
def perform(self, node, inp, out):
multi_index, dims = inp[:-1], inp[-1]
res = np.ravel_multi_index(multi_index, dims, mode=self.mode, order=self.order)
out[0][0] = _asarray(res, node.outputs[0].dtype)
out[0][0] = np.asarray(res, node.outputs[0].dtype)


def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
Expand Down
31 changes: 15 additions & 16 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import pprint
from pytensor.raise_op import Assert
from pytensor.scalar.basic import BinaryScalarOp
Expand Down Expand Up @@ -202,7 +201,7 @@ def perform(self, node, inp, outs):
new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
reshaped_x = transposed_x.reshape(new_shape)

max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")

def c_code(self, node, name, inp, out, sub):
(x,) = inp
Expand Down Expand Up @@ -730,32 +729,32 @@ def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
--------
>>> import pytensor
>>> import numpy as np
>>> a = _asarray([1e10, 1e-7], dtype="float64")
>>> b = _asarray([1.00001e10, 1e-8], dtype="float64")
>>> a = np.array([1e10, 1e-7], dtype="float64")
>>> b = np.array([1.00001e10, 1e-8], dtype="float64")
>>> pytensor.tensor.isclose(a, b).eval()
array([ True, False])
>>> a = _asarray([1e10, 1e-8], dtype="float64")
>>> b = _asarray([1.00001e10, 1e-9], dtype="float64")
>>> a = np.array([1e10, 1e-8], dtype="float64")
>>> b = np.array([1.00001e10, 1e-9], dtype="float64")
>>> pytensor.tensor.isclose(a, b).eval()
array([ True, True])
>>> a = _asarray([1e10, 1e-8], dtype="float64")
>>> b = _asarray([1.0001e10, 1e-9], dtype="float64")
>>> a = np.array([1e10, 1e-8], dtype="float64")
>>> b = np.array([1.0001e10, 1e-9], dtype="float64")
>>> pytensor.tensor.isclose(a, b).eval()
array([False, True])
>>> a = _asarray([1.0, np.nan], dtype="float64")
>>> b = _asarray([1.0, np.nan], dtype="float64")
>>> a = np.array([1.0, np.nan], dtype="float64")
>>> b = np.array([1.0, np.nan], dtype="float64")
>>> pytensor.tensor.isclose(a, b).eval()
array([ True, False])
>>> a = _asarray([1.0, np.nan], dtype="float64")
>>> b = _asarray([1.0, np.nan], dtype="float64")
>>> a = np.array([1.0, np.nan], dtype="float64")
>>> b = np.array([1.0, np.nan], dtype="float64")
>>> pytensor.tensor.isclose(a, b, equal_nan=True).eval()
array([ True, True])
>>> a = _asarray([1.0, np.inf], dtype="float64")
>>> b = _asarray([1.0, -np.inf], dtype="float64")
>>> a = np.array([1.0, np.inf], dtype="float64")
>>> b = np.array([1.0, -np.inf], dtype="float64")
>>> pytensor.tensor.isclose(a, b).eval()
array([ True, False])
>>> a = _asarray([1.0, np.inf], dtype="float64")
>>> b = _asarray([1.0, np.inf], dtype="float64")
>>> a = np.array([1.0, np.inf], dtype="float64")
>>> b = np.array([1.0, np.inf], dtype="float64")
>>> pytensor.tensor.isclose(a, b).eval()
array([ True, True])
Expand Down
9 changes: 4 additions & 5 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
node_rewriter,
)
from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import assert_op
from pytensor.tensor.basic import (
Alloc,
Expand Down Expand Up @@ -1205,7 +1204,7 @@ def mul_calculate(num, denum, aslist=False, out_type=None):
out_dtype = ps.upcast(*[v.dtype for v in (num + denum)])
else:
out_dtype = out_type.dtype
one = _asarray(1, dtype=out_dtype)
one = np.asarray(1, dtype=out_dtype)

v = reduce(np.multiply, num, one) / reduce(np.multiply, denum, one)
if aslist:
Expand Down Expand Up @@ -1878,7 +1877,7 @@ def local_mul_zero(fgraph, node):
# print 'MUL by value', value, node.inputs
if value == 0:
# print '... returning zeros'
return [broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0]]
return [broadcast_arrays(np.asarray(0, dtype=otype.dtype), *node.inputs)[0]]


# TODO: Add this to the canonicalization to reduce redundancy.
Expand Down Expand Up @@ -2353,8 +2352,8 @@ def add_calculate(num, denum, aslist=False, out_type=None):
if out_type is None:
zero = 0.0
else:
zero = _asarray(0, dtype=out_type.dtype)
# zero = 0.0 if out_type is None else _asarray(0,
zero = np.asarray(0, dtype=out_type.dtype)
# zero = 0.0 if out_type is None else np.asarray(0,
# dtype=out_type.dtype)
if out_type and out_type.dtype == "bool":
if len(denum) == 0:
Expand Down
Loading

0 comments on commit 25147b8

Please sign in to comment.