Skip to content

Commit

Permalink
make check outputs visually more appealing
Browse files Browse the repository at this point in the history
* implement AssertResult that allows formatting of assert
  results
* use answer_key as CheckableWidget name
* handle error during check more transparent to student by catching
  errors during fingerprint function and asserts and embedding them into
  the assert message
* implement input parameters for Checks
  - suppress_fingerprint_asserts: specifies if the assert messages that
    use the fingerprint function result are suppressed
  - stop_on_assert_error_raised: Specifies if running the asserts is
    stopped as soon as an error is raised in an assert
* format check results more structured
  • Loading branch information
DivyaSuman14 authored and agoscinski committed Dec 25, 2023
1 parent 6842420 commit 9687211
Show file tree
Hide file tree
Showing 8 changed files with 386 additions and 70 deletions.
48 changes: 44 additions & 4 deletions src/scwidgets/_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,56 @@
import re

from termcolor import colored


class Printer:
# move to output
# TODO rename to Formatter
# remove print funcs
LINE_LENGTH = 120
INFO_COLOR = "blue"
ERROR_COLOR = "red"
SUCCESS_COLOR = "green"

@staticmethod
def format_title_message(message: str) -> str:
return message.center(Printer.LINE_LENGTH - len(message) // 2, "-")

@staticmethod
def break_lines(message: str) -> str:
return "\n ".join(re.findall(r".{1," + str(Printer.LINE_LENGTH) + "}", message))

@staticmethod
def color_error_message(message: str) -> str:
return colored(message, Printer.ERROR_COLOR, attrs=["bold"])

@staticmethod
def print_error_message(message: str):
print(colored(message, "red", attrs=["bold"]))
print(Printer.color_error_message(message))

@staticmethod
def color_success_message(message: str) -> str:
return colored(message, Printer.SUCCESS_COLOR, attrs=["bold"])

@staticmethod
def print_success_message(message: str):
print(colored(message, "green", attrs=["bold"]))
print(Printer.color_success_message(message))

@staticmethod
def color_info_message(message: str):
return colored(message, Printer.INFO_COLOR, attrs=["bold"])

@staticmethod
def print_info_message(message: str):
print(colored(message, "blue", attrs=["bold"]))
print(Printer.color_info_message(message))

@staticmethod
def color_assert_failed(message: str) -> str:
return colored(message, "light_" + Printer.ERROR_COLOR)

@staticmethod
def color_assert_info(message: str) -> str:
return colored(message, "light_" + Printer.INFO_COLOR)

@staticmethod
def color_assert_success(message: str) -> str:
return colored(message, "light_" + Printer.SUCCESS_COLOR)
3 changes: 2 additions & 1 deletion src/scwidgets/check/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
assert_shape,
assert_type,
)
from ._check import Check, ChecksLog
from ._check import AssertResult, Check, ChecksLog
from ._widget_check_registry import CheckableWidget, CheckRegistry

