From 4ac8b6121a069de6f520da9793663e28cd0f079a Mon Sep 17 00:00:00 2001 From: MalvinaNikandrou Date: Tue, 5 Dec 2023 02:34:34 +0000 Subject: [PATCH] Rename NLU to CR --- docker/docker-compose.yaml | 2 +- .../api/clients/simbot/__init__.py | 2 +- .../simbot/{nlu_intent.py => cr_intent.py} | 4 +- .../api/controllers/simbot/clients.py | 8 ++-- .../api/controllers/simbot/pipelines.py | 8 ++-- .../common/settings/simbot.py | 4 +- .../datamodels/simbot/__init__.py | 2 +- .../datamodels/simbot/enums/__init__.py | 2 +- .../datamodels/simbot/enums/intents.py | 2 +- .../instruction_handler.py | 38 +++++++++---------- .../parsers/simbot/__init__.py | 2 +- .../simbot/{nlu_output.py => cr_output.py} | 14 +++---- .../simbot/agent_intent_selection.py | 12 +++--- tests/fixtures/clients.py | 14 +++---- tests/parsers/simbot/test_nlu_parser.py | 28 +++++++------- 15 files changed, 71 insertions(+), 71 deletions(-) rename src/emma_experience_hub/api/clients/simbot/{nlu_intent.py => cr_intent.py} (88%) rename src/emma_experience_hub/parsers/simbot/{nlu_output.py => cr_output.py} (72%) diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 442381b7..5c286edb 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -47,7 +47,7 @@ services: <<: *healthcheck-defaults volumes: *model-volume entrypoint: python - command: src/emma_policy/commands/run_simbot_nlu.py + command: src/emma_policy/commands/run_simbot_cr.py instruction_predictor: container_name: instruction_predictor diff --git a/src/emma_experience_hub/api/clients/simbot/__init__.py b/src/emma_experience_hub/api/clients/simbot/__init__.py index a6770962..1cd05664 100644 --- a/src/emma_experience_hub/api/clients/simbot/__init__.py +++ b/src/emma_experience_hub/api/clients/simbot/__init__.py @@ -5,7 +5,7 @@ SimBotExtractedFeaturesClient, SimBotPydanticCacheClient, ) +from emma_experience_hub.api.clients.simbot.cr_intent import SimBotCRIntentClient from emma_experience_hub.api.clients.simbot.features import SimBotFeaturesClient -from emma_experience_hub.api.clients.simbot.nlu_intent import SimBotNLUIntentClient from emma_experience_hub.api.clients.simbot.placeholder_vision import SimBotPlaceholderVisionClient from emma_experience_hub.api.clients.simbot.session_db import SimBotSessionDbClient diff --git a/src/emma_experience_hub/api/clients/simbot/nlu_intent.py b/src/emma_experience_hub/api/clients/simbot/cr_intent.py similarity index 88% rename from src/emma_experience_hub/api/clients/simbot/nlu_intent.py rename to src/emma_experience_hub/api/clients/simbot/cr_intent.py index 5792e415..ef5b1eb6 100644 --- a/src/emma_experience_hub/api/clients/simbot/nlu_intent.py +++ b/src/emma_experience_hub/api/clients/simbot/cr_intent.py @@ -4,8 +4,8 @@ from emma_experience_hub.api.clients.emma_policy import EmmaPolicyClient -class SimBotNLUIntentClient(EmmaPolicyClient): - """API Client for SimBot NLU.""" +class SimBotCRIntentClient(EmmaPolicyClient): + """API Client for SimBot CR.""" def generate( self, diff --git a/src/emma_experience_hub/api/controllers/simbot/clients.py b/src/emma_experience_hub/api/controllers/simbot/clients.py index 655b00ac..d720b50f 100644 --- a/src/emma_experience_hub/api/controllers/simbot/clients.py +++ b/src/emma_experience_hub/api/controllers/simbot/clients.py @@ -11,9 +11,9 @@ from emma_experience_hub.api.clients.simbot import ( SimbotActionPredictionClient, SimBotAuxiliaryMetadataClient, + SimBotCRIntentClient, SimBotExtractedFeaturesClient, SimBotFeaturesClient, - SimBotNLUIntentClient, SimBotPlaceholderVisionClient, SimBotSessionDbClient, ) @@ -26,7 +26,7 @@ class SimBotControllerClients(BaseModel, arbitrary_types_allowed=True): _exit = Event() features: SimBotFeaturesClient - nlu_intent: SimBotNLUIntentClient + cr_intent: SimBotCRIntentClient action_predictor: SimbotActionPredictionClient session_db: SimBotSessionDbClient @@ -53,8 +53,8 @@ def from_simbot_settings(cls, simbot_settings: SimBotSettings) -> "SimBotControl session_db=SimBotSessionDbClient( db_file=Path(simbot_settings.session_local_db_file), ), - nlu_intent=SimBotNLUIntentClient( - endpoint=simbot_settings.nlu_predictor_url, + cr_intent=SimBotCRIntentClient( + endpoint=simbot_settings.cr_predictor_url, timeout=simbot_settings.client_timeout, ), action_predictor=SimbotActionPredictionClient( diff --git a/src/emma_experience_hub/api/controllers/simbot/pipelines.py b/src/emma_experience_hub/api/controllers/simbot/pipelines.py index 45c6f1e3..897cb8cd 100644 --- a/src/emma_experience_hub/api/controllers/simbot/pipelines.py +++ b/src/emma_experience_hub/api/controllers/simbot/pipelines.py @@ -4,7 +4,7 @@ from emma_experience_hub.common.settings import SimBotSettings from emma_experience_hub.parsers.simbot import ( SimBotActionPredictorOutputParser, - SimBotNLUOutputParser, + SimBotCROutputParser, SimBotPreviousActionParser, SimBotVisualGroundingOutputParser, ) @@ -54,9 +54,9 @@ def from_clients( environment_intent_extractor=SimBotEnvironmentIntentExtractionPipeline(), agent_intent_selector=SimBotAgentIntentSelectionPipeline( features_client=clients.features, - nlu_intent_client=clients.nlu_intent, - nlu_intent_parser=SimBotNLUOutputParser( - intent_type_delimiter=simbot_settings.nlu_predictor_intent_type_delimiter + cr_intent_client=clients.cr_intent, + cr_intent_parser=SimBotCROutputParser( + intent_type_delimiter=simbot_settings.cr_predictor_intent_type_delimiter ), environment_error_pipeline=SimBotEnvironmentErrorCatchingPipeline(), action_predictor_client=clients.action_predictor, diff --git a/src/emma_experience_hub/common/settings/simbot.py b/src/emma_experience_hub/common/settings/simbot.py index 17a47ab5..8e2b314b 100644 --- a/src/emma_experience_hub/common/settings/simbot.py +++ b/src/emma_experience_hub/common/settings/simbot.py @@ -66,8 +66,8 @@ class SimBotSettings(BaseSettings): feature_extractor_url: AnyHttpUrl = AnyHttpUrl(url=f"{scheme}://0.0.0.0:5500", scheme=scheme) - nlu_predictor_url: AnyHttpUrl = AnyHttpUrl(url=f"{scheme}://0.0.0.0:5501", scheme=scheme) - nlu_predictor_intent_type_delimiter: str = " " + cr_predictor_url: AnyHttpUrl = AnyHttpUrl(url=f"{scheme}://0.0.0.0:5501", scheme=scheme) + cr_predictor_intent_type_delimiter: str = " " action_predictor_url: AnyHttpUrl = AnyHttpUrl(url=f"{scheme}://0.0.0.0:5502", scheme=scheme) diff --git a/src/emma_experience_hub/datamodels/simbot/__init__.py b/src/emma_experience_hub/datamodels/simbot/__init__.py index 78d2dc05..e5242754 100644 --- a/src/emma_experience_hub/datamodels/simbot/__init__.py +++ b/src/emma_experience_hub/datamodels/simbot/__init__.py @@ -6,10 +6,10 @@ from emma_experience_hub.datamodels.simbot.enums import ( SimBotActionStatusType, SimBotActionType, + SimBotCRIntentType, SimBotDummyRawActions, SimBotEnvironmentIntentType, SimBotIntentType, - SimBotNLUIntentType, SimBotPhysicalInteractionIntentType, SimBotUserIntentType, SimBotVerbalInteractionIntentType, diff --git a/src/emma_experience_hub/datamodels/simbot/enums/__init__.py b/src/emma_experience_hub/datamodels/simbot/enums/__init__.py index 47bc3248..6c6c362f 100644 --- a/src/emma_experience_hub/datamodels/simbot/enums/__init__.py +++ b/src/emma_experience_hub/datamodels/simbot/enums/__init__.py @@ -4,9 +4,9 @@ SimBotDummyRawActions, ) from emma_experience_hub.datamodels.simbot.enums.intents import ( + SimBotCRIntentType, SimBotEnvironmentIntentType, SimBotIntentType, - SimBotNLUIntentType, SimBotPhysicalInteractionIntentType, SimBotUserIntentType, SimBotVerbalInteractionIntentType, diff --git a/src/emma_experience_hub/datamodels/simbot/enums/intents.py b/src/emma_experience_hub/datamodels/simbot/enums/intents.py index a1f7f727..fb232b17 100644 --- a/src/emma_experience_hub/datamodels/simbot/enums/intents.py +++ b/src/emma_experience_hub/datamodels/simbot/enums/intents.py @@ -165,7 +165,7 @@ def is_verbal_interaction_intent_type( # noqa: WPS602 SimBotIntentType.generic_success, ] -SimBotNLUIntentType = Literal[ +SimBotCRIntentType = Literal[ SimBotIntentType.act_no_match, SimBotIntentType.act_missing_inventory, SimBotIntentType.act_too_many_matches, diff --git a/src/emma_experience_hub/functions/simbot/agent_intent_selection/instruction_handler.py b/src/emma_experience_hub/functions/simbot/agent_intent_selection/instruction_handler.py index 6fc90629..e11e34a9 100644 --- a/src/emma_experience_hub/functions/simbot/agent_intent_selection/instruction_handler.py +++ b/src/emma_experience_hub/functions/simbot/agent_intent_selection/instruction_handler.py @@ -5,14 +5,14 @@ from emma_common.datamodels import EnvironmentStateTurn, SpeakerRole from emma_experience_hub.api.clients.simbot import ( SimbotActionPredictionClient, + SimBotCRIntentClient, SimBotFeaturesClient, - SimBotNLUIntentClient, ) from emma_experience_hub.datamodels.simbot import ( SimBotAgentIntents, + SimBotCRIntentType, SimBotIntent, SimBotIntentType, - SimBotNLUIntentType, SimBotSession, SimBotUserSpeech, SimBotUtterance, @@ -27,8 +27,8 @@ class SimBotActHandler: def __init__( self, features_client: SimBotFeaturesClient, - nlu_intent_client: SimBotNLUIntentClient, - nlu_intent_parser: NeuralParser[SimBotIntent[SimBotNLUIntentType]], + cr_intent_client: SimBotCRIntentClient, + cr_intent_parser: NeuralParser[SimBotIntent[SimBotCRIntentType]], action_predictor_client: SimbotActionPredictionClient, _enable_clarification_questions: bool = True, _enable_search_actions: bool = True, @@ -38,8 +38,8 @@ def __init__( ) -> None: self._features_client = features_client - self._nlu_intent_client = nlu_intent_client - self._nlu_intent_parser = nlu_intent_parser + self._cr_intent_client = cr_intent_client + self._cr_intent_parser = cr_intent_parser self._action_predictor_client = action_predictor_client self._enable_clarification_questions = _enable_clarification_questions @@ -50,15 +50,15 @@ def __init__( def run(self, session: SimBotSession) -> Optional[SimBotAgentIntents]: """Get the agent intent.""" - # Check if the utterance has already been processed by the NLU - if self._utterance_has_been_processed_by_nlu(session): + # Check if the utterance has already been processed by the CR + if self._utterance_has_been_processed_by_cr(session): logger.debug("Executing utterance that triggered the search.") return SimBotAgentIntents( physical_interaction=SimBotIntent(type=SimBotIntentType.act_one_match) ) - # Otherwise, use the NLU to detect it - intents = self._process_utterance_with_nlu(session) + # Otherwise, use the CR to detect it + intents = self._process_utterance_with_cr(session) if self._should_search_target_object(session, intents): intents = self._handle_act_no_match_intent(session=session, intents=intents) elif self._should_search_missing_inventory(session, intents): @@ -66,8 +66,8 @@ def run(self, session: SimBotSession) -> Optional[SimBotAgentIntents]: return self._handle_search_holding_object(session=session, intents=intents) - def _utterance_has_been_processed_by_nlu(self, session: SimBotSession) -> bool: - """Determine if the utterance has already been processed by the NLU. + def _utterance_has_been_processed_by_cr(self, session: SimBotSession) -> bool: + """Determine if the utterance has already been processed by the CR. This happens when the intent was act no_match and search was completed. """ @@ -78,15 +78,15 @@ def _utterance_has_been_processed_by_nlu(self, session: SimBotSession) -> bool: and session.previous_turn.actions.is_successful ) - def _process_utterance_with_nlu(self, session: SimBotSession) -> SimBotAgentIntents: - """Perform NLU on the utterance to determine what the agent should do next. + def _process_utterance_with_cr(self, session: SimBotSession) -> SimBotAgentIntents: + """Perform CR on the utterance to determine what the agent should do next. This is primarily used to determine whether the agent should act or ask for more information. """ extracted_features = self._features_client.get_features(session.current_turn) - intent = self._nlu_intent_parser( - self._nlu_intent_client.generate( + intent = self._cr_intent_parser( + self._cr_intent_client.generate( dialogue_history=session.current_turn.utterances, environment_state_history=[EnvironmentStateTurn(features=extracted_features)], inventory_entity=session.current_state.inventory.entity, @@ -141,13 +141,13 @@ def _process_utterance_with_nlu(self, session: SimBotSession) -> SimBotAgentInte ) raise NotImplementedError( - "All NLU intents are not accounted for. This means that NLU has returned an intent which does not map to either an interaction intent, or a response intent." + "All CR intents are not accounted for. This means that CR has returned an intent which does not map to either an interaction intent, or a response intent." ) def _handle_act_no_match_intent( self, session: SimBotSession, intents: SimBotAgentIntents ) -> SimBotAgentIntents: - """Update the session based on the NLU output. + """Update the session based on the CR output. For `act_no_match`, push the current utterance in the utterance queue, and set the verbal interaction intent to `confirm_before_search`. @@ -301,7 +301,7 @@ def _should_search_missing_inventory( return should_search_before_executing_instruction def _should_replace_missing_inventory_intent( - self, intent: SimBotIntent[SimBotNLUIntentType] + self, intent: SimBotIntent[SimBotCRIntentType] ) -> bool: """Is the missing inventory intent disabled?""" return ( diff --git a/src/emma_experience_hub/parsers/simbot/__init__.py b/src/emma_experience_hub/parsers/simbot/__init__.py index e83e30fb..8237ebbe 100644 --- a/src/emma_experience_hub/parsers/simbot/__init__.py +++ b/src/emma_experience_hub/parsers/simbot/__init__.py @@ -1,10 +1,10 @@ from emma_experience_hub.parsers.simbot.action_predictor_output import ( SimBotActionPredictorOutputParser, ) +from emma_experience_hub.parsers.simbot.cr_output import SimBotCROutputParser from emma_experience_hub.parsers.simbot.intent_from_action_status import ( SimBotIntentFromActionStatusParser, ) -from emma_experience_hub.parsers.simbot.nlu_output import SimBotNLUOutputParser from emma_experience_hub.parsers.simbot.previous_action import SimBotPreviousActionParser from emma_experience_hub.parsers.simbot.visual_grounding_output import ( SimBotVisualGroundingOutputParser, diff --git a/src/emma_experience_hub/parsers/simbot/nlu_output.py b/src/emma_experience_hub/parsers/simbot/cr_output.py similarity index 72% rename from src/emma_experience_hub/parsers/simbot/nlu_output.py rename to src/emma_experience_hub/parsers/simbot/cr_output.py index 580cf4dc..0be5a8c7 100644 --- a/src/emma_experience_hub/parsers/simbot/nlu_output.py +++ b/src/emma_experience_hub/parsers/simbot/cr_output.py @@ -3,21 +3,21 @@ from loguru import logger from emma_experience_hub.datamodels.simbot import ( + SimBotCRIntentType, SimBotIntent, SimBotIntentType, - SimBotNLUIntentType, ) from emma_experience_hub.parsers.parser import NeuralParser -class SimBotNLUOutputParser(NeuralParser[SimBotIntent[SimBotNLUIntentType]]): - """Convert the output from the SimBot NLU module to a SimBot intent.""" +class SimBotCROutputParser(NeuralParser[SimBotIntent[SimBotCRIntentType]]): + """Convert the output from the SimBot CR module to a SimBot intent.""" def __init__(self, intent_type_delimiter: str) -> None: self._intent_type_delimiter = intent_type_delimiter - def __call__(self, output_text: str) -> SimBotIntent[SimBotNLUIntentType]: - """Parses the intent generated by the NLU component. + def __call__(self, output_text: str) -> SimBotIntent[SimBotCRIntentType]: + """Parses the intent generated by the CR component. The model is trained with the following templates: - @@ -26,7 +26,7 @@ def __call__(self, output_text: str) -> SimBotIntent[SimBotNLUIntentType]: - object_name - """ - logger.debug(f"NLU output text: `{output_text}`") + logger.debug(f"CR output text: `{output_text}`") # Split the raw output text by the given delimiter. We assume it's a " " separating the # special tokens and the object_name. @@ -34,7 +34,7 @@ def __call__(self, output_text: str) -> SimBotIntent[SimBotNLUIntentType]: # Get the intent type from the left-side of the template. intent_type = SimBotIntentType(split_parts[0]) - intent_type = cast(SimBotNLUIntentType, intent_type) + intent_type = cast(SimBotCRIntentType, intent_type) # If it exists, get the object name from the right-side of the template object_name = " ".join(split_parts[1:]) if len(split_parts) > 1 else None diff --git a/src/emma_experience_hub/pipelines/simbot/agent_intent_selection.py b/src/emma_experience_hub/pipelines/simbot/agent_intent_selection.py index f0b68059..fe3b1eb8 100644 --- a/src/emma_experience_hub/pipelines/simbot/agent_intent_selection.py +++ b/src/emma_experience_hub/pipelines/simbot/agent_intent_selection.py @@ -4,14 +4,14 @@ from emma_experience_hub.api.clients.simbot import ( SimbotActionPredictionClient, + SimBotCRIntentClient, SimBotFeaturesClient, - SimBotNLUIntentClient, ) from emma_experience_hub.datamodels.simbot import ( SimBotAgentIntents, + SimBotCRIntentType, SimBotIntent, SimBotIntentType, - SimBotNLUIntentType, SimBotSession, SimBotUserIntentType, ) @@ -58,8 +58,8 @@ class SimBotAgentIntentSelectionPipeline: def __init__( self, features_client: SimBotFeaturesClient, - nlu_intent_client: SimBotNLUIntentClient, - nlu_intent_parser: NeuralParser[SimBotIntent[SimBotNLUIntentType]], + cr_intent_client: SimBotCRIntentClient, + cr_intent_parser: NeuralParser[SimBotIntent[SimBotCRIntentType]], action_predictor_client: SimbotActionPredictionClient, environment_error_pipeline: SimBotEnvironmentErrorCatchingPipeline, _enable_clarification_questions: bool = True, @@ -72,8 +72,8 @@ def __init__( self._features_client = features_client self.act_handler = SimBotActHandler( features_client=features_client, - nlu_intent_client=nlu_intent_client, - nlu_intent_parser=nlu_intent_parser, + cr_intent_client=cr_intent_client, + cr_intent_parser=cr_intent_parser, action_predictor_client=action_predictor_client, _enable_clarification_questions=_enable_clarification_questions, _enable_search_actions=_enable_search_actions, diff --git a/tests/fixtures/clients.py b/tests/fixtures/clients.py index 772f90ff..61be4b31 100644 --- a/tests/fixtures/clients.py +++ b/tests/fixtures/clients.py @@ -4,8 +4,8 @@ from emma_experience_hub.api.clients.simbot import ( SimbotActionPredictionClient, + SimBotCRIntentClient, SimBotFeaturesClient, - SimBotNLUIntentClient, ) from emma_experience_hub.datamodels import EmmaExtractedFeatures from tests.fixtures.simbot_arena_constants import create_placeholder_features_frames @@ -26,13 +26,13 @@ def mock_features(*args: Any, **kwargs: Any) -> list[EmmaExtractedFeatures]: # def mock_policy_response_goto_room(monkeypatch: MonkeyPatch) -> None: """Mock the responses of EMMA policy.""" - def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + def get_cr(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 return "" def get_action(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 return "goto breakroom." - monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu) + monkeypatch.setattr(SimBotCRIntentClient, "generate", get_cr) monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_action) @@ -41,13 +41,13 @@ def mock_policy_response_toggle_computer(monkeypatch: MonkeyPatch) -> None: """Mock the responses of EMMA policy when the input instruction is about turning on the computer.""" - def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + def get_cr(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 return "" def get_action(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 return "toggle computer ." - monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu) + monkeypatch.setattr(SimBotCRIntentClient, "generate", get_cr) monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_action) @@ -56,7 +56,7 @@ def mock_policy_response_search(monkeypatch: MonkeyPatch) -> None: """Mock the responses of EMMA policy when the input instruction is about searching an object.""" - def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + def get_cr(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 return "" def get_object(*args: Any, **kwargs: Any) -> list[str]: # noqa: WPS430 @@ -65,6 +65,6 @@ def get_object(*args: Any, **kwargs: Any) -> list[str]: # noqa: WPS430 def get_target(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 return "goto object ." - monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu) + monkeypatch.setattr(SimBotCRIntentClient, "generate", get_cr) monkeypatch.setattr(SimbotActionPredictionClient, "find_object_in_scene", get_object) monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_target) diff --git a/tests/parsers/simbot/test_nlu_parser.py b/tests/parsers/simbot/test_nlu_parser.py index fe3cf189..ad2e9e15 100644 --- a/tests/parsers/simbot/test_nlu_parser.py +++ b/tests/parsers/simbot/test_nlu_parser.py @@ -2,11 +2,11 @@ from emma_experience_hub.datamodels.simbot import ( SimBotAction, + SimBotCRIntentType, SimBotIntent, SimBotIntentType, - SimBotNLUIntentType, ) -from emma_experience_hub.parsers.simbot import SimBotNLUOutputParser +from emma_experience_hub.parsers.simbot import SimBotCROutputParser @fixture(scope="session") @@ -19,27 +19,27 @@ def intent_type_delimiter() -> str: ) -class DecodedNLUOutputs: +class DecodedCROutputs: """Various cases to ensure the various intents are parsed correctly.""" _entity: str = "mug" - def case_act(self) -> tuple[str, SimBotIntent[SimBotNLUIntentType]]: - return "", SimBotIntent[SimBotNLUIntentType]( + def case_act(self) -> tuple[str, SimBotIntent[SimBotCRIntentType]]: + return "", SimBotIntent[SimBotCRIntentType]( type=SimBotIntentType.act_one_match ) - def case_search(self) -> tuple[str, SimBotIntent[SimBotNLUIntentType]]: - return "", SimBotIntent[SimBotNLUIntentType](type=SimBotIntentType.search) + def case_search(self) -> tuple[str, SimBotIntent[SimBotCRIntentType]]: + return "", SimBotIntent[SimBotCRIntentType](type=SimBotIntentType.search) def case_act_too_many_matches( self, should_include_entity: bool - ) -> tuple[str, SimBotIntent[SimBotNLUIntentType]]: + ) -> tuple[str, SimBotIntent[SimBotCRIntentType]]: output = "" if should_include_entity: output = f"{output} {self._entity}" - intent = SimBotIntent[SimBotNLUIntentType]( + intent = SimBotIntent[SimBotCRIntentType]( type=SimBotIntentType.act_too_many_matches, entity=self._entity if should_include_entity else None, ) @@ -48,12 +48,12 @@ def case_act_too_many_matches( def case_act_no_match( self, should_include_entity: bool - ) -> tuple[str, SimBotIntent[SimBotNLUIntentType]]: + ) -> tuple[str, SimBotIntent[SimBotCRIntentType]]: output = "" if should_include_entity: output = f"{output} {self._entity}" - intent = SimBotIntent[SimBotNLUIntentType]( + intent = SimBotIntent[SimBotCRIntentType]( type=SimBotIntentType.act_no_match, entity=self._entity if should_include_entity else None, ) @@ -61,11 +61,11 @@ def case_act_no_match( return output, intent -@parametrize_with_cases("decoded_actions,expected_output", cases=DecodedNLUOutputs) -def test_parser_decodes_nlu_output( +@parametrize_with_cases("decoded_actions,expected_output", cases=DecodedCROutputs) +def test_parser_decodes_cr_output( decoded_actions: str, expected_output: SimBotAction, intent_type_delimiter: str ) -> None: - trajectory_parser = SimBotNLUOutputParser(intent_type_delimiter=intent_type_delimiter) + trajectory_parser = SimBotCROutputParser(intent_type_delimiter=intent_type_delimiter) parsed_trajectory = trajectory_parser(decoded_actions) assert parsed_trajectory == expected_output