Skip to content

Commit

Permalink
update metric manager invocation (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
datageek00 authored Aug 14, 2023
1 parent 50d1e08 commit e0fc398
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 216 deletions.
42 changes: 18 additions & 24 deletions datachecks/core/metric/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,29 @@ def __init__(
name: str,
data_source: DataSource,
metric_type: MetricsType,
table_name: Optional[str] = None,
index_name: Optional[str] = None,
filters: Dict = None,
metric_logger: MetricLogger = None,
**kwargs,
):
if index_name is not None and table_name is not None:
raise ValueError(
"Please give a value for table_name or index_name (but not both)"
)
if index_name is None and table_name is None:
if "index_name" in kwargs and "table_name" in kwargs:
if kwargs["index_name"] is not None and kwargs["table_name"] is not None:
raise ValueError(
"Please give a value for table_name or index_name (but not both)"
)
if "index_name" not in kwargs and "table_name" not in kwargs:
raise ValueError("Please give a value for table_name or index_name")

self.index_name, self.table_name = None, None
if index_name:
self.index_name = index_name
if table_name:
self.table_name = table_name
if "index_name" in kwargs:
self.index_name = kwargs["index_name"]
if "table_name" in kwargs:
self.table_name = kwargs["table_name"]

self.name: str = name
self.data_source = data_source
self.metric_type = metric_type
self.filter_query = None
if filters is not None:
if "filters" in kwargs and kwargs["filters"] is not None:
filters = kwargs["filters"]
if ("search_query" in filters and filters["search_query"] is not None) and (
"where_clause" in filters and filters["where_clause"] is not None
):
Expand All @@ -109,7 +109,6 @@ def get_metric_identity(self):
data_source=self.data_source,
)

@abstractmethod
def _generate_metric_value(self) -> float:
pass

Expand Down Expand Up @@ -143,25 +142,20 @@ class FieldMetrics(Metric, ABC):
def __init__(
self,
name: str,
metric_type: MetricsType,
data_source: DataSource,
field_name: str,
table_name: Optional[str] = None,
index_name: Optional[str] = None,
filters: Dict = None,
metric_type: MetricsType,
metric_logger: MetricLogger = None,
**kwargs,
):
super().__init__(
name=name,
data_source=data_source,
table_name=table_name,
index_name=index_name,
metric_type=metric_type,
filters=filters,
metric_logger=metric_logger,
**kwargs,
)

self.field_name = field_name
if "field_name" in kwargs:
self.field_name = kwargs["field_name"]

@property
def get_field_name(self):
Expand Down
108 changes: 41 additions & 67 deletions datachecks/core/metric/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,25 @@
# limitations under the License.

from dataclasses import asdict
from typing import Dict, List
from typing import List

from loguru import logger

from datachecks.core.configuration.configuration import MetricConfiguration
from datachecks.core.datasource.manager import DataSourceManager
from datachecks.core.logger.base import MetricLogger
from datachecks.core.metric.base import Metric, MetricsType
from datachecks.core.metric.freshness_metric import FreshnessValueMetric
from datachecks.core.metric.numeric_metric import (DocumentCountMetric,
MaxMetric, RowCountMetric)
from datachecks.core.metric.numeric_metric import *
from datachecks.core.metric.reliability_metric import *


class MetricManager:
METRIC_CLASS_MAPPING = {
MetricsType.DOCUMENT_COUNT.value: "DocumentCountMetric",
MetricsType.ROW_COUNT.value: "RowCountMetric",
MetricsType.MAX.value: "MaxMetric",
MetricsType.FRESHNESS.value: "FreshnessValueMetric",
}

