Skip to content

Commit

Permalink
Auto convert byte to str
Browse files Browse the repository at this point in the history
  • Loading branch information
Danidite committed Nov 12, 2024
1 parent 5610860 commit 021b305
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 41 deletions.
16 changes: 9 additions & 7 deletions caribou/common/models/remote_client/aws_remote_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,9 +658,7 @@ def set_value_in_table_column(
UpdateExpression=update_expression,
)

def get_value_from_table(
self, table_name: str, key: str, consistent_read: bool = True, convert_from_bytes: bool = False
) -> tuple[str, float]:
def get_value_from_table(self, table_name: str, key: str, consistent_read: bool = True) -> tuple[str, float]:
client = self._client("dynamodb")
response = client.get_item(
TableName=table_name,
Expand All @@ -677,7 +675,8 @@ def get_value_from_table(

item = response.get("Item")
if item is not None and "value" in item:
if convert_from_bytes:
# Detect if the value is compressed (in bytes) and decompress it
if "B" in item["value"]:
return decompress_json_str(item["value"]["B"]), consumed_read_capacity

return item["value"]["S"], consumed_read_capacity
Expand All @@ -688,15 +687,18 @@ def remove_value_from_table(self, table_name: str, key: str) -> None:
client = self._client("dynamodb")
client.delete_item(TableName=table_name, Key={"key": {"S": key}})

def get_all_values_from_table(self, table_name: str, convert_from_bytes: bool = False) -> dict[str, Any]:
def get_all_values_from_table(self, table_name: str) -> dict[str, Any]:
client = self._client("dynamodb")
response = client.scan(TableName=table_name)
if "Items" not in response:
return {}
items = response.get("Items")
if items is not None:
if convert_from_bytes:
return {item["key"]["S"]: decompress_json_str(item["value"]["B"]) for item in items}
for item in items:
# Detect if the value is compressed (in bytes) and decompress it
if "value" in item and "B" in item["value"]:
item["value"]["S"] = decompress_json_str(item["value"]["B"])
del item["value"]["B"]

return {item["key"]["S"]: item["value"]["S"] for item in items}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,7 @@ def set_value_in_table(self, table_name: str, key: str, value: str, convert_to_b
conn.commit()
conn.close()

def get_value_from_table(
self, table_name: str, key: str, consistent_read: bool = True, convert_from_bytes: bool = False
) -> tuple[str, float]:
def get_value_from_table(self, table_name: str, key: str, consistent_read: bool = True) -> tuple[str, float]:
conn = self._db_connection()
cursor = conn.cursor()
cursor.execute(f"SELECT value FROM {table_name} WHERE key=?", (key,))
Expand Down Expand Up @@ -267,7 +265,7 @@ def set_predecessor_reached(
conn.close()
return [bool(res) for res in result], 0.0, 0.0

def get_all_values_from_table(self, table_name: str, convert_from_bytes: bool = False) -> dict:
def get_all_values_from_table(self, table_name: str) -> dict:
conn = self._db_connection()
cursor = conn.cursor()
cursor.execute(f"SELECT key, value FROM {table_name}")
Expand Down
4 changes: 2 additions & 2 deletions caribou/common/models/remote_client/mock_remote_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def upload_predecessor_data_at_sync_node(self, function_name, workflow_instance_
def set_value_in_table(self, table_name, key, value, convert_to_bytes: bool = False):
pass

def get_value_from_table(self, table_name, key, consistent_read: bool = True, convert_from_bytes: bool = False):
def get_value_from_table(self, table_name, key, consistent_read: bool = True):
pass

def upload_resource(self, key, resource):
Expand Down Expand Up @@ -74,7 +74,7 @@ def set_predecessor_reached(
) -> list[bool]:
pass

def get_all_values_from_table(self, table_name: str, convert_from_bytes: bool = False) -> dict:
def get_all_values_from_table(self, table_name: str) -> dict:
pass

def set_value_in_table_column(
Expand Down
6 changes: 2 additions & 4 deletions caribou/common/models/remote_client/remote_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,15 @@ def set_value_in_table_column(
raise NotImplementedError()

@abstractmethod
def get_value_from_table(
self, table_name: str, key: str, consistent_read: bool = True, convert_from_bytes: bool = False
) -> tuple[str, float]:
def get_value_from_table(self, table_name: str, key: str, consistent_read: bool = True) -> tuple[str, float]:
raise NotImplementedError()

@abstractmethod
def remove_value_from_table(self, table_name: str, key: str) -> None:
raise NotImplementedError()

@abstractmethod
def get_all_values_from_table(self, table_name: str, convert_from_bytes: bool = False) -> dict[str, Any]:
def get_all_values_from_table(self, table_name: str) -> dict[str, Any]:
raise NotImplementedError()

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def retrieve_all_workflow_ids(self) -> set[str]:

def retrieve_workflow_summary(self, workflow_unique_id: str) -> dict[str, Any]:
# Load the summarized logs from the workflow summary table
workflow_summarized, _ = self._client.get_value_from_table(
self._workflow_summary_table, workflow_unique_id, convert_from_bytes=True
)
workflow_summarized, _ = self._client.get_value_from_table(self._workflow_summary_table, workflow_unique_id)

# Consolidate all the timestamps together to one summary and return the result
return self._transform_workflow_summary(workflow_summarized)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def _retrieve_region_data(self, available_regions: set[str]) -> dict[str, Any]:

return all_data

def _retrieve_data(self, table_name: str, data_key: str, convert_from_bytes: bool = False) -> dict[str, Any]:
value, _ = self._client.get_value_from_table(table_name, data_key, convert_from_bytes=convert_from_bytes)
def _retrieve_data(self, table_name: str, data_key: str) -> dict[str, Any]:
value, _ = self._client.get_value_from_table(table_name, data_key)

loaded_data: dict[str, Any] = {}
if value is not None and value != "":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def get_architecture(self, instance_name: str, provider_name: str) -> str:
) # Default to x86_64

def _retrieve_workflow_data(self, workflow_id: str) -> dict[str, Any]:
return self._retrieve_data(self._primary_table, workflow_id, convert_from_bytes=True)
return self._retrieve_data(self._primary_table, workflow_id)

def _round_to_kb(self, number: float, round_to: int = 10, round_up: bool = True) -> float:
"""
Expand Down
4 changes: 1 addition & 3 deletions caribou/monitors/deployment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def check_workflow(self, workflow_id: str) -> None:

workflow_config = self._get_workflow_config(workflow_id)

workflow_summary_raw, _ = data_collector_client.get_value_from_table(
WORKFLOW_INSTANCE_TABLE, workflow_id, convert_from_bytes=True
)
workflow_summary_raw, _ = data_collector_client.get_value_from_table(WORKFLOW_INSTANCE_TABLE, workflow_id)

workflow_summary = json.loads(workflow_summary_raw)

Expand Down
4 changes: 1 addition & 3 deletions caribou/syncers/log_syncer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def sync_workflow(self, workflow_id: str) -> None:
DEPLOYMENT_RESOURCES_TABLE, workflow_id
)

previous_data_str, _ = self._workflow_summary_client.get_value_from_table(
WORKFLOW_SUMMARY_TABLE, workflow_id, convert_from_bytes=True
)
previous_data_str, _ = self._workflow_summary_client.get_value_from_table(WORKFLOW_SUMMARY_TABLE, workflow_id)
previous_data = json.loads(previous_data_str) if previous_data_str else {}

last_sync_time: Optional[str] = previous_data.get(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def test_get_value_from_table(self, mock_client):
table_name = "test_table"
key = "test_key"

# Scenario 1: Item exists and convert_from_bytes is False
# Scenario 1: Item exists and is of type byte is False
mock_client.return_value.get_item.return_value = {
"Item": {"key": {"S": key}, "value": {"S": "test_value"}},
"ConsumedCapacity": {"CapacityUnits": 1.0},
Expand All @@ -474,7 +474,7 @@ def test_get_value_from_table(self, mock_client):
self.assertEqual(result, "test_value")
self.assertEqual(consumed_capacity, 1.0)

# Scenario 2: Item exists and convert_from_bytes is True
# Scenario 2: Item exists and is of type byte is True
mock_client.return_value.get_item.return_value = {
"Item": {"key": {"S": key}, "value": {"B": b"compressed_value"}},
"ConsumedCapacity": {"CapacityUnits": 1.0},
Expand All @@ -483,7 +483,7 @@ def test_get_value_from_table(self, mock_client):
"caribou.common.models.remote_client.aws_remote_client.decompress_json_str",
return_value="decompressed_value",
):
result, consumed_capacity = self.aws_client.get_value_from_table(table_name, key, convert_from_bytes=True)
result, consumed_capacity = self.aws_client.get_value_from_table(table_name, key)
self.assertEqual(result, "decompressed_value")
self.assertEqual(consumed_capacity, 1.0)

Expand Down Expand Up @@ -516,7 +516,7 @@ def test_remove_value_from_table(self, mock_client):
def test_get_all_values_from_table(self, mock_client):
table_name = "test_table"

# Scenario 1: Items exist and convert_from_bytes is False
# Scenario 1: Items exist and is of type byte is False
mock_client.return_value.scan.return_value = {
"Items": [
{"key": {"S": "key1"}, "value": {"S": "value1"}},
Expand All @@ -526,7 +526,7 @@ def test_get_all_values_from_table(self, mock_client):
result = self.aws_client.get_all_values_from_table(table_name)
self.assertEqual(result, {"key1": "value1", "key2": "value2"})

# Scenario 2: Items exist and convert_from_bytes is True
# Scenario 2: Items exist and is of type byte is True
mock_client.return_value.scan.return_value = {
"Items": [
{"key": {"S": "key1"}, "value": {"B": b"compressed_value1"}},
Expand All @@ -537,7 +537,7 @@ def test_get_all_values_from_table(self, mock_client):
"caribou.common.models.remote_client.aws_remote_client.decompress_json_str",
side_effect=["decompressed_value1", "decompressed_value2"],
):
result = self.aws_client.get_all_values_from_table(table_name, convert_from_bytes=True)
result = self.aws_client.get_all_values_from_table(table_name)
self.assertEqual(result, {"key1": "decompressed_value1", "key2": "decompressed_value2"})

# Scenario 3: No items in response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def test_retrieve_workflow_summary(self):

# Assertions
self.assertEqual(result, {"transformed": "data"})
self.mock_client.get_value_from_table.assert_called_once_with(
WORKFLOW_SUMMARY_TABLE, "workflow_id", convert_from_bytes=True
)
self.mock_client.get_value_from_table.assert_called_once_with(WORKFLOW_SUMMARY_TABLE, "workflow_id")
mock_transform.assert_called_once_with(json.dumps({"logs": []}))

@patch.object(WorkflowRetriever, "_reorganize_instance_summary")
Expand Down
6 changes: 3 additions & 3 deletions caribou/tests/monitors/test_deployment_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def test_check_workflow(
mock_client.get_value_from_table.assert_has_calls(
[
call(DEPLOYMENT_MANAGER_WORKFLOW_INFO_TABLE, workflow_id),
call(WORKFLOW_INSTANCE_TABLE, workflow_id, convert_from_bytes=True),
call(WORKFLOW_INSTANCE_TABLE, workflow_id),
]
)
mock_get_last_solved.assert_called_once()
Expand Down Expand Up @@ -503,7 +503,7 @@ def test_check_workflow_not_enough_tokens(
mock_client.get_value_from_table.assert_has_calls(
[
call(DEPLOYMENT_MANAGER_WORKFLOW_INFO_TABLE, workflow_id),
call(WORKFLOW_INSTANCE_TABLE, workflow_id, convert_from_bytes=True),
call(WORKFLOW_INSTANCE_TABLE, workflow_id),
]
)
mock_get_last_solved.assert_called_once()
Expand Down Expand Up @@ -571,7 +571,7 @@ def test_check_workflow_deployed_remotely(
mock_client.get_value_from_table.assert_has_calls(
[
call(DEPLOYMENT_MANAGER_WORKFLOW_INFO_TABLE, workflow_id),
call(WORKFLOW_INSTANCE_TABLE, workflow_id, convert_from_bytes=True),
call(WORKFLOW_INSTANCE_TABLE, workflow_id),
]
)
mock_get_last_solved.assert_called_once()
Expand Down

0 comments on commit 021b305

Please sign in to comment.