-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
numpy backend requires jax? #18840
Comments
It seems numpy backend has Jax dependency. I can find it in backend/numpy/nn.py also. |
Yes, that's intentional. We found that when implementing nn ops like conv or pooling, a low-level pure numpy implementation was vastly less performant than jax, so we just use jax for these ops. |
I feel like this is seriously misleading. numpy and Jax offer drastically different compilation, testing, and integration challenges. installing Jax and installing numpy can be quite different in resource limited environments. I think that:
for example, I would expect a backend of
as is, users cannot use the numpy backend at all without Jax. in my mind, it just shouldn't be called "numpy backend" I am not saying this in any hostile way, just as a packager that tried to test things with numpy only and was quite surprised that I needed a heavy dependency like Jax to get any operators to work. If my tone is negative, it is because in my mind, you hastily closed this issue without giving the chance to respond. |
I second the comment of @hmaarrfk. I think it would be useful to have the possibility of having a lightweight keras installation powered by the numpy backend in order to run small inference tasks in limited environments. |
That's a reasonable view. There's definitely a huge slowdown however. If you'd like to open a PR to revert some ops to pure numpy implementation, we can consider it. |
Thanks François for your answer. I see that the functions involved are |
The fft ops should be a trivial change since they're already available in numpy. |
It seems that the numpy backend imports jax
keras/keras/backend/numpy/image.py
Line 1 in 9c675a9
and seems to use it below:
is that intention?
The text was updated successfully, but these errors were encountered: