Skip to content

Commit

Permalink
feat: new execute_commands
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Jan 19, 2024
1 parent 1d5cd99 commit 9bd5a5e
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 26 deletions.
114 changes: 88 additions & 26 deletions tabled/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,31 @@ def dataframes(tables: DataFrames) -> Iterable[pd.DataFrame]:
# Combined datasets
# See https://github.com/i2mint/tabled/discussions/3
from dataclasses import dataclass
from collections import ChainMap


def execute_commands(
commands: Iterable,
scope: Mapping,
interpreter_map: Mapping,
*,
extra_scope=None,
):
"""
Carries `commands` operations out with tables taken from `scope`.
:param commands: An iterable of join operations to carry out.
"""
if extra_scope is None:
extra_scope = dict()
_scope = ChainMap(extra_scope, scope)

for command in commands:
command_type = type(command)
command_executor = interpreter_map.get(command_type, None)
if command_executor is None:
raise TypeError(f'Unknown command type: {type(command)}')
yield command_executor(_scope, command)


# Define the JoinWith dataclass
Expand All @@ -43,6 +68,10 @@ def dataframes(tables: DataFrames) -> Iterable[pd.DataFrame]:
# remove: list = None


@dataclass
class Load:
key: str

@dataclass
class Join:
table_key: str
Expand All @@ -53,33 +82,66 @@ class Remove:
fields: Union[str, Iterable[str]]


def execute_commands(
commands: Iterable, tables: Mapping[str, pd.DataFrame]
) -> pd.DataFrame:
"""
Carries `commands` operations out with tables taken from `tables`.

:param commands: An iterable of join operations to carry out.
Each join operation is either a table name (str) or a JoinWith object.
If it's a JoinWith object, it's assumed that the table has already been joined
and the fields to remove are in the `remove` attribute of the object.
:param tables: A mapping of table names to tables (pd.DataFrame)
"""
# join_ops = map(ensure_join_op, resolution_sequence)
commands = iter(commands)
first_command = next(commands)
assert isinstance(first_command, Join)
table_key = first_command.table_key
cumul = tables[table_key] # initialize my accumulator
for command in commands:
if isinstance(command, Join):
table = tables[command.table_key]
cumul = cumul.merge(table, how='inner')
elif isinstance(command, Remove):
cumul = cumul.drop(columns=command.fields)
else:
raise TypeError(f'Unknown command type: {type(command)}')
return cumul

def set_scope_value(scope, key, value):
scope[key] = value

def load_func(scope, command):
return set_scope_value(scope, 'cumul', scope[command.key])

def join_func(scope, command):
table = scope[command.table_key]
cumul = scope['cumul']
scope['cumul'] = cumul.merge(table, how='inner')

def remove_func(scope, command):
scope['cumul'] = scope['cumul'].drop(columns=command.fields)


dflt_tables_interpreter_map = {
Load: load_func,
Join: join_func,
Remove: remove_func,
}

def execute_table_commands(
commands: Iterable,
tables: Mapping[str, pd.DataFrame],
interpreter_map: Mapping = dflt_tables_interpreter_map,
*,
extra_scope=None,
):
return execute_commands(commands, tables, interpreter_map, extra_scope=extra_scope)


# def execute_commands(
# commands: Iterable, tables: Mapping[str, pd.DataFrame]
# ) -> pd.DataFrame:
# """
# Carries `commands` operations out with tables taken from `tables`.

# :param commands: An iterable of join operations to carry out.
# Each join operation is either a table name (str) or a JoinWith object.
# If it's a JoinWith object, it's assumed that the table has already been joined
# and the fields to remove are in the `remove` attribute of the object.
# :param tables: A mapping of table names to tables (pd.DataFrame)
# """
# # join_ops = map(ensure_join_op, resolution_sequence)
# commands = iter(commands)
# first_command = next(commands)
# assert isinstance(first_command, Join)
# table_key = first_command.table_key
# cumul = tables[table_key] # initialize my accumulator
# for command in commands:
# if isinstance(command, Join):
# table = tables[command.table_key]
# cumul = cumul.merge(table, how='inner')
# elif isinstance(command, Remove):
# cumul = cumul.drop(columns=command.fields)
# else:
# raise TypeError(f'Unknown command type: {type(command)}')
# return cumul


# --------------------------------------------------------------------------------------
Expand Down
65 changes: 65 additions & 0 deletions tabled/tests/join_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,68 @@ def test_compute_join_resolution(tables):
expected_result = pd.DataFrame({'b': [1, 2], 'g': [4, 5], 'j': [4, 5]})
result = compute_join_resolution(join_sequence, tables)
pd.testing.assert_frame_equal(result, expected_result)


from tabled.multi import execute_commands


def test_execute_commands_simply():
from tabled.multi import Join, Remove, Load

# ---------------------------------------------
# First silly test

silly_interpreter_map = {
Join: lambda scope, command: f'Joining {command.table_key}',
Remove: lambda scope, command: f'Removing {command.fields}'
}

g = execute_commands(
[Join('asdf'), Remove('apple')],
scope={},
interpreter_map=silly_interpreter_map
)
assert list(g) == ['Joining asdf', 'Removing apple']

# ---------------------------------------------
# The real case

table1 = pd.DataFrame({'ID': [1, 2, 3], 'Name': ['Alice', 'Bob', 'Charlie']})
table2 = pd.DataFrame({'ID': [2, 3, 4], 'Age': [25, 30, 22]})
table3 = pd.DataFrame({'ID': [1, 2, 3, 4], 'Salary': [50000, 60000, 70000, 55000]})

tables = {
'table1': table1,
'table2': table2,
'table3': table3
}

commands = [
Load('table1'),
Remove(['Name']),
Join('table3')
]

scope = tables
extra_scope = dict()

from tabled.multi import execute_table_commands

it = execute_table_commands(
commands,
tables,
extra_scope=extra_scope
)

def are_equal(a, b):
if isinstance(a, pd.DataFrame) and isinstance(b, pd.DataFrame):
return (a == b).all().all()
else:
return a == b

next(it)
assert are_equal(extra_scope['cumul'], scope['table1'])
next(it)
assert are_equal(extra_scope['cumul'], pd.DataFrame({'ID': [1, 2, 3]}))
next(it)
assert list(extra_scope['cumul']) == ['ID', 'Salary']

0 comments on commit 9bd5a5e

Please sign in to comment.