forked from ml-explore/mlx-swift
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- fixes ml-explore#188 - also fixes bug with array[idx] += 3 that would _not_ assign back to array
- Loading branch information
1 parent
b990c58
commit 8eded82
Showing
6 changed files
with
207 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright © 2025 Apple Inc. | ||
|
||
import Cmlx | ||
import Foundation | ||
|
||
public struct ArrayAt { | ||
|
||
let array: MLXArray | ||
|
||
public subscript(indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default) | ||
-> ArrayAtIndices | ||
{ | ||
get { | ||
ArrayAtIndices( | ||
array: array, | ||
indexOperations: indices.map { $0.mlxArrayIndexOperation }, | ||
stream: stream) | ||
} | ||
} | ||
|
||
public subscript(indices: [MLXArrayIndex], stream stream: StreamOrDevice = .default) | ||
-> ArrayAtIndices | ||
{ | ||
get { | ||
ArrayAtIndices( | ||
array: array, | ||
indexOperations: indices.map { $0.mlxArrayIndexOperation }, | ||
stream: stream) | ||
} | ||
} | ||
} | ||
|
||
public struct ArrayAtIndices { | ||
|
||
let array: MLXArray | ||
let indexOperations: [MLXArrayIndexOperation] | ||
let stream: StreamOrDevice | ||
|
||
public func add(_ values: ScalarOrArray) -> MLXArray { | ||
let values = values.asMLXArray(dtype: array.dtype) | ||
let (indices, update, axes) = scatterArguments( | ||
src: array, operations: indexOperations, update: values, stream: stream) | ||
|
||
if !indices.isEmpty { | ||
let indices_vector = new_mlx_vector_array(indices) | ||
defer { mlx_vector_array_free(indices_vector) } | ||
|
||
var result = mlx_array_new() | ||
mlx_scatter_add( | ||
&result, array.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx) | ||
|
||
return MLXArray(result) | ||
} else { | ||
return array + update | ||
} | ||
} | ||
|
||
public func subtract(_ values: ScalarOrArray) -> MLXArray { | ||
add(-values.asMLXArray(dtype: array.dtype)) | ||
} | ||
|
||
public func multiply(_ values: ScalarOrArray) -> MLXArray { | ||
let values = values.asMLXArray(dtype: array.dtype) | ||
let (indices, update, axes) = scatterArguments( | ||
src: array, operations: indexOperations, update: values, stream: stream) | ||
|
||
if !indices.isEmpty { | ||
let indices_vector = new_mlx_vector_array(indices) | ||
defer { mlx_vector_array_free(indices_vector) } | ||
|
||
var result = mlx_array_new() | ||
mlx_scatter_prod( | ||
&result, array.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx) | ||
|
||
return MLXArray(result) | ||
} else { | ||
return array * update | ||
} | ||
} | ||
|
||
public func divide(_ values: ScalarOrArray) -> MLXArray { | ||
multiply(1 / values.asMLXArray(dtype: array.dtype)) | ||
} | ||
|
||
public func minimum(_ values: ScalarOrArray) -> MLXArray { | ||
let values = values.asMLXArray(dtype: array.dtype) | ||
let (indices, update, axes) = scatterArguments( | ||
src: array, operations: indexOperations, update: values, stream: stream) | ||
|
||
if !indices.isEmpty { | ||
let indices_vector = new_mlx_vector_array(indices) | ||
defer { mlx_vector_array_free(indices_vector) } | ||
|
||
var result = mlx_array_new() | ||
mlx_scatter_min( | ||
&result, array.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx) | ||
|
||
return MLXArray(result) | ||
} else { | ||
return MLX.minimum(array, update) | ||
} | ||
} | ||
|
||
public func maximum(_ values: ScalarOrArray) -> MLXArray { | ||
let values = values.asMLXArray(dtype: array.dtype) | ||
let (indices, update, axes) = scatterArguments( | ||
src: array, operations: indexOperations, update: values, stream: stream) | ||
|
||
if !indices.isEmpty { | ||
let indices_vector = new_mlx_vector_array(indices) | ||
defer { mlx_vector_array_free(indices_vector) } | ||
|
||
var result = mlx_array_new() | ||
mlx_scatter_max( | ||
&result, array.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx) | ||
|
||
return MLXArray(result) | ||
} else { | ||
return MLX.maximum(array, update) | ||
} | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Copyright © 2025 Apple Inc. | ||
|
||
import Foundation | ||
import MLX | ||
import MLXLinalg | ||
import XCTest | ||
|
||
class ArrayAtTests: XCTestCase { | ||
|
||
func testArrayAt() { | ||
// from example at https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.array.at.html#mlx.core.array.at | ||
|
||
// this references each index twice | ||
let idx = MLXArray([0, 1, 0, 1]) | ||
|
||
// assign through index -- we can only observe the last assignment to a location | ||
let a0 = MLXArray([0, 0]) | ||
a0[idx] = MLXArray(2) | ||
assertEqual(a0, MLXArray([2, 2])) | ||
|
||
// similar to above -- we can only observe one assignment, so we just get a +1 | ||
// note: there was a bug in the += operator where the lhs was not inout and | ||
// this was producing [0, 0] | ||
let a1 = MLXArray([0, 0]) | ||
a1[idx] += 1 | ||
assertEqual(a1, MLXArray([1, 1])) | ||
|
||
// the bare add produces a value for each index including the duplicates | ||
let a2 = MLXArray([0, 0]) | ||
assertEqual(a2[idx] + 1, MLXArray([1, 1, 1, 1])) | ||
|
||
// but the assign back through the index will collapse the values down | ||
// into the same location -- this is the same as a2[idx] += 1 | ||
a2[idx] = a2[idx] + 1 | ||
assertEqual(a2, MLXArray([1, 1])) | ||
|
||
// this will update 0 and 1 twice | ||
let a3 = MLXArray([0, 0]) | ||
assertEqual(a3.at[idx].add(1), MLXArray([2, 2])) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters