Skip to content

Commit

Permalink
add missing at functions
Browse files Browse the repository at this point in the history
- fixes ml-explore#188
- also fixes bug with array[idx] += 3 that would _not_ assign back to array
  • Loading branch information
davidkoski committed Feb 13, 2025
1 parent b990c58 commit 8eded82
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 10 deletions.
123 changes: 123 additions & 0 deletions Source/MLX/ArrayAt.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright © 2025 Apple Inc.

import Cmlx
import Foundation

public struct ArrayAt {

let array: MLXArray

public subscript(indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default)
-> ArrayAtIndices
{
get {
ArrayAtIndices(
array: array,
indexOperations: indices.map { $0.mlxArrayIndexOperation },
stream: stream)
}
}

public subscript(indices: [MLXArrayIndex], stream stream: StreamOrDevice = .default)
-> ArrayAtIndices
{
get {
ArrayAtIndices(
array: array,
indexOperations: indices.map { $0.mlxArrayIndexOperation },
stream: stream)
}
}
}

public struct ArrayAtIndices {

let array: MLXArray
let indexOperations: [MLXArrayIndexOperation]
let stream: StreamOrDevice

public func add(_ values: ScalarOrArray) -> MLXArray {
let values = values.asMLXArray(dtype: array.dtype)
let (indices, update, axes) = scatterArguments(
src: array, operations: indexOperations, update: values, stream: stream)

if !indices.isEmpty {
let indices_vector = new_mlx_vector_array(indices)
defer { mlx_vector_array_free(indices_vector) }

var result = mlx_array_new()
mlx_scatter_add(
&result, array.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx)

return MLXArray(result)
} else {
return array + update
}
}

public func subtract(_ values: ScalarOrArray) -> MLXArray {
add(-values.asMLXArray(dtype: array.dtype))
}

public func multiply(_ values: ScalarOrArray) -> MLXArray {
let values = values.asMLXArray(dtype: array.dtype)
let (indices, update, axes) = scatterArguments(
src: array, operations: indexOperations, update: values, stream: stream)

if !indices.isEmpty {
let indices_vector = new_mlx_vector_array(indices)
defer { mlx_vector_array_free(indices_vector) }

var result = mlx_array_new()
mlx_scatter_prod(
&result, array.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx)

return MLXArray(result)
} else {
return array * update
}
}

public func divide(_ values: ScalarOrArray) -> MLXArray {
multiply(1 / values.asMLXArray(dtype: array.dtype))
}

public func minimum(_ values: ScalarOrArray) -> MLXArray {
let values = values.asMLXArray(dtype: array.dtype)
let (indices, update, axes) = scatterArguments(
src: array, operations: indexOperations, update: values, stream: stream)

if !indices.isEmpty {
let indices_vector = new_mlx_vector_array(indices)
defer { mlx_vector_array_free(indices_vector) }

var result = mlx_array_new()
mlx_scatter_min(
&result, array.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx)

return MLXArray(result)
} else {
return MLX.minimum(array, update)
}
}

public func maximum(_ values: ScalarOrArray) -> MLXArray {
let values = values.asMLXArray(dtype: array.dtype)
let (indices, update, axes) = scatterArguments(
src: array, operations: indexOperations, update: values, stream: stream)

if !indices.isEmpty {
let indices_vector = new_mlx_vector_array(indices)
defer { mlx_vector_array_free(indices_vector) }

var result = mlx_array_new()
mlx_scatter_max(
&result, array.ctx, indices_vector, update.ctx, axes, axes.count, stream.ctx)

return MLXArray(result)
} else {
return MLX.maximum(array, update)
}
}

}
1 change: 1 addition & 0 deletions Source/MLX/MLXArray+Indexing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ extension MLXArray {
/// - ``MLXArrayIndex/ellipsis``
/// - ``MLXArrayIndex/newAxis``
/// - ``MLXArrayIndex/stride(from:to:by:)``
/// - ``MLXArray/at``
public subscript(indices: MLXArrayIndex..., stream stream: StreamOrDevice = .default)
-> MLXArray
{
Expand Down
16 changes: 8 additions & 8 deletions Source/MLX/MLXArray+Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ extension MLXArray {
/// ### See Also
/// - <doc:arithmetic>
/// - ``add(_:_:stream:)``
public static func += (lhs: MLXArray, rhs: MLXArray) {
public static func += (lhs: inout MLXArray, rhs: MLXArray) {
lhs.update(lhs + rhs)
}

Expand All @@ -80,7 +80,7 @@ extension MLXArray {
/// ### See Also
/// - <doc:arithmetic>
/// - ``MLXArray/+(_:_:)-1rv98``
public static func += <T: ScalarOrArray>(lhs: MLXArray, rhs: T) {
public static func += <T: ScalarOrArray>(lhs: inout MLXArray, rhs: T) {
lhs += rhs.asMLXArray(dtype: lhs.dtype)
}

Expand Down Expand Up @@ -132,7 +132,7 @@ extension MLXArray {
/// ### See Also
/// - <doc:arithmetic>
/// - ``subtract(_:_:stream:)``
public static func -= (lhs: MLXArray, rhs: MLXArray) {
public static func -= (lhs: inout MLXArray, rhs: MLXArray) {
lhs.update(lhs - rhs)
}

Expand All @@ -148,7 +148,7 @@ extension MLXArray {
///
/// ### See Also
/// - <doc:arithmetic>
public static func -= <T: ScalarOrArray>(lhs: MLXArray, rhs: T) {
public static func -= <T: ScalarOrArray>(lhs: inout MLXArray, rhs: T) {
lhs -= rhs.asMLXArray(dtype: lhs.dtype)
}

Expand Down Expand Up @@ -223,7 +223,7 @@ extension MLXArray {
/// - ``multiply(_:_:stream:)``
/// - ``matmul(_:stream:)``
/// - ``matmul(_:_:stream:)``
public static func *= (lhs: MLXArray, rhs: MLXArray) {
public static func *= (lhs: inout MLXArray, rhs: MLXArray) {
lhs.update(lhs * rhs)
}

Expand All @@ -241,7 +241,7 @@ extension MLXArray {
/// ### See Also
/// - <doc:arithmetic>
/// - ``MLXArray/*(_:_:)-1z2ck``
public static func *= <T: ScalarOrArray>(lhs: MLXArray, rhs: T) {
public static func *= <T: ScalarOrArray>(lhs: inout MLXArray, rhs: T) {
lhs *= rhs.asMLXArray(dtype: lhs.dtype)
}

Expand Down Expand Up @@ -339,7 +339,7 @@ extension MLXArray {
/// - <doc:arithmetic>
/// - ``divide(_:_:stream:)``
/// - ``floorDivide(_:_:stream:)``
public static func /= (lhs: MLXArray, rhs: MLXArray) {
public static func /= (lhs: inout MLXArray, rhs: MLXArray) {
lhs.update(lhs / rhs)
}

Expand All @@ -355,7 +355,7 @@ extension MLXArray {
///
/// ### See Also
/// - <doc:arithmetic>
public static func /= <T: ScalarOrArray>(lhs: MLXArray, rhs: T) {
public static func /= <T: ScalarOrArray>(lhs: inout MLXArray, rhs: T) {
lhs /= rhs.asMLXArray(dtype: lhs.dtype)
}

Expand Down
32 changes: 32 additions & 0 deletions Source/MLX/MLXArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,38 @@ public final class MLXArray {
mlx_array_set(&new, self.ctx)
return MLXArray(new)
}

/// Used to apply update at given indices.
///
/// An assignment through indices `array[indicies]` will produce
/// a result where each index will only be updated once. For example:
///
/// ```swift
/// // this references each index twice
/// let idx = MLXArray([0, 1, 0, 1])
///
/// let a1 = MLXArray([0, 0])
/// a1[idx] += 1
/// assertEqual(a1, MLXArray([1, 1]))
///
/// // this will update 0 and 1 twice
/// var a2 = MLXArray([0, 0])
/// a2 = a2.at[idx].add(1)
/// assertEqual(a2, MLXArray([2, 2]))
/// ```
///
/// This is because the assignment through `array[indicies]` writes
/// a sub-array of `array` rather than performing the operation on each
/// resolved index.
///
/// The `at` property produces an intermediate value that can take a subscript
/// `[]` and produce an ``ArrayAtIndices`` that has several methods to
/// update values.
///
/// ### See Also
/// - ``subscript(indices:stream:)``
/// - ``ArrayAtIndices``
public var at: ArrayAt { ArrayAt(array: self) }
}

extension MLXArray: Updatable, Evaluatable {
Expand Down
41 changes: 41 additions & 0 deletions Tests/MLXTests/ArrayAtTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright © 2025 Apple Inc.

import Foundation
import MLX
import MLXLinalg
import XCTest

class ArrayAtTests: XCTestCase {

func testArrayAt() {
// from example at https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.array.at.html#mlx.core.array.at

// this references each index twice
let idx = MLXArray([0, 1, 0, 1])

// assign through index -- we can only observe the last assignment to a location
let a0 = MLXArray([0, 0])
a0[idx] = MLXArray(2)
assertEqual(a0, MLXArray([2, 2]))

// similar to above -- we can only observe one assignment, so we just get a +1
// note: there was a bug in the += operator where the lhs was not inout and
// this was producing [0, 0]
let a1 = MLXArray([0, 0])
a1[idx] += 1
assertEqual(a1, MLXArray([1, 1]))

// the bare add produces a value for each index including the duplicates
let a2 = MLXArray([0, 0])
assertEqual(a2[idx] + 1, MLXArray([1, 1, 1, 1]))

// but the assign back through the index will collapse the values down
// into the same location -- this is the same as a2[idx] += 1
a2[idx] = a2[idx] + 1
assertEqual(a2, MLXArray([1, 1]))

// this will update 0 and 1 twice
let a3 = MLXArray([0, 0])
assertEqual(a3.at[idx].add(1), MLXArray([2, 2]))
}
}
4 changes: 2 additions & 2 deletions Tests/MLXTests/MLXArray+OpsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class MLXArrayOpsTests: XCTestCase {
// MARK: - Operators

func testArithmeticSimple() {
let a = MLXArray([1, 2, 3])
let b = MLXArray(converting: [-5.0, 37.5, 4])
var a = MLXArray([1, 2, 3])
var b = MLXArray(converting: [-5.0, 37.5, 4])

// example of an expression -- the - 1 is using the 1 as ExpressibleByIntegerLiteral
let r = a + b - 1
Expand Down

0 comments on commit 8eded82

Please sign in to comment.