Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert cert default, add require_certificate_validation Behavior Flag #447

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20241120-191809.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Revert cert default to False. Add require_certificate_validation Behavior Flag
time: 2024-11-20T19:18:09.725288+01:00
custom:
Author: damian3031
Issue: ""
PR: "447"
29 changes: 21 additions & 8 deletions dbt/adapters/trino/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

import sqlparse
import trino
Expand Down Expand Up @@ -99,7 +99,7 @@ class TrinoNoneCredentials(TrinoCredentials):
user: str
client_tags: Optional[List[str]] = None
roles: Optional[Dict[str, str]] = None
cert: Optional[Union[str, bool]] = True
cert: Optional[str] = None
http_scheme: HttpScheme = HttpScheme.HTTP
http_headers: Optional[Dict[str, str]] = None
session_properties: Dict[str, Any] = field(default_factory=dict)
Expand All @@ -124,7 +124,7 @@ class TrinoCertificateCredentials(TrinoCredentials):
user: Optional[str] = None
client_tags: Optional[List[str]] = None
roles: Optional[Dict[str, str]] = None
cert: Optional[Union[str, bool]] = True
cert: Optional[str] = None
http_headers: Optional[Dict[str, str]] = None
session_properties: Dict[str, Any] = field(default_factory=dict)
prepared_statements_enabled: bool = PREPARED_STATEMENTS_ENABLED_DEFAULT
Expand Down Expand Up @@ -154,7 +154,7 @@ class TrinoLdapCredentials(TrinoCredentials):
impersonation_user: Optional[str] = None
client_tags: Optional[List[str]] = None
roles: Optional[Dict[str, str]] = None
cert: Optional[Union[str, bool]] = True
cert: Optional[str] = None
http_headers: Optional[Dict[str, str]] = None
session_properties: Dict[str, Any] = field(default_factory=dict)
prepared_statements_enabled: bool = PREPARED_STATEMENTS_ENABLED_DEFAULT
Expand Down Expand Up @@ -185,7 +185,7 @@ class TrinoKerberosCredentials(TrinoCredentials):
krb5_config: Optional[str] = None
service_name: Optional[str] = "trino"
mutual_authentication: Optional[bool] = False
cert: Optional[Union[str, bool]] = True
cert: Optional[str] = None
http_headers: Optional[Dict[str, str]] = None
force_preemptive: Optional[bool] = False
hostname_override: Optional[str] = None
Expand Down Expand Up @@ -227,7 +227,7 @@ class TrinoJwtCredentials(TrinoCredentials):
user: Optional[str] = None
client_tags: Optional[List[str]] = None
roles: Optional[Dict[str, str]] = None
cert: Optional[Union[str, bool]] = True
cert: Optional[str] = None
http_headers: Optional[Dict[str, str]] = None
session_properties: Dict[str, Any] = field(default_factory=dict)
prepared_statements_enabled: bool = PREPARED_STATEMENTS_ENABLED_DEFAULT
Expand All @@ -253,7 +253,7 @@ class TrinoOauthCredentials(TrinoCredentials):
user: Optional[str] = None
client_tags: Optional[List[str]] = None
roles: Optional[Dict[str, str]] = None
cert: Optional[Union[str, bool]] = True
cert: Optional[str] = None
http_headers: Optional[Dict[str, str]] = None
session_properties: Dict[str, Any] = field(default_factory=dict)
prepared_statements_enabled: bool = PREPARED_STATEMENTS_ENABLED_DEFAULT
Expand Down Expand Up @@ -282,7 +282,7 @@ class TrinoOauthConsoleCredentials(TrinoCredentials):
user: Optional[str] = None
client_tags: Optional[List[str]] = None
roles: Optional[Dict[str, str]] = None
cert: Optional[Union[str, bool]] = True
cert: Optional[str] = None
http_headers: Optional[Dict[str, str]] = None
session_properties: Dict[str, Any] = field(default_factory=dict)
prepared_statements_enabled: bool = PREPARED_STATEMENTS_ENABLED_DEFAULT
Expand Down Expand Up @@ -423,6 +423,12 @@ class TrinoAdapterResponse(AdapterResponse):

class TrinoConnectionManager(SQLConnectionManager):
TYPE = "trino"
behavior_flags = None

def __init__(self, profile, mp_context, behavior_flags=None) -> None:
super().__init__(profile, mp_context)

TrinoConnectionManager.behavior_flags = behavior_flags

