From 5326da4b83ed4405553e88d5d5464508256498d0 Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Tue, 28 Jun 2022 09:30:02 -0600 Subject: [PATCH] Add `airflow_kpo_in_cluster` label to KPO pods (#24658) This allows one to determine if the pod was created with in_cluster config or not, both on the k8s side and in pod_mutation_hooks. --- .../cncf/kubernetes/hooks/kubernetes.py | 15 ++++++++++++++ .../kubernetes/operators/kubernetes_pod.py | 9 +++++++-- .../test_kubernetes_pod_operator.py | 15 ++++++++++---- ...test_kubernetes_pod_operator_backcompat.py | 2 ++ .../cncf/kubernetes/hooks/test_kubernetes.py | 7 +++++++ .../operators/test_kubernetes_pod.py | 20 +++++++++++++------ 6 files changed, 56 insertions(+), 12 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index c4658ec8f3d10..ed794cd553454 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -127,6 +127,8 @@ def __init__( self.disable_verify_ssl = disable_verify_ssl self.disable_tcp_keepalive = disable_tcp_keepalive + self._is_in_cluster: Optional[bool] = None + # these params used for transition in KPO to K8s hook # for a deprecation period we will continue to consider k8s settings from airflow.cfg self._deprecated_core_disable_tcp_keepalive: Optional[bool] = None @@ -232,11 +234,13 @@ def get_conn(self) -> Any: if in_cluster: self.log.debug("loading kube_config from: in_cluster configuration") + self._is_in_cluster = True config.load_incluster_config() return client.ApiClient() if kubeconfig_path is not None: self.log.debug("loading kube_config from: %s", kubeconfig_path) + self._is_in_cluster = False config.load_kube_config( config_file=kubeconfig_path, client_configuration=self.client_configuration, @@ -249,6 +253,7 @@ def get_conn(self) -> Any: self.log.debug("loading kube_config from: connection kube_config") temp_config.write(kubeconfig.encode()) temp_config.flush() + self._is_in_cluster = False config.load_kube_config( config_file=temp_config.name, client_configuration=self.client_configuration, @@ -265,14 +270,24 @@ def _get_default_client(self, *, cluster_context=None): # in the default location try: config.load_incluster_config(client_configuration=self.client_configuration) + self._is_in_cluster = True except ConfigException: self.log.debug("loading kube_config from: default file") + self._is_in_cluster = False config.load_kube_config( client_configuration=self.client_configuration, context=cluster_context, ) return client.ApiClient() + @property + def is_in_cluster(self): + """Expose whether the hook is configured with ``load_incluster_config`` or not""" + if self._is_in_cluster is not None: + return self._is_in_cluster + self.api_client # so we can determine if we are in_cluster or not + return self._is_in_cluster + @cached_property def api_client(self) -> Any: """Cached Kubernetes API client""" diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py index 1966b6bdb1349..09cad504fe396 100644 --- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py +++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py @@ -342,6 +342,11 @@ def pod_manager(self) -> PodManager: return PodManager(kube_client=self.client) def get_hook(self): + warnings.warn("get_hook is deprecated. Please use hook instead.", DeprecationWarning, stacklevel=2) + return self.hook + + @cached_property + def hook(self) -> KubernetesHook: hook = KubernetesHook( conn_id=self.kubernetes_conn_id, in_cluster=self.in_cluster, @@ -353,8 +358,7 @@ def get_hook(self): @cached_property def client(self) -> CoreV1Api: - hook = self.get_hook() - return hook.core_v1_client + return self.hook.core_v1_client def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.V1Pod]: """Returns an already-running pod for this task instance if one exists.""" @@ -580,6 +584,7 @@ def build_pod_request_obj(self, context=None): pod.metadata.labels.update( { 'airflow_version': airflow_version.replace('+', '-'), + 'airflow_kpo_in_cluster': str(self.hook.is_in_cluster), } ) pod_mutation_hook(pod) diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index fb661e46b08a5..50e5978de7de1 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -93,6 +93,7 @@ def setUp(self): 'foo': 'bar', 'kubernetes_pod_operator': 'True', 'airflow_version': airflow_version.replace('+', '-'), + 'airflow_kpo_in_cluster': 'False', 'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b', 'dag_id': 'dag', 'task_id': ANY, @@ -734,6 +735,7 @@ def test_pod_template_file_with_overrides_system(self): 'fizz': 'buzz', 'foo': 'bar', 'airflow_version': mock.ANY, + 'airflow_kpo_in_cluster': 'False', 'dag_id': 'dag', 'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b', 'kubernetes_pod_operator': 'True', @@ -773,6 +775,7 @@ def test_pod_template_file_with_full_pod_spec(self): 'fizz': 'buzz', 'foo': 'bar', 'airflow_version': mock.ANY, + 'airflow_kpo_in_cluster': 'False', 'dag_id': 'dag', 'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b', 'kubernetes_pod_operator': 'True', @@ -815,6 +818,7 @@ def test_full_pod_spec(self): 'fizz': 'buzz', 'foo': 'bar', 'airflow_version': mock.ANY, + 'airflow_kpo_in_cluster': 'False', 'dag_id': 'dag', 'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b', 'kubernetes_pod_operator': 'True', @@ -882,9 +886,10 @@ def test_init_container(self): @mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom") @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock) - @mock.patch(HOOK_CLASS, new=MagicMock) - def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock): + @mock.patch(HOOK_CLASS) + def test_pod_template_file(self, hook_mock, await_pod_completion_mock, extract_xcom_mock): # todo: This isn't really a system test + hook_mock.return_value.is_in_cluster = False extract_xcom_mock.return_value = '{}' path = sys.path[0] + '/tests/kubernetes/pod.yaml' k = KubernetesPodOperator( @@ -920,6 +925,7 @@ def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock): 'metadata': { 'annotations': {}, 'labels': { + 'airflow_kpo_in_cluster': 'False', 'dag_id': 'dag', 'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b', 'kubernetes_pod_operator': 'True', @@ -968,13 +974,14 @@ def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock): @mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion") @mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock) - @mock.patch(HOOK_CLASS, new=MagicMock) - def test_pod_priority_class_name(self, await_pod_completion_mock): + @mock.patch(HOOK_CLASS) + def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock): """ Test ability to assign priorityClassName to pod todo: This isn't really a system test """ + hook_mock.return_value.is_in_cluster = False priority_class_name = "medium-test" k = KubernetesPodOperator( diff --git a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py index f15400edeab2f..af2f0f38fe1fd 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py +++ b/kubernetes_tests/test_kubernetes_pod_operator_backcompat.py @@ -90,6 +90,7 @@ def setUp(self): 'foo': 'bar', 'kubernetes_pod_operator': 'True', 'airflow_version': airflow_version.replace('+', '-'), + 'airflow_kpo_in_cluster': 'False', 'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b', 'dag_id': 'dag', 'task_id': 'task', @@ -571,6 +572,7 @@ def test_pod_template_file_with_overrides_system(self): 'fizz': 'buzz', 'foo': 'bar', 'airflow_version': mock.ANY, + 'airflow_kpo_in_cluster': 'False', 'dag_id': 'dag', 'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b', 'kubernetes_pod_operator': 'True', diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index 572f6e2890d25..6bbe5926e909d 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -106,6 +106,11 @@ def test_in_cluster_connection( else: mock_get_default_client.assert_called() assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) + if mock_get_default_client.called: + # get_default_client sets it, but it's mocked + assert kubernetes_hook.is_in_cluster is None + else: + assert kubernetes_hook.is_in_cluster is in_cluster_called @pytest.mark.parametrize('in_cluster_fails', [True, False]) @patch("kubernetes.config.kube_config.KubeConfigLoader") @@ -130,10 +135,12 @@ def test_get_default_client( mock_incluster.assert_called_once() mock_merger.assert_called_once_with(KUBE_CONFIG_PATH) mock_loader.assert_called_once() + assert kubernetes_hook.is_in_cluster is False else: mock_incluster.assert_called_once() mock_merger.assert_not_called() mock_loader.assert_not_called() + assert kubernetes_hook.is_in_cluster is True assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) @pytest.mark.parametrize( diff --git a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py index b771361d1a679..88e0eeb07e1c0 100644 --- a/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py @@ -100,6 +100,8 @@ def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> k8s.V remote_pod_mock = MagicMock() remote_pod_mock.status.phase = 'Succeeded' self.await_pod_mock.return_value = remote_pod_mock + if not isinstance(self.hook_mock.return_value.is_in_cluster, bool): + self.hook_mock.return_value.is_in_cluster = True operator.execute(context=context) return self.await_start_mock.call_args[1]['pod'] @@ -170,7 +172,9 @@ def test_envs_from_configmaps( pod = self.run_pod(k) assert pod.spec.containers[0].env_from == env_from - def test_labels(self): + @pytest.mark.parametrize(("in_cluster",), ([True], [False])) + def test_labels(self, in_cluster): + self.hook_mock.return_value.is_in_cluster = in_cluster k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -178,7 +182,7 @@ def test_labels(self): labels={"foo": "bar"}, name="test", task_id="task", - in_cluster=False, + in_cluster=in_cluster, do_xcom_push=False, ) pod = self.run_pod(k) @@ -190,6 +194,7 @@ def test_labels(self): "try_number": "1", "airflow_version": mock.ANY, "run_id": "test", + "airflow_kpo_in_cluster": str(in_cluster), } def test_labels_mapped(self): @@ -209,6 +214,7 @@ def test_labels_mapped(self): "airflow_version": mock.ANY, "run_id": "test", "map_index": "10", + "airflow_kpo_in_cluster": "True", } def test_find_pod_labels(self): @@ -391,6 +397,7 @@ def test_full_pod_spec(self, randomize_name, pod_spec): "task_id": "task", "try_number": "1", "airflow_version": mock.ANY, + "airflow_kpo_in_cluster": "True", "run_id": "test", } @@ -429,6 +436,7 @@ def test_full_pod_spec_kwargs(self, randomize_name, pod_spec): "task_id": "task", "try_number": "1", "airflow_version": mock.ANY, + "airflow_kpo_in_cluster": "True", "run_id": "test", } @@ -499,6 +507,7 @@ def test_pod_template_file(self, randomize_name, pod_template_file): "task_id": "task", "try_number": "1", "airflow_version": mock.ANY, + "airflow_kpo_in_cluster": "True", "run_id": "test", } assert pod.metadata.namespace == "mynamespace" @@ -568,6 +577,7 @@ def test_pod_template_file_kwargs_override(self, randomize_name, pod_template_fi "task_id": "task", "try_number": "1", "airflow_version": mock.ANY, + "airflow_kpo_in_cluster": "True", "run_id": "test", } @@ -877,13 +887,11 @@ def test_patch_core_settings(self, key, value, attr, patched_value): # the hook attr should be None op = KubernetesPodOperator(task_id='abc', name='hi') self.hook_patch.stop() - hook = op.get_hook() - assert getattr(hook, attr) is None + assert getattr(op.hook, attr) is None # now check behavior with a non-default value with conf_vars({('kubernetes', key): value}): op = KubernetesPodOperator(task_id='abc', name='hi') - hook = op.get_hook() - assert getattr(hook, attr) == patched_value + assert getattr(op.hook, attr) == patched_value def test__suppress():