Skip to content

Commit

Permalink
Added and updated unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Danidite committed Oct 31, 2024
1 parent f2c54be commit 0b0862c
Show file tree
Hide file tree
Showing 12 changed files with 1,102 additions and 189 deletions.
191 changes: 187 additions & 4 deletions caribou/tests/common/models/remote_client/test_aws_remote_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest.mock import call

from caribou.common.constants import (
REMOTE_CARIBOU_CLI_FUNCTION_NAME,
SYNC_MESSAGES_TABLE,
CARIBOU_WORKFLOW_IMAGES_TABLE,
GLOBAL_TIME_ZONE,
Expand Down Expand Up @@ -431,12 +432,25 @@ def test_set_value_in_table(self, mock_client):
table_name = "test_table"
key = "test_key"
value = "test_value"

# Test without convert_to_bytes
self.aws_client.set_value_in_table(table_name, key, value)
mock_client.assert_called_with("dynamodb")
mock_client.return_value.put_item.assert_called_once_with(
TableName=table_name, Item={"key": {"S": key}, "value": {"S": value}}
)

# Test with convert_to_bytes
mock_client.reset_mock()
with patch(
"caribou.common.models.remote_client.aws_remote_client.compress_json_str", return_value=b"compressed_value"
):
self.aws_client.set_value_in_table(table_name, key, value, convert_to_bytes=True)
mock_client.assert_called_with("dynamodb")
mock_client.return_value.put_item.assert_called_once_with(
TableName=table_name, Item={"key": {"S": key}, "value": {"B": b"compressed_value"}}
)

@patch.object(AWSRemoteClient, "_client")
def test_set_value_in_table_column(self, mock_client):
table_name = "test_table"
Expand All @@ -450,9 +464,45 @@ def test_set_value_in_table_column(self, mock_client):
def test_get_value_from_table(self, mock_client):
table_name = "test_table"
key = "test_key"
mock_client.return_value.get_item.return_value = {"Item": {"key": {"S": key}, "value": {"S": "test_value"}}}
result, _ = self.aws_client.get_value_from_table(table_name, key)

# Scenario 1: Item exists and convert_from_bytes is False
mock_client.return_value.get_item.return_value = {
"Item": {"key": {"S": key}, "value": {"S": "test_value"}},
"ConsumedCapacity": {"CapacityUnits": 1.0},
}
result, consumed_capacity = self.aws_client.get_value_from_table(table_name, key)
self.assertEqual(result, "test_value")
self.assertEqual(consumed_capacity, 1.0)

# Scenario 2: Item exists and convert_from_bytes is True
mock_client.return_value.get_item.return_value = {
"Item": {"key": {"S": key}, "value": {"B": b"compressed_value"}},
"ConsumedCapacity": {"CapacityUnits": 1.0},
}
with patch(
"caribou.common.models.remote_client.aws_remote_client.decompress_json_str",
return_value="decompressed_value",
):
result, consumed_capacity = self.aws_client.get_value_from_table(table_name, key, convert_from_bytes=True)
self.assertEqual(result, "decompressed_value")
self.assertEqual(consumed_capacity, 1.0)

# Scenario 3: Item does not exist
mock_client.return_value.get_item.return_value = {
"ConsumedCapacity": {"CapacityUnits": 1.0},
}
result, consumed_capacity = self.aws_client.get_value_from_table(table_name, key)
self.assertEqual(result, "")
self.assertEqual(consumed_capacity, 1.0)

# Scenario 4: Item exists but no value field
mock_client.return_value.get_item.return_value = {
"Item": {"key": {"S": key}},
"ConsumedCapacity": {"CapacityUnits": 1.0},
}
result, consumed_capacity = self.aws_client.get_value_from_table(table_name, key)
self.assertEqual(result, "")
self.assertEqual(consumed_capacity, 1.0)

@patch.object(AWSRemoteClient, "_client")
def test_remove_value_from_table(self, mock_client):
Expand All @@ -465,11 +515,40 @@ def test_remove_value_from_table(self, mock_client):
@patch.object(AWSRemoteClient, "_client")
def test_get_all_values_from_table(self, mock_client):
table_name = "test_table"

# Scenario 1: Items exist and convert_from_bytes is False
mock_client.return_value.scan.return_value = {
"Items": [
{"key": {"S": "key1"}, "value": {"S": "value1"}},
{"key": {"S": "key2"}, "value": {"S": "value2"}},
]
}
result = self.aws_client.get_all_values_from_table(table_name)
self.assertEqual(result, {"key1": "value1", "key2": "value2"})

# Scenario 2: Items exist and convert_from_bytes is True
mock_client.return_value.scan.return_value = {
"Items": [{"key": {"S": "key1"}, "value": {"S": json.dumps("value1")}}]
"Items": [
{"key": {"S": "key1"}, "value": {"B": b"compressed_value1"}},
{"key": {"S": "key2"}, "value": {"B": b"compressed_value2"}},
]
}
with patch(
"caribou.common.models.remote_client.aws_remote_client.decompress_json_str",
side_effect=["decompressed_value1", "decompressed_value2"],
):
result = self.aws_client.get_all_values_from_table(table_name, convert_from_bytes=True)
self.assertEqual(result, {"key1": "decompressed_value1", "key2": "decompressed_value2"})

# Scenario 3: No items in response
mock_client.return_value.scan.return_value = {}
result = self.aws_client.get_all_values_from_table(table_name)
self.assertEqual(result, {"key1": '"value1"'})
self.assertEqual(result, {})

# Scenario 4: Items key is None
mock_client.return_value.scan.return_value = {"Items": None}
result = self.aws_client.get_all_values_from_table(table_name)
self.assertEqual(result, {})

@patch.object(AWSRemoteClient, "_client")
def test_get_key_present_in_table(self, mock_client):
Expand Down Expand Up @@ -1459,6 +1538,110 @@ def test_create_timer_rule_permission_exists(
],
)