def __init__(
self,
metric_config: Dict[str, List[MetricConfiguration]],
Expand All @@ -39,68 +46,35 @@ def __init__(
def _build_metrics(self, config: Dict[str, List[MetricConfiguration]]):
for data_source, metric_list in config.items():
for metric_config in metric_list:
if metric_config.metric_type == MetricsType.DOCUMENT_COUNT:
metric = DocumentCountMetric(
name=metric_config.name,
data_source=self.data_source_manager.get_data_source(
data_source
),
filters=asdict(metric_config.filters)
if metric_config.filters
else None,
index_name=metric_config.index,
metric_type=MetricsType.DOCUMENT_COUNT,
metric_logger=self.metric_logger,
)
self.metrics[metric.get_metric_identity()] = metric
elif metric_config.metric_type == MetricsType.ROW_COUNT:
metric = RowCountMetric(
name=metric_config.name,
data_source=self.data_source_manager.get_data_source(
data_source
),
filters=asdict(metric_config.filters)
if metric_config.filters
else None,
table_name=metric_config.table,
metric_type=MetricsType.ROW_COUNT,
metric_logger=self.metric_logger,
)
self.metrics[metric.get_metric_identity()] = metric
elif metric_config.metric_type == MetricsType.MAX:
metric = MaxMetric(
name=metric_config.name,
data_source=self.data_source_manager.get_data_source(
data_source
),
filters=asdict(metric_config.filters)
if metric_config.filters
else None,
table_name=metric_config.table,
index_name=metric_config.index,
metric_type=MetricsType.MAX,
field_name=metric_config.field,
metric_logger=self.metric_logger,
)
self.metrics[metric.get_metric_identity()] = metric
elif metric_config.metric_type == MetricsType.FRESHNESS:
metric = FreshnessValueMetric(
name=metric_config.name,
data_source=self.data_source_manager.get_data_source(
data_source
),
filters=asdict(metric_config.filters)
if metric_config.filters
else None,
table_name=metric_config.table,
index_name=metric_config.index,
metric_type=MetricsType.FRESHNESS,
field_name=metric_config.field,
metric_logger=self.metric_logger,
)
self.metrics[metric.get_metric_identity()] = metric
else:
raise ValueError("Invalid metric type")
params = {
"filters": asdict(metric_config.filters)
if metric_config.filters
else None,
}
if metric_config.index:
params["index_name"] = metric_config.index
if metric_config.table:
params["table_name"] = metric_config.table
if metric_config.field:
params["field_name"] = metric_config.field

logger.info(f"==============metric_config: {self.METRIC_CLASS_MAPPING}")
logger.info(
f"==============metric_config.metric_type: {metric_config.metric_type}"
)
# logger.info(globals())
metric: Metric = globals()[
self.METRIC_CLASS_MAPPING[metric_config.metric_type]
](
metric_config.name,
self.data_source_manager.get_data_source(data_source),
MetricsType(metric_config.metric_type.lower()),
self.metric_logger,
**params,
)

logger.info(metric.__dict__)
self.metrics[metric.get_metric_identity()] = metric

@property
def get_metrics(self):
Expand Down
62 changes: 5 additions & 57 deletions datachecks/core/metric/numeric_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,10 @@

from datachecks.core.datasource.search_datasource import SearchIndexDataSource
from datachecks.core.datasource.sql_datasource import SQLDatasource
from datachecks.core.metric.base import (FieldMetrics, Metric, MetricIdentity,
from datachecks.core.metric.base import (FieldMetrics, MetricIdentity,
MetricsType)


class DocumentCountMetric(Metric):
"""
DocumentCountMetrics is a class that represents a metric that is generated by a data source.
"""

def validate_data_source(self):
return isinstance(self.data_source, SearchIndexDataSource)

def get_metric_identity(self):
return MetricIdentity.generate_identity(
metric_type=MetricsType.DOCUMENT_COUNT,
metric_name=self.name,
data_source=self.data_source,
index_name=self.index_name,
)

def _generate_metric_value(self):
if isinstance(self.data_source, SearchIndexDataSource):
return self.data_source.query_get_document_count(
index_name=self.index_name,
filters=self.filter_query if self.filter_query else None,
)
else:
raise ValueError("Invalid data source type")


class RowCountMetric(Metric):

"""
RowCountMetrics is a class that represents a metric that is generated by a data source.
"""

def get_metric_identity(self):
return MetricIdentity.generate_identity(
metric_type=MetricsType.ROW_COUNT,
metric_name=self.name,
data_source=self.data_source,
table_name=self.table_name,
)

def validate_data_source(self):
return isinstance(self.data_source, SQLDatasource)

def _generate_metric_value(self):
if isinstance(self.data_source, SQLDatasource):
return self.data_source.query_get_row_count(
table=self.table_name,
filters=self.filter_query if self.filter_query else None,
)
else:
raise ValueError("Invalid data source type")

class MinMetric(FieldMetrics):

"""
Expand All @@ -87,7 +35,7 @@ def get_metric_identity(self):
table_name=self.table_name if self.table_name else None,
index_name=self.index_name if self.index_name else None,
)

def _generate_metric_value(self):
if isinstance(self.data_source, SQLDatasource):
return self.data_source.query_get_min(
Expand All @@ -105,7 +53,6 @@ def _generate_metric_value(self):
raise ValueError("Invalid data source type")



class MaxMetric(FieldMetrics):

"""
Expand Down Expand Up @@ -138,6 +85,7 @@ def _generate_metric_value(self):
else:
raise ValueError("Invalid data source type")


class AvgMetric(FieldMetrics):

"""
Expand All @@ -153,7 +101,7 @@ def get_metric_identity(self):
table_name=self.table_name if self.table_name else None,
index_name=self.index_name if self.index_name else None,
)

def _generate_metric_value(self):
if isinstance(self.data_source, SQLDatasource):
return self.data_source.query_get_avg(
Expand All @@ -168,4 +116,4 @@ def _generate_metric_value(self):
filters=self.filter_query if self.filter_query else None,
)
else:
raise ValueError("Invalid data source type")
raise ValueError("Invalid data source type")
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,60 @@
from datachecks.core.datasource.search_datasource import SearchIndexDataSource
from datachecks.core.datasource.sql_datasource import SQLDatasource
from datachecks.core.metric.base import (FieldMetrics, MetricIdentity,
MetricsType)
MetricsType, Metric)


class DocumentCountMetric(Metric):
"""
DocumentCountMetrics is a class that represents a metric that is generated by a data source.
"""

def validate_data_source(self):
return isinstance(self.data_source, SearchIndexDataSource)

def get_metric_identity(self):
return MetricIdentity.generate_identity(
metric_type=MetricsType.DOCUMENT_COUNT,
metric_name=self.name,
data_source=self.data_source,
index_name=self.index_name,
)

def _generate_metric_value(self):
if isinstance(self.data_source, SearchIndexDataSource):
return self.data_source.query_get_document_count(
index_name=self.index_name,
filters=self.filter_query if self.filter_query else None,
)
else:
raise ValueError("Invalid data source type")


class RowCountMetric(Metric):

"""
RowCountMetrics is a class that represents a metric that is generated by a data source.
"""

def get_metric_identity(self):
return MetricIdentity.generate_identity(
metric_type=MetricsType.ROW_COUNT,
metric_name=self.name,
data_source=self.data_source,
table_name=self.table_name,
)

def validate_data_source(self):
return isinstance(self.data_source, SQLDatasource)

def _generate_metric_value(self):
if isinstance(self.data_source, SQLDatasource):
return self.data_source.query_get_row_count(
table=self.table_name,
filters=self.filter_query if self.filter_query else None,
)
else:
raise ValueError("Invalid data source type")


class FreshnessValueMetric(FieldMetrics):
Expand Down
2 changes: 1 addition & 1 deletion tests/core/metric/test_metric_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from datachecks.core.datasource.manager import DataSourceManager
from datachecks.core.metric.base import MetricsType
from datachecks.core.metric.manager import MetricManager
from datachecks.core.metric.numeric_metric import DocumentCountMetric
from datachecks.core.metric.reliability_metric import DocumentCountMetric

OPEN_SEARCH_DATA_SOURCE_NAME = "test_open_search_data_source"
POSTGRES_DATA_SOURCE_NAME = "test_postgres_data_source"
Expand Down
Loading

0 comments on commit e0fc398

Please sign in to comment.