diff --git a/datachecks/core/metric/base.py b/datachecks/core/metric/base.py index 50a68bd2..a120613d 100644 --- a/datachecks/core/metric/base.py +++ b/datachecks/core/metric/base.py @@ -66,29 +66,29 @@ def __init__( name: str, data_source: DataSource, metric_type: MetricsType, - table_name: Optional[str] = None, - index_name: Optional[str] = None, - filters: Dict = None, metric_logger: MetricLogger = None, + **kwargs, ): - if index_name is not None and table_name is not None: - raise ValueError( - "Please give a value for table_name or index_name (but not both)" - ) - if index_name is None and table_name is None: + if "index_name" in kwargs and "table_name" in kwargs: + if kwargs["index_name"] is not None and kwargs["table_name"] is not None: + raise ValueError( + "Please give a value for table_name or index_name (but not both)" + ) + if "index_name" not in kwargs and "table_name" not in kwargs: raise ValueError("Please give a value for table_name or index_name") self.index_name, self.table_name = None, None - if index_name: - self.index_name = index_name - if table_name: - self.table_name = table_name + if "index_name" in kwargs: + self.index_name = kwargs["index_name"] + if "table_name" in kwargs: + self.table_name = kwargs["table_name"] self.name: str = name self.data_source = data_source self.metric_type = metric_type self.filter_query = None - if filters is not None: + if "filters" in kwargs and kwargs["filters"] is not None: + filters = kwargs["filters"] if ("search_query" in filters and filters["search_query"] is not None) and ( "where_clause" in filters and filters["where_clause"] is not None ): @@ -109,7 +109,6 @@ def get_metric_identity(self): data_source=self.data_source, ) - @abstractmethod def _generate_metric_value(self) -> float: pass @@ -143,25 +142,20 @@ class FieldMetrics(Metric, ABC): def __init__( self, name: str, - metric_type: MetricsType, data_source: DataSource, - field_name: str, - table_name: Optional[str] = None, - index_name: Optional[str] = None, - filters: Dict = None, + metric_type: MetricsType, metric_logger: MetricLogger = None, + **kwargs, ): super().__init__( name=name, data_source=data_source, - table_name=table_name, - index_name=index_name, metric_type=metric_type, - filters=filters, metric_logger=metric_logger, + **kwargs, ) - - self.field_name = field_name + if "field_name" in kwargs: + self.field_name = kwargs["field_name"] @property def get_field_name(self): diff --git a/datachecks/core/metric/manager.py b/datachecks/core/metric/manager.py index 5c40463f..9af9412e 100644 --- a/datachecks/core/metric/manager.py +++ b/datachecks/core/metric/manager.py @@ -13,18 +13,25 @@ # limitations under the License. from dataclasses import asdict -from typing import Dict, List +from typing import List + +from loguru import logger from datachecks.core.configuration.configuration import MetricConfiguration from datachecks.core.datasource.manager import DataSourceManager from datachecks.core.logger.base import MetricLogger -from datachecks.core.metric.base import Metric, MetricsType -from datachecks.core.metric.freshness_metric import FreshnessValueMetric -from datachecks.core.metric.numeric_metric import (DocumentCountMetric, - MaxMetric, RowCountMetric) +from datachecks.core.metric.numeric_metric import * +from datachecks.core.metric.reliability_metric import * class MetricManager: + METRIC_CLASS_MAPPING = { + MetricsType.DOCUMENT_COUNT.value: "DocumentCountMetric", + MetricsType.ROW_COUNT.value: "RowCountMetric", + MetricsType.MAX.value: "MaxMetric", + MetricsType.FRESHNESS.value: "FreshnessValueMetric", + } + def __init__( self, metric_config: Dict[str, List[MetricConfiguration]], @@ -39,68 +46,35 @@ def __init__( def _build_metrics(self, config: Dict[str, List[MetricConfiguration]]): for data_source, metric_list in config.items(): for metric_config in metric_list: - if metric_config.metric_type == MetricsType.DOCUMENT_COUNT: - metric = DocumentCountMetric( - name=metric_config.name, - data_source=self.data_source_manager.get_data_source( - data_source - ), - filters=asdict(metric_config.filters) - if metric_config.filters - else None, - index_name=metric_config.index, - metric_type=MetricsType.DOCUMENT_COUNT, - metric_logger=self.metric_logger, - ) - self.metrics[metric.get_metric_identity()] = metric - elif metric_config.metric_type == MetricsType.ROW_COUNT: - metric = RowCountMetric( - name=metric_config.name, - data_source=self.data_source_manager.get_data_source( - data_source - ), - filters=asdict(metric_config.filters) - if metric_config.filters - else None, - table_name=metric_config.table, - metric_type=MetricsType.ROW_COUNT, - metric_logger=self.metric_logger, - ) - self.metrics[metric.get_metric_identity()] = metric - elif metric_config.metric_type == MetricsType.MAX: - metric = MaxMetric( - name=metric_config.name, - data_source=self.data_source_manager.get_data_source( - data_source - ), - filters=asdict(metric_config.filters) - if metric_config.filters - else None, - table_name=metric_config.table, - index_name=metric_config.index, - metric_type=MetricsType.MAX, - field_name=metric_config.field, - metric_logger=self.metric_logger, - ) - self.metrics[metric.get_metric_identity()] = metric - elif metric_config.metric_type == MetricsType.FRESHNESS: - metric = FreshnessValueMetric( - name=metric_config.name, - data_source=self.data_source_manager.get_data_source( - data_source - ), - filters=asdict(metric_config.filters) - if metric_config.filters - else None, - table_name=metric_config.table, - index_name=metric_config.index, - metric_type=MetricsType.FRESHNESS, - field_name=metric_config.field, - metric_logger=self.metric_logger, - ) - self.metrics[metric.get_metric_identity()] = metric - else: - raise ValueError("Invalid metric type") + params = { + "filters": asdict(metric_config.filters) + if metric_config.filters + else None, + } + if metric_config.index: + params["index_name"] = metric_config.index + if metric_config.table: + params["table_name"] = metric_config.table + if metric_config.field: + params["field_name"] = metric_config.field + + logger.info(f"==============metric_config: {self.METRIC_CLASS_MAPPING}") + logger.info( + f"==============metric_config.metric_type: {metric_config.metric_type}" + ) + # logger.info(globals()) + metric: Metric = globals()[ + self.METRIC_CLASS_MAPPING[metric_config.metric_type] + ]( + metric_config.name, + self.data_source_manager.get_data_source(data_source), + MetricsType(metric_config.metric_type.lower()), + self.metric_logger, + **params, + ) + + logger.info(metric.__dict__) + self.metrics[metric.get_metric_identity()] = metric @property def get_metrics(self): diff --git a/datachecks/core/metric/numeric_metric.py b/datachecks/core/metric/numeric_metric.py index fba39cf5..e1f66ab4 100644 --- a/datachecks/core/metric/numeric_metric.py +++ b/datachecks/core/metric/numeric_metric.py @@ -16,62 +16,10 @@ from datachecks.core.datasource.search_datasource import SearchIndexDataSource from datachecks.core.datasource.sql_datasource import SQLDatasource -from datachecks.core.metric.base import (FieldMetrics, Metric, MetricIdentity, +from datachecks.core.metric.base import (FieldMetrics, MetricIdentity, MetricsType) -class DocumentCountMetric(Metric): - """ - DocumentCountMetrics is a class that represents a metric that is generated by a data source. - """ - - def validate_data_source(self): - return isinstance(self.data_source, SearchIndexDataSource) - - def get_metric_identity(self): - return MetricIdentity.generate_identity( - metric_type=MetricsType.DOCUMENT_COUNT, - metric_name=self.name, - data_source=self.data_source, - index_name=self.index_name, - ) - - def _generate_metric_value(self): - if isinstance(self.data_source, SearchIndexDataSource): - return self.data_source.query_get_document_count( - index_name=self.index_name, - filters=self.filter_query if self.filter_query else None, - ) - else: - raise ValueError("Invalid data source type") - - -class RowCountMetric(Metric): - - """ - RowCountMetrics is a class that represents a metric that is generated by a data source. - """ - - def get_metric_identity(self): - return MetricIdentity.generate_identity( - metric_type=MetricsType.ROW_COUNT, - metric_name=self.name, - data_source=self.data_source, - table_name=self.table_name, - ) - - def validate_data_source(self): - return isinstance(self.data_source, SQLDatasource) - - def _generate_metric_value(self): - if isinstance(self.data_source, SQLDatasource): - return self.data_source.query_get_row_count( - table=self.table_name, - filters=self.filter_query if self.filter_query else None, - ) - else: - raise ValueError("Invalid data source type") - class MinMetric(FieldMetrics): """ @@ -87,7 +35,7 @@ def get_metric_identity(self): table_name=self.table_name if self.table_name else None, index_name=self.index_name if self.index_name else None, ) - + def _generate_metric_value(self): if isinstance(self.data_source, SQLDatasource): return self.data_source.query_get_min( @@ -105,7 +53,6 @@ def _generate_metric_value(self): raise ValueError("Invalid data source type") - class MaxMetric(FieldMetrics): """ @@ -138,6 +85,7 @@ def _generate_metric_value(self): else: raise ValueError("Invalid data source type") + class AvgMetric(FieldMetrics): """ @@ -153,7 +101,7 @@ def get_metric_identity(self): table_name=self.table_name if self.table_name else None, index_name=self.index_name if self.index_name else None, ) - + def _generate_metric_value(self): if isinstance(self.data_source, SQLDatasource): return self.data_source.query_get_avg( @@ -168,4 +116,4 @@ def _generate_metric_value(self): filters=self.filter_query if self.filter_query else None, ) else: - raise ValueError("Invalid data source type") \ No newline at end of file + raise ValueError("Invalid data source type") diff --git a/datachecks/core/metric/freshness_metric.py b/datachecks/core/metric/reliability_metric.py similarity index 50% rename from datachecks/core/metric/freshness_metric.py rename to datachecks/core/metric/reliability_metric.py index 62d2cb11..6877ebfc 100644 --- a/datachecks/core/metric/freshness_metric.py +++ b/datachecks/core/metric/reliability_metric.py @@ -14,7 +14,60 @@ from datachecks.core.datasource.search_datasource import SearchIndexDataSource from datachecks.core.datasource.sql_datasource import SQLDatasource from datachecks.core.metric.base import (FieldMetrics, MetricIdentity, - MetricsType) + MetricsType, Metric) + + +class DocumentCountMetric(Metric): + """ + DocumentCountMetrics is a class that represents a metric that is generated by a data source. + """ + + def validate_data_source(self): + return isinstance(self.data_source, SearchIndexDataSource) + + def get_metric_identity(self): + return MetricIdentity.generate_identity( + metric_type=MetricsType.DOCUMENT_COUNT, + metric_name=self.name, + data_source=self.data_source, + index_name=self.index_name, + ) + + def _generate_metric_value(self): + if isinstance(self.data_source, SearchIndexDataSource): + return self.data_source.query_get_document_count( + index_name=self.index_name, + filters=self.filter_query if self.filter_query else None, + ) + else: + raise ValueError("Invalid data source type") + + +class RowCountMetric(Metric): + + """ + RowCountMetrics is a class that represents a metric that is generated by a data source. + """ + + def get_metric_identity(self): + return MetricIdentity.generate_identity( + metric_type=MetricsType.ROW_COUNT, + metric_name=self.name, + data_source=self.data_source, + table_name=self.table_name, + ) + + def validate_data_source(self): + return isinstance(self.data_source, SQLDatasource) + + def _generate_metric_value(self): + if isinstance(self.data_source, SQLDatasource): + return self.data_source.query_get_row_count( + table=self.table_name, + filters=self.filter_query if self.filter_query else None, + ) + else: + raise ValueError("Invalid data source type") class FreshnessValueMetric(FieldMetrics): diff --git a/tests/core/metric/test_metric_manager.py b/tests/core/metric/test_metric_manager.py index 07a2726d..21687c87 100644 --- a/tests/core/metric/test_metric_manager.py +++ b/tests/core/metric/test_metric_manager.py @@ -20,7 +20,7 @@ from datachecks.core.datasource.manager import DataSourceManager from datachecks.core.metric.base import MetricsType from datachecks.core.metric.manager import MetricManager -from datachecks.core.metric.numeric_metric import DocumentCountMetric +from datachecks.core.metric.reliability_metric import DocumentCountMetric OPEN_SEARCH_DATA_SOURCE_NAME = "test_open_search_data_source" POSTGRES_DATA_SOURCE_NAME = "test_postgres_data_source" diff --git a/tests/core/metric/test_numeric_metric.py b/tests/core/metric/test_numeric_metric.py index 27783744..e3833593 100644 --- a/tests/core/metric/test_numeric_metric.py +++ b/tests/core/metric/test_numeric_metric.py @@ -22,8 +22,8 @@ OpenSearchSearchIndexDataSource from datachecks.core.datasource.postgres import PostgresSQLDatasource from datachecks.core.metric.base import MetricsType -from datachecks.core.metric.numeric_metric import (DocumentCountMetric, MinMetric, - MaxMetric, AvgMetric, RowCountMetric) +from datachecks.core.metric.numeric_metric import (MinMetric, + MaxMetric, AvgMetric) from tests.utils import create_opensearch_client, create_postgres_connection @@ -98,61 +98,6 @@ def setup_data( postgresql_connection.close() -@pytest.mark.usefixtures("setup_data", "opensearch_datasource") -class TestDocumentCountMetric: - def test_should_return_document_count_metric_without_filter( - self, opensearch_datasource: OpenSearchSearchIndexDataSource - ): - doc = DocumentCountMetric( - name="document_count_metric_test", - data_source=opensearch_datasource, - index_name="numeric_metric_test", - metric_type=MetricsType.DOCUMENT_COUNT, - ) - doc_value = doc.get_value() - assert doc_value["value"] == 5 - - def test_should_return_document_count_metric_with_filter( - self, opensearch_datasource: OpenSearchSearchIndexDataSource - ): - doc = DocumentCountMetric( - name="document_count_metric_test_1", - data_source=opensearch_datasource, - index_name="numeric_metric_test", - metric_type=MetricsType.DOCUMENT_COUNT, - filters={"search_query": '{"range": {"age": {"gte": 30, "lte": 40}}}'}, - ) - doc_value = doc.get_value() - assert doc_value["value"] == 3 - - -@pytest.mark.usefixtures("setup_data", "postgres_datasource") -class TestRowCountMetric: - def test_should_return_row_count_metric_without_filter( - self, postgres_datasource: PostgresSQLDatasource - ): - row = RowCountMetric( - name="row_count_metric_test", - data_source=postgres_datasource, - table_name="numeric_metric_test", - metric_type=MetricsType.ROW_COUNT, - ) - row_value = row.get_value() - assert row_value["value"] == 5 - - def test_should_return_row_count_metric_with_filter( - self, postgres_datasource: PostgresSQLDatasource - ): - row = RowCountMetric( - name="row_count_metric_test_1", - data_source=postgres_datasource, - table_name="numeric_metric_test", - metric_type=MetricsType.ROW_COUNT, - filters={"where_clause": "age >= 30 AND age <= 40"}, - ) - row_value = row.get_value() - assert row_value["value"] == 3 - @pytest.mark.usefixtures("setup_data", "postgres_datasource", "opensearch_datasource") class TestMinColumnValueMetric: def test_should_return_min_column_value_postgres_without_filter( diff --git a/tests/core/metric/test_freshness_metric.py b/tests/core/metric/test_reliability_metric.py similarity index 65% rename from tests/core/metric/test_freshness_metric.py rename to tests/core/metric/test_reliability_metric.py index 79bf92f5..8aa0667d 100644 --- a/tests/core/metric/test_freshness_metric.py +++ b/tests/core/metric/test_reliability_metric.py @@ -20,12 +20,14 @@ from datachecks.core.configuration.configuration import \ DataSourceConnectionConfiguration +from datachecks.core.datasource.opensearch import OpenSearchSearchIndexDataSource +from datachecks.core.datasource.postgres import PostgresSQLDatasource from datachecks.core.metric.base import MetricsType -from datachecks.core.metric.freshness_metric import FreshnessValueMetric +from datachecks.core.metric.reliability_metric import FreshnessValueMetric, RowCountMetric, DocumentCountMetric from tests.utils import create_opensearch_client, create_postgres_connection -INDEX_NAME = "freshness_metric_test" -TABLE_NAME = "freshness_metric_test_table" +INDEX_NAME = "reliability_metric_test" +TABLE_NAME = "reliability_metric_test_table" @pytest.fixture(scope="class") @@ -65,6 +67,7 @@ def populate_opensearch_datasource(opensearch_client: OpenSearch): index=INDEX_NAME, body={ "name": "thor", + "age": 1500, "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=10), }, ) @@ -72,6 +75,7 @@ def populate_opensearch_datasource(opensearch_client: OpenSearch): index=INDEX_NAME, body={ "name": "captain america", + "age": 100, "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=3), }, ) @@ -79,6 +83,7 @@ def populate_opensearch_datasource(opensearch_client: OpenSearch): index=INDEX_NAME, body={ "name": "iron man", + "age": 50, "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=4), }, ) @@ -86,6 +91,7 @@ def populate_opensearch_datasource(opensearch_client: OpenSearch): index=INDEX_NAME, body={ "name": "hawk eye", + "age": 40, "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=5), }, ) @@ -93,6 +99,7 @@ def populate_opensearch_datasource(opensearch_client: OpenSearch): index=INDEX_NAME, body={ "name": "black widow", + "age": 35, "last_fight": datetime.datetime.utcnow() - datetime.timedelta(days=6), }, ) @@ -106,7 +113,7 @@ def populate_postgres_datasource(postgresql_connection: Connection): postgresql_connection.execute( text( f""" - CREATE TABLE IF NOT EXISTS {TABLE_NAME} (name VARCHAR(50), last_fight timestamp) + CREATE TABLE IF NOT EXISTS {TABLE_NAME} (name VARCHAR(50), last_fight timestamp, age INTEGER) """ ) ) @@ -114,11 +121,11 @@ def populate_postgres_datasource(postgresql_connection: Connection): insert_query = f""" INSERT INTO {TABLE_NAME} VALUES - ('thor', '{(datetime.datetime.utcnow() - datetime.timedelta(days=10)).strftime("%Y-%m-%d")}'), - ('captain america', '{(datetime.datetime.utcnow() - datetime.timedelta(days=3)).strftime("%Y-%m-%d")}'), - ('iron man', '{(datetime.datetime.utcnow() - datetime.timedelta(days=4)).strftime("%Y-%m-%d")}'), - ('hawk eye', '{(datetime.datetime.utcnow() - datetime.timedelta(days=5)).strftime("%Y-%m-%d")}'), - ('black widow', '{(datetime.datetime.utcnow() - datetime.timedelta(days=6)).strftime("%Y-%m-%d")}') + ('thor', '{(datetime.datetime.utcnow() - datetime.timedelta(days=10)).strftime("%Y-%m-%d")}', 1500), + ('captain america', '{(datetime.datetime.utcnow() - datetime.timedelta(days=3)).strftime("%Y-%m-%d")}', 90), + ('iron man', '{(datetime.datetime.utcnow() - datetime.timedelta(days=4)).strftime("%Y-%m-%d")}', 50), + ('hawk eye', '{(datetime.datetime.utcnow() - datetime.timedelta(days=5)).strftime("%Y-%m-%d")}', 40), + ('black widow', '{(datetime.datetime.utcnow() - datetime.timedelta(days=6)).strftime("%Y-%m-%d")}', 35) """ postgresql_connection.execute(text(insert_query)) @@ -127,6 +134,62 @@ def populate_postgres_datasource(postgresql_connection: Connection): print(e) +@pytest.mark.usefixtures("setup_data", "opensearch_datasource") +class TestDocumentCountMetric: + def test_should_return_document_count_metric_without_filter( + self, opensearch_datasource: OpenSearchSearchIndexDataSource + ): + doc = DocumentCountMetric( + name="document_count_metric_test", + data_source=opensearch_datasource, + index_name=INDEX_NAME, + metric_type=MetricsType.DOCUMENT_COUNT, + ) + doc_value = doc.get_value() + assert doc_value["value"] == 5 + + def test_should_return_document_count_metric_with_filter( + self, opensearch_datasource: OpenSearchSearchIndexDataSource + ): + doc = DocumentCountMetric( + name="document_count_metric_test_1", + data_source=opensearch_datasource, + index_name=INDEX_NAME, + metric_type=MetricsType.DOCUMENT_COUNT, + filters={"search_query": '{"range": {"age": {"gte": 30, "lte": 40}}}'}, + ) + doc_value = doc.get_value() + assert doc_value["value"] == 2 + + +@pytest.mark.usefixtures("setup_data", "postgres_datasource") +class TestRowCountMetric: + def test_should_return_row_count_metric_without_filter( + self, postgres_datasource: PostgresSQLDatasource + ): + row = RowCountMetric( + name="row_count_metric_test", + data_source=postgres_datasource, + table_name=TABLE_NAME, + metric_type=MetricsType.ROW_COUNT, + ) + row_value = row.get_value() + assert row_value["value"] == 5 + + def test_should_return_row_count_metric_with_filter( + self, postgres_datasource: PostgresSQLDatasource + ): + row = RowCountMetric( + name="row_count_metric_test_1", + data_source=postgres_datasource, + table_name=TABLE_NAME, + metric_type=MetricsType.ROW_COUNT, + filters={"where_clause": "age >= 30 AND age <= 40"}, + ) + row_value = row.get_value() + assert row_value["value"] == 2 + + @pytest.mark.usefixtures("setup_data", "postgres_datasource", "opensearch_datasource") class TestFreshnessValueMetric: def test_should_get_freshness_value_from_opensearch(self, opensearch_datasource):