diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 891e827839..bb08c02e3f 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -1358,6 +1358,25 @@ func (v *StringValue) GetMember(interpreter *Interpreter, locationRange Location return v.Split(invocation.Interpreter, invocation.LocationRange, separator.Str) }, ) + + case sema.StringTypeReplaceAllFunctionName: + return NewHostFunctionValue( + interpreter, + sema.StringTypeReplaceAllFunctionType, + func(invocation Invocation) Value { + of, ok := invocation.Arguments[0].(*StringValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + with, ok := invocation.Arguments[1].(*StringValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + return v.ReplaceAll(invocation.Interpreter, invocation.LocationRange, of.Str, with.Str) + }, + ) } return nil @@ -1441,6 +1460,23 @@ func (v *StringValue) Split(inter *Interpreter, locationRange LocationRange, sep ) } +func (v *StringValue) ReplaceAll(inter *Interpreter, locationRange LocationRange, of string, with string) *StringValue { + // Over-estimate the resulting string length. + // In the worst case, `of` can be empty in which case, `with` will be added at every index. + // e.g. `of` = "", `v` = "ABC", `with` = "1": result = "1A1B1C1". + lengthOverEstimate := (2*len(v.Str) + 1) * len(with) + + memoryUsage := common.NewStringMemoryUsage(lengthOverEstimate) + + return NewStringValue( + inter, + memoryUsage, + func() string { + return strings.ReplaceAll(v.Str, of, with) + }, + ) +} + func (v *StringValue) Storable(storage atree.SlabStorage, address atree.Address, maxInlineSize uint64) (atree.Storable, error) { return maybeLargeImmutableStorable(v, storage, address, maxInlineSize) } diff --git a/runtime/sema/string_type.go b/runtime/sema/string_type.go index 74b9baaccf..bcb562c6d6 100644 --- a/runtime/sema/string_type.go +++ b/runtime/sema/string_type.go @@ -116,6 +116,12 @@ func init() { StringTypeSplitFunctionType, StringTypeSplitFunctionDocString, ), + NewUnmeteredPublicFunctionMember( + t, + StringTypeReplaceAllFunctionName, + StringTypeReplaceAllFunctionType, + StringTypeReplaceAllFunctionDocString, + ), }) } } @@ -163,6 +169,13 @@ It does not modify the original string. If either of the parameters are out of the bounds of the string, or the indices are invalid (` + "`from > upTo`" + `), then the function will fail ` +const StringTypeReplaceAllFunctionName = "replaceAll" +const StringTypeReplaceAllFunctionDocString = ` +Returns a new string after replacing all the occurrences of parameter ` + "`of` with the parameter `with`" + `. + +If ` + "`with`" + ` is empty, it matches at the beginning of the string and after each UTF-8 sequence, yielding k+1 replacements for a string of length k. +` + // ByteArrayType represents the type [UInt8] var ByteArrayType = &VariableSizedType{ Type: UInt8Type, @@ -361,3 +374,18 @@ var StringTypeSplitFunctionType = NewSimpleFunctionType( }, ), ) + +var StringTypeReplaceAllFunctionType = NewSimpleFunctionType( + FunctionPurityView, + []Parameter{ + { + Identifier: "of", + TypeAnnotation: StringTypeAnnotation, + }, + { + Identifier: "with", + TypeAnnotation: StringTypeAnnotation, + }, + }, + StringTypeAnnotation, +) diff --git a/runtime/tests/checker/string_test.go b/runtime/tests/checker/string_test.go index 5d94c0c077..828cb73276 100644 --- a/runtime/tests/checker/string_test.go +++ b/runtime/tests/checker/string_test.go @@ -451,3 +451,86 @@ func TestCheckStringSplitTypeMissingArgumentLabelSeparator(t *testing.T) { assert.IsType(t, &sema.MissingArgumentLabelError{}, errs[0]) } + +func TestCheckStringReplaceAll(t *testing.T) { + + t.Parallel() + + checker, err := ParseAndCheck(t, ` + let s = "👪.❤️.Abc".replaceAll(of: "❤️", with: "|") + `) + require.NoError(t, err) + + assert.Equal(t, + sema.StringType, + RequireGlobalValue(t, checker.Elaboration, "s"), + ) +} + +func TestCheckStringReplaceAllTypeMismatchOf(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + let s = "Abc:1".replaceAll(of: 1234, with: "/") + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) +} + +func TestCheckStringReplaceAllTypeMismatchWith(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + let s = "Abc:1".replaceAll(of: "1", with: true) + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) +} + +func TestCheckStringReplaceAllTypeMismatchCharacters(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + let a: Character = "x" + let b: Character = "y" + let s = "Abc:1".replaceAll(of: a, with: b) + `) + + errs := RequireCheckerErrors(t, err, 2) + + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) + assert.IsType(t, &sema.TypeMismatchError{}, errs[1]) +} + +func TestCheckStringReplaceAllTypeMissingArgumentLabelOf(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + let s = "👪Abc".replaceAll("/", with: "abc") + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.MissingArgumentLabelError{}, errs[0]) +} + +func TestCheckStringReplaceAllTypeMissingArgumentLabelWith(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + let s = "👪Abc".replaceAll(of: "/", "abc") + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.MissingArgumentLabelError{}, errs[0]) +} diff --git a/runtime/tests/interpreter/string_test.go b/runtime/tests/interpreter/string_test.go index 37bba83186..176927311c 100644 --- a/runtime/tests/interpreter/string_test.go +++ b/runtime/tests/interpreter/string_test.go @@ -596,3 +596,50 @@ func TestInterpretStringSplit(t *testing.T) { ), ) } + +func TestInterpretStringReplaceAll(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun replaceAll(): String { + return "👪////❤️".replaceAll(of: "////", with: "||") + } + fun replaceAllSpaceWithDoubleSpace(): String { + return "👪 ❤️ Abc6 ;123".replaceAll(of: " ", with: " ") + } + fun replaceAllWithUnicodeEquivalence(): String { + return "Caf\u{65}\u{301}ABc".replaceAll(of: "\u{e9}", with: "X") + } + fun testEmptyString(): String { + return "".replaceAll(of: "//", with: "abc") + } + fun testEmptyOf(): String { + return "abc".replaceAll(of: "", with: "1") + } + fun testNoMatch(): String { + return "pqrS;asdf".replaceAll(of: ";;", with: "does_not_matter") + } + `) + + testCase := func(t *testing.T, funcName string, expected *interpreter.StringValue) { + t.Run(funcName, func(t *testing.T) { + result, err := inter.Invoke(funcName) + require.NoError(t, err) + + RequireValuesEqual( + t, + inter, + expected, + result, + ) + }) + } + + testCase(t, "replaceAll", interpreter.NewUnmeteredStringValue("👪||❤️")) + testCase(t, "replaceAllSpaceWithDoubleSpace", interpreter.NewUnmeteredStringValue("👪 ❤️ Abc6 ;123")) + testCase(t, "replaceAllWithUnicodeEquivalence", interpreter.NewUnmeteredStringValue("CafXABc")) + testCase(t, "testEmptyString", interpreter.NewUnmeteredStringValue("")) + testCase(t, "testEmptyOf", interpreter.NewUnmeteredStringValue("1a1b1c1")) + testCase(t, "testNoMatch", interpreter.NewUnmeteredStringValue("pqrS;asdf")) +}