diff --git a/Cargo.toml b/Cargo.toml index 7de6214..96d461b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stitch_bindings" -version = "0.1.23" +version = "0.1.24" edition = "2021" [lib] @@ -12,7 +12,7 @@ name = "stitch_core" [dependencies] # stitch_core = { path = "../stitch"} -stitch_core = { git = "https://github.com/mlb2251/stitch", rev = "f96ba0e"} +stitch_core = { git = "https://github.com/mlb2251/stitch", rev = "323da8c"} pyo3 = {version = "0.17.3", features = ["extension-module"] } clap = { version = "3.1.0" } diff --git a/src/lib.rs b/src/lib.rs index 4061fcb..9e117fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,13 +7,16 @@ use clap::Parser; #[pyfunction( programs, tasks, + weights, name_mapping, + panic_loud, args )] fn compress_backend( py: Python, programs: Vec, tasks: Option>, + weights: Option>, name_mapping: Option>, panic_loud: bool, args: String, @@ -31,7 +34,7 @@ fn compress_backend( // release the GIL and call compression let (_step_results, json_res) = py.allow_threads(|| - multistep_compression(&programs, tasks, name_mapping, None, &cfg) + multistep_compression(&programs, tasks, weights, name_mapping, None, &cfg) ); // return as something you could json.loads(out) from in python @@ -42,6 +45,7 @@ fn compress_backend( #[pyfunction( programs, abstractions, + panic_loud, args )] fn rewrite_backend( diff --git a/stitch_core/__init__.py b/stitch_core/__init__.py index dd86ff5..fe884d8 100644 --- a/stitch_core/__init__.py +++ b/stitch_core/__init__.py @@ -289,6 +289,7 @@ def compress( """ tasks = kwargs.pop("tasks", None) + weights = kwargs.pop("weights", None) name_mapping = kwargs.pop("name_mapping", None) panic_loud = kwargs.pop('panic_loud',False) @@ -305,6 +306,7 @@ def compress( res = compress_backend( programs, tasks, + weights, name_mapping, panic_loud, args) diff --git a/tests/test.py b/tests/test.py index 32c62fd..3954075 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,5 +1,6 @@ from stitch_core import compress, rewrite, StitchException, from_dreamcoder, Abstraction, name_mapping_stitch, stitch_to_dreamcoder import json +import math # simple test programs = ["(a a a)", "(b b b)"] @@ -68,4 +69,22 @@ # print(e) pass +# 1x (default) weighting vs 2x weighting vs weighting the "g" programs more +programs = ["(f a a)", "(f b b)", "(f c c)", "(g d d)", "(g e e)"] +res = compress(programs, iterations=1) +res2x = compress(programs, iterations=1, weights=[2. for _ in programs]) +res_uneven = compress(programs, iterations=1, weights=[1., 1., 1., 2., 2.]) + +assert res.json["original_cost"] *2 == res2x.json["original_cost"] +assert res.json["final_cost"] *2 == res2x.json["final_cost"] +assert res.abstractions[0].body == res2x.abstractions[0].body == "(f #0 #0)" +assert res_uneven.abstractions[0].body == "(g #0 #0)" + +# make sure compression ratio is as expected +assert math.fabs(res_uneven.json["original_cost"]/res_uneven.json["final_cost"] - res_uneven.json["compression_ratio"]) < 0.00001 + +# assert res.rewritten == ['(fn_0 a)', '(fn_0 b)'] +# assert res.abstractions[0].body == '(#0 #0 #0)' + + print("Passed all tests") \ No newline at end of file