diff --git a/genjax-localization-tutorial/probcomp-localization-tutorial.py b/genjax-localization-tutorial/probcomp-localization-tutorial.py index 1ce4d86..36b64cd 100644 --- a/genjax-localization-tutorial/probcomp-localization-tutorial.py +++ b/genjax-localization-tutorial/probcomp-localization-tutorial.py @@ -39,7 +39,6 @@ import jax import jax.numpy as jnp import genjax -from urllib.request import urlopen from genjax import SelectionBuilder as S from genjax import ChoiceMapBuilder as C from genjax.typing import Array, FloatArray, PRNGKey, IntArray @@ -197,27 +196,42 @@ def make_world(wall_verts, clutters_vec, start, controls): def load_world(file_name): """ - Loads the world configuration from a specified file and constructs the world. + Loads the world configuration from local JSON files and constructs the world. Args: - - file_name (str): The name of the file containing the world configuration. + - file_name (str): Not used, kept for backwards compatibility Returns: - tuple: A tuple containing the world configuration, the initial state, and the total number of control steps. """ - with urlopen( - "https://raw.githubusercontent.com/probcomp/gen-localization/main/resources/example_20_program.json" - ) as url: - data = json.load(url) + # Try to find the JSON files relative to different possible working directories + possible_paths = [ + '.', # Current directory + '..', # Parent directory + os.path.dirname(__file__), # Directory containing this file + os.path.dirname(os.path.dirname(__file__)), # Parent of directory containing this file + ] + + def find_json(filename): + for path in possible_paths: + full_path = os.path.join(path, filename) + if os.path.exists(full_path): + with open(full_path, 'r') as f: + return json.load(f) + raise FileNotFoundError(f"Could not find {filename} in any of: {possible_paths}") + + # Load both required JSON files + world_data = find_json('world.json') + program_data = find_json('robot_program.json') - walls_vec = jnp.array(data["wall_verts"]) - clutters_vec = jnp.array(data["clutter_vert_groups"]) + walls_vec = jnp.array(world_data["wall_verts"]) + clutters_vec = jnp.array(world_data["clutter_vert_groups"]) start = Pose( - jnp.array(data["start_pose"]["p"], dtype=float), - jnp.array(data["start_pose"]["hd"], dtype=float), + jnp.array(program_data["start_pose"]["p"], dtype=float), + jnp.array(program_data["start_pose"]["hd"], dtype=float), ) - cs = jnp.array([[c["ds"], c["dhd"]] for c in data["program_controls"]]) + cs = jnp.array([[c["ds"], c["dhd"]] for c in program_data["program_controls"]]) controls = Control(cs[:, 0], cs[:, 1]) return make_world(walls_vec, clutters_vec, start, controls) @@ -426,7 +440,7 @@ def pose_plot(p, fill: str | Any = "black", **opts): strokeWidth=2, stroke="#ccc", ), - {"margin": 0, "inset": 50, "width": 500, "axis": None, "aspectRatio": 1}, + {"margin": 0, "inset": 50, "maxWidth": 500, "axis": None, "aspectRatio": 1}, Plot.domain([0, 20]), ) # Plot the world with walls only @@ -807,7 +821,7 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): + Plot.color_map({"some pose": "green", "with heading modified": "red"}) + Plot.title("Modifying a heading") ) - | html("span.tc", f"score ratio: {rotated_trace_weight_diff}") + | html(["span.tc", f"score ratio: {rotated_trace_weight_diff}"]) ) # %% [markdown] @@ -834,7 +848,7 @@ def animate_path_with_confidence(path: Pose, motion_settings: dict): for pose in path_from_trace(trace) ] + Plot.color_map({"some path": "green", "with heading modified": "red"}) -) | html("span.tc", f"score ratio: {rotated_first_step_weight_diff}") +) | html(["span.tc", f"score ratio: {rotated_first_step_weight_diff}"]) # %% [markdown] # ### Ideal sensors @@ -1127,7 +1141,7 @@ def plt(readings): return Plot.new( plot_base or Plot.domain([0, 20]), plot_sensors(pose, readings), - {"width": 400, "height": 400}, + {"maxWidth": 400, "aspectRatio": 1}, ) return plt(readings1) & plt(readings2) @@ -1169,13 +1183,13 @@ def plt(readings): sample, log_weight = model_importance( sub_key, constraints_low_deviation, (motion_settings_low_deviation,) ) -animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") +animate_full_trace(sample) | html(["span.tc", f"log_weight: {log_weight}"]) # %% key, sub_key = jax.random.split(key) sample, log_weight = model_importance( sub_key, constraints_high_deviation, (motion_settings_high_deviation,) ) -animate_full_trace(sample) | html("span.tc", f"log_weight: {log_weight}") +animate_full_trace(sample) | html(["span.tc", f"log_weight: {log_weight}"]) # %% [markdown] # A trace resulting from a call to `importance` is structurally indistinguishable from one drawn from `simulate`. But there is a key situational difference: while `get_score` always returns the frequency with which `simulate` stochastically produces the trace, this value is **no longer equal to** the frequency with which the trace is stochastically produced by `importance`. This is both true in an obvious and less relevant sense, as well as true in a more subtle and extremely germane sense. # @@ -1256,9 +1270,9 @@ def constraint_from_path(path): Plot.Row( *[ ( - html("div.f3.b.tc", title) + html(["div.f3.b.tc", title]) | animate_full_trace(trace, frame_key="frame") - | html("span.tc", f"score: {score:,.2f}") + | html(["span.tc", f"score: {score:,.2f}"]) ) for (title, trace, motion_settings, score) in [ [ @@ -1521,8 +1535,8 @@ class SequentialImportanceSampling(Generic[StateT, ControlT]): Given: - a functional wrapper for the importance method of a generative function - an initial state of type StateT, which should be a PyTree $z_0$ - - a vector of control inputs, also a PyTree $u_i, of shape $(T, \ldots)$ - - an array of observations $y_i$, also of shape $(T, \ldots)$ + - a vector of control inputs, also a PyTree $u_i, of shape $(T, \\ldots)$ + - an array of observations $y_i$, also of shape $(T, \\ldots)$ perform the inference technique known as Sequential Importance Sampling. The signature of the GFI importance method is diff --git a/poetry.lock b/poetry.lock index cfcb83f..8bae674 100644 --- a/poetry.lock +++ b/poetry.lock @@ -546,6 +546,13 @@ files = [ {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, + {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, @@ -676,6 +683,7 @@ description = "Python AST that abstracts the underlying Python version" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" files = [ + {file = "gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54"}, {file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"}, ] @@ -711,13 +719,13 @@ reference = "gcp" [[package]] name = "genstudio" -version = "2024.10.1" +version = "2025.1.12" description = "" optional = false -python-versions = ">=3.10,<3.13" +python-versions = "<3.13,>=3.11" files = [ - {file = "genstudio-2024.10.1-py3-none-any.whl", hash = "sha256:c95cffb1e3d9ca8d9424a535ba227c3e8ecbdc95673e907f9da78e89d6c77b3c"}, - {file = "genstudio-2024.10.1.tar.gz", hash = "sha256:279d461dbec2c6d58f27c99216d9199f40f233a7add506f1c909cf48e9aff8e7"}, + {file = "genstudio-2025.1.12-py3-none-any.whl", hash = "sha256:7b821a715fa191f55d83294ce22fce34b882709d5fc5e5ae8909f73d15ca0d2a"}, + {file = "genstudio-2025.1.12.tar.gz", hash = "sha256:18009d80da979350dd72a792caf84482316d407db5dd0479afe4efe1af3099a0"}, ] [package.dependencies] @@ -727,11 +735,6 @@ orjson = ">=3.10.6,<4.0.0" pillow = ">=10.4.0,<11.0.0" traitlets = ">=5.14.3,<6.0.0" -[package.source] -type = "legacy" -url = "https://us-west1-python.pkg.dev/probcomp-caliban/probcomp/simple" -reference = "gcp" - [[package]] name = "html2image" version = "2.0.5" @@ -2665,4 +2668,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "42925b13162b17dff6f9ac5a701991d53a8de8a7f653847c1b18ae6d3a480d1b" +content-hash = "9cdd74762b93a818ec1ec1884b8683d86e7778b070b84a7cdbda098f046cbffb" diff --git a/pyproject.toml b/pyproject.toml index d53ba76..6f3138e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ package-mode = false python = ">=3.11,<3.13" jupytext = "^1.16.1" genjax = {version = "0.7.0.post4.dev0+eacb241e" , source = "gcp" } -genstudio = {version = "2024.10.1", source = "gcp"} +genstudio = "2025.1.12" ipykernel = "^6.29.3" matplotlib = "^3.8.3" anywidget = "^0.9.7"