@contextmanager
def exception_handler(self, sql):
Expand Down Expand Up @@ -465,6 +471,13 @@ def open(cls, connection):

credentials = connection.credentials

# set default `cert` value, according to
# require_certificate_validation behavior flag
if credentials.cert is None:
req_cert_val_flag = cls.behavior_flags.require_certificate_validation.setting
if req_cert_val_flag:
credentials.cert = True

# it's impossible for trino to fail here as 'connections' are actually
# just cursor factories.
trino_conn = trino.dbapi.connect(
Expand Down
21 changes: 21 additions & 0 deletions dbt/adapters/trino/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Support,
)
from dbt.adapters.sql import SQLAdapter
from dbt_common.behavior_flags import BehaviorFlag
from dbt_common.contracts.constraints import ConstraintType
from dbt_common.exceptions import DbtDatabaseError

Expand Down Expand Up @@ -47,6 +48,26 @@ class TrinoAdapter(SQLAdapter):
}
)

def __init__(self, config, mp_context) -> None:
super().__init__(config, mp_context)
self.connections = self.ConnectionManager(config, mp_context, self.behavior)

@property
def _behavior_flags(self) -> list[BehaviorFlag]:
return [
{ # type: ignore
"name": "require_certificate_validation",
"default": False,
"description": (
"SSL certificate validation is disabled by default. "
"It is legacy behavior which will be changed in future releases. "
"It is strongly advised to enable `require_certificate_validation` flag "
"or explicitly set `cert` configuration to `True` for security reasons. "
"You may receive an error after that if your SSL setup is incorrect."
),
}
]

@classmethod
def date_function(cls):
return "datenow()"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import warnings

import pytest
from dbt.tests.util import run_dbt, run_dbt_and_capture
from urllib3.exceptions import InsecureRequestWarning


class TestRequireCertificateValidationDefault:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {}}

def test_cert_default_value(self, project):
assert project.adapter.connections.profile.credentials.cert is None

def test_require_certificate_validation_logs(self, project):
dbt_args = ["show", "--inline", "select 1"]
_, logs = run_dbt_and_capture(dbt_args)
assert "It is strongly advised to enable `require_certificate_validation` flag" in logs

@pytest.mark.skip_profile("trino_starburst")
def test_require_certificate_validation_insecure_request_warning(self, project):
with warnings.catch_warnings(record=True) as w:
dbt_args = ["show", "--inline", "select 1"]
run_dbt(dbt_args)

# Check if any InsecureRequestWarning was raised
assert any(
issubclass(warning.category, InsecureRequestWarning) for warning in w
), "InsecureRequestWarning was not raised"


class TestRequireCertificateValidationFalse:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {"require_certificate_validation": False}}

def test_cert_default_value(self, project):
assert project.adapter.connections.profile.credentials.cert is None

def test_require_certificate_validation_logs(self, project):
dbt_args = ["show", "--inline", "select 1"]
_, logs = run_dbt_and_capture(dbt_args)
assert "It is strongly advised to enable `require_certificate_validation` flag" in logs

@pytest.mark.skip_profile("trino_starburst")
def test_require_certificate_validation_insecure_request_warning(self, project):
with warnings.catch_warnings(record=True) as w:
dbt_args = ["show", "--inline", "select 1"]
run_dbt(dbt_args)

# Check if any InsecureRequestWarning was raised
assert any(
issubclass(warning.category, InsecureRequestWarning) for warning in w
), "InsecureRequestWarning was not raised"


class TestRequireCertificateValidationTrue:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {"require_certificate_validation": True}}

def test_cert_default_value(self, project):
assert project.adapter.connections.profile.credentials.cert is True

def test_require_certificate_validation_logs(self, project):
dbt_args = ["show", "--inline", "select 1"]
_, logs = run_dbt_and_capture(dbt_args)
assert "It is strongly advised to enable `require_certificate_validation` flag" not in logs

@pytest.mark.skip_profile("trino_starburst")
def test_require_certificate_validation_insecure_request_warning(self, project):
with warnings.catch_warnings(record=True) as w:
dbt_args = ["show", "--inline", "select 1"]
run_dbt(dbt_args)

# Check if not any InsecureRequestWarning was raised
assert not any(
issubclass(warning.category, InsecureRequestWarning) for warning in w
), "InsecureRequestWarning was not raised"
18 changes: 0 additions & 18 deletions tests/functional/adapter/test_insecure_warnings.py

This file was deleted.

Loading