Skip to content

Commit

Permalink
feat[next]: Allow type inference without domain argument to `as_field…
Browse files Browse the repository at this point in the history
…op` (#1689)

In case we don't have a domain argument to `as_fieldop` we can not infer
the exact result type. In order to still allow some passes which don't
need this information to run before the domain inference, we continue
with a dummy domain. One example is the CollapseTuple pass which only
needs information about the structure, e.g. how many tuple elements does
this node have, but not the dimensions of a field.

Note that it might appear as if using the TraceShift pass would allow us
to deduce the return type of `as_fieldop` without a domain, but this is
not the case, since we don't have information on the ordering of
dimensions. In this example
```
as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)
```
it is unclear if the result has dimension I, J or J, I.
  • Loading branch information
tehrengruber authored Oct 16, 2024
1 parent 5ce0fb8 commit 3f7fcee
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 6 deletions.
4 changes: 3 additions & 1 deletion src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,11 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType:
def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType:
domain = self.visit(node.domain, ctx=ctx)
assert isinstance(domain, it_ts.DomainType)
assert domain.dims != "unknown"
assert node.dtype
return type_info.apply_to_primitive_constituents(
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above
node.dtype,
)

def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class NamedRangeType(ts.TypeSpec):

@dataclasses.dataclass(frozen=True)
class DomainType(ts.DataType):
dims: list[common.Dimension]
dims: list[common.Dimension] | Literal["unknown"]


@dataclasses.dataclass(frozen=True)
Expand Down
27 changes: 23 additions & 4 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,17 +271,36 @@ def _convert_as_fieldop_input_to_iterator(

@_register_builtin_type_synthesizer
def as_fieldop(
stencil: TypeSynthesizer, domain: it_ts.DomainType, offset_provider: common.OffsetProvider
stencil: TypeSynthesizer,
domain: Optional[it_ts.DomainType] = None,
*,
offset_provider: common.OffsetProvider,
) -> TypeSynthesizer:
# In case we don't have a domain argument to `as_fieldop` we can not infer the exact result
# type. In order to still allow some passes which don't need this information to run before the
# domain inference, we continue with a dummy domain. One example is the CollapseTuple pass
# which only needs information about the structure, e.g. how many tuple elements does this node
# have, but not the dimensions of a field.
# Note that it might appear as if using the TraceShift pass would allow us to deduce the return
# type of `as_fieldop` without a domain, but this is not the case, since we don't have
# information on the ordering of dimensions. In this example
# `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)`
# it is unclear if the result has dimension I, J or J, I.
if domain is None:
domain = it_ts.DomainType(dims="unknown")

@TypeSynthesizer
def applied_as_fieldop(*fields) -> ts.FieldType:
def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType:
stencil_return = stencil(
*(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields),
offset_provider=offset_provider,
)
assert isinstance(stencil_return, ts.DataType)
return type_info.apply_to_primitive_constituents(
lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type), stencil_return
lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type)
if domain.dims != "unknown"
else ts.DeferredType(constraint=ts.FieldType),
stencil_return,
)

return applied_as_fieldop
Expand Down Expand Up @@ -329,7 +348,7 @@ def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider


@_register_builtin_type_synthesizer
def shift(*offset_literals, offset_provider) -> TypeSynthesizer:
def shift(*offset_literals, offset_provider: common.OffsetProvider) -> TypeSynthesizer:
@TypeSynthesizer
def apply_shift(
it: it_ts.IteratorType | ts.DeferredType,
Expand Down
13 changes: 13 additions & 0 deletions tests/next_tests/unit_tests/iterator_tests/test_type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,16 @@ def test_if_stmt():
result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True)
assert result.cond.type == bool_type
assert result.true_branch[0].expr.type == float_i_field


def test_as_fieldop_without_domain():
testee = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("IOff", 1)("it"))))(
im.ref("inp", float_i_field)
)
result = itir_type_inference.infer(
testee, offset_provider={"IOff": IDim}, allow_undeclared_symbols=True
)
assert result.type == ts.DeferredType(constraint=ts.FieldType)
assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType(
position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype
)

0 comments on commit 3f7fcee

Please sign in to comment.