Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with jax.grad #22

Open
SNMS95 opened this issue Dec 19, 2024 · 2 comments
Open

Issue with jax.grad #22

SNMS95 opened this issue Dec 19, 2024 · 2 comments

Comments

@SNMS95
Copy link

SNMS95 commented Dec 19, 2024

Hey,

Thanks for a lot for this nice package!
I was trying to find the graph for the gradient operation and it does not do anything

from jax import make_jaxpr
import jax
import jpviz

def func(x):
    return x**2

print(make_jaxpr(jax.grad(func))(1.0)) # This works and shows a series of ops
dot_graph = jpviz.draw(jax.grad(func), collapse_primitives=False)(1.0)
jpviz.view_pydot(dot_graph)  # This shows only a single box!
@zombie-einstein
Copy link
Owner

Thanks for a lot for this nice package!

Thanks, glad you are making use of it!

I was trying to find the graph for the gradient operation and it does not do anything

Thanks for the report and example, I'll grab a look at this.

@zombie-einstein
Copy link
Owner

I think the issue here is not jit compiling the function, so

import jax
import jpviz

def func(x):
    return x**2

dot_graph = jpviz.draw(jax.jit(jax.grad(func)), collapse_primitives=False)(1.0)
jpviz.view_pydot(dot_graph)

should work?

I've also just released an update that fixed some old node labelling issues.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants