Skip to content

Commit

Permalink
[NEEMQuery] improved subquery interface
Browse files Browse the repository at this point in the history
  • Loading branch information
AbdelrhmanBassiouny committed Jul 7, 2024
1 parent c1f0e5a commit 60d6eee
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "NEEMQuery"
version = "1.1.1"
version = "1.1.2"
description = "Querying the NEEMs (Narrative Enabled Episodic Memories) datasbase"
readme = "README.md"
authors = [{ name = "Abdelrhman Bassiouny", email = "[email protected]" }]
Expand Down
40 changes: 39 additions & 1 deletion src/neem_query/neem_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from copy import copy

import pandas as pd
from sqlalchemy import (create_engine, Engine, between, and_, func, Table, BinaryExpression, select,
Expand Down Expand Up @@ -36,6 +37,7 @@ class NeemQuery:

def __init__(self, sql_uri: str):
self._select_neem_id: bool = False
self.sql_uri = sql_uri
self.engine = create_engine(sql_uri)
self.session = sessionmaker(bind=self.engine)()
self.query: Optional[Select] = None
Expand Down Expand Up @@ -1108,7 +1110,7 @@ def _create_entity_tf_view(self,
"""
Create a view of the TF data.
"""
nq = NeemQuery(self.engine.url.__str__())
nq = NeemQuery(self.sql_uri)
subquery = (nq._select_entity_tf_columns(entity_tf, entity_tf_header)
._select_entity_tf_transform_columns(entity_tf_translation, entity_tf_rotation)
.select(entity_tf.neem_id)
Expand All @@ -1118,6 +1120,21 @@ def _create_entity_tf_view(self,
).construct_subquery(name)
return self.create_table_from_subquery(subquery)

def as_subquery_table(self, name: Optional[str] = None, return_query: Optional[bool] = False) \
-> Tuple[NamedFromClause, Optional[Subquery]]:
"""
Create a subquery table from the query.
:param name: the name of the table.
:param return_query: whether to return the neem_query object or not.
:return: the table, and the query object if return_query is True.
"""
subquery = self.construct_subquery(name)
table = self.create_table_from_subquery(subquery)
if return_query:
return table, self.__copy__()
else:
return table

def _join_entity_tf_header_on_tf(self, entity_tf_header: Type[TfHeader], entity_tf: Type[Tf]) -> 'NeemQuery':
"""
Join the entity_tf_header on the entity_tf table using the header column in the entity_tf table.
Expand Down Expand Up @@ -1815,6 +1832,27 @@ def reset(self):
def __eq__(self, other):
return self.construct_query() == other.construct_query()

def __copy__(self):
nq = NeemQuery(self.sql_uri)
if self.query is not None:
nq.query = copy(self.query)
else:
nq.query = self.query
nq.selected_columns = self.selected_columns.copy()
nq.joins = self.joins.copy()
nq.in_filters = self.in_filters.copy()
nq.remove_filters = self.remove_filters.copy()
nq.outer_joins = self.outer_joins.copy()
nq.filters = self.filters.copy()
nq._limit = self._limit
nq._order_by = self._order_by
nq.select_from_tables = self.select_from_tables.copy()
nq.latest_executed_query = self.latest_executed_query
nq.latest_result = self.latest_result
nq._distinct = self._distinct
nq._select_neem_id = self._select_neem_id
return nq

@property
def query_changed(self):
if self.latest_executed_query is None:
Expand Down
38 changes: 36 additions & 2 deletions test/test_neem_query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import time
from unittest import TestCase

import pandas as pd
from sqlalchemy import select
from sqlalchemy import select, and_, func

from neem_query import NeemQuery
from neem_query.enums import ColumnLabel as CL, ParticipantBaseLinkName, ParticipantBaseLink, PerformerBaseLinkName, \
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_join_participant_tf_on_time_interval(self):
self.assertTrue(
all(df[CL.participant_base_link.value][i].split(':')[-1] == df[CL.participant_child_frame_id.value][i]
for i in range(len(df)))
)
)
self.assertTrue(all(
df[CL.time_interval_begin.value][i] <= df[CL.participant_stamp.value][i] <= df[CL.time_interval_end.value][
i]
Expand Down Expand Up @@ -215,3 +216,36 @@ def test_performer_and_participant(self):
)
for c in query.get_result_in_chunks(100):
self.assertTrue(len(c) > 0)

def test_filter_links(self):
pr2_links = self.get_pr2_links_query().get_result().df
self.assertTrue(len(pr2_links) > 0)
self.assertTrue(all(pr2_links["rdf_type_s"].str.contains("pr2")))

def test_filter_tf_by_pr2_links(self):
start = time.time()
pr2_links = self.get_pr2_links_query().get_result().df["rdf_type_s"].str.split(':').str[-1]
self.nq.reset()
filtered_tf = self.nq.select(Tf.child_frame_id).filter(Tf.child_frame_id.in_(pr2_links)).get_result().df
self.assertTrue(len(filtered_tf) > 0)
self.assertTrue(all(filtered_tf["child_frame_id"].isin(pr2_links)))
print(filtered_tf)

def test_filter_tf_by_pr2_links_using_join(self):
start = time.time()
pr2_links_table, pr2_links_query = self.get_pr2_links_query().as_subquery_table("pr2_links",
return_query=True)
self.nq.reset()
filtered_tf = (self.nq.select(Tf.child_frame_id).
join(pr2_links_table, Tf.child_frame_id == func.substring_index(pr2_links_table.rdf_type_s,
':', -1))).get_result().df
print(time.time() - start)
self.assertTrue(len(filtered_tf) > 0)
pr2_links = pr2_links_query.get_result().df["rdf_type_s"].str.split(':').str[-1]
self.assertTrue(all(filtered_tf["child_frame_id"].isin(pr2_links))
)
print(filtered_tf)

def get_pr2_links_query(self):
return (self.nq.select(RdfType.s).filter_by_type(RdfType, ["urdf:link"])
.filter(RdfType.s.like("%pr2%")).distinct())

0 comments on commit 60d6eee

Please sign in to comment.