Skip to content

Commit

Permalink
fix: fix data table user function exponent bug
Browse files Browse the repository at this point in the history
Data table variable in exponent of user function would cause error. Test added.
  • Loading branch information
mgreminger committed Oct 3, 2024
1 parent 916cac0 commit ed03b37
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
29 changes: 14 additions & 15 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,7 @@ def new_func(expr: Expr,
func_key: Literal["dim_func"] | Literal["sympy_func"],
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
data_table_subs: DataTableSubs) -> Expr:
data_table_subs: DataTableSubs | None) -> Expr:
result = func(expr, func_key, placeholder_map,
placeholder_set, data_table_subs)

Expand All @@ -1523,7 +1523,7 @@ def replace_placeholder_funcs(expr: Expr,
func_key: Literal["dim_func"] | Literal["sympy_func"],
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
data_table_subs: DataTableSubs) -> Expr:
data_table_subs: DataTableSubs | None) -> Expr:
if is_matrix(expr):
rows = []
for i in range(expr.rows):
Expand Down Expand Up @@ -1568,7 +1568,7 @@ def replace_placeholder_funcs(expr: Expr,
return cast(Expr, expr.func(*matrix_args))
else:
return cast(Expr, expr.func(*processed_args))
elif expr.func == data_table_calc_wrapper:
elif data_table_subs is not None and expr.func == data_table_calc_wrapper:
if len(expr.args[0].atoms(data_table_id_wrapper)) == 0:
return replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, data_table_subs)

Expand All @@ -1595,16 +1595,18 @@ def replace_placeholder_funcs(expr: Expr,
else:
return cast(Expr, Matrix([sub_expr,]*shortest_col))

elif expr.func == data_table_id_wrapper:
new_var = Symbol(f"_data_table_var_{data_table_subs.get_next_id()}")
elif data_table_subs is not None and expr.func == data_table_id_wrapper:
current_expr = replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, data_table_subs)
data_table_subs.subs_stack[-1][new_var] = current_expr
new_var = Symbol(f"_data_table_var_{data_table_subs.get_next_id()}")

if not is_matrix(current_expr):
raise EmptyColumnData(current_expr)

if len(data_table_subs.subs_stack) > 0:
data_table_subs.subs_stack[-1][new_var] = cast(Expr, current_expr)

if is_matrix(current_expr):
if data_table_subs.shortest_col_stack[-1] is None or current_expr.rows < data_table_subs.shortest_col_stack[-1]:
data_table_subs.shortest_col_stack[-1] = current_expr.rows
else:
raise EmptyColumnData(current_expr)

if func_key == "sympy_func":
return new_var
Expand Down Expand Up @@ -1852,8 +1854,7 @@ def solve_system(statements: list[EqualityStatement], variables: list[str],
{exponent["name"]:exponent["expression"] for exponent in cast(list[Exponent], statement["exponents"])})
equality = replace_placeholder_funcs(cast(Expr, equality),
"sympy_func",
placeholder_map, placeholder_set,
DataTableSubs())
placeholder_map, placeholder_set, None)

system.append(cast(Expr, equality.doit()))

Expand Down Expand Up @@ -1944,9 +1945,7 @@ def solve_system_numerical(statements: list[EqualityStatement], variables: list[
equality = equality.subs(parameter_subs)
equality = replace_placeholder_funcs(cast(Expr, equality),
"sympy_func",
placeholder_map,
placeholder_set,
DataTableSubs())
placeholder_map, placeholder_set, None)
system.append(cast(Expr, equality.doit()))
new_statements.extend(statement["equalityUnitsQueries"])

Expand Down Expand Up @@ -2626,7 +2625,7 @@ def evaluate_statements(statements: list[InputAndSystemStatement],
"sympy_func",
placeholder_map,
placeholder_set,
DataTableSubs())
None)

exponent_subs[symbols(exponent_name+current_function_name)] = final_expression

Expand Down
28 changes: 27 additions & 1 deletion tests/test_data_table.spec.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ test('Test greek character function name', async () => {
expect(content).toBe('');
});

test('Test interpolation and polfit with numerical solve', async () => {
test('Test interpolation and polyfit with numerical solve', async () => {
const modifierKey = (await page.evaluate('window.modifierKey') )=== "metaKey" ? "Meta" : "Control";

await page.setLatex(0, String.raw`t1=`);
Expand Down Expand Up @@ -1329,4 +1329,30 @@ test('Test factorial function in data table', async () => {
let content = await page.textContent(`#result-value-0`);
expect(content).toBe(String.raw`\begin{bmatrix} 1 \\ 2 \\ 6 \end{bmatrix}`);

});

test('Test data table user function exponent bug', async () => {
await page.setLatex(0, String.raw`y=2\left\lbrack m\right\rbrack^{\frac{x}{1\left\lbrack in\right\rbrack}}`);

await page.locator('#add-math-cell').click();
await page.setLatex(1, String.raw`Col2=`);

await page.locator('#add-data-table-cell').click();

await expect(page.locator('#data-table-input-2-0-0')).toBeFocused();

await page.keyboard.type('1');
await page.keyboard.press('Enter');
await page.keyboard.type('2');
await page.keyboard.press('Enter');
await page.keyboard.type('3');

await page.locator('#parameter-units-2-0 >> math-field').type('[in]');

await page.setLatex(2, String.raw`Col2=y\left(x=Col1\right)`, 1);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent(`#result-value-1`);
expect(content).toBe(String.raw`\begin{bmatrix} 2\left\lbrack m\right\rbrack \\ 4\left\lbrack m\right\rbrack \\ 8\left\lbrack m\right\rbrack \end{bmatrix}`);
});

0 comments on commit ed03b37

Please sign in to comment.