Skip to content

Commit

Permalink
type: type-check operator group implementation against its protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed Jun 30, 2024
1 parent 4bca3a0 commit e781071
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
11 changes: 8 additions & 3 deletions elastica/modules/operator_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TypeVar, Generic, Iterator, Callable

from collections.abc import Iterable
from typing import TYPE_CHECKING, TypeVar, Generic, Callable, Any
from collections.abc import Iterable, Iterator

import itertools

Expand Down Expand Up @@ -80,3 +79,9 @@ def add_operators(self, feature: F, operators: list[T]) -> None:
def is_last(self, feature: F) -> bool:
"""Checks if the feature is the last feature in the FIFO."""
return id(feature) == self._operator_ids[-1]


if TYPE_CHECKING:
from elastica.typing import OperatorType

_: Iterable[OperatorType] = OperatorGroupFIFO[OperatorType, Any]()
75 changes: 75 additions & 0 deletions tests/test_modules/test_feature_grouping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from elastica.modules.operator_group import OperatorGroupFIFO
import functools


def test_add_ids():
Expand Down Expand Up @@ -65,3 +66,77 @@ def test_is_last():

assert group.is_last(1) == False
assert group.is_last(2) == True


class TestOperatorGroupingWithCallableModules:
class OperatorTypeA:
def __init__(self):
self.value = 0

def apply(self) -> None:
self.value += 1

class OperatorTypeB:
def __init__(self):
self.value2 = 0

def apply(self) -> None:
self.value2 -= 1

# def test_lambda(self):
# feature_group = OperatorGroupFIFO()

# op_a = self.OperatorTypeA()
# feature_group.append_id(op_a)
# op_b = self.OperatorTypeB()
# feature_group.append_id(op_b)

# for op in [op_a, op_b]:
# func = functools.partial(lambda t: op.apply())
# feature_group.add_operators(op, [func])

# for operator in feature_group:
# operator(t=0)

# assert op_a.value == 1
# assert op_b.value2 == -1

# def test_def(self):
# feature_group = OperatorGroupFIFO()

# op_a = self.OperatorTypeA()
# feature_group.append_id(op_a)
# op_b = self.OperatorTypeB()
# feature_group.append_id(op_b)

# for op in [op_a, op_b]:
# def func(t):
# op.apply()
# feature_group.add_operators(op, [func])

# for operator in feature_group:
# operator(t=0)

# assert op_a.value == 1
# assert op_b.value2 == -1

def test_partial(self):
feature_group = OperatorGroupFIFO()

op_a = self.OperatorTypeA()
feature_group.append_id(op_a)
op_b = self.OperatorTypeB()
feature_group.append_id(op_b)

def _func(t, op):
op.apply()

for op in [op_a, op_b]:
func = functools.partial(_func, op=op)
feature_group.add_operators(op, [func])

for operator in feature_group:
operator(t=0)

assert op_a.value == 1
assert op_b.value2 == -1

0 comments on commit e781071

Please sign in to comment.