Skip to content

Commit

Permalink
Add airflow_kpo_in_cluster label to KPO pods (#24658)
Browse files Browse the repository at this point in the history
This allows one to determine if the pod was created with in_cluster config
or not, both on the k8s side and in pod_mutation_hooks.
  • Loading branch information
jedcunningham authored Jun 28, 2022
1 parent 8f638bb commit 5326da4
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 12 deletions.
15 changes: 15 additions & 0 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def __init__(
self.disable_verify_ssl = disable_verify_ssl
self.disable_tcp_keepalive = disable_tcp_keepalive

self._is_in_cluster: Optional[bool] = None

# these params used for transition in KPO to K8s hook
# for a deprecation period we will continue to consider k8s settings from airflow.cfg
self._deprecated_core_disable_tcp_keepalive: Optional[bool] = None
Expand Down Expand Up @@ -232,11 +234,13 @@ def get_conn(self) -> Any:

if in_cluster:
self.log.debug("loading kube_config from: in_cluster configuration")
self._is_in_cluster = True
config.load_incluster_config()
return client.ApiClient()

if kubeconfig_path is not None:
self.log.debug("loading kube_config from: %s", kubeconfig_path)
self._is_in_cluster = False
config.load_kube_config(
config_file=kubeconfig_path,
client_configuration=self.client_configuration,
Expand All @@ -249,6 +253,7 @@ def get_conn(self) -> Any:
self.log.debug("loading kube_config from: connection kube_config")
temp_config.write(kubeconfig.encode())
temp_config.flush()
self._is_in_cluster = False
config.load_kube_config(
config_file=temp_config.name,
client_configuration=self.client_configuration,
Expand All @@ -265,14 +270,24 @@ def _get_default_client(self, *, cluster_context=None):
# in the default location
try:
config.load_incluster_config(client_configuration=self.client_configuration)
self._is_in_cluster = True
except ConfigException:
self.log.debug("loading kube_config from: default file")
self._is_in_cluster = False
config.load_kube_config(
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()

@property
def is_in_cluster(self):
"""Expose whether the hook is configured with ``load_incluster_config`` or not"""
if self._is_in_cluster is not None:
return self._is_in_cluster
self.api_client # so we can determine if we are in_cluster or not
return self._is_in_cluster

@cached_property
def api_client(self) -> Any:
"""Cached Kubernetes API client"""
Expand Down
9 changes: 7 additions & 2 deletions airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client)

def get_hook(self):
warnings.warn("get_hook is deprecated. Please use hook instead.", DeprecationWarning, stacklevel=2)
return self.hook

@cached_property
def hook(self) -> KubernetesHook:
hook = KubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
Expand All @@ -353,8 +358,7 @@ def get_hook(self):

@cached_property
def client(self) -> CoreV1Api:
hook = self.get_hook()
return hook.core_v1_client
return self.hook.core_v1_client

def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.V1Pod]:
"""Returns an already-running pod for this task instance if one exists."""
Expand Down Expand Up @@ -580,6 +584,7 @@ def build_pod_request_obj(self, context=None):
pod.metadata.labels.update(
{
'airflow_version': airflow_version.replace('+', '-'),
'airflow_kpo_in_cluster': str(self.hook.is_in_cluster),
}
)
pod_mutation_hook(pod)
Expand Down
15 changes: 11 additions & 4 deletions kubernetes_tests/test_kubernetes_pod_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def setUp(self):
'foo': 'bar',
'kubernetes_pod_operator': 'True',
'airflow_version': airflow_version.replace('+', '-'),
'airflow_kpo_in_cluster': 'False',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'dag_id': 'dag',
'task_id': ANY,
Expand Down Expand Up @@ -734,6 +735,7 @@ def test_pod_template_file_with_overrides_system(self):
'fizz': 'buzz',
'foo': 'bar',
'airflow_version': mock.ANY,
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down Expand Up @@ -773,6 +775,7 @@ def test_pod_template_file_with_full_pod_spec(self):
'fizz': 'buzz',
'foo': 'bar',
'airflow_version': mock.ANY,
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down Expand Up @@ -815,6 +818,7 @@ def test_full_pod_spec(self):
'fizz': 'buzz',
'foo': 'bar',
'airflow_version': mock.ANY,
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down Expand Up @@ -882,9 +886,10 @@ def test_init_container(self):
@mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom")
@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
@mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
@mock.patch(HOOK_CLASS, new=MagicMock)
def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock):
@mock.patch(HOOK_CLASS)
def test_pod_template_file(self, hook_mock, await_pod_completion_mock, extract_xcom_mock):
# todo: This isn't really a system test
hook_mock.return_value.is_in_cluster = False
extract_xcom_mock.return_value = '{}'
path = sys.path[0] + '/tests/kubernetes/pod.yaml'
k = KubernetesPodOperator(
Expand Down Expand Up @@ -920,6 +925,7 @@ def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock):
'metadata': {
'annotations': {},
'labels': {
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down Expand Up @@ -968,13 +974,14 @@ def test_pod_template_file(self, await_pod_completion_mock, extract_xcom_mock):

@mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion")
@mock.patch(f"{POD_MANAGER_CLASS}.create_pod", new=MagicMock)
@mock.patch(HOOK_CLASS, new=MagicMock)
def test_pod_priority_class_name(self, await_pod_completion_mock):
@mock.patch(HOOK_CLASS)
def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock):
"""
Test ability to assign priorityClassName to pod
todo: This isn't really a system test
"""
hook_mock.return_value.is_in_cluster = False

priority_class_name = "medium-test"
k = KubernetesPodOperator(
Expand Down
2 changes: 2 additions & 0 deletions kubernetes_tests/test_kubernetes_pod_operator_backcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def setUp(self):
'foo': 'bar',
'kubernetes_pod_operator': 'True',
'airflow_version': airflow_version.replace('+', '-'),
'airflow_kpo_in_cluster': 'False',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'dag_id': 'dag',
'task_id': 'task',
Expand Down Expand Up @@ -571,6 +572,7 @@ def test_pod_template_file_with_overrides_system(self):
'fizz': 'buzz',
'foo': 'bar',
'airflow_version': mock.ANY,
'airflow_kpo_in_cluster': 'False',
'dag_id': 'dag',
'run_id': 'manual__2016-01-01T0100000100-da4d1ce7b',
'kubernetes_pod_operator': 'True',
Expand Down
7 changes: 7 additions & 0 deletions tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def test_in_cluster_connection(
else:
mock_get_default_client.assert_called()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
if mock_get_default_client.called:
# get_default_client sets it, but it's mocked
assert kubernetes_hook.is_in_cluster is None
else:
assert kubernetes_hook.is_in_cluster is in_cluster_called

@pytest.mark.parametrize('in_cluster_fails', [True, False])
@patch("kubernetes.config.kube_config.KubeConfigLoader")
Expand All @@ -130,10 +135,12 @@ def test_get_default_client(
mock_incluster.assert_called_once()
mock_merger.assert_called_once_with(KUBE_CONFIG_PATH)
mock_loader.assert_called_once()
assert kubernetes_hook.is_in_cluster is False
else:
mock_incluster.assert_called_once()
mock_merger.assert_not_called()
mock_loader.assert_not_called()
assert kubernetes_hook.is_in_cluster is True
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@pytest.mark.parametrize(
Expand Down
20 changes: 14 additions & 6 deletions tests/providers/cncf/kubernetes/operators/test_kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def run_pod(self, operator: KubernetesPodOperator, map_index: int = -1) -> k8s.V
remote_pod_mock = MagicMock()
remote_pod_mock.status.phase = 'Succeeded'
self.await_pod_mock.return_value = remote_pod_mock
if not isinstance(self.hook_mock.return_value.is_in_cluster, bool):
self.hook_mock.return_value.is_in_cluster = True
operator.execute(context=context)
return self.await_start_mock.call_args[1]['pod']

Expand Down Expand Up @@ -170,15 +172,17 @@ def test_envs_from_configmaps(
pod = self.run_pod(k)
assert pod.spec.containers[0].env_from == env_from

def test_labels(self):
@pytest.mark.parametrize(("in_cluster",), ([True], [False]))
def test_labels(self, in_cluster):
self.hook_mock.return_value.is_in_cluster = in_cluster
k = KubernetesPodOperator(
namespace="default",
image="ubuntu:16.04",
cmds=["bash", "-cx"],
labels={"foo": "bar"},
name="test",
task_id="task",
in_cluster=False,
in_cluster=in_cluster,
do_xcom_push=False,
)
pod = self.run_pod(k)
Expand All @@ -190,6 +194,7 @@ def test_labels(self):
"try_number": "1",
"airflow_version": mock.ANY,
"run_id": "test",
"airflow_kpo_in_cluster": str(in_cluster),
}

def test_labels_mapped(self):
Expand All @@ -209,6 +214,7 @@ def test_labels_mapped(self):
"airflow_version": mock.ANY,
"run_id": "test",
"map_index": "10",
"airflow_kpo_in_cluster": "True",
}

def test_find_pod_labels(self):
Expand Down Expand Up @@ -391,6 +397,7 @@ def test_full_pod_spec(self, randomize_name, pod_spec):
"task_id": "task",
"try_number": "1",
"airflow_version": mock.ANY,
"airflow_kpo_in_cluster": "True",
"run_id": "test",
}

Expand Down Expand Up @@ -429,6 +436,7 @@ def test_full_pod_spec_kwargs(self, randomize_name, pod_spec):
"task_id": "task",
"try_number": "1",
"airflow_version": mock.ANY,
"airflow_kpo_in_cluster": "True",
"run_id": "test",
}

Expand Down Expand Up @@ -499,6 +507,7 @@ def test_pod_template_file(self, randomize_name, pod_template_file):
"task_id": "task",
"try_number": "1",
"airflow_version": mock.ANY,
"airflow_kpo_in_cluster": "True",
"run_id": "test",
}
assert pod.metadata.namespace == "mynamespace"
Expand Down Expand Up @@ -568,6 +577,7 @@ def test_pod_template_file_kwargs_override(self, randomize_name, pod_template_fi
"task_id": "task",
"try_number": "1",
"airflow_version": mock.ANY,
"airflow_kpo_in_cluster": "True",
"run_id": "test",
}

Expand Down Expand Up @@ -877,13 +887,11 @@ def test_patch_core_settings(self, key, value, attr, patched_value):
# the hook attr should be None
op = KubernetesPodOperator(task_id='abc', name='hi')
self.hook_patch.stop()
hook = op.get_hook()
assert getattr(hook, attr) is None
assert getattr(op.hook, attr) is None
# now check behavior with a non-default value
with conf_vars({('kubernetes', key): value}):
op = KubernetesPodOperator(task_id='abc', name='hi')
hook = op.get_hook()
assert getattr(hook, attr) == patched_value
assert getattr(op.hook, attr) == patched_value


def test__suppress():
Expand Down

0 comments on commit 5326da4

Please sign in to comment.