Skip to content

Commit

Permalink
Merge pull request #294 from mgreminger/matmult-fix
Browse files Browse the repository at this point in the history
fix: fix issues with matrix multiplication
  • Loading branch information
mgreminger authored Nov 3, 2024
2 parents 29034ca + fdd6df4 commit fc356b1
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 43 deletions.
61 changes: 23 additions & 38 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# must be at least 131 to load sympy, cpython is 3000 by default
setrecursionlimit(1000)

from functools import lru_cache, partial
from functools import lru_cache, partial, reduce
import traceback
from importlib import import_module

Expand Down Expand Up @@ -1017,21 +1017,23 @@ def custom_matmul(exp1: Expr, exp2: Expr):
((exp1.rows == 1 and exp1.cols == 3) and (exp2.rows == 1 and exp2.cols == 3))):
return exp1.cross(exp2)
else:
return MatMul(exp1, exp2)
return Mul(exp1, exp2)

def custom_matmul_dims(exp1: Expr, exp2: Expr):
if is_matrix(exp1) and is_matrix(exp2) and \
(((exp1.rows == 3 and exp1.cols == 1) and (exp2.rows == 3 and exp2.cols == 1)) or \
((exp1.rows == 1 and exp1.cols == 3) and (exp2.rows == 1 and exp2.cols == 3))):
result = Matrix([Add(Mul(exp1[1],exp2[2]),Mul(exp1[2],exp2[1])),
Add(Mul(exp1[2],exp2[0]),Mul(exp1[0],exp2[2])),
Add(Mul(exp1[0],exp2[1]),Mul(exp1[1],exp2[0]))])
if exp1.rows == 3:
def custom_matmul_dims(*args: Expr):
if len(args) == 2 and is_matrix(args[0]) and is_matrix(args[1]) and \
(((args[0].rows == 3 and args[0].cols == 1) and (args[1].rows == 3 and args[1].cols == 1)) or \
((args[0].rows == 1 and args[0].cols == 3) and (args[1].rows == 1 and args[1].cols == 3))):

result = Matrix([Add(Mul(args[0][1],args[1][2]),Mul(args[0][2],args[1][1])),
Add(Mul(args[0][2],args[1][0]),Mul(args[0][0],args[1][2])),
Add(Mul(args[0][0],args[1][1]),Mul(args[0][1],args[1][0]))])

if args[0].rows == 3:
return result
else:
return result.T
else:
return MatMul(exp1, exp2)
return Mul(*args)

def custom_min(*args: Expr):
if len(args) == 1 and is_matrix(args[0]):
Expand Down Expand Up @@ -1479,7 +1481,8 @@ def get_next_id(self):
cast(Function, Function('_Inverse')) : {"dim_func": ensure_inverse_dims, "sympy_func": UniversalInverse},
cast(Function, Function('_Transpose')) : {"dim_func": custom_transpose, "sympy_func": custom_transpose},
cast(Function, Function('_Determinant')) : {"dim_func": custom_determinant, "sympy_func": custom_determinant},
cast(Function, Function('_MatMul')) : {"dim_func": custom_matmul_dims, "sympy_func": custom_matmul},
cast(Function, Function('_mat_multiply')) : {"dim_func": custom_matmul_dims, "sympy_func": custom_matmul},
cast(Function, Function('_multiply')) : {"dim_func": Mul, "sympy_func": Mul},
cast(Function, Function('_IndexMatrix')) : {"dim_func": IndexMatrix, "sympy_func": IndexMatrix},
cast(Function, Function('_Eq')) : {"dim_func": Eq, "sympy_func": Eq},
cast(Function, Function('_norm')) : {"dim_func": custom_norm, "sympy_func": custom_norm},
Expand All @@ -1495,6 +1498,7 @@ def get_next_id(self):

global_placeholder_set = set(global_placeholder_map.keys())
dummy_var_placeholder_set = (Function('_Derivative'), Function('_Integral'))
multiply_placeholder_set = (Function('_multiply'), Function('_mat_multiply'))
placeholder_inverse_map = { value["sympy_func"]: key for key, value in reversed(global_placeholder_map.items()) }
placeholder_inverse_set = set(placeholder_inverse_map.keys())

Expand All @@ -1507,24 +1511,6 @@ def replace_sympy_funcs_with_placeholder_funcs(expression: Expr) -> Expr:

return expression


def doit_for_dim_func(func):
def new_func(expr: Expr,
func_key: Literal["dim_func"] | Literal["sympy_func"],
placeholder_map: dict[Function, PlaceholderFunction],
placeholder_set: set[Function],
data_table_subs: DataTableSubs | None) -> Expr:
result = func(expr, func_key, placeholder_map,
placeholder_set, data_table_subs)

