Skip to content

Commit

Permalink
Fixed set function bug
Browse files Browse the repository at this point in the history
  • Loading branch information
argenisleon committed Nov 10, 2019
1 parent d32000a commit 9d1111c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
12 changes: 9 additions & 3 deletions optimus/dataframe/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
from optimus.profiler.functions import fill_missing_var_types, parse_profiler_dtypes

ENGINE = "spark"
# Because the monkey patching and the need to call set a function we need to rename the standard python set.
# This is awful but the best option for the user.
python_set = set


def cols(self):
Expand Down Expand Up @@ -261,6 +264,7 @@ def apply_by_dtypes(columns, func, func_return_type, args=None, func_type=None,
when=fbdt(col_name, data_type))
return df

# TODO: Maybe we could merge this with apply()
@add_attr(cols)
def set(output_col, value=None):
"""
Expand All @@ -278,12 +282,14 @@ def set(output_col, value=None):
expr = F.array([F.lit(x) for x in value])
elif is_numeric(value):
expr = F.lit(value)
elif value:
elif is_str(value):
expr = F.expr(value)
else:
RaiseIt.value_error(value, ["numeric", "list", "hive expression"])

return df.withColumn(output_col, expr)
df = df.withColumn(output_col, expr)
df = df.preserve_meta(self, Actions.SET.value, columns)
return df

# TODO: Check if we must use * to select all the columns
@add_attr(cols)
Expand Down Expand Up @@ -1725,7 +1731,7 @@ def str_to_array(_value):
return str_to_data_type(_value, (list, tuple))

def str_to_object(_value):
return str_to_data_type(_value, (dict, set))
return str_to_data_type(_value, (dict, python_set))

def str_to_data_type(_value, _dtypes):
"""
Expand Down
1 change: 1 addition & 0 deletions optimus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class Actions(Enum):
UNNEST = "unnest"
DROP_ROW = "drop_row"
VALUES_TO_COLS = "values_to_cols"
SET = "set"

@staticmethod
def list():
Expand Down

0 comments on commit 9d1111c

Please sign in to comment.