-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import Callable, List, Any | ||
|
||
class LabelingFunctionLib: | ||
def __init__(self) -> None: | ||
"""Initialize an empty labeling function library.""" | ||
self.label_functions = {} | ||
|
||
def register(self, name: str, func: Callable[[Any], Any]) -> None: | ||
""" | ||
Register a labeling function with a given name. | ||
Args: | ||
name (str): The name of the labeling function. | ||
func (Callable[[Any], Any]): The labeling function to be registered. | ||
Raises: | ||
TypeError: If the provided argument is not a function. | ||
ValueError: If a label function with the same name has already been registered. | ||
""" | ||
if not callable(func): | ||
raise TypeError("The provided argument is not a function.") | ||
elif name in self.label_functions: | ||
raise ValueError(f"A label function named '{name}' has already been registered.") | ||
else: | ||
self.label_functions[name] = func | ||
|
||
def get(self, name: str) -> Callable[[Any], Any]: | ||
""" | ||
Retrieve a label function by name. | ||
Args: | ||
name (str): The name of the label function to retrieve. | ||
Returns: | ||
Callable[[Any], Any]: The label function. | ||
Raises: | ||
LookupError: If the label function with the given name does not exist. | ||
""" | ||
try: | ||
return self.label_functions[name] | ||
except KeyError as e: | ||
raise LookupError(f"No such label function: {e}") from e | ||
|
||
def get_all(self) -> List[Callable[[Any], Any]]: | ||
""" | ||
Returns a list of all labeling functions. | ||
Returns: | ||
List[Callable[[Any], Any]]: A list of all labeling functions. | ||
""" | ||
return list(self.label_functions.values()) | ||
|
||
def unregister(self, name: str) -> None: | ||
""" | ||
Unregisters a labeling function by removing it from the label_functions dictionary. | ||
Args: | ||
name (str): The name of the labeling function to unregister. | ||
Returns: | ||
None | ||
""" | ||
del self.label_functions[name] | ||
|
||
def clear(self) -> None: | ||
"""Remove all label functions from the library.""" | ||
self.label_functions = {} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import pytest | ||
from labeling_function_lib.labeling_function_lib import LabelingFunctionLib | ||
|
||
def dummy_label_func(item): | ||
return "dummy_label" | ||
|
||
def test_register_and_get(): | ||
lib = LabelingFunctionLib() | ||
lib.register("dummy", dummy_label_func) | ||
assert lib.get("dummy") == dummy_label_func, "Failed to register or get the label function" | ||
|
||
def test_register_existing_name(): | ||
lib = LabelingFunctionLib() | ||
lib.register("dummy", dummy_label_func) | ||
with pytest.raises(ValueError): | ||
lib.register("dummy", dummy_label_func) | ||
|
||
def test_register_not_callable(): | ||
lib = LabelingFunctionLib() | ||
with pytest.raises(TypeError): | ||
lib.register("not_callable", "this is not a function") | ||
|
||
def test_get_non_existent(): | ||
lib = LabelingFunctionLib() | ||
with pytest.raises(LookupError): | ||
lib.get("non_existent") | ||
|
||
def test_get_all(): | ||
lib = LabelingFunctionLib() | ||
lib.register("dummy1", dummy_label_func) | ||
lib.register("dummy2", dummy_label_func) | ||
assert len(lib.get_all()) == 2, "Failed to retrieve all label functions" | ||
|
||
def test_unregister(): | ||
lib = LabelingFunctionLib() | ||
lib.register("dummy", dummy_label_func) | ||
lib.unregister("dummy") | ||
with pytest.raises(LookupError): | ||
lib.get("dummy") | ||
|
||
def test_clear(): | ||
lib = LabelingFunctionLib() | ||
lib.register("dummy1", dummy_label_func) | ||
lib.register("dummy2", dummy_label_func) | ||
lib.clear() | ||
assert len(lib.get_all()) == 0, "Failed to clear all label functions" |