-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added dot kron rewrite #1090
base: main
Are you sure you want to change the base?
Added dot kron rewrite #1090
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1090 +/- ##
=======================================
Coverage 82.11% 82.11%
=======================================
Files 183 183
Lines 47959 47972 +13
Branches 8635 8636 +1
=======================================
+ Hits 39381 39394 +13
Misses 6411 6411
Partials 2167 2167
|
@register_canonicalize | ||
@register_stabilize | ||
@node_rewriter([Dot]) | ||
def rewrite_dot_kron(fgraph, node): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a docstring and typehints
@@ -906,3 +906,29 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): | |||
f_rewritten = function([x], z_cholesky, mode="FAST_RUN") | |||
nodes = f_rewritten.maker.fgraph.apply_nodes | |||
assert any(isinstance(node.op, Cholesky) for node in nodes) | |||
|
|||
|
|||
def test_dot_kron_rewrite(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test with batch dims. I am worried the use of e.g. ravel
in the rewrite will cause problems.
The name I found for the identity we're exploiting here is vec-trick, so we can use that in the naming instead of "clever way", which doesn't say anything |
Description
Adds a rewrite for kron(a,b) @ c -> (b @ c.reshape @ a.T).ravel
Related Issue
dot(kron(a, b), c)
#1043Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1090.org.readthedocs.build/en/1090/