From 11f8977ecaf9e30eb09a809f3feb2494e3fca0df Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Tue, 9 Jan 2024 18:31:43 +0100 Subject: [PATCH] forward table columns in collection object --- tests/routes/test_item.py | 2 +- tipg/collections.py | 55 ++++++++++++++++++--------------------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/tests/routes/test_item.py b/tests/routes/test_item.py index 1e05d52c..44124535 100644 --- a/tests/routes/test_item.py +++ b/tests/routes/test_item.py @@ -49,5 +49,5 @@ def test_item_with_property_config(app_public_table): assert body["type"] == "Feature" assert body["id"] == 1 assert body["links"] - print(body["properties"]) + assert list(body["properties"]) == ["pr"] Item.model_validate(body) diff --git a/tipg/collections.py b/tipg/collections.py index fc2b1515..cdb26524 100644 --- a/tipg/collections.py +++ b/tipg/collections.py @@ -163,8 +163,9 @@ class Collection(BaseModel): dbschema: str = Field(..., alias="schema") title: Optional[str] = None description: Optional[str] = None + table_columns: List[Column] = [] properties: List[Column] = [] - id_column: Optional[str] = None + id_column: Optional[Column] = None geometry_column: Optional[Column] = None datetime_column: Optional[Column] = None parameters: List[Parameter] = [] @@ -237,12 +238,12 @@ def crs(self): @property def geometry_columns(self) -> List[Column]: """Return geometry columns.""" - return [c for c in self.properties if c.is_geometry] + return [c for c in self.table_columns if c.is_geometry] @property def datetime_columns(self) -> List[Column]: """Return datetime columns.""" - return [c for c in self.properties if c.is_datetime] + return [c for c in self.table_columns if c.is_datetime] def get_geometry_column(self, name: Optional[str] = None) -> Optional[Column]: """Return the name of the first geometry column.""" @@ -272,13 +273,6 @@ def get_datetime_column(self, name: Optional[str] = None) -> Optional[Column]: return None - @property - def id_column_info(self) -> Column: # type: ignore - """Return Column for a unique identifier.""" - for col in self.properties: - if col.name == self.id_column: - return col - def columns(self, properties: Optional[List[str]] = None) -> List[str]: """Return table columns optionally filtered to only include columns from properties.""" if properties in [[], [""]]: @@ -311,7 +305,7 @@ def _select_no_geo(self, properties: Optional[List[str]], addid: bool = True): if addid: if self.id_column: - id_clause = logic.V(self.id_column).as_("tipg_id") + id_clause = logic.V(self.id_column.name).as_("tipg_id") else: id_clause = raw(" ROW_NUMBER () OVER () AS tipg_id ") if nocomma: @@ -480,18 +474,14 @@ def _where( # noqa: C901 if ids is not None: if len(ids) == 1: wheres.append( - logic.V(self.id_column) - == pg_funcs.cast( - pg_funcs.cast(ids[0], "text"), self.id_column_info.type - ) + logic.V(self.id_column.name) + == pg_funcs.cast(pg_funcs.cast(ids[0], "text"), self.id_column.type) ) else: w = [ - logic.V(self.id_column) + logic.V(self.id_column.name) == logic.S( - pg_funcs.cast( - pg_funcs.cast(i, "text"), self.id_column_info.type - ) + pg_funcs.cast(pg_funcs.cast(i, "text"), self.id_column.type) ) for i in ids ] @@ -626,7 +616,7 @@ def _sortby(self, sortby: Optional[str]): else: if self.id_column is not None: - sorts.append(logic.V(self.id_column)) + sorts.append(logic.V(self.id_column.name)) else: sorts.append(logic.V(self.properties[0].name)) @@ -961,23 +951,27 @@ async def get_collection_index( # noqa: C901 table_conf = table_confs.get(confid, TableConfig()) # Make sure that any properties set in conf exist in table - properties = sorted(table.get("properties", []), key=lambda d: d["name"]) - properties_setting = table_conf.properties or [] - if properties_setting: - properties = [p for p in properties if p["name"] in properties_setting] + columns = sorted(table.get("properties", []), key=lambda d: d["name"]) + properties_setting = table_conf.properties or [c["name"] for c in columns] # ID Column - id_column = table_conf.pk or table.get("pk") - if not id_column and fallback_key_names: - for p in properties: + id_column = None + if id_name := table_conf.pk or table.get("pk"): + for p in columns: + if id_name == p["name"]: + id_column = p + break + + if id_column is None and fallback_key_names: + for p in columns: if p["name"] in fallback_key_names: - id_column = p["name"] + id_column = p break datetime_column = None geometry_column = None - for c in properties: + for c in columns: if c.get("type") in ("timestamp", "timestamptz", "date"): if table_conf.datetimecol == c["name"] or datetime_column is None: datetime_column = c @@ -992,8 +986,9 @@ async def get_collection_index( # noqa: C901 table=table["name"], schema=table["schema"], description=table.get("description", None), + table_columns=columns, + properties=[p for p in columns if p["name"] in properties_setting], id_column=id_column, - properties=properties, datetime_column=datetime_column, geometry_column=geometry_column, parameters=table.get("parameters") or [],