Skip to content

Commit

Permalink
Merge pull request #652 from liyier90/feat-threshold-checker-interval…
Browse files Browse the repository at this point in the history
…-string

Feat: threshold checker interval string
  • Loading branch information
ongtw authored May 26, 2022
2 parents ac10c34 + 24dbbc9 commit 8cf1be8
Show file tree
Hide file tree
Showing 28 changed files with 317 additions and 398 deletions.
1 change: 1 addition & 0 deletions lint_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ click == 7.1.2
colorama == 0.4.4
numpy == 1.17.3
opencv-contrib-python >= 4.5.2.54
protobuf <= 3.20.1
pyyaml >= 5.3
requests == 2.24.0
tensorflow == 2.2.0
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/augment/brightness.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Node(ThresholdCheckerMixin, AbstractNode):
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any) -> None:
super().__init__(config, node_path=__name__, **kwargs)

self.check_bounds("beta", (-100, 100), "within")
self.check_bounds("beta", "[-100, 100]")

def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Adjusts the brightness of an image frame.
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/augment/contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Node(ThresholdCheckerMixin, AbstractNode):
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any) -> None:
super().__init__(config, node_path=__name__, **kwargs)

self.check_bounds("alpha", (0, 3), "within")
self.check_bounds("alpha", "[0, 3]")