if func_key == "dim_func":
return cast(Expr, result.doit())
else:
return result

return new_func

@doit_for_dim_func
def replace_placeholder_funcs(expr: Expr,
func_key: Literal["dim_func"] | Literal["sympy_func"],
placeholder_map: dict[Function, PlaceholderFunction],
Expand All @@ -1545,11 +1531,7 @@ def replace_placeholder_funcs(expr: Expr,
if len(expr.args) == 0:
return expr

if expr.func in dummy_var_placeholder_set and func_key == "dim_func":
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args))))
elif expr.func in placeholder_set:
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) for arg in expr.args)))
elif func_key == "dim_func" and (expr.func is Mul or expr.func is MatMul):
if func_key == "dim_func" and expr.func in multiply_placeholder_set:
processed_args = [replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) for arg in expr.args]
matrix_args = []
scalar_args = []
Expand All @@ -1571,9 +1553,13 @@ def replace_placeholder_funcs(expr: Expr,

matrix_args[0] = Matrix(new_rows)

return cast(Expr, expr.func(*matrix_args))
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*matrix_args))
else:
return cast(Expr, expr.func(*processed_args))
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*processed_args))
elif expr.func in dummy_var_placeholder_set and func_key == "dim_func":
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) if index > 0 else arg for index, arg in enumerate(expr.args))))
elif expr.func in placeholder_set:
return cast(Expr, cast(Callable, placeholder_map[expr.func][func_key])(*(replace_placeholder_funcs(cast(Expr, arg), func_key, placeholder_map, placeholder_set, data_table_subs) for arg in expr.args)))
elif data_table_subs is not None and expr.func == data_table_calc_wrapper:
if len(expr.args[0].atoms(data_table_id_wrapper)) == 0:
return replace_placeholder_funcs(cast(Expr, expr.args[0]), func_key, placeholder_map, placeholder_set, data_table_subs)
Expand Down Expand Up @@ -2394,7 +2380,6 @@ def get_evaluated_expression(expression: Expr,
placeholder_map,
placeholder_set,
DataTableSubs())
expression = cast(Expr, expression.doit())
if not is_matrix(expression):
if simplify_symbolic_expressions:
try:
Expand Down
6 changes: 3 additions & 3 deletions src/App.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@
const apiUrl = window.location.origin;
const currentVersion = 20240827;
const tutorialHash = "hUts8q3sKUqJGFUwSdL5ZS";
const currentVersion = 20241102;
const tutorialHash = "moJCuTwjPi7dZeZn5QiuaP";
const termsVersion = 20240110;
let termsAccepted = 0;
Expand Down Expand Up @@ -132,7 +132,7 @@
title: "Equation Solving"
},
{
path: "/8pWM9yEqEPNntRBd6Jr9Sv",
path: "/V53SzSCEixmE9MQz6m66mk",
title: "Matrices and Vectors"
},
{
Expand Down
16 changes: 16 additions & 0 deletions src/Updates.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@
}
</style>

<em>November 2, 2024</em>
<h4>Matrix Multiplication Improvements</h4>
<p>
The two multiplication symbols (the dot symbol obtained using the * key and the x symbol
obtained using the @ key) can now be used interchangeably and will correctly apply
either scalar or matrix multiplication as appropriate for the situation.
Previously, the dot symbol would give unexpected results for some matrix
multiplication situations. The only difference between the two symbols is that the
x multiplication symbol will automatically perform a cross product when operating on
compatible vectors (both 3x1 or both 1x3). This cross product behavior is unchanged from
previous versions. This update also fixes some situations where the result of a
matrix multiplication passed to a function would cause an error.
</p>

<br>

<em>August 27, 2024</em>
<h4>PDF Export Improvements</h4>
<p>
Expand Down
4 changes: 2 additions & 2 deletions src/parser/LatexToSympy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1956,11 +1956,11 @@ export class LatexToSympy extends LatexParserVisitor<string | Statement | UnitBl
}

visitMultiply = (ctx: MultiplyContext) => {
return `${this.visit(ctx.expr(0))}*${this.visit(ctx.expr(1))}`;
return `_multiply(${this.visit(ctx.expr(0))}, ${this.visit(ctx.expr(1))})`;
}

visitMatrixMultiply = (ctx: MatrixMultiplyContext) => {
return `_MatMul(${this.visit(ctx.expr(0))}, ${this.visit(ctx.expr(1))})`;
return `_mat_multiply(${this.visit(ctx.expr(0))}, ${this.visit(ctx.expr(1))})`;
}

