Skip to content

Commit

Permalink
enforse lazy evaluation of infinite sequenc in logic
Browse files Browse the repository at this point in the history
  • Loading branch information
rssh committed Dec 27, 2023
1 parent 7dfb8f8 commit fe0ab06
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 191 deletions.
2 changes: 2 additions & 0 deletions shared/src/main/scala/cps/CpsMonadConversion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ object CpsMonadConversion:
given identityConversion[F[_]]: CpsMonadConversion[F,F] with
def apply[T](ft:F[T]): F[T] = ft



end CpsMonadConversion

45 changes: 35 additions & 10 deletions shared/src/test/scala/logic/CpsLogicMonad.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import scala.util.*

import cps.*

trait CpsLogicMonad[M[+_]] extends CpsTryMonad[M] {
trait CpsLogicMonad[M[_]] extends CpsTryMonad[M] {

def mzero: M[Nothing]
override type Context <: CpsLogicMonadContext[M]

def mplus[A](a: M[A], b: M[A]): M[A]
def mzero[A]: M[A]

def mplus[A](a: M[A], b: => M[A]): M[A]

def msplit[A](c: M[A]): M[Option[(Try[A], M[A])]]

Expand Down Expand Up @@ -59,7 +61,7 @@ trait CpsLogicMonad[M[+_]] extends CpsTryMonad[M] {

type Observer[A]

def observerCpsMonad: CpsMonad[Observer]
def observerCpsMonad: CpsTryMonad[Observer]

def mObserveOne[A](ma:M[A]): Observer[Option[A]]

Expand All @@ -68,7 +70,12 @@ trait CpsLogicMonad[M[+_]] extends CpsTryMonad[M] {
observerCpsMonad.flatMap(observer)(seq => observerCpsMonad.map(fa)(a => seq :+ a))
}

def mFoldLeft[A,B](ma:M[A], zero:Observer[B], op: (Observer[B],Observer[A])=>Observer[B]): Observer[B]
def mObserveAll[A](ma: M[A]): Observer[Seq[A]] =
mFoldLeft(ma, observerCpsMonad.pure(IndexedSeq.empty[A])) { (observer, fa) =>
observerCpsMonad.flatMap(observer)(seq => observerCpsMonad.map(fa)(a => seq :+ a))
}

def mFoldLeft[A,B](ma:M[A], zero:Observer[B])(op: (Observer[B],Observer[A])=>Observer[B]): Observer[B]

def mFoldLeftN[A,B](ma: M[A], zero: Observer[B], n: Int)(op: (Observer[B],Observer[A])=>Observer[B]): Observer[B]

Expand All @@ -83,17 +90,17 @@ object CpsLogicMonad {
}


trait CpsLogicMonadContext[M[+_]] extends CpsTryMonadContext[M] {
trait CpsLogicMonadContext[M[_]] extends CpsTryMonadContext[M] {

override def monad: CpsLogicMonad[M]

}

class CpsLogicMonadInstanceContextBody[M[+_]](m:CpsLogicMonad[M]) extends CpsLogicMonadContext[M] {
class CpsLogicMonadInstanceContextBody[M[_]](m:CpsLogicMonad[M]) extends CpsLogicMonadContext[M] {
override def monad: CpsLogicMonad[M] = m
}

trait CpsLogicMonadInstanceContext[M[+_]] extends CpsLogicMonad[M] {
trait CpsLogicMonadInstanceContext[M[_]] extends CpsLogicMonad[M] {

override type Context = CpsLogicMonadInstanceContextBody[M]

Expand All @@ -103,7 +110,7 @@ trait CpsLogicMonadInstanceContext[M[+_]] extends CpsLogicMonad[M] {

}

def all[M[+_],A](collection: IterableOnce[A])(using m:CpsLogicMonad[M]): M[A] =
def all[M[_],A](collection: IterableOnce[A])(using m:CpsLogicMonad[M]): M[A] =
def allIt(it: Iterator[A]): M[A] =
if (it.hasNext) {
m.mplus(m.pure(it.next()), allIt(it))
Expand All @@ -113,7 +120,10 @@ def all[M[+_],A](collection: IterableOnce[A])(using m:CpsLogicMonad[M]): M[A] =
allIt(collection.iterator)


extension [M[+_],A](ma: M[A])(using m:CpsLogicMonad[M])



extension [M[_],A](ma: M[A])(using m:CpsLogicMonad[M])

def filter(p: A => Boolean): M[A] =
m.flatMap(m.msplit(ma)) { sc =>
Expand All @@ -133,3 +143,18 @@ extension [M[+_],A](ma: M[A])(using m:CpsLogicMonad[M])

def observeN(n: Int): m.Observer[Seq[A]] =
m.mObserveN(ma,n)

def observeOne: m.Observer[Option[A]] =
m.mObserveOne(ma)

def observeAll: m.Observer[Seq[A]] =
m.mObserveAll(ma)

def |+|(mb: =>M[A]): M[A] =
m.mplus(ma,mb)


transparent inline def guard[M[_]:CpsLogicMonad](p: =>Boolean)(using mc:CpsLogicMonadContext[M]): Unit =
reflect{
if (p) mc.monad.pure(()) else mc.monad.mzero
}
46 changes: 46 additions & 0 deletions shared/src/test/scala/logic/GrandParentsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package logic

import cps.*
import cps.monads.{*, given}
import logict.{*,given}

import org.junit.Test


class GrandParentsTest {

@Test
def testLogicM() = {
import GrandParentsTest.*
val r = grandParent[LogicM]("Anne").observeAll
println(s"GrandParentTest: r=$r")
assert(r.size == 2)
assert(r.contains("Sarah"))
assert(r.contains("Arnold"))
}



}

object GrandParentsTest {

case class IsParentOf(parent:String, child:String)

def parents = Seq(
IsParentOf("Sarah", "John"),
IsParentOf("Arnold", "John"),
IsParentOf("John", "Anne"),
)

def grandParent[M[_]:CpsLogicMonad](name:String):M[String] = reify[M]{
// TODO: fill bug report in dotty.
// shopuld be search
import cps.CpsMonadConversion.given
val IsParentOf(p,c) = reflect(all(parents))
guard(c == name)
val IsParentOf(gp,_) = reflect(all(parents.filter(_.child == p)))
gp
}

}
41 changes: 41 additions & 0 deletions shared/src/test/scala/logic/LogicTTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package logic

import scala.concurrent.duration.*
import scala.util.*
import cps.*
import logict.*
import org.junit.Test

class LogicTTest {

import LogicTTest.*

@Test
def testNatCB():Unit = {

val cbNat = nats[[A] =>> LogicT[ComputationBound,A]]

val cb4 = cbNat.observeN(4)
assert(cb4.run(1.second) == Success(Seq(1,2,3,4)))

val cbOdd = odds[[A] =>> LogicT[ComputationBound,A]]



}


}

object LogicTTest {

def nats[M[_]](using m: CpsLogicMonad[M]): M[Int] =
m.pure(1) |+| reify[M] { 1 + reflect(nats) }

def odds[M[_]](using m: CpsLogicMonad[M]): M[Int] =
m.pure(1) |+| reify[M] { 2 + reflect(odds) }




}
2 changes: 1 addition & 1 deletion shared/src/test/scala/logic/QueensTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class QueensTest {
import QueensTest.*
val r = queens[LogicM](8).observeN(2)
// observer monad is identity monad here.

println(s"QueensTest:r=$r")
assert(r.size == 2)
assert(QueensTest.isCorrect(r(0)))
assert(QueensTest.isCorrect(r(1)))
Expand Down
Loading

0 comments on commit fe0ab06

Please sign in to comment.