Skip to content

Commit

Permalink
fix extract_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Dec 3, 2024
1 parent 1471750 commit 86960cb
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 11 deletions.
3 changes: 3 additions & 0 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr:
def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
# TODO(tehrengruber): extend iterator ir to support unary operators
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
Expand Down Expand Up @@ -421,12 +422,14 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr:

def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
min_value, _ = type_info.arithmetic_bounds(dtype)
init_expr = self._make_literal(str(min_value), dtype)
return self._make_reduction_expr(node, "maximum", init_expr, **kwargs)

def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
_, max_value = type_info.arithmetic_bounds(dtype)
init_expr = self._make_literal(str(max_value), dtype)
return self._make_reduction_expr(node, "minimum", init_expr, **kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr:
def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
# TODO(tehrengruber): extend iterator ir to support unary operators
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
Expand Down Expand Up @@ -441,12 +442,14 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr:

def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
min_value, _ = type_info.arithmetic_bounds(dtype)
init_expr = self._make_literal(str(min_value), dtype)
return self._make_reduction_expr(node, "maximum", init_expr, **kwargs)

def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
_, max_value = type_info.arithmetic_bounds(dtype)
init_expr = self._make_literal(str(max_value), dtype)
return self._make_reduction_expr(node, "minimum", init_expr, **kwargs)
Expand Down
12 changes: 6 additions & 6 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def _transform_by_pattern(
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = (
type_info.apply_to_primitive_constituents(
type_info.extract_dtype,
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
tmp_dtypes: (
ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...]
) = type_info.apply_to_primitive_constituents(
type_info.extract_dtype,
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)

# allocate temporary for all tuple elements
Expand Down
17 changes: 12 additions & 5 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def apply_to_primitive_constituents(
return fun(*symbol_types)


def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType:
def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType:
"""
Extract the data type from ``symbol_type`` if it is either `FieldType` or `ScalarType`.
Expand All @@ -213,7 +213,6 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType:
"""
match symbol_type:
case ts.FieldType(dtype=dtype):
assert isinstance(dtype, ts.ScalarType)
return dtype
case ts.ScalarType() as dtype:
return dtype
Expand All @@ -235,7 +234,10 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool:
>>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)))
True
"""
return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64]
return isinstance(symbol_type, ts.ScalarType) and extract_dtype(symbol_type).kind in [ # type: ignore[union-attr] # checked is `ScalarType`
ts.ScalarKind.FLOAT32,
ts.ScalarKind.FLOAT64,
]


def is_integer(symbol_type: ts.TypeSpec) -> bool:
Expand Down Expand Up @@ -296,7 +298,10 @@ def is_number(symbol_type: ts.TypeSpec) -> bool:


def is_logical(symbol_type: ts.TypeSpec) -> bool:
return extract_dtype(symbol_type).kind is ts.ScalarKind.BOOL
return (
isinstance(symbol_type, ts.ScalarType)
and extract_dtype(symbol_type).kind is ts.ScalarKind.BOOL # type: ignore[union-attr] # checked is `ScalarType`
)


def is_arithmetic(symbol_type: ts.TypeSpec) -> bool:
Expand Down Expand Up @@ -502,7 +507,9 @@ def promote(
return types[0]
elif all(isinstance(type_, (ts.ScalarType, ts.FieldType)) for type_ in types):
dims = common.promote_dims(*(extract_dims(type_) for type_ in types))
dtype = cast(ts.ScalarType, promote(*(extract_dtype(type_) for type_ in types)))
extracted_dtypes = [extract_dtype(type_) for type_ in types]
assert all(isinstance(dtype, ts.ScalarType) for dtype in extracted_dtypes)
dtype = cast(ts.ScalarType, promote(*extracted_dtypes)) # type: ignore[arg-type] # checked is `ScalarType`

return ts.FieldType(dims=dims, dtype=dtype)
raise TypeError("Expected a 'FieldType' or 'ScalarType'.")
Expand Down

0 comments on commit 86960cb

Please sign in to comment.