diff --git a/.github/workflows/ci-code.yml b/.github/workflows/ci-code.yml index da5d995306..2d91760f67 100644 --- a/.github/workflows/ci-code.yml +++ b/.github/workflows/ci-code.yml @@ -112,7 +112,6 @@ jobs: - name: Run test suite env: AIIDA_WARN_v3: 1 - SQLALCHEMY_WARN_20: 1 run: .github/workflows/tests.sh diff --git a/.github/workflows/test-install.yml b/.github/workflows/test-install.yml index dbda28e153..d25c6b9c24 100644 --- a/.github/workflows/test-install.yml +++ b/.github/workflows/test-install.yml @@ -253,7 +253,6 @@ jobs: - name: Run test suite env: AIIDA_WARN_v3: 1 - SQLALCHEMY_WARN_20: 1 run: .github/workflows/tests.sh diff --git a/aiida/cmdline/commands/cmd_data/cmd_list.py b/aiida/cmdline/commands/cmd_data/cmd_list.py index 4d687396cb..c518fce366 100644 --- a/aiida/cmdline/commands/cmd_data/cmd_list.py +++ b/aiida/cmdline/commands/cmd_data/cmd_list.py @@ -52,6 +52,14 @@ def query(datatype, project, past_days, group_pks, all_users): n_days_ago = now - datetime.timedelta(days=past_days) data_filters.update({'ctime': {'>=': n_days_ago}}) + # Since the query results are sorted on ``ctime`` it has to be projected on. If it doesn't exist, append it to the + # projections, but make sure to pop it again from the final results since it wasn't part of the original projections + if 'ctime' in project: + pop_ctime = False + else: + project.append('ctime') + pop_ctime = True + qbl.append(datatype, tag='data', with_user='creator', filters=data_filters, project=project) # If there is a group restriction @@ -63,7 +71,12 @@ def query(datatype, project, past_days, group_pks, all_users): qbl.order_by({datatype: {'ctime': 'asc'}}) object_list = qbl.distinct() - return object_list.all() + results = object_list.all() + + if pop_ctime: + return [element[:-1] for element in results] + + return results # pylint: disable=unused-argument,too-many-arguments diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 7f731ebb28..6bf54eaafd 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -1840,14 +1840,14 @@ def get_bands_and_parents_structure(args, backend=None): tag='sdata', with_descendants='bdata', # We don't care about the creator of StructureData - project=['id', 'attributes.kinds', 'attributes.sites'] + project=['id', 'attributes.kinds', 'attributes.sites', 'ctime'] ) q_build.order_by({orm.StructureData: {'ctime': 'desc'}}) structure_dict = {} list_data = q_build.distinct().all() - for bid, _, _, _, akinds, asites in list_data: + for bid, _, _, _, akinds, asites, _ in list_data: structure_dict[bid] = (akinds, asites) entry_list = [] diff --git a/aiida/storage/psql_dos/alembic_cli.py b/aiida/storage/psql_dos/alembic_cli.py index a288ce1aac..9aca87bad1 100755 --- a/aiida/storage/psql_dos/alembic_cli.py +++ b/aiida/storage/psql_dos/alembic_cli.py @@ -11,9 +11,10 @@ """Simple wrapper around the alembic command line tool that first loads an AiiDA profile.""" from __future__ import annotations +import contextlib + import alembic import click -from sqlalchemy.util.compat import nullcontext from aiida.cmdline import is_verbose from aiida.cmdline.groups.verdi import VerdiCommandGroup @@ -38,8 +39,8 @@ def execute_alembic_command(self, command_name, connect=True, **kwargs): raise click.ClickException('No profile specified') migrator = PsqlDosMigrator(self.profile) - context = migrator._alembic_connect() if connect else nullcontext(migrator._alembic_config()) # pylint: disable=protected-access - with context as config: # type: ignore[attr-defined] + context = migrator._alembic_connect() if connect else contextlib.nullcontext(migrator._alembic_config()) # pylint: disable=protected-access + with context as config: command = getattr(alembic.command, command_name) config.stdout = click.get_text_stream('stdout') command(config, **kwargs) diff --git a/aiida/storage/psql_dos/backend.py b/aiida/storage/psql_dos/backend.py index 3ca9e0075c..c425fcfe7e 100644 --- a/aiida/storage/psql_dos/backend.py +++ b/aiida/storage/psql_dos/backend.py @@ -15,6 +15,7 @@ import pathlib from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union +from sqlalchemy import column, insert, update from sqlalchemy.orm import Session, scoped_session, sessionmaker from aiida.common.exceptions import ClosedStorage, ConfigurationError, IntegrityError @@ -188,8 +189,9 @@ def _clear(self) -> None: with self.transaction() as session: session.execute( - DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY - ).values(val=repository_uuid) + DbSetting.__table__.update().where( + DbSetting.key == REPOSITORY_UUID_KEY # type: ignore[attr-defined] + ).values(val=repository_uuid) ) def get_repository(self) -> 'DiskObjectStoreRepositoryBackend': @@ -305,8 +307,8 @@ def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults # by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): - session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) - return [row['id'] for row in rows] + result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall() + return [row.id for row in result] def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: mapper, keys = self._get_mapper_from_entity(entity_type, True) @@ -319,7 +321,7 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): - session.bulk_update_mappings(mapper, rows) + session.execute(update(mapper), rows) def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # pylint: disable=no-value-for-parameter @@ -331,14 +333,17 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: session = self.get_session() # Delete the membership of these nodes to groups. - session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) + session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) # type: ignore[attr-defined] ).delete(synchronize_session='fetch') # Delete the links coming out of the nodes marked for deletion. - session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete)) + ).delete(synchronize_session='fetch') # type: ignore[attr-defined] # Delete the links pointing to the nodes marked for deletion. - session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete)) + ).delete(synchronize_session='fetch') # type: ignore[attr-defined] # Delete the actual nodes - session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete)) + ).delete(synchronize_session='fetch') # type: ignore[attr-defined] def get_backend_entity(self, model: base.Base) -> BackendEntity: """ @@ -356,9 +361,10 @@ def set_global_variable( session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): - if session.query(DbSetting).filter(DbSetting.key == key).count(): + if session.query(DbSetting).filter(DbSetting.key == key).count(): # type: ignore[attr-defined] if overwrite: - session.query(DbSetting).filter(DbSetting.key == key).update(dict(val=value)) + session.query(DbSetting).filter(DbSetting.key == key + ).update(dict(val=value)) # type: ignore[attr-defined] else: raise ValueError(f'The setting {key} already exists') else: @@ -369,7 +375,7 @@ def get_global_variable(self, key: str) -> Union[None, str, int, float]: session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): - setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none() + setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none() # type: ignore[attr-defined] if setting is None: raise KeyError(f'No setting found with key {key}') return setting.val diff --git a/aiida/storage/psql_dos/migrations/utils/legacy_workflows.py b/aiida/storage/psql_dos/migrations/utils/legacy_workflows.py index 53dd8cdd6d..39234beffb 100644 --- a/aiida/storage/psql_dos/migrations/utils/legacy_workflows.py +++ b/aiida/storage/psql_dos/migrations/utils/legacy_workflows.py @@ -41,9 +41,9 @@ def export_workflow_data(connection, profile): DbWorkflowData = table('db_dbworkflowdata') DbWorkflowStep = table('db_dbworkflowstep') - count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar() - count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar() - count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar() + count_workflow = connection.execute(select(func.count()).select_from(DbWorkflow)).scalar() # pylint: disable=not-callable + count_workflow_data = connection.execute(select(func.count()).select_from(DbWorkflowData)).scalar() # pylint: disable=not-callable + count_workflow_step = connection.execute(select(func.count()).select_from(DbWorkflowStep)).scalar() # pylint: disable=not-callable # Nothing to do if all tables are empty if count_workflow == 0 and count_workflow_data == 0 and count_workflow_step == 0: diff --git a/aiida/storage/psql_dos/migrations/utils/migrate_repository.py b/aiida/storage/psql_dos/migrations/utils/migrate_repository.py index 0b17fa2ba5..00fcadef6f 100644 --- a/aiida/storage/psql_dos/migrations/utils/migrate_repository.py +++ b/aiida/storage/psql_dos/migrations/utils/migrate_repository.py @@ -33,7 +33,7 @@ def migrate_repository(connection, profile): column('repository_metadata', JSONB), ) - node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar() + node_count = connection.execute(select(func.count()).select_from(DbNode)).scalar() # pylint: disable=not-callable missing_repo_folder = [] shard_count = 256 diff --git a/aiida/storage/psql_dos/migrations/utils/provenance_redesign.py b/aiida/storage/psql_dos/migrations/utils/provenance_redesign.py index 899e5a43ab..8d128675f6 100644 --- a/aiida/storage/psql_dos/migrations/utils/provenance_redesign.py +++ b/aiida/storage/psql_dos/migrations/utils/provenance_redesign.py @@ -126,8 +126,9 @@ def migrate_infer_calculation_entry_point(alembic_op): fallback_cases.append([uuid, type_string, entry_point_string]) connection.execute( - DbNode.update().where(DbNode.c.type == alembic_op.inline_literal(type_string) - ).values(process_type=alembic_op.inline_literal(entry_point_string)) + DbNode.update().where( + DbNode.c.type == alembic_op.inline_literal(type_string) # type: ignore[attr-defined] + ).values(process_type=alembic_op.inline_literal(entry_point_string)) ) if fallback_cases: diff --git a/aiida/storage/psql_dos/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py b/aiida/storage/psql_dos/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py index 61185e72ee..c798b026d2 100644 --- a/aiida/storage/psql_dos/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py +++ b/aiida/storage/psql_dos/migrations/versions/12536798d4d3_trajectory_symbols_to_attribute.py @@ -50,7 +50,7 @@ def upgrade(): column('attributes', JSONB), ) - nodes = connection.execute( + nodes = connection.execute( # type: ignore[var-annotated] select(DbNode.c.id, DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')) ).fetchall() @@ -58,7 +58,7 @@ def upgrade(): for pk, uuid in nodes: symbols = load_numpy_array_from_repository(repo_path, uuid, 'symbols').tolist() connection.execute( - DbNode.update().where(DbNode.c.id == pk).values( + DbNode.update().where(DbNode.c.id == pk).values( # type: ignore[attr-defined] attributes=func.jsonb_set(DbNode.c.attributes, op.inline_literal('{"symbols"}'), cast(symbols, JSONB)) ) ) diff --git a/aiida/storage/psql_dos/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py b/aiida/storage/psql_dos/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py index 60e465319d..f5ea6f634b 100644 --- a/aiida/storage/psql_dos/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py +++ b/aiida/storage/psql_dos/migrations/versions/ce56d84bcc35_delete_trajectory_symbols_array.py @@ -45,7 +45,7 @@ def upgrade(): column('attributes', JSONB), ) - nodes = connection.execute( + nodes = connection.execute( # type: ignore[var-annotated] select(DbNode.c.id, DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')) ).fetchall() diff --git a/aiida/storage/psql_dos/migrations/versions/django_0026_trajectory_symbols_to_attribute.py b/aiida/storage/psql_dos/migrations/versions/django_0026_trajectory_symbols_to_attribute.py index b4817b18b0..0613473a49 100644 --- a/aiida/storage/psql_dos/migrations/versions/django_0026_trajectory_symbols_to_attribute.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0026_trajectory_symbols_to_attribute.py @@ -42,7 +42,7 @@ def upgrade(): sa.column('type', sa.String), ) - nodes = connection.execute( + nodes = connection.execute( # type: ignore[var-annotated] sa.select(node_model.c.id, node_model.c.uuid).where( node_model.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.') ) diff --git a/aiida/storage/psql_dos/migrations/versions/django_0027_delete_trajectory_symbols_array.py b/aiida/storage/psql_dos/migrations/versions/django_0027_delete_trajectory_symbols_array.py index e79b0463b2..ba3276453d 100644 --- a/aiida/storage/psql_dos/migrations/versions/django_0027_delete_trajectory_symbols_array.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0027_delete_trajectory_symbols_array.py @@ -44,7 +44,7 @@ def upgrade(): # sa.column('attributes', JSONB), ) - nodes = connection.execute( + nodes = connection.execute( # type: ignore[var-annotated] sa.select(node_tbl.c.id, node_tbl.c.uuid).where( node_tbl.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.') ) diff --git a/aiida/storage/psql_dos/migrations/versions/django_0037_attributes_extras_settings_json.py b/aiida/storage/psql_dos/migrations/versions/django_0037_attributes_extras_settings_json.py index cdbe50ae39..1f7b325f33 100644 --- a/aiida/storage/psql_dos/migrations/versions/django_0037_attributes_extras_settings_json.py +++ b/aiida/storage/psql_dos/migrations/versions/django_0037_attributes_extras_settings_json.py @@ -88,16 +88,22 @@ def upgrade(): op.add_column('db_dbnode', sa.Column('extras', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) # transition attributes and extras to node - node_count = conn.execute(select(func.count()).select_from(node_tbl)).scalar() + node_count = conn.execute(select(func.count()).select_from(node_tbl)).scalar() # pylint: disable=not-callable if node_count: with get_progress_reporter()(total=node_count, desc='Updating attributes and extras') as progress: for node in conn.execute(select(node_tbl)).all(): - attr_list = conn.execute(select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id)).all() + attr_list = conn.execute( # type: ignore[var-annotated] + select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id) + ).all() attributes, _ = attributes_to_dict(sorted(attr_list, key=lambda a: a.key)) - extra_list = conn.execute(select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id)).all() + extra_list = conn.execute( # type: ignore[var-annotated] + select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id) + ).all() extras, _ = attributes_to_dict(sorted(extra_list, key=lambda a: a.key)) conn.execute( - node_tbl.update().where(node_tbl.c.id == node.id).values(attributes=attributes, extras=extras) + node_tbl.update().where( # type: ignore[attr-defined] + node_tbl.c.id == node.id + ).values(attributes=attributes, extras=extras) ) progress.update() @@ -107,7 +113,7 @@ def upgrade(): op.add_column('db_dbsetting', sa.Column('val', postgresql.JSONB(astext_type=sa.Text()), nullable=True)) # transition settings - setting_count = conn.execute(select(func.count()).select_from(setting_tbl)).scalar() + setting_count = conn.execute(select(func.count()).select_from(setting_tbl)).scalar() # pylint: disable=not-callable if setting_count: with get_progress_reporter()(total=setting_count, desc='Updating settings') as progress: for setting in conn.execute(select(setting_tbl)).all(): @@ -129,8 +135,9 @@ def upgrade(): else: val = setting.dval conn.execute( - setting_tbl.update().where(setting_tbl.c.id == setting.id - ).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text()))) + setting_tbl.update().where( # type: ignore[attr-defined] + setting_tbl.c.id == setting.id + ).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text()))) ) progress.update() diff --git a/aiida/storage/psql_dos/orm/querybuilder/main.py b/aiida/storage/psql_dos/orm/querybuilder/main.py index 23aa3e591b..2408576012 100644 --- a/aiida/storage/psql_dos/orm/querybuilder/main.py +++ b/aiida/storage/psql_dos/orm/querybuilder/main.py @@ -760,10 +760,10 @@ def get_creation_statistics(self, user_pk: Optional[int] = None) -> Dict[str, An retdict: Dict[Any, Any] = {} total_query = session.query(self.Node) - types_query = session.query(self.Node.node_type.label('typestring'), sa_func.count(self.Node.id)) # pylint: disable=no-member + types_query = session.query(self.Node.node_type.label('typestring'), sa_func.count(self.Node.id)) # pylint: disable=no-member,not-callable stat_query = session.query( sa_func.date_trunc('day', self.Node.ctime).label('cday'), # pylint: disable=no-member - sa_func.count(self.Node.id), # pylint: disable=no-member + sa_func.count(self.Node.id), # pylint: disable=no-member,not-callable ) if user_pk is not None: @@ -1088,11 +1088,11 @@ def _get_projection( if func is None: pass elif func == 'max': - entity_to_project = sa_func.max(entity_to_project) + entity_to_project = sa_func.max(entity_to_project) # pylint: disable=not-callable elif func == 'min': - entity_to_project = sa_func.max(entity_to_project) + entity_to_project = sa_func.max(entity_to_project) # pylint: disable=not-callable elif func == 'count': - entity_to_project = sa_func.count(entity_to_project) + entity_to_project = sa_func.count(entity_to_project) # pylint: disable=not-callable else: raise ValueError(f'\nInvalid function specification {func}') diff --git a/aiida/storage/psql_dos/utils.py b/aiida/storage/psql_dos/utils.py index 4d6be94335..4d66ae8fbb 100644 --- a/aiida/storage/psql_dos/utils.py +++ b/aiida/storage/psql_dos/utils.py @@ -45,14 +45,12 @@ def create_sqlalchemy_engine(config: PsqlConfig): password=config['database_password'], hostname=hostname, port=config['database_port'], - name=config['database_name'] + name=config['database_name'], ) return create_engine( engine_url, json_serializer=json.dumps, json_deserializer=json.loads, - future=True, - encoding='utf-8', **config.get('engine_kwargs', {}), ) @@ -60,7 +58,7 @@ def create_sqlalchemy_engine(config: PsqlConfig): def create_scoped_session_factory(engine, **kwargs): """Create scoped SQLAlchemy session factory""" from sqlalchemy.orm import scoped_session, sessionmaker - return scoped_session(sessionmaker(bind=engine, future=True, **kwargs)) + return scoped_session(sessionmaker(bind=engine, **kwargs)) def flag_modified(instance, key): diff --git a/aiida/storage/sqlite_temp/backend.py b/aiida/storage/sqlite_temp/backend.py index dd85d82f89..c25a9d7777 100644 --- a/aiida/storage/sqlite_temp/backend.py +++ b/aiida/storage/sqlite_temp/backend.py @@ -19,6 +19,7 @@ import shutil from typing import Any, BinaryIO, Iterator, Sequence +from sqlalchemy import column, insert, update from sqlalchemy.orm import Session from aiida.common.exceptions import ClosedStorage, IntegrityError @@ -260,8 +261,8 @@ def bulk_insert(self, entity_type: EntityTypes, rows: list[dict], allow_defaults raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): - session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) - return [row['id'] for row in rows] + result = session.execute(insert(mapper).returning(mapper, column('id')), rows).fetchall() + return [row.id for row in result] def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None: mapper, keys = self._get_mapper_from_entity(entity_type, True) @@ -274,7 +275,7 @@ def bulk_update(self, entity_type: EntityTypes, rows: list[dict]) -> None: raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') session = self.get_session() with (nullcontext() if self.in_transaction else self.transaction()): - session.bulk_update_mappings(mapper, rows) + session.execute(update(mapper), rows) def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): raise NotImplementedError diff --git a/aiida/storage/sqlite_zip/migrations/legacy_to_main.py b/aiida/storage/sqlite_zip/migrations/legacy_to_main.py index d23d5a5b18..1afd705bc5 100644 --- a/aiida/storage/sqlite_zip/migrations/legacy_to_main.py +++ b/aiida/storage/sqlite_zip/migrations/legacy_to_main.py @@ -180,7 +180,7 @@ def _json_to_sqlite( # pylint: disable=too-many-branches,too-many-locals # get mapping of node IDs to node UUIDs node_uuid_map = { # pylint: disable=unnecessary-comprehension - uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbNode.uuid, v1_schema.DbNode.id)) + uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbNode.uuid, v1_schema.DbNode.id)) # pylint: disable=not-an-iterable } # links @@ -211,7 +211,7 @@ def _transform_link(link_row): if data['groups_uuid']: # get mapping of node IDs to node UUIDs group_uuid_map = { # pylint: disable=unnecessary-comprehension - uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbGroup.uuid, v1_schema.DbGroup.id)) + uuid: pk for uuid, pk in connection.execute(select(v1_schema.DbGroup.uuid, v1_schema.DbGroup.id)) # pylint: disable=not-an-iterable } length = sum(len(uuids) for uuids in data['groups_uuid'].values()) unknown_nodes: Dict[str, set] = {} diff --git a/aiida/storage/sqlite_zip/migrations/v1_db_schema.py b/aiida/storage/sqlite_zip/migrations/v1_db_schema.py index bad4f14ac0..97d31a37ce 100644 --- a/aiida/storage/sqlite_zip/migrations/v1_db_schema.py +++ b/aiida/storage/sqlite_zip/migrations/v1_db_schema.py @@ -56,8 +56,8 @@ class DbAuthInfo(ArchiveV1Base): nullable=True, index=True ) - _metadata = Column('metadata', JSON, default=dict, nullable=True) - auth_params = Column(JSON, default=dict, nullable=True) + _metadata = Column('metadata', JSON, default=dict, nullable=True) # type: ignore[var-annotated] + auth_params = Column(JSON, default=dict, nullable=True) # type: ignore[misc] enabled = Column(Boolean, default=True, nullable=True) @@ -96,7 +96,7 @@ class DbComputer(ArchiveV1Base): description = Column(Text, default='', nullable=True) scheduler_type = Column(String(255), default='', nullable=True) transport_type = Column(String(255), default='', nullable=True) - _metadata = Column('metadata', JSON, default=dict, nullable=True) + _metadata = Column('metadata', JSON, default=dict, nullable=True) # type: ignore[var-annotated] class DbGroupNodes(ArchiveV1Base): @@ -126,7 +126,7 @@ class DbGroup(ArchiveV1Base): type_string = Column(String(255), default='', nullable=True, index=True) time = Column(DateTime(timezone=True), default=timezone.now, nullable=True) description = Column(Text, default='', nullable=True) - extras = Column(JSON, default=dict, nullable=False) + extras = Column(JSON, default=dict, nullable=False) # type: ignore[misc] user_id = Column( Integer, ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), @@ -152,7 +152,7 @@ class DbLog(ArchiveV1Base): index=True ) message = Column(Text(), default='', nullable=True) - _metadata = Column('metadata', JSON, default=dict, nullable=True) + _metadata = Column('metadata', JSON, default=dict, nullable=True) # type: ignore[var-annotated] class DbNode(ArchiveV1Base): @@ -168,9 +168,9 @@ class DbNode(ArchiveV1Base): description = Column(Text(), default='', nullable=True) ctime = Column(DateTime(timezone=True), default=timezone.now, nullable=True, index=True) mtime = Column(DateTime(timezone=True), default=timezone.now, nullable=True, index=True) - attributes = Column(JSON) - extras = Column(JSON) - repository_metadata = Column(JSON, nullable=False, default=dict, server_default='{}') + attributes = Column(JSON) # type: ignore[var-annotated] + extras = Column(JSON) # type: ignore[var-annotated] + repository_metadata = Column(JSON, nullable=False, default=dict, server_default='{}') # type: ignore[var-annotated] dbcomputer_id = Column( Integer, ForeignKey('db_dbcomputer.id', deferrable=True, initially='DEFERRED', ondelete='RESTRICT'), diff --git a/aiida/storage/sqlite_zip/migrations/versions/main_0000a_replace_nulls.py b/aiida/storage/sqlite_zip/migrations/versions/main_0000a_replace_nulls.py index 7d5fa87463..ab1b4c5d68 100644 --- a/aiida/storage/sqlite_zip/migrations/versions/main_0000a_replace_nulls.py +++ b/aiida/storage/sqlite_zip/migrations/versions/main_0000a_replace_nulls.py @@ -43,12 +43,24 @@ def upgrade(): # pylint: disable=too-many-statements ) # remove rows with null values, which may have previously resulted from deletion of a user or computer - op.execute(db_dbauthinfo.delete().where(db_dbauthinfo.c.aiidauser_id.is_(None))) # type: ignore[arg-type] - op.execute(db_dbauthinfo.delete().where(db_dbauthinfo.c.dbcomputer_id.is_(None))) # type: ignore[arg-type] + op.execute(db_dbauthinfo.delete().where(db_dbauthinfo.c.aiidauser_id.is_(None))) + op.execute(db_dbauthinfo.delete().where(db_dbauthinfo.c.dbcomputer_id.is_(None))) - op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.enabled.is_(None)).values(enabled=True)) - op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.auth_params.is_(None)).values(auth_params={})) - op.execute(db_dbauthinfo.update().where(db_dbauthinfo.c.metadata.is_(None)).values(metadata={})) + op.execute( + db_dbauthinfo.update().where( # type: ignore[attr-defined] + db_dbauthinfo.c.enabled.is_(None) + ).values(enabled=True) + ) + op.execute( + db_dbauthinfo.update().where( # type: ignore[attr-defined] + db_dbauthinfo.c.auth_params.is_(None) + ).values(auth_params={}) + ) + op.execute( + db_dbauthinfo.update().where( # type: ignore[attr-defined] + db_dbauthinfo.c.metadata.is_(None) + ).values(metadata={}) + ) db_dbcomment = sa.sql.table( 'db_dbcomment', @@ -61,12 +73,22 @@ def upgrade(): # pylint: disable=too-many-statements ) # remove rows with null values, which may have previously resulted from deletion of a node or user - op.execute(db_dbcomment.delete().where(db_dbcomment.c.dbnode_id.is_(None))) # type: ignore[arg-type] - op.execute(db_dbcomment.delete().where(db_dbcomment.c.user_id.is_(None))) # type: ignore[arg-type] + op.execute(db_dbcomment.delete().where(db_dbcomment.c.dbnode_id.is_(None))) + op.execute(db_dbcomment.delete().where(db_dbcomment.c.user_id.is_(None))) - op.execute(db_dbcomment.update().where(db_dbcomment.c.content.is_(None)).values(content='')) - op.execute(db_dbcomment.update().where(db_dbcomment.c.ctime.is_(None)).values(ctime=timezone.now())) - op.execute(db_dbcomment.update().where(db_dbcomment.c.mtime.is_(None)).values(mtime=timezone.now())) + op.execute( + db_dbcomment.update().where(db_dbcomment.c.content.is_(None)).values(content='') # type: ignore[attr-defined] + ) + op.execute( + db_dbcomment.update().where( # type: ignore[attr-defined] + db_dbcomment.c.ctime.is_(None) + ).values(ctime=timezone.now()) + ) + op.execute( + db_dbcomment.update().where( # type: ignore[attr-defined] + db_dbcomment.c.mtime.is_(None) + ).values(mtime=timezone.now()) + ) db_dbcomputer = sa.sql.table( 'db_dbcomputer', @@ -78,11 +100,31 @@ def upgrade(): # pylint: disable=too-many-statements sa.Column('uuid', sa.CHAR(32)), ) - op.execute(db_dbcomputer.update().where(db_dbcomputer.c.description.is_(None)).values(description='')) - op.execute(db_dbcomputer.update().where(db_dbcomputer.c.hostname.is_(None)).values(hostname='')) - op.execute(db_dbcomputer.update().where(db_dbcomputer.c.metadata.is_(None)).values(metadata={})) - op.execute(db_dbcomputer.update().where(db_dbcomputer.c.scheduler_type.is_(None)).values(scheduler_type='')) - op.execute(db_dbcomputer.update().where(db_dbcomputer.c.transport_type.is_(None)).values(transport_type='')) + op.execute( + db_dbcomputer.update().where( # type: ignore[attr-defined] + db_dbcomputer.c.description.is_(None) + ).values(description='') + ) + op.execute( + db_dbcomputer.update().where( # type: ignore[attr-defined] + db_dbcomputer.c.hostname.is_(None) + ).values(hostname='') + ) + op.execute( + db_dbcomputer.update().where( # type: ignore[attr-defined] + db_dbcomputer.c.metadata.is_(None) + ).values(metadata={}) + ) + op.execute( + db_dbcomputer.update().where( # type: ignore[attr-defined] + db_dbcomputer.c.scheduler_type.is_(None) + ).values(scheduler_type='') + ) + op.execute( + db_dbcomputer.update().where( # type: ignore[attr-defined] + db_dbcomputer.c.transport_type.is_(None) + ).values(transport_type='') + ) db_dbgroup = sa.sql.table( 'db_dbgroup', @@ -93,9 +135,21 @@ def upgrade(): # pylint: disable=too-many-statements sa.Column('uuid', sa.CHAR(32)), ) - op.execute(db_dbgroup.update().where(db_dbgroup.c.description.is_(None)).values(description='')) - op.execute(db_dbgroup.update().where(db_dbgroup.c.time.is_(None)).values(time=timezone.now())) - op.execute(db_dbgroup.update().where(db_dbgroup.c.type_string.is_(None)).values(type_string='core')) + op.execute( + db_dbgroup.update().where( # type: ignore[attr-defined] + db_dbgroup.c.description.is_(None) + ).values(description='') + ) + op.execute( + db_dbgroup.update().where( # type: ignore[attr-defined] + db_dbgroup.c.time.is_(None) + ).values(time=timezone.now()) + ) + op.execute( + db_dbgroup.update().where( # type: ignore[attr-defined] + db_dbgroup.c.type_string.is_(None) + ).values(type_string='core') + ) db_dblog = sa.sql.table( 'db_dblog', @@ -107,11 +161,17 @@ def upgrade(): # pylint: disable=too-many-statements sa.Column('uuid', sa.CHAR(32)), ) - op.execute(db_dblog.update().where(db_dblog.c.levelname.is_(None)).values(levelname='')) - op.execute(db_dblog.update().where(db_dblog.c.loggername.is_(None)).values(loggername='')) - op.execute(db_dblog.update().where(db_dblog.c.message.is_(None)).values(message='')) - op.execute(db_dblog.update().where(db_dblog.c.metadata.is_(None)).values(metadata={})) - op.execute(db_dblog.update().where(db_dblog.c.time.is_(None)).values(time=timezone.now())) + op.execute( + db_dblog.update().where(db_dblog.c.levelname.is_(None)).values(levelname='') # type: ignore[attr-defined] + ) + op.execute( + db_dblog.update().where(db_dblog.c.loggername.is_(None)).values(loggername='') # type: ignore[attr-defined] + ) + op.execute(db_dblog.update().where(db_dblog.c.message.is_(None)).values(message='')) # type: ignore[attr-defined] + op.execute(db_dblog.update().where(db_dblog.c.metadata.is_(None)).values(metadata={})) # type: ignore[attr-defined] + op.execute( + db_dblog.update().where(db_dblog.c.time.is_(None)).values(time=timezone.now()) # type: ignore[attr-defined] + ) db_dbnode = sa.sql.table( 'db_dbnode', @@ -123,10 +183,16 @@ def upgrade(): # pylint: disable=too-many-statements sa.Column('uuid', sa.CHAR(32)), ) - op.execute(db_dbnode.update().where(db_dbnode.c.ctime.is_(None)).values(ctime=timezone.now())) - op.execute(db_dbnode.update().where(db_dbnode.c.description.is_(None)).values(description='')) - op.execute(db_dbnode.update().where(db_dbnode.c.label.is_(None)).values(label='')) - op.execute(db_dbnode.update().where(db_dbnode.c.mtime.is_(None)).values(mtime=timezone.now())) + op.execute( + db_dbnode.update().where(db_dbnode.c.ctime.is_(None)).values(ctime=timezone.now()) # type: ignore[attr-defined] + ) + op.execute( + db_dbnode.update().where(db_dbnode.c.description.is_(None)).values(description='') # type: ignore[attr-defined] + ) + op.execute(db_dbnode.update().where(db_dbnode.c.label.is_(None)).values(label='')) # type: ignore[attr-defined] + op.execute( + db_dbnode.update().where(db_dbnode.c.mtime.is_(None)).values(mtime=timezone.now()) # type: ignore[attr-defined] + ) db_dbuser = sa.sql.table( 'db_dbuser', @@ -136,9 +202,15 @@ def upgrade(): # pylint: disable=too-many-statements sa.Column('institution', sa.String(254)), ) - op.execute(db_dbuser.update().where(db_dbuser.c.first_name.is_(None)).values(first_name='')) - op.execute(db_dbuser.update().where(db_dbuser.c.last_name.is_(None)).values(last_name='')) - op.execute(db_dbuser.update().where(db_dbuser.c.institution.is_(None)).values(institution='')) + op.execute( + db_dbuser.update().where(db_dbuser.c.first_name.is_(None)).values(first_name='') # type: ignore[attr-defined] + ) + op.execute( + db_dbuser.update().where(db_dbuser.c.last_name.is_(None)).values(last_name='') # type: ignore[attr-defined] + ) + op.execute( + db_dbuser.update().where(db_dbuser.c.institution.is_(None)).values(institution='') # type: ignore[attr-defined] + ) def downgrade(): diff --git a/aiida/storage/sqlite_zip/utils.py b/aiida/storage/sqlite_zip/utils.py index 12e3b28493..048b461f4b 100644 --- a/aiida/storage/sqlite_zip/utils.py +++ b/aiida/storage/sqlite_zip/utils.py @@ -52,14 +52,7 @@ def sqlite_case_sensitive_like(dbapi_connection, _): def create_sqla_engine(path: Union[str, Path], *, enforce_foreign_keys: bool = True, **kwargs) -> Engine: """Create a new engine instance.""" - engine = create_engine( - f'sqlite:///{path}', - json_serializer=json.dumps, - json_deserializer=json.loads, - encoding='utf-8', - future=True, - **kwargs - ) + engine = create_engine(f'sqlite:///{path}', json_serializer=json.dumps, json_deserializer=json.loads, **kwargs) event.listen(engine, 'connect', sqlite_case_sensitive_like) if enforce_foreign_keys: event.listen(engine, 'connect', sqlite_enforce_foreign_keys) diff --git a/docs/source/conf.py b/docs/source/conf.py index 74c5f6596e..d1ac9e1461 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -62,7 +62,7 @@ 'myst_nb', 'sphinx.ext.intersphinx', 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.viewcode', 'sphinx.ext.coverage', 'sphinx.ext.mathjax', 'sphinx.ext.ifconfig', 'sphinx.ext.todo', 'IPython.sphinxext.ipython_console_highlighting', 'IPython.sphinxext.ipython_directive', 'aiida.sphinxext', - 'sphinx_design', 'sphinx_copybutton', 'sphinxext.rediraffe', 'notfound.extension', 'sphinx_sqlalchemy' + 'sphinx_design', 'sphinx_copybutton', 'sphinxext.rediraffe', 'notfound.extension' ] intersphinx_mapping = { diff --git a/docs/source/internals/storage/psql_dos.rst b/docs/source/internals/storage/psql_dos.rst index bc061cab73..4fb9675f07 100644 --- a/docs/source/internals/storage/psql_dos.rst +++ b/docs/source/internals/storage/psql_dos.rst @@ -36,26 +36,17 @@ This is a 32-position hexadecimal sequence that is stored as a string with some - ``→`` indicate foreign keys - ``?`` indicate value types that are nullable. -.. sqla-model:: ~aiida.storage.psql_dos.models.user.DbUser - -.. sqla-model:: ~aiida.storage.psql_dos.models.node.DbNode - -.. sqla-model:: ~aiida.storage.psql_dos.models.node.DbLink - -.. sqla-model:: ~aiida.storage.psql_dos.models.group.DbGroup - -.. sqla-model:: ~aiida.storage.psql_dos.models.group.DbGroupNode - -.. sqla-model:: ~aiida.storage.psql_dos.models.computer.DbComputer - -.. sqla-model:: ~aiida.storage.psql_dos.models.authinfo.DbAuthInfo - -.. sqla-model:: ~aiida.storage.psql_dos.models.comment.DbComment - -.. sqla-model:: ~aiida.storage.psql_dos.models.log.DbLog - -.. sqla-model:: ~aiida.storage.psql_dos.models.settings.DbSetting - +.. + .. sqla-model:: ~aiida.storage.psql_dos.models.user.DbUser + .. sqla-model:: ~aiida.storage.psql_dos.models.node.DbNode + .. sqla-model:: ~aiida.storage.psql_dos.models.node.DbLink + .. sqla-model:: ~aiida.storage.psql_dos.models.group.DbGroup + .. sqla-model:: ~aiida.storage.psql_dos.models.group.DbGroupNode + .. sqla-model:: ~aiida.storage.psql_dos.models.computer.DbComputer + .. sqla-model:: ~aiida.storage.psql_dos.models.authinfo.DbAuthInfo + .. sqla-model:: ~aiida.storage.psql_dos.models.comment.DbComment + .. sqla-model:: ~aiida.storage.psql_dos.models.log.DbLog + .. sqla-model:: ~aiida.storage.psql_dos.models.settings.DbSetting The many-to-one relationship ---------------------------- diff --git a/docs/source/internals/storage/sqlite_zip.rst b/docs/source/internals/storage/sqlite_zip.rst index 6ba6670ef0..8d4e1e27a2 100644 --- a/docs/source/internals/storage/sqlite_zip.rst +++ b/docs/source/internals/storage/sqlite_zip.rst @@ -83,23 +83,15 @@ The only differences are in the handling of certain data types by SQLite versus Also, `varchar_pattern_ops` indexes are not possible in SQLite. -Tables -...... - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbUser - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbNode - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbLink - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbGroup - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbGroupNodes - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbComputer - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbAuthInfo - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbComment - -.. sqla-model:: ~aiida.storage.sqlite_zip.models.DbLog +.. + Tables + ...... + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbUser + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbNode + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbLink + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbGroup + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbGroupNodes + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbComputer + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbAuthInfo + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbComment + .. sqla-model:: ~aiida.storage.sqlite_zip.models.DbLog diff --git a/environment.yml b/environment.yml index 78374edb9e..7a8bd8a795 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: - pytz~=2021.1 - pyyaml~=6.0 - requests~=2.0 -- sqlalchemy~=1.4.22 +- sqlalchemy~=2.0 - tabulate~=0.8.5 - tqdm~=4.45 - upf_to_json~=0.9.2 diff --git a/pyproject.toml b/pyproject.toml index 0d9fd964e0..3ead519b85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dependencies = [ "pytz~=2021.1", "pyyaml~=6.0", "requests~=2.0", - "sqlalchemy~=1.4.22", + "sqlalchemy~=2.0", "tabulate~=0.8.5", "tqdm~=4.45", "upf_to_json~=0.9.2", @@ -84,7 +84,7 @@ docs = [ "sphinx-design~=0.0.13", "sphinx-notfound-page~=0.5", "sphinxext-rediraffe~=0.2.4", - "sphinx-sqlalchemy~=0.1.1", + # "sphinx-sqlalchemy~=0.1.1", "sphinx-intl~=2.1.0", "myst-nb~=0.17.0", ] @@ -108,7 +108,7 @@ pre-commit = [ "pre-commit~=2.2", "pylint~=2.17.4", "pylint-aiida~=0.1.1", - "sqlalchemy[mypy]~=1.4.29", + "sqlalchemy[mypy]~=2.0", "tomli", "types-PyYAML", ] @@ -125,7 +125,6 @@ tests = [ "pytest-regressions~=2.2", "pympler~=0.9", "coverage~=6.0", - "sqlalchemy-utils~=0.37.2", "sphinx~=4.0", "docutils==0.16", ] @@ -407,7 +406,6 @@ module = [ 'bs4.*', 'CifFile.*', 'circus.*', - 'click_config_file.*', 'click_spinner.*', 'docutils.*', 'flask_cors.*', @@ -438,7 +436,6 @@ module = [ 'seekpath.*', 'spglib.*', 'sphinxcontrib.details.*', - 'sqlalchemy_utils.*', 'tabulate.*', 'tqdm.*', 'trogon.*', @@ -467,8 +464,7 @@ deps = passenv = PYTHONASYNCIODEBUG setenv = - AIIDA_WARN_v3 = 1 - SQLALCHEMY_WARN_20 = 1 + AIIDA_WARN_v3 = commands = pytest {posargs} [testenv:py{39,310,311}-verdi] diff --git a/requirements/requirements-py-3.10.txt b/requirements/requirements-py-3.10.txt index f10c1b7242..e7231a59c3 100644 --- a/requirements/requirements-py-3.10.txt +++ b/requirements/requirements-py-3.10.txt @@ -178,7 +178,6 @@ sphinx-copybutton==0.5.2 sphinx-design==0.0.13 sphinx-intl==2.1.0 sphinx-notfound-page==0.8.3 -sphinx-sqlalchemy==0.1.1 sphinxcontrib-applehelp==1.0.4 sphinxcontrib-details-directive==0.1.0 sphinxcontrib-devhelp==1.0.2 @@ -187,8 +186,7 @@ sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sphinxext-rediraffe==0.2.7 -sqlalchemy==1.4.48 -sqlalchemy-utils==0.37.9 +sqlalchemy==2.0.23 stack-data==0.6.2 sympy==1.12 tabulate==0.8.10 diff --git a/requirements/requirements-py-3.11.txt b/requirements/requirements-py-3.11.txt index 1948ce8008..bd0fe3749e 100644 --- a/requirements/requirements-py-3.11.txt +++ b/requirements/requirements-py-3.11.txt @@ -177,7 +177,6 @@ sphinx-copybutton==0.5.2 sphinx-design==0.0.13 sphinx-intl==2.1.0 sphinx-notfound-page==0.8.3 -sphinx-sqlalchemy==0.1.1 sphinxcontrib-applehelp==1.0.4 sphinxcontrib-details-directive==0.1.0 sphinxcontrib-devhelp==1.0.2 @@ -186,8 +185,7 @@ sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sphinxext-rediraffe==0.2.7 -sqlalchemy==1.4.48 -sqlalchemy-utils==0.37.9 +sqlalchemy==2.0.23 stack-data==0.6.2 sympy==1.12 tabulate==0.8.10 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index f9996690bf..123b833a5e 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -180,7 +180,6 @@ sphinx-copybutton==0.5.2 sphinx-design==0.0.13 sphinx-intl==2.1.0 sphinx-notfound-page==0.8.3 -sphinx-sqlalchemy==0.1.1 sphinxcontrib-applehelp==1.0.4 sphinxcontrib-details-directive==0.1.0 sphinxcontrib-devhelp==1.0.2 @@ -189,8 +188,7 @@ sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sphinxext-rediraffe==0.2.7 -sqlalchemy==1.4.48 -sqlalchemy-utils==0.37.9 +sqlalchemy==2.0.23 stack-data==0.6.2 sympy==1.12 tabulate==0.8.10 diff --git a/tests/orm/test_querybuilder/test_as_sql_literal_quote.txt b/tests/orm/test_querybuilder/test_as_sql_literal_quote.txt index cbea0bfe33..2c9cf2d512 100644 --- a/tests/orm/test_querybuilder/test_as_sql_literal_quote.txt +++ b/tests/orm/test_querybuilder/test_as_sql_literal_quote.txt @@ -1,3 +1,3 @@ SELECT db_dbnode_1.uuid FROM db_dbnode AS db_dbnode_1 -WHERE CAST(db_dbnode_1.node_type AS VARCHAR) LIKE 'data.core.structure.%%' AND CAST((db_dbnode_1.extras #> '{elements}') AS JSONB) @> '["Si"]' +WHERE CAST(db_dbnode_1.node_type AS VARCHAR) LIKE 'data.core.structure.%%' AND (CAST((db_dbnode_1.extras #> '{elements}') AS JSONB) @> '["Si"]') diff --git a/tests/storage/psql_dos/migrations/django_branch/test_0037_attributes_extras_settings_json.py b/tests/storage/psql_dos/migrations/django_branch/test_0037_attributes_extras_settings_json.py index 842f97d043..ddb46abc4e 100644 --- a/tests/storage/psql_dos/migrations/django_branch/test_0037_attributes_extras_settings_json.py +++ b/tests/storage/psql_dos/migrations/django_branch/test_0037_attributes_extras_settings_json.py @@ -141,7 +141,7 @@ def test_settings_migration(perform_migrations: PsqlDosMigrator): setting_model = perform_migrations.get_current_table('db_dbsetting') with perform_migrations.session() as session: - settings = { + settings = { # type: ignore[var-annotated] row[0]: row[1] for row in session.execute(select(setting_model.key, setting_model.val).order_by(setting_model.key)).all() } diff --git a/tests/storage/psql_dos/test_utils.py b/tests/storage/psql_dos/test_utils.py deleted file mode 100644 index 957b843600..0000000000 --- a/tests/storage/psql_dos/test_utils.py +++ /dev/null @@ -1,110 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=import-error,no-name-in-module -"""In this file various data management functions, needed for the SQLA test, -are added. They are "heavily inspired" by the sqlalchemy_utils.functions.database -(SQLAlchemy-Utils package). - -However, they were corrected to work properly with a SQlAlchemy and PostgreSQL. -The main problem of the SQLAlchemy-Utils that were rewritten was that they -were not properly disposing the (SQLA) engine, resulting to error messages -from PostgreSQL.""" - -from sqlalchemy_utils.functions.database import drop_database - - -def new_database(uri): - """Drop the database at ``uri`` and create a brand new one.""" - destroy_database(uri) - create_database(uri) - - -def destroy_database(uri): - """Destroy the database at ``uri``, if it exists.""" - if database_exists(uri): - drop_database(uri) - - -def database_exists(url): - """Check if a database exists. - - This is a modification of sqlalchemy_utils.functions.database.database_exists - since the latter one did not correctly work with SQLAlchemy and PostgreSQL. - - :param url: A SQLAlchemy engine URL. - - Performs backend-specific testing to quickly determine if a database - exists on the server.""" - - from copy import copy - - import sqlalchemy as sa - from sqlalchemy.engine.url import make_url - - url = copy(make_url(url)) - database = url.database - if url.drivername.startswith('postgresql'): - url = url.set(database='template1') - else: - url = url.set(database=None) - - engine = sa.create_engine(url) # pylint: disable=no-member - - try: - if engine.dialect.name == 'postgresql': - text = sa.text(f"SELECT 1 FROM pg_database WHERE datname='{database}'") - return bool(engine.connect().execute(text).scalar()) - raise RuntimeError('Only PostgreSQL is supported.') - - finally: - engine.dispose() - - -def create_database(url, encoding='utf8'): - """Issue the appropriate CREATE DATABASE statement. - - This is a modification of sqlalchemy_utils.functions.database.create_database - since the latter one did not correctly work with SQLAlchemy and PostgreSQL. - - :param url: A SQLAlchemy engine URL. - :param encoding: The encoding to create the database as. - - - It currently supports only PostgreSQL and the psycopg2 driver. - """ - - from copy import copy - - import sqlalchemy as sa - from sqlalchemy.engine.url import make_url - from sqlalchemy_utils.functions.orm import quote - - url = copy(make_url(url)) - - database = url.database - - # A default PostgreSQL database to connect - url = url.set(database='template1') - - engine = sa.create_engine(url) # pylint: disable=no-member - - try: - if engine.dialect.name == 'postgresql' and engine.driver == 'psycopg2': - from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT - engine.raw_connection().set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - - text = sa.text(f"CREATE DATABASE {quote(engine, database)} ENCODING '{encoding}'") - with engine.begin() as connection: - connection.execute(text) - - else: - raise RuntimeError('Only PostgreSQL with the psycopg2 driver is supported.') - finally: - engine.dispose() diff --git a/tests/tools/archive/test_schema.py b/tests/tools/archive/test_schema.py index a0b9337970..1981495f6a 100644 --- a/tests/tools/archive/test_schema.py +++ b/tests/tools/archive/test_schema.py @@ -60,7 +60,7 @@ def test_psql_sync_migrate(tmp_path): raise AssertionError(f'Schema is not in-sync with the psql backend:\n{yaml.safe_dump(diffs)}') -def diff_schemas(psql_insp: Inspector, sqlite_insp: Inspector): # pylint: disable=too-many-branches,too-many-statements +def diff_schemas(psql_insp: Inspector, sqlite_insp: Inspector): # pylint: disable=too-many-branches,too-many-statements,too-many-locals """Compare the reflected schemas of the two databases.""" diffs: dict = {} @@ -85,20 +85,25 @@ def diff_schemas(psql_insp: Inspector, sqlite_insp: Inspector): # pylint: disab # check type psql_type = psql_columns[column_name]['type'] sqlite_type = sqlite_columns[column_name]['type'] + # standardise types + # Since sqlalchemy v2.0 the ``UUID.as_generic()`` for PostgreSQL is converted to ``CHAR(32)`` which causes + # a discrepancy between the field for sqlite which is defined as ``VARCHAR(32)``. Therefore ``UUID`` is + # converted to string manually before calling ``.as_generic()``. + if isinstance(psql_type, UUID): + psql_type = String(length=32) + with suppress(NotImplementedError): psql_type = psql_type.as_generic() with suppress(NotImplementedError): sqlite_type = sqlite_type.as_generic() - if isinstance(psql_type, UUID): - psql_type = String(length=32) + if not isinstance(sqlite_type, type(psql_type)): diffs.setdefault(table_name, {}).setdefault(column_name, {})['type'] = f'{sqlite_type} != {psql_type}' elif isinstance(psql_type, String): - if psql_type.length != sqlite_type.length: - diffs.setdefault(table_name, - {}).setdefault(column_name, - {})['length'] = f'{sqlite_type.length} != {psql_type.length}' + if psql_type.length != sqlite_type.length: # type: ignore[attr-defined] + string = f'{sqlite_type.length} != {psql_type.length}' # type: ignore[attr-defined] + diffs.setdefault(table_name, {}).setdefault(column_name, {})['length'] = string # check nullability psql_nullable = psql_columns[column_name]['nullable'] sqlite_nullable = sqlite_columns[column_name]['nullable'] @@ -130,7 +135,7 @@ def diff_schemas(psql_insp: Inspector, sqlite_insp: Inspector): # pylint: disab psql_indexes = [ idx['name'] for idx in psql_insp.get_indexes(table_name) - if not idx['unique'] and not idx['name'].startswith('ix_pat_') + if not idx['unique'] and not (idx['name'] is not None and idx['name'].startswith('ix_pat_')) ] sqlite_indexes = [idx['name'] for idx in sqlite_insp.get_indexes(table_name) if not idx['unique']] for index in psql_indexes: