diff --git a/pyproject.toml b/pyproject.toml index 0397155..3e69e26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ dependencies = [ [project.optional-dependencies] dev = [ + "factory_boy<4", + "pytest-factoryboy<3", "pytest", "pytest-cov", "sphinx-autobuild<=2024.5", diff --git a/snowexsql/api.py b/snowexsql/api.py index 3aa2fe9..e27c040 100644 --- a/snowexsql/api.py +++ b/snowexsql/api.py @@ -1,4 +1,5 @@ import logging +import os from contextlib import contextmanager import geoalchemy2.functions as gfunc @@ -10,9 +11,8 @@ from snowexsql.conversions import query_to_geopandas, raster_to_rasterio from snowexsql.db import get_db -from snowexsql.tables import ImageData, LayerData, PointData, Instrument, \ - Observer, Site, Campaign, MeasurementType, DOI - +from snowexsql.tables import Campaign, DOI, ImageData, Instrument, LayerData, \ + MeasurementType, Observer, PointData, PointObservation, Site LOG = logging.getLogger(__name__) DB_NAME = 'snow:hackweek@db.snowexdata.org/snowex' @@ -102,6 +102,22 @@ def _filter_observers(cls, qry, v): ).filter(Observer.name == v) return qry + @classmethod + def _filter_instrument(cls, qry, value): + return qry.filter( + cls.MODEL.instrument.has(name=value) + ) + + @classmethod + def _filter_measurement_type(cls, qry, value): + return qry.join( + cls.MODEL.measurement_type + ).filter(MeasurementType.name == value) + + @classmethod + def _filter_doi(cls, qry, value): + return qry.join(cls.MODEL.doi).filter(DOI.doi == value) + @classmethod def extend_qry(cls, qry, check_size=True, **kwargs): if cls.MODEL is None: @@ -111,12 +127,15 @@ def extend_qry(cls, qry, check_size=True, **kwargs): for k, v in kwargs.items(): # Handle special operations if k in cls.ALLOWED_QRY_KWARGS: + + qry_model = cls.MODEL # Logic for filtering on date with LayerData if "date" in k and cls.MODEL == LayerData: qry = qry.join(LayerData.site) qry_model = Site - else: - qry_model = cls.MODEL + elif cls.MODEL == PointData: + qry = qry.join(PointData.observation) + # standard filtering using qry.filter if isinstance(v, list): filter_col = getattr(qry_model, k) @@ -137,17 +156,17 @@ def extend_qry(cls, qry, check_size=True, **kwargs): # Filter boundary if "_greater_equal" in k: key = k.split("_greater_equal")[0] - filter_col = getattr(qry_model, key) - qry = qry.filter(filter_col >= v) + qry = qry.filter( + getattr(qry_model, key) >= v + ) elif "_less_equal" in k: key = k.split("_less_equal")[0] - filter_col = getattr(qry_model, key) - qry = qry.filter(filter_col <= v) - # Filter linked columns - elif k == "instrument": qry = qry.filter( - qry_model.instrument.has(name=v) + getattr(qry_model, key) <= v ) + # Filter linked columns + elif k == "instrument": + qry = cls._filter_instrument(qry, v) elif k == "campaign": qry = cls._filter_campaign(qry, v) elif k == "site_id": @@ -157,17 +176,14 @@ def extend_qry(cls, qry, check_size=True, **kwargs): elif k == "observer": qry = cls._filter_observers(qry, v) elif k == "doi": - qry = qry.join( - qry_model.doi - ).filter(DOI.doi == v) + qry = cls._filter_doi(qry, v) elif k == "type": - qry = qry.join( - qry_model.measurement - ).filter(MeasurementType.name == v) + qry = cls._filter_measurement_type(qry, v) # Filter to exact value else: - filter_col = getattr(qry_model, k) - qry = qry.filter(filter_col == v) + qry = qry.filter( + getattr(qry_model, k) == v + ) LOG.debug( f"Filtering {k} to list {v}" ) @@ -207,6 +223,103 @@ def from_unique_entries(cls, columns_to_search, **kwargs): return results + @classmethod + def from_filter(cls, **kwargs): + """ + Get data for the class by filtering by allowed arguments. The allowed + filters are cls.ALLOWED_QRY_KWARGS. + """ + with db_session(cls.DB_NAME) as (session, engine): + try: + qry = session.query(cls.MODEL) + qry = cls.extend_qry(qry, **kwargs) + + # For debugging in the test suite and not recommended + # in production + # https://docs.sqlalchemy.org/en/20/faq/sqlexpressions.html#rendering-postcompile-parameters-as-bound-parameters ## noqa + if 'DEBUG_QUERY' in os.environ: + full_sql_query = qry.statement.compile( + compile_kwargs={"literal_binds": True} + ) + print("\n ** SQL query **") + print(full_sql_query) + + df = query_to_geopandas(qry, engine) + except Exception as e: + session.close() + LOG.error("Failed query for PointData") + raise e + + return df + + @classmethod + def from_area(cls, shp=None, pt=None, buffer=None, crs=26912, **kwargs): + """ + Get data for the class within a specific shapefile or + within a point and a known buffer + Args: + shp: shapely geometry in which to filter + pt: shapely point that will have a buffer applied in order + to find search area + buffer: in same units as point + crs: integer crs to use + kwargs: for more filtering or limiting (cls.ALLOWED_QRY_KWARGS) + Returns: Geopandas dataframe of results + + """ + if shp is None and pt is None: + raise ValueError( + "Inputs must be a shape description or a point and buffer" + ) + if (pt is not None and buffer is None) or \ + (buffer is not None and pt is None): + raise ValueError("pt and buffer must be given together") + with db_session(cls.DB_NAME) as (session, engine): + try: + if shp is not None: + qry = session.query(cls.MODEL) + # Filter geometry based on Site for LayerData + if cls.MODEL == LayerData: + qry = qry.join(cls.MODEL.site).filter( + func.ST_Within( + Site.geom, from_shape(shp, srid=crs) + ) + ) + else: + qry = qry.filter( + func.ST_Within( + cls.MODEL.geom, from_shape(shp, srid=crs) + ) + ) + qry = cls.extend_qry(qry, check_size=True, **kwargs) + df = query_to_geopandas(qry, engine) + else: + qry_pt = from_shape(pt) + qry = session.query( + gfunc.ST_SetSRID( + func.ST_Buffer(qry_pt, buffer), crs + ) + ) + + buffered_pt = qry.all()[0][0] + qry = session.query(cls.MODEL) + # Filter geometry based on Site for LayerData + if cls.MODEL == LayerData: + qry = qry.join(cls.MODEL.site).filter( + func.ST_Within(Site.geom, buffered_pt) + ) + else: + qry = qry.filter( + func.ST_Within(cls.MODEL.geom, buffered_pt) + ) + qry = cls.extend_qry(qry, check_size=True, **kwargs) + df = query_to_geopandas(qry, engine) + except Exception as e: + session.close() + raise e + + return df + @property def all_site_names(self): """ @@ -287,98 +400,78 @@ class PointMeasurements(BaseDataset): MODEL = PointData @classmethod - def from_filter(cls, **kwargs): - """ - Get data for the class by filtering by allowed arguments. The allowed - filters are cls.ALLOWED_QRY_KWARGS. - """ - with db_session(cls.DB_NAME) as (session, engine): - try: - qry = session.query(cls.MODEL) - qry = cls.extend_qry(qry, **kwargs) - df = query_to_geopandas(qry, engine) - except Exception as e: - session.close() - LOG.error("Failed query for PointData") - raise e + def _filter_campaign(cls, qry, value): + return qry.join( + cls.MODEL.observation + ).join( + PointObservation.campaign + ).filter( + Campaign.name == value + ) - return df + @classmethod + def _filter_measurement_type(cls, qry, value): + return qry.join( + cls.MODEL.observation + ).join( + PointObservation.measurement_type + ).filter( + MeasurementType.name == value + ) @classmethod - def from_area(cls, shp=None, pt=None, buffer=None, crs=26912, **kwargs): - """ - Get data for the class within a specific shapefile or - within a point and a known buffer - Args: - shp: shapely geometry in which to filter - pt: shapely point that will have a buffer applied in order - to find search area - buffer: in same units as point - crs: integer crs to use - kwargs: for more filtering or limiting (cls.ALLOWED_QRY_KWARGS) - Returns: Geopandas dataframe of results + def _filter_instrument(cls, qry, value): + return qry.join( + cls.MODEL.observation + ).join( + PointObservation.instrument + ).filter( + Instrument.name == value + ) - """ - if shp is None and pt is None: - raise ValueError( - "Inputs must be a shape description or a point and buffer" - ) - if (pt is not None and buffer is None) or \ - (buffer is not None and pt is None): - raise ValueError("pt and buffer must be given together") - with db_session(cls.DB_NAME) as (session, engine): - try: - if shp is not None: - qry = session.query(cls.MODEL) - # Filter geometry based on Site for LayerData - if cls.MODEL == LayerData: - qry = qry.join(cls.MODEL.site).filter( - func.ST_Within( - Site.geom, from_shape(shp, srid=crs) - ) - ) - else: - qry = qry.filter( - func.ST_Within( - cls.MODEL.geom, from_shape(shp, srid=crs) - ) - ) - qry = cls.extend_qry(qry, check_size=True, **kwargs) - df = query_to_geopandas(qry, engine) - else: - qry_pt = from_shape(pt) - qry = session.query( - gfunc.ST_SetSRID( - func.ST_Buffer(qry_pt, buffer), crs - ) - ) + @classmethod + def _filter_doi(cls, qry, value): + return qry.join( + cls.MODEL.observation + ).join( + PointObservation.doi + ).filter( + DOI.doi == value + ) - buffered_pt = qry.all()[0][0] - qry = session.query(cls.MODEL) - # Filter geometry based on Site for LayerData - if cls.MODEL == LayerData: - qry = qry.join(cls.MODEL.site).filter( - func.ST_Within(Site.geom, buffered_pt) - ) - else: - qry = qry.filter( - func.ST_Within(cls.MODEL.geom, buffered_pt) - ) - qry = cls.extend_qry(qry, check_size=True, **kwargs) - df = query_to_geopandas(qry, engine) - except Exception as e: - session.close() - raise e + @classmethod + def _filter_observers(cls, qry, value): + return qry.join( + cls.MODEL.observation + ).join( + PointObservation.observer + ).filter( + Observer.name == value + ) - return df + @property + def all_instruments(self): + """ + Return all distinct instruments in the data + """ + with db_session(self.DB_NAME) as (session, engine): + result = session.query(Instrument.name).filter( + Instrument.id.in_( + session.query(PointObservation.instrument_id).distinct() + ) + ).all() + return self.retrieve_single_value_result(result) class TooManyRastersException(Exception): - """ Exceptiont to report to users that their query will produce too many rasters""" + """ + Exception to report to users that their query will produce too many + rasters + """ pass -class LayerMeasurements(PointMeasurements): +class LayerMeasurements(BaseDataset): """ API class for access to LayerData """ diff --git a/snowexsql/conversions.py b/snowexsql/conversions.py index 18dfa64..0de0598 100644 --- a/snowexsql/conversions.py +++ b/snowexsql/conversions.py @@ -3,7 +3,6 @@ filetypes, datatypes, etc. Many tools here will be useful for most end users of the database. """ - import geopandas as gpd import pandas as pd from geoalchemy2.shape import to_shape @@ -54,8 +53,10 @@ def query_to_geopandas(query, engine, **kwargs): # Fill out the variables in the query sql = query.statement.compile(dialect=postgresql.dialect()) - # Get dataframe from geopandas using the query and engine - df = gpd.GeoDataFrame.from_postgis(sql, engine, **kwargs) + # Get dataframe from geopandas using the query and the DB connection. + # By passing in the actual connection, we maintain ownership of it and + # keep it alive until we close it. + df = gpd.read_postgis(sql, engine.connect(), **kwargs) return df diff --git a/snowexsql/db.py b/snowexsql/db.py index 5d5a4d8..13333ad 100644 --- a/snowexsql/db.py +++ b/snowexsql/db.py @@ -10,6 +10,11 @@ from snowexsql.tables.base import Base +# This library requires a postgres dialect and the psycopg2 driver +DB_CONNECTION_PROTOCOL = 'postgresql+psycopg2://' +# Always create a Session in UTC time +DB_CONNECTION_OPTIONS = {"options": "-c timezone=UTC"} + def initialize(engine): """ @@ -22,55 +27,73 @@ def initialize(engine): meta.create_all(bind=engine) -def get_db(db_str, credentials=None, return_metadata=False): +def load_credentials(credentials_path): + """ + Load username and password from a user supplied credential file + + Args: + credentials_path (string): Full path to credentials file + """ + with open(credentials_path) as fp: + creds = json.load(fp) + return creds['username'], creds['password'] + + +def db_connection_string(db_name, credentials_path=None): + """ + Construct a connection info string for SQLAlchemy database + + Args: + db_name: The name of the database to connect to + credentials_path: Path to a json file containing username and password + for the database + + Returns: + String - DB connection + """ + db = DB_CONNECTION_PROTOCOL + + if credentials_path is not None: + username, password = load_credentials(credentials_path) + db += f"{username}:{password}@{db_name}" + else: + db += f"{db_name}" + + return db + + +def get_db(db_name, credentials=None, return_metadata=False): """ Returns the DB engine, MetaData, and session object Args: - db_str: Just the name of the database - credentials: Path to a json file containing username and password for the database + db_name: The name of the database to connect to + credentials: Path to a json file containing username and password for + the database return_metadata: Boolean indicating whether the metadata object is being returned, useful only for developers Returns: tuple: **engine** - sqlalchemy Engine object for directly sending - querys to the DB + queries to the DB **session** - sqlalchemy Session Object for using object relational mapping (ORM) **metadata** (optional) - sqlalchemy MetaData object for modifying the database """ + db_connection = db_connection_string(db_name, credentials) - # This library requires a postgres dialect and the psycopg2 driver - prefix = f'postgresql+psycopg2://' - - if credentials is not None: - # Read in the credentials - with open(credentials) as fp: - creds = json.load(fp) - username = creds['username'] - password = creds['password'] - - db = f"{prefix}{username}:{password}@{db_str}" - else: - db = f"{prefix}{db_str}" - - # Always create a Session in UTC time engine = create_engine( - db, echo=False, connect_args={ - "options": "-c timezone=UTC"}) + db_connection, echo=False, connect_args=DB_CONNECTION_OPTIONS + ) - Session = sessionmaker(bind=engine) - metadata = MetaData() - session = Session(expire_on_commit=False) + session = sessionmaker(bind=engine) + session = session(expire_on_commit=False) if return_metadata: - result = (engine, session, metadata) - + return engine, session, MetaData() else: - result = (engine, session) - - return result + return engine, session def get_table_attributes(DataCls): diff --git a/snowexsql/tables/__init__.py b/snowexsql/tables/__init__.py index 55f6f1e..9639a1a 100644 --- a/snowexsql/tables/__init__.py +++ b/snowexsql/tables/__init__.py @@ -1,21 +1,25 @@ +from .campaign import Campaign +from .doi import DOI from .image_data import ImageData +from .image_observation import ImageObservation +from .instrument import Instrument from .layer_data import LayerData -from .point_data import PointData +from .measurement_type import MeasurementType from .observers import Observer -from .instrument import Instrument -from .campaign import Campaign +from .point_data import PointData +from .point_observation import PointObservation from .site import Site -from .doi import DOI -from .measurement_type import MeasurementType __all__ = [ "Campaign", "DOI", "ImageData", + "ImageObservation", "Instrument", "LayerData", "MeasurementType", "Observer", "PointData", + "PointObservation", "Site", ] diff --git a/snowexsql/tables/base.py b/snowexsql/tables/base.py index f664e5d..f3af319 100644 --- a/snowexsql/tables/base.py +++ b/snowexsql/tables/base.py @@ -1,12 +1,4 @@ -""" -Module contains all the data models for the database. Classes here actually -represent tables where columns are mapped as attributed. Any class inheriting -from Base is a real table in the database. This is called Object Relational -Mapping in the sqlalchemy or ORM. -""" - -from geoalchemy2 import Geometry -from sqlalchemy import Column, Float, Integer, Time, Date +from sqlalchemy import Column, Date, Integer from sqlalchemy.orm import DeclarativeBase @@ -18,15 +10,3 @@ class Base(DeclarativeBase): __table_args__ = {"schema": "public"} # Primary Key id = Column(Integer, primary_key=True) - - -class SingleLocationData: - """ - Base class for points and profiles - """ - elevation = Column(Float) - geom = Column(Geometry("POINT")) - time = Column(Time(timezone=True)) - - - diff --git a/snowexsql/tables/campaign_observation.py b/snowexsql/tables/campaign_observation.py new file mode 100644 index 0000000..123f276 --- /dev/null +++ b/snowexsql/tables/campaign_observation.py @@ -0,0 +1,42 @@ +from sqlalchemy import Column, Date, ForeignKey, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from .base import Base +from .campaign import InCampaign +from .doi import HasDOI +from .instrument import HasInstrument +from .measurement_type import HasMeasurementType +from .observers import HasObserver + + +class CampaignObservation( + Base, HasDOI, HasInstrument, HasMeasurementType, HasObserver, InCampaign +): + """ + A campaign observation holds additional information for a point or image. + This is a parent table that has a 'type' column to use for single table + inheritance. The PointObservation and ImageObservation tables use this. + """ + __tablename__ = 'campaign_observations' + + # Data columns + name = Column(Text) + description = Column(Text) + date = Column(Date, nullable=False) + + # Single Table Inheritance column + type = Column(String, nullable=False) + + __mapper_args__ = { + 'polymorphic_on': type, + } + + +class HasObservation: + """ + Class to inherit when adding a observation relationship to a table + """ + + observation_id: Mapped[int] = mapped_column( + ForeignKey("public.campaign_observations.id"), index=True + ) diff --git a/snowexsql/tables/image_data.py b/snowexsql/tables/image_data.py index 33472c2..706ec80 100644 --- a/snowexsql/tables/image_data.py +++ b/snowexsql/tables/image_data.py @@ -1,20 +1,13 @@ from geoalchemy2 import Raster -from sqlalchemy import Column, String, Date +from sqlalchemy import Column from .base import Base -from .campaign import InCampaign -from .instrument import HasInstrument -from .measurement_type import HasMeasurementType -from .doi import HasDOI +from .image_observation import HasImageObservation -class ImageData(Base, HasMeasurementType, HasInstrument, HasDOI, InCampaign): +class ImageData(Base, HasImageObservation): """ Class representing the images table. This table holds all images/rasters """ __tablename__ = 'images' - # Date of the measurement - date = Column(Date) raster = Column(Raster) - description = Column(String()) - units = Column(String(50)) diff --git a/snowexsql/tables/image_observation.py b/snowexsql/tables/image_observation.py new file mode 100644 index 0000000..105d6df --- /dev/null +++ b/snowexsql/tables/image_observation.py @@ -0,0 +1,24 @@ +from sqlalchemy.orm import Mapped, declared_attr, relationship + +from .campaign_observation import CampaignObservation, HasObservation + + +class ImageObservation(CampaignObservation): + """ + Class to hold specific methods to query image observations from + the campaign_observations table + """ + # Single Table Inheritance identifier + __mapper_args__ = { + 'polymorphic_identity': 'ImageObservation', + 'polymorphic_load': 'inline', + } + + +class HasImageObservation(HasObservation): + """ + Class to inherit when adding a observation relationship to a table + """ + @declared_attr + def observation(self) -> Mapped[ImageObservation]: + return relationship("ImageObservation") diff --git a/snowexsql/tables/measurement_type.py b/snowexsql/tables/measurement_type.py index e1ea189..7f1eebd 100644 --- a/snowexsql/tables/measurement_type.py +++ b/snowexsql/tables/measurement_type.py @@ -1,5 +1,5 @@ -from sqlalchemy import Boolean, Column, Text, Integer, ForeignKey -from sqlalchemy.orm import relationship, declared_attr +from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text +from sqlalchemy.orm import declared_attr, relationship from .base import Base @@ -15,17 +15,17 @@ class MeasurementType(Base): derived = Column(Boolean, default=False) - class HasMeasurementType: """ Class to extend when including a measurement type """ @declared_attr - def measurement_type_id(cls): - return Column(Integer, ForeignKey('public.measurement_type.id'), - index=True) + def measurement_type_id(self): + return Column( + Integer, ForeignKey('public.measurement_type.id'), index=True + ) @declared_attr - def measurement(cls): + def measurement_type(self): return relationship('MeasurementType') diff --git a/snowexsql/tables/observers.py b/snowexsql/tables/observers.py index 07f1e0e..42ad9a6 100644 --- a/snowexsql/tables/observers.py +++ b/snowexsql/tables/observers.py @@ -1,5 +1,5 @@ -from sqlalchemy.orm import mapped_column, Mapped -from sqlalchemy import Column, String +from sqlalchemy import Column, ForeignKey, String +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship from .base import Base @@ -10,3 +10,17 @@ class Observer(Base): id: Mapped[int] = mapped_column(primary_key=True) # Name of the observer name = Column(String()) + + +class HasObserver: + """ + Class to inherit when adding a observer relationship to a table + """ + + observers_id: Mapped[int] = mapped_column( + ForeignKey("public.observers.id"), index=True + ) + + @declared_attr + def observer(self) -> Mapped[Observer]: + return relationship('Observer') diff --git a/snowexsql/tables/point_data.py b/snowexsql/tables/point_data.py index d5a6bbf..6701d4e 100644 --- a/snowexsql/tables/point_data.py +++ b/snowexsql/tables/point_data.py @@ -1,29 +1,12 @@ -from sqlalchemy import Column, Float, Integer, String, ForeignKey, Date -from sqlalchemy.orm import Mapped, relationship, mapped_column -from typing import List +from sqlalchemy import Column, Date, Float, Integer, String +from sqlalchemy.ext.hybrid import hybrid_property -from .base import Base, SingleLocationData -from .campaign import InCampaign -from .doi import HasDOI -from .measurement_type import HasMeasurementType -from .observers import Observer -from .instrument import HasInstrument +from .base import Base +from .point_observation import HasPointObservation +from .single_location import SingleLocationData -class PointObservers(Base): - """ - Link table - """ - __tablename__ = 'point_observers' - - point_id = Column(Integer, ForeignKey('public.points.id')) - observer_id = Column(Integer, ForeignKey("public.observers.id")) - - -class PointData( - SingleLocationData, HasMeasurementType, HasInstrument, Base, HasDOI, - InCampaign -): +class PointData(Base, SingleLocationData, HasPointObservation): """ Class representing the points table. This table holds all point data. Here a single data entry is a single coordinate pair with a single value @@ -31,9 +14,6 @@ class PointData( """ __tablename__ = 'points' - # Date of the measurement - date = Column(Date) - version_number = Column(Integer) equipment = Column(String()) value = Column(Float) @@ -41,8 +21,16 @@ class PointData( # bring these in instead of Measurement units = Column(String()) - # id is a mapped column for many-to-many with observers - id: Mapped[int] = mapped_column(primary_key=True) - observers: Mapped[List[Observer]] = relationship( - secondary=PointObservers.__table__ - ) + @hybrid_property + def date(self): + """ + Helper attribute to only query for dates of measurements + """ + return self.datetime.date() + + @date.expression + def date(cls): + """ + Helper attribute to only query for dates of measurements + """ + return cls.datetime.cast(Date) diff --git a/snowexsql/tables/point_observation.py b/snowexsql/tables/point_observation.py new file mode 100644 index 0000000..bd2f9a1 --- /dev/null +++ b/snowexsql/tables/point_observation.py @@ -0,0 +1,24 @@ +from sqlalchemy.orm import Mapped, declared_attr, relationship + +from .campaign_observation import CampaignObservation, HasObservation + + +class PointObservation(CampaignObservation): + """ + Class to hold specific methods to query points observations from + the campaign_observations table + """ + # Single Table Inheritance identifier + __mapper_args__ = { + 'polymorphic_identity': 'PointObservation', + 'polymorphic_load': 'inline', + } + + +class HasPointObservation(HasObservation): + """ + Class to inherit when adding a observation relationship to a table + """ + @declared_attr + def observation(self) -> Mapped[PointObservation]: + return relationship("PointObservation") diff --git a/snowexsql/tables/single_location.py b/snowexsql/tables/single_location.py new file mode 100644 index 0000000..92052c9 --- /dev/null +++ b/snowexsql/tables/single_location.py @@ -0,0 +1,12 @@ +from geoalchemy2 import Geometry +from sqlalchemy import Column, DateTime, Float + + +class SingleLocationData: + """ + Base class for point and layer data + """ + # Date of the measurement with time + datetime = Column(DateTime(timezone=True)) + elevation = Column(Float) + geom = Column(Geometry("POINT")) diff --git a/snowexsql/tables/site.py b/snowexsql/tables/site.py index ac96d3f..b8bf665 100644 --- a/snowexsql/tables/site.py +++ b/snowexsql/tables/site.py @@ -1,18 +1,14 @@ -# -*- coding: utf-8 -*- -""" -Created on Thu Aug 22 11:56:34 2024 - -@author: jtmaz -""" - -from sqlalchemy import Column, String, Date, Float, Integer, ForeignKey -from sqlalchemy.orm import Mapped, relationship, mapped_column from typing import List -from .observers import Observer -from .base import Base, SingleLocationData +from sqlalchemy import Column, Date, Float, ForeignKey, Integer, String +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .base import Base from .campaign import InCampaign from .doi import HasDOI +from .observers import Observer +from .single_location import SingleLocationData class SiteObservers(Base): @@ -34,8 +30,6 @@ class Site(SingleLocationData, Base, InCampaign, HasDOI): name = Column(String()) # This can be pit_id description = Column(String()) - # Date of the measurement - date = Column(Date) # Link the observer # id is a mapped column for many-to-many with observers @@ -59,3 +53,17 @@ class Site(SingleLocationData, Base, InCampaign, HasDOI): vegetation_height = Column(String()) tree_canopy = Column(String()) site_notes = Column(String()) + + @hybrid_property + def date(self): + """ + Helper attribute to only query for dates of measurements + """ + return self.datetime.date() + + @date.expression + def date(cls): + """ + Helper attribute to only query for dates of measurements + """ + return cls.datetime.cast(Date) diff --git a/tests/api/test_layer_measurements.py b/tests/api/test_layer_measurements.py index a761476..5ed8328 100644 --- a/tests/api/test_layer_measurements.py +++ b/tests/api/test_layer_measurements.py @@ -16,7 +16,7 @@ class TestLayerMeasurements(DBConnection): def test_all_types(self, clz): result = clz().all_types - assert result == ["depth", "density"] + assert result == ["density"] def test_all_site_names(self, clz): result = clz().all_site_names diff --git a/tests/api/test_point_measurements.py b/tests/api/test_point_measurements.py index 8ddbc94..52fd14e 100644 --- a/tests/api/test_point_measurements.py +++ b/tests/api/test_point_measurements.py @@ -1,78 +1,168 @@ -from datetime import date +""" +Test the Point Measurement class +""" +from datetime import date, timedelta import geopandas as gpd -import numpy as np import pytest +from geoalchemy2.shape import to_shape from snowexsql.api import PointMeasurements -from tests import DBConnection +from snowexsql.tables import PointData -class TestPointMeasurements(DBConnection): - """ - Test the Point Measurement class - """ - CLZ = PointMeasurements +@pytest.fixture +def point_data_x_y(point_data_factory): + return to_shape(point_data_factory.build().geom) - def test_all_types(self, clz): - result = clz().all_types - assert result == ['depth', 'density'] - def test_all_site_names(self, clz): - result = clz().all_site_names - assert result == ['Grand Mesa'] +@pytest.fixture +def point_data_srid(point_data_factory): + return point_data_factory.build().geom.srid - def test_all_dates(self, clz): - result = clz().all_dates - assert len(result) == 1 - def test_all_observers(self, clz): - result = clz().all_observers - assert result == ['TEST'] +@pytest.fixture +def point_data(point_data_factory, db_session): + point_data_factory.create() + return db_session.query(PointData).all() - def test_all_instruments(self, clz): - result = clz().all_instruments - assert result == ["magnaprobe"] - def test_all_dois(self, clz): - result = clz().all_dois - assert result == ['fake_doi', 'fake_doi2'] +@pytest.mark.usefixtures("db_test_session") +@pytest.mark.usefixtures("db_test_connection") +@pytest.mark.usefixtures("point_data") +class TestPointMeasurements: + @pytest.fixture(autouse=True) + def setup_method(self, point_data): + self.subject = PointMeasurements() + self.db_data = point_data - @pytest.mark.parametrize( - "kwargs, expected_length, mean_value", [ - ({ - "date": date(2020, 5, 28), - "instrument": 'camera' - }, 0, np.nan), - ({"instrument": "magnaprobe", "limit": 10}, 1, 94.0), - # limit works - ({ - "date": date(2020, 5, 28), - "instrument": 'pit ruler' - }, 0, np.nan), - ({ - "date_less_equal": date(2019, 10, 1), - }, 0, np.nan), - ({ - "date_greater_equal": date(2020, 6, 7), - }, 0, np.nan), - ({ - "doi": "fake_doi", - }, 1, 94.0), - ({ - "type": 'depth', - }, 1, 94.0), - ({ - "observer": 'TEST', - "campaign": 'Grand Mesa' - }, 1, 94.0), + def test_all_types(self): + result = self.subject.all_types + assert result == [ + record.observation.measurement_type.name + for record in self.db_data + ] + + def test_all_site_names(self): + result = self.subject.all_site_names + assert result == [ + record.observation.campaign.name + for record in self.db_data + ] + + def test_all_dates(self): + result = self.subject.all_dates + assert result == [ + record.date + for record in self.db_data + ] + + def test_all_observers(self): + result = self.subject.all_observers + assert result == [ + record.observation.observer.name + for record in self.db_data + ] + + def test_all_instruments(self): + result = self.subject.all_instruments + assert result == [ + record.observation.instrument.name + for record in self.db_data + ] + + def test_all_dois(self): + result = self.subject.all_dois + assert result == [ + record.observation.doi.doi + for record in self.db_data ] - ) - def test_from_filter(self, clz, kwargs, expected_length, mean_value): - result = clz.from_filter(**kwargs) - assert len(result) == expected_length - if expected_length > 0: - assert pytest.approx(result["value"].mean()) == mean_value + + +@pytest.mark.usefixtures("db_test_session") +@pytest.mark.usefixtures("db_test_connection") +@pytest.mark.usefixtures("point_data") +class TestPointMeasurementFilter: + @pytest.fixture(autouse=True) + def setup_method(self, point_data): + self.subject = PointMeasurements() + # Pick the first record for this test case + self.db_data = point_data[0] + + def test_date_and_instrument(self): + result = self.subject.from_filter( + date=self.db_data.datetime.date(), + instrument=self.db_data.observation.instrument.name, + ) + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value + + def test_instrument_and_limit(self, point_data_factory): + # Create 10 more records, but only fetch five + point_data_factory.create_batch(10) + + result = self.subject.from_filter( + instrument=self.db_data.observation.instrument.name, + limit=5 + ) + assert len(result) == 5 + assert pytest.approx(result["value"].mean()) == self.db_data.value + + def test_no_instrument_on_date(self): + result = self.subject.from_filter( + date=self.db_data.datetime.date() + timedelta(days=1), + instrument=self.db_data.observation.instrument.name, + ) + assert len(result) == 0 + + def test_unknown_instrument(self): + result = self.subject.from_filter( + instrument='Does not exist', + ) + assert len(result) == 0 + + def test_date_and_measurement_type(self): + result = self.subject.from_filter( + date=self.db_data.datetime.date(), + type=self.db_data.observation.measurement_type.name, + ) + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value + + def test_doi(self): + result = self.subject.from_filter( + doi=self.db_data.observation.doi.doi, + ) + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value + + def test_observer_in_campaign(self): + result = self.subject.from_filter( + observer=self.db_data.observation.observer.name, + campaign=self.db_data.observation.campaign.name, + ) + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value + + def test_date_less_equal(self, point_data_factory): + greater_date = self.db_data.datetime.date() + timedelta(days=1) + point_data_factory.create(datetime=greater_date) + + result = self.subject.from_filter( + date_less_equal=self.db_data.datetime.date(), + ) + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value + + def test_date_greater_equal(self, point_data_factory): + greater_date = self.db_data.datetime.date() - timedelta(days=1) + point_data_factory.create(datetime=greater_date) + + result = self.subject.from_filter( + date_greater_equal=self.db_data.datetime.date(), + ) + assert len(result) == 1 + assert result.loc[0].value == self.db_data.value @pytest.mark.parametrize( "kwargs, expected_error", [ @@ -81,28 +171,27 @@ def test_from_filter(self, clz, kwargs, expected_length, mean_value): ({"date": [date(2020, 5, 28), date(2019, 10, 3)]}, ValueError), ] ) - def test_from_filter_fails(self, clz, kwargs, expected_error): + def test_from_filter_fails(self, kwargs, expected_error): """ Test failure on not-allowed key and too many returns """ with pytest.raises(expected_error): - clz.from_filter(**kwargs) + self.subject.from_filter(**kwargs) - def test_from_area(self, clz): + def test_from_area(self, point_data_x_y, point_data_srid): shp = gpd.points_from_xy( - [743766.4794971556], [4321444.154620216], crs="epsg:26912" + [point_data_x_y.x], + [point_data_x_y.y], + crs=f"epsg:{point_data_srid}" ).buffer(10)[0] - result = clz.from_area( - shp=shp, - date=date(2019, 10, 30) - ) - assert len(result) == 0 + result = self.subject.from_area(shp=shp) + assert len(result) == 1 - def test_from_area_point(self, clz): - pts = gpd.points_from_xy([743766.4794971556], [4321444.154620216]) - crs = "26912" - result = clz.from_area( - pt=pts[0], buffer=10, crs=crs, - date=date(2019, 10, 30) + def test_from_area_point(self, point_data_x_y, point_data_srid): + pts = gpd.points_from_xy( + [point_data_x_y.x], + [point_data_x_y.y], ) - assert len(result) == 0 + crs = point_data_srid + result = self.subject.from_area(pt=pts[0], buffer=10, crs=crs) + assert len(result) == 1 diff --git a/tests/conftest.py b/tests/conftest.py index e69de29..a55d362 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -0,0 +1,108 @@ +import os +from contextlib import contextmanager + +import pytest +from pytest_factoryboy import register +from sqlalchemy import create_engine + +import snowexsql +from snowexsql.db import ( + DB_CONNECTION_OPTIONS, db_connection_string, initialize +) +from tests.factories import (CampaignFactory, DOIFactory, InstrumentFactory, + LayerDataFactory, MeasurementTypeFactory, + ObserverFactory, PointDataFactory, + PointObservationFactory, SiteFactory) +from .db_setup import CREDENTIAL_FILE, DB_INFO, SESSION + +# Make factories available to tests +register(CampaignFactory) +register(DOIFactory) +register(InstrumentFactory) +register(LayerDataFactory) +register(MeasurementTypeFactory) +register(ObserverFactory) +register(PointDataFactory) +register(PointObservationFactory) +register(SiteFactory) + + +# Add this factory to a test if you would like to debug the SQL statement +# It will print the query from the BaseDataset.from_filter() method +@pytest.fixture(scope='session') +def _debug_sql_query(): + os.environ['DEBUG_QUERY'] = '1' + + +@pytest.fixture(scope='function') +def db_test_session(monkeypatch, sqlalchemy_engine): + """ + Ensure that we are using the same connection across the test suite and in + the API when initiating a session. + """ + @contextmanager + def db_session(*args, **kwargs): + yield SESSION(), sqlalchemy_engine + + monkeypatch.setattr(snowexsql.api, "db_session", db_session) + + +@pytest.fixture(scope='function') +def db_test_connection(monkeypatch, sqlalchemy_engine, connection): + def test_connection(): + return connection + + monkeypatch.setattr(sqlalchemy_engine, 'connect', test_connection) + + +@pytest.fixture(scope='session') +def test_db_info(): + database_name = DB_INFO["address"] + "/" + DB_INFO["db_name"] + return db_connection_string(database_name, CREDENTIAL_FILE) + + +@pytest.fixture(scope='session') +def sqlalchemy_engine(test_db_info): + engine = create_engine( + test_db_info, + pool_pre_ping=True, + connect_args={ + 'connect_timeout': 10, + **DB_CONNECTION_OPTIONS + } + ) + initialize(engine) + + yield engine + + engine.dispose() + + +@pytest.fixture(scope="session") +def connection(sqlalchemy_engine): + with sqlalchemy_engine.connect() as connection: + # Configure session + SESSION.configure( + bind=connection, join_transaction_mode="create_savepoint" + ) + + yield connection + + +@pytest.fixture(scope="function", autouse=True) +def db_session(connection): + # Based on: + # https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites ## noqa + + transaction = connection.begin() + + # Create a new session + session = SESSION() + + yield session + + # rollback + # Everything that happened with the Session above + # (including calls to commit()) are rolled back. + session.close() + transaction.rollback() diff --git a/tests/db_connection.py b/tests/db_connection.py index 7efe8e9..4c6f29a 100644 --- a/tests/db_connection.py +++ b/tests/db_connection.py @@ -1,4 +1,4 @@ -from datetime import date, time +from datetime import datetime import pytest from geoalchemy2.elements import WKTElement @@ -7,7 +7,7 @@ PointMeasurements, db_session ) from snowexsql.tables import DOI, Instrument, LayerData, MeasurementType, \ - Observer, PointData, Site + Observer, Site from snowexsql.tables.campaign import Campaign from .db_setup import DBSetup @@ -77,9 +77,8 @@ def _add_entry( session, Site, dict(name=site_name), object_kwargs=dict( name=site_name, campaign=campaign, - date=kwargs.pop("date"), + datetime=kwargs.pop("datetime"), geom=kwargs.pop("geom"), - time=kwargs.pop("time"), elevation=kwargs.pop("elevation"), observers=observer_list, ) @@ -98,7 +97,9 @@ def _add_entry( ) object_kwargs = dict( - instrument=instrument, doi=doi, measurement=measurement_obj, + instrument=instrument, + doi=doi, + measurement_type=measurement_obj, **kwargs ) # Add site if given @@ -113,33 +114,11 @@ def _add_entry( session.add(new_entry) session.commit() - @pytest.fixture(scope="class") - def populated_points(self, db): - # Add made up data at the initialization of the class - row = { - 'date': date(2020, 1, 28), - 'time': time(18, 48), - 'elevation': 3148.2, - 'equipment': 'CRREL_B', - 'version_number': 1, - 'geom': WKTElement( - "POINT(747987.6190615438 4324061.7062127385)", srid=26912 - ), - 'value': 94 - } - self._add_entry( - db.url, PointData, 'magnaprobe', ["TEST"], - 'Grand Mesa', None, - "fake_doi", "depth", - **row - ) - @pytest.fixture(scope="class") def populated_layer(self, db): # Fake data to implement row = { - 'date': date(2020, 1, 28), - 'time': time(18, 48), + 'datetime': datetime(2020, 1, 28, 18, 48), 'elevation': 3148.2, 'geom': WKTElement( "POINT(747987.6190615438 4324061.7062127385)", srid=26912 @@ -154,7 +133,7 @@ def populated_layer(self, db): ) @pytest.fixture(scope="class") - def clz(self, populated_points, populated_layer): + def clz(self, populated_layer): """ Extend the class and overwrite the database name """ diff --git a/tests/db_setup.py b/tests/db_setup.py index 0a63557..47e43cc 100644 --- a/tests/db_setup.py +++ b/tests/db_setup.py @@ -1,7 +1,17 @@ import json from os.path import dirname, join +from sqlalchemy import orm + from snowexsql.db import get_db, initialize +from snowexsql.tables import (Campaign, DOI, Instrument, LayerData, + MeasurementType, Observer, Site) +from snowexsql.tables.site import SiteObservers + +# DB Configuration and Session +CREDENTIAL_FILE = join(dirname(__file__), 'credentials.json') +DB_INFO = json.load(open(CREDENTIAL_FILE)) +SESSION = orm.scoped_session(orm.sessionmaker()) class DBSetup: @@ -9,8 +19,8 @@ class DBSetup: Base class for all our tests. Ensures that we clean up after every class that's run """ - CREDENTIAL_FILE = join(dirname(__file__), 'credentials.json') - DB_INFO = json.load(open(CREDENTIAL_FILE)) + CREDENTIAL_FILE = CREDENTIAL_FILE + DB_INFO = DB_INFO @classmethod def database_name(cls): @@ -37,6 +47,16 @@ def teardown_class(cls): NOTE: Not dropping the DB since this is done at every test class initialization """ - cls.session.flush() - cls.session.rollback() + # TODO - Hack to make different data loading methods co-exist + # Remove this once we switch all to using factory boy + cls.session.query(LayerData).delete() + cls.session.query(SiteObservers).delete() + cls.session.query(Observer).delete() + cls.session.query(Site).delete() + cls.session.query(Instrument).delete() + cls.session.query(Campaign).delete() + cls.session.query(DOI).delete() + cls.session.query(MeasurementType).delete() + cls.session.commit() + cls.session.close() diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py new file mode 100644 index 0000000..e4a9dad --- /dev/null +++ b/tests/factories/__init__.py @@ -0,0 +1,21 @@ +from .campaign import CampaignFactory +from .doi import DOIFactory +from .instrument import InstrumentFactory +from .layer_data import LayerDataFactory +from .measurement_type import MeasurementTypeFactory +from .observer import ObserverFactory +from .point_data import PointDataFactory +from .point_observation import PointObservationFactory +from .site import SiteFactory + +__all__ = [ + "CampaignFactory", + "DOIFactory", + "InstrumentFactory", + "LayerDataFactory", + "MeasurementTypeFactory", + "ObserverFactory", + "PointDataFactory", + "PointObservationFactory", + "SiteFactory", +] diff --git a/tests/factories/base_factory.py b/tests/factories/base_factory.py new file mode 100644 index 0000000..6c02882 --- /dev/null +++ b/tests/factories/base_factory.py @@ -0,0 +1,9 @@ +import factory.alchemy as factory_alchemy + +from tests.db_setup import SESSION + + +class BaseFactory(factory_alchemy.SQLAlchemyModelFactory): + class Meta: + sqlalchemy_session = SESSION + sqlalchemy_session_persistence = 'commit' diff --git a/tests/factories/campaign.py b/tests/factories/campaign.py new file mode 100644 index 0000000..9515847 --- /dev/null +++ b/tests/factories/campaign.py @@ -0,0 +1,11 @@ +from snowexsql.tables import Campaign + +from .base_factory import BaseFactory + + +class CampaignFactory(BaseFactory): + class Meta: + model = Campaign + + name = 'Snow Campaign' + description = 'Snow Campaign Description' diff --git a/tests/factories/doi.py b/tests/factories/doi.py new file mode 100644 index 0000000..aa84131 --- /dev/null +++ b/tests/factories/doi.py @@ -0,0 +1,14 @@ +import datetime + +import factory + +from snowexsql.tables import DOI +from .base_factory import BaseFactory + + +class DOIFactory(BaseFactory): + class Meta: + model = DOI + + doi = '111-222' + date_accessed = factory.LazyFunction(datetime.date.today) diff --git a/tests/factories/instrument.py b/tests/factories/instrument.py new file mode 100644 index 0000000..7e04dba --- /dev/null +++ b/tests/factories/instrument.py @@ -0,0 +1,11 @@ +from snowexsql.tables import Instrument +from .base_factory import BaseFactory + + +class InstrumentFactory(BaseFactory): + class Meta: + model = Instrument + + name = 'SWE Instrument' + model = 'Instrument Model' + specifications = 'Measures SWE well' diff --git a/tests/factories/layer_data.py b/tests/factories/layer_data.py new file mode 100644 index 0000000..5d51a7c --- /dev/null +++ b/tests/factories/layer_data.py @@ -0,0 +1,28 @@ +import factory + +from snowexsql.tables import LayerData +from .base_factory import BaseFactory +from .doi import DOIFactory +from .instrument import InstrumentFactory +from .measurement_type import MeasurementTypeFactory +from .site import SiteFactory + + +class LayerDataFactory(BaseFactory): + class Meta: + model = LayerData + + depth = 100.0 + pit_id = 'Pit 123' + bottom_depth = 90.0 + comments = 'Layer comment' + value = '40' + flags = 'Sample' + sample_a = '42' + + measurement_type = factory.SubFactory( + MeasurementTypeFactory, name='Density', units='kg/m3' + ) + instrument = factory.SubFactory(InstrumentFactory, name='Density Cutter') + doi = factory.SubFactory(DOIFactory) + site = factory.SubFactory(SiteFactory) diff --git a/tests/factories/measurement_type.py b/tests/factories/measurement_type.py new file mode 100644 index 0000000..7954552 --- /dev/null +++ b/tests/factories/measurement_type.py @@ -0,0 +1,10 @@ +from snowexsql.tables import MeasurementType +from .base_factory import BaseFactory + + +class MeasurementTypeFactory(BaseFactory): + class Meta: + model = MeasurementType + + name = 'SWE' + units = 'mm' diff --git a/tests/factories/observer.py b/tests/factories/observer.py new file mode 100644 index 0000000..5e53f8d --- /dev/null +++ b/tests/factories/observer.py @@ -0,0 +1,11 @@ +import factory + +from snowexsql.tables import Observer +from .base_factory import BaseFactory + + +class ObserverFactory(BaseFactory): + class Meta: + model = Observer + + name = factory.Faker('name') diff --git a/tests/factories/point_data.py b/tests/factories/point_data.py new file mode 100644 index 0000000..9a944d7 --- /dev/null +++ b/tests/factories/point_data.py @@ -0,0 +1,23 @@ +from datetime import datetime, timezone + +import factory +from geoalchemy2 import WKTElement + +from snowexsql.tables import PointData +from .base_factory import BaseFactory +from .point_observation import PointObservationFactory + + +class PointDataFactory(BaseFactory): + class Meta: + model = PointData + + value = 10 + datetime = factory.LazyFunction(lambda: datetime.now(timezone.utc)) + + geom = WKTElement( + "POINT(747987.6190615438 4324061.7062127385)", srid=26912 + ) + elevation = 3148.2 + + observation = factory.SubFactory(PointObservationFactory) diff --git a/tests/factories/point_observation.py b/tests/factories/point_observation.py new file mode 100644 index 0000000..4eaf56a --- /dev/null +++ b/tests/factories/point_observation.py @@ -0,0 +1,26 @@ +import datetime + +import factory + +from snowexsql.tables.point_observation import PointObservation +from .base_factory import BaseFactory +from .campaign import CampaignFactory +from .doi import DOIFactory +from .instrument import InstrumentFactory +from .measurement_type import MeasurementTypeFactory +from .observer import ObserverFactory + + +class PointObservationFactory(BaseFactory): + class Meta: + model = PointObservation + + name = 'Point Observation' + description = 'Point Description' + date = factory.LazyFunction(datetime.date.today) + + campaign = factory.SubFactory(CampaignFactory) + doi = factory.SubFactory(DOIFactory) + instrument = factory.SubFactory(InstrumentFactory) + measurement_type = factory.SubFactory(MeasurementTypeFactory) + observer = factory.SubFactory(ObserverFactory) diff --git a/tests/factories/site.py b/tests/factories/site.py new file mode 100644 index 0000000..f70e291 --- /dev/null +++ b/tests/factories/site.py @@ -0,0 +1,51 @@ +from datetime import datetime, timezone + +import factory +from geoalchemy2 import WKTElement + +from snowexsql.tables.site import Site +from .base_factory import BaseFactory +from .campaign import CampaignFactory +from .doi import DOIFactory + + +class SiteFactory(BaseFactory): + class Meta: + model = Site + + name = 'Site Name' + description = 'Site Description' + datetime = factory.LazyFunction(lambda: datetime.now(timezone.utc)) + + slope_angle = 0.0 + aspect = 0.0 + air_temp = -5.0 + total_depth = 100.5 + weather_description = "Weather Conditions" + precip = "None" + sky_cover = "Clear" + wind = "Light" + ground_condition = "Frozen" + ground_roughness = "Smooth" + ground_vegetation = "Bare" + vegetation_height = "None" + tree_canopy = "Open" + site_notes = "Site Notes" + + # Single Location data + geom = WKTElement( + "POINT(747987.6190615438 4324061.7062127385)", srid=26912 + ) + elevation = 3148.2 + + campaign = factory.SubFactory(CampaignFactory, name="Snow Campaign 2") + doi = factory.SubFactory(DOIFactory, doi='222-333') + + @factory.post_generation + def observers(self, create, extracted, **kwargs): + if not create or not extracted: + # Simple build, or nothing to add, do nothing. + return + + # Add the iterable of groups using bulk addition + self.observers.append(extracted) diff --git a/tests/tables/test_campaign.py b/tests/tables/test_campaign.py new file mode 100644 index 0000000..c093a37 --- /dev/null +++ b/tests/tables/test_campaign.py @@ -0,0 +1,21 @@ +import pytest + +from snowexsql.tables import Campaign + + +@pytest.fixture +def campaign_record(campaign_factory, db_session): + campaign_factory.create() + return db_session.query(Campaign).first() + + +class TestCampaign: + @pytest.fixture(autouse=True) + def setup_method(self, campaign_record): + self.subject = campaign_record + + def test_campaign_name(self, campaign_factory): + assert self.subject.name == campaign_factory.name + + def test_campaign_description(self, campaign_factory): + assert self.subject.description == campaign_factory.description diff --git a/tests/tables/test_doi.py b/tests/tables/test_doi.py new file mode 100644 index 0000000..405ff9c --- /dev/null +++ b/tests/tables/test_doi.py @@ -0,0 +1,23 @@ +from datetime import date + +import pytest + +from snowexsql.tables import DOI + + +@pytest.fixture +def doi_record(doi_factory, db_session): + doi_factory.create() + return db_session.query(DOI).first() + + +class TestDOI: + @pytest.fixture(autouse=True) + def setup_method(self, doi_record): + self.subject = doi_record + + def test_doi(self, doi_factory): + assert self.subject.doi == doi_factory.doi + + def test_date_accessed(self): + assert type(self.subject.date_accessed) is date diff --git a/tests/tables/test_instrument.py b/tests/tables/test_instrument.py new file mode 100644 index 0000000..3f62846 --- /dev/null +++ b/tests/tables/test_instrument.py @@ -0,0 +1,26 @@ +import pytest + +from snowexsql.tables import Instrument + + +@pytest.fixture +def instrument_record(instrument_factory, db_session): + instrument_factory.create() + return db_session.query(Instrument).first() + + +class TestInstrument: + @pytest.fixture(autouse=True) + def setup_method(self, instrument_record): + self.subject = instrument_record + + def test_instrument_name(self, instrument_factory): + assert self.subject.name == instrument_factory.name + + def test_instrument_model(self, instrument_factory): + assert self.subject.model == instrument_factory.model + + def test_instrument_specification(self, instrument_factory): + assert ( + self.subject.specifications == instrument_factory.specifications + ) diff --git a/tests/tables/test_layer_data.py b/tests/tables/test_layer_data.py new file mode 100644 index 0000000..920f357 --- /dev/null +++ b/tests/tables/test_layer_data.py @@ -0,0 +1,63 @@ +import pytest + +from snowexsql.tables import DOI, Instrument, LayerData, MeasurementType, Site + + +@pytest.fixture +def layer_data_attributes(layer_data_factory): + return layer_data_factory.build() + + +@pytest.fixture +def layer_data_record(layer_data_factory, db_session): + layer_data_factory.create() + return db_session.query(LayerData).first() + + +class TestLayerData: + @pytest.fixture(autouse=True) + def setup_method(self, layer_data_record, layer_data_attributes): + self.subject = layer_data_record + self.attributes = layer_data_attributes + + def test_pit_id(self): + assert self.subject.pit_id == self.attributes.pit_id + + def test_depth_attribute(self): + assert type(self.subject.depth) is float + assert self.subject.depth == self.attributes.depth + + def test_bottom_depth_attribute(self): + assert type(self.subject.bottom_depth) is float + assert self.subject.bottom_depth == self.attributes.bottom_depth + + def test_comments_attribute(self): + assert self.subject.comments == self.attributes.comments + + def test_value_attribute(self): + assert self.subject.value == self.attributes.value + + def test_sample_a_attribute(self): + assert self.subject.sample_a == self.attributes.sample_a + + def test_flags_attribute(self): + assert self.subject.flags == self.attributes.flags + + def test_has_site(self): + assert isinstance(self.subject.site, Site) + assert self.subject.site.name == self.attributes.site.name + + def test_has_measurement_type(self): + assert isinstance(self.subject.measurement_type, MeasurementType) + assert ( + self.subject.measurement_type.name == + self.attributes.measurement_type.name + ) + + def test_has_instrument(self): + assert isinstance(self.subject.instrument, Instrument) + assert self.subject.instrument.name == self.attributes.instrument.name + + def test_has_doi(self): + assert isinstance(self.subject.doi, DOI) + assert self.subject.doi.doi == self.attributes.doi.doi diff --git a/tests/tables/test_measurment_type.py b/tests/tables/test_measurment_type.py new file mode 100644 index 0000000..439eae4 --- /dev/null +++ b/tests/tables/test_measurment_type.py @@ -0,0 +1,21 @@ +import pytest + +from snowexsql.tables import MeasurementType + + +@pytest.fixture +def measurement_type_record(measurement_type_factory, db_session): + measurement_type_factory.create() + return db_session.query(MeasurementType).first() + + +class TestMeasurementType: + @pytest.fixture(autouse=True) + def setup_method(self, measurement_type_record): + self.subject = measurement_type_record + + def test_measurement_type_name(self, measurement_type_factory): + assert self.subject.name == measurement_type_factory.name + + def test_measurement_type_unit(self, measurement_type_factory): + assert self.subject.units == measurement_type_factory.units diff --git a/tests/tables/test_observer.py b/tests/tables/test_observer.py new file mode 100644 index 0000000..6cf3c6b --- /dev/null +++ b/tests/tables/test_observer.py @@ -0,0 +1,18 @@ +import pytest + +from snowexsql.tables import Observer + + +@pytest.fixture +def observer_record(observer_factory, db_session): + observer_factory.create() + return db_session.query(Observer).first() + + +class TestObserver: + @pytest.fixture(autouse=True) + def setup_method(self, observer_record): + self.subject = observer_record + + def test_observer_name(self): + assert type(self.subject.name) is str diff --git a/tests/tables/test_point_data.py b/tests/tables/test_point_data.py new file mode 100644 index 0000000..407ab31 --- /dev/null +++ b/tests/tables/test_point_data.py @@ -0,0 +1,50 @@ +from datetime import date, datetime + +import pytest +from geoalchemy2 import WKBElement + +from snowexsql.tables import PointData + + +@pytest.fixture +def point_data_attributes(site_factory): + return site_factory.build() + + +@pytest.fixture +def point_entry_record(point_data_factory, db_session): + point_data_factory.create() + return db_session.query(PointData).first() + + +class TestPointData: + @pytest.fixture(autouse=True) + def setup_method(self, point_entry_record, point_data_attributes): + self.subject = point_entry_record + self.attributes = point_data_attributes + + def test_record_id(self): + assert self.subject.id is not None + + def test_value_attribute(self): + assert type(self.subject.value) is float + + def test_datetime_attribute(self): + assert type(self.subject.datetime) is datetime + # The microseconds won't be the same between the site_attribute + # and site_record fixture. Hence only testing the difference being + # small. Important to subtract the later from the earlier time as + # the timedelta object is incorrect otherwise + assert ( + self.attributes.datetime - self.subject.datetime + ).seconds == pytest.approx(0, rel=0.1) + + def test_date_attribute(self): + assert type(self.subject.date) is date + assert self.subject.date == self.attributes.date + + def test_elevation_attribute(self): + assert self.subject.elevation == self.attributes.elevation + + def test_geom_attribute(self): + assert isinstance(self.subject.geom, WKBElement) diff --git a/tests/tables/test_point_observation.py b/tests/tables/test_point_observation.py new file mode 100644 index 0000000..d5ef760 --- /dev/null +++ b/tests/tables/test_point_observation.py @@ -0,0 +1,50 @@ +import datetime + +import pytest + +from snowexsql.tables import ( + Campaign, DOI, Instrument, MeasurementType, Observer, PointObservation +) + + +@pytest.fixture +def point_observation_record(point_observation_factory, db_session): + point_observation_factory.create() + return db_session.query(PointObservation).first() + + +class TestPointObservation: + @pytest.fixture(autouse=True) + def setup_method(self, point_observation_record): + self.subject = point_observation_record + + def test_name_attribute(self, point_observation_factory): + assert self.subject.name == point_observation_factory.name + + def test_description_attribute(self, point_observation_factory): + assert ( + self.subject.description == point_observation_factory.description + ) + + def test_date_attribute(self): + assert type(self.subject.date) is datetime.date + + def test_in_campaign(self): + assert self.subject.campaign is not None + assert isinstance(self.subject.campaign, Campaign) + + def test_has_doi(self): + assert self.subject.doi is not None + assert isinstance(self.subject.doi, DOI) + + def test_has_measurement_type(self): + assert self.subject.measurement_type is not None + assert isinstance(self.subject.measurement_type, MeasurementType) + + def test_has_instrument(self): + assert self.subject.instrument is not None + assert isinstance(self.subject.instrument, Instrument) + + def test_has_observer(self): + assert self.subject.observer is not None + assert isinstance(self.subject.observer, Observer) diff --git a/tests/tables/test_site.py b/tests/tables/test_site.py new file mode 100644 index 0000000..8ce24fe --- /dev/null +++ b/tests/tables/test_site.py @@ -0,0 +1,123 @@ +import datetime + +import pytest +from geoalchemy2 import WKBElement + +from snowexsql.tables import Campaign, DOI, Observer, Site + + +@pytest.fixture +def site_attributes(site_factory): + return site_factory.build() + + +@pytest.fixture +def site_record(site_factory, observer_factory, db_session): + site_factory.create(observers=(observer_factory.create())) + return db_session.query(Site).first() + + +class TestSite: + @pytest.fixture(autouse=True) + def setup_method(self, site_record, site_attributes): + self.subject = site_record + self.attributes = site_attributes + + def test_site_name(self, site_factory): + assert self.subject.name == site_factory.name + + def test_datetime_attribute(self): + assert type(self.subject.datetime) is datetime.datetime + # The microseconds won't be the same between the site_attribute + # and site_record fixture. Hence only testing the difference being + # small. Important to subtract the later from the earlier time as + # the timedelta object is incorrect otherwise + assert ( + self.attributes.datetime - self.subject.datetime + ).seconds == pytest.approx(0, rel=0.1) + + def test_date_attribute(self): + assert type(self.subject.date) is datetime.date + assert self.subject.date == self.attributes.date + + def test_description_attribute(self): + assert self.subject.description == self.attributes.description + + def test_slope_angle_attribute(self): + assert type(self.subject.slope_angle) is float + assert self.subject.slope_angle == self.attributes.slope_angle + + def test_aspect_attribute(self): + assert type(self.subject.aspect) is float + assert self.subject.aspect == self.attributes.aspect + + def test_air_temp_attribute(self): + assert type(self.subject.air_temp) is float + assert self.subject.air_temp == self.attributes.air_temp + + def test_total_depth_attribute(self): + assert type(self.subject.total_depth) is float + assert self.subject.total_depth == self.attributes.total_depth + + def test_weather_description_attribute(self): + assert ( + self.subject.weather_description == + self.attributes.weather_description + ) + + def test_precip_attribute(self): + assert self.subject.precip == self.attributes.precip + + def test_sky_cover_attribute(self): + assert self.subject.sky_cover == self.attributes.sky_cover + + def test_wind_attribute(self): + assert self.subject.wind == self.attributes.wind + + def test_ground_condition_attribute(self): + assert ( + self.subject.ground_condition == self.attributes.ground_condition + ) + + def test_ground_roughness_attribute(self): + assert ( + self.subject.ground_roughness == self.attributes.ground_roughness + ) + + def test_ground_vegetation_attribute(self): + assert ( + self.subject.ground_vegetation == self.attributes.ground_vegetation + ) + + def test_vegetation_height_attribute(self): + assert ( + self.subject.vegetation_height == + self.attributes.vegetation_height + ) + + def test_tree_canopy_attribute(self): + assert self.subject.tree_canopy == self.attributes.tree_canopy + + def test_site_notes_attribute(self): + assert self.subject.site_notes == self.attributes.site_notes + + def test_elevation_attribute(self, point_data_factory): + assert self.subject.elevation == point_data_factory.elevation + + def test_geom_attribute(self): + assert isinstance(self.subject.geom, WKBElement) + + def test_in_campaign(self): + assert self.subject.campaign is not None + assert isinstance(self.subject.campaign, Campaign) + assert self.subject.campaign.name == self.attributes.campaign.name + + def test_has_doi(self): + assert self.subject.doi is not None + assert isinstance(self.subject.doi, DOI) + assert self.subject.doi.doi == self.attributes.doi.doi + + def test_has_observers(self): + assert self.subject.observers is not None + assert isinstance(self.subject.observers, list) + assert type(self.subject.observers[0]) == Observer diff --git a/tests/test_db.py b/tests/test_db.py index f8b9bd8..29cc333 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,98 +1,86 @@ import pytest -from sqlalchemy import Table +from sqlalchemy import Engine, MetaData +from sqlalchemy.orm import Session -from snowexsql.db import get_db, get_table_attributes -from snowexsql.tables import ImageData, LayerData, PointData, Site, \ - MeasurementType, DOI -from .db_setup import DBSetup +import snowexsql +from snowexsql.db import ( + DB_CONNECTION_PROTOCOL, db_connection_string, get_db, load_credentials +) +from .db_setup import DBSetup, DB_INFO -class TestDB(DBSetup): - base_atts = ['date', 'site_id'] - single_loc_atts = ['elevation', 'geom', 'time'] +@pytest.fixture(scope='function') +def db_connection_string_patch(monkeypatch, test_db_info): + def db_string(_name, _credentials): + return test_db_info - meas_atts = ['measurement_type_id'] - - site_atts = single_loc_atts + \ - ['slope_angle', 'aspect', 'air_temp', 'total_depth', - 'weather_description', 'precip', 'sky_cover', 'wind', - 'ground_condition', 'ground_roughness', - 'ground_vegetation', 'vegetation_height', - 'tree_canopy', 'site_notes'] - - point_atts = single_loc_atts + meas_atts + \ - ['version_number', 'equipment', 'value', 'instrument_id'] - - layer_atts = meas_atts + \ - ['depth', 'value', 'bottom_depth', 'comments', 'sample_a', - 'sample_b', 'sample_c'] - raster_atts = meas_atts + ['raster', 'description'] - measurement_types_attributes = ['name', 'units','derived'] - DOI_attributes = ['doi', 'date_accessed'] - - def setup_class(self): - """ - Setup the database one time for testing - """ - super().setup_class() - # only reflect the tables we will use - self.metadata.reflect(self.engine, only=['points', 'layers']) - - def test_point_structure(self): - """ - Tests our tables are in the database - """ - t = Table("points", self.metadata, autoload=True) - columns = [m.key for m in t.columns] + monkeypatch.setattr( + snowexsql.db, + 'db_connection_string', + db_string + ) - for c in self.point_atts: - assert c in columns - def test_layer_structure(self): +class TestDBConnectionInfo: + def test_load_credentials(self): + user, password = load_credentials(DBSetup.CREDENTIAL_FILE) + assert user == DB_INFO['username'] + assert password == DB_INFO['password'] + + def test_db_connection_string(self): + db_string = db_connection_string( + DBSetup.database_name(), DBSetup.CREDENTIAL_FILE + ) + assert db_string.startswith(DB_CONNECTION_PROTOCOL) + + def test_db_connection_string_credentials(self): + db_string = db_connection_string( + DBSetup.database_name(), DBSetup.CREDENTIAL_FILE + ) + user, password = load_credentials(DBSetup.CREDENTIAL_FILE) + + assert user in db_string + assert password in db_string + + def test_db_connection_string_has_db_and_host(self): + db_string = db_connection_string( + DBSetup.database_name(), DBSetup.CREDENTIAL_FILE + ) + + assert DB_INFO['address'] in db_string + assert DB_INFO['db_name'] in db_string + + def test_db_connection_string_no_credentials(self): + db_string = db_connection_string(DBSetup.database_name()) + user, password = load_credentials(DBSetup.CREDENTIAL_FILE) + + assert user not in db_string + assert password not in db_string + + @pytest.mark.usefixtures('db_connection_string_patch') + def test_returns_engine(self, monkeypatch, test_db_info): + assert isinstance(get_db(DBSetup.database_name())[0], Engine) + + @pytest.mark.usefixtures('db_connection_string_patch') + def test_returns_session(self): + assert isinstance(get_db(DBSetup.database_name())[1], Session) + + @pytest.mark.usefixtures('db_connection_string_patch') + def test_returns_metadata(self): + assert isinstance( + get_db(DBSetup.database_name(), return_metadata=True)[2], + MetaData + ) + + @pytest.mark.usefixtures('db_connection_string_patch') + @pytest.mark.parametrize("return_metadata, expected_objs", [ + (False, 2), + (True, 3)]) + def test_get_metadata(self, return_metadata, expected_objs): """ - Tests our tables are in the database + Test we can receive a connection and opt out of getting the metadata """ - t = Table("layers", self.metadata, autoload=True) - columns = [m.key for m in t.columns] - - for c in self.layer_atts: - assert c in columns - - @pytest.mark.parametrize( - "DataCls,attributes", - [ - (Site, site_atts), - (PointData, point_atts), - (LayerData, layer_atts), - (ImageData, raster_atts), - (MeasurementType, measurement_types_attributes), - (DOI, DOI_attributes) - ] - ) - def test_get_table_attributes(self, DataCls, attributes): - """ - Test we return a correct list of table columns from db.py - """ - atts = get_table_attributes(DataCls) - - for c in attributes: - assert c in atts - - -# Independent Tests -@pytest.mark.parametrize("return_metadata, expected_objs", [ - (False, 2), - (True, 3)]) -def test_getting_db(return_metadata, expected_objs): - """ - Test we can receive a connection and opt out of getting the metadata - """ - - db_name = ( - DBSetup.DB_INFO["username"] + ":" + - DBSetup.DB_INFO["password"] + "@" + - DBSetup.database_name() - ) - - result = get_db(db_name, return_metadata=return_metadata) - assert len(result) == expected_objs + result = get_db( + DBSetup.database_name(), return_metadata=return_metadata + ) + assert len(result) == expected_objs