Skip to content

Commit

Permalink
Allow str stack_dim in meshgrid()
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Feb 3, 2025
1 parent 5ea1326 commit 57ed15d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ def fftfreq(resolution: Shape, dx: Union[Tensor, float] = 1, dtype: DType = None
return to_float(k) if dtype is None else cast(k, dtype)


def meshgrid(dims: Union[Callable, Shape] = spatial, stack_dim=channel('vector'), **dimensions: Union[int, Tensor, tuple, list, Any]) -> Tensor:
def meshgrid(dims: Union[Callable, Shape] = spatial, stack_dim: Union[Shape, str, None] = channel('vector'), **dimensions: Union[int, Tensor, tuple, list, Any]) -> Tensor:
"""
Generate a mesh-grid `Tensor` from keyword dimensions.
Expand All @@ -684,6 +684,8 @@ def meshgrid(dims: Union[Callable, Shape] = spatial, stack_dim=channel('vector')
(0, 1) along xˢ
"""
assert 'dim_type' not in dimensions, f"dim_type has been renamed to dims"
if isinstance(stack_dim, str):
stack_dim = auto(stack_dim, channel)
assert not stack_dim or stack_dim.name not in dimensions
if isinstance(dims, SHAPE_TYPES):
assert not dimensions, f"When passing a Shape to meshgrid(), no kwargs are allowed"
Expand Down

0 comments on commit 57ed15d

Please sign in to comment.