diff --git a/caribou/common/models/remote_client/aws_remote_client.py b/caribou/common/models/remote_client/aws_remote_client.py index a5dd0810..1e24ed4e 100644 --- a/caribou/common/models/remote_client/aws_remote_client.py +++ b/caribou/common/models/remote_client/aws_remote_client.py @@ -658,9 +658,7 @@ def set_value_in_table_column( UpdateExpression=update_expression, ) - def get_value_from_table( - self, table_name: str, key: str, consistent_read: bool = True, convert_from_bytes: bool = False - ) -> tuple[str, float]: + def get_value_from_table(self, table_name: str, key: str, consistent_read: bool = True) -> tuple[str, float]: client = self._client("dynamodb") response = client.get_item( TableName=table_name, @@ -677,7 +675,8 @@ def get_value_from_table( item = response.get("Item") if item is not None and "value" in item: - if convert_from_bytes: + # Detect if the value is compressed (in bytes) and decompress it + if "B" in item["value"]: return decompress_json_str(item["value"]["B"]), consumed_read_capacity return item["value"]["S"], consumed_read_capacity @@ -688,15 +687,18 @@ def remove_value_from_table(self, table_name: str, key: str) -> None: client = self._client("dynamodb") client.delete_item(TableName=table_name, Key={"key": {"S": key}}) - def get_all_values_from_table(self, table_name: str, convert_from_bytes: bool = False) -> dict[str, Any]: + def get_all_values_from_table(self, table_name: str) -> dict[str, Any]: client = self._client("dynamodb") response = client.scan(TableName=table_name) if "Items" not in response: return {} items = response.get("Items") if items is not None: - if convert_from_bytes: - return {item["key"]["S"]: decompress_json_str(item["value"]["B"]) for item in items} + for item in items: + # Detect if the value is compressed (in bytes) and decompress it + if "value" in item and "B" in item["value"]: + item["value"]["S"] = decompress_json_str(item["value"]["B"]) + del item["value"]["B"] return {item["key"]["S"]: item["value"]["S"] for item in items} diff --git a/caribou/common/models/remote_client/integration_test_remote_client.py b/caribou/common/models/remote_client/integration_test_remote_client.py index e898dee5..32b1a275 100644 --- a/caribou/common/models/remote_client/integration_test_remote_client.py +++ b/caribou/common/models/remote_client/integration_test_remote_client.py @@ -215,9 +215,7 @@ def set_value_in_table(self, table_name: str, key: str, value: str, convert_to_b conn.commit() conn.close() - def get_value_from_table( - self, table_name: str, key: str, consistent_read: bool = True, convert_from_bytes: bool = False - ) -> tuple[str, float]: + def get_value_from_table(self, table_name: str, key: str, consistent_read: bool = True) -> tuple[str, float]: conn = self._db_connection() cursor = conn.cursor() cursor.execute(f"SELECT value FROM {table_name} WHERE key=?", (key,)) @@ -267,7 +265,7 @@ def set_predecessor_reached( conn.close() return [bool(res) for res in result], 0.0, 0.0 - def get_all_values_from_table(self, table_name: str, convert_from_bytes: bool = False) -> dict: + def get_all_values_from_table(self, table_name: str) -> dict: conn = self._db_connection() cursor = conn.cursor() cursor.execute(f"SELECT key, value FROM {table_name}") diff --git a/caribou/common/models/remote_client/mock_remote_client.py b/caribou/common/models/remote_client/mock_remote_client.py index 8d9a4cf2..5c9e4baa 100644 --- a/caribou/common/models/remote_client/mock_remote_client.py +++ b/caribou/common/models/remote_client/mock_remote_client.py @@ -43,7 +43,7 @@ def upload_predecessor_data_at_sync_node(self, function_name, workflow_instance_ def set_value_in_table(self, table_name, key, value, convert_to_bytes: bool = False): pass - def get_value_from_table(self, table_name, key, consistent_read: bool = True, convert_from_bytes: bool = False): + def get_value_from_table(self, table_name, key, consistent_read: bool = True): pass def upload_resource(self, key, resource): @@ -74,7 +74,7 @@ def set_predecessor_reached( ) -> list[bool]: pass - def get_all_values_from_table(self, table_name: str, convert_from_bytes: bool = False) -> dict: + def get_all_values_from_table(self, table_name: str) -> dict: pass def set_value_in_table_column( diff --git a/caribou/common/models/remote_client/remote_client.py b/caribou/common/models/remote_client/remote_client.py index 1e63f2ee..f47d5fde 100644 --- a/caribou/common/models/remote_client/remote_client.py +++ b/caribou/common/models/remote_client/remote_client.py @@ -160,9 +160,7 @@ def set_value_in_table_column( raise NotImplementedError() @abstractmethod - def get_value_from_table( - self, table_name: str, key: str, consistent_read: bool = True, convert_from_bytes: bool = False - ) -> tuple[str, float]: + def get_value_from_table(self, table_name: str, key: str, consistent_read: bool = True) -> tuple[str, float]: raise NotImplementedError() @abstractmethod @@ -170,7 +168,7 @@ def remove_value_from_table(self, table_name: str, key: str) -> None: raise NotImplementedError() @abstractmethod - def get_all_values_from_table(self, table_name: str, convert_from_bytes: bool = False) -> dict[str, Any]: + def get_all_values_from_table(self, table_name: str) -> dict[str, Any]: raise NotImplementedError() @abstractmethod diff --git a/caribou/data_collector/components/workflow/workflow_retriever.py b/caribou/data_collector/components/workflow/workflow_retriever.py index c4dbbe1a..00224900 100644 --- a/caribou/data_collector/components/workflow/workflow_retriever.py +++ b/caribou/data_collector/components/workflow/workflow_retriever.py @@ -21,9 +21,7 @@ def retrieve_all_workflow_ids(self) -> set[str]: def retrieve_workflow_summary(self, workflow_unique_id: str) -> dict[str, Any]: # Load the summarized logs from the workflow summary table - workflow_summarized, _ = self._client.get_value_from_table( - self._workflow_summary_table, workflow_unique_id, convert_from_bytes=True - ) + workflow_summarized, _ = self._client.get_value_from_table(self._workflow_summary_table, workflow_unique_id) # Consolidate all the timestamps together to one summary and return the result return self._transform_workflow_summary(workflow_summarized) diff --git a/caribou/deployment_solver/deployment_input/components/loader.py b/caribou/deployment_solver/deployment_input/components/loader.py index 447b49d7..6db4a49b 100644 --- a/caribou/deployment_solver/deployment_input/components/loader.py +++ b/caribou/deployment_solver/deployment_input/components/loader.py @@ -18,8 +18,8 @@ def _retrieve_region_data(self, available_regions: set[str]) -> dict[str, Any]: return all_data - def _retrieve_data(self, table_name: str, data_key: str, convert_from_bytes: bool = False) -> dict[str, Any]: - value, _ = self._client.get_value_from_table(table_name, data_key, convert_from_bytes=convert_from_bytes) + def _retrieve_data(self, table_name: str, data_key: str) -> dict[str, Any]: + value, _ = self._client.get_value_from_table(table_name, data_key) loaded_data: dict[str, Any] = {} if value is not None and value != "": diff --git a/caribou/deployment_solver/deployment_input/components/loaders/workflow_loader.py b/caribou/deployment_solver/deployment_input/components/loaders/workflow_loader.py index 5c955ae6..349294a0 100644 --- a/caribou/deployment_solver/deployment_input/components/loaders/workflow_loader.py +++ b/caribou/deployment_solver/deployment_input/components/loaders/workflow_loader.py @@ -389,7 +389,7 @@ def get_architecture(self, instance_name: str, provider_name: str) -> str: ) # Default to x86_64 def _retrieve_workflow_data(self, workflow_id: str) -> dict[str, Any]: - return self._retrieve_data(self._primary_table, workflow_id, convert_from_bytes=True) + return self._retrieve_data(self._primary_table, workflow_id) def _round_to_kb(self, number: float, round_to: int = 10, round_up: bool = True) -> float: """ diff --git a/caribou/monitors/deployment_manager.py b/caribou/monitors/deployment_manager.py index 2bdc9424..84536be0 100644 --- a/caribou/monitors/deployment_manager.py +++ b/caribou/monitors/deployment_manager.py @@ -124,9 +124,7 @@ def check_workflow(self, workflow_id: str) -> None: workflow_config = self._get_workflow_config(workflow_id) - workflow_summary_raw, _ = data_collector_client.get_value_from_table( - WORKFLOW_INSTANCE_TABLE, workflow_id, convert_from_bytes=True - ) + workflow_summary_raw, _ = data_collector_client.get_value_from_table(WORKFLOW_INSTANCE_TABLE, workflow_id) workflow_summary = json.loads(workflow_summary_raw) diff --git a/caribou/syncers/log_syncer.py b/caribou/syncers/log_syncer.py index b3063c1c..c0233287 100644 --- a/caribou/syncers/log_syncer.py +++ b/caribou/syncers/log_syncer.py @@ -65,9 +65,7 @@ def sync_workflow(self, workflow_id: str) -> None: DEPLOYMENT_RESOURCES_TABLE, workflow_id ) - previous_data_str, _ = self._workflow_summary_client.get_value_from_table( - WORKFLOW_SUMMARY_TABLE, workflow_id, convert_from_bytes=True - ) + previous_data_str, _ = self._workflow_summary_client.get_value_from_table(WORKFLOW_SUMMARY_TABLE, workflow_id) previous_data = json.loads(previous_data_str) if previous_data_str else {} last_sync_time: Optional[str] = previous_data.get( diff --git a/caribou/tests/common/models/remote_client/test_aws_remote_client.py b/caribou/tests/common/models/remote_client/test_aws_remote_client.py index c07fb146..dbf303ec 100644 --- a/caribou/tests/common/models/remote_client/test_aws_remote_client.py +++ b/caribou/tests/common/models/remote_client/test_aws_remote_client.py @@ -465,7 +465,7 @@ def test_get_value_from_table(self, mock_client): table_name = "test_table" key = "test_key" - # Scenario 1: Item exists and convert_from_bytes is False + # Scenario 1: Item exists and is of type byte is False mock_client.return_value.get_item.return_value = { "Item": {"key": {"S": key}, "value": {"S": "test_value"}}, "ConsumedCapacity": {"CapacityUnits": 1.0}, @@ -474,7 +474,7 @@ def test_get_value_from_table(self, mock_client): self.assertEqual(result, "test_value") self.assertEqual(consumed_capacity, 1.0) - # Scenario 2: Item exists and convert_from_bytes is True + # Scenario 2: Item exists and is of type byte is True mock_client.return_value.get_item.return_value = { "Item": {"key": {"S": key}, "value": {"B": b"compressed_value"}}, "ConsumedCapacity": {"CapacityUnits": 1.0}, @@ -483,7 +483,7 @@ def test_get_value_from_table(self, mock_client): "caribou.common.models.remote_client.aws_remote_client.decompress_json_str", return_value="decompressed_value", ): - result, consumed_capacity = self.aws_client.get_value_from_table(table_name, key, convert_from_bytes=True) + result, consumed_capacity = self.aws_client.get_value_from_table(table_name, key) self.assertEqual(result, "decompressed_value") self.assertEqual(consumed_capacity, 1.0) @@ -516,7 +516,7 @@ def test_remove_value_from_table(self, mock_client): def test_get_all_values_from_table(self, mock_client): table_name = "test_table" - # Scenario 1: Items exist and convert_from_bytes is False + # Scenario 1: Items exist and is of type byte is False mock_client.return_value.scan.return_value = { "Items": [ {"key": {"S": "key1"}, "value": {"S": "value1"}}, @@ -526,7 +526,7 @@ def test_get_all_values_from_table(self, mock_client): result = self.aws_client.get_all_values_from_table(table_name) self.assertEqual(result, {"key1": "value1", "key2": "value2"}) - # Scenario 2: Items exist and convert_from_bytes is True + # Scenario 2: Items exist and is of type byte is True mock_client.return_value.scan.return_value = { "Items": [ {"key": {"S": "key1"}, "value": {"B": b"compressed_value1"}}, @@ -537,7 +537,7 @@ def test_get_all_values_from_table(self, mock_client): "caribou.common.models.remote_client.aws_remote_client.decompress_json_str", side_effect=["decompressed_value1", "decompressed_value2"], ): - result = self.aws_client.get_all_values_from_table(table_name, convert_from_bytes=True) + result = self.aws_client.get_all_values_from_table(table_name) self.assertEqual(result, {"key1": "decompressed_value1", "key2": "decompressed_value2"}) # Scenario 3: No items in response diff --git a/caribou/tests/data_collector/components/workflow/test_workflow_retriever.py b/caribou/tests/data_collector/components/workflow/test_workflow_retriever.py index 032b8a46..0aedc0fd 100644 --- a/caribou/tests/data_collector/components/workflow/test_workflow_retriever.py +++ b/caribou/tests/data_collector/components/workflow/test_workflow_retriever.py @@ -36,9 +36,7 @@ def test_retrieve_workflow_summary(self): # Assertions self.assertEqual(result, {"transformed": "data"}) - self.mock_client.get_value_from_table.assert_called_once_with( - WORKFLOW_SUMMARY_TABLE, "workflow_id", convert_from_bytes=True - ) + self.mock_client.get_value_from_table.assert_called_once_with(WORKFLOW_SUMMARY_TABLE, "workflow_id") mock_transform.assert_called_once_with(json.dumps({"logs": []})) @patch.object(WorkflowRetriever, "_reorganize_instance_summary") diff --git a/caribou/tests/monitors/test_deployment_manager.py b/caribou/tests/monitors/test_deployment_manager.py index 947ad13e..093e5e50 100644 --- a/caribou/tests/monitors/test_deployment_manager.py +++ b/caribou/tests/monitors/test_deployment_manager.py @@ -438,7 +438,7 @@ def test_check_workflow( mock_client.get_value_from_table.assert_has_calls( [ call(DEPLOYMENT_MANAGER_WORKFLOW_INFO_TABLE, workflow_id), - call(WORKFLOW_INSTANCE_TABLE, workflow_id, convert_from_bytes=True), + call(WORKFLOW_INSTANCE_TABLE, workflow_id), ] ) mock_get_last_solved.assert_called_once() @@ -503,7 +503,7 @@ def test_check_workflow_not_enough_tokens( mock_client.get_value_from_table.assert_has_calls( [ call(DEPLOYMENT_MANAGER_WORKFLOW_INFO_TABLE, workflow_id), - call(WORKFLOW_INSTANCE_TABLE, workflow_id, convert_from_bytes=True), + call(WORKFLOW_INSTANCE_TABLE, workflow_id), ] ) mock_get_last_solved.assert_called_once() @@ -571,7 +571,7 @@ def test_check_workflow_deployed_remotely( mock_client.get_value_from_table.assert_has_calls( [ call(DEPLOYMENT_MANAGER_WORKFLOW_INFO_TABLE, workflow_id), - call(WORKFLOW_INSTANCE_TABLE, workflow_id, convert_from_bytes=True), + call(WORKFLOW_INSTANCE_TABLE, workflow_id), ] ) mock_get_last_solved.assert_called_once()