diff --git a/datachecks/core/datasource/base.py b/datachecks/core/datasource/base.py index 59fbd49d..68506551 100644 --- a/datachecks/core/datasource/base.py +++ b/datachecks/core/datasource/base.py @@ -77,6 +77,16 @@ def query_get_max(self, index_name: str, field: str, filters: str = None) -> int :return: max value """ raise NotImplementedError("query_get_max method is not implemented") + + def query_get_avg(self, index_name: str, field: str, filters: str = None) -> int: + """ + Get the average value + :param index_name: name of the index + :param field: field name + :param filters: optional filter + :return: average value + """ + raise NotImplementedError("query_get_avg method is not implemented") def query_get_time_diff(self, index_name: str, field: str) -> int: """ @@ -134,6 +144,20 @@ def query_get_max(self, table: str, field: str, filters: str = None) -> int: query += " WHERE {}".format(filters) return self.connection.execute(text(query)).fetchone()[0] + + def query_get_avg(self, table: str, field: str, filters: str = None) -> int: + """ + Get the average value + :param table: table name + :param field: column name + :param filters: filter condition + :return: + """ + query = "SELECT ROUND(AVG({}), 2) FROM {}".format(field, table) + if filters: + query += " WHERE {}".format(filters) + + return self.connection.execute(text(query)).fetchone()[0] def query_get_time_diff(self, table: str, field: str) -> int: """ diff --git a/datachecks/core/datasource/opensearch.py b/datachecks/core/datasource/opensearch.py index d9a3d77b..7acf4f23 100644 --- a/datachecks/core/datasource/opensearch.py +++ b/datachecks/core/datasource/opensearch.py @@ -84,6 +84,21 @@ def query_get_max(self, index_name: str, field: str, filters: Dict = None) -> in response = self.client.search(index=index_name, body=query) return response["aggregations"]["max_value"]["value"] + + def query_get_avg(self, index_name: str, field: str, filters: Dict = None) -> int: + """ + Get the average value of a field + :param index_name: + :param field: + :param filters: + :return: + """ + query = {"aggs": {"avg_value": {"avg": {"field": field}}}} + if filters: + query["query"] = filters + + response = self.client.search(index=index_name, body=query) + return round(response["aggregations"]["avg_value"]["value"], 2) def query_get_time_diff(self, index_name: str, field: str) -> int: """ diff --git a/datachecks/core/metric/base.py b/datachecks/core/metric/base.py index 6109ab7e..3977cad5 100644 --- a/datachecks/core/metric/base.py +++ b/datachecks/core/metric/base.py @@ -27,6 +27,7 @@ class MetricsType(str, Enum): ROW_COUNT = "row_count" DOCUMENT_COUNT = "document_count" MAX = "max" + AVG = "avg" FRESHNESS = "freshness" diff --git a/datachecks/core/metric/numeric_metric.py b/datachecks/core/metric/numeric_metric.py index 14975354..ab4351a3 100644 --- a/datachecks/core/metric/numeric_metric.py +++ b/datachecks/core/metric/numeric_metric.py @@ -104,3 +104,35 @@ def _generate_metric_value(self): ) else: raise ValueError("Invalid data source type") + +class AvgMetric(FieldMetrics): + + """ + AvgMetric is a class that represents a metric that is generated by a data source. + """ + + def get_metric_identity(self): + return MetricIdentity.generate_identity( + metric_type=MetricsType.AVG, + metric_name=self.name, + data_source=self.data_source, + field_name=self.field_name, + table_name=self.table_name if self.table_name else None, + index_name=self.index_name if self.index_name else None, + ) + + def _generate_metric_value(self): + if isinstance(self.data_source, SQLDatasource): + return self.data_source.query_get_avg( + table=self.table_name, + field=self.field_name, + filters=self.filter_query if self.filter_query else None, + ) + elif isinstance(self.data_source, SearchIndexDataSource): + return self.data_source.query_get_avg( + index_name=self.index_name, + field=self.field_name, + filters=self.filter_query if self.filter_query else None, + ) + else: + raise ValueError("Invalid data source type") \ No newline at end of file diff --git a/tests/core/metric/test_numeric_metric.py b/tests/core/metric/test_numeric_metric.py index 3a4e4761..1e692b9b 100644 --- a/tests/core/metric/test_numeric_metric.py +++ b/tests/core/metric/test_numeric_metric.py @@ -23,7 +23,7 @@ from datachecks.core.datasource.postgres import PostgresSQLDatasource from datachecks.core.metric.base import MetricsType from datachecks.core.metric.numeric_metric import (DocumentCountMetric, - MaxMetric, RowCountMetric) + MaxMetric, AvgMetric, RowCountMetric) from tests.utils import create_opensearch_client, create_postgres_connection @@ -206,3 +206,59 @@ def test_should_return_max_column_value_opensearch_with_filter( ) row_value = row.get_value() assert row_value["value"] == 110 + +@pytest.mark.usefixtures("setup_data", "postgres_datasource", "opensearch_datasource") +class TestAvgColumnValueMetric: + def test_should_return_avg_column_value_postgres_without_filter( + self, postgres_datasource: PostgresSQLDatasource + ): + row = AvgMetric( + name="avg_metric_test", + data_source=postgres_datasource, + table_name="numeric_metric_test", + metric_type=MetricsType.AVG, + field_name="age", + ) + row_value = row.get_value() + assert float(row_value["value"]) == 141.40 + + def test_should_return_avg_column_value_postgres_with_filter( + self, postgres_datasource: PostgresSQLDatasource + ): + row = AvgMetric( + name="avg_metric_test_1", + data_source=postgres_datasource, + table_name="numeric_metric_test", + metric_type=MetricsType.AVG, + field_name="age", + filters={"where_clause": "age >= 30 AND age <= 200"}, + ) + row_value = row.get_value() + assert float(row_value["value"]) == 51.50 + + def test_should_return_avg_column_value_opensearch_without_filter( + self, opensearch_datasource: OpenSearchSearchIndexDataSource + ): + row = AvgMetric( + name="avg_metric_test", + data_source=opensearch_datasource, + index_name="numeric_metric_test", + metric_type=MetricsType.AVG, + field_name="age", + ) + row_value = row.get_value() + assert float(row_value["value"]) == 141.40 + + def test_should_return_avg_column_value_opensearch_with_filter( + self, opensearch_datasource: OpenSearchSearchIndexDataSource + ): + row = AvgMetric( + name="avg_metric_test_1", + data_source=opensearch_datasource, + index_name="numeric_metric_test", + metric_type=MetricsType.AVG, + field_name="age", + filters={"search_query": '{"range": {"age": {"gte": 30, "lte": 200}}}'}, + ) + row_value = row.get_value() + assert float(row_value["value"]) == 51.50