Skip to content

Commit

Permalink
NODE service get_table_data task is only allowed for published tables.
Browse files Browse the repository at this point in the history
https://team-1617704806227.atlassian.net/browse/MIP-825
The 'get_table_data' task uses the db public user to fetch data so it only works for published data.
Added 'use_public_user' parameter in the monetdb_facade method, so that the public user is used for the queries.
Added 'get_table_data_from_db' method in the standalone tests (conftest.py) and changed all tests to use this method instead of calling the NODE task.
Refactored many standalone tests to use ONLY the task they are testing, instead of using multiple NODE tasks to ensure the proper result. Instead direct DB queries are made.
Added on/off switch named 'protect_local_data'.
When the switch is on, the 'get_table_data' will use the public user, otherwise not.
The GLOBALNODE should always have this variable switched off.
In the LOCALNODE this variable should be off ONLY in testing scenarios.
  • Loading branch information
ThanKarab committed Sep 21, 2023
1 parent 746aab1 commit 47a0001
Show file tree
Hide file tree
Showing 31 changed files with 363 additions and 574 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
[privacy]
minimum_row_count = 10
protect_local_data = false
[cleanup]
nodes_cleanup_interval=10
Expand Down
7 changes: 6 additions & 1 deletion exareme2/node/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ NODE_IDENTIFIER=globalnode
NODE_ROLE=GLOBALNODE
LOG_LEVEL=INFO
FRAMEWORK_LOG_LEVEL=INFO
PROTECT_LOCAL_DATA=false
RABBITMQ_IP=172.17.0.1
RABBITMQ_PORT=5670
MONETDB_IP=172.17.0.1
MONETDB_PORT=50000
MONETDB_PASSWORD=executor
MONETDB_LOCAL_USERNAME=executor
MONETDB_LOCAL_PASSWORD=executor
MONETDB_PUBLIC_USERNAME=guest
MONETDB_PUBLIC_PASSWORD=guest
SMPC_ENABLED=false
```

Then start the container with:
Expand Down
1 change: 1 addition & 0 deletions exareme2/node/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ framework_log_level = "$FRAMEWORK_LOG_LEVEL"

[privacy]
minimum_row_count = 10
protect_local_data = "$PROTECT_LOCAL_DATA"

[celery]
worker_concurrency = 16
Expand Down
23 changes: 12 additions & 11 deletions exareme2/node/monetdb_interface/common_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from exareme2 import DType
from exareme2.exceptions import TablesNotFound
from exareme2.node import config as node_config
from exareme2.node.monetdb_interface.guard import is_datamodel
from exareme2.node.monetdb_interface.guard import sql_injection_guard
from exareme2.node.monetdb_interface.monet_db_facade import db_execute_and_fetchall
from exareme2.node.monetdb_interface.monet_db_facade import db_execute_query
from exareme2.node_info_DTOs import NodeRole
from exareme2.node_tasks_DTOs import ColumnInfo
from exareme2.node_tasks_DTOs import CommonDataElement
from exareme2.node_tasks_DTOs import CommonDataElements
Expand Down Expand Up @@ -170,33 +172,32 @@ def get_table_names(table_type: TableType, context_id: str) -> List[str]:
return [table[0] for table in table_names]


@sql_injection_guard(table_name=str.isidentifier)
def get_table_data(table_name: str) -> List[ColumnData]:
@sql_injection_guard(
table_name=str.isidentifier,
use_public_user=None,
)
def get_table_data(table_name: str, use_public_user: bool = True) -> List[ColumnData]:
"""
Returns a list of columns data which will contain name, type and the data of the specific column.
Parameters
----------
table_name : str
The name of the table
use_public_user : bool
Will the public or local user be used to access the data?
Returns
------
List[ColumnData]
A list of column data
"""

schema = get_table_schema(table_name)
# TODO: blocked by https://team-1617704806227.atlassian.net/browse/MIP-133 .
# Retrieving the data should be a simple select.
# row_stored_data = db_execute_and_fetchall(f"SELECT * FROM {table_name}")

local_username = node_config.monetdb.local_username
row_stored_data = db_execute_and_fetchall(
f"""
SELECT {table_name}.*
FROM {table_name}
INNER JOIN tables ON tables.name = '{table_name}'
WHERE tables.system=false
"""
f"SELECT * FROM {local_username}.{table_name}", use_public_user=use_public_user
)

column_stored_data = list(zip(*row_stored_data))
Expand Down
41 changes: 30 additions & 11 deletions exareme2/node/monetdb_interface/monet_db_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,34 @@
class _DBExecutionDTO(BaseModel):
query: str
parameters: Optional[List[Any]]
use_public_user: bool = False
timeout: Optional[int]

class Config:
allow_mutation = False


def db_execute_and_fetchall(query: str, parameters=None) -> List:
def db_execute_and_fetchall(
query: str, parameters=None, use_public_user: bool = False
) -> List:
query_execution_timeout = node_config.celery.tasks_timeout
db_execution_dto = _DBExecutionDTO(
query=query, parameters=parameters, timeout=query_execution_timeout
query=query,
parameters=parameters,
use_public_user=use_public_user,
timeout=query_execution_timeout,
)
return _execute_and_fetchall(db_execution_dto=db_execution_dto)


def db_execute_query(query: str, parameters=None):
def db_execute_query(query: str, parameters=None, use_public_user: bool = False):
query_execution_timeout = node_config.celery.tasks_timeout
query = convert_to_idempotent(query)
db_execution_dto = _DBExecutionDTO(
query=query, parameters=parameters, timeout=query_execution_timeout
query=query,
parameters=parameters,
use_public_user=use_public_user,
timeout=query_execution_timeout,
)
_execute(db_execution_dto=db_execution_dto, lock=query_execution_lock)

Expand All @@ -62,21 +71,28 @@ def db_execute_udf(query: str, parameters=None):

# Connection Pool disabled due to bugs in maintaining connections
@contextmanager
def _connection():
def _connection(use_public_user: bool):
if use_public_user:
username = node_config.monetdb.public_username
password = node_config.monetdb.public_password
else:
username = node_config.monetdb.local_username
password = node_config.monetdb.local_password

conn = pymonetdb.connect(
hostname=node_config.monetdb.ip,
port=node_config.monetdb.port,
username=node_config.monetdb.local_username,
password=node_config.monetdb.local_password,
username=username,
password=password,
database=node_config.monetdb.database,
)
yield conn
conn.close()


@contextmanager
def _cursor(commit=False):
with _connection() as conn:
def _cursor(use_public_user: bool, commit: bool = False):
with _connection(use_public_user) as conn:
cur = conn.cursor()
yield cur
cur.close()
Expand Down Expand Up @@ -163,7 +179,7 @@ def _execute_and_fetchall(db_execution_dto) -> List:
Used to execute only select queries that return a result.
'parameters' option to provide the functionality of bind-parameters.
"""
with _cursor() as cur:
with _cursor(use_public_user=db_execution_dto.use_public_user) as cur:
cur.execute(db_execution_dto.query, db_execution_dto.parameters)
result = cur.fetchall()
return result
Expand Down Expand Up @@ -249,7 +265,10 @@ def _execute(db_execution_dto: _DBExecutionDTO, lock):

try:
with _lock(lock, db_execution_dto.timeout):
with _cursor(commit=True) as cur:
with _cursor(
use_public_user=db_execution_dto.use_public_user,
commit=True,
) as cur:
cur.execute(db_execution_dto.query, db_execution_dto.parameters)
except TimeoutError:
error_msg = f"""
Expand Down
7 changes: 6 additions & 1 deletion exareme2/node/tasks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ def get_table_data(request_id: str, table_name: str) -> str:
str(TableData)
An object of TableData in a jsonified format
"""
columns = common_actions.get_table_data(table_name)
# If the public user is used, its ensured that the table won't hold private data.
# Tables are published to the public DB user when they are meant for sending to other nodes.
# The "protect_local_data" config allows for turning this logic off in testing scenarios.
use_public_user = True if node_config.privacy.protect_local_data else False

columns = common_actions.get_table_data(table_name, use_public_user)

return TableData(name=table_name, columns=columns).json()

Expand Down
6 changes: 4 additions & 2 deletions exareme2/node/tasks/smpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def validate_smpc_templates_match(
Nothing, only throws exception if they don't match.
"""

templates = _get_smpc_values_from_table_data(get_table_data(table_name))
templates = _get_smpc_values_from_table_data(get_table_data(table_name, False))
first_template, *_ = templates
for template in templates[1:]:
if template != first_template:
Expand Down Expand Up @@ -73,7 +73,9 @@ def load_data_to_smpc_client(request_id: str, table_name: str, jobid: str) -> st
"load_data_to_smpc_client is allowed only for a LOCALNODE."
)

smpc_values, *_ = _get_smpc_values_from_table_data(get_table_data(table_name))
smpc_values, *_ = _get_smpc_values_from_table_data(
get_table_data(table_name, False)
)

smpc_cluster.load_data_to_smpc_client(
node_config.smpc.client_address, jobid, smpc_values
Expand Down
2 changes: 2 additions & 0 deletions kubernetes/templates/mipengine-globalnode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ spec:
value: {{ .Values.log_level }}
- name: FRAMEWORK_LOG_LEVEL
value: {{ .Values.framework_log_level }}
- name: PROTECT_LOCAL_DATA
value: "false" # The GLOBALNODE does not need to secure its data, since they are not private.
- name: CELERY_TASKS_TIMEOUT
value: {{ quote .Values.controller.celery_tasks_timeout }}
- name: RABBITMQ_IP
Expand Down
2 changes: 2 additions & 0 deletions kubernetes/templates/mipengine-localnode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ spec:
value: {{ .Values.log_level }}
- name: FRAMEWORK_LOG_LEVEL
value: {{ .Values.framework_log_level }}
- name: PROTECT_LOCAL_DATA
value: "true"
- name: CELERY_TASKS_TIMEOUT
value: {{ quote .Values.controller.celery_tasks_timeout }}
- name: RABBITMQ_IP
Expand Down
10 changes: 10 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ def create_configs(c):
node_config["privacy"]["minimum_row_count"] = deployment_config["privacy"][
"minimum_row_count"
]
if node["role"] == "GLOBALNODE":
node_config["privacy"]["protect_local_data"] = False
else:
node_config["privacy"]["protect_local_data"] = deployment_config["privacy"][
"protect_local_data"
]

node_config["smpc"]["enabled"] = deployment_config["smpc"]["enabled"]
if node_config["smpc"]["enabled"]:
Expand All @@ -163,11 +169,15 @@ def create_configs(c):
node_config["smpc"][
"coordinator_address"
] = f"http://{deployment_config['ip']}:{SMPC_COORDINATOR_PORT}"
node_config["privacy"]["protect_local_data"] = False
else:
node_config["smpc"]["client_id"] = node["id"]
node_config["smpc"][
"client_address"
] = f"http://{deployment_config['ip']}:{node['smpc_client_port']}"
node_config["privacy"]["protect_local_data"] = deployment_config[
"privacy"
]["protect_local_data"]

node_config_file = NODES_CONFIG_DIR / f"{node['id']}.toml"
with open(node_config_file, "w+") as fp:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ celery_run_udf_task_timeout = 300

[privacy]
minimum_row_count = 1
protect_local_data = false

[cleanup]
nodes_cleanup_interval=30
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ celery_run_udf_task_timeout = 120

[privacy]
minimum_row_count = 1
protect_local_data = false

[cleanup]
nodes_cleanup_interval=30
Expand Down
7 changes: 7 additions & 0 deletions tests/standalone_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,13 @@ def insert_data_to_db(
db_cursor.execute(sql_clause, list(chain(*table_values)))


def get_table_data_from_db(
db_cursor,
table_name: str,
):
return db_cursor.execute(f"SELECT * FROM {table_name};").fetchall()


def _clean_db(cursor):
class TableType(enum.Enum):
NORMAL = 0
Expand Down
1 change: 0 additions & 1 deletion tests/standalone_tests/test_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def run_udf_on_local_nodes(self, func, keyword_args, *args, **kwargs):
run_udf_on_global_node = run_udf_on_local_nodes


@pytest.mark.slow
class TestLinearRegression:
@pytest.mark.parametrize("nrows", range(10, 100, 10))
@pytest.mark.parametrize("ncols", range(1, 20))
Expand Down
1 change: 0 additions & 1 deletion tests/standalone_tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
np.random.seed(1)


@pytest.mark.slow
class TestLogisticRegression:
@pytest.mark.parametrize("nrows", range(10, 100, 10))
@pytest.mark.parametrize("ncols", range(1, 20))
Expand Down
40 changes: 19 additions & 21 deletions tests/standalone_tests/test_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from exareme2.exceptions import IncompatibleSchemasMergeException
from exareme2.exceptions import TablesNotFound
from exareme2.node_tasks_DTOs import ColumnInfo
from exareme2.node_tasks_DTOs import TableData
from exareme2.node_tasks_DTOs import TableInfo
from exareme2.node_tasks_DTOs import TableSchema
from exareme2.node_tasks_DTOs import TableType
Expand All @@ -15,14 +14,14 @@
from tests.standalone_tests.conftest import MONETDB_LOCALNODE2_PORT
from tests.standalone_tests.conftest import TASKS_TIMEOUT
from tests.standalone_tests.conftest import create_table_in_db
from tests.standalone_tests.conftest import get_table_data_from_db
from tests.standalone_tests.conftest import insert_data_to_db
from tests.standalone_tests.nodes_communication_helper import get_celery_task_signature
from tests.standalone_tests.std_output_logger import StdOutputLogger

create_remote_task_signature = get_celery_task_signature("create_remote_table")
create_merge_table_task_signature = get_celery_task_signature("create_merge_table")
get_merge_tables_task_signature = get_celery_task_signature("get_merge_tables")
get_table_data_task_signature = get_celery_task_signature("get_table_data")


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -210,6 +209,7 @@ def test_create_merge_table_on_top_of_remote_tables(
use_localnode2_database,
globalnode_node_service,
globalnode_celery_app,
globalnode_db_cursor,
use_globalnode_database,
):
"""
Expand All @@ -224,7 +224,7 @@ def test_create_merge_table_on_top_of_remote_tables(
ColumnInfo(name="col3", dtype=DType.STR),
]
)
table_values = [[1, 0.1, "test1"], [2, 0.2, "test2"], [3, 0.3, "test3"]]
initial_table_values = [[1, 0.1, "test1"], [2, 0.2, "test2"], [3, 0.3, "test3"]]
localnode1_tableinfo = TableInfo(
name=f"normal_testlocalnode1_{context_id}",
schema_=table_schema,
Expand All @@ -247,8 +247,12 @@ def test_create_merge_table_on_top_of_remote_tables(
localnode2_tableinfo.schema_,
True,
)
insert_data_to_db(localnode1_tableinfo.name, table_values, localnode1_db_cursor)
insert_data_to_db(localnode2_tableinfo.name, table_values, localnode2_db_cursor)
insert_data_to_db(
localnode1_tableinfo.name, initial_table_values, localnode1_db_cursor
)
insert_data_to_db(
localnode2_tableinfo.name, initial_table_values, localnode2_db_cursor
)

# Create remote tables
local_node_1_monetdb_sock_address = f"{str(COMMON_IP)}:{MONETDB_LOCALNODE1_PORT}"
Expand Down Expand Up @@ -305,21 +309,15 @@ def test_create_merge_table_on_top_of_remote_tables(
)
)

# Validate merge table row count
async_result = globalnode_celery_app.queue_task(
task_signature=get_table_data_task_signature,
logger=StdOutputLogger(),
request_id=request_id,
table_name=merge_table_info.name,
# Validate merge tables contains both remote tables' values
merge_table_values = get_table_data_from_db(
globalnode_db_cursor, merge_table_info.name
)

table_data_json = globalnode_celery_app.get_result(
async_result=async_result,
logger=StdOutputLogger(),
timeout=TASKS_TIMEOUT,
)
table_data = TableData.parse_raw(table_data_json)
column_count = len(table_data.columns)
assert column_count == len(table_values)
row_count = len(table_data.columns[0].data)
assert row_count == len(table_values[0] * 2)
column_count = len(initial_table_values[0])
assert column_count == len(merge_table_values[0])

row_count = len(initial_table_values)
assert row_count * 2 == len(
merge_table_values
) # The rows are doubled since we have 2 localnodes with N rows each.
Loading

0 comments on commit 47a0001

Please sign in to comment.