Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TraverseFilter.mapAccumulateFilter #4561

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions core/src/main/scala/cats/TraverseFilter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ trait TraverseFilter[F[_]] extends FunctorFilter[F] {
override def mapFilter[A, B](fa: F[A])(f: A => Option[B]): F[B] =
traverseFilter[Id, A, B](fa)(f)

def mapAccumulateFilter[S, A, B](init: S, fa: F[A])(f: (S, A) => (S, Option[B])): (S, F[B]) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that a scaladoc comment (perhaps, with a short usage example) wouldn't hurt and could come handy here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost done, I should fix the code violations. Is the language correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

traverseFilter(fa)(a => State(s => f(s, a))).run(init).value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we override this for some of the built in collections? State is rather slow so we should avoid it for List, Vector, Chain, NonEmptyList, NonEmptyVector, ...

Copy link
Contributor Author

@Masynchin Masynchin Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can copy StaticMethods.mapAccumulateFromStrictFunctor for mapAccumulateFilter, this will cover for some collections


/**
* Removes duplicate elements from a list, keeping only the first occurrence.
*/
Expand Down Expand Up @@ -184,6 +187,8 @@ object TraverseFilter {
typeClassInstance.filterA[G, A](self)(f)(G)
def traverseEither[G[_], B, C](f: A => G[Either[C, B]])(g: (A, C) => G[Unit])(implicit G: Monad[G]): G[F[B]] =
typeClassInstance.traverseEither[G, A, B, C](self)(f)(g)(G)
def mapAccumulateFilter[S, B](init: S)(f: (S, A) => (S, Option[B])): (S, F[B]) =
typeClassInstance.mapAccumulateFilter[S, A, B](init, self)(f)
def ordDistinct(implicit O: Order[A]): F[A] = typeClassInstance.ordDistinct(self)
def hashDistinct(implicit H: Hash[A]): F[A] = typeClassInstance.hashDistinct(self)
}
Expand Down
13 changes: 13 additions & 0 deletions tests/shared/src/test/scala/cats/tests/TraverseFilterSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ abstract class TraverseFilterSuite[F[_]: TraverseFilter](name: String)(implicit

implicit def T: Traverse[F] = implicitly[TraverseFilter[F]].traverse

test(s"TraverseFilter[$name].mapAccumulateFilter") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it does not check the default implementation (based on State), does it?
Except maybe one for Stream, but I wouldn't count on it.

For testing default implementations we usually use ListWrapper from testkit:
ListWrapper.scala

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had straightforwardly copy-pasted and updated tests from mapAccumulate here:

test(s"Traverse[$name].mapAccumulate") {
forAll { (init: Int, fa: F[Int], fn: ((Int, Int)) => (Int, Int)) =>
val lhs = fa.mapAccumulate(init)((s, a) => fn((s, a)))
val rhs = fa.foldLeft((init, List.empty[Int])) { case ((s1, acc), a) =>
val (s2, b) = fn((s1, a))
(s2, b :: acc)
}
assert(lhs.map(_.toList) === rhs.map(_.reverse))
}
}

If that tests doesn't test default implementation too, I can update mapAccumulateFilter tests. Can you provide an example of how to pass ListWrapper?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For testing default implementations we usually use ListWrapper from testkit: ListWrapper.scala

I am failing to understand how to use ListWrapper to test mapAccumulateFilter. I was looking for examples in other Suites, but it is either suites for data and not typeclasses (OptionT, Try, etc.), or it refers to <Typeclass>Tests[ListWrapper].<methodToTest>, like in the ApplicativeSuite:

implicit val listwrapperApplicative: Applicative[ListWrapper] = ListWrapper.applicative
implicit val listwrapperCoflatMap: CoflatMap[ListWrapper] = Applicative.coflatMap[ListWrapper]
checkAll("Applicative[ListWrapper].coflatMap", CoflatMapTests[ListWrapper].coflatMap[String, String, String])

which I can not apply here, because we don't have TraverseFilter[?].mapAccumulateFilter. Am I missing something?

Copy link
Contributor

@satorg satorg Apr 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My apologies for the delay – I was snowed under a bit. Actually, there's TraverseFilter for ListWrapper:

val traverseFilter: TraverseFilter[ListWrapper] = {
val F = TraverseFilter[List]
new TraverseFilter[ListWrapper] {
def traverse = ListWrapper.traverse
def traverseFilter[G[_], A, B](
fa: ListWrapper[A]
)(f: A => G[Option[B]])(implicit G: Applicative[G]): G[ListWrapper[B]] =
G.map(F.traverseFilter(fa.list)(f))(ListWrapper.apply)
}
}

To test the default implementation you can either call it directly:

ListWrapper.traverseFilter.mapAccumulateFilter(...)

or make it an implicit in the scope:

implicit val listWrapperTraverseFilter: TraverseFilter[ListWrapper] = ListWrapper.traverseFilter

And then you can work with TraverseFilter for ListWrapper as usual.

forAll { (init: Int, fa: F[Int], fn: ((Int, Int)) => (Int, Option[Int])) =>
val lhs = fa.mapAccumulateFilter(init)((s, a) => fn((s, a)))

val rhs = fa.foldLeft((init, List.empty[Int])) { case ((s1, acc), a) =>
val (s2, b) = fn((s1, a))
(s2, b.fold(acc)(_ :: acc))
}

assert(lhs.map(_.toList) === rhs.map(_.reverse))
}
}

test(s"TraverseFilter[$name].ordDistinct") {
forAll { (fa: F[Int]) =>
fa.ordDistinct.toList === fa.toList.distinct
Expand Down