@patch.object(AWSRemoteClient, "_client")
@patch("json.dumps")
def test_invoke_remote_framework_with_payload(self, mock_json_dumps, mock_client):
# Mocking the scenario where the Lambda function is invoked successfully
mock_lambda_client = MagicMock()
mock_client.return_value = mock_lambda_client

client = AWSRemoteClient("region1")

# Define the input
payload = {"key": "value"}
invocation_type = "Event"

# Mock the return value of json.dumps
mock_json_dumps.return_value = '{"key": "value"}'

# Call the method with test values
client.invoke_remote_framework_with_payload(payload, invocation_type)

# Check that the _client method was called with the correct arguments
mock_client.assert_called_once_with("lambda")

# Check that the invoke method was called with the correct arguments
mock_lambda_client.invoke.assert_called_once_with(
FunctionName=REMOTE_CARIBOU_CLI_FUNCTION_NAME,
InvocationType=invocation_type,
Payload='{"key": "value"}',
)

@patch.object(AWSRemoteClient, "_client")
@patch("json.dumps")
def test_invoke_remote_framework_with_payload_default_invocation_type(self, mock_json_dumps, mock_client):
# Mocking the scenario where the Lambda function is invoked successfully with default invocation type
mock_lambda_client = MagicMock()
mock_client.return_value = mock_lambda_client

client = AWSRemoteClient("region1")

# Define the input
payload = {"key": "value"}

# Mock the return value of json.dumps
mock_json_dumps.return_value = '{"key": "value"}'

# Call the method with test values
client.invoke_remote_framework_with_payload(payload)

# Check that the _client method was called with the correct arguments
mock_client.assert_called_once_with("lambda")

# Check that the invoke method was called with the correct arguments
mock_lambda_client.invoke.assert_called_once_with(
FunctionName=REMOTE_CARIBOU_CLI_FUNCTION_NAME,
InvocationType="Event",
Payload='{"key": "value"}',
)

@patch.object(AWSRemoteClient, "invoke_remote_framework_with_payload")
def test_invoke_remote_framework_internal_action(self, mock_invoke_remote_framework_with_payload):
action_type = "test_action_type"
action_events = {"key": "value"}

self.aws_client.invoke_remote_framework_internal_action(action_type, action_events)

expected_payload = {
"action": "internal_action",
"type": action_type,
"event": action_events,
}

mock_invoke_remote_framework_with_payload.assert_called_once_with(expected_payload, invocation_type="Event")

@patch.object(AWSRemoteClient, "_client")
def test_update_value_in_table(self, mock_client):
table_name = "test_table"
key = "test_key"
value = "test_value"

# Test without convert_to_bytes
self.aws_client.update_value_in_table(table_name, key, value)
mock_client.assert_called_with("dynamodb")
mock_client.return_value.update_item.assert_called_once_with(
TableName=table_name,
Key={"key": {"S": key}},
UpdateExpression="SET #v = :value",
ExpressionAttributeNames={"#v": "value"},
ExpressionAttributeValues={":value": {"S": value}},
)

# Test with convert_to_bytes
mock_client.reset_mock()
with patch(
"caribou.common.models.remote_client.aws_remote_client.compress_json_str", return_value=b"compressed_value"
):
self.aws_client.update_value_in_table(table_name, key, value, convert_to_bytes=True)
mock_client.assert_called_with("dynamodb")
mock_client.return_value.update_item.assert_called_once_with(
TableName=table_name,
Key={"key": {"S": key}},
UpdateExpression="SET #v = :value",
ExpressionAttributeNames={"#v": "value"},
ExpressionAttributeValues={":value": {"B": b"compressed_value"}},
)


