Skip to content

Commit

Permalink
remove use of numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Jan 10, 2025
1 parent 3ed911d commit 45d57ed
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 31 deletions.
37 changes: 10 additions & 27 deletions docs/marimo/autotune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import marimo

__generated_with = "0.10.9"
__generated_with = "0.10.10"
app = marimo.App(width="full")


Expand Down Expand Up @@ -553,17 +553,8 @@ def _(

@app.cell
def _():
import numpy as np
from numpy.random import Generator as RandomGenerator
return RandomGenerator, np


@app.cell
def _(np):
rng = np.random.default_rng()

type(rng)
return (rng,)
from random import Random
return (Random,)


@app.cell
Expand All @@ -573,28 +564,20 @@ def _():


@app.cell
def _(
Attribute,
DenseIntOrFPElementsAttr,
RandomGenerator,
f64,
np,
to_dtype,
):
def _(Attribute, DenseIntOrFPElementsAttr, Random, f64):
from xdsl.dialects.builtin import FloatAttr, ShapedType, ContainerType, TensorType


def random_attr_of_type(t: Attribute, rng: RandomGenerator) -> Attribute | None:
def random_attr_of_type(t: Attribute, rng: Random) -> Attribute | None:
# if isinstance(type, ShapedType):
if t == f64:
return FloatAttr(rng.random(1, to_dtype(t))[0], f64)
return FloatAttr(rng.random(), f64)
elif isinstance(t, ShapedType) and isinstance(t, ContainerType):
values = rng.random(t.element_count(), to_dtype(t.get_element_type()))
return DenseIntOrFPElementsAttr.from_list(
t, values
t, tuple(rng.random() for _ in range(t.element_count()))
)

_rng = np.random.default_rng()
_rng = Random("autotune")
random_attr_of_type(f64, _rng), random_attr_of_type(TensorType(f64, (2, 3)), _rng)
return (
ContainerType,
Expand All @@ -612,12 +595,12 @@ def _(
MemrefStreamUnrollAndJamPass,
ModuleOp,
ModulePass,
Random,
SnitchCycleCostModel,
func,
memref_stream,
memref_stream_module,
msg_factors,
np,
random_attr_of_type,
riscv_passes,
uaj_passes,
Expand Down Expand Up @@ -652,7 +635,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:

arg_types = func_op.function_type.inputs

rng = np.random.default_rng()
rng = Random("autotune")
attrs = tuple(random_attr_of_type(t, rng) for t in arg_types)

cost_model = LensCostModel(SnitchCycleCostModel(func_op_name, attrs), riscv_passes.passes)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ dev = [
"textual-dev==1.7.0",
"pytest-asyncio==0.25.2",
"pyright==1.1.391",
"numpy==2.2.1",
]
gui = ["textual==1.0.0", "pyclip==0.7"]
jax = ["jax==0.4.38", "numpy==2.2.1"]
Expand Down
4 changes: 1 addition & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 45d57ed

Please sign in to comment.