From 71a50c1181a4efcea4595784b6296b0a0417c772 Mon Sep 17 00:00:00 2001 From: Tian Luan Date: Sat, 17 Feb 2024 03:15:04 -0500 Subject: [PATCH] Implement labeling function library --- .../labeling_function_lib.py | 68 +++++++++++++++++++ tests/labeling_function_lib/.gitkeep | 0 .../test_labeling_function_lib.py | 46 +++++++++++++ 3 files changed, 114 insertions(+) create mode 100644 src/labeling_function_lib/labeling_function_lib.py delete mode 100644 tests/labeling_function_lib/.gitkeep create mode 100644 tests/labeling_function_lib/test_labeling_function_lib.py diff --git a/src/labeling_function_lib/labeling_function_lib.py b/src/labeling_function_lib/labeling_function_lib.py new file mode 100644 index 0000000..ac65874 --- /dev/null +++ b/src/labeling_function_lib/labeling_function_lib.py @@ -0,0 +1,68 @@ +from typing import Callable, List, Any + +class LabelingFunctionLib: + def __init__(self) -> None: + """Initialize an empty labeling function library.""" + self.label_functions = {} + + def register(self, name: str, func: Callable[[Any], Any]) -> None: + """ + Register a labeling function with a given name. + + Args: + name (str): The name of the labeling function. + func (Callable[[Any], Any]): The labeling function to be registered. + + Raises: + TypeError: If the provided argument is not a function. + ValueError: If a label function with the same name has already been registered. + """ + if not callable(func): + raise TypeError("The provided argument is not a function.") + elif name in self.label_functions: + raise ValueError(f"A label function named '{name}' has already been registered.") + else: + self.label_functions[name] = func + + def get(self, name: str) -> Callable[[Any], Any]: + """ + Retrieve a label function by name. + + Args: + name (str): The name of the label function to retrieve. + + Returns: + Callable[[Any], Any]: The label function. + + Raises: + LookupError: If the label function with the given name does not exist. + """ + try: + return self.label_functions[name] + except KeyError as e: + raise LookupError(f"No such label function: {e}") from e + + def get_all(self) -> List[Callable[[Any], Any]]: + """ + Returns a list of all labeling functions. + + Returns: + List[Callable[[Any], Any]]: A list of all labeling functions. + """ + return list(self.label_functions.values()) + + def unregister(self, name: str) -> None: + """ + Unregisters a labeling function by removing it from the label_functions dictionary. + + Args: + name (str): The name of the labeling function to unregister. + + Returns: + None + """ + del self.label_functions[name] + + def clear(self) -> None: + """Remove all label functions from the library.""" + self.label_functions = {} \ No newline at end of file diff --git a/tests/labeling_function_lib/.gitkeep b/tests/labeling_function_lib/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/tests/labeling_function_lib/test_labeling_function_lib.py b/tests/labeling_function_lib/test_labeling_function_lib.py new file mode 100644 index 0000000..7d2924c --- /dev/null +++ b/tests/labeling_function_lib/test_labeling_function_lib.py @@ -0,0 +1,46 @@ +import pytest +from labeling_function_lib.labeling_function_lib import LabelingFunctionLib + +def dummy_label_func(item): + return "dummy_label" + +def test_register_and_get(): + lib = LabelingFunctionLib() + lib.register("dummy", dummy_label_func) + assert lib.get("dummy") == dummy_label_func, "Failed to register or get the label function" + +def test_register_existing_name(): + lib = LabelingFunctionLib() + lib.register("dummy", dummy_label_func) + with pytest.raises(ValueError): + lib.register("dummy", dummy_label_func) + +def test_register_not_callable(): + lib = LabelingFunctionLib() + with pytest.raises(TypeError): + lib.register("not_callable", "this is not a function") + +def test_get_non_existent(): + lib = LabelingFunctionLib() + with pytest.raises(LookupError): + lib.get("non_existent") + +def test_get_all(): + lib = LabelingFunctionLib() + lib.register("dummy1", dummy_label_func) + lib.register("dummy2", dummy_label_func) + assert len(lib.get_all()) == 2, "Failed to retrieve all label functions" + +def test_unregister(): + lib = LabelingFunctionLib() + lib.register("dummy", dummy_label_func) + lib.unregister("dummy") + with pytest.raises(LookupError): + lib.get("dummy") + +def test_clear(): + lib = LabelingFunctionLib() + lib.register("dummy1", dummy_label_func) + lib.register("dummy2", dummy_label_func) + lib.clear() + assert len(lib.get_all()) == 0, "Failed to clear all label functions"