diff --git a/pyproject.toml b/pyproject.toml index 8c24913e0..2f631c7d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dev = [ "pytest-rerunfailures", "pytest-timeout", "ruff", + "scanspec==0.7.2", "sphinx<7.4.0", # https://github.com/bluesky/ophyd-async/issues/459 "sphinx-autobuild", "autodoc-pydantic", diff --git a/src/ophyd_async/epics/motor.py b/src/ophyd_async/epics/motor.py index a89024956..ca3609e85 100644 --- a/src/ophyd_async/epics/motor.py +++ b/src/ophyd_async/epics/motor.py @@ -78,6 +78,7 @@ def __init__(self, prefix: str, name="") -> None: self.high_limit_travel = epics_signal_rw(float, prefix + ".HLM") self.motor_stop = epics_signal_x(prefix + ".STOP") + self.encoder_res = epics_signal_rw(float, prefix + ".ERES") # Whether set() should complete successfully or not self._set_success = True diff --git a/src/ophyd_async/fastcs/panda/__init__.py b/src/ophyd_async/fastcs/panda/__init__.py index 29b27d557..6b0edf031 100644 --- a/src/ophyd_async/fastcs/panda/__init__.py +++ b/src/ophyd_async/fastcs/panda/__init__.py @@ -19,6 +19,8 @@ ) from ._trigger import ( PcompInfo, + ScanSpecInfo, + ScanSpecSeqTableTriggerLogic, SeqTableInfo, StaticPcompTriggerLogic, StaticSeqTableTriggerLogic, @@ -34,6 +36,8 @@ "PcompBlock", "PcompDirection", "PulseBlock", + "ScanSpecInfo", + "ScanSpecSeqTableTriggerLogic", "SeqBlock", "TimeUnits", "HDFPanda", diff --git a/src/ophyd_async/fastcs/panda/_trigger.py b/src/ophyd_async/fastcs/panda/_trigger.py index 0aa363376..52518ba2d 100644 --- a/src/ophyd_async/fastcs/panda/_trigger.py +++ b/src/ophyd_async/fastcs/panda/_trigger.py @@ -1,11 +1,15 @@ import asyncio +from typing import Literal +import numpy as np from pydantic import BaseModel, Field +from scanspec.specs import Frames, Path, Spec from ophyd_async.core import FlyerController, wait_for_value +from ophyd_async.epics import motor from ._block import BitMux, PcompBlock, PcompDirection, SeqBlock, TimeUnits -from ._table import SeqTable +from ._table import SeqTable, SeqTrigger class SeqTableInfo(BaseModel): @@ -14,6 +18,11 @@ class SeqTableInfo(BaseModel): prescale_as_us: float = Field(default=1, ge=0) # microseconds +class ScanSpecInfo(BaseModel): + spec: Spec[motor.Motor | Literal["DURATION"]] = Field(default=None) + deadtime: float = Field() + + class StaticSeqTableTriggerLogic(FlyerController[SeqTableInfo]): def __init__(self, seq: SeqBlock) -> None: self.seq = seq @@ -41,6 +50,91 @@ async def stop(self): await wait_for_value(self.seq.active, False, timeout=1) +class ScanSpecSeqTableTriggerLogic(FlyerController[ScanSpecInfo]): + def __init__(self, seq: SeqBlock, name="") -> None: + self.seq = seq + self.name = name + + async def prepare(self, value: ScanSpecInfo): + await self.seq.enable.set(BitMux.zero) + path = Path(value.spec.calculate()) + chunk = path.consume() + gaps = self._calculate_gaps(chunk) + if gaps[0] == 0: + gaps = np.delete(gaps, 0) + scan_size = len(chunk) + + gaps = np.append(gaps, scan_size) + fast_axis = chunk.axes()[len(chunk.axes()) - 2] + # Get the resolution from the PandA Encoder? + resolution = await fast_axis.encoder_res.get_value() + start = 0 + # Wait for GPIO to go low + rows = SeqTable.row(trigger=SeqTrigger.BITA_0) + for gap in gaps: + # Wait for GPIO to go high + rows += SeqTable.row(trigger=SeqTrigger.BITA_1) + # Wait for position + if ( + chunk.midpoints[fast_axis][gap - 1] * resolution + > chunk.midpoints[fast_axis][start] * resolution + ): + trig = SeqTrigger.POSA_GT + dir = False if resolution > 0 else True + + else: + trig = SeqTrigger.POSA_LT + dir = True if resolution > 0 else False + rows += SeqTable.row( + trigger=trig, + position=int( + chunk.midpoints[fast_axis][start] + / await fast_axis.encoder_res.get_value() + ), + ) + + # Time based triggers + rows += SeqTable.row( + repeats=gap - start, + trigger=SeqTrigger.IMMEDIATE, + time1=(chunk.midpoints["DURATION"][0] - value.deadtime) * 10**6, + time2=int(value.deadtime * 10**6), + outa1=True, + outb1=dir, + outa2=False, + outb2=dir, + ) + + # Wait for GPIO to go low + rows += SeqTable.row(trigger=SeqTrigger.BITA_0) + + start = gap + await asyncio.gather( + self.seq.prescale.set(1.0), + self.seq.prescale_units.set(TimeUnits.us), + self.seq.repeats.set(1), + self.seq.table.set(rows), + ) + + async def kickoff(self) -> None: + await self.seq.enable.set(BitMux.one) + await wait_for_value(self.seq.active, True, timeout=1) + + async def complete(self) -> None: + await wait_for_value(self.seq.active, False, timeout=None) + + async def stop(self): + await self.seq.enable.set(BitMux.zero) + await wait_for_value(self.seq.active, False, timeout=1) + + def _calculate_gaps(self, chunk: Frames[motor.Motor]): + inds = np.argwhere(chunk.gap) + if len(inds) == 0: + return [len(chunk)] + else: + return inds + + class PcompInfo(BaseModel): start_postion: int = Field(description="start position in counts") pulse_width: int = Field(description="width of a single pulse in counts", gt=0) diff --git a/tests/fastcs/panda/test_trigger.py b/tests/fastcs/panda/test_trigger.py index 66d6e1c5f..5170d1bed 100644 --- a/tests/fastcs/panda/test_trigger.py +++ b/tests/fastcs/panda/test_trigger.py @@ -3,13 +3,17 @@ import numpy as np import pytest from pydantic import ValidationError +from scanspec.specs import Line, fly from ophyd_async.core import DeviceCollector, set_mock_value +from ophyd_async.epics import motor from ophyd_async.fastcs.core import fastcs_connector from ophyd_async.fastcs.panda import ( CommonPandaBlocks, PcompDirection, PcompInfo, + ScanSpecInfo, + ScanSpecSeqTableTriggerLogic, SeqTable, SeqTableInfo, SeqTrigger, @@ -69,6 +73,53 @@ async def set_active(value: bool): await asyncio.gather(trigger_logic.complete(), set_active(False)) +@pytest.fixture +async def sim_x_motor(): + async with DeviceCollector(mock=True): + sim_motor = motor.Motor("BLxxI-MO-STAGE-01:X", name="sim_x_motor") + + set_mock_value(sim_motor.encoder_res, 0.02) + + yield sim_motor + + +@pytest.fixture +async def sim_y_motor(): + async with DeviceCollector(mock=True): + sim_motor = motor.Motor("BLxxI-MO-STAGE-01:Y", name="sim_x_motor") + + set_mock_value(sim_motor.encoder_res, 0.2) + + yield sim_motor + + +async def test_seq_scanspec_trigger_logic(mock_panda, sim_x_motor, sim_y_motor) -> None: + spec = fly(Line(sim_y_motor, 1, 2, 3) * ~Line(sim_x_motor, 1, 5, 5), 1) + info = ScanSpecInfo(spec=spec, deadtime=0.1) + trigger_logic = ScanSpecSeqTableTriggerLogic(mock_panda.seq[1]) + await trigger_logic.prepare(info) + out = await trigger_logic.seq.table.get_value() + assert (out.repeats == [1, 1, 1, 5, 1, 1, 1, 5, 1, 1, 1, 5, 1]).all() + assert out.trigger == [ + SeqTrigger.BITA_0, + SeqTrigger.BITA_1, + SeqTrigger.POSA_GT, + SeqTrigger.IMMEDIATE, + SeqTrigger.BITA_0, + SeqTrigger.BITA_1, + SeqTrigger.POSA_LT, + SeqTrigger.IMMEDIATE, + SeqTrigger.BITA_0, + SeqTrigger.BITA_1, + SeqTrigger.POSA_GT, + SeqTrigger.IMMEDIATE, + SeqTrigger.BITA_0, + ] + assert (out.position == [0, 0, 50, 0, 0, 0, 250, 0, 0, 0, 50, 0, 0]).all() + assert (out.time1 == [0, 0, 0, 900000, 0, 0, 0, 900000, 0, 0, 0, 900000, 0]).all() + assert (out.time2 == [0, 0, 0, 100000, 0, 0, 0, 100000, 0, 0, 0, 100000, 0]).all() + + @pytest.mark.parametrize( ["kwargs", "error_msg"], [