diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/asb.py b/providers/src/airflow/providers/microsoft/azure/hooks/asb.py index 4a153f0c82624..d5cb666411e05 100644 --- a/providers/src/airflow/providers/microsoft/azure/hooks/asb.py +++ b/providers/src/airflow/providers/microsoft/azure/hooks/asb.py @@ -382,6 +382,39 @@ def create_subscription( return subscription + def update_subscription( + self, + topic_name: str, + subscription_name: str, + max_delivery_count: int | None = None, + dead_lettering_on_message_expiration: bool | None = None, + enable_batched_operations: bool | None = None, + ) -> None: + """ + Update an Azure ServiceBus Topic Subscription under a ServiceBus Namespace. + + :param topic_name: The topic that will own the to-be-created subscription. + :param subscription_name: Name of the subscription that need to be created. + :param max_delivery_count: The maximum delivery count. A message is automatically dead lettered + after this number of deliveries. Default value is 10. + :param dead_lettering_on_message_expiration: A value that indicates whether this subscription + has dead letter support when a message expires. + :param enable_batched_operations: Value that indicates whether server-side batched + operations are enabled. + """ + with self.get_conn() as service_mgmt_conn: + subscription_prop = service_mgmt_conn.get_subscription(topic_name, subscription_name) + if max_delivery_count: + subscription_prop.max_delivery_count = max_delivery_count + if dead_lettering_on_message_expiration is not None: + subscription_prop.dead_lettering_on_message_expiration = dead_lettering_on_message_expiration + if enable_batched_operations is not None: + subscription_prop.enable_batched_operations = enable_batched_operations + # update by updating the properties in the model + service_mgmt_conn.update_subscription(topic_name, subscription_prop) + updated_subscription = service_mgmt_conn.get_subscription(topic_name, subscription_name) + self.log.info("Subscription Updated successfully %s", updated_subscription.name) + def delete_subscription(self, subscription_name: str, topic_name: str) -> None: """ Delete a topic subscription entities under a ServiceBus Namespace. diff --git a/providers/src/airflow/providers/microsoft/azure/operators/asb.py b/providers/src/airflow/providers/microsoft/azure/operators/asb.py index ba3a3257b940d..aa8eecb3f2442 100644 --- a/providers/src/airflow/providers/microsoft/azure/operators/asb.py +++ b/providers/src/airflow/providers/microsoft/azure/operators/asb.py @@ -489,18 +489,13 @@ def execute(self, context: Context) -> None: """Update Subscription properties, by connecting to Service Bus Admin client.""" hook = AdminClientHook(azure_service_bus_conn_id=self.azure_service_bus_conn_id) - with hook.get_conn() as service_mgmt_conn: - subscription_prop = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name) - if self.max_delivery_count: - subscription_prop.max_delivery_count = self.max_delivery_count - if self.dl_on_message_expiration is not None: - subscription_prop.dead_lettering_on_message_expiration = self.dl_on_message_expiration - if self.enable_batched_operations is not None: - subscription_prop.enable_batched_operations = self.enable_batched_operations - # update by updating the properties in the model - service_mgmt_conn.update_subscription(self.topic_name, subscription_prop) - updated_subscription = service_mgmt_conn.get_subscription(self.topic_name, self.subscription_name) - self.log.info("Subscription Updated successfully %s", updated_subscription) + hook.update_subscription( + topic_name=self.topic_name, + subscription_name=self.subscription_name, + max_delivery_count=self.max_delivery_count, + dead_lettering_on_message_expiration=self.dl_on_message_expiration, + enable_batched_operations=self.enable_batched_operations, + ) class ASBReceiveSubscriptionMessageOperator(BaseOperator): diff --git a/providers/tests/microsoft/azure/hooks/test_asb.py b/providers/tests/microsoft/azure/hooks/test_asb.py index 067e79bf5702d..6f9203bd0d1c2 100644 --- a/providers/tests/microsoft/azure/hooks/test_asb.py +++ b/providers/tests/microsoft/azure/hooks/test_asb.py @@ -184,6 +184,37 @@ def test_create_subscription_with_rule( assert mock_subscription_properties.name == subscription_name assert mock_rule_properties.name == mock_rule_name + @mock.patch("azure.servicebus.management.SubscriptionProperties") + @mock.patch(f"{MODULE}.AdminClientHook.get_conn") + def test_modify_subscription(self, mock_sb_admin_client, mock_subscription_properties): + """ + Test modify subscription functionality by ensuring correct data is copied into properties + and passed to update_subscription method of connection mocking the azure service bus function + `update_subscription` + """ + subscription_name = "test_subscription_name" + topic_name = "test_topic_name" + hook = AdminClientHook(azure_service_bus_conn_id=self.conn_id) + + mock_sb_admin_client.return_value.__enter__.return_value.get_subscription.return_value = ( + mock_subscription_properties + ) + + hook.update_subscription( + topic_name, + subscription_name, + max_delivery_count=3, + dead_lettering_on_message_expiration=True, + enable_batched_operations=True, + ) + + expected_calls = [ + mock.call().__enter__().get_subscription(topic_name, subscription_name), + mock.call().__enter__().update_subscription(topic_name, mock_subscription_properties), + mock.call().__enter__().get_subscription(topic_name, subscription_name), + ] + mock_sb_admin_client.assert_has_calls(expected_calls) + @mock.patch(f"{MODULE}.AdminClientHook.get_conn") def test_delete_subscription(self, mock_sb_admin_client): """ diff --git a/providers/tests/microsoft/azure/operators/test_asb.py b/providers/tests/microsoft/azure/operators/test_asb.py index 145d2e729407a..d887db2fff044 100644 --- a/providers/tests/microsoft/azure/operators/test_asb.py +++ b/providers/tests/microsoft/azure/operators/test_asb.py @@ -255,10 +255,9 @@ def test_init(self): @mock.patch("azure.servicebus.management.TopicProperties") def test_create_topic(self, mock_topic_properties, mock_get_conn): """ - Test AzureServiceBusSubscriptionCreateOperator passed with the subscription name, topic name - mocking the connection details, hook create_subscription function + Test AzureServiceBusTopicCreateOperator passed with the topic name + mocking the connection """ - print("Wazzup doc") asb_create_topic = AzureServiceBusTopicCreateOperator( task_id="asb_create_topic", topic_name=TOPIC_NAME, @@ -269,7 +268,7 @@ def test_create_topic(self, mock_topic_properties, mock_get_conn): created_topic_name = asb_create_topic.execute(None) # ensure the topic name is returned assert created_topic_name == TOPIC_NAME - # ensure create_subscription is called with the correct arguments on the connection + # ensure create_topic is called with the correct arguments on the connection mock_get_conn.return_value.__enter__.return_value.create_topic.assert_called_once_with( topic_name=TOPIC_NAME, default_message_time_to_live=None, @@ -287,10 +286,9 @@ def test_create_topic(self, mock_topic_properties, mock_get_conn): user_metadata=None, max_message_size_in_kilobytes=None, ) - print("Later Gator") @mock.patch("airflow.providers.microsoft.azure.hooks.asb.AdminClientHook") - def test_create_subscription_exception(self, mock_sb_admin_client): + def test_create_topic_exception(self, mock_sb_admin_client): """ Test `AzureServiceBusTopicCreateOperator` functionality to raise AirflowException, by passing topic name as None and pytest raise Airflow Exception @@ -428,9 +426,20 @@ def test_update_subscription(self, mock_get_conn, mock_subscription_properties): subscription_name=SUBSCRIPTION_NAME, max_delivery_count=20, ) - with mock.patch.object(asb_update_subscription.log, "info") as mock_log_info: - asb_update_subscription.execute(None) - mock_log_info.assert_called_with("Subscription Updated successfully %s", mock_subscription_properties) + + asb_update_subscription.execute(None) + + mock_get_conn.return_value.__enter__.return_value.get_subscription.assert_has_calls( + [ + mock.call(TOPIC_NAME, SUBSCRIPTION_NAME), # before update + mock.call(TOPIC_NAME, SUBSCRIPTION_NAME), # after update + ] + ) + + mock_get_conn.return_value.__enter__.return_value.update_subscription.assert_called_once_with( + TOPIC_NAME, + mock_subscription_properties, + ) class TestASBSubscriptionReceiveMessageOperator: