Skip to content

Commit

Permalink
fix: refactor VALUE_MAPPING and ADD_LABELS operators
Browse files Browse the repository at this point in the history
Signed-off-by: seolmin <[email protected]>
  • Loading branch information
stat-kwon committed Dec 10, 2024
1 parent ba4bd69 commit 7b2f534
Showing 1 changed file with 97 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -435,31 +435,14 @@ def add_labels_data_table(
data_table_vo = self.data_table_vos[0]
label_keys = list(data_table_vo.labels_info.keys())
data_keys = list(data_table_vo.data_info.keys())

labels = self.options.get("labels")

if not labels:
raise ERROR_REQUIRED_PARAMETER(key="options.ADD_LABELS.labels")

df = self._get_data_table(data_table_vo, granularity, start, end, vars)

for label_key in labels.keys():
if label_key in df.columns:
raise ERROR_INVALID_PARAMETER(
key="options.ADD_LABELS.labels",
reason=f"Duplicated key: {label_key}, columns={list(df.columns)}",
)
self.validate_labels(labels, df)

self.add_labels_to_dataframe(df, labels, label_keys, data_keys)

for key, value in labels.items():
df[key] = value
if isinstance(value, str):
label_keys.append(key)
elif isinstance(value, (int, float)):
data_keys.append(key)
else:
raise ERROR_INVALID_PARAMETER_TYPE(
key="options.ADD_LABELS.labels", type=type(value)
)
df = df.reindex(columns=label_keys + data_keys)

self.label_keys = label_keys
Expand All @@ -469,7 +452,7 @@ def add_labels_data_table(

def value_mapping_data_table(
self,
granularity: GRANULARITY = "MONTHLY",
granularity: str = "MONTHLY",
start: str = None,
end: str = None,
vars: dict = None,
Expand All @@ -479,45 +462,14 @@ def value_mapping_data_table(

self.label_keys = list(data_table_vo.labels_info.keys())
self.data_keys = list(data_table_vo.data_info.keys())

name = self.options["name"]
field_type = self.options.get("field_type", "LABEL")
else_value = self.options.get("else", None)

if condition := self.options.get("condition"):
if self.is_jinja_expression(condition):
condition, gv_type_map = self.change_global_variables(condition, vars)
condition = self.remove_jinja_braces(condition)
condition = self.change_expression_data_type(condition, gv_type_map)

filtered_df = df.query(condition).copy()
else:
filtered_df = df.copy()

filtered_df.loc[:, name] = else_value

if cases := self.options.get("cases", []):
for case in cases:
self._validate_case(case)
key = case["key"]
operator = case["operator"]
value = case["value"]
match = case["match"]

if operator == "eq":
filtered_df.loc[filtered_df[key] == match, name] = value
elif operator == "regex":
filtered_df.loc[filtered_df[key].str.contains(match), name] = value
filtered_df = self.filter_data(df, vars)
filtered_df = self.apply_cases(filtered_df)

df.loc[filtered_df.index, name] = filtered_df[name]

unfiltered_index = df.index.difference(filtered_df.index)
if field_type == "LABEL":
df.loc[unfiltered_index, name] = ""
self.data_keys.append(name)
elif field_type == "DATA":
df.loc[unfiltered_index, name] = 0
self.data_keys.append(name)
self.handle_unfiltered_data(df, filtered_df, name, field_type)

self.df = df

Expand Down Expand Up @@ -843,6 +795,96 @@ def _sort_and_filter_pivot_table(self, pivot_table: pd.DataFrame) -> pd.DataFram

return pivot_table

@staticmethod
def validate_labels(labels: dict, df: pd.DataFrame) -> None:
if not labels:
raise ERROR_REQUIRED_PARAMETER(key="options.ADD_LABELS.labels")

for label_key in labels.keys():
if label_key in df.columns:
raise ERROR_INVALID_PARAMETER(
key="options.ADD_LABELS.labels",
reason=f"Duplicated key: {label_key}, columns={list(df.columns)}",
)

@staticmethod
def update_keys(
key: str,
value,
label_keys: list,
data_keys: list,
) -> None:
if isinstance(value, str):
label_keys.append(key)
elif isinstance(value, (int, float)):
data_keys.append(key)
else:
raise ERROR_INVALID_PARAMETER_TYPE(
key="options.ADD_LABELS.labels", type=type(value)
)

def add_labels_to_dataframe(
self,
df: pd.DataFrame,
labels: dict,
label_keys: list,
data_keys: list,
) -> None:
for key, value in labels.items():
df[key] = value
self.update_keys(key, value, label_keys, data_keys)

def filter_data(self, df: pd.DataFrame, vars: dict) -> pd.DataFrame:
condition = self.options.get("condition")
if not condition:
return df.copy()

if self.is_jinja_expression(condition):
condition, gv_type_map = self.change_global_variables(condition, vars)
condition = self.remove_jinja_braces(condition)
condition = self.change_expression_data_type(condition, gv_type_map)

return df.query(condition).copy()

def apply_cases(self, filtered_df: pd.DataFrame) -> pd.DataFrame:
name = self.options["name"]
else_value = self.options.get("else", None)
cases = self.options.get("cases", [])

filtered_df.loc[:, name] = else_value

for case in cases:
self._validate_case(case)
key = case["key"]
operator = case["operator"]
value = case["value"]
match = case["match"]

if operator == "eq":
filtered_df.loc[filtered_df[key] == match, name] = value
elif operator == "regex":
filtered_df.loc[
filtered_df[key].str.contains(match, na=False), name
] = value

return filtered_df

def handle_unfiltered_data(
self,
df: pd.DataFrame,
filtered_df: pd.DataFrame,
name: str,
field_type: str,
):
unfiltered_index = df.index.difference(filtered_df.index)

if field_type == "LABEL":
df.loc[unfiltered_index, name] = ""
self.label_keys.append(name)
elif field_type == "DATA":
df.loc[unfiltered_index, name] = 0
self.data_keys.append(name)

def _apply_row_sorting(
self,
pivot_table: pd.DataFrame,
Expand Down

0 comments on commit 7b2f534

Please sign in to comment.