Skip to content

Commit

Permalink
ported and added to test LogicT monad transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
rssh committed Dec 25, 2023
1 parent 81a5ae1 commit c9e0f33
Show file tree
Hide file tree
Showing 4 changed files with 474 additions and 1 deletion.
135 changes: 135 additions & 0 deletions shared/src/test/scala/logic/CpsLogicMonad.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package logic

import scala.util.*

import cps.*

trait CpsLogicMonad[M[+_]] extends CpsTryMonad[M] {

def mzero: M[Nothing]

def mplus[A](a: M[A], b: M[A]): M[A]

def msplit[A](c: M[A]): M[Option[(Try[A], M[A])]]

def interleave[A](a: M[A], b: M[A]): M[A] = {
flatMap(msplit(a)) { sa =>
sa match
case None => b
case Some((ta, sa1)) =>
mplus(fromTry(ta), interleave(b, sa1))
}
}

// >>- in haskell LogicT
def fairFlatMap[A, B](ma: M[A], mb: A => M[B]): M[B] = {
flatMap(msplit(ma)) { sa =>
sa match
case None => mzero
case Some((ta, sa1)) =>
ta match
case Success(a) =>
interleave(mb(a), fairFlatMap(sa1, mb))
case Failure(ex) =>
error(ex)
}
}

def ifte[A, B](a: M[A], thenp: A => M[B], elsep: M[B]): M[B] = {
flatMap(msplit(a)) { sc =>
sc match
case None => elsep
case Some((ta, sa)) =>
ta match
case Success(a) =>
mplus(thenp(a), flatMap(sa)(thenp))
case Failure(ex) =>
error(ex)
}
}

def once[A](a: M[A]): M[A] = {
flatMap(msplit(a)) { sc =>
sc match
case None => mzero
case Some((ta, sa)) =>
fromTry(ta)
}
}

type Observer[A]

def observerCpsMonad: CpsMonad[Observer]

def mObserveOne[A](ma:M[A]): Observer[Option[A]]

def mObserveN[A](ma: M[A], n: Int): Observer[Seq[A]] =
mFoldLeftN(ma, observerCpsMonad.pure(IndexedSeq.empty[A]), n) { (observer, fa) =>
observerCpsMonad.flatMap(observer)(seq => observerCpsMonad.map(fa)(a => seq :+ a))
}

def mFoldLeft[A,B](ma:M[A], zero:Observer[B], op: (Observer[B],Observer[A])=>Observer[B]): Observer[B]

def mFoldLeftN[A,B](ma: M[A], zero: Observer[B], n: Int)(op: (Observer[B],Observer[A])=>Observer[B]): Observer[B]

}

object CpsLogicMonad {

type Aux[M[+_],F[_]] = CpsLogicMonad[M] {
type Observer[T] = F[T]
}

}


trait CpsLogicMonadContext[M[+_]] extends CpsTryMonadContext[M] {

override def monad: CpsLogicMonad[M]

}

class CpsLogicMonadInstanceContextBody[M[+_]](m:CpsLogicMonad[M]) extends CpsLogicMonadContext[M] {
override def monad: CpsLogicMonad[M] = m
}

trait CpsLogicMonadInstanceContext[M[+_]] extends CpsLogicMonad[M] {

override type Context = CpsLogicMonadInstanceContextBody[M]

override def apply[T](op: CpsLogicMonadInstanceContextBody[M] => M[T]): M[T] = {
op(new CpsLogicMonadInstanceContextBody[M](this))
}

}

def all[M[+_],A](collection: IterableOnce[A])(using m:CpsLogicMonad[M]): M[A] =
def allIt(it: Iterator[A]): M[A] =
if (it.hasNext) {
m.mplus(m.pure(it.next()), allIt(it))
} else {
m.mzero
}
allIt(collection.iterator)


extension [M[+_],A](ma: M[A])(using m:CpsLogicMonad[M])

def filter(p: A => Boolean): M[A] =
m.flatMap(m.msplit(ma)) { sc =>
sc match
case None => m.mzero
case Some((ta, sa)) =>
ta match
case Success(a) =>
if (p(a)) {
m.mplus(m.pure(a), sa.filter(p))
} else {
sa.filter(p)
}
case Failure(ex) =>
m.error(ex)
}

def observeN(n: Int): m.Observer[Seq[A]] =
m.mObserveN(ma,n)
51 changes: 51 additions & 0 deletions shared/src/test/scala/logic/QueensTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package logic

import cps.*
import logic.logict.{*,given}
import org.junit.{Ignore, Test}



class QueensTest {

@Test
def testLogicM() = {
import QueensTest.*
val r = queens[LogicM](8).observeN(2)
// observer monad is identity monad here.

assert(r.size == 2)
assert(QueensTest.isCorrect(r(0)))
assert(QueensTest.isCorrect(r(1)))
}

}

