-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ported and added to test LogicT monad transformer
- Loading branch information
Showing
4 changed files
with
474 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
package logic | ||
|
||
import scala.util.* | ||
|
||
import cps.* | ||
|
||
trait CpsLogicMonad[M[+_]] extends CpsTryMonad[M] { | ||
|
||
def mzero: M[Nothing] | ||
|
||
def mplus[A](a: M[A], b: M[A]): M[A] | ||
|
||
def msplit[A](c: M[A]): M[Option[(Try[A], M[A])]] | ||
|
||
def interleave[A](a: M[A], b: M[A]): M[A] = { | ||
flatMap(msplit(a)) { sa => | ||
sa match | ||
case None => b | ||
case Some((ta, sa1)) => | ||
mplus(fromTry(ta), interleave(b, sa1)) | ||
} | ||
} | ||
|
||
// >>- in haskell LogicT | ||
def fairFlatMap[A, B](ma: M[A], mb: A => M[B]): M[B] = { | ||
flatMap(msplit(ma)) { sa => | ||
sa match | ||
case None => mzero | ||
case Some((ta, sa1)) => | ||
ta match | ||
case Success(a) => | ||
interleave(mb(a), fairFlatMap(sa1, mb)) | ||
case Failure(ex) => | ||
error(ex) | ||
} | ||
} | ||
|
||
def ifte[A, B](a: M[A], thenp: A => M[B], elsep: M[B]): M[B] = { | ||
flatMap(msplit(a)) { sc => | ||
sc match | ||
case None => elsep | ||
case Some((ta, sa)) => | ||
ta match | ||
case Success(a) => | ||
mplus(thenp(a), flatMap(sa)(thenp)) | ||
case Failure(ex) => | ||
error(ex) | ||
} | ||
} | ||
|
||
def once[A](a: M[A]): M[A] = { | ||
flatMap(msplit(a)) { sc => | ||
sc match | ||
case None => mzero | ||
case Some((ta, sa)) => | ||
fromTry(ta) | ||
} | ||
} | ||
|
||
type Observer[A] | ||
|
||
def observerCpsMonad: CpsMonad[Observer] | ||
|
||
def mObserveOne[A](ma:M[A]): Observer[Option[A]] | ||
|
||
def mObserveN[A](ma: M[A], n: Int): Observer[Seq[A]] = | ||
mFoldLeftN(ma, observerCpsMonad.pure(IndexedSeq.empty[A]), n) { (observer, fa) => | ||
observerCpsMonad.flatMap(observer)(seq => observerCpsMonad.map(fa)(a => seq :+ a)) | ||
} | ||
|
||
def mFoldLeft[A,B](ma:M[A], zero:Observer[B], op: (Observer[B],Observer[A])=>Observer[B]): Observer[B] | ||
|
||
def mFoldLeftN[A,B](ma: M[A], zero: Observer[B], n: Int)(op: (Observer[B],Observer[A])=>Observer[B]): Observer[B] | ||
|
||
} | ||
|
||
object CpsLogicMonad { | ||
|
||
type Aux[M[+_],F[_]] = CpsLogicMonad[M] { | ||
type Observer[T] = F[T] | ||
} | ||
|
||
} | ||
|
||
|
||
trait CpsLogicMonadContext[M[+_]] extends CpsTryMonadContext[M] { | ||
|
||
override def monad: CpsLogicMonad[M] | ||
|
||
} | ||
|
||
class CpsLogicMonadInstanceContextBody[M[+_]](m:CpsLogicMonad[M]) extends CpsLogicMonadContext[M] { | ||
override def monad: CpsLogicMonad[M] = m | ||
} | ||
|
||
trait CpsLogicMonadInstanceContext[M[+_]] extends CpsLogicMonad[M] { | ||
|
||
override type Context = CpsLogicMonadInstanceContextBody[M] | ||
|
||
override def apply[T](op: CpsLogicMonadInstanceContextBody[M] => M[T]): M[T] = { | ||
op(new CpsLogicMonadInstanceContextBody[M](this)) | ||
} | ||
|
||
} | ||
|
||
def all[M[+_],A](collection: IterableOnce[A])(using m:CpsLogicMonad[M]): M[A] = | ||
def allIt(it: Iterator[A]): M[A] = | ||
if (it.hasNext) { | ||
m.mplus(m.pure(it.next()), allIt(it)) | ||
} else { | ||
m.mzero | ||
} | ||
allIt(collection.iterator) | ||
|
||
|
||
extension [M[+_],A](ma: M[A])(using m:CpsLogicMonad[M]) | ||
|
||
def filter(p: A => Boolean): M[A] = | ||
m.flatMap(m.msplit(ma)) { sc => | ||
sc match | ||
case None => m.mzero | ||
case Some((ta, sa)) => | ||
ta match | ||
case Success(a) => | ||
if (p(a)) { | ||
m.mplus(m.pure(a), sa.filter(p)) | ||
} else { | ||
sa.filter(p) | ||
} | ||
case Failure(ex) => | ||
m.error(ex) | ||
} | ||
|
||
def observeN(n: Int): m.Observer[Seq[A]] = | ||
m.mObserveN(ma,n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package logic | ||
|
||
import cps.* | ||
import logic.logict.{*,given} | ||
import org.junit.{Ignore, Test} | ||
|
||
|
||
|
||
class QueensTest { | ||
|
||
@Test | ||
def testLogicM() = { | ||
import QueensTest.* | ||
val r = queens[LogicM](8).observeN(2) | ||
// observer monad is identity monad here. | ||
|
||
assert(r.size == 2) | ||
assert(QueensTest.isCorrect(r(0))) | ||
assert(QueensTest.isCorrect(r(1))) | ||
} | ||
|
||
} | ||
|
||
object QueensTest { | ||
|
||
case class Pos(x:Int, y:Int) | ||
|
||
def isBeat(p1:Pos, p2:Pos):Boolean = | ||
(p1.x == p2.x) || (p1.y == p2.y) || (p1.x - p1.y == p2.x - p2.y) || (p1.x + p1.y == p2.x + p2.y) | ||
|
||
def isFree(p:Pos, prefix:IndexedSeq[Pos]):Boolean = | ||
prefix.forall(pp => !isBeat(p, pp)) | ||
|
||
def queens[M[+_]:CpsLogicMonad](n:Int, prefix:IndexedSeq[Pos]=IndexedSeq.empty): M[IndexedSeq[Pos]] = reify[M] { | ||
if (prefix.length >= n) then | ||
prefix | ||
else | ||
val nextPos = (1 to n).map(Pos(prefix.length+1,_)).filter(pos => isFree(pos, prefix)) | ||
reflect(queens(n, prefix :+ reflect(all(nextPos)))) | ||
} | ||
|
||
def isCorrect(queens:IndexedSeq[Pos]):Boolean = { | ||
queens.forall(p1 => queens.forall(p2 => | ||
(p1 == p2) || !isBeat(p1,p2) | ||
)) | ||
} | ||
|
||
|
||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
package logic.backtrm | ||
|
||
//Term implementation of backtracking monad transformer | ||
// from Deriving Backtracking Monad Transformers by Ralf Hinze | ||
// https://dl.acm.org/doi/10.1145/351240.351258 | ||
// | ||
// changed to support flatMapTry | ||
|
||
import cps.* | ||
import cps.syntax.* | ||
import scala.util.* | ||
import scala.util.control.NonFatal | ||
|
||
|
||
|
||
sealed trait BacktrT[M[_],A] { | ||
|
||
def map[B](f: A=>B): BacktrT[M,B] | ||
def flatMap[B](f: A=>BacktrT[M,B]): BacktrT[M,B] | ||
def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): BacktrT[M,B] | ||
def append(next: =>BacktrT[M,A]): BacktrT[M,A] | ||
|
||
|
||
} | ||
|
||
object BacktrT{ | ||
|
||
case class Zero[M[_],A]() extends BacktrT[M,A] { | ||
override def map[B](f: A=>B): Zero[M,B] = Zero() | ||
override def flatMap[B](f: A=>BacktrT[M,B]): Zero[M,B] = Zero() | ||
override def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): Zero[M,B] = Zero() | ||
override def append(next: =>BacktrT[M, A]): BacktrT[M, A] = next | ||
|
||
} | ||
|
||
case class Cons[M[_]:CpsTryMonad,A](h: A, t: () => BacktrT[M,A]) extends BacktrT[M,A] { | ||
|
||
def map[B](f: A=>B): Cons[M,B] = Cons(f(h), () => t().map(f)) | ||
|
||
def flatMap[B](f: A=>BacktrT[M,B]): BacktrT[M,B] = | ||
f(h).append(t().flatMap(f)) | ||
|
||
override def append(next: =>BacktrT[M, A]): BacktrT[M, A] = | ||
Cons(h, () => t().append(next)) | ||
|
||
override def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): BacktrT[M,B] = | ||
try { | ||
f(Success(h)).append(t().flatMapTry(f)) | ||
} catch { | ||
case NonFatal(ex) => | ||
PromoteBind(summon[CpsTryMonad[M]].error(ex), { | ||
case Success(_) => Success(Zero()) // impossible | ||
case Failure(ex) => Failure(ex) | ||
}) | ||
} | ||
|
||
} | ||
|
||
|
||
case class PromoteBind[M[_],A,B](ma: M[A], transform: Try[A]=>Try[BacktrT[M,B]]) extends BacktrT[M,B] { | ||
|
||
def map[C](f: B=>C): PromoteBind[M,A,C] = PromoteBind(ma, | ||
{ case Success(a) => transform(Success(a)).map(_.map(f)) | ||
case Failure(ex) => Failure(ex) | ||
} | ||
) | ||
|
||
def flatMap[C](g: B=>BacktrT[M,C]): BacktrT[M,C] = PromoteBind(ma, { | ||
case Success(a) => transform(Success(a)).map(_.flatMap(g)) | ||
case Failure(ex) => Failure(ex) | ||
}) | ||
|
||
override def append(next: => BacktrT[M, B]): BacktrT[M, B] = | ||
PromoteBind(ma, { | ||
case Success(a) => transform(Success(a)).map(_.append(next)) | ||
case Failure(ex) => Failure(ex) | ||
}) | ||
|
||
|
||
override def flatMapTry[C](g: Try[B]=>BacktrT[M,C]): BacktrT[M,C] = | ||
PromoteBind( | ||
ma, { x => | ||
transform(x) match | ||
case Success(bm) => Success(bm.flatMapTry(g)) | ||
case Failure(ex) => Success(g(Failure(ex))) | ||
} | ||
) | ||
|
||
|
||
} | ||
|
||
|
||
|
||
class CpsBacktrTryMonad[M[_]:CpsTryMonad] extends CpsTryMonad[[X]=>>BacktrT[M,X]] with CpsTryMonadInstanceContext[[X]=>>BacktrT[M,X]] { | ||
|
||
override def pure[A](a: A): BacktrT[M, A] = Cons(a,()=>Zero()) | ||
|
||
override def map[A,B](fa: BacktrT[M,A])(f: A=>B): BacktrT[M,B] = | ||
fa.map(f) | ||
|
||
override def flatMap[A,B](fa: BacktrT[M,A])(f: A=>BacktrT[M,B]): BacktrT[M,B] = | ||
fa.flatMap(f) | ||
|
||
override def error[A](e: Throwable): BacktrT[M, A] = | ||
PromoteBind(summon[CpsTryMonad[M]].error(e), { | ||
case Success(_) => Failure(e) // impossible | ||
case Failure(ex) => Failure(ex) | ||
}) | ||
|
||
override def flatMapTry[A,B](fa: BacktrT[M,A])(f: Try[A]=> BacktrT[M,B]): BacktrT[M,B] = | ||
fa.flatMapTry(f) | ||
|
||
|
||
|
||
} | ||
|
||
} |
Oops, something went wrong.