diff --git a/shared/src/test/scala/logic/CpsLogicMonad.scala b/shared/src/test/scala/logic/CpsLogicMonad.scala new file mode 100644 index 00000000..1fe7dcf3 --- /dev/null +++ b/shared/src/test/scala/logic/CpsLogicMonad.scala @@ -0,0 +1,135 @@ +package logic + +import scala.util.* + +import cps.* + +trait CpsLogicMonad[M[+_]] extends CpsTryMonad[M] { + + def mzero: M[Nothing] + + def mplus[A](a: M[A], b: M[A]): M[A] + + def msplit[A](c: M[A]): M[Option[(Try[A], M[A])]] + + def interleave[A](a: M[A], b: M[A]): M[A] = { + flatMap(msplit(a)) { sa => + sa match + case None => b + case Some((ta, sa1)) => + mplus(fromTry(ta), interleave(b, sa1)) + } + } + + // >>- in haskell LogicT + def fairFlatMap[A, B](ma: M[A], mb: A => M[B]): M[B] = { + flatMap(msplit(ma)) { sa => + sa match + case None => mzero + case Some((ta, sa1)) => + ta match + case Success(a) => + interleave(mb(a), fairFlatMap(sa1, mb)) + case Failure(ex) => + error(ex) + } + } + + def ifte[A, B](a: M[A], thenp: A => M[B], elsep: M[B]): M[B] = { + flatMap(msplit(a)) { sc => + sc match + case None => elsep + case Some((ta, sa)) => + ta match + case Success(a) => + mplus(thenp(a), flatMap(sa)(thenp)) + case Failure(ex) => + error(ex) + } + } + + def once[A](a: M[A]): M[A] = { + flatMap(msplit(a)) { sc => + sc match + case None => mzero + case Some((ta, sa)) => + fromTry(ta) + } + } + + type Observer[A] + + def observerCpsMonad: CpsMonad[Observer] + + def mObserveOne[A](ma:M[A]): Observer[Option[A]] + + def mObserveN[A](ma: M[A], n: Int): Observer[Seq[A]] = + mFoldLeftN(ma, observerCpsMonad.pure(IndexedSeq.empty[A]), n) { (observer, fa) => + observerCpsMonad.flatMap(observer)(seq => observerCpsMonad.map(fa)(a => seq :+ a)) + } + + def mFoldLeft[A,B](ma:M[A], zero:Observer[B], op: (Observer[B],Observer[A])=>Observer[B]): Observer[B] + + def mFoldLeftN[A,B](ma: M[A], zero: Observer[B], n: Int)(op: (Observer[B],Observer[A])=>Observer[B]): Observer[B] + +} + +object CpsLogicMonad { + + type Aux[M[+_],F[_]] = CpsLogicMonad[M] { + type Observer[T] = F[T] + } + +} + + +trait CpsLogicMonadContext[M[+_]] extends CpsTryMonadContext[M] { + + override def monad: CpsLogicMonad[M] + +} + +class CpsLogicMonadInstanceContextBody[M[+_]](m:CpsLogicMonad[M]) extends CpsLogicMonadContext[M] { + override def monad: CpsLogicMonad[M] = m +} + +trait CpsLogicMonadInstanceContext[M[+_]] extends CpsLogicMonad[M] { + + override type Context = CpsLogicMonadInstanceContextBody[M] + + override def apply[T](op: CpsLogicMonadInstanceContextBody[M] => M[T]): M[T] = { + op(new CpsLogicMonadInstanceContextBody[M](this)) + } + +} + +def all[M[+_],A](collection: IterableOnce[A])(using m:CpsLogicMonad[M]): M[A] = + def allIt(it: Iterator[A]): M[A] = + if (it.hasNext) { + m.mplus(m.pure(it.next()), allIt(it)) + } else { + m.mzero + } + allIt(collection.iterator) + + +extension [M[+_],A](ma: M[A])(using m:CpsLogicMonad[M]) + + def filter(p: A => Boolean): M[A] = + m.flatMap(m.msplit(ma)) { sc => + sc match + case None => m.mzero + case Some((ta, sa)) => + ta match + case Success(a) => + if (p(a)) { + m.mplus(m.pure(a), sa.filter(p)) + } else { + sa.filter(p) + } + case Failure(ex) => + m.error(ex) + } + + def observeN(n: Int): m.Observer[Seq[A]] = + m.mObserveN(ma,n) diff --git a/shared/src/test/scala/logic/QueensTest.scala b/shared/src/test/scala/logic/QueensTest.scala new file mode 100644 index 00000000..736dbcd8 --- /dev/null +++ b/shared/src/test/scala/logic/QueensTest.scala @@ -0,0 +1,51 @@ +package logic + +import cps.* +import logic.logict.{*,given} +import org.junit.{Ignore, Test} + + + +class QueensTest { + + @Test + def testLogicM() = { + import QueensTest.* + val r = queens[LogicM](8).observeN(2) + // observer monad is identity monad here. + + assert(r.size == 2) + assert(QueensTest.isCorrect(r(0))) + assert(QueensTest.isCorrect(r(1))) + } + +} + +object QueensTest { + + case class Pos(x:Int, y:Int) + + def isBeat(p1:Pos, p2:Pos):Boolean = + (p1.x == p2.x) || (p1.y == p2.y) || (p1.x - p1.y == p2.x - p2.y) || (p1.x + p1.y == p2.x + p2.y) + + def isFree(p:Pos, prefix:IndexedSeq[Pos]):Boolean = + prefix.forall(pp => !isBeat(p, pp)) + + def queens[M[+_]:CpsLogicMonad](n:Int, prefix:IndexedSeq[Pos]=IndexedSeq.empty): M[IndexedSeq[Pos]] = reify[M] { + if (prefix.length >= n) then + prefix + else + val nextPos = (1 to n).map(Pos(prefix.length+1,_)).filter(pos => isFree(pos, prefix)) + reflect(queens(n, prefix :+ reflect(all(nextPos)))) + } + + def isCorrect(queens:IndexedSeq[Pos]):Boolean = { + queens.forall(p1 => queens.forall(p2 => + (p1 == p2) || !isBeat(p1,p2) + )) + } + + + + +} \ No newline at end of file diff --git a/shared/src/test/scala/logic/backtrm/BacktrT.scala b/shared/src/test/scala/logic/backtrm/BacktrT.scala new file mode 100644 index 00000000..455ae0e7 --- /dev/null +++ b/shared/src/test/scala/logic/backtrm/BacktrT.scala @@ -0,0 +1,117 @@ +package logic.backtrm + +//Term implementation of backtracking monad transformer +// from Deriving Backtracking Monad Transformers by Ralf Hinze +// https://dl.acm.org/doi/10.1145/351240.351258 +// +// changed to support flatMapTry + +import cps.* +import cps.syntax.* +import scala.util.* +import scala.util.control.NonFatal + + + +sealed trait BacktrT[M[_],A] { + + def map[B](f: A=>B): BacktrT[M,B] + def flatMap[B](f: A=>BacktrT[M,B]): BacktrT[M,B] + def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): BacktrT[M,B] + def append(next: =>BacktrT[M,A]): BacktrT[M,A] + + +} + +object BacktrT{ + + case class Zero[M[_],A]() extends BacktrT[M,A] { + override def map[B](f: A=>B): Zero[M,B] = Zero() + override def flatMap[B](f: A=>BacktrT[M,B]): Zero[M,B] = Zero() + override def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): Zero[M,B] = Zero() + override def append(next: =>BacktrT[M, A]): BacktrT[M, A] = next + + } + + case class Cons[M[_]:CpsTryMonad,A](h: A, t: () => BacktrT[M,A]) extends BacktrT[M,A] { + + def map[B](f: A=>B): Cons[M,B] = Cons(f(h), () => t().map(f)) + + def flatMap[B](f: A=>BacktrT[M,B]): BacktrT[M,B] = + f(h).append(t().flatMap(f)) + + override def append(next: =>BacktrT[M, A]): BacktrT[M, A] = + Cons(h, () => t().append(next)) + + override def flatMapTry[B](f: Try[A]=>BacktrT[M,B]): BacktrT[M,B] = + try { + f(Success(h)).append(t().flatMapTry(f)) + } catch { + case NonFatal(ex) => + PromoteBind(summon[CpsTryMonad[M]].error(ex), { + case Success(_) => Success(Zero()) // impossible + case Failure(ex) => Failure(ex) + }) + } + + } + + + case class PromoteBind[M[_],A,B](ma: M[A], transform: Try[A]=>Try[BacktrT[M,B]]) extends BacktrT[M,B] { + + def map[C](f: B=>C): PromoteBind[M,A,C] = PromoteBind(ma, + { case Success(a) => transform(Success(a)).map(_.map(f)) + case Failure(ex) => Failure(ex) + } + ) + + def flatMap[C](g: B=>BacktrT[M,C]): BacktrT[M,C] = PromoteBind(ma, { + case Success(a) => transform(Success(a)).map(_.flatMap(g)) + case Failure(ex) => Failure(ex) + }) + + override def append(next: => BacktrT[M, B]): BacktrT[M, B] = + PromoteBind(ma, { + case Success(a) => transform(Success(a)).map(_.append(next)) + case Failure(ex) => Failure(ex) + }) + + + override def flatMapTry[C](g: Try[B]=>BacktrT[M,C]): BacktrT[M,C] = + PromoteBind( + ma, { x => + transform(x) match + case Success(bm) => Success(bm.flatMapTry(g)) + case Failure(ex) => Success(g(Failure(ex))) + } + ) + + + } + + + + class CpsBacktrTryMonad[M[_]:CpsTryMonad] extends CpsTryMonad[[X]=>>BacktrT[M,X]] with CpsTryMonadInstanceContext[[X]=>>BacktrT[M,X]] { + + override def pure[A](a: A): BacktrT[M, A] = Cons(a,()=>Zero()) + + override def map[A,B](fa: BacktrT[M,A])(f: A=>B): BacktrT[M,B] = + fa.map(f) + + override def flatMap[A,B](fa: BacktrT[M,A])(f: A=>BacktrT[M,B]): BacktrT[M,B] = + fa.flatMap(f) + + override def error[A](e: Throwable): BacktrT[M, A] = + PromoteBind(summon[CpsTryMonad[M]].error(e), { + case Success(_) => Failure(e) // impossible + case Failure(ex) => Failure(ex) + }) + + override def flatMapTry[A,B](fa: BacktrT[M,A])(f: Try[A]=> BacktrT[M,B]): BacktrT[M,B] = + fa.flatMapTry(f) + + + + } + +} \ No newline at end of file diff --git a/shared/src/test/scala/logic/logict/LogicT.scala b/shared/src/test/scala/logic/logict/LogicT.scala index 7236b9d0..0b069c81 100644 --- a/shared/src/test/scala/logic/logict/LogicT.scala +++ b/shared/src/test/scala/logic/logict/LogicT.scala @@ -1,10 +1,180 @@ package logic.logict +import cps.* +import cps.monads.{*, given} + +import scala.util.* +import logic.* + +import scala.annotation.tailrec + +type LogicM[A] = LogicT[CpsIdentity,A] + + +given logicMonadM: LogicT.LogicSKFKMonad[CpsIdentity] = LogicT.logicMonadT[CpsIdentity] + + /** * Dependency-less port of haskell's LogicT monad transformer. */ -trait LogicT[F[_],A] { +type LogicT[F[_],+A] = LogicT.LogicCallbackAcceptor[F,A] + +object LogicT { + + + type SuccessContinuation[F[_],A,R] = Try[A]=>F[R]=>F[R] + + + trait LogicCallbackAcceptor[F[_],+A] { + /** + * + * @param sk - success continuation (like k wih continuation monad and s for succee) + * @param fk - failure continuation. + * @tparam A - type of value on the time of performing continuation + * @return + */ + def apply[R](sk: SuccessContinuation[F,A,R]): F[R] => F[R] + } + + + + class LogicSKFKMonad[F[_]:CpsTryMonad] extends CpsLogicMonad[[X]=>>LogicCallbackAcceptor[F,X]] + with CpsLogicMonadInstanceContext[[X]=>>LogicCallbackAcceptor[F,X]] { + + override type Observer[A] = F[A] + + override def pure[A](a:A): LogicCallbackAcceptor[F,A] = new LogicCallbackAcceptor[F,A] { + override def apply[R](sk: SuccessContinuation[F,A,R]): F[R] => F[R] = + sk(Success(a)) + } + + override def error[A](ex: Throwable): LogicCallbackAcceptor[F,A] = new LogicCallbackAcceptor[F,A] { + override def apply[R](sk: SuccessContinuation[F,A,R]): F[R] => F[R] = + sk(Failure(ex)) + } + + override def map[A, B](fa: LogicCallbackAcceptor[F, A])(f: A => B): LogicCallbackAcceptor[F, B] = + new LogicCallbackAcceptor[F,B] { + override def apply[R](sk: SuccessContinuation[F,B,R]): F[R] => F[R] = + fa.apply[R]{ + case Success(a) => sk(Success(f(a))) + case Failure(ex) => sk(Failure(ex)) + } + } + + override def flatMap[A, B](fa: LogicCallbackAcceptor[F, A])(f: A => LogicCallbackAcceptor[F, B]): LogicCallbackAcceptor[F, B] = { + flatMapTry(fa) { + case Success(a) => f(a) + case Failure(ex) => error(ex) + } + } + + override def flatMapTry[A,B](fa: LogicCallbackAcceptor[F,A])(f: Try[A] => LogicCallbackAcceptor[F,B]): LogicCallbackAcceptor[F,B] = + new LogicCallbackAcceptor[F,B] { + override def apply[R](sk: SuccessContinuation[F,B,R]): F[R] => F[R] = + fa.apply[R]( + { + case Success(a) => f(Success(a)).apply(sk) + case Failure(ex) => sk(Failure(ex)) + } + ) + } + + override def mzero: LogicCallbackAcceptor[F, Nothing] = + new LogicCallbackAcceptor[F,Nothing] { + override def apply[R](sk: SuccessContinuation[F,Nothing,R]): (F[R] => F[R]) = + identity[F[R]] + } + + override def mplus[A](a: LogicCallbackAcceptor[F, A], b: LogicCallbackAcceptor[F, A]): LogicCallbackAcceptor[F, A] = + new LogicCallbackAcceptor[F,A] { + override def apply[R](sk: SuccessContinuation[F,A,R]): (F[R] => F[R]) = + fk => a.apply(sk)(b.apply(sk)(fk)) + } + + def lift[A](fa:F[A]): LogicCallbackAcceptor[F,A] = + new LogicCallbackAcceptor[F,A] { + override def apply[R](sk: SuccessContinuation[F,A,R]): F[R] => F[R] = + fk => summon[CpsMonad[F]].flatMapTry(fa)(ta => sk(ta)(fk)) + } + + + + override def msplit[A](c: LogicCallbackAcceptor[F, A]): LogicCallbackAcceptor[F, Option[(Try[A], LogicCallbackAcceptor[F, A])]] = { + val fcr = c.apply[Option[(Try[A],LogicCallbackAcceptor[F,A])]]( + { c => cfk => + val next = flatMap(lift(cfk)) { + case None => (mzero: LogicCallbackAcceptor[F,A]) + case Some((ta,sa)) => new LogicCallbackAcceptor[F,A] { + override def apply[R](sk: SuccessContinuation[F,A,R]): (F[R] => F[R]) = + fk => sk(ta)(sa.apply(sk)(fk)) + } + } + summon[CpsMonad[F]].pure(Some((c,next))) + } + )( summon[CpsMonad[F]].pure(None) ) + lift(fcr) + } + + def prepend[A](a: A, c: LogicCallbackAcceptor[F, A]): LogicCallbackAcceptor[F, A] = + new LogicCallbackAcceptor[F,A] { + override def apply[R](sk: SuccessContinuation[F,A,R]): (F[R] => F[R]) = + fk => sk(Success(a))(c.apply(sk)(fk)) + } + + override def observerCpsMonad: CpsTryMonad[F] = summon[CpsTryMonad[F]] + + override def mObserveOne[A](m:LogicCallbackAcceptor[F,A]): F[Option[A]] = { + msplit(m).apply[Option[A]] { + case Success(v) => v match + case None => fk => fk + case Some((ta,sa)) => fk => + ta match + case Success(a) => observerCpsMonad.pure(Some(a)) + case Failure(ex) => observerCpsMonad.error(ex) + case Failure(ex) => fk => summon[CpsTryMonad[F]].error(ex) + }(summon[CpsMonad[F]].pure(None)) + } + + override def mFoldLeftN[A, B](ma: LogicCallbackAcceptor[F, A], zero: F[B], n: Int)(op: (F[B], F[A]) => F[B]): F[B] = { + if (n<=0) then + zero + else + msplit(ma).apply[B] { + case Success(v) => v match + case None => fk => fk + case Some((ta,sa)) => + ta match + case Success(a) => + fk => mFoldLeftN(sa, op(fk,observerCpsMonad.pure(a)), n-1)(op) + case Failure(ex) => + fk => observerCpsMonad.error(ex) + }(zero) + } + + override def mFoldLeft[A, B](ma: LogicCallbackAcceptor[F, A], zero: F[B], op: (F[B], F[A]) => F[B]): F[B] = { + msplit(ma).apply[B] { + case Success(v) => v match + case None => fk => fk + case Some((ta,sa)) => fk => + ta match + case Success(a) => + mFoldLeft(sa, op(fk,observerCpsMonad.pure(a)), op) + case Failure(ex) => + observerCpsMonad.error(ex) + case Failure(ex) => fk => observerCpsMonad.error(ex) + }(zero) + } + + } + + + + //given logicMonadSKF[F[_]:CpsTryMonad]: CpsLogicMonad[[X]=>>LogicCallbackAcceptor[F,X]] = + // LogicSKFKMonad[F]() + given logicMonadT[F[_] : CpsTryMonad]: LogicSKFKMonad[F] = + LogicSKFKMonad[F]() }