diff --git a/src/main/scala/uk/gov/nationalarchives/tdr/api/db/repository/ConsignmentRepository.scala b/src/main/scala/uk/gov/nationalarchives/tdr/api/db/repository/ConsignmentRepository.scala index baf9ee95d..24c73a160 100644 --- a/src/main/scala/uk/gov/nationalarchives/tdr/api/db/repository/ConsignmentRepository.scala +++ b/src/main/scala/uk/gov/nationalarchives/tdr/api/db/repository/ConsignmentRepository.scala @@ -88,11 +88,11 @@ class ConsignmentRepository(db: Database, timeSource: TimeSource) { db.run(query.result) } - def updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput: ConsignmentFields.UpdateConsignmentSeriesIdInput): Future[Int] = { + def updateSeriesOfConsignment(updateConsignmentSeriesIdInput: ConsignmentFields.UpdateConsignmentSeriesIdInput, seriesName: Option[String]): Future[Int] = { val update = Consignment .filter(_.consignmentid === updateConsignmentSeriesIdInput.consignmentId) - .map(t => t.seriesid) - .update(Some(updateConsignmentSeriesIdInput.seriesId)) + .map(t => (t.seriesid, t.seriesname)) + .update(Some(updateConsignmentSeriesIdInput.seriesId), seriesName) db.run(update) } diff --git a/src/main/scala/uk/gov/nationalarchives/tdr/api/graphql/fields/ConsignmentFields.scala b/src/main/scala/uk/gov/nationalarchives/tdr/api/graphql/fields/ConsignmentFields.scala index 6d0f14796..d8fc0c95e 100644 --- a/src/main/scala/uk/gov/nationalarchives/tdr/api/graphql/fields/ConsignmentFields.scala +++ b/src/main/scala/uk/gov/nationalarchives/tdr/api/graphql/fields/ConsignmentFields.scala @@ -309,7 +309,7 @@ object ConsignmentFields { "updateConsignmentSeriesId", OptionType(IntType), arguments = UpdateConsignmentSeriesIdArg :: Nil, - resolve = ctx => ctx.ctx.consignmentService.updateSeriesIdOfConsignment(ctx.arg(UpdateConsignmentSeriesIdArg)), + resolve = ctx => ctx.ctx.consignmentService.updateSeriesOfConsignment(ctx.arg(UpdateConsignmentSeriesIdArg)), tags = List(ValidateUserHasAccessToConsignment(UpdateConsignmentSeriesIdArg), ValidateUpdateConsignmentSeriesId) ) ) diff --git a/src/main/scala/uk/gov/nationalarchives/tdr/api/service/ConsignmentService.scala b/src/main/scala/uk/gov/nationalarchives/tdr/api/service/ConsignmentService.scala index 857989575..ddead0a91 100644 --- a/src/main/scala/uk/gov/nationalarchives/tdr/api/service/ConsignmentService.scala +++ b/src/main/scala/uk/gov/nationalarchives/tdr/api/service/ConsignmentService.scala @@ -73,8 +73,7 @@ class ConsignmentService( for { sequence <- consignmentRepository.getNextConsignmentSequence body <- transferringBodyService.getBodyByCode(userBody) - series <- if (seriesId.isDefined) seriesRepository.getSeries(seriesId.get) else Future(Seq()) - seriesName = if (series.nonEmpty) Some(series.head.name) else None + seriesName <- getSeriesName(seriesId) consignmentRef = ConsignmentReference.createConsignmentReference(yearNow, sequence) consignmentId = uuidSource.uuid consignmentRow = ConsignmentRow( @@ -123,9 +122,10 @@ class ConsignmentService( consignment.map(rows => rows.headOption.map(series => Series(series.seriesid, series.bodyid, series.name, series.code, series.description))) } - def updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput: UpdateConsignmentSeriesIdInput): Future[Int] = { + def updateSeriesOfConsignment(updateConsignmentSeriesIdInput: UpdateConsignmentSeriesIdInput): Future[Int] = { for { - result <- consignmentRepository.updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput) + seriesName <- getSeriesName(Some(updateConsignmentSeriesIdInput.seriesId)) + result <- consignmentRepository.updateSeriesOfConsignment(updateConsignmentSeriesIdInput, seriesName) seriesStatus = if (result == 1) Completed else Failed _ <- consignmentStatusRepository.updateConsignmentStatus(updateConsignmentSeriesIdInput.consignmentId, "Series", seriesStatus, Timestamp.from(timeSource.now)) } yield result @@ -173,6 +173,14 @@ class ConsignmentService( .map(cr => convertRowToConsignment(cr)) .map(c => ConsignmentEdge(c, c.consignmentReference)) } + + private def getSeriesName(seriesId: Option[UUID]): Future[Option[String]] = { + if (seriesId.isDefined) { + seriesRepository.getSeries(seriesId.get).map(_.headOption.map(_.name)) + } else { + Future(None) + } + } } case class PaginatedConsignments(lastCursor: Option[String], consignmentEdges: Seq[ConsignmentEdge]) diff --git a/src/test/scala/uk/gov/nationalarchives/tdr/api/db/repository/ConsignmentRepositorySpec.scala b/src/test/scala/uk/gov/nationalarchives/tdr/api/db/repository/ConsignmentRepositorySpec.scala index 368b5b99f..eee6b67c6 100644 --- a/src/test/scala/uk/gov/nationalarchives/tdr/api/db/repository/ConsignmentRepositorySpec.scala +++ b/src/test/scala/uk/gov/nationalarchives/tdr/api/db/repository/ConsignmentRepositorySpec.scala @@ -5,6 +5,7 @@ import com.dimafeng.testcontainers.PostgreSQLContainer import org.scalatest.concurrent.ScalaFutures import org.scalatest.matchers.should.Matchers import uk.gov.nationalarchives.Tables.ConsignmentstatusRow +import uk.gov.nationalarchives.tdr.api.graphql.fields.ConsignmentFields import uk.gov.nationalarchives.tdr.api.graphql.fields.ConsignmentFields.{ConsignmentFilters, StartUploadInput, UpdateExportDataInput} import uk.gov.nationalarchives.tdr.api.service.CurrentTimeSource import uk.gov.nationalarchives.tdr.api.service.FileStatusService.{InProgress, Upload} @@ -297,6 +298,28 @@ class ConsignmentRepositorySpec extends TestContainerUtils with ScalaFutures wit consignmentReferences should equal(List("TDR-2021-B", "TDR-2021-A")) } + "updateSeriesOfConsignment" should "update id and name of the consignment" in withContainers { case container: PostgreSQLContainer => + val db = container.database + val consignmentRepository = new ConsignmentRepository(db, new CurrentTimeSource) + val utils = TestUtils(db) + val seriesId: UUID = UUID.fromString("20e88b3c-d063-4a6e-8b61-187d8c51d11d") + val seriesName: String = "Mock1" + val bodyId: UUID = UUID.fromString("8a72cc59-7f2f-4e55-a263-4a4cb9f677f5") + + utils.createConsignment(consignmentIdOne, userId, consignmentRef = "TDR-2021-A") + utils.addTransferringBody(bodyId, "MOCK Department", "Code123") + utils.addSeries(seriesId, bodyId, "TDR-2020-XYZ", seriesName) + + val input = ConsignmentFields.UpdateConsignmentSeriesIdInput(consignmentId = consignmentIdOne, seriesId = seriesId) + + val response = consignmentRepository.updateSeriesOfConsignment(input, seriesName.some).futureValue + + response should be(1) + val consignment = consignmentRepository.getConsignment(consignmentIdOne).futureValue.head + consignment.seriesid should be(seriesId.some) + consignment.seriesname should be(seriesName.some) + } + "totalConsignments" should "return total number of consignments" in withContainers { case container: PostgreSQLContainer => val db = container.database val consignmentRepository = new ConsignmentRepository(db, new CurrentTimeSource) diff --git a/src/test/scala/uk/gov/nationalarchives/tdr/api/service/ConsignmentServiceSpec.scala b/src/test/scala/uk/gov/nationalarchives/tdr/api/service/ConsignmentServiceSpec.scala index 1ace3e4c1..5977e3a6e 100644 --- a/src/test/scala/uk/gov/nationalarchives/tdr/api/service/ConsignmentServiceSpec.scala +++ b/src/test/scala/uk/gov/nationalarchives/tdr/api/service/ConsignmentServiceSpec.scala @@ -311,38 +311,40 @@ class ConsignmentServiceSpec extends AnyFlatSpec with MockitoSugar with ResetMoc series.description shouldBe mockSeries.head.description } - "updateSeriesIdOfConsignment" should "update the seriesId and status for a given consignment" in { + "updateSeriesOfConsignment" should "update the seriesId, seriesName and status for a given consignment" in { val updateConsignmentSeriesIdInput = UpdateConsignmentSeriesIdInput(consignmentId, seriesId) val statusType = "Series" val expectedSeriesStatus = Completed val expectedResult = 1 - when(consignmentRepoMock.updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput)) + when(consignmentRepoMock.updateSeriesOfConsignment(updateConsignmentSeriesIdInput, Some(seriesName))) .thenReturn(Future.successful(1)) when(consignmentStatusRepoMock.updateConsignmentStatus(consignmentId, statusType, Completed, Timestamp.from(fixedTimeSource))) .thenReturn(Future.successful(1)) + when(seriesRepositoryMock.getSeries(updateConsignmentSeriesIdInput.seriesId)).thenReturn(Future.successful(Seq(mockSeries))) - val result = consignmentService.updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput).futureValue + val result = consignmentService.updateSeriesOfConsignment(updateConsignmentSeriesIdInput).futureValue - verify(consignmentRepoMock).updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput) + verify(consignmentRepoMock).updateSeriesOfConsignment(updateConsignmentSeriesIdInput, Some(seriesName)) verify(consignmentStatusRepoMock) .updateConsignmentStatus(updateConsignmentSeriesIdInput.consignmentId, statusType, expectedSeriesStatus, Timestamp.from(fixedTimeSource)) result should equal(expectedResult) } - "updateSeriesIdOfConsignment" should "update the status with 'Failed' if seriesId update fails for a given consignment" in { + "updateSeriesOfConsignment" should "update the status with 'Failed' if seriesId update fails for a given consignment" in { val updateConsignmentSeriesIdInput = UpdateConsignmentSeriesIdInput(consignmentId, seriesId) val statusType = "Series" val expectedSeriesStatus = Failed val expectedResult = 0 - when(consignmentRepoMock.updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput)) + when(consignmentRepoMock.updateSeriesOfConsignment(updateConsignmentSeriesIdInput, Some(seriesName))) .thenReturn(Future.successful(0)) when(consignmentStatusRepoMock.updateConsignmentStatus(consignmentId, statusType, Failed, Timestamp.from(fixedTimeSource))) .thenReturn(Future.successful(1)) + when(seriesRepositoryMock.getSeries(updateConsignmentSeriesIdInput.seriesId)).thenReturn(Future.successful(Seq(mockSeries))) - val result = consignmentService.updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput).futureValue + val result = consignmentService.updateSeriesOfConsignment(updateConsignmentSeriesIdInput).futureValue - verify(consignmentRepoMock).updateSeriesIdOfConsignment(updateConsignmentSeriesIdInput) + verify(consignmentRepoMock).updateSeriesOfConsignment(updateConsignmentSeriesIdInput, Some(seriesName)) verify(consignmentStatusRepoMock) .updateConsignmentStatus(updateConsignmentSeriesIdInput.consignmentId, statusType, expectedSeriesStatus, Timestamp.from(fixedTimeSource))