diff --git a/tests/conftest.py b/tests/conftest.py index 3979efe5..5a38eef9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import datajoint as dj from packaging import version +from typing import Dict import os from os import environ, remove import minio @@ -56,12 +57,17 @@ def enable_filepath_feature(monkeypatch): @pytest.fixture(scope="session") -def connection_root_bare(): - connection = dj.Connection( - host=os.getenv("DJ_HOST"), - user=os.getenv("DJ_USER"), - password=os.getenv("DJ_PASS"), +def db_creds_root() -> Dict: + return dict( + host=os.getenv("DJ_HOST", "fakeservices.datajoint.io"), + user=os.getenv("DJ_USER", "root"), + password=os.getenv("DJ_PASS", "password"), ) + + +@pytest.fixture(scope="session") +def connection_root_bare(db_creds_root): + connection = dj.Connection(**db_creds_root) yield connection diff --git a/tests/schema_privileges.py b/tests/schema_privileges.py new file mode 100644 index 00000000..b53d6b26 --- /dev/null +++ b/tests/schema_privileges.py @@ -0,0 +1,34 @@ +import datajoint as dj +import inspect + + +class Parent(dj.Lookup): + definition = """ + id: int + """ + contents = [(1,)] + + +class Child(dj.Computed): + definition = """ + -> Parent + """ + + def make(self, key): + self.insert1(key) + + +class NoAccess(dj.Lookup): + definition = """ + string: varchar(10) + """ + + +class NoAccessAgain(dj.Manual): + definition = """ + -> NoAccess + """ + + +LOCALS_PRIV = {k: v for k, v in locals().items() if inspect.isclass(v)} +__all__ = list(LOCALS_PRIV) diff --git a/tests/test_privileges.py b/tests/test_privileges.py new file mode 100644 index 00000000..949dbc8a --- /dev/null +++ b/tests/test_privileges.py @@ -0,0 +1,118 @@ +import os +import pytest +import datajoint as dj +from . import schema, CONN_INFO_ROOT, PREFIX +from . import schema_privileges + +namespace = locals() + + +@pytest.fixture +def schema_priv(connection_test): + schema_priv = dj.Schema( + context=schema_privileges.LOCALS_PRIV, + connection=connection_test, + ) + schema_priv(schema_privileges.Parent) + schema_priv(schema_privileges.Child) + schema_priv(schema_privileges.NoAccess) + schema_priv(schema_privileges.NoAccessAgain) + yield schema_priv + if schema_priv.is_activated(): + schema_priv.drop() + + +@pytest.fixture +def connection_djsubset(connection_root, db_creds_root, schema_priv): + user = "djsubset" + conn = dj.conn(**db_creds_root, reset=True) + schema_priv.activate(f"{PREFIX}_schema_privileges") + conn.query( + f""" + CREATE USER IF NOT EXISTS '{user}'@'%%' + IDENTIFIED BY '{user}' + """ + ) + conn.query( + f""" + GRANT SELECT, INSERT, UPDATE, DELETE + ON `{PREFIX}_schema_privileges`.`#parent` + TO '{user}'@'%%' + """ + ) + conn.query( + f""" + GRANT SELECT, INSERT, UPDATE, DELETE + ON `{PREFIX}_schema_privileges`.`__child` + TO '{user}'@'%%' + """ + ) + conn_djsubset = dj.conn( + host=db_creds_root["host"], + user=user, + password=user, + reset=True, + ) + yield conn_djsubset + conn.query(f"DROP USER {user}") + conn.query(f"DROP DATABASE {PREFIX}_schema_privileges") + + +@pytest.fixture +def connection_djview(connection_root, db_creds_root): + """ + A connection with only SELECT privilege to djtest schemas. + Requires connection_root fixture so that `djview` user exists. + """ + connection = dj.conn( + host=db_creds_root["host"], + user="djview", + password="djview", + reset=True, + ) + yield connection + + +class TestUnprivileged: + def test_fail_create_schema(self, connection_djview): + """creating a schema with no CREATE privilege""" + with pytest.raises(dj.DataJointError): + return dj.Schema( + "forbidden_schema", namespace, connection=connection_djview + ) + + def test_insert_failure(self, connection_djview, schema_any): + unprivileged = dj.Schema( + schema_any.database, namespace, connection=connection_djview + ) + unprivileged.spawn_missing_classes() + assert issubclass(Language, dj.Lookup) and len(Language()) == len( + schema.Language() + ), "failed to spawn missing classes" + with pytest.raises(dj.DataJointError): + Language().insert1(("Socrates", "Greek")) + + def test_failure_to_create_table(self, connection_djview, schema_any): + unprivileged = dj.Schema( + schema_any.database, namespace, connection=connection_djview + ) + + @unprivileged + class Try(dj.Manual): + definition = """ # should not matter really + id : int + --- + value : float + """ + + with pytest.raises(dj.DataJointError): + Try().insert1((1, 1.5)) + + +class TestSubset: + def test_populate_activate(self, connection_djsubset, schema_priv): + schema_priv.activate( + f"{PREFIX}_schema_privileges", create_schema=True, create_tables=False + ) + schema_privileges.Child.populate() + assert schema_privileges.Child.progress(display=False)[0] == 0