diff --git a/updown/data/datasets.py b/updown/data/datasets.py index 0b92014..073c99b 100644 --- a/updown/data/datasets.py +++ b/updown/data/datasets.py @@ -222,6 +222,7 @@ def __init__( hierarchy_jsonpath: str, nms_threshold: float = 0.85, max_given_constraints: int = 3, + max_words_per_constraint: int = 3, in_memory: bool = True, ): super().__init__(image_features_h5path, in_memory=in_memory) @@ -234,7 +235,7 @@ def __init__( self._constraint_filter = ConstraintFilter( hierarchy_jsonpath, nms_threshold, max_given_constraints ) - self._fsm_builder = FiniteStateMachineBuilder(vocabulary, wordforms_tsvpath) + self._fsm_builder = FiniteStateMachineBuilder(vocabulary, wordforms_tsvpath, max_given_constraints, max_words_per_constraint) @classmethod def from_config(cls, config: Config, **kwargs): @@ -247,6 +248,8 @@ def from_config(cls, config: Config, **kwargs): boxes_jsonpath=_C.DATA.CBS.INFER_BOXES, wordforms_tsvpath=_C.DATA.CBS.WORDFORMS, hierarchy_jsonpath=_C.DATA.CBS.CLASS_HIERARCHY, + max_given_constraints=_C.DATA.CBS.MAX_GIVEN_CONSTRAINTS, + max_words_per_constraint=_C.DATA.CBS.MAX_WORDS_PER_CONSTRAINT, in_memory=kwargs.pop("in_memory"), )