From 45d57ed2d6531a3a0b4792821a70538bd6a4ecd6 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 10 Jan 2025 23:27:41 +0000 Subject: [PATCH] remove use of numpy --- docs/marimo/autotune.py | 37 ++++++++++--------------------------- pyproject.toml | 1 - uv.lock | 4 +--- 3 files changed, 11 insertions(+), 31 deletions(-) diff --git a/docs/marimo/autotune.py b/docs/marimo/autotune.py index b2b65ceeda..60ff787920 100644 --- a/docs/marimo/autotune.py +++ b/docs/marimo/autotune.py @@ -1,6 +1,6 @@ import marimo -__generated_with = "0.10.9" +__generated_with = "0.10.10" app = marimo.App(width="full") @@ -553,17 +553,8 @@ def _( @app.cell def _(): - import numpy as np - from numpy.random import Generator as RandomGenerator - return RandomGenerator, np - - -@app.cell -def _(np): - rng = np.random.default_rng() - - type(rng) - return (rng,) + from random import Random + return (Random,) @app.cell @@ -573,28 +564,20 @@ def _(): @app.cell -def _( - Attribute, - DenseIntOrFPElementsAttr, - RandomGenerator, - f64, - np, - to_dtype, -): +def _(Attribute, DenseIntOrFPElementsAttr, Random, f64): from xdsl.dialects.builtin import FloatAttr, ShapedType, ContainerType, TensorType - def random_attr_of_type(t: Attribute, rng: RandomGenerator) -> Attribute | None: + def random_attr_of_type(t: Attribute, rng: Random) -> Attribute | None: # if isinstance(type, ShapedType): if t == f64: - return FloatAttr(rng.random(1, to_dtype(t))[0], f64) + return FloatAttr(rng.random(), f64) elif isinstance(t, ShapedType) and isinstance(t, ContainerType): - values = rng.random(t.element_count(), to_dtype(t.get_element_type())) return DenseIntOrFPElementsAttr.from_list( - t, values + t, tuple(rng.random() for _ in range(t.element_count())) ) - _rng = np.random.default_rng() + _rng = Random("autotune") random_attr_of_type(f64, _rng), random_attr_of_type(TensorType(f64, (2, 3)), _rng) return ( ContainerType, @@ -612,12 +595,12 @@ def _( MemrefStreamUnrollAndJamPass, ModuleOp, ModulePass, + Random, SnitchCycleCostModel, func, memref_stream, memref_stream_module, msg_factors, - np, random_attr_of_type, riscv_passes, uaj_passes, @@ -652,7 +635,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: arg_types = func_op.function_type.inputs - rng = np.random.default_rng() + rng = Random("autotune") attrs = tuple(random_attr_of_type(t, rng) for t in arg_types) cost_model = LensCostModel(SnitchCycleCostModel(func_op_name, attrs), riscv_passes.passes) diff --git a/pyproject.toml b/pyproject.toml index 1c3361df2d..67c4b73584 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,6 @@ dev = [ "textual-dev==1.7.0", "pytest-asyncio==0.25.2", "pyright==1.1.391", - "numpy==2.2.1", ] gui = ["textual==1.0.0", "pyclip==0.7"] jax = ["jax==0.4.38", "numpy==2.2.1"] diff --git a/uv.lock b/uv.lock index 50895bf95e..aa29fe77dd 100644 --- a/uv.lock +++ b/uv.lock @@ -2544,7 +2544,6 @@ dev = [ { name = "marimo" }, { name = "nbconvert" }, { name = "nbval" }, - { name = "numpy" }, { name = "pip" }, { name = "pre-commit" }, { name = "pyright" }, @@ -2586,7 +2585,6 @@ requires-dist = [ { name = "marimo", marker = "extra == 'dev'", specifier = "==0.10.10" }, { name = "nbconvert", marker = "extra == 'dev'", specifier = ">=7.7.2,<8.0.0" }, { name = "nbval", marker = "extra == 'dev'", specifier = "<0.12" }, - { name = "numpy", marker = "extra == 'dev'", specifier = "==2.2.1" }, { name = "numpy", marker = "extra == 'jax'", specifier = "==2.2.1" }, { name = "numpy", marker = "extra == 'onnx'", specifier = "==2.2.1" }, { name = "onnx", marker = "extra == 'onnx'", specifier = "==1.17.0" }, @@ -2599,7 +2597,7 @@ requires-dist = [ { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==0.25.2" }, { name = "pytest-cov", marker = "extra == 'dev'" }, { name = "riscemu", marker = "extra == 'riscv'", specifier = "==2.2.7" }, - { name = "ruff", marker = "extra == 'dev'", specifier = "==0.9" }, + { name = "ruff", marker = "extra == 'dev'", specifier = "==0.9.0" }, { name = "textual", marker = "extra == 'gui'", specifier = "==1.0.0" }, { name = "textual-dev", marker = "extra == 'dev'", specifier = "==1.7.0" }, { name = "toml", marker = "extra == 'dev'", specifier = "<0.11" },