Skip to content

Commit

Permalink
Merge pull request #8 from BritishGeologicalSurvey/params-from-env
Browse files Browse the repository at this point in the history
Params from env
  • Loading branch information
dvalters authored Oct 11, 2019
2 parents 59b13b2 + 6526c7f commit 76dd2a1
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ test:
-e TEST_MSSQL_PORT="${TEST_MSSQL_PORT}" \
-e TEST_MSSQL_DBNAME="${TEST_MSSQL_DBNAME}" \
-e TEST_MSSQL_PASSWORD="${TEST_MSSQL_PASSWORD}" \
"$CI_REGISTRY_IMAGE:test-runner" pytest --cov=etlhelper -vs test/
"$CI_REGISTRY_IMAGE:test-runner" pytest -rsx --cov=etlhelper -vs test/
package:
tags:
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,12 @@ from etlhelper import DbParams
ORACLEDB = DbParams(host="localhost", port=1521,
database="mydata",
username="oracle_user")
user="oracle_user")
```

DbParams objects can also be created from environment variables using the
`from_environment()` function.

#### Get rows

Connections are created by `connect` function.
Expand Down
2 changes: 1 addition & 1 deletion bin/run_tests_for_developer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ docker run \
--net=host \
--name=etlhelper-test-runner \
etlhelper-test-runner \
pytest -vs --cov=etlhelper --cov-report html --cov-report term test/
pytest -vs -rsx --cov=etlhelper --cov-report html --cov-report term test/

# Copy coverage files out of container to local if tests passed
if [ $? -eq 0 ]; then
Expand Down
6 changes: 3 additions & 3 deletions etlhelper/db_helpers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
self.sql_exceptions = (pyodbc.DatabaseError)
self._connect_func = pyodbc.connect
self.connect_exceptions = (pyodbc.DatabaseError, pyodbc.InterfaceError)
self.required_params = {'host', 'port', 'dbname', 'username', 'odbc_driver'}
self.required_params = {'host', 'port', 'dbname', 'user', 'odbc_driver'}
except ImportError:
print("The pyodc Python package could not be found.\n"
"Run: python -m pip install pyodbc")
Expand All @@ -30,14 +30,14 @@ def get_connection_string(self, db_params, password_variable):
# Prepare connection string
password = self.get_password(password_variable)
return (f'DRIVER={db_params.odbc_driver};SERVER=tcp:{db_params.host};PORT={db_params.port};'
f'DATABASE={db_params.dbname};UID={db_params.username};PWD={password}')
f'DATABASE={db_params.dbname};UID={db_params.user};PWD={password}')

def get_sqlalchemy_connection_string(self, db_params, password_variable):
"""
Returns connection string for sql alchemy type
"""
password = self.get_password(password_variable)
driver = db_params.odbc_driver.replace(" ", "+")
return (f'mssql+pyodbc://{db_params.username}:{password}@'
return (f'mssql+pyodbc://{db_params.user}:{password}@'
f'{db_params.host}:{db_params.port}/{db_params.dbname}?'
f'driver={driver}')
6 changes: 3 additions & 3 deletions etlhelper/db_helpers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
self.sql_exceptions = (cx_Oracle.DatabaseError)
self._connect_func = cx_Oracle.connect
self.connect_exceptions = (cx_Oracle.DatabaseError)
self.required_params = {'host', 'port', 'dbname', 'username'}
self.required_params = {'host', 'port', 'dbname', 'user'}
except ImportError:
print("The cxOracle drivers were not found. See setup guide for more information.")

Expand All @@ -28,7 +28,7 @@ def get_connection_string(self, db_params, password_variable):
"""
# Prepare connection string
password = self.get_password(password_variable)
return (f'{db_params.username}/{password}@'
return (f'{db_params.user}/{password}@'
f'{db_params.host}:{db_params.port}/{db_params.dbname}')

def get_sqlalchemy_connection_string(self, db_params, password_variable):
Expand All @@ -37,5 +37,5 @@ def get_sqlalchemy_connection_string(self, db_params, password_variable):
"""
password = self.get_password(password_variable)

return (f'oracle://{db_params.username}:{password}@'
return (f'oracle://{db_params.user}:{password}@'
f'{db_params.host}:{db_params.port}/{db_params.dbname}')
6 changes: 3 additions & 3 deletions etlhelper/db_helpers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self):
self.sql_exceptions = (psycopg2.ProgrammingError)
self._connect_func = psycopg2.connect
self.connect_exceptions = (psycopg2.OperationalError)
self.required_params = {'host', 'port', 'dbname', 'username'}
self.required_params = {'host', 'port', 'dbname', 'user'}
except ImportError:
print("The PostgreSQL python libraries could not be found.\n"
"Run: python -m pip install psycopg2-binary")
Expand All @@ -31,14 +31,14 @@ def get_connection_string(self, db_params, password_variable):
password = self.get_password(password_variable)
return (f'host={db_params.host} port={db_params.port} '
f'dbname={db_params.dbname} '
f'user={db_params.username} password={password}')
f'user={db_params.user} password={password}')

def get_sqlalchemy_connection_string(self, db_params, password_variable):
"""
Returns connection string for sql alchemy
"""
password = self.get_password(password_variable)
return (f'postgresql://{db_params.username}:{password}@'
return (f'postgresql://{db_params.user}:{password}@'
f'{db_params.host}:{db_params.port}/{db_params.dbname}')

@staticmethod
Expand Down
28 changes: 16 additions & 12 deletions etlhelper/db_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DbParams(dict):
here: https://amir.rachum.com/blog/2016/10/05/python-dynamic-attributes/
"""

def __init__(self, dbtype=None, **kwargs):
def __init__(self, dbtype='dbtype not set', **kwargs):
kwargs.update(dbtype=dbtype.upper())
super().__init__(kwargs)
self.validate_params()
Expand Down Expand Up @@ -45,25 +45,29 @@ def validate_params(self):
msg = f'{self.dbtype} not in valid types ({DB_HELPER_FACTORY.helpers.keys()})'
raise ETLHelperDbParamsError(msg)

if (given ^ required_params) & required_params:
msg = f'Parameter not set. Required parameters are {required_params}'
unset_params = (given ^ required_params) & required_params
if unset_params:
msg = f'{unset_params} not set. Required parameters are {required_params}'
raise ETLHelperDbParamsError(msg)

@classmethod
def from_environment(cls, prefix='ETLHelper_'):
"""
Create DbParams object from parameters specified by environment
variables e.g. ETLHelper_DBTYPE, ETLHelper_HOST, ETLHelper_PORT, etc.
variables e.g. ETLHelper_dbtype, ETLHelper_host, ETLHelper_port, etc.
:param prefix: str, prefix to environment variable names
"""
return cls(
dbtype=os.getenv(f'{prefix}DBTYPE'),
odbc_driver=os.getenv(f'{prefix}DBDRIVER'),
host=os.getenv(f'{prefix}HOST'),
port=os.getenv(f'{prefix}PORT'),
dbname=os.getenv(f'{prefix}DBNAME'),
username=os.getenv(f'{prefix}USER'),
)
dbparams_keys = [key for key in os.environ if key.startswith(prefix)]
dbparams_from_env = {key.replace(prefix, '').lower(): os.environ[key]
for key in dbparams_keys}

# Ensure dbtype has been set
dbtype_var = f'{prefix}dbtype'
if 'dbtype' not in dbparams_from_env:
msg = f"{dbtype_var} environment variable is not set"
raise ETLHelperDbParamsError(msg)

return cls(**dbparams_from_env)

def __repr__(self):
key_val_str = ", ".join([f"{key}='{self[key]}'" for key in self.keys()])
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
host='localhost',
port=5432,
dbname='etlhelper',
username='etlhelper_user')
user='etlhelper_user')


@pytest.fixture('module')
Expand Down
10 changes: 3 additions & 7 deletions test/integration/db/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
import pytest

from etlhelper import connect, get_rows, copy_rows, DbParams
from etlhelper.exceptions import ETLHelperError, ETLHelperConnectionError
from etlhelper.exceptions import ETLHelperConnectionError
from test.conftest import db_is_unreachable

# Skip these tests if database is unreachable
try:
ORADB = DbParams.from_environment(prefix='TEST_ORACLE_')
if db_is_unreachable(ORADB.host, ORADB.port):
raise ETLHelperConnectionError()
except (ETLHelperError, TypeError):
# TypeError thrown if host not set, others subclass ETLHelperError
ORADB = DbParams.from_environment(prefix='TEST_ORACLE_')
if db_is_unreachable(ORADB.host, ORADB.port):
pytest.skip('Oracle test database is unreachable', allow_module_level=True)


Expand Down
14 changes: 7 additions & 7 deletions test/unit/test_db_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@pytest.fixture()
def params():
return DbParams(dbtype='ORACLE', odbc_driver='test driver', host='testhost',
port=1521, dbname='testdb', username='testuser')
port=1521, dbname='testdb', user='testuser')


def test_oracle_sql_exceptions():
Expand All @@ -34,7 +34,7 @@ def test_oracle_connect(monkeypatch):
# TODO: Fix DbParams class to take driver as init input.
db_params = DbParams(dbtype='ORACLE',
host='server', port='1521', dbname='testdb',
username='testuser')
user='testuser')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
expected_conn_str = 'testuser/mypassword@server:1521/testdb'

Expand All @@ -52,7 +52,7 @@ def test_oracle_connect(monkeypatch):
def test_sqlserver_connect(monkeypatch):
db_params = DbParams(dbtype='MSSQL',
host='server', port='1521', dbname='testdb',
username='testuser', odbc_driver='test driver')
user='testuser', odbc_driver='test driver')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
expected_conn_str = ('DRIVER=test driver;SERVER=tcp:server;PORT=1521;'
'DATABASE=testdb;UID=testuser;PWD=mypassword')
Expand All @@ -71,7 +71,7 @@ def test_sqlserver_connect(monkeypatch):
def test_postgres_connect(monkeypatch):
db_params = DbParams(dbtype='PG',
host='server', port='1521', dbname='testdb',
username='testuser', odbc_driver='test driver')
user='testuser', odbc_driver='test driver')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
expected_conn_str = 'host=server port=1521 dbname=testdb user=testuser password=mypassword'
mock_connect = Mock()
Expand All @@ -89,7 +89,7 @@ def test_postgres_connect(monkeypatch):
def test_oracle_sqlalchemy_conn_string(monkeypatch):
db_params = DbParams(dbtype='ORACLE',
host='server', port='1521', dbname='testdb',
username='testuser')
user='testuser')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
helper = OracleDbHelper()
conn_str = helper.get_sqlalchemy_connection_string(db_params, 'DB_PASSWORD')
Expand All @@ -101,7 +101,7 @@ def test_oracle_sqlalchemy_conn_string(monkeypatch):
def test_sqlserver_sqlalchemy_connect(monkeypatch):
db_params = DbParams(dbtype='MSSQL',
host='server', port='1521', dbname='testdb',
username='testuser', odbc_driver='test driver')
user='testuser', odbc_driver='test driver')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
helper = SqlServerDbHelper()
conn_str = helper.get_sqlalchemy_connection_string(db_params, 'DB_PASSWORD')
Expand All @@ -113,7 +113,7 @@ def test_sqlserver_sqlalchemy_connect(monkeypatch):
def test_postgres_sqlalchemy_connect(monkeypatch):
db_params = DbParams(dbtype='PG',
host='server', port='1521', dbname='testdb',
username='testuser')
user='testuser')
monkeypatch.setenv('DB_PASSWORD', 'mypassword')
helper = PostgresDbHelper()
conn_str = helper.get_sqlalchemy_connection_string(db_params, 'DB_PASSWORD')
Expand Down
32 changes: 23 additions & 9 deletions test/unit/test_db_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ def test_db_params_repr():
host='localhost',
port=5432,
dbname='etlhelper',
username='etlhelper_user')
user='etlhelper_user')
result = str(test_params)
expected = ("DbParams(host='localhost', "
"port='5432', dbname='etlhelper', username='etlhelper_user', dbtype='PG')")
"port='5432', dbname='etlhelper', "
"user='etlhelper_user', dbtype='PG')")
assert result == expected


Expand All @@ -31,18 +32,31 @@ def test_db_params_from_environment(monkeypatch):
Test capturing db params from environment settings.
"""
# Arrange
monkeypatch.setenv('TEST_DBTYPE', 'ORACLE')
monkeypatch.setenv('TEST_HOST', 'test.host')
monkeypatch.setenv('TEST_PORT', '1234')
monkeypatch.setenv('TEST_DBNAME', 'testdb')
monkeypatch.setenv('TEST_USER', 'testuser')
monkeypatch.setenv('TEST_DB_PARAMS_ENV_DBTYPE', 'ORACLE')
monkeypatch.setenv('TEST_DB_PARAMS_ENV_HOST', 'test.host')
monkeypatch.setenv('TEST_DB_PARAMS_ENV_PORT', '1234')
monkeypatch.setenv('TEST_DB_PARAMS_ENV_DBNAME', 'testdb')
monkeypatch.setenv('TEST_DB_PARAMS_ENV_USER', 'testuser')

# Act
db_params = DbParams.from_environment(prefix='TEST_')
db_params = DbParams.from_environment(prefix='TEST_DB_PARAMS_ENV_')

# Assert
db_params.dbtype = 'ORACLE'
db_params.host = 'test.host'
db_params.port = '1234'
db_params.dbname = 'testdb'
db_params.username = 'testuser'
db_params.user = 'testuser'


def test_db_params_from_environment_not_set(monkeypatch):
"""
Test missing db params from environment settings.
"""
# Arrange
monkeypatch.delenv('TEST_DBTYPE', raising=False)

# Act
with pytest.raises(ETLHelperDbParamsError,
match=r".*environment variable is not set.*"):
DbParams.from_environment(prefix='TEST_')

0 comments on commit 76dd2a1

Please sign in to comment.