__all__ = [
"Check",
"ChecksLog",
"AssertResult",
"CheckRegistry",
"CheckableWidget",
"assert_shape",
Expand Down
96 changes: 75 additions & 21 deletions src/scwidgets/check/_asserts.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import functools
from collections import abc
from typing import Iterable, TypeVar, Union
from typing import Iterable, Union

import numpy as np

from ._check import Check
from ._check import AssertResult, Check

AssertResultT = TypeVar("AssertResultT", bound="str")
AssertFunctionOutputT = Union[str, AssertResult]


def assert_shape(
output_parameters: Check.FunOutParamsT,
output_references: Check.FunOutParamsT,
parameters_to_check: Union[Iterable[int], str] = "auto",
) -> str:
) -> AssertResult:
assert len(output_parameters) == len(
output_references
), "output_parameters and output_references have to have the same length"
Expand All @@ -22,7 +22,7 @@ def assert_shape(
if isinstance(parameters_to_check, str):
if parameters_to_check == "auto":
parameter_indices = []
for i in range(len(output_parameters)):
for i in range(len(output_references)):
if hasattr(output_references[i], "shape"):
parameter_indices.append(i)
elif parameters_to_check == "all":
Expand All @@ -40,13 +40,25 @@ def assert_shape(
f"but got type {type(parameters_to_check)}."
)

failed_parameter_indices = []
failed_parameter_values = []
messages = []
for i in parameter_indices:
if output_parameters[i].shape != output_references[i].shape:
return (
f"For parameter {i} expected shape {output_references[i].shape} "
message = (
f"Expected shape {output_references[i].shape})"
f"but got {output_parameters[i].shape}."
)
return ""
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
messages.append(message)

return AssertResult(
assert_name="assert_shape",
parameter_indices=failed_parameter_indices,
parameter_values=failed_parameter_values,
messages=messages,
)


def assert_numpy_allclose(
Expand All @@ -56,7 +68,7 @@ def assert_numpy_allclose(
rtol=1e-05,
atol=1e-08,
equal_nan=False,
) -> str:
) -> AssertResult:
assert len(output_parameters) == len(
output_references
), "output_parameters and output_references have to have the same length"
Expand Down Expand Up @@ -86,6 +98,9 @@ def assert_numpy_allclose(
f"but got type {type(parameters_to_check)}."
)

failed_parameter_indices = []
failed_parameter_values = []
messages = []
for i in parameter_indices:
is_allclose = np.allclose(
output_parameters[i],
Expand All @@ -101,18 +116,28 @@ def assert_numpy_allclose(
)
abs_diff = np.sum(diff)
rel_diff = np.sum(diff / np.abs(output_references[i]))
return (
f"Output parameter {i} is not close to reference absolute difference "

message = (
f"Output is not close to reference absolute difference "
f"is {abs_diff}, relative difference is {rel_diff}."
)
return ""
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
messages.append(message)

return AssertResult(
assert_name="assert_numpy_allclose",
parameter_indices=failed_parameter_indices,
parameter_values=failed_parameter_values,
messages=messages,
)


def assert_type(
output_parameters: Check.FunOutParamsT,
output_references: Check.FunOutParamsT,
parameters_to_check: Union[Iterable[int], str] = "all",
) -> str:
) -> AssertResult:
assert len(output_parameters) == len(
output_references
), "output_parameters and output_references have to have the same length"
Expand All @@ -134,20 +159,31 @@ def assert_type(
f"but got type {type(parameters_to_check)}."
)

failed_parameter_indices = []
failed_parameter_values = []
messages = []
for i in parameter_indices:
if not (isinstance(output_parameters[i], type(output_references[i]))):
return (
message = (
f"Expected type {type(output_references[i])} "
f"but got {type(output_parameters[i])}."
)
return ""
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
messages.append(message)
return AssertResult(
assert_name="assert_type",
parameter_indices=failed_parameter_indices,
parameter_values=failed_parameter_values,
messages=messages,
)


def assert_numpy_sub_dtype(
output_parameters: Union[Check.FunOutParamsT, tuple[Check.FingerprintT]],
numpy_type: Union[np.dtype, type],
parameters_to_check: Union[Iterable[int], str] = "all",
) -> str:
) -> AssertResult:
if parameters_to_check == "all":
parameter_indices = range(len(output_parameters))
elif isinstance(parameters_to_check, abc.Iterable):
Expand All @@ -158,23 +194,41 @@ def assert_numpy_sub_dtype(
f"but got type {type(parameters_to_check)}."
)

failed_parameter_indices = []
failed_parameter_values = []
messages = []
for i in parameter_indices:
if not (isinstance(output_parameters[i], np.ndarray)):
return (
f"Output parameter {i} expected to be numpy array "
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
message = (
f"Output expected to be numpy array "
f"but got {type(output_parameters[i])}."
)
messages.append(message)
if not (np.issubdtype(output_parameters[i].dtype, numpy_type)):
if isinstance(numpy_type, np.dtype):
type_name = numpy_type.type.__name__
else:
type_name = numpy_type.__name__
return (
f"Output parameter {i} expected to be sub dtype "
failed_parameter_indices.append(i)
failed_parameter_values.append(output_parameters[i])
message = (
f"Output expected to be sub dtype "
f"numpy.{type_name} but got "
f"numpy.{output_parameters[i].dtype.type.__name__}."
)
return ""
messages.append(message)
if isinstance(numpy_type, np.dtype):
type_name = numpy_type.type.__name__
else:
type_name = numpy_type.__name__
return AssertResult(
assert_name=f"assert_numpy_{type_name}_sub_dtype",
parameter_indices=failed_parameter_indices,
parameter_values=failed_parameter_values,
messages=messages,
)


assert_numpy_floating_sub_dtype = functools.partial(
Expand Down
Loading

0 comments on commit 9687211

Please sign in to comment.