Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework IorRaise impl to use EmptyValue, and add tests #3338

Merged
merged 5 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions arrow-libs/core/arrow-core/api/arrow-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ public final class arrow/core/EitherKt {
public final class arrow/core/EmptyValue {
public static final field INSTANCE Larrow/core/EmptyValue;
public final fun combine (Ljava/lang/Object;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
public final fun fold (Ljava/lang/Object;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
public final fun unbox (Ljava/lang/Object;)Ljava/lang/Object;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@ public inline fun <A> identity(a: A): A = a
*/
@PublishedApi
internal object EmptyValue {
@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
public inline fun <A> unbox(value: Any?): A =
if (value === this) null as A else value as A
@Suppress("UNCHECKED_CAST")
inline fun <A> unbox(value: Any?): A =
fold(value, { null as A }, ::identity)

public inline fun <T> combine(first: Any?, second: T, combine: (T, T) -> T): T =
if (first === EmptyValue) second else combine(first as T, second)
inline fun <T> combine(first: Any?, second: T, combine: (T, T) -> T): T =
fold(first, { second }, { t: T -> combine(t, second) })

@Suppress("UNCHECKED_CAST")
inline fun <T, R> fold(value: Any?, ifEmpty: () -> R, ifNotEmpty: (T) -> R): R =
if (value === EmptyValue) ifEmpty() else ifNotEmpty(value as T)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import arrow.core.Option
import arrow.core.Some
import arrow.core.getOrElse
import arrow.core.identity
import arrow.core.EmptyValue
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
import kotlin.experimental.ExperimentalTypeInference
Expand Down Expand Up @@ -89,12 +90,11 @@ public inline fun <A> option(block: OptionRaise.() -> A): Option<A> =
* [Arrow docs](https://arrow-kt.io/learn/typed-errors/working-with-typed-errors/#running-and-inspecting-results).
*/
public inline fun <Error, A> ior(noinline combineError: (Error, Error) -> Error, @BuilderInference block: IorRaise<Error>.() -> A): Ior<Error, A> {
val state: Atomic<Option<Error>> = Atomic(None)
return fold<Error, A, Ior<Error, A>>(
val state: Atomic<Any?> = Atomic(EmptyValue)
return fold(
{ block(IorRaise(combineError, state, this)) },
{ e -> throw e },
{ e -> Ior.Left(state.get().getOrElse { e }) },
{ a -> state.get().fold({ Ior.Right(a) }, { Ior.Both(it, a) }) }
{ e -> Ior.Left(EmptyValue.combine(state.get(), e, combineError)) },
{ a -> EmptyValue.fold(state.get(), { Ior.Right(a) }, { e: Error -> Ior.Both(e, a) }) }
)
}

Expand All @@ -114,15 +114,8 @@ public inline fun <Error, A> ior(noinline combineError: (Error, Error) -> Error,
* Read more about running a [Raise] computation in the
* [Arrow docs](https://arrow-kt.io/learn/typed-errors/working-with-typed-errors/#running-and-inspecting-results).
*/
public inline fun <Error, A> iorNel(noinline combineError: (NonEmptyList<Error>, NonEmptyList<Error>) -> NonEmptyList<Error> = { a, b -> a + b }, @BuilderInference block: IorRaise<NonEmptyList<Error>>.() -> A): IorNel<Error, A> {
val state: Atomic<Option<NonEmptyList<Error>>> = Atomic(None)
return fold<NonEmptyList<Error>, A, Ior<NonEmptyList<Error>, A>>(
{ block(IorRaise(combineError, state, this)) },
{ e -> throw e },
{ e -> Ior.Left(state.get().getOrElse { e }) },
{ a -> state.get().fold({ Ior.Right(a) }, { Ior.Both(it, a) }) }
)
}
public inline fun <Error, A> iorNel(noinline combineError: (NonEmptyList<Error>, NonEmptyList<Error>) -> NonEmptyList<Error> = { a, b -> a + b }, @BuilderInference block: IorRaise<NonEmptyList<Error>>.() -> A): IorNel<Error, A> =
ior(combineError, block)

/**
* Implementation of [Raise] used by `ignoreErrors`.
Expand Down Expand Up @@ -306,12 +299,12 @@ public class OptionRaise(private val raise: Raise<None>) : Raise<None> by raise
*/
public class IorRaise<Error> @PublishedApi internal constructor(
@PublishedApi internal val combineError: (Error, Error) -> Error,
private val state: Atomic<Option<Error>>,
private val state: Atomic<Any?>,
private val raise: Raise<Error>,
) : Raise<Error> {

@RaiseDSL
override fun raise(r: Error): Nothing = raise.raise(combine(r))
) : Raise<Error> by raise {
Comment on lines -313 to +304
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can remove the override for raise, right? Don't we otherwise combine the error when doing raise now, are we? 🤔

I think a test is missing here:

ior(String::plus) {
  Ior.Both("Hello", Unit).bind()
  raise("World")
} shouldBe Ior.Left("Hello World")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the final combine call happening inside the ior builder itself. This is so that try catch can recover from the error, and also so that any races in calling raise is only won by the call that actually reaches the builder, which is the behaviour of the rest of the builders.
That test passes (after adding a missing space) so no issues there.

removing the raise override is also good because it makes transitioning to context receivers much easier.

@Suppress("UNCHECKED_CAST")
@PublishedApi
internal fun combine(e: Error): Error = state.updateAndGet { EmptyValue.combine(it, e, combineError) } as Error

@RaiseDSL
@JvmName("bindAllIor")
Expand Down Expand Up @@ -343,23 +336,22 @@ public class IorRaise<Error> @PublishedApi internal constructor(
public fun <K, V> Map<K, Ior<Error, V>>.bindAll(): Map<K, V> =
mapValues { (_, v) -> v.bind() }

@PublishedApi
internal fun combine(other: Error): Error =
state.updateAndGet { prev ->
Some(prev.map { combineError(it, other) }.getOrElse { other })
}.getOrElse { other }

@RaiseDSL
public inline fun <A> recover(
@BuilderInference block: IorRaise<Error>.() -> A,
recover: (error: Error) -> A,
): A = when (val ior = ior(combineError, block)) {
is Ior.Both -> {
combine(ior.leftValue)
ior.rightValue
}

is Ior.Left -> recover(ior.value)
is Ior.Right -> ior.value
): A {
val state: Atomic<Any?> = Atomic(EmptyValue)
return recover<Error, A>({
try {
block(IorRaise(combineError, state, this))
} finally {
val accumulated = state.get()
if (accumulated != EmptyValue) {
@Suppress("UNCHECKED_CAST")
combine(accumulated as Error)
}
}
}, recover)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ class IorSpec : StringSpec({
} shouldBe Ior.Left("Hello, World!")
}

"Accumulates and short-circuits with raise" {
ior(String::plus) {
Ior.Both("Hello", Unit).bind()
raise(" World")
} shouldBe Ior.Left("Hello World")
}

"Ior rethrows exception" {
val boom = RuntimeException("Boom!")
shouldThrow<RuntimeException> {
Expand All @@ -78,10 +85,63 @@ class IorSpec : StringSpec({

"Recover works as expected" {
ior(String::plus) {
val one = recover({ Ior.Left("Hello").bind() }) { 1 }
val one = recover({
Ior.Both("Hi", Unit).bind()
Ior.Left("Hello").bind()
}) {
it shouldBe "Hello"
1
}
val two = Ior.Right(2).bind()
val three = Ior.Both(", World", 3).bind()
one + two + three
} shouldBe Ior.Both("Hi, World", 6)
}

"recover with throw" {
ior(String::plus) {
val one = try {
recover({
Ior.Both("Hi", Unit).bind()
throw RuntimeException("Hello")
}) {
unreachable()
}
} catch (e: RuntimeException) {
1
}
val two = Ior.Right(2).bind()
val three = Ior.Both(", World", 3).bind()
one + two + three
} shouldBe Ior.Both("Hi, World", 6)
}

"recover with raise is a no-op" {
ior(String::plus) {
val one: Int =
recover({
Ior.Both("Hi", Unit).bind()
Ior.Left(", Hello").bind()
}) {
raise(it)
}
val two = Ior.Right(2).bind()
val three = Ior.Both(", World", 3).bind()
one + two + three
} shouldBe Ior.Left("Hi, Hello")
}

"try catch can recover from raise" {
ior(String::plus) {
val one = try {
Ior.Both("Hi", Unit).bind()
Ior.Left("Hello").bind()
} catch (e: Throwable) {
1
}
val two = Ior.Right(2).bind()
val three = Ior.Both(", World", 3).bind()
one + two + three
} shouldBe Ior.Both(", World", 6)
} shouldBe Ior.Both("Hi, World", 6)
}
})