From a2dc3024828e58560af7cfb0babe095c2d07da73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Sm=C3=B3=C5=82ka?= Date: Sat, 9 Nov 2024 01:25:19 +0100 Subject: [PATCH] Added reprojection error evaluation and conditional updates to `Buffer` --- .../core/transformation/buffer.py | 58 ++++++++++++++++++- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/child_lab_framework/core/transformation/buffer.py b/child_lab_framework/core/transformation/buffer.py index dc22611..384d6b7 100644 --- a/child_lab_framework/core/transformation/buffer.py +++ b/child_lab_framework/core/transformation/buffer.py @@ -7,6 +7,8 @@ from .. import serialization from . import Transformation +from .error import reprojection_error as _reprojection_error +from .interface import ProjectableAndTransformable # T appears both as a method argument and return type, therefore it must be invariant. @@ -74,9 +76,14 @@ def __setitem__( def __getitem__(self, from_to: tuple[T, T]) -> Transformation | None: connections = self.__connections - maybe_result: Transformation | None = connections.get_edge_data(*from_to).get( - 'transformation' - ) + match connections.get_edge_data(*from_to): + case {'transformation': transformation} if isinstance( + transformation, Transformation + ): + maybe_result = transformation + + case _: + maybe_result = None if maybe_result is not None: return maybe_result @@ -100,6 +107,51 @@ def __getitem__(self, from_to: tuple[T, T]) -> Transformation | None: return transformation + def reprojection_error[U: ProjectableAndTransformable]( + self, + evaluated_frame: T, + referential_frame: T, + evaluated_object: U, + referential_object: U, + ) -> float: + transformation = self.__getitem__((evaluated_frame, referential_frame)) + if transformation is None: + return float('inf') + + return _reprojection_error(evaluated_object, referential_object, transformation) + + def update_transformation_if_better[U: ProjectableAndTransformable]( + self, + from_frame: T, + to_frame: T, + from_object: U, + to_object: U, + transformation: Transformation, + ) -> None: + new_error = _reprojection_error(from_object, to_object, transformation) + current_error = self.reprojection_error( + from_frame, + to_frame, + from_object, + to_object, + ) + + if new_error < current_error: + self.__setitem__((from_frame, to_frame), transformation) + + inverse_transformation = transformation.inverse + + new_error = _reprojection_error(to_object, from_object, inverse_transformation) + current_error = self.reprojection_error( + to_frame, + from_frame, + to_object, + from_object, + ) + + if new_error < current_error: + self.__setitem__((to_frame, from_frame), inverse_transformation) + def serialize(self) -> dict[str, serialization.Value]: return { 'frames_of_reference': list(map(str, self.__frames_of_reference)),