Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added dot kron rewrite #1090

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working rewrite and test
  • Loading branch information
tanish1729 committed Nov 15, 2024
commit 30198d0a1d27e199c7312868e87828da6e90ee49
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,13 +996,13 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
@node_rewriter([Dot])
def rewrite_dot_kron(fgraph, node):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a docstring and typehints

potential_kron = node.inputs[0].owner
if not (isinstance(potential_kron.op, KroneckerProduct)):
if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)):
return False

c = node.inputs[1]
[a, b] = potential_kron.inputs

m, n = a.type.shape
p, q = b.type.shape
out_clever = (b @ c.reshape(shape=(n, q)).T @ a.T).ravel()
out_clever = pt.expand_dims((b @ c.reshape(shape=(n, q)).T @ a.T).T.ravel(), 1)
tanish1729 marked this conversation as resolved.
Show resolved Hide resolved
return [out_clever]
5 changes: 2 additions & 3 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pytensor.tensor import swapaxes
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, _allclose, dot, matmul
from pytensor.tensor.math import _allclose, dot, matmul
from pytensor.tensor.nlinalg import (
SVD,
Det,
Expand Down Expand Up @@ -918,8 +918,7 @@ def test_dot_kron_rewrite():
# REWRITE TEST
f_direct_rewritten = function([a, b, c], out_direct, mode="FAST_RUN")
nodes = f_direct_rewritten.maker.fgraph.apply_nodes
print(nodes)
assert not any(isinstance(node.op.core_op, Dot) for node in nodes)
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)

# NUMERIC VALUE TEST
a_test = np.random.rand(m, n)
Expand Down
Loading