Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define validate shape of descriptor #788

Merged
merged 4 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions databroker/mongo_normalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def __init__(
event_collection,
root_map,
sub_dict,
validate_shape,
):
self._run = run
self._stream_name = stream_name
Expand All @@ -522,6 +523,7 @@ def __init__(
self._event_collection = event_collection
self._sub_dict = sub_dict
self.root_map = root_map
self.validate_shape = validate_shape

# metadata should look like
# {
Expand Down Expand Up @@ -851,7 +853,7 @@ def populate_columns(keys, min_seq_num, max_seq_num):
if expected_shape and (not is_external):
validated_column = list(
map(
lambda item: _validate_shape(
lambda item: self.validate_shape(
key, numpy.asarray(item), expected_shape
),
result[key],
Expand Down Expand Up @@ -936,7 +938,7 @@ def populate_columns(keys, min_seq_num, max_seq_num):
last_datum_id=None,
)
filled_data = filled_mock_event["data"][key]
validated_filled_data = _validate_shape(
validated_filled_data = self.validate_shape(
key, filled_data, expected_shape
)
filled_column.append(validated_filled_data)
Expand Down Expand Up @@ -1047,6 +1049,7 @@ def from_uri(
access_policy=None,
cache_ttl_complete=60, # seconds
cache_ttl_partial=2, # seconds
validate_shape=None
):
"""
Create a MongoAdapter from MongoDB with the "normalized" (original) layout.
Expand Down Expand Up @@ -1094,6 +1097,9 @@ def from_uri(
cache_ttl_complete : float
Time (in seconds) to cache a *complete* BlueskyRun before checking
the database for updates. Default 60.
validate_shape: func
function that will be used to validate that the shape of the data matches
the shape in the descriptor document
"""
metadatastore_db = _get_database(uri)
if asset_registry_uri is None:
Expand Down Expand Up @@ -1122,6 +1128,7 @@ def from_uri(
cache_of_partial_bluesky_runs=cache_of_partial_bluesky_runs,
metadata=metadata,
access_policy=access_policy,
validate_shape=validate_shape,
)

@classmethod
Expand All @@ -1135,6 +1142,7 @@ def from_mongomock(
access_policy=None,
cache_ttl_complete=60, # seconds
cache_ttl_partial=2, # seconds
validate_shape=None
):
"""
Create a transient MongoAdapter from backed by "mongomock".
Expand Down Expand Up @@ -1178,6 +1186,9 @@ def from_mongomock(
cache_ttl_complete : float
Time (in seconds) to cache a *complete* BlueskyRun before checking
the database for updates. Default 60.
validate_shape: func
function that will be used to validate that the shape of the data matches
the shape in the descriptor document
"""
import mongomock

Expand Down Expand Up @@ -1205,6 +1216,7 @@ def from_mongomock(
cache_of_partial_bluesky_runs=cache_of_partial_bluesky_runs,
metadata=metadata,
access_policy=access_policy,
validate_shape=validate_shape,
)

def __init__(
Expand All @@ -1220,6 +1232,7 @@ def __init__(
queries=None,
sorting=None,
access_policy=None,
validate_shape=None,
):
"This is not user-facing. Use MongoAdapter.from_uri."
self._run_start_collection = metadatastore_db.get_collection("run_start")
Expand Down Expand Up @@ -1249,6 +1262,11 @@ def __init__(
self._sorting = sorting
self.access_policy = access_policy
self._serializer = None
if validate_shape is None:
validate_shape = default_validate_shape
elif isinstance(validate_shape, str):
validate_shape = import_object(validate_shape)
self.validate_shape = validate_shape
super().__init__()

@property
Expand Down Expand Up @@ -1441,6 +1459,7 @@ def _build_event_stream(self, *, run_start_uid, stream_name, is_complete):
event_collection=self._event_collection,
root_map=self.root_map,
sub_dict="data",
validate_shape=self.validate_shape,
),
"timestamps": lambda: DatasetFromDocuments(
run=run,
Expand All @@ -1450,6 +1469,7 @@ def _build_event_stream(self, *, run_start_uid, stream_name, is_complete):
event_collection=self._event_collection,
root_map=self.root_map,
sub_dict="timestamps",
validate_shape=self.validate_shape,
),
"config": lambda: Config(
OneShotCachedMap(
Expand Down Expand Up @@ -2095,7 +2115,7 @@ class BadShapeMetadata(Exception):
pass


def _validate_shape(key, data, expected_shape):
def default_validate_shape(key, data, expected_shape):
"""
Check that data.shape == expected.shape.

Expand Down
31 changes: 31 additions & 0 deletions databroker/tests/test_validate_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from bluesky import RunEngine
from bluesky.plans import count
from ophyd.sim import img
from tiled.client import Context, from_context
from tiled.server.app import build_app

from ..mongo_normalized import MongoAdapter


def test_validate_shape(tmpdir):
# custom_validate_shape will mutate this to show it has been called
shapes = []

def custom_validate_shape(key, data, expected_shape):
shapes.append(expected_shape)
return data

adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape)

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([img]))
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes
Loading