diff --git a/src/odm2/base.py b/src/odm2/base.py index 0117d11a..4d080ee3 100644 --- a/src/odm2/base.py +++ b/src/odm2/base.py @@ -224,31 +224,40 @@ class Models: def __init__(self, base_model) -> None: self._base_model = base_model - self._process_schema(annotations) - self._process_schema(core) - self._process_schema(cv) - self._process_schema(dataquality) - self._process_schema(equipment) - self._process_schema(extensionproperties) - self._process_schema(externalidentifiers) - self._process_schema(labanalyses) - self._process_schema(provenance) - self._process_schema(results) - self._process_schema(samplingfeatures) - self._process_schema(simulation) - self._process_schema(auth) - - def _process_schema(self, schema: str) -> None: + # models that are declaratively mapped. + self.__add_model(_results.TimeSeriesResults), + self.__add_model(_results.TimeSeriesResultValues) + + self.__process_schema(annotations) + self.__process_schema(core) + self.__process_schema(cv) + self.__process_schema(dataquality) + self.__process_schema(equipment) + self.__process_schema(extensionproperties) + self.__process_schema(externalidentifiers) + self.__process_schema(labanalyses) + self.__process_schema(provenance) + self.__process_schema(results) + self.__process_schema(samplingfeatures) + self.__process_schema(simulation) + self.__process_schema(auth) + + def __process_schema(self, schema: str) -> None: classes = [c for c in dir(schema) if not c.startswith("__")] - base = tuple([self._base_model]) for class_name in classes: model = getattr(schema, class_name) # ignore modules for when a schema imports them if type(model) is not type: continue + self.__remap_model(model) + + def __remap_model(self, model): + base = tuple([self._base_model]) + extended_model = type(model.__name__, base, {}) + setattr(self, model.__name__, extended_model) - extended_model = type(class_name, base, {}) - setattr(self, class_name, extended_model) + def __add_model(self, model): + setattr(self, model.__name__, model) def _trim_dunders(self, dictionary: Dict[str, Any]) -> Dict[str, Any]: return {k: v for k, v in dictionary.items() if not k.startswith("__")} @@ -285,18 +294,6 @@ def _prepare_model_base(self): return automap_base(cls=AutoBase, metadata=metadata) def _prepare_automap_models(self): - # models that are declaratively mapped. - setattr( - self._model_base, - "TimeSeriesResults", - dict(_results.TimeSeriesResults.__dict__), - ) - setattr( - self._model_base, - "TimeSeriesResultValues", - dict(_results.TimeSeriesResultValues.__dict__), - ) - self._model_base.prepare(self._engine) if not self._cache_path: return diff --git a/src/odm2/models/results.py b/src/odm2/models/results.py index 23d65e28..56ab59f0 100644 --- a/src/odm2/models/results.py +++ b/src/odm2/models/results.py @@ -4,11 +4,17 @@ from sqlalchemy import orm import sqlalchemy as sqla from sqlalchemy.dialects import postgresql as pg +from sqlalchemy.orm import declarative_base +Base = declarative_base() -class TimeSeriesResults: + +class TimeSeriesResults(Base): """http://odm2.github.io/ODM2/schemas/ODM2_Current/tables/ODM2Results_TimeSeriesResults.html""" + __tablename__ = "timeseriesresults" + __table_args__ = {"schema": "odm2"} + resultid: orm.Mapped[int] = sqla.Column("resultid", sqla.Integer, primary_key=True) xlocation: orm.Mapped[typing.Optional[float]] = sqla.Column( "xlocation", pg.DOUBLE_PRECISION @@ -42,9 +48,12 @@ class TimeSeriesResults: ) -class TimeSeriesResultValues: +class TimeSeriesResultValues(Base): """http://odm2.github.io/ODM2/schemas/ODM2_Current/tables/ODM2Results_TimeSeriesResultValues.html""" + __tablename__ = "timeseriesresultvalues" + __table_args__ = {"schema": "odm2"} + valueid: orm.Mapped[int] = sqla.Column("valueid", sqla.Integer, primary_key=True) resultid: orm.Mapped[int] = sqla.Column( "resultid", sqla.ForeignKey("results.resultid")