object QueensTest {

case class Pos(x:Int, y:Int)

def isBeat(p1:Pos, p2:Pos):Boolean =
(p1.x == p2.x) || (p1.y == p2.y) || (p1.x - p1.y == p2.x - p2.y) || (p1.x + p1.y == p2.x + p2.y)

def isFree(p:Pos, prefix:IndexedSeq[Pos]):Boolean =
prefix.forall(pp => !isBeat(p, pp))

def queens[M[+_]:CpsLogicMonad](n:Int, prefix:IndexedSeq[Pos]=IndexedSeq.empty): M[IndexedSeq[Pos]] = reify[M] {
if (prefix.length >= n) then
prefix
else
val nextPos = (1 to n).map(Pos(prefix.length+1,_)).filter(pos => isFree(pos, prefix))
reflect(queens(n, prefix :+ reflect(all(nextPos))))
}

def isCorrect(queens:IndexedSeq[Pos]):Boolean = {
queens.forall(p1 => queens.forall(p2 =>
(p1 == p2) || !isBeat(p1,p2)
))
}




}
117 changes: 117 additions & 0 deletions shared/src/test/scala/logic/backtrm/BacktrT.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package logic.backtrm

//Term implementation of backtracking monad transformer
// from Deriving Backtracking Monad Transformers by Ralf Hinze
// https://dl.acm.org/doi/10.1145/351240.351258
//
// changed to support flatMapTry

import cps.*
import cps.syntax.*
import scala.util.*
import scala.util.control.NonFatal



sealed trait BacktrT[M[_],A] {

def map[B](f: A=>B): BacktrT[M,B]
def flatMap[B](f: A=>BacktrT[M,B]): BacktrT[M,B]
def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): BacktrT[M,B]
def append(next: =>BacktrT[M,A]): BacktrT[M,A]


}

object BacktrT{

case class Zero[M[_],A]() extends BacktrT[M,A] {
override def map[B](f: A=>B): Zero[M,B] = Zero()
override def flatMap[B](f: A=>BacktrT[M,B]): Zero[M,B] = Zero()
override def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): Zero[M,B] = Zero()
override def append(next: =>BacktrT[M, A]): BacktrT[M, A] = next

}

case class Cons[M[_]:CpsTryMonad,A](h: A, t: () => BacktrT[M,A]) extends BacktrT[M,A] {

def map[B](f: A=>B): Cons[M,B] = Cons(f(h), () => t().map(f))

def flatMap[B](f: A=>BacktrT[M,B]): BacktrT[M,B] =
f(h).append(t().flatMap(f))

override def append(next: =>BacktrT[M, A]): BacktrT[M, A] =
Cons(h, () => t().append(next))

override def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): BacktrT[M,B] =
try {
f(Success(h)).append(t().flatMapTry(f))
} catch {
case NonFatal(ex) =>
PromoteBind(summon[CpsTryMonad[M]].error(ex), {
case Success(_) => Success(Zero()) // impossible
case Failure(ex) => Failure(ex)
})
}

}


case class PromoteBind[M[_],A,B](ma: M[A], transform: Try[A]=>Try[BacktrT[M,B]]) extends BacktrT[M,B] {

def map[C](f: B=>C): PromoteBind[M,A,C] = PromoteBind(ma,
{ case Success(a) => transform(Success(a)).map(_.map(f))
case Failure(ex) => Failure(ex)
}
)

def flatMap[C](g: B=>BacktrT[M,C]): BacktrT[M,C] = PromoteBind(ma, {
case Success(a) => transform(Success(a)).map(_.flatMap(g))
case Failure(ex) => Failure(ex)
})

override def append(next: => BacktrT[M, B]): BacktrT[M, B] =
PromoteBind(ma, {
case Success(a) => transform(Success(a)).map(_.append(next))
case Failure(ex) => Failure(ex)
})


override def flatMapTry[C](g: Try[B]=>BacktrT[M,C]): BacktrT[M,C] =
PromoteBind(
ma, { x =>
transform(x) match
case Success(bm) => Success(bm.flatMapTry(g))
case Failure(ex) => Success(g(Failure(ex)))
}
)


}



class CpsBacktrTryMonad[M[_]:CpsTryMonad] extends CpsTryMonad[[X]=>>BacktrT[M,X]] with CpsTryMonadInstanceContext[[X]=>>BacktrT[M,X]] {

override def pure[A](a: A): BacktrT[M, A] = Cons(a,()=>Zero())

override def map[A,B](fa: BacktrT[M,A])(f: A=>B): BacktrT[M,B] =
fa.map(f)

override def flatMap[A,B](fa: BacktrT[M,A])(f: A=>BacktrT[M,B]): BacktrT[M,B] =
fa.flatMap(f)

override def error[A](e: Throwable): BacktrT[M, A] =
PromoteBind(summon[CpsTryMonad[M]].error(e), {
case Success(_) => Failure(e) // impossible
case Failure(ex) => Failure(ex)
})

override def flatMapTry[A,B](fa: BacktrT[M,A])(f: Try[A]=> BacktrT[M,B]): BacktrT[M,B] =
fa.flatMapTry(f)



}

}
Loading

0 comments on commit c9e0f33

Please sign in to comment.