diff --git a/src/core/Context.scala b/src/core/Context.scala new file mode 100644 index 0000000..32380d4 --- /dev/null +++ b/src/core/Context.scala @@ -0,0 +1,75 @@ +package mutatus + +import com.google.cloud.datastore.{DatastoreReader, DatastoreWriter, Key, FullEntity, DatastoreException} + +sealed trait Context { + implicit val service: Service +} + +object Context { + implicit def default(implicit svc: Service) = Default(svc) + + sealed trait ReadApi { + self: Context => + val read: DatastoreReader + } + sealed trait WriteApi { + self: Context => + val write: DatastoreWriter + def saveAll(entities: Traversable[FullEntity[_]]): Result[Unit] + def deleteAll(keys: Traversable[Key]): Result[Unit] + } + + /** + * Default context used to perform non-batched operations. + */ + case class Default(service: Service) + extends Context + with ReadApi + with WriteApi { + val read: DatastoreReader = service.datastore + val write: DatastoreWriter = service.datastore + def deleteAll(keys: Traversable[Key]): Result[Unit] = Result { + val batch = service.datastore.newBatch() + batch.delete(keys.toList: _*) + batch.submit() + } + def saveAll(entities: Traversable[FullEntity[_]]): Result[Unit] = Result { + val batch = service.datastore.newBatch() + batch.put(entities.toList: _*) + batch.submit() + } + } + + /** + * Context used to performs batched operations using Datastore Transactions API. + */ + private[mutatus] case class Transaction(service: Service) + extends Context + with ReadApi + with WriteApi { + val tx = service.datastore.newTransaction() + val read: DatastoreReader = tx + val write: DatastoreWriter = tx + def deleteAll(keys: Traversable[Key]): Result[Unit] = + Result(write.delete(keys.toList: _*)) + def saveAll(entities: Traversable[FullEntity[_]]): Result[Unit] = + Result(write.put(entities.toList: _*)) + } + + /** + * Context used for batched operations using Datastore Batch API. It enabled only Write operations. + */ + private[mutatus] case class Batch(service: Service) + extends Context + with WriteApi { + val batch = service.datastore.newBatch() + val write: DatastoreWriter = batch + + def deleteAll(keys: Traversable[Key]): Result[Unit] = + Result(write.delete(keys.toList: _*)) + def saveAll(entities: Traversable[FullEntity[_]]): Result[Unit] = + Result(write.put(entities.toList: _*)) + } +} + diff --git a/src/core/MutatusException.scala b/src/core/MutatusException.scala index 7112133..2402be1 100644 --- a/src/core/MutatusException.scala +++ b/src/core/MutatusException.scala @@ -8,3 +8,4 @@ case class NotSavedException(kind: String) extends MutatusException( s"Entity of type $kind cannot be deleted becasue it has not been saved" ) +case class InactiveTransactionException(tx: com.google.cloud.datastore.Transaction) extends MutatusException(s"Transaction ${tx.getTransactionId()} was already commited or timeout out") \ No newline at end of file diff --git a/src/core/QueryBuilder.scala b/src/core/QueryBuilder.scala index 80bf531..9987edb 100644 --- a/src/core/QueryBuilder.scala +++ b/src/core/QueryBuilder.scala @@ -57,10 +57,10 @@ case class QueryBuilder[T] private[mutatus] ( /** Materializes query and returns Stream of entities for GCP Storage */ def run()( - implicit svc: Service = Service.default, + implicit ctx: Context with Context.ReadApi, namespace: Namespace, decoder: Decoder[T] - ): mutatus.Result[Stream[mutatus.Result[T]]] = { + ): Stream[mutatus.Result[T]] = { val baseQuery = namespace.option.foldLeft( Query.newEntityQueryBuilder().setKind(kind) )(_.setNamespace(_)) @@ -72,15 +72,18 @@ case class QueryBuilder[T] private[mutatus] ( val withOffset = offset.foldLeft(limited)(_.setOffset(_)) val query = withOffset.build() - for { - results <- mutatus.Result(svc.read.run(query)) - entities = new Iterator[Entity] { + Result { + val results = ctx.read.run(query) + new Iterator[Entity] { def next(): Entity = results.next() - def hasNext: Boolean = results.hasNext - }.toStream - } yield entities.map(decoder.decode) - }.extenuate { - case exc: DatastoreException => DatabaseException(exc) + }.toStream.map(decoder.decode) + }.extenuate { + case exc: DatastoreException => DatabaseException(exc) + } match { + case Answer(entities) => entities + case Error(error) => Stream(Error(error)) + case Surprise(error) => Stream(Surprise(error)) + } } } diff --git a/src/core/mutatus.scala b/src/core/mutatus.scala index c0d3b57..329b742 100644 --- a/src/core/mutatus.scala +++ b/src/core/mutatus.scala @@ -23,9 +23,9 @@ import mercator._ import scala.annotation.StaticAnnotation import scala.collection.JavaConverters._ import scala.collection.generic.CanBuildFrom -import scala.language.experimental.macros -import io.opencensus.trace.Status.CanonicalCode import com.google.rpc.Code +import scala.reflect.macros.blackbox +import language.experimental.macros /** Mutatus package object */ object `package` extends Domain[MutatusException] { @@ -35,49 +35,42 @@ object `package` extends Domain[MutatusException] { /** saves the all case class as a Datastore entity in batch mode */ def saveAll()( - implicit svc: Service, + implicit ctx: Context with Context.WriteApi, encoder: Encoder[T], dao: Dao[T], idField: IdField[T] ): mutatus.Result[Map[T, Ref[T]]] = { - val (batch, refs) = - values.foldLeft(svc.readWrite.newBatch() -> Map.empty[T, Ref[T]]) { - case ((b, entityRefs), value) => - val entity = value.buildEntity() - b.put(value.buildEntity()) - b -> entityRefs.updated(value, new Ref[T](entity.getKey)) - } - Result(batch.submit()) - .extenuate { - case ex: DatastoreException => DatabaseException(ex) - } - .map(_ => refs) //batch may supply only those keys which were generated by Datastore, but in mutatus generates keys deterministically based on entity content + for { + encodingResult <- Result { + for { + value <- values + entity = value.buildEntity + ref = new Ref[T](entity.getKey()) + } yield entity -> (value -> ref) + } + (encodedEntities, valueRefs) = encodingResult.unzip + _ <- ctx.saveAll(encodedEntities.toList) + } yield valueRefs.toMap + }.extenuate { + case ex: DatastoreException => DatabaseException(ex) } /** deletes the Datastore entities in batch mode*/ def deleteAll()( - implicit svc: Service, + implicit ctx: Context with Context.WriteApi, dao: Dao[T], idField: IdField[T] ): mutatus.Result[Unit] = Result { - values - .foldLeft(svc.readWrite.newBatch()) { - case (b, value) => - b.delete(idField.idKey(idField.key(value)).newKey(dao.keyFactory)) - b - } - .submit() - () - }.extenuate { - case ex: DatastoreException => - import google.rpc._ - ex.getCode match { - case Code.NOT_FOUND_VALUE => NotSavedException(dao.kind) - case _ => DatabaseException(ex) - } - } + for { + value <- values + idKey = idField.idKey(idField.key(value)) + } yield idKey.newKey(dao.keyFactory) + }.flatMap(ctx.deleteAll) + .extenuate { + case ex: DatastoreException => DatabaseException(ex) + } } /** provides `save` and `delete` methods on case class instances */ @@ -99,14 +92,14 @@ object `package` extends Domain[MutatusException] { /** saves the case class as a Datastore entity */ def save()( - implicit svc: Service, + implicit ctx: Context with Context.WriteApi, encoder: Encoder[T], dao: Dao[T], idField: IdField[T] ): mutatus.Result[Ref[T]] = Result { new Ref[T]( - svc.readWrite.put(buildEntity()).getKey + ctx.write.put(buildEntity()).getKey ) }.extenuate { case exc: DatastoreException => DatabaseException(exc) @@ -114,12 +107,12 @@ object `package` extends Domain[MutatusException] { /** deletes the Datastore entity with this ID */ def delete()( - implicit svc: Service, + implicit ctx: Context with Context.WriteApi, dao: Dao[T], idField: IdField[T] ): mutatus.Result[Unit] = Result { - svc.readWrite.delete( + ctx.write.delete( idField.idKey(idField.key(value)).newKey(dao.keyFactory) ) }.extenuate { @@ -154,8 +147,11 @@ final class id() extends StaticAnnotation case class Ref[T](ref: Key) { /** resolves the reference and returns a case class instance */ - def apply()(implicit svc: Service, decoder: Decoder[T]): Result[T] = - decoder.decode(svc.read.get(ref)) + def apply()( + implicit ctx: Context with Context.ReadApi, + decoder: Decoder[T] + ): Result[T] = + decoder.decode(ctx.read.get(ref)) override def toString: String = s"$Ref[${ref.getKind}]($key)" /** a `String` version of the key contained by this reference */ @@ -181,9 +177,7 @@ case class Geo(lat: Double, lng: Double) { } /** a representation of the GCP Datastore service */ -case class Service(readWrite: Datastore) { - def read: DatastoreReader = readWrite -} +case class Service(datastore: Datastore) object Service { @@ -291,13 +285,13 @@ case class LongId(id: Long) extends IdKey { /** a data access object for a particular type */ case class Dao[T](kind: String)( - implicit svc: Service, + implicit ctx: Context, namespace: Namespace, decoder: Decoder[T] ) { private[mutatus] lazy val keyFactory = { - val baseFactory = svc.readWrite.newKeyFactory().setKind(kind) + val baseFactory = ctx.service.datastore.newKeyFactory().setKind(kind) namespace.option.foldLeft(baseFactory)(_.setNamespace(_)) } @@ -306,9 +300,9 @@ case class Dao[T](kind: String)( def unapply[R](id: R)(implicit idField: IdField[T] { type Return = R - }): Option[Result[T]] = { + }, ctx: Context with Context.ReadApi): Option[Result[T]] = { val key = idField.idKey(id).newKey(keyFactory) - Option(svc.read.get(key)) + Option(ctx.read.get(key)) .map(decoder.decode) } } @@ -319,8 +313,90 @@ object Dao { implicit metadata: TypeMetadata[T], decoder: Decoder[T], namespace: Namespace, - service: Service + ctx: Context ): Dao[T] = Dao(metadata.typeName) + + /** + * Executes operations from tx body using Datastore Transaction API. + * If all operations would pass transaction would be automaticlly commited, in otherwise it would be rolledback + * + * @param tx Body of transaction to be performed. Input parameter must be used as implicit value in order to work. + * + * Example of usage + * Dao.transaction{ implicit tx => + * for{ + * foo <- Dao.all.filter(_.bar == "bar").run() + * bar <- foo.map(_.copy(bar = "foo")).saveAll + * } yield bar + * } + */ + def transaction[T]( + tx: Context.Transaction => Result[T] + )(implicit svc: Service) = { + val ctx = Context.Transaction(svc) + tx(ctx) + .flatMap { result => + Result(ctx.tx.commit()) + .map(_ => result) + .extenuate { + case ex: DatastoreException => DatabaseException(ex) + } + } + .extenuate { + case ex: DatastoreException => + if (ctx.tx.isActive) { + ctx.tx.rollback() + } + DatabaseException(ex) + } + } + + private object DaoMacro{ + + def batchImpl[T: c.WeakTypeTag]( + c: blackbox.Context + )(batch: c.Tree)(svc: c.Tree): c.Tree = { + import c.universe._ + val q"($arg) => $fnBody" = batch + //Using external context is a proof that provided Context.Batch was not sufficient to perform operation, though external (default) context was used + //Such construct could be considered as dirty hack, but I've not found better way to found usage of default context + val usesExternalContext = fnBody.exists(showCode(_).contains("Context.default")) + if (usesExternalContext) + c.abort( + c.enclosingPosition, + "mutatus: Read operations within Batch are prohibited" + ) + + q"""{ + val ctx = _root_.mutatus.Context.Batch($svc) + $batch(ctx).flatMap { result => + _root_.mutatus.Result(ctx.batch.submit()).map(_ => result) + }.extenuate{ + case ex: _root_.com.google.cloud.datastore.DatastoreException => _root_.mutatus.DatabaseException(ex) + } + }""" + } + } + + /** + * Executes write-only operations using Datastore Batch API. + * If none of operations inside batch function would fail, batch would would be automaticlly submitted. + * + * @param batch Body of batch to be performed. Input parameter must be used as implicit value in order to work. + * + * Example of usage + * val entities = List[_] + * Dao.batch{ implicit batch => + * for{ + * x <- FooBar(1).save + * xs <- List(FooBar(2), FooBar(3)).saveAll + * y <- FooBar(0).delete + * } yield + * } + */ + def batch[T](batch: Context.Batch => Result[T])( + implicit svc: Service + ): Result[T] = macro DaoMacro.batchImpl[T] } /** companion object for `Decoder`, including Magnolia generic derivation */ diff --git a/src/tests/EndToEndSpec.scala b/src/tests/EndToEndSpec.scala index d66b113..ce458aa 100644 --- a/src/tests/EndToEndSpec.scala +++ b/src/tests/EndToEndSpec.scala @@ -114,24 +114,19 @@ case class EndToEndSpec()(implicit runner: Runner) { test("fetch entities - simple")( Dao[TestSimpleEntity].all.run() - ).assert { - case Answer(value) => - value.toVector == (simpleEntities ++ batchedSimpleEntities) - .sortBy(_.id) - .map(mutatus.Answer(_)) - case _ => false - - } + ).assert( + _.toVector == (simpleEntities ++ batchedSimpleEntities) + .sortBy(_.id) + .map(mutatus.Answer(_)) + ) test("fetch entities - complex - long id")( Dao[TestComplexLongId].all.run() - ).assert { - case Answer(value) => - value.toVector == longIdComplexEntities.toVector - .sortBy(_.id) - .map(Answer.apply) - case _ => false - } + ).assert( + _.toVector == longIdComplexEntities.toVector + .sortBy(_.id) + .map(Answer.apply) + ) simpleEntities.take(3).foreach { e => test("fetch entity by id - simple")(Dao[TestSimpleEntity].unapply(e.id)) @@ -168,20 +163,18 @@ case class EndToEndSpec()(implicit runner: Runner) { .drop(1) .take(2) .run() - .map(_.toList) + .toList }.assert( - _ == Answer( - longIdComplexEntities - .filter(_.innerOpt.exists(_.int >= 2)) - .filter(_.innerOpt.exists(_.int <= 8)) - .sortBy(_.innerOpt.map(_.int).getOrElse(-1))( - implicitly[Ordering[Int]].reverse - ) - .drop(1) - .take(2) - .toList - .map(Answer(_)) - ) + _ == longIdComplexEntities + .filter(_.innerOpt.exists(_.int >= 2)) + .filter(_.innerOpt.exists(_.int <= 8)) + .sortBy(_.innerOpt.map(_.int).getOrElse(-1))( + implicitly[Ordering[Int]].reverse + ) + .drop(1) + .take(2) + .toList + .map(Answer(_)) ) simpleEntities.take(5).foreach { entity => @@ -196,6 +189,135 @@ case class EndToEndSpec()(implicit runner: Runner) { }.assert(_.contains(Answer(updated))) } + { + def fetchEntities = Dao[TestSimpleEntity].all.run() + val simpleEntitiesBefore = fetchEntities + val error = Error(SerializationException("Mocked up problem")) + + test("rollbacks transactions in case of failure") { + Dao.transaction { implicit tx => + for { + updated <- fetchEntities + .collect { + case Answer(value) => value.copy(longParam = value.longParam + 1L) + } + .saveAll() + _ <- error + } yield updated + } + }.assert(result => { + val stateAfter = fetchEntities + (simpleEntitiesBefore, result, stateAfter) match { + case (before, `error`, after) => + before.toList == after.toList + case _ => false + } + }) + } + + { + def fetchEntities = Dao[TestSimpleEntity].all.run() + test("commits transactions in case of success") { + Dao.transaction { implicit tx => + for { + updated <- fetchEntities + .collect { + case Answer(value) => value.copy(longParam = value.longParam + 2L) + } + .saveAll() + } yield updated + } + }.assert(result => { + val stateAfter = fetchEntities + .collect { + case Answer(value) => value + } + .toList + .sortBy(_.id) + + (stateAfter, result.map(_.keys.toList.sortBy(_.id))) match { + case (expected, Answer(actual)) => + val passed = expected == actual + if (!passed) { + println(s""" + Expected: ${expected} + Actual: ${actual} + """) + } + passed + case (expected, actual) => + println(s""" + Expected: ${expected} + Actual: ${actual} + """) + false + } + }) + } + + def fetchSimpleEntities() = Dao[TestSimpleEntity].all.run() + + { + test("submits batches") { + val stateBefore = fetchSimpleEntities + Dao.batch { implicit batch => + for { + x <- Result { + stateBefore + .take(5) + .collect { + case Answer(value) => + val updated = + value.copy(decimalParam = value.decimalParam + 1.0) + value -> value.save() + } + .toMap + } + y <- stateBefore + .drop(5) + .collect { + case Answer(value) => + value.copy(decimalParam = value.decimalParam - 1.0) + } + .saveAll() + } yield x ++ y + } + }.assert(result => { + val actualState = fetchSimpleEntities().collect { + case Answer(value) => value + } + + (result, actualState) match { + case (Answer(expected), actual) => + expected.keySet.toList.sortBy(_.id) == actual.toList + case _ => false + } + }) + + val stateBefore = fetchSimpleEntities() + val error = Error(SerializationException("Mock up")) + test("does not submit errorues batches") { + Dao.batch { implicit batch => + for { + updated <- Result { + stateBefore.collect { + case Answer(value) => + value.copy(decimalParam = value.decimalParam + Math.PI) + } + } + _ <- updated.saveAll() + _ <- error + } yield updated + } + }.assert(result => + (result, stateBefore, fetchSimpleEntities()) match { + case (`error`, stateBefore, actualState) => + stateBefore.nonEmpty && stateBefore.toList == actualState.toList + case _ => false + } + ) + } + simpleEntities.foreach { e => test("removes entities") { e.delete() @@ -237,10 +359,7 @@ case class EndToEndSpec()(implicit runner: Runner) { test("removes entities in batch mode") { batchedSimpleEntities.deleteAll() Dao[TestSimpleEntity].all.run() - }.assert { - case Answer(result) => result.isEmpty - case _ => false - } + }.assert(_.isEmpty) test("removed everything") { List( @@ -249,13 +368,7 @@ case class EndToEndSpec()(implicit runner: Runner) { Dao[TestComplexStringId].all.run(), Dao[TestComplexGuid].all.run() ) - }.assert(_.forall { - case Answer(result) => result.isEmpty - case other => - println(other) - false - }) - + }.assert(_.forall(_.isEmpty)) } object EndToEndSpec {