From 936ec197c737e45676ace89592f69b6500441c57 Mon Sep 17 00:00:00 2001 From: Supun Setunga Date: Thu, 25 Jul 2024 15:42:30 -0700 Subject: [PATCH] Add tests for nil-coalesce type inference --- runtime/sema/checker.go | 6 +- runtime/tests/checker/type_inference_test.go | 113 +++++++++++++++++++ 2 files changed, 116 insertions(+), 3 deletions(-) diff --git a/runtime/sema/checker.go b/runtime/sema/checker.go index a03655e0ac..51d5204b44 100644 --- a/runtime/sema/checker.go +++ b/runtime/sema/checker.go @@ -2768,9 +2768,9 @@ func (checker *Checker) checkNativeModifier(isNative bool, position ast.HasPosit } func (checker *Checker) leastCommonSuperType(pos ast.HasPosition, types ...Type) Type { - elementType := LeastCommonSuperType(types...) + superType := LeastCommonSuperType(types...) - if elementType == InvalidType { + if superType == InvalidType { checker.report( &TypeAnnotationRequiredError{ Cause: "cannot infer type:", @@ -2779,5 +2779,5 @@ func (checker *Checker) leastCommonSuperType(pos ast.HasPosition, types ...Type) ) } - return elementType + return superType } diff --git a/runtime/tests/checker/type_inference_test.go b/runtime/tests/checker/type_inference_test.go index cf191c2879..99bc14702c 100644 --- a/runtime/tests/checker/type_inference_test.go +++ b/runtime/tests/checker/type_inference_test.go @@ -1375,3 +1375,116 @@ func TestCheckDeploymentResultInference(t *testing.T) { assert.Equal(t, sema.DeploymentResultType, variableSizedType.Type) } + +func TestCheckNilCoalesceExpressionTypeInference(t *testing.T) { + + t.Parallel() + + t.Run("resource", func(t *testing.T) { + t.Parallel() + + code := ` + resource R {} + + fun f(): @R? { + return <-create R() + } + + let y <- f() ?? panic("no R") + ` + + checker, err := ParseAndCheckWithPanic(t, code) + require.NoError(t, err) + + yType := RequireGlobalValue(t, checker.Elaboration, "y") + require.IsType(t, &sema.CompositeType{}, yType) + compositeType := yType.(*sema.CompositeType) + + assert.Equal(t, "R", compositeType.Identifier) + }) + + t.Run("any resource", func(t *testing.T) { + t.Parallel() + + code := ` + resource R {} + + fun f(): @AnyResource? { + return <-create R() + } + + let y <- f() ?? panic("no R") + ` + + checker, err := ParseAndCheckWithPanic(t, code) + require.NoError(t, err) + + yType := RequireGlobalValue(t, checker.Elaboration, "y") + assert.Equal(t, sema.AnyResourceType, yType) + }) + + t.Run("struct", func(t *testing.T) { + t.Parallel() + + code := ` + struct S {} + + fun f(): S? { + return S() + } + + let y = f() ?? panic("no S") + ` + + checker, err := ParseAndCheckWithPanic(t, code) + require.NoError(t, err) + + yType := RequireGlobalValue(t, checker.Elaboration, "y") + require.IsType(t, &sema.CompositeType{}, yType) + compositeType := yType.(*sema.CompositeType) + + assert.Equal(t, "S", compositeType.Identifier) + }) + + t.Run("any struct", func(t *testing.T) { + t.Parallel() + + code := ` + struct S {} + + fun f(): AnyStruct? { + return S() + } + + let y = f() ?? panic("no S") + ` + + checker, err := ParseAndCheckWithPanic(t, code) + require.NoError(t, err) + + yType := RequireGlobalValue(t, checker.Elaboration, "y") + assert.Equal(t, sema.AnyStructType, yType) + }) + + t.Run("invalid type", func(t *testing.T) { + t.Parallel() + + code := ` + struct S {} + resource R {} + + fun f(): @R? { + return <-create R() + } + + let y <- f() ?? S() + ` + + _, err := ParseAndCheckWithPanic(t, code) + + errs := RequireCheckerErrors(t, err, 1) + + typeAnnotRequiredError := &sema.TypeAnnotationRequiredError{} + require.ErrorAs(t, errs[0], &typeAnnotRequiredError) + }) +}