visitUnitMultiply = (ctx: UnitMultiplyContext) => {
Expand Down
Binary file modified tests/images/webkit_reference.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions tests/test_calc.spec.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,20 @@ test('Test derivative dimensional analysis bug', async () => {
expect(content).toBe('');
});

test('Test unitless nested integral in exponent', async () => {
await page.setLatex(0, String.raw`2^{\int_{-\frac{h}{2}}^{\frac{h}{2}}\left(\int_{-\frac{w}{2}}^{\frac{w}{2}}\left(y^2\right)\mathrm{d}\left(x\right)\right)\mathrm{d}\left(y\right)}=`);

await page.locator('#add-math-cell').click();
await page.setLatex(1, 'h=2');

await page.locator('#add-math-cell').click();
await page.setLatex(2, 'w=3');

await page.waitForSelector('.status-footer', {state: 'detached'});

let content = await page.textContent('#result-value-0');
expect(parseLatexFloat(content)).toBeCloseTo(4, precision);
content = await page.textContent('#result-units-0');
expect(content).toBe('');
});

88 changes: 88 additions & 0 deletions tests/test_matrix_multiplication.spec.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -450,3 +450,91 @@ test('Dot product with column vectors and numeric entries with units', async ()
expect(content).toBe('m^2');
});

test('Matrix multiplication with cdot multiplication symbol case 1', async () => {
await page.setLatex(0, String.raw`A=\begin{bmatrix}1\\ 2\end{bmatrix},\:B=\begin{bmatrix}3 & 4\end{bmatrix}`);

await page.locator('#add-math-cell').click();
await page.setLatex(1, String.raw`A\cdot B=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent('#result-value-1');
expect(content).toBe(String.raw`\begin{bmatrix} 3 & 4 \\ 6 & 8 \end{bmatrix}`);
});

test('Matrix multiplication with cdot multiplication symbol case 2', async () => {
await page.setLatex(0, String.raw`C=\begin{bmatrix}1 & 2\end{bmatrix},\:D=\begin{bmatrix}3\\ 4\end{bmatrix}`);

await page.locator('#add-math-cell').click();
await page.setLatex(1, String.raw`C\cdot D=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent('#result-value-1');
expect(content).toBe(String.raw`\begin{bmatrix} 11 \end{bmatrix}`);
});

test('Matrix multiplication passed as argument to function', async () => {
await page.setLatex(0, String.raw`A=\begin{bmatrix}1\\ 2\end{bmatrix},\:B=\begin{bmatrix}3 & 4\end{bmatrix}`);

await page.locator('#add-math-cell').click();
await page.setLatex(1, String.raw`\mathrm{count}\left(A\times B\right)=`);

await page.locator('#add-math-cell').click();
await page.setLatex(2, String.raw`\mathrm{sum}\left(A\times B\right)=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent(`#result-value-1`);
expect(parseLatexFloat(content)).toBeCloseTo(4, precision);
content = await page.textContent('#result-units-1');
expect(content).toBe('');

content = await page.textContent(`#result-value-2`);
expect(parseLatexFloat(content)).toBeCloseTo(21, precision);
content = await page.textContent('#result-units-2');
expect(content).toBe('');
});

test('Matrix multiplication with cdot symbol passed as argument to function', async () => {
await page.setLatex(0, String.raw`A=\begin{bmatrix}1\\ 2\end{bmatrix},\:B=\begin{bmatrix}3 & 4\end{bmatrix}`);

await page.locator('#add-math-cell').click();
await page.setLatex(1, String.raw`\mathrm{count}\left(A\cdot B\right)=`);

await page.locator('#add-math-cell').click();
await page.setLatex(2, String.raw`\mathrm{sum}\left(A\cdot B\right)=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent(`#result-value-1`);
expect(parseLatexFloat(content)).toBeCloseTo(4, precision);
content = await page.textContent('#result-units-1');
expect(content).toBe('');

content = await page.textContent(`#result-value-2`);
expect(parseLatexFloat(content)).toBeCloseTo(21, precision);
content = await page.textContent('#result-units-2');
expect(content).toBe('');
});

test('Symbolic matrix multiplication', async () => {
await page.setLatex(0, String.raw`A\times B=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent('#result-value-0');
expect(content).toBe(String.raw`A \cdot B`);
});

test('Matrix multiplication with matrix and non-matrix symbol', async () => {
await page.setLatex(0, String.raw`A=\begin{bmatrix}a\\ b\end{bmatrix}`);

await page.locator('#add-math-cell').click();
await page.setLatex(1, String.raw`A\times B=`);

await page.waitForSelector('text=Updating...', {state: 'detached'});

let content = await page.textContent('#result-value-1');
expect(content).toBe(String.raw`\begin{bmatrix} B \cdot a \\ B \cdot b \end{bmatrix}`);
});

0 comments on commit fc356b1

Please sign in to comment.