Skip to content

Commit

Permalink
Dependencies: Update to sqlalchemy~=2.0
Browse files Browse the repository at this point in the history
A number of minor changes were required for the update:

* Queries that use `order_by` now need to include the property that is
  being ordered on in the list of projections.
* The `Session.bulk_update_mappings` and `Session.bulk_insert_mappings`
  are replaced by using `Session.execute` with the `update` and `insert`
  methods.
* The `sqlalchemy-utils` dependency is no longer used as well as the
  `tests/storage/psql_dos/test_utils.py` file that used it.
* The `future=True` is removed from the engine creation. This was a
  temporary flag to enable v2.0 compatibility with v1.4.
* Test of schema equivalence for export archives needed to be updated
  since the casting of `UUID` columns for PostgreSQL changed.
* Remove the `sphinx-sqlalchemy` dependency since it is not compatible
  with `sqlalchemy~=2.0`. The documentation that relied on it to show
  the database models is temporarily commented out.
  • Loading branch information
sphuber committed Nov 8, 2023
1 parent 6123f52 commit 99960c6
Show file tree
Hide file tree
Showing 33 changed files with 235 additions and 277 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ jobs:
- name: Run test suite
env:
AIIDA_WARN_v3: 1
SQLALCHEMY_WARN_20: 1
run:
.github/workflows/tests.sh

Expand Down
1 change: 0 additions & 1 deletion .github/workflows/test-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ jobs:
- name: Run test suite
env:
AIIDA_WARN_v3: 1
SQLALCHEMY_WARN_20: 1
run:
.github/workflows/tests.sh

Expand Down
15 changes: 14 additions & 1 deletion aiida/cmdline/commands/cmd_data/cmd_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def query(datatype, project, past_days, group_pks, all_users):
n_days_ago = now - datetime.timedelta(days=past_days)
data_filters.update({'ctime': {'>=': n_days_ago}})

# Since the query results are sorted on ``ctime`` it has to be projected on. If it doesn't exist, append it to the
# projections, but make sure to pop it again from the final results since it wasn't part of the original projections
if 'ctime' in project:
pop_ctime = False
else:
project.append('ctime')
pop_ctime = True

qbl.append(datatype, tag='data', with_user='creator', filters=data_filters, project=project)

# If there is a group restriction
Expand All @@ -63,7 +71,12 @@ def query(datatype, project, past_days, group_pks, all_users):
qbl.order_by({datatype: {'ctime': 'asc'}})

object_list = qbl.distinct()
return object_list.all()
results = object_list.all()

if pop_ctime:
return [element[:-1] for element in results]

return results


# pylint: disable=unused-argument,too-many-arguments
Expand Down
4 changes: 2 additions & 2 deletions aiida/orm/nodes/data/array/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,14 +1840,14 @@ def get_bands_and_parents_structure(args, backend=None):
tag='sdata',
with_descendants='bdata',
# We don't care about the creator of StructureData
project=['id', 'attributes.kinds', 'attributes.sites']
project=['id', 'attributes.kinds', 'attributes.sites', 'ctime']
)

q_build.order_by({orm.StructureData: {'ctime': 'desc'}})

structure_dict = {}
list_data = q_build.distinct().all()
for bid, _, _, _, akinds, asites in list_data:
for bid, _, _, _, akinds, asites, _ in list_data:
structure_dict[bid] = (akinds, asites)

entry_list = []
Expand Down
7 changes: 4 additions & 3 deletions aiida/storage/psql_dos/alembic_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
"""Simple wrapper around the alembic command line tool that first loads an AiiDA profile."""
from __future__ import annotations

import contextlib

import alembic
import click
from sqlalchemy.util.compat import nullcontext

from aiida.cmdline import is_verbose
from aiida.cmdline.groups.verdi import VerdiCommandGroup
Expand All @@ -38,8 +39,8 @@ def execute_alembic_command(self, command_name, connect=True, **kwargs):
raise click.ClickException('No profile specified')
migrator = PsqlDosMigrator(self.profile)

context = migrator._alembic_connect() if connect else nullcontext(migrator._alembic_config()) # pylint: disable=protected-access
with context as config: # type: ignore[attr-defined]
context = migrator._alembic_connect() if connect else contextlib.nullcontext(migrator._alembic_config()) # pylint: disable=protected-access
with context as config:
command = getattr(alembic.command, command_name)
config.stdout = click.get_text_stream('stdout')
command(config, **kwargs)
Expand Down
30 changes: 18 additions & 12 deletions aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pathlib
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union

from sqlalchemy import column, insert, update
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common.exceptions import ClosedStorage, ConfigurationError, IntegrityError
Expand Down Expand Up @@ -188,8 +189,9 @@ def _clear(self) -> None:

with self.transaction() as session:
session.execute(
DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY
).values(val=repository_uuid)
DbSetting.__table__.update().where(
DbSetting.key == REPOSITORY_UUID_KEY # type: ignore[attr-defined]
).values(val=repository_uuid)
)

def get_repository(self) -> 'DiskObjectStoreRepositoryBackend':
Expand Down Expand Up @@ -305,8 +307,8 @@ def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults
# by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True)
return [row['id'] for row in rows]
result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall()
return [row.id for row in result]