if __name__ == "__main__":
unittest.main()
22 changes: 18 additions & 4 deletions caribou/tests/common/models/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import unittest
from unittest import mock
from unittest.mock import patch, MagicMock
from caribou.common.models.remote_client.aws_remote_client import AWSRemoteClient
from caribou.common.models.remote_client.remote_client import RemoteClient
from caribou.common.models.remote_client.remote_client_factory import RemoteClientFactory
from caribou.common.provider import Provider
from caribou.common.constants import GLOBAL_SYSTEM_REGION, INTEGRATION_TEST_SYSTEM_REGION
from caribou.common.models.endpoints import Endpoints # Adjust the import as needed


class TestEndpoints(unittest.TestCase):
@patch.object(RemoteClientFactory, "get_remote_client")
def test_initialization(self, mock_get_remote_client):
@patch.object(RemoteClientFactory, "get_framework_cli_remote_client")
def test_initialization(self, mock_get_framework_cli_remote_client, mock_get_remote_client):
# Setup environment for the test
mock_get_remote_client.return_value = MagicMock(spec=RemoteClient)
mock_get_framework_cli_remote_client.return_value = MagicMock(spec=AWSRemoteClient)

# Case 1: INTEGRATIONTEST_ON is False
with patch.dict("os.environ", {"INTEGRATIONTEST_ON": "False"}):
endpoints = Endpoints()

# Assertions for AWS provider
mock_get_remote_client.assert_any_call(Provider.AWS.value, mock.ANY)
mock_get_remote_client.assert_any_call(Provider.AWS.value, GLOBAL_SYSTEM_REGION)
self.assertEqual(endpoints.get_deployment_resources_client(), mock_get_remote_client.return_value)
self.assertEqual(endpoints.get_deployment_manager_client(), mock_get_remote_client.return_value)
self.assertEqual(
Expand All @@ -33,7 +36,9 @@ def test_initialization(self, mock_get_remote_client):
endpoints = Endpoints()

# Assertions for Integration Test provider
mock_get_remote_client.assert_any_call(Provider.INTEGRATION_TEST_PROVIDER.value, mock.ANY)
mock_get_remote_client.assert_any_call(
Provider.INTEGRATION_TEST_PROVIDER.value, INTEGRATION_TEST_SYSTEM_REGION
)
self.assertEqual(endpoints.get_deployment_resources_client(), mock_get_remote_client.return_value)
self.assertEqual(endpoints.get_deployment_manager_client(), mock_get_remote_client.return_value)
self.assertEqual(
Expand Down Expand Up @@ -81,6 +86,15 @@ def test_get_datastore_client(self):
endpoints = Endpoints()
self.assertEqual(endpoints.get_datastore_client(), mock_get_remote_client.return_value)

def test_get_framework_cli_remote_client(self):
with patch.object(
RemoteClientFactory, "get_framework_cli_remote_client", return_value=MagicMock(spec=AWSRemoteClient)
) as mock_get_framework_cli_remote_client:
endpoints = Endpoints()
self.assertEqual(
endpoints.get_framework_cli_remote_client(), mock_get_framework_cli_remote_client.return_value
)


if __name__ == "__main__":
unittest.main()
51 changes: 44 additions & 7 deletions caribou/tests/common/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import unittest
from typing import Callable, Any
import inspect
import importlib
import textwrap
import ast

from caribou.common.utils import get_function_source
from caribou.common.utils import decompress_json_str, get_function_source
from caribou.common.utils import compress_json_str
import zstandard as zstd


class TestGetFunctionSource(unittest.TestCase):
Expand Down Expand Up @@ -33,6 +29,47 @@ def test_function():

self.assertIn('print("Hello,world!")', source_code)

def test_compress_json_str(self):
json_str = '{"key": "value"}'
compressed_bytes = compress_json_str(json_str)

# Decompress to verify
dctx = zstd.ZstdDecompressor()
decompressed_bytes = dctx.decompress(compressed_bytes)
decompressed_str = decompressed_bytes.decode("utf-8")

self.assertEqual(json_str, decompressed_str)

def test_compress_json_str_with_different_compression_level(self):
json_str = '{"key": "value"}'
compressed_bytes = compress_json_str(json_str, compression_level=10)

# Decompress to verify
dctx = zstd.ZstdDecompressor()
decompressed_bytes = dctx.decompress(compressed_bytes)
decompressed_str = decompressed_bytes.decode("utf-8")

self.assertEqual(json_str, decompressed_str)

def test_decompress_json_str(self):
json_str = '{"key": "value"}'
compressed_bytes = compress_json_str(json_str)
decompressed_str = decompress_json_str(compressed_bytes)

self.assertEqual(json_str, decompressed_str)

def test_decompress_json_str_with_different_compression_level(self):
json_str = '{"key": "value"}'
compressed_bytes = compress_json_str(json_str, compression_level=10)
decompressed_str = decompress_json_str(compressed_bytes)

self.assertEqual(json_str, decompressed_str)

def test_decompress_json_str_with_invalid_data(self):
invalid_data = b"invalid compressed data"
with self.assertRaises(zstd.ZstdError):
decompress_json_str(invalid_data)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 0b0862c

Please sign in to comment.