diff --git a/pyproject.toml b/pyproject.toml index 419209b..a6f6083 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "NEEMQuery" -version = "1.1.1" +version = "1.1.2" description = "Querying the NEEMs (Narrative Enabled Episodic Memories) datasbase" readme = "README.md" authors = [{ name = "Abdelrhman Bassiouny", email = "abassiou@uni-bremen.de" }] diff --git a/src/neem_query/neem_query.py b/src/neem_query/neem_query.py index 214bbc1..7f500f9 100644 --- a/src/neem_query/neem_query.py +++ b/src/neem_query/neem_query.py @@ -1,4 +1,5 @@ import logging +from copy import copy import pandas as pd from sqlalchemy import (create_engine, Engine, between, and_, func, Table, BinaryExpression, select, @@ -36,6 +37,7 @@ class NeemQuery: def __init__(self, sql_uri: str): self._select_neem_id: bool = False + self.sql_uri = sql_uri self.engine = create_engine(sql_uri) self.session = sessionmaker(bind=self.engine)() self.query: Optional[Select] = None @@ -1108,7 +1110,7 @@ def _create_entity_tf_view(self, """ Create a view of the TF data. """ - nq = NeemQuery(self.engine.url.__str__()) + nq = NeemQuery(self.sql_uri) subquery = (nq._select_entity_tf_columns(entity_tf, entity_tf_header) ._select_entity_tf_transform_columns(entity_tf_translation, entity_tf_rotation) .select(entity_tf.neem_id) @@ -1118,6 +1120,21 @@ def _create_entity_tf_view(self, ).construct_subquery(name) return self.create_table_from_subquery(subquery) + def as_subquery_table(self, name: Optional[str] = None, return_query: Optional[bool] = False) \ + -> Tuple[NamedFromClause, Optional[Subquery]]: + """ + Create a subquery table from the query. + :param name: the name of the table. + :param return_query: whether to return the neem_query object or not. + :return: the table, and the query object if return_query is True. + """ + subquery = self.construct_subquery(name) + table = self.create_table_from_subquery(subquery) + if return_query: + return table, self.__copy__() + else: + return table + def _join_entity_tf_header_on_tf(self, entity_tf_header: Type[TfHeader], entity_tf: Type[Tf]) -> 'NeemQuery': """ Join the entity_tf_header on the entity_tf table using the header column in the entity_tf table. @@ -1815,6 +1832,27 @@ def reset(self): def __eq__(self, other): return self.construct_query() == other.construct_query() + def __copy__(self): + nq = NeemQuery(self.sql_uri) + if self.query is not None: + nq.query = copy(self.query) + else: + nq.query = self.query + nq.selected_columns = self.selected_columns.copy() + nq.joins = self.joins.copy() + nq.in_filters = self.in_filters.copy() + nq.remove_filters = self.remove_filters.copy() + nq.outer_joins = self.outer_joins.copy() + nq.filters = self.filters.copy() + nq._limit = self._limit + nq._order_by = self._order_by + nq.select_from_tables = self.select_from_tables.copy() + nq.latest_executed_query = self.latest_executed_query + nq.latest_result = self.latest_result + nq._distinct = self._distinct + nq._select_neem_id = self._select_neem_id + return nq + @property def query_changed(self): if self.latest_executed_query is None: diff --git a/test/test_neem_query.py b/test/test_neem_query.py index dfe1cbb..c03f8e8 100644 --- a/test/test_neem_query.py +++ b/test/test_neem_query.py @@ -1,7 +1,8 @@ +import time from unittest import TestCase import pandas as pd -from sqlalchemy import select +from sqlalchemy import select, and_, func from neem_query import NeemQuery from neem_query.enums import ColumnLabel as CL, ParticipantBaseLinkName, ParticipantBaseLink, PerformerBaseLinkName, \ @@ -161,7 +162,7 @@ def test_join_participant_tf_on_time_interval(self): self.assertTrue( all(df[CL.participant_base_link.value][i].split(':')[-1] == df[CL.participant_child_frame_id.value][i] for i in range(len(df))) - ) + ) self.assertTrue(all( df[CL.time_interval_begin.value][i] <= df[CL.participant_stamp.value][i] <= df[CL.time_interval_end.value][ i] @@ -215,3 +216,36 @@ def test_performer_and_participant(self): ) for c in query.get_result_in_chunks(100): self.assertTrue(len(c) > 0) + + def test_filter_links(self): + pr2_links = self.get_pr2_links_query().get_result().df + self.assertTrue(len(pr2_links) > 0) + self.assertTrue(all(pr2_links["rdf_type_s"].str.contains("pr2"))) + + def test_filter_tf_by_pr2_links(self): + start = time.time() + pr2_links = self.get_pr2_links_query().get_result().df["rdf_type_s"].str.split(':').str[-1] + self.nq.reset() + filtered_tf = self.nq.select(Tf.child_frame_id).filter(Tf.child_frame_id.in_(pr2_links)).get_result().df + self.assertTrue(len(filtered_tf) > 0) + self.assertTrue(all(filtered_tf["child_frame_id"].isin(pr2_links))) + print(filtered_tf) + + def test_filter_tf_by_pr2_links_using_join(self): + start = time.time() + pr2_links_table, pr2_links_query = self.get_pr2_links_query().as_subquery_table("pr2_links", + return_query=True) + self.nq.reset() + filtered_tf = (self.nq.select(Tf.child_frame_id). + join(pr2_links_table, Tf.child_frame_id == func.substring_index(pr2_links_table.rdf_type_s, + ':', -1))).get_result().df + print(time.time() - start) + self.assertTrue(len(filtered_tf) > 0) + pr2_links = pr2_links_query.get_result().df["rdf_type_s"].str.split(':').str[-1] + self.assertTrue(all(filtered_tf["child_frame_id"].isin(pr2_links)) + ) + print(filtered_tf) + + def get_pr2_links_query(self): + return (self.nq.select(RdfType.s).filter_by_type(RdfType, ["urdf:link"]) + .filter(RdfType.s.like("%pr2%")).distinct())