def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
mapper, keys = self._get_mapper_from_entity(entity_type, True)
Expand All @@ -319,7 +321,7 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_update_mappings(mapper, rows)
session.execute(update(mapper), rows)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None:
# pylint: disable=no-value-for-parameter
Expand All @@ -331,14 +333,17 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None:

session = self.get_session()
# Delete the membership of these nodes to groups.
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) # type: ignore[attr-defined]
).delete(synchronize_session='fetch')
# Delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]
# Delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]
# Delete the actual nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]

def get_backend_entity(self, model: base.Base) -> BackendEntity:
"""
Expand All @@ -356,9 +361,10 @@ def set_global_variable(

session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
if session.query(DbSetting).filter(DbSetting.key == key).count():
if session.query(DbSetting).filter(DbSetting.key == key).count(): # type: ignore[attr-defined]
if overwrite:
session.query(DbSetting).filter(DbSetting.key == key).update(dict(val=value))
session.query(DbSetting).filter(DbSetting.key == key
).update(dict(val=value)) # type: ignore[attr-defined]
else:
raise ValueError(f'The setting {key} already exists')
else:
Expand All @@ -369,7 +375,7 @@ def get_global_variable(self, key: str) -> Union[None, str, int, float]:

session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none()
setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none() # type: ignore[attr-defined]
if setting is None:
raise KeyError(f'No setting found with key {key}')
return setting.val
Expand Down
6 changes: 3 additions & 3 deletions aiida/storage/psql_dos/migrations/utils/legacy_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def export_workflow_data(connection, profile):
DbWorkflowData = table('db_dbworkflowdata')
DbWorkflowStep = table('db_dbworkflowstep')

count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar()
count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar()
count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar()
count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar() # pylint: disable=not-callable
count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar() # pylint: disable=not-callable
count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar() # pylint: disable=not-callable

# Nothing to do if all tables are empty
if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def migrate_repository(connection, profile):
column('repository_metadata', JSONB),
)

node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar()
node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar() # pylint: disable=not-callable
missing_repo_folder = []
shard_count = 256

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def migrate_infer_calculation_entry_point(alembic_op):
fallback_cases.append([uuid, type_string, entry_point_string])

connection.execute(
DbNode.update().where(DbNode.c.type == alembic_op.inline_literal(type_string)
).values(process_type=alembic_op.inline_literal(entry_point_string))
DbNode.update().where(
DbNode.c.type == alembic_op.inline_literal(type_string) # type: ignore[attr-defined]
).values(process_type=alembic_op.inline_literal(entry_point_string))
)

if fallback_cases:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def upgrade():
column('attributes', JSONB),
)

nodes = connection.execute(
nodes = connection.execute( # type: ignore[var-annotated]
select(DbNode.c.id,
DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))
).fetchall()

for pk, uuid in nodes:
symbols = load_numpy_array_from_repository(repo_path, uuid, 'symbols').tolist()
connection.execute(
DbNode.update().where(DbNode.c.id == pk).values(
DbNode.update().where(DbNode.c.id == pk).values( # type: ignore[attr-defined]
attributes=func.jsonb_set(DbNode.c.attributes, op.inline_literal('{"symbols"}'), cast(symbols, JSONB))
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def upgrade():
column('attributes', JSONB),
)

nodes = connection.execute(
nodes = connection.execute( # type: ignore[var-annotated]
select(DbNode.c.id,
DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))
).fetchall()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def upgrade():
sa.column('type', sa.String),
)

nodes = connection.execute(
nodes = connection.execute( # type: ignore[var-annotated]
sa.select(node_model.c.id, node_model.c.uuid).where(
node_model.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def upgrade():
# sa.column('attributes', JSONB),
)

nodes = connection.execute(
nodes = connection.execute( # type: ignore[var-annotated]
sa.select(node_tbl.c.id, node_tbl.c.uuid).where(
node_tbl.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,22 @@ def upgrade():
op.add_column('db_dbnode', sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()), nullable=True))

# transition attributes and extras to node
node_count = conn.execute(select(func.count()).select_from(node_tbl)).scalar()
node_count = conn.execute(select(func.count()).select_from(node_tbl)).scalar() # pylint: disable=not-callable
if node_count:
with get_progress_reporter()(total=node_count, desc='Updating attributes and extras') as progress:
for node in conn.execute(select(node_tbl)).all():
attr_list = conn.execute(select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id)).all()
attr_list = conn.execute( # type: ignore[var-annotated]
select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id)
).all()
attributes, _ = attributes_to_dict(sorted(attr_list, key=lambda a: a.key))
extra_list = conn.execute(select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id)).all()
extra_list = conn.execute( # type: ignore[var-annotated]
select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id)
).all()
extras, _ = attributes_to_dict(sorted(extra_list, key=lambda a: a.key))
conn.execute(
node_tbl.update().where(node_tbl.c.id == node.id).values(attributes=attributes, extras=extras)
node_tbl.update().where( # type: ignore[attr-defined]
node_tbl.c.id == node.id
).values(attributes=attributes, extras=extras)
)
progress.update()

Expand All @@ -107,7 +113,7 @@ def upgrade():
op.add_column('db_dbsetting', sa.Column('val', postgresql.JSONB(astext_type=sa.Text()), nullable=True))

# transition settings
setting_count = conn.execute(select(func.count()).select_from(setting_tbl)).scalar()
setting_count = conn.execute(select(func.count()).select_from(setting_tbl)).scalar() # pylint: disable=not-callable
if setting_count:
with get_progress_reporter()(total=setting_count, desc='Updating settings') as progress:
for setting in conn.execute(select(setting_tbl)).all():
Expand All @@ -129,8 +135,9 @@ def upgrade():
else:
val = setting.dval
conn.execute(
setting_tbl.update().where(setting_tbl.c.id == setting.id
).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text())))
setting_tbl.update().where( # type: ignore[attr-defined]
setting_tbl.c.id == setting.id
).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text())))
)
progress.update()

Expand Down
10 changes: 5 additions & 5 deletions aiida/storage/psql_dos/orm/querybuilder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,10 @@ def get_creation_statistics(self, user_pk: Optional[int] = None) -> Dict[str, An
retdict: Dict[Any, Any] = {}

total_query = session.query(self.Node)
types_query = session.query(self.Node.node_type.label('typestring'), sa_func.count(self.Node.id)) # pylint: disable=no-member
types_query = session.query(self.Node.node_type.label('typestring'), sa_func.count(self.Node.id)) # pylint: disable=no-member,not-callable
stat_query = session.query(
sa_func.date_trunc('day', self.Node.ctime).label('cday'), # pylint: disable=no-member
sa_func.count(self.Node.id), # pylint: disable=no-member
sa_func.count(self.Node.id), # pylint: disable=no-member,not-callable
)

if user_pk is not None:
Expand Down Expand Up @@ -1088,11 +1088,11 @@ def _get_projection(
if func is None:
pass
elif func == 'max':
entity_to_project = sa_func.max(entity_to_project)
entity_to_project = sa_func.max(entity_to_project) # pylint: disable=not-callable
elif func == 'min':
entity_to_project = sa_func.max(entity_to_project)
entity_to_project = sa_func.max(entity_to_project) # pylint: disable=not-callable
elif func == 'count':
entity_to_project = sa_func.count(entity_to_project)
entity_to_project = sa_func.count(entity_to_project) # pylint: disable=not-callable
else:
raise ValueError(f'\nInvalid function specification {func}')

Expand Down
6 changes: 2 additions & 4 deletions aiida/storage/psql_dos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,20 @@ def create_sqlalchemy_engine(config: PsqlConfig):
password=config['database_password'],
hostname=hostname,
port=config['database_port'],
name=config['database_name']
name=config['database_name'],
)
return create_engine(
engine_url,
json_serializer=json.dumps,
json_deserializer=json.loads,
future=True,
encoding='utf-8',
**config.get('engine_kwargs', {}),
)


def create_scoped_session_factory(engine, **kwargs):
"""Create scoped SQLAlchemy session factory"""
from sqlalchemy.orm import scoped_session, sessionmaker
return scoped_session(sessionmaker(bind=engine, future=True, **kwargs))
return scoped_session(sessionmaker(bind=engine, **kwargs))


def flag_modified(instance, key):
Expand Down
7 changes: 4 additions & 3 deletions aiida/storage/sqlite_temp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import shutil
from typing import Any, BinaryIO, Iterator, Sequence

from sqlalchemy import column, insert, update
from sqlalchemy.orm import Session

from aiida.common.exceptions import ClosedStorage, IntegrityError
Expand Down Expand Up @@ -260,8 +261,8 @@ def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}')
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True)
return [row['id'] for row in rows]
result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall()
return [row.id for row in result]

def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
mapper, keys = self._get_mapper_from_entity(entity_type, True)
Expand All @@ -274,7 +275,7 @@ def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None:
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
session.bulk_update_mappings(mapper, rows)
session.execute(update(mapper), rows)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]):
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions aiida/storage/sqlite_zip/migrations/legacy_to_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _json_to_sqlite( # pylint: disable=too-many-branches,too-many-locals

# get mapping of node IDs to node UUIDs
node_uuid_map = { # pylint: disable=unnecessary-comprehension
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbNode.uuid, v1_schema.DbNode.id))
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbNode.uuid, v1_schema.DbNode.id)) # pylint: disable=not-an-iterable
}

# links
Expand Down Expand Up @@ -211,7 +211,7 @@ def _transform_link(link_row):
if data['groups_uuid']:
# get mapping of node IDs to node UUIDs
group_uuid_map = { # pylint: disable=unnecessary-comprehension
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbGroup.uuid, v1_schema.DbGroup.id))
uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbGroup.uuid, v1_schema.DbGroup.id)) # pylint: disable=not-an-iterable
}
length = sum(len(uuids) for uuids in data['groups_uuid'].values())
unknown_nodes: Dict[str, set] = {}
Expand Down
Loading

0 comments on commit 99960c6

Please sign in to comment.