Skip to content

Commit

Permalink
Update to MLX 0.21.0
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Dec 7, 2024
1 parent bde7972 commit ce1e918
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 65 deletions.
2 changes: 1 addition & 1 deletion deps/mlx
Submodule mlx updated 55 files
+1 −1 CMakeLists.txt
+172 −45 benchmarks/python/sdpa_bench.py
+1 −1 docs/src/install.rst
+0 −1 docs/src/python/fast.rst
+2 −0 docs/src/python/nn/layers.rst
+1 −1 examples/extensions/requirements.txt
+11 −0 mlx/backend/common/primitives.cpp
+190 −133 mlx/backend/common/quantized.cpp
+145 −69 mlx/backend/common/reduce.cpp
+23 −23 mlx/backend/metal/jit_kernels.cpp
+4 −3 mlx/backend/metal/kernels.h
+19 −3 mlx/backend/metal/kernels/CMakeLists.txt
+338 −133 mlx/backend/metal/kernels/quantized.h
+2 −1 mlx/backend/metal/kernels/quantized.metal
+126 −156 mlx/backend/metal/kernels/reduce.metal
+10 −5 mlx/backend/metal/kernels/reduction/reduce_all.h
+55 −41 mlx/backend/metal/kernels/reduction/reduce_col.h
+22 −18 mlx/backend/metal/kernels/reduction/reduce_row.h
+0 −919 mlx/backend/metal/kernels/scaled_dot_product_attention.metal
+0 −42 mlx/backend/metal/kernels/scaled_dot_product_attention_params.h
+296 −0 mlx/backend/metal/kernels/steel/attn/attn.h
+349 −0 mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h
+31 −0 mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal
+264 −0 mlx/backend/metal/kernels/steel/attn/loader.h
+726 −0 mlx/backend/metal/kernels/steel/attn/mma.h
+36 −0 mlx/backend/metal/kernels/steel/attn/params.h
+71 −0 mlx/backend/metal/kernels/steel/attn/transforms.h
+2 −2 mlx/backend/metal/kernels/steel/gemm/mma.h
+60 −30 mlx/backend/metal/kernels/utils.h
+5 −3 mlx/backend/metal/nojit_kernels.cpp
+11 −0 mlx/backend/metal/primitives.cpp
+7 −13 mlx/backend/metal/quantized.cpp
+247 −64 mlx/backend/metal/reduce.cpp
+87 −101 mlx/backend/metal/scaled_dot_product_attention.cpp
+6 −2 mlx/backend/metal/utils.cpp
+1 −0 mlx/backend/metal/utils.h
+1 −0 mlx/backend/no_cpu/primitives.cpp
+1 −0 mlx/backend/no_metal/primitives.cpp
+68 −73 mlx/fast.cpp
+0 −8 mlx/fast.h
+1 −10 mlx/io/load.cpp
+3 −15 mlx/io/safetensors.cpp
+29 −3 mlx/ops.cpp
+6 −0 mlx/ops.h
+26 −0 mlx/primitives.cpp
+19 −0 mlx/primitives.h
+46 −107 python/mlx/nn/layers/pooling.py
+0 −43 python/src/fast.cpp
+0 −12 python/tests/test_fast.py
+24 −20 python/tests/test_nn.py
+4 −4 python/tests/test_quantized.py
+22 −0 python/tests/test_reduce.py
+1 −1 setup.py
+3 −3 tests/compile_tests.cpp
+14 −0 tests/ops_tests.cpp
66 changes: 18 additions & 48 deletions lib/nn/layers/pooling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,8 @@ class Pool3d extends Pool {
*
* @remarks
*
* Assuming an input of shape `(N, L, C)` and `kernelSize` is `k`, the output is
* a tensor of shape `(N, L_out, C)`, given by:
*
* `out(N_i, t, C_j) = max_{m=0,...,k-1} input(N_i, stride * t + m, C_j)`
*
* where `L_out = floor((L + 2 * padding - kernelSize) / stride) + 1`.
* Spatially downsamples the input by taking the maximum of a sliding window
* of size `kernel_size` and sliding stride `stride`.
*
* @param kernelSize - The size of the pooling window kernel.
* @param stride - The stride of the pooling window. Default: `kernelSize`.
Expand All @@ -122,12 +118,8 @@ export class MaxPool1d extends Pool1d {
*
* @remarks
*
* Assuming an input of shape `(N, L, C)` and `kernelSize` is `k`, the output is
* a tensor of shape `(N, L_out, C)`, given by:
*
* `out(N_i, t, C_j) = 1/k * sum_{m=0,...,k-1} input(N_i, stride * t + m, C_j)`
*
* where `L_out = floor((L + 2 * padding - kernelSize) / stride) + 1`.
* Spatially downsamples the input by taking the average of a sliding window
* of size `kernel_size` and sliding stride `stride`.
*
* @param kernelSize - The size of the pooling window kernel.
* @param stride - The stride of the pooling window. Default: `kernelSize`.
Expand All @@ -147,15 +139,11 @@ export class AvgPool1d extends Pool1d {
*
* @remarks
*
* Assuming an input of shape `(N, H, W, C)` and `kernelSize` is `(k_H, k_W)`,
* the output is a tensor of shape `(N, H_out, W_out, C)`, given by:
* Spatially downsamples the input by taking the maximum of a sliding window
* of size `kernel_size` and sliding stride `stride`.
*
* `out(N_i, h, w, C_j) = max_{m=0,...,k_H-1} max_{n=0,...,k_W-1} input(N_i, stride[0] * h + m, stride[1] * w + n, C_j)`
* The parameters `kernelSize`, `stride` and `padding` can either be:
*
* where `H_out = floor((H + 2 * padding[0] - kernelSize[0]) / stride[0]) + 1`
* `W_out = floor((W + 2 * padding[1] - kernelSize[1]) / stride[1]) + 1`
*
* The parameters `kernelSize`, `stride`, `padding`, can either be:
* - a single `number` -- in which case the same value is used for both the
* height and width axis;
* - a `tuple` of two `numbers`s -- in which case, the first `number` is used
Expand All @@ -179,16 +167,10 @@ export class MaxPool2d extends Pool2d {
*
* @remarks
*
* Assuming an input of shape `(N, H, W, C)` and `kernelSize` is `(kH, kW)`,
* the output is a tensor of shape `(N, H_out, W_out, C)`, given by:
*
* `out(N_i, h, w, C_j) = 1/(kH*kW) * sum_{m=0,...,kH-1} sum_{n=0,...,kW-1}
* input(N_i, stride[0] * h + m, stride[1] * w + n, C_j)`
* Spatially downsamples the input by taking the average of a sliding window
* of size `kernel_size` and sliding stride `stride`.
*
* where `H_out = floor((H + 2 * padding[0] - kernelSize[0]) / stride[0]) + 1`,
* `W_out = floor((W + 2 * padding[1] - kernelSize[1]) / stride[1]) + 1`.
*
* The parameters `kernelSize`, `stride`, `padding`, can either be:
* The parameters `kernelSize`, `stride` and `padding` can either be:
*
* - a single `number` -- in which case the same value is used for both the
* height and width axis
Expand All @@ -213,22 +195,16 @@ export class AvgPool2d extends Pool2d {
*
* @remarks
*
* Assuming an input of shape `(N, D, H, W, C)` and `kernelSize` is `(k_D, k_H, k_W)`,
* the output is a tensor of shape `(N, D_out, H_out, W_out, C)`, given by:
*
* `out(N_i, d, h, w, C_j) = max_{l=0,...,k_D-1} max_{m=0,...,k_H-1} max_{n=0,...,k_W-1}
* input(N_i, stride[0] * d + l, stride[1] * h + m, stride[2] * w + n, C_j)`
* Spatially downsamples the input by taking the maximum of a sliding window
* of size `kernel_size` and sliding stride `stride`.
*
* where `D_out = floor((D + 2 * padding[0] - kernelSize[0]) / stride[0]) + 1`
* `H_out = floor((H + 2 * padding[1] - kernelSize[1]) / stride[1]) + 1`
* `W_out = floor((W + 2 * padding[2] - kernelSize[2]) / stride[2]) + 1`
* The parameters `kernelSize`, `stride` and `padding` can either be:
*
* The parameters `kernelSize`, `stride`, `padding`, can either be:
* - a single `number` -- in which case the same value is used for the depth,
* height and width axis;
* - a `tuple` of three `numbers`s -- in which case, the first `number` is used
* for the depth axis, the second `number` for the height axis, and the third
* `number` for the width axis.
* - a `tuple` of three `numbers`s -- in which case, the first `number` is
* used for the depth axis, the second `number` for the height axis, and the
* third `number` for the width axis.
*
* @param kernelSize - The size of the pooling window.
* @param stride - The stride of the pooling window. Default: `kernelSize`.
Expand All @@ -248,14 +224,8 @@ export class MaxPool3d extends Pool3d {
*
* @remarks
*
* Assuming an input of shape `(N, D, H, W, C)` and `kernelSize` is `(k_D, k_H, k_W)`,
* the output is a tensor of shape `(N, D_out, H_out, W_out, C)`, given by:
*
* `out(N_i, d, h, w, C_j) = (1 / (k_D * k_H * k_W)) * sum_{l=0,...,k_D-1} sum_{m=0,...,k_H-1} sum_{n=0,...,k_W-1} input(N_i, stride[0] * d + l, stride[1] * h + m, stride[2] * w + n, C_j)`
*
* where `D_out = floor((D + 2 * padding[0] - kernelSize[0]) / stride[0]) + 1`
* `H_out = floor((H + 2 * padding[1] - kernelSize[1]) / stride[1]) + 1`
* `W_out = floor((W + 2 * padding[2] - kernelSize[2]) / stride[2]) + 1`
* Spatially downsamples the input by taking the average of a sliding window
* of size `kernel_size` and sliding stride `stride`.
*
* The parameters `kernelSize`, `stride`, `padding`, can either be:
*
Expand Down
17 changes: 1 addition & 16 deletions src/fast.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,6 @@
#include "src/array.h"
#include "src/stream.h"

namespace fast_ops {

mx::array AffineQuantize(const mx::array& w,
const mx::array& scales,
const mx::array& biases,
std::optional<int> group_size,
std::optional<int> bits,
mx::StreamOrDevice s) {
return mx::fast::affine_quantize(w, scales, biases, group_size.value_or(64),
bits.value_or(4));
}

} // namespace fast_ops

void InitFast(napi_env env, napi_value exports) {
napi_value fast = ki::CreateObject(env);
ki::Set(env, exports, "fast", fast);
Expand All @@ -23,6 +9,5 @@ void InitFast(napi_env env, napi_value exports) {
"rmsNorm", &mx::fast::rms_norm,
"layerNorm", &mx::fast::layer_norm,
"rope", &mx::fast::rope,
"scaledDotProductAttention", &mx::fast::scaled_dot_product_attention,
"affineQuantize", &fast_ops::AffineQuantize);
"scaledDotProductAttention", &mx::fast::scaled_dot_product_attention);
}

0 comments on commit ce1e918

Please sign in to comment.