Skip to content

Commit

Permalink
Implement labeling function library
Browse files Browse the repository at this point in the history
  • Loading branch information
ti1uan committed Feb 17, 2024
1 parent deac63a commit 71a50c1
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/labeling_function_lib/labeling_function_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Callable, List, Any

class LabelingFunctionLib:
def __init__(self) -> None:
"""Initialize an empty labeling function library."""
self.label_functions = {}

def register(self, name: str, func: Callable[[Any], Any]) -> None:
"""
Register a labeling function with a given name.
Args:
name (str): The name of the labeling function.
func (Callable[[Any], Any]): The labeling function to be registered.
Raises:
TypeError: If the provided argument is not a function.
ValueError: If a label function with the same name has already been registered.
"""
if not callable(func):
raise TypeError("The provided argument is not a function.")
elif name in self.label_functions:
raise ValueError(f"A label function named '{name}' has already been registered.")
else:
self.label_functions[name] = func

def get(self, name: str) -> Callable[[Any], Any]:
"""
Retrieve a label function by name.
Args:
name (str): The name of the label function to retrieve.
Returns:
Callable[[Any], Any]: The label function.
Raises:
LookupError: If the label function with the given name does not exist.
"""
try:
return self.label_functions[name]
except KeyError as e:
raise LookupError(f"No such label function: {e}") from e

def get_all(self) -> List[Callable[[Any], Any]]:
"""
Returns a list of all labeling functions.
Returns:
List[Callable[[Any], Any]]: A list of all labeling functions.
"""
return list(self.label_functions.values())

def unregister(self, name: str) -> None:
"""
Unregisters a labeling function by removing it from the label_functions dictionary.
Args:
name (str): The name of the labeling function to unregister.
Returns:
None
"""
del self.label_functions[name]

def clear(self) -> None:
"""Remove all label functions from the library."""
self.label_functions = {}
Empty file.
46 changes: 46 additions & 0 deletions tests/labeling_function_lib/test_labeling_function_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
from labeling_function_lib.labeling_function_lib import LabelingFunctionLib

def dummy_label_func(item):
return "dummy_label"

def test_register_and_get():
lib = LabelingFunctionLib()
lib.register("dummy", dummy_label_func)
assert lib.get("dummy") == dummy_label_func, "Failed to register or get the label function"

def test_register_existing_name():
lib = LabelingFunctionLib()
lib.register("dummy", dummy_label_func)
with pytest.raises(ValueError):
lib.register("dummy", dummy_label_func)

def test_register_not_callable():
lib = LabelingFunctionLib()
with pytest.raises(TypeError):
lib.register("not_callable", "this is not a function")

def test_get_non_existent():
lib = LabelingFunctionLib()
with pytest.raises(LookupError):
lib.get("non_existent")

def test_get_all():
lib = LabelingFunctionLib()
lib.register("dummy1", dummy_label_func)
lib.register("dummy2", dummy_label_func)
assert len(lib.get_all()) == 2, "Failed to retrieve all label functions"

def test_unregister():
lib = LabelingFunctionLib()
lib.register("dummy", dummy_label_func)
lib.unregister("dummy")
with pytest.raises(LookupError):
lib.get("dummy")

def test_clear():
lib = LabelingFunctionLib()
lib.register("dummy1", dummy_label_func)
lib.register("dummy2", dummy_label_func)
lib.clear()
assert len(lib.get_all()) == 0, "Failed to clear all label functions"

0 comments on commit 71a50c1

Please sign in to comment.