def run(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Adjusts the contrast of an image frame.
Expand Down
237 changes: 76 additions & 161 deletions peekingduck/pipeline/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,136 +17,93 @@
import hashlib
import operator
import os
import re
import sys
import zipfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union

import requests
from tqdm import tqdm

BASE_URL = "https://storage.googleapis.com/peekingduck/models"
PEEKINGDUCK_WEIGHTS_SUBDIR = "peekingduck_weights"

Number = Union[float, int]


class ThresholdCheckerMixin:
"""Mixin class providing utility methods for checking validity of config
values, typically thresholds.
"""

def check_bounds(
self,
key: Union[str, List[str]],
value: Union[Number, Tuple[Number, Number]],
method: str,
include: Optional[str] = "both",
) -> None:
"""Checks if the configuration value(s) specified by `key` satisties
interval_pattern = re.compile(
r"^[\[\(]\s*[-+]?(inf|\d*\.?\d+)\s*,\s*[-+]?(inf|\d*\.?\d+)\s*[\]\)]$"
)

def check_bounds(self, key: Union[str, List[str]], interval: str) -> None:
"""Checks if the configuration value(s) specified by `key` satisfies
the specified bounds.
Args:
key (Union[str, List[str]]): The specified key or list of keys.
value (Union[Number, Tuple[Number, Number]]): Either a single
number to specify the upper or lower bound or a tuple of
numbers to specify both the upper and lower bounds.
method (str): The bounds checking methods, one of
{"above", "below", "both"}. If "above", checks if the
configuration value is above the specified `value`. If "below",
checks if the configuration value is below the specified
`value`. If "both", checks if the configuration value is above
`value[0]` and below `value[1]`.
include (Optional[str]): Indicates if the `value` itself should be
included in the bound, one of {"lower", "upper", "both", None}.
Please see Technotes for details.
interval (str): An mathematical interval representing the range of
valid values. The syntax of the `interval` string is:
<value> = <number> | "-inf" | "+inf"
<left_bracket> = "(" | "["
<right_bracket> = ")" | "]"
<interval> = <left_bracket> <value> "," <value> <right_bracket>
See Technotes for more details.
Raises:
TypeError: `key` type is not in (List[str], str).
TypeError: If `value` is not a tuple of only float/int.
TypeError: If `value` is not a tuple with 2 elements.
TypeError: If `value` is not a float, int, or tuple.
TypeError: If `value` type is not a tuple when `method` is
"within".
TypeError: If `value` type is a tuple when `method` is
"above"/"below".
ValueError: If `method` is not one of {"above", "below", "within"}.
ValueError: If `interval` does not match the specified format.
ValueError: If the lower bound is larger than the upper bound.
ValueError: If the configuration value fails the bounds comparison.
Technotes:
The behavior of `include` depends on the specified `method`. The
table below shows the comparison done for various argument
combinations.
+-----------+---------+-------------------------------------+
| method | include | comparison |
+===========+=========+=====================================+
| | "lower" | config[key] >= value |
+ +---------+-------------------------------------+
| | "upper" | config[key] > value |
+ +---------+-------------------------------------+
| | "both" | config[key] >= value |
+ +---------+-------------------------------------+
| "above" | None | config[key] > value |
+-----------+---------+-------------------------------------+
| | "lower" | config[key] < value |
+ +---------+-------------------------------------+
| | "upper" | config[key] <= value |
+ +---------+-------------------------------------+
| | "both" | config[key] <= value |
+ +---------+-------------------------------------+
| "below" | None | config[key] < value |
+-----------+---------+-------------------------------------+
| | "lower" | value[0] <= config[key] < value[1] |
+ +---------+-------------------------------------+
| | "upper" | value[0] < config[key] <= value[1] |
+ +---------+-------------------------------------+
| | "both" | value[0] <= config[key] <= value[1] |
+ +---------+-------------------------------------+
| "within" | None | value[0] < config[key] < value[1] |
+-----------+---------+-------------------------------------+
The table below shows the comparison done for various interval
expressions.
+---------------------+-------------------------------------+
| interval | comparison |
+=====================+=====================================+
| [lower, +inf] | |
+---------------------+ |
| [lower, +inf) | config[key] >= lower |
+---------------------+-------------------------------------+
| (lower, +inf] | |
+---------------------+ |
| (lower, +inf) | config[key] > lower |
+---------------------+-------------------------------------+
| [-inf, upper] | |
+---------------------+ |
| (-inf, upper] | config[key] <= upper |
+---------------------+-------------------------------------+
| [-inf, upper) | |
+---------------------+ |
| (-inf, upper) | config[key] < upper |
+---------------------+-------------------------------------+
| [lower, upper] | lower <= config[key] <= upper |
+---------------------+-------------------------------------+
| (lower, upper] | lower < config[key] <= upper |
+---------------------+-------------------------------------+
| [lower, upper) | lower <= config[key] < upper |
+---------------------+-------------------------------------+
| (lower, upper) | lower < config[key] < upper |
+---------------------+-------------------------------------+
"""
# available checking methods
methods = {"above", "below", "within"}
# available options of lower/upper bound inclusion
lower_includes = {"lower", "both"}
upper_includes = {"upper", "both"}

if method not in methods:
raise ValueError(f"`method` must be one of {methods}")

if isinstance(value, tuple):
if not all(isinstance(val, (float, int)) for val in value):
raise TypeError(
"When using tuple for `value`, it must be a tuple of float/int"
)
if len(value) != 2:
raise ValueError(
"When using tuple for `value`, it must contain only 2 elements"
)
elif isinstance(value, (float, int)):
pass
else:
raise TypeError(
"`value` must be a float/int or tuple, but you passed a "
f"{type(value).__name__}"
)

if method == "within":
if not isinstance(value, tuple):
raise TypeError("`value` must be a tuple when `method` is 'within'")
self._check_within_bounds(
key, value, (include in lower_includes, include in upper_includes)
)
else:
if isinstance(value, tuple):
raise TypeError(
"`value` must be a float/int when `method` is 'above'/'below'"
)
if method == "above":
self._check_above_value(key, value, include in lower_includes)
elif method == "below":
self._check_below_value(key, value, include in upper_includes)
if self.interval_pattern.match(interval) is None:
raise ValueError("Badly formatted interval")

left_bracket = interval[0]
right_bracket = interval[-1]
lower, upper = [float(value.strip()) for value in interval[1:-1].split(",")]

if lower > upper:
raise ValueError("Lower bound cannot be larger than upper bound")

self._check_within_bounds(key, lower, upper, left_bracket, right_bracket)

def check_valid_choice(
self, key: str, choices: Set[Union[int, float, str]]
Expand All @@ -167,78 +124,36 @@ def check_valid_choice(
if self.config[key] not in choices:
raise ValueError(f"{key} must be one of {choices}")

def _check_above_value(
self, key: Union[str, List[str]], value: Number, inclusive: bool
) -> None:
"""Checks that configuration values specified by `key` is more than
(or equal to) the specified `value`.
Args:
key (Union[str, List[str]]): The specified key or list of keys.
value (Number): The specified value.
inclusive (bool): If `True`, compares `config[key] >= value`. If
`False`, compares `config[key] > value`.
Raises:
TypeError: `key` type is not in (List[str], str).
ValueError: If the configuration value is less than (or equal to)
`value`.
"""
method = operator.ge if inclusive else operator.gt
extra_reason = " or equal to" if inclusive else ""
self._compare(key, value, method, reason=f"more than{extra_reason} {value}")

def _check_below_value(
self, key: Union[str, List[str]], value: Number, inclusive: bool
) -> None:
"""Checks that configuration values specified by `key` is more than
(or equal to) the specified `value`.
Args:
key (Union[str, List[str]]): The specified key or list of keys.
value (Number): The specified value.
inclusive (bool): If `True`, compares `config[key] <= value`. If
`False`, compares `config[key] < value`.
Raises:
TypeError: `key` type is not in (List[str], str).
ValueError: If the configuration value is less than (or equal to)
`value`.
"""
method = operator.le if inclusive else operator.lt
extra_reason = " or equal to" if inclusive else ""
self._compare(key, value, method, reason=f"less than{extra_reason} {value}")

def _check_within_bounds(
def _check_within_bounds( # pylint: disable=too-many-arguments
self,
key: Union[str, List[str]],
bounds: Tuple[Number, Number],
includes: Tuple[bool, bool],
lower: float,
upper: float,
left_bracket: str,
right_bracket: str,
) -> None:
"""Checks that configuration values specified by `key` is within the
specified bounds between `lower` and `upper`.
Args:
key (Union[str, List[str]]): The specified key or list of keys.
(Union[float, int]): The lower bound.
bounds (Tuple[Number, Number]): The lower and upper bounds.
includes (Tuple[bool, bool]): If `True`, compares `config[key] >= value`.
If `False`, compares `config[key] > value`.
inclusive_upper (bool): If `True`, compares `config[key] <= value`.
If `False`, compares `config[key] < value`.
lower (float): The lower bound.
upper (float): The upper bound.
left_bracket (str): Either a "(" for an open lower bound or a "["
for a closed lower bound.
right_bracket (str): Either a ")" for an open upper bound or a "]"
for a closed upper bound.
Raises:
TypeError: `key` type is not in (List[str], str).
ValueError: If the configuration value is not between `lower` and
`upper`.
"""
method_lower = operator.ge if includes[0] else operator.gt
method_upper = operator.le if includes[1] else operator.lt
reason_lower = "[" if includes[0] else "("
reason_upper = "]" if includes[1] else ")"
reason = f"between {reason_lower}{bounds[0]}, {bounds[1]}{reason_upper}"
self._compare(key, bounds[0], method_lower, reason)
self._compare(key, bounds[1], method_upper, reason)
method_lower = operator.ge if left_bracket == "[" else operator.gt
method_upper = operator.le if right_bracket == "]" else operator.lt
reason = f"between {left_bracket}{lower}, {upper}{right_bracket}"
self._compare(key, lower, method_lower, reason)
self._compare(key, upper, method_upper, reason)

def _compare(
self,
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/model/csrnetv1/csrnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds("width", 0, "above", include=None)
self.check_bounds("width", "(0, +inf]")

model_dir = self.download_weights()
self.predictor = Predictor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.logger = logging.getLogger(__name__)

self.check_valid_choice("model_type", {0, 1, 2, 3, 4})
self.check_bounds("score_threshold", (0, 1), "within")
self.check_bounds("score_threshold", "[0, 1]")

model_dir = self.download_weights()
classes_path = model_dir / self.weights["classes_file"]
Expand Down
6 changes: 2 additions & 4 deletions peekingduck/pipeline/nodes/model/fairmotv1/fairmot_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@ def __init__(self, config: Dict[str, Any], frame_rate: float) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds(
["K", "min_box_area", "track_buffer"], 0, "above", include=None
)
self.check_bounds("score_threshold", (0, 1), "within")
self.check_bounds(["K", "min_box_area", "track_buffer"], "(0, +inf]")
self.check_bounds("score_threshold", "[0, 1]")

model_dir = self.download_weights()
self.tracker = Tracker(
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/model/hrnetv1/hrnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds("score_threshold", (0, 1), "within")
self.check_bounds("score_threshold", "[0, 1]")

model_dir = self.download_weights()
self.detector = Detector(
Expand Down
2 changes: 1 addition & 1 deletion peekingduck/pipeline/nodes/model/jdev1/jde_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, config: Dict[str, Any], frame_rate: float) -> None:
self.logger = logging.getLogger(__name__)

self.check_bounds(
["iou_threshold", "nms_threshold", "score_threshold"], (0, 1), "within"
["iou_threshold", "nms_threshold", "score_threshold"], "[0, 1]"
)

model_dir = self.download_weights()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
{"singlepose_lightning", "singlepose_thunder", "multipose_lightning"},
)
self.check_bounds(
["bbox_score_threshold", "keypoint_score_threshold"], (0, 1), "within"
["bbox_score_threshold", "keypoint_score_threshold"], "[0, 1]"
)

model_dir = self.download_weights()
Expand Down
4 changes: 2 additions & 2 deletions peekingduck/pipeline/nodes/model/mtcnnv1/mtcnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.logger = logging.getLogger(__name__)

self.check_bounds("min_size", 0, "above", include=None)
self.check_bounds("min_size", "(0, +inf]")
self.check_bounds(
["network_thresholds", "scale_factor", "score_threshold"], (0, 1), "within"
["network_thresholds", "scale_factor", "score_threshold"], "[0, 1]"
)

model_dir = self.download_weights()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, config: Dict[str, Any]) -> None:
self.logger = logging.getLogger(__name__)

self.check_valid_choice("model_type", {50, 75, 100, "resnet"})
self.check_bounds("score_threshold", (0, 1), "within")
self.check_bounds("score_threshold", "[0, 1]")

model_dir = self.download_weights()
self.predictor = Predictor(
Expand Down
Loading

0 comments on commit 8cf1be8

Please sign in to comment.