Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial stab at DerivedSignal #525 #661

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions src/ophyd_async/core/_derived_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import asyncio
import dataclasses
from collections.abc import Iterator
from typing import Generic, TypeVar, get_type_hints

import numpy as np

from ._device import Device
from ._protocol import AsyncMovable
from ._signal import SignalR, SignalRW, soft_signal_rw
from ._signal_backend import Array1D, SignalBackend
from ._status import AsyncStatus
from ._utils import T, get_origin_class

RawSignalsT = TypeVar("RawSignalsT")
ParametersSignalsT = TypeVar("ParametersSignalsT")
RawT = TypeVar("RawT")
DerivedT = TypeVar("DerivedT")
ParametersT = TypeVar("ParametersT")


class Transform(Generic[RawT, DerivedT, ParametersT]):
def forward(self, raw: RawT, parameters: ParametersT) -> DerivedT: ...
def inverse(self, derived: DerivedT, parameters: ParametersT) -> RawT: ...


F_contra = TypeVar("F_contra", bound=float | Array1D[np.float64], contravariant=True)


# TODO: should this be a TypedDict?
@dataclasses.dataclass
class SlitsRaw(Generic[F_contra]):
top: F_contra
bottom: F_contra


@dataclasses.dataclass
class SlitsDerived(Generic[F_contra]):
gap: F_contra
centre: F_contra


@dataclasses.dataclass
class SlitsParameters:
gap_offset: float


class SlitsTransform(Transform[SlitsRaw, SlitsDerived, SlitsParameters]):
def forward(
self, raw: SlitsRaw[F_contra], parameters: SlitsParameters
) -> SlitsDerived[F_contra]:
return SlitsDerived(
gap=raw.top - raw.bottom + parameters.gap_offset,
centre=(raw.top + raw.bottom) / 2,
)

def inverse(
self, derived: SlitsDerived[F_contra], parameters: SlitsParameters
) -> SlitsRaw[F_contra]:
half_gap = (derived.gap - parameters.gap_offset) / 2
return SlitsRaw(
top=derived.centre + half_gap,
bottom=derived.centre - half_gap,
)


def _get_dataclass_args(method) -> Iterator[type]:
for k, v in get_type_hints(method):
cls = get_origin_class(v)
if k != "return" and dataclasses.is_dataclass(cls):
yield cls


async def _get_dataclass_from_signals(cls: type[T], device: Device) -> T:
coros = {}
for field in dataclasses.fields(cls):
sig = getattr(device, field.name)
assert isinstance(
sig, SignalR
), f"{device.name}.{field.name} is {sig}, not a Signal"
coros[field.name] = sig.get_value()
results = await asyncio.gather(*coros.values())
kwargs = dict(zip(coros, results, strict=True))
return cls(**kwargs)


class DerivedBackend(Generic[RawT, DerivedT, ParametersT]):
def __init__(
self,
device: Device,
transform: Transform[RawT, DerivedT, ParametersT],
):
self._device = device
self._transform = transform
self._raw_cls, self._param_cls = _get_dataclass_args(self._transform.forward)

async def get_parameters(self) -> ParametersT:
return await _get_dataclass_from_signals(self._param_cls, self._device)

async def get_raw_values(self) -> RawT:
return await _get_dataclass_from_signals(self._raw_cls, self._device)

async def get_derived_values(self) -> DerivedT:
raw, parameters = await asyncio.gather(
self.get_raw_values(), self.get_parameters()
)
return self._transform.forward(raw, parameters)

async def set_derived_values(self, derived: DerivedT):
assert isinstance(self._device, AsyncMovable)
await self._device.set(derived)

async def calculate_raw_values(self, derived: DerivedT) -> RawT:
return self._transform.inverse(derived, await self.get_parameters())

def derived_signal(self, variable: str) -> SignalRW[float]:
return SignalRW(DerivedSignalBackend(self, variable))


class DerivedSignalBackend(SignalBackend[float]):
def __init__(self, backend: DerivedBackend, variable: str):
self._backend = backend
self._variable = variable
super().__init__(float)

async def get_value(self) -> float:
derived = await self._backend.get_derived_values()
return getattr(derived, self._variable)

async def put(self, value: float | None, wait: bool):
derived = await self._backend.get_derived_values()
# TODO: we should be calling locate on these as we want to move relative to the
# setpoint, not readback
setattr(derived, self._variable, value)
await self._backend.set_derived_values(derived)


class Slits(Device):
def __init__(self, name=""):
self._backend = DerivedBackend(self, SlitsTransform())
# Raw signals
self.top = soft_signal_rw(float)
self.bottom = soft_signal_rw(float)
# Parameter
self.gap_offset = soft_signal_rw(float)
# Derived signals
self.gap = self._backend.derived_signal("gap")
self.centre = self._backend.derived_signal("centre")
super().__init__(name=name)

@AsyncStatus.wrap
async def set(self, derived: SlitsDerived[float]) -> None:
raw: SlitsRaw[float] = await self._backend.calculate_raw_values(derived)
await asyncio.gather(self.top.set(raw.top), self.bottom.set(raw.bottom))
8 changes: 8 additions & 0 deletions src/ophyd_async/core/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,14 @@ def unstage(self) -> AsyncStatus:
"""


@runtime_checkable
class AsyncMovable(Protocol):
@abstractmethod
def set(self, value) -> AsyncStatus:
"""Return a ``Status`` that is marked done when the device is done moving."""
...


C = TypeVar("C", contravariant=True)


Expand Down
Loading