From ca723c833d6d65df852d6e6ab547fd4515337a58 Mon Sep 17 00:00:00 2001 From: jsotobroad Date: Fri, 15 Nov 2024 11:57:42 -0500 Subject: [PATCH] TSPS-360 Add steps to calculate quota consumed (#158) Co-authored-by: Jose Soto --- .../internal/WdlPipelineConfiguration.java | 12 + .../pipelines/common/utils/FlightBeanBag.java | 6 +- .../FetchQuotaConsumedFromDataTableStep.java | 93 +++++++ ...PollQuotaConsumedSubmissionStatusStep.java | 68 +++++ .../SubmitQuotaConsumedSubmissionStep.java | 131 ++++++++++ .../imputation/ImputationJobMapKeys.java | 1 + .../imputation/RunImputationGcpJobFlight.java | 22 +- .../gcp/FetchOutputsFromDataTableStep.java | 2 + .../gcp/PollCromwellSubmissionStatusStep.java | 52 +--- .../gcp/SubmitCromwellSubmissionStep.java | 67 +---- .../utils/RawlsSubmissionStepHelper.java | 142 ++++++++++ service/src/main/resources/application.yml | 3 + .../common/utils/FlightBeanBagTest.java | 6 +- .../RunImputationGcpFlightTest.java | 3 + ...tchQuotaConsumedFromDataTableStepTest.java | 126 +++++++++ ...QuotaConsumedSubmissionStatusStepTest.java | 166 ++++++++++++ ...SubmitQuotaConsumedSubmissionStepTest.java | 243 ++++++++++++++++++ .../PollCromwellSubmissionStatusStepTest.java | 10 +- .../gcp/SubmitCromwellSubmissionStepTest.java | 7 +- .../src/test/resources/application-test.yml | 3 + 20 files changed, 1052 insertions(+), 111 deletions(-) create mode 100644 service/src/main/java/bio/terra/pipelines/app/configuration/internal/WdlPipelineConfiguration.java create mode 100644 service/src/main/java/bio/terra/pipelines/stairway/FetchQuotaConsumedFromDataTableStep.java create mode 100644 service/src/main/java/bio/terra/pipelines/stairway/PollQuotaConsumedSubmissionStatusStep.java create mode 100644 service/src/main/java/bio/terra/pipelines/stairway/SubmitQuotaConsumedSubmissionStep.java create mode 100644 service/src/main/java/bio/terra/pipelines/stairway/utils/RawlsSubmissionStepHelper.java create mode 100644 service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/FetchQuotaConsumedFromDataTableStepTest.java create mode 100644 service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/PollQuotaConsumedSubmissionStatusStepTest.java create mode 100644 service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/SubmitQuotaConsumedSubmissionStepTest.java diff --git a/service/src/main/java/bio/terra/pipelines/app/configuration/internal/WdlPipelineConfiguration.java b/service/src/main/java/bio/terra/pipelines/app/configuration/internal/WdlPipelineConfiguration.java new file mode 100644 index 00000000..7cdfd24c --- /dev/null +++ b/service/src/main/java/bio/terra/pipelines/app/configuration/internal/WdlPipelineConfiguration.java @@ -0,0 +1,12 @@ +package bio.terra.pipelines.app.configuration.internal; + +import lombok.Getter; +import lombok.Setter; +import org.springframework.boot.context.properties.ConfigurationProperties; + +@ConfigurationProperties(prefix = "pipelines.wdl") +@Getter +@Setter +public class WdlPipelineConfiguration { + private Long quotaConsumedPollingIntervalSeconds; +} diff --git a/service/src/main/java/bio/terra/pipelines/common/utils/FlightBeanBag.java b/service/src/main/java/bio/terra/pipelines/common/utils/FlightBeanBag.java index 59a4ddc9..7b857b57 100644 --- a/service/src/main/java/bio/terra/pipelines/common/utils/FlightBeanBag.java +++ b/service/src/main/java/bio/terra/pipelines/common/utils/FlightBeanBag.java @@ -2,6 +2,7 @@ import bio.terra.pipelines.app.configuration.external.CbasConfiguration; import bio.terra.pipelines.app.configuration.internal.ImputationConfiguration; +import bio.terra.pipelines.app.configuration.internal.WdlPipelineConfiguration; import bio.terra.pipelines.dependencies.cbas.CbasService; import bio.terra.pipelines.dependencies.leonardo.LeonardoService; import bio.terra.pipelines.dependencies.rawls.RawlsService; @@ -37,6 +38,7 @@ public class FlightBeanBag { private final RawlsService rawlsService; private final ImputationConfiguration imputationConfiguration; private final CbasConfiguration cbasConfiguration; + private final WdlPipelineConfiguration wdlPipelineConfiguration; @Lazy @Autowired @@ -51,7 +53,8 @@ public FlightBeanBag( RawlsService rawlsService, WorkspaceManagerService workspaceManagerService, ImputationConfiguration imputationConfiguration, - CbasConfiguration cbasConfiguration) { + CbasConfiguration cbasConfiguration, + WdlPipelineConfiguration wdlPipelineConfiguration) { this.pipelinesService = pipelinesService; this.pipelineRunsService = pipelineRunsService; this.pipelineInputsOutputsService = pipelineInputsOutputsService; @@ -63,6 +66,7 @@ public FlightBeanBag( this.rawlsService = rawlsService; this.imputationConfiguration = imputationConfiguration; this.cbasConfiguration = cbasConfiguration; + this.wdlPipelineConfiguration = wdlPipelineConfiguration; } public static FlightBeanBag getFromObject(Object object) { diff --git a/service/src/main/java/bio/terra/pipelines/stairway/FetchQuotaConsumedFromDataTableStep.java b/service/src/main/java/bio/terra/pipelines/stairway/FetchQuotaConsumedFromDataTableStep.java new file mode 100644 index 00000000..e996ebe2 --- /dev/null +++ b/service/src/main/java/bio/terra/pipelines/stairway/FetchQuotaConsumedFromDataTableStep.java @@ -0,0 +1,93 @@ +package bio.terra.pipelines.stairway; + +import bio.terra.common.exception.InternalServerErrorException; +import bio.terra.pipelines.common.utils.FlightUtils; +import bio.terra.pipelines.common.utils.PipelinesEnum; +import bio.terra.pipelines.dependencies.rawls.RawlsService; +import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException; +import bio.terra.pipelines.dependencies.sam.SamService; +import bio.terra.pipelines.dependencies.stairway.JobMapKeys; +import bio.terra.pipelines.stairway.imputation.ImputationJobMapKeys; +import bio.terra.rawls.model.Entity; +import bio.terra.stairway.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This step calls Rawls to fetch outputs from a data table row for a given quota consumed job. It + * specifically fetches the quota consumed value from the data table row using the quota_consumed + * key + * + *

This step expects nothing from the working map + */ +public class FetchQuotaConsumedFromDataTableStep implements Step { + + private final RawlsService rawlsService; + private final SamService samService; + private final Logger logger = LoggerFactory.getLogger(FetchQuotaConsumedFromDataTableStep.class); + + public FetchQuotaConsumedFromDataTableStep(RawlsService rawlsService, SamService samService) { + this.rawlsService = rawlsService; + this.samService = samService; + } + + @Override + @SuppressWarnings( + "java:S2259") // suppress warning for possible NPE when calling pipelineName.getValue(), + // since we do validate that pipelineName is not null in `validateRequiredEntries` + public StepResult doStep(FlightContext flightContext) { + String jobId = flightContext.getFlightId(); + + // validate and extract parameters from input map + var inputParameters = flightContext.getInputParameters(); + FlightUtils.validateRequiredEntries( + inputParameters, + JobMapKeys.PIPELINE_NAME, + ImputationJobMapKeys.CONTROL_WORKSPACE_BILLING_PROJECT, + ImputationJobMapKeys.CONTROL_WORKSPACE_NAME); + + String controlWorkspaceBillingProject = + inputParameters.get(ImputationJobMapKeys.CONTROL_WORKSPACE_BILLING_PROJECT, String.class); + String controlWorkspaceName = + inputParameters.get(ImputationJobMapKeys.CONTROL_WORKSPACE_NAME, String.class); + PipelinesEnum pipelineName = inputParameters.get(JobMapKeys.PIPELINE_NAME, PipelinesEnum.class); + + Entity entity; + try { + entity = + rawlsService.getDataTableEntity( + samService.getTeaspoonsServiceAccountToken(), + controlWorkspaceBillingProject, + controlWorkspaceName, + pipelineName.getValue(), + jobId); + } catch (RawlsServiceApiException e) { + return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e); + } + + // extract quota_consumed from entity + int quotaConsumed; + try { + quotaConsumed = (int) entity.getAttributes().get("quota_consumed"); + if (quotaConsumed <= 0) { + return new StepResult( + StepStatus.STEP_RESULT_FAILURE_FATAL, + new InternalServerErrorException("Quota consumed is unexpectedly not greater than 0")); + } + } catch (NullPointerException e) { + return new StepResult( + StepStatus.STEP_RESULT_FAILURE_FATAL, + new InternalServerErrorException("Quota consumed is unexpectedly null")); + } + + logger.info("Quota consumed: {}", quotaConsumed); + + return StepResult.getStepResultSuccess(); + } + + @Override + public StepResult undoStep(FlightContext flightContext) { + // nothing to undo + return StepResult.getStepResultSuccess(); + } +} diff --git a/service/src/main/java/bio/terra/pipelines/stairway/PollQuotaConsumedSubmissionStatusStep.java b/service/src/main/java/bio/terra/pipelines/stairway/PollQuotaConsumedSubmissionStatusStep.java new file mode 100644 index 00000000..242ece4a --- /dev/null +++ b/service/src/main/java/bio/terra/pipelines/stairway/PollQuotaConsumedSubmissionStatusStep.java @@ -0,0 +1,68 @@ +package bio.terra.pipelines.stairway; + +import bio.terra.pipelines.app.configuration.internal.WdlPipelineConfiguration; +import bio.terra.pipelines.common.utils.FlightUtils; +import bio.terra.pipelines.dependencies.rawls.RawlsService; +import bio.terra.pipelines.dependencies.sam.SamService; +import bio.terra.pipelines.stairway.imputation.ImputationJobMapKeys; +import bio.terra.pipelines.stairway.utils.RawlsSubmissionStepHelper; +import bio.terra.stairway.*; +import java.util.UUID; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This step polls rawls for a submission until all runs are in a finalized state. If submission is + * not in a final state, this will step will poll again after an interval of time. Once the + * submission is finalized then it will see if the workflows are all successful and if so will + * succeed otherwise will fail. + * + *

this step expects quota submission id to be provided in the working map + */ +public class PollQuotaConsumedSubmissionStatusStep implements Step { + private final RawlsService rawlsService; + private final SamService samService; + private final WdlPipelineConfiguration wdlPipelineConfiguration; + private final Logger logger = + LoggerFactory.getLogger(PollQuotaConsumedSubmissionStatusStep.class); + + public PollQuotaConsumedSubmissionStatusStep( + RawlsService rawlsService, + SamService samService, + WdlPipelineConfiguration wdlPipelineConfiguration) { + this.samService = samService; + this.rawlsService = rawlsService; + this.wdlPipelineConfiguration = wdlPipelineConfiguration; + } + + @Override + public StepResult doStep(FlightContext flightContext) throws InterruptedException { + // validate and extract parameters from input map + FlightMap inputParameters = flightContext.getInputParameters(); + FlightUtils.validateRequiredEntries( + inputParameters, + ImputationJobMapKeys.CONTROL_WORKSPACE_NAME, + ImputationJobMapKeys.CONTROL_WORKSPACE_BILLING_PROJECT); + String controlWorkspaceName = + inputParameters.get(ImputationJobMapKeys.CONTROL_WORKSPACE_NAME, String.class); + String controlWorkspaceProject = + inputParameters.get(ImputationJobMapKeys.CONTROL_WORKSPACE_BILLING_PROJECT, String.class); + // validate and extract parameters from working map + FlightMap workingMap = flightContext.getWorkingMap(); + FlightUtils.validateRequiredEntries(workingMap, ImputationJobMapKeys.QUOTA_SUBMISSION_ID); + + UUID quotaSubmissionId = workingMap.get(ImputationJobMapKeys.QUOTA_SUBMISSION_ID, UUID.class); + + RawlsSubmissionStepHelper rawlsSubmissionStepHelper = + new RawlsSubmissionStepHelper( + rawlsService, samService, controlWorkspaceProject, controlWorkspaceName, logger); + return rawlsSubmissionStepHelper.pollRawlsSubmissionHelper( + quotaSubmissionId, wdlPipelineConfiguration.getQuotaConsumedPollingIntervalSeconds()); + } + + @Override + public StepResult undoStep(FlightContext context) { + // nothing to undo; there's nothing to undo about polling a cromwell submission + return StepResult.getStepResultSuccess(); + } +} diff --git a/service/src/main/java/bio/terra/pipelines/stairway/SubmitQuotaConsumedSubmissionStep.java b/service/src/main/java/bio/terra/pipelines/stairway/SubmitQuotaConsumedSubmissionStep.java new file mode 100644 index 00000000..10930c8d --- /dev/null +++ b/service/src/main/java/bio/terra/pipelines/stairway/SubmitQuotaConsumedSubmissionStep.java @@ -0,0 +1,131 @@ +package bio.terra.pipelines.stairway; + +import bio.terra.pipelines.common.utils.FlightUtils; +import bio.terra.pipelines.common.utils.PipelineVariableTypesEnum; +import bio.terra.pipelines.common.utils.PipelinesEnum; +import bio.terra.pipelines.db.entities.PipelineInputDefinition; +import bio.terra.pipelines.db.entities.PipelineOutputDefinition; +import bio.terra.pipelines.dependencies.rawls.RawlsService; +import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException; +import bio.terra.pipelines.dependencies.sam.SamService; +import bio.terra.pipelines.dependencies.stairway.JobMapKeys; +import bio.terra.pipelines.stairway.imputation.ImputationJobMapKeys; +import bio.terra.pipelines.stairway.utils.RawlsSubmissionStepHelper; +import bio.terra.rawls.model.SubmissionReport; +import bio.terra.rawls.model.SubmissionRequest; +import bio.terra.stairway.*; +import com.fasterxml.jackson.core.type.TypeReference; +import java.util.List; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This step submits a quota consumed wdl to cromwell using the rawls submission endpoint. The quota + * consumed wdl that is run depends on the workspace name and billing project provided to the step. + * + *

this step expects nothing from the working map + * + *

this step writes quota_submission_id to the working map + */ +public class SubmitQuotaConsumedSubmissionStep implements Step { + private final SamService samService; + private final RawlsService rawlsService; + + public static final String QUOTA_CONSUMED_METHOD_NAME = "QuotaConsumed"; + public static final List QUOTA_CONSUMED_OUTPUT_DEFINITION_LIST = + List.of( + new PipelineOutputDefinition( + null, "quotaConsumed", "quota_consumed", PipelineVariableTypesEnum.INTEGER)); + + private final Logger logger = LoggerFactory.getLogger(SubmitQuotaConsumedSubmissionStep.class); + + public SubmitQuotaConsumedSubmissionStep(RawlsService rawlsService, SamService samService) { + this.samService = samService; + this.rawlsService = rawlsService; + } + + @Override + @SuppressWarnings( + "java:S2259") // suppress warning for possible NPE when calling pipelineName.getValue(), + // since we do validate that pipelineName is not null in `validateRequiredEntries` + public StepResult doStep(FlightContext flightContext) { + // validate and extract parameters from input map + FlightMap inputParameters = flightContext.getInputParameters(); + FlightUtils.validateRequiredEntries( + inputParameters, + JobMapKeys.PIPELINE_NAME, + ImputationJobMapKeys.CONTROL_WORKSPACE_BILLING_PROJECT, + ImputationJobMapKeys.CONTROL_WORKSPACE_NAME, + ImputationJobMapKeys.WDL_METHOD_VERSION, + ImputationJobMapKeys.PIPELINE_INPUT_DEFINITIONS); + + PipelinesEnum pipelineName = inputParameters.get(JobMapKeys.PIPELINE_NAME, PipelinesEnum.class); + String controlWorkspaceName = + inputParameters.get(ImputationJobMapKeys.CONTROL_WORKSPACE_NAME, String.class); + String controlWorkspaceProject = + inputParameters.get(ImputationJobMapKeys.CONTROL_WORKSPACE_BILLING_PROJECT, String.class); + String wdlMethodVersion = + inputParameters.get(ImputationJobMapKeys.WDL_METHOD_VERSION, String.class); + List inputDefinitions = + inputParameters.get( + ImputationJobMapKeys.PIPELINE_INPUT_DEFINITIONS, new TypeReference<>() {}); + + // validate and extract parameters from working map + FlightMap workingMap = flightContext.getWorkingMap(); + + RawlsSubmissionStepHelper rawlsSubmissionStepHelper = + new RawlsSubmissionStepHelper( + rawlsService, samService, controlWorkspaceProject, controlWorkspaceName, logger); + + Optional validationResponse = + rawlsSubmissionStepHelper.validateRawlsSubmissionMethodHelper( + QUOTA_CONSUMED_METHOD_NAME, + wdlMethodVersion, + inputDefinitions, + QUOTA_CONSUMED_OUTPUT_DEFINITION_LIST, + pipelineName); + + // if there is a validation response that means the validation failed so return it + if (validationResponse.isPresent()) { + return validationResponse.get(); + } + + // create submission request + SubmissionRequest submissionRequest = + new SubmissionRequest() + .entityName(flightContext.getFlightId()) + .entityType(pipelineName.getValue()) + .useCallCache(true) + .deleteIntermediateOutputFiles(true) + .useReferenceDisks(false) + .userComment( + "%s - getting quota consumed for flight id: %s" + .formatted(pipelineName, flightContext.getFlightId())) + .methodConfigurationNamespace(controlWorkspaceProject) + .methodConfigurationName(QUOTA_CONSUMED_METHOD_NAME); + + // submit workflow to rawls + SubmissionReport submissionReport; + try { + submissionReport = + rawlsService.submitWorkflow( + samService.getTeaspoonsServiceAccountToken(), + submissionRequest, + controlWorkspaceProject, + controlWorkspaceName); + } catch (RawlsServiceApiException e) { + return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e); + } + + // add submission id to working map to be used for polling in downstream step + workingMap.put(ImputationJobMapKeys.QUOTA_SUBMISSION_ID, submissionReport.getSubmissionId()); + return StepResult.getStepResultSuccess(); + } + + @Override + public StepResult undoStep(FlightContext context) { + // nothing to undo; there's nothing to undo about submitting a run set + return StepResult.getStepResultSuccess(); + } +} diff --git a/service/src/main/java/bio/terra/pipelines/stairway/imputation/ImputationJobMapKeys.java b/service/src/main/java/bio/terra/pipelines/stairway/imputation/ImputationJobMapKeys.java index ceb3c062..5e85b3b4 100644 --- a/service/src/main/java/bio/terra/pipelines/stairway/imputation/ImputationJobMapKeys.java +++ b/service/src/main/java/bio/terra/pipelines/stairway/imputation/ImputationJobMapKeys.java @@ -18,6 +18,7 @@ public class ImputationJobMapKeys { "control_workspace_billing_project"; public static final String CONTROL_WORKSPACE_NAME = "control_workspace_name"; public static final String SUBMISSION_ID = "submission_id"; + public static final String QUOTA_SUBMISSION_ID = "quota_submission_id"; // Azure specific keys public static final String CONTROL_WORKSPACE_ID = "control_workspace_id"; diff --git a/service/src/main/java/bio/terra/pipelines/stairway/imputation/RunImputationGcpJobFlight.java b/service/src/main/java/bio/terra/pipelines/stairway/imputation/RunImputationGcpJobFlight.java index ebe4fc81..f0b43bd6 100644 --- a/service/src/main/java/bio/terra/pipelines/stairway/imputation/RunImputationGcpJobFlight.java +++ b/service/src/main/java/bio/terra/pipelines/stairway/imputation/RunImputationGcpJobFlight.java @@ -5,6 +5,9 @@ import bio.terra.pipelines.common.utils.FlightUtils; import bio.terra.pipelines.common.utils.PipelinesEnum; import bio.terra.pipelines.dependencies.stairway.JobMapKeys; +import bio.terra.pipelines.stairway.FetchQuotaConsumedFromDataTableStep; +import bio.terra.pipelines.stairway.PollQuotaConsumedSubmissionStatusStep; +import bio.terra.pipelines.stairway.SubmitQuotaConsumedSubmissionStep; import bio.terra.pipelines.stairway.imputation.steps.CompletePipelineRunStep; import bio.terra.pipelines.stairway.imputation.steps.PrepareImputationInputsStep; import bio.terra.pipelines.stairway.imputation.steps.gcp.AddDataTableRowStep; @@ -75,6 +78,23 @@ public RunImputationGcpJobFlight(FlightMap inputParameters, Object beanBag) { new AddDataTableRowStep(flightBeanBag.getRawlsService(), flightBeanBag.getSamService()), externalServiceRetryRule); + addStep( + new SubmitQuotaConsumedSubmissionStep( + flightBeanBag.getRawlsService(), flightBeanBag.getSamService()), + externalServiceRetryRule); + + addStep( + new PollQuotaConsumedSubmissionStatusStep( + flightBeanBag.getRawlsService(), + flightBeanBag.getSamService(), + flightBeanBag.getWdlPipelineConfiguration()), + externalServiceRetryRule); + + addStep( + new FetchQuotaConsumedFromDataTableStep( + flightBeanBag.getRawlsService(), flightBeanBag.getSamService()), + externalServiceRetryRule); + addStep( new SubmitCromwellSubmissionStep( flightBeanBag.getRawlsService(), @@ -84,8 +104,8 @@ public RunImputationGcpJobFlight(FlightMap inputParameters, Object beanBag) { addStep( new PollCromwellSubmissionStatusStep( - flightBeanBag.getSamService(), flightBeanBag.getRawlsService(), + flightBeanBag.getSamService(), flightBeanBag.getImputationConfiguration()), externalServiceRetryRule); diff --git a/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/FetchOutputsFromDataTableStep.java b/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/FetchOutputsFromDataTableStep.java index 7ce2790c..ecf5b3db 100644 --- a/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/FetchOutputsFromDataTableStep.java +++ b/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/FetchOutputsFromDataTableStep.java @@ -23,6 +23,8 @@ * This step calls Rawls to fetch outputs from a data table row for a given job and stores them in * the flight's working map. These outputs are considered raw in that they are cloud paths and not * signed urls. + * + *

This step expects nothing from the working map */ public class FetchOutputsFromDataTableStep implements Step { diff --git a/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/PollCromwellSubmissionStatusStep.java b/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/PollCromwellSubmissionStatusStep.java index f5000ffa..0da948ba 100644 --- a/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/PollCromwellSubmissionStatusStep.java +++ b/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/PollCromwellSubmissionStatusStep.java @@ -1,19 +1,13 @@ package bio.terra.pipelines.stairway.imputation.steps.gcp; -import bio.terra.common.exception.InternalServerErrorException; import bio.terra.pipelines.app.configuration.internal.ImputationConfiguration; import bio.terra.pipelines.common.utils.FlightUtils; import bio.terra.pipelines.dependencies.rawls.RawlsService; -import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException; import bio.terra.pipelines.dependencies.sam.SamService; import bio.terra.pipelines.stairway.imputation.ImputationJobMapKeys; -import bio.terra.rawls.model.Submission; -import bio.terra.rawls.model.Workflow; -import bio.terra.rawls.model.WorkflowStatus; +import bio.terra.pipelines.stairway.utils.RawlsSubmissionStepHelper; import bio.terra.stairway.*; -import java.util.List; import java.util.UUID; -import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,8 +27,8 @@ public class PollCromwellSubmissionStatusStep implements Step { private final Logger logger = LoggerFactory.getLogger(PollCromwellSubmissionStatusStep.class); public PollCromwellSubmissionStatusStep( - SamService samService, RawlsService rawlsService, + SamService samService, ImputationConfiguration imputationConfiguration) { this.samService = samService; this.rawlsService = rawlsService; @@ -59,43 +53,11 @@ public StepResult doStep(FlightContext flightContext) throws InterruptedExceptio UUID submissionId = workingMap.get(ImputationJobMapKeys.SUBMISSION_ID, UUID.class); - // poll until all runs are in a finalized state - Submission submissionResponse = null; - boolean stillRunning = true; - try { - while (stillRunning) { - submissionResponse = - rawlsService.getSubmissionStatus( - samService.getTeaspoonsServiceAccountToken(), - controlWorkspaceProject, - controlWorkspaceName, - submissionId); - stillRunning = RawlsService.submissionIsRunning(submissionResponse); - if (stillRunning) { - logger.info( - "Polling Started, sleeping for {} seconds", - imputationConfiguration.getCromwellSubmissionPollingIntervalInSeconds()); - TimeUnit.SECONDS.sleep( - imputationConfiguration.getCromwellSubmissionPollingIntervalInSeconds()); - } - } - } catch (RawlsServiceApiException e) { - return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e); - } - - // if there are any non-successful workflows, fatally fail the step - List failedRunLogs = - submissionResponse.getWorkflows().stream() - .filter(workflow -> !workflow.getStatus().equals(WorkflowStatus.SUCCEEDED)) - .toList(); - if (failedRunLogs.isEmpty()) { - return StepResult.getStepResultSuccess(); - } else { - return new StepResult( - StepStatus.STEP_RESULT_FAILURE_FATAL, - new InternalServerErrorException( - "Not all runs succeeded for submission: " + submissionId)); - } + RawlsSubmissionStepHelper rawlsSubmissionStepHelper = + new RawlsSubmissionStepHelper( + rawlsService, samService, controlWorkspaceProject, controlWorkspaceName, logger); + return rawlsSubmissionStepHelper.pollRawlsSubmissionHelper( + submissionId, imputationConfiguration.getCromwellSubmissionPollingIntervalInSeconds()); } @Override diff --git a/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/SubmitCromwellSubmissionStep.java b/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/SubmitCromwellSubmissionStep.java index f2c09beb..c535f6f1 100644 --- a/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/SubmitCromwellSubmissionStep.java +++ b/service/src/main/java/bio/terra/pipelines/stairway/imputation/steps/gcp/SubmitCromwellSubmissionStep.java @@ -10,12 +10,13 @@ import bio.terra.pipelines.dependencies.sam.SamService; import bio.terra.pipelines.dependencies.stairway.JobMapKeys; import bio.terra.pipelines.stairway.imputation.ImputationJobMapKeys; -import bio.terra.rawls.model.MethodConfiguration; +import bio.terra.pipelines.stairway.utils.RawlsSubmissionStepHelper; import bio.terra.rawls.model.SubmissionReport; import bio.terra.rawls.model.SubmissionRequest; import bio.terra.stairway.*; import com.fasterxml.jackson.core.type.TypeReference; import java.util.List; +import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -81,68 +82,24 @@ public StepResult doStep(FlightContext flightContext) { // validate and extract parameters from working map FlightMap workingMap = flightContext.getWorkingMap(); - MethodConfiguration methodConfiguration; - try { - // grab current method config and validate it - methodConfiguration = - rawlsService.getCurrentMethodConfigForMethod( - samService.getTeaspoonsServiceAccountToken(), - controlWorkspaceProject, - controlWorkspaceName, - wdlMethodName); - } catch (RawlsServiceApiException e) { - // if we fail to grab the method config then retry - return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e); - } - boolean validMethodConfig = - rawlsService.validateMethodConfig( - methodConfiguration, - pipelineName.getValue(), - wdlMethodName, - inputDefinitions, - outputDefinitions, - wdlMethodVersion); + RawlsSubmissionStepHelper rawlsSubmissionStepHelper = + new RawlsSubmissionStepHelper( + rawlsService, samService, controlWorkspaceProject, controlWorkspaceName, logger); - // if not a valid method config, set the method config to what we think it should be. This - // shouldn't happen - if (!validMethodConfig) { - logger.warn( - "found method config that was not valid for billing project: {}, workspace: {}, method name: {}, methodConfigVersion: {}", - controlWorkspaceProject, - controlWorkspaceName, - wdlMethodName, - methodConfiguration.getMethodConfigVersion()); + Optional validationResponse = + rawlsSubmissionStepHelper.validateRawlsSubmissionMethodHelper( + wdlMethodName, wdlMethodVersion, inputDefinitions, outputDefinitions, pipelineName); - MethodConfiguration updatedMethodConfiguration = - rawlsService.updateMethodConfigToBeValid( - methodConfiguration, - pipelineName.getValue(), - wdlMethodName, - inputDefinitions, - outputDefinitions, - wdlMethodVersion); - try { - // update method config version, inputs, and outputs - rawlsService.setMethodConfigForMethod( - samService.getTeaspoonsServiceAccountToken(), - updatedMethodConfiguration, - controlWorkspaceProject, - controlWorkspaceName, - wdlMethodName); - } catch (RawlsServiceApiException e) { - // if we fail to update the method config then retry - return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e); - } + // if there is a validation response that means the validation failed so return it + if (validationResponse.isPresent()) { + return validationResponse.get(); } // create submission request SubmissionRequest submissionRequest = new SubmissionRequest() .entityName(flightContext.getFlightId()) - .entityType( - pipelineName.getValue()) // this must match the configuration the method is set to - // launch with. Will be addressed in - // https://broadworkbench.atlassian.net/browse/TSPS-301 + .entityType(pipelineName.getValue()) .useCallCache(imputationConfiguration.isUseCallCaching()) .deleteIntermediateOutputFiles(imputationConfiguration.isDeleteIntermediateFiles()) .useReferenceDisks(imputationConfiguration.isUseReferenceDisk()) diff --git a/service/src/main/java/bio/terra/pipelines/stairway/utils/RawlsSubmissionStepHelper.java b/service/src/main/java/bio/terra/pipelines/stairway/utils/RawlsSubmissionStepHelper.java new file mode 100644 index 00000000..8b9a17f5 --- /dev/null +++ b/service/src/main/java/bio/terra/pipelines/stairway/utils/RawlsSubmissionStepHelper.java @@ -0,0 +1,142 @@ +package bio.terra.pipelines.stairway.utils; + +import bio.terra.common.exception.InternalServerErrorException; +import bio.terra.pipelines.common.utils.PipelinesEnum; +import bio.terra.pipelines.db.entities.PipelineInputDefinition; +import bio.terra.pipelines.db.entities.PipelineOutputDefinition; +import bio.terra.pipelines.dependencies.rawls.RawlsService; +import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException; +import bio.terra.pipelines.dependencies.sam.SamService; +import bio.terra.rawls.model.MethodConfiguration; +import bio.terra.rawls.model.Submission; +import bio.terra.rawls.model.Workflow; +import bio.terra.rawls.model.WorkflowStatus; +import bio.terra.stairway.StepResult; +import bio.terra.stairway.StepStatus; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; + +public class RawlsSubmissionStepHelper { + + private final SamService samService; + private final RawlsService rawlsService; + private final String controlWorkspaceProject; + private final String controlWorkspaceName; + private final Logger logger; + + public RawlsSubmissionStepHelper( + RawlsService rawlsService, + SamService samService, + String controlWorkspaceProject, + String controlWorkspaceName, + Logger logger) { + this.rawlsService = rawlsService; + this.samService = samService; + this.controlWorkspaceProject = controlWorkspaceProject; + this.controlWorkspaceName = controlWorkspaceName; + this.logger = logger; + } + + public StepResult pollRawlsSubmissionHelper(UUID submissionId, Long secondsToSleep) + throws InterruptedException { + // poll until all runs are in a finalized state + Submission submissionResponse = null; + boolean stillRunning = true; + try { + while (stillRunning) { + submissionResponse = + rawlsService.getSubmissionStatus( + samService.getTeaspoonsServiceAccountToken(), + controlWorkspaceProject, + controlWorkspaceName, + submissionId); + stillRunning = RawlsService.submissionIsRunning(submissionResponse); + if (stillRunning) { + logger.info("Polling Started, sleeping for {} seconds", secondsToSleep); + TimeUnit.SECONDS.sleep(secondsToSleep); + } + } + } catch (RawlsServiceApiException e) { + return new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e); + } + + // if there are any non-successful workflows, fatally fail the step + List failedRunLogs = + submissionResponse.getWorkflows().stream() + .filter(workflow -> !workflow.getStatus().equals(WorkflowStatus.SUCCEEDED)) + .toList(); + if (failedRunLogs.isEmpty()) { + return StepResult.getStepResultSuccess(); + } else { + return new StepResult( + StepStatus.STEP_RESULT_FAILURE_FATAL, + new InternalServerErrorException( + "Not all runs succeeded for submission: " + submissionId)); + } + } + + public Optional validateRawlsSubmissionMethodHelper( + String wdlMethodName, + String wdlMethodVersion, + List inputDefinitions, + List outputDefinitions, + PipelinesEnum pipelineName) { + MethodConfiguration methodConfiguration; + try { + // grab current method config and validate it + methodConfiguration = + rawlsService.getCurrentMethodConfigForMethod( + samService.getTeaspoonsServiceAccountToken(), + controlWorkspaceProject, + controlWorkspaceName, + wdlMethodName); + } catch (RawlsServiceApiException e) { + // if we fail to grab the method config then retry + return Optional.of(new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e)); + } + boolean validMethodConfig = + rawlsService.validateMethodConfig( + methodConfiguration, + pipelineName.getValue(), + wdlMethodName, + inputDefinitions, + outputDefinitions, + wdlMethodVersion); + + // if not a valid method config, set the method config to what we think it should be. This + // shouldn't happen + if (!validMethodConfig) { + logger.warn( + "found method config that was not valid for billing project: {}, workspace: {}, method name: {}, methodConfigVersion: {}", + controlWorkspaceProject, + controlWorkspaceName, + wdlMethodName, + methodConfiguration.getMethodConfigVersion()); + + MethodConfiguration updatedMethodConfiguration = + rawlsService.updateMethodConfigToBeValid( + methodConfiguration, + pipelineName.getValue(), + wdlMethodName, + inputDefinitions, + outputDefinitions, + wdlMethodVersion); + try { + // update method config version, inputs, and outputs + rawlsService.setMethodConfigForMethod( + samService.getTeaspoonsServiceAccountToken(), + updatedMethodConfiguration, + controlWorkspaceProject, + controlWorkspaceName, + wdlMethodName); + } catch (RawlsServiceApiException e) { + // if we fail to update the method config then retry + return Optional.of(new StepResult(StepStatus.STEP_RESULT_FAILURE_RETRY, e)); + } + } + return Optional.empty(); + } +} diff --git a/service/src/main/resources/application.yml b/service/src/main/resources/application.yml index 245e2c50..970c593f 100644 --- a/service/src/main/resources/application.yml +++ b/service/src/main/resources/application.yml @@ -151,6 +151,9 @@ pipelines: dsn: ${SENTRY_DSN:} environment: ${DEPLOY_ENV:} + wdl: + quotaConsumedPollingIntervalSeconds: 60 + terra.common: kubernetes: in-kubernetes: ${env.kubernetes.in-kubernetes} # whether to use a pubsub queue for Stairway; if false, use a local queue diff --git a/service/src/test/java/bio/terra/pipelines/common/utils/FlightBeanBagTest.java b/service/src/test/java/bio/terra/pipelines/common/utils/FlightBeanBagTest.java index b0efb1dd..a7c80d40 100644 --- a/service/src/test/java/bio/terra/pipelines/common/utils/FlightBeanBagTest.java +++ b/service/src/test/java/bio/terra/pipelines/common/utils/FlightBeanBagTest.java @@ -4,6 +4,7 @@ import bio.terra.pipelines.app.configuration.external.CbasConfiguration; import bio.terra.pipelines.app.configuration.internal.ImputationConfiguration; +import bio.terra.pipelines.app.configuration.internal.WdlPipelineConfiguration; import bio.terra.pipelines.dependencies.cbas.CbasService; import bio.terra.pipelines.dependencies.leonardo.LeonardoService; import bio.terra.pipelines.dependencies.rawls.RawlsService; @@ -30,6 +31,7 @@ class FlightBeanBagTest extends BaseEmbeddedDbTest { @Autowired private RawlsService rawlsService; @Autowired private ImputationConfiguration imputationConfiguration; @Autowired private CbasConfiguration cbasConfiguration; + @Autowired private WdlPipelineConfiguration wdlPipelineConfiguration; @Test void testFlightBeanBag() { @@ -45,7 +47,8 @@ void testFlightBeanBag() { rawlsService, workspaceManagerService, imputationConfiguration, - cbasConfiguration); + cbasConfiguration, + wdlPipelineConfiguration); assertEquals(pipelinesService, flightBeanBag.getPipelinesService()); assertEquals(pipelineRunsService, flightBeanBag.getPipelineRunsService()); assertEquals(pipelineInputsOutputsService, flightBeanBag.getPipelineInputsOutputsService()); @@ -57,5 +60,6 @@ void testFlightBeanBag() { assertEquals(rawlsService, flightBeanBag.getRawlsService()); assertEquals(imputationConfiguration, flightBeanBag.getImputationConfiguration()); assertEquals(cbasConfiguration, flightBeanBag.getCbasConfiguration()); + assertEquals(wdlPipelineConfiguration, flightBeanBag.getWdlPipelineConfiguration()); } } diff --git a/service/src/test/java/bio/terra/pipelines/stairway/imputation/RunImputationGcpFlightTest.java b/service/src/test/java/bio/terra/pipelines/stairway/imputation/RunImputationGcpFlightTest.java index a7b69de1..6ba8efea 100644 --- a/service/src/test/java/bio/terra/pipelines/stairway/imputation/RunImputationGcpFlightTest.java +++ b/service/src/test/java/bio/terra/pipelines/stairway/imputation/RunImputationGcpFlightTest.java @@ -32,6 +32,9 @@ class RunImputationGcpFlightTest extends BaseEmbeddedDbTest { List.of( "PrepareImputationInputsStep", "AddDataTableRowStep", + "SubmitQuotaConsumedSubmissionStep", + "PollQuotaConsumedSubmissionStatusStep", + "FetchQuotaConsumedFromDataTableStep", "SubmitCromwellSubmissionStep", "PollCromwellSubmissionStatusStep", "CompletePipelineRunStep", diff --git a/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/FetchQuotaConsumedFromDataTableStepTest.java b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/FetchQuotaConsumedFromDataTableStepTest.java new file mode 100644 index 00000000..efb46173 --- /dev/null +++ b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/FetchQuotaConsumedFromDataTableStepTest.java @@ -0,0 +1,126 @@ +package bio.terra.pipelines.stairway.imputation.steps; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +import bio.terra.common.exception.InternalServerErrorException; +import bio.terra.pipelines.common.utils.PipelinesEnum; +import bio.terra.pipelines.dependencies.rawls.RawlsService; +import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException; +import bio.terra.pipelines.dependencies.rawls.RawlsServiceException; +import bio.terra.pipelines.dependencies.sam.SamService; +import bio.terra.pipelines.stairway.FetchQuotaConsumedFromDataTableStep; +import bio.terra.pipelines.testutils.BaseEmbeddedDbTest; +import bio.terra.pipelines.testutils.StairwayTestUtils; +import bio.terra.pipelines.testutils.TestUtils; +import bio.terra.rawls.model.Entity; +import bio.terra.stairway.FlightContext; +import bio.terra.stairway.FlightMap; +import bio.terra.stairway.StepResult; +import bio.terra.stairway.StepStatus; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; + +class FetchQuotaConsumedFromDataTableStepTest extends BaseEmbeddedDbTest { + + @Mock RawlsService rawlsService; + @Mock SamService samService; + @Mock private FlightContext flightContext; + + @BeforeEach + void setup() { + var inputParameters = new FlightMap(); + var workingMap = new FlightMap(); + + when(flightContext.getInputParameters()).thenReturn(inputParameters); + when(flightContext.getWorkingMap()).thenReturn(workingMap); + when(samService.getTeaspoonsServiceAccountToken()).thenReturn("thisToken"); + } + + @Test + void doStepSuccess() throws RawlsServiceException { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + when(flightContext.getFlightId()).thenReturn(TestUtils.TEST_NEW_UUID.toString()); + + // outputs to match the test output definitions + Map entityAttributes = new HashMap<>(Map.of("quota_consumed", 1)); + Entity entity = new Entity().attributes(entityAttributes); + + when(rawlsService.getDataTableEntity( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + PipelinesEnum.ARRAY_IMPUTATION.getValue(), + TestUtils.TEST_NEW_UUID.toString())) + .thenReturn(entity); + + FetchQuotaConsumedFromDataTableStep fetchQuotaConsumedFromDataTableStep = + new FetchQuotaConsumedFromDataTableStep(rawlsService, samService); + StepResult result = fetchQuotaConsumedFromDataTableStep.doStep(flightContext); + + assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); + } + + @Test + void doStepRawlsFailureRetry() throws RawlsServiceException { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + when(flightContext.getFlightId()).thenReturn(TestUtils.TEST_NEW_UUID.toString()); + + when(rawlsService.getDataTableEntity( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + PipelinesEnum.ARRAY_IMPUTATION.getValue(), + TestUtils.TEST_NEW_UUID.toString())) + .thenThrow(new RawlsServiceApiException("Rawls Service Api Exception")); + + FetchQuotaConsumedFromDataTableStep fetchQuotaConsumedFromDataTableStep = + new FetchQuotaConsumedFromDataTableStep(rawlsService, samService); + StepResult result = fetchQuotaConsumedFromDataTableStep.doStep(flightContext); + + assertEquals(StepStatus.STEP_RESULT_FAILURE_RETRY, result.getStepStatus()); + } + + @Test + void doStepOutputsFailureNoRetry() throws InternalServerErrorException { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + when(flightContext.getFlightId()).thenReturn(TestUtils.TEST_NEW_UUID.toString()); + + // try with no quota_consumed attribute + Map entityAttributes = new HashMap<>(); + Entity entity = new Entity().attributes(entityAttributes); + + when(rawlsService.getDataTableEntity( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + PipelinesEnum.ARRAY_IMPUTATION.getValue(), + TestUtils.TEST_NEW_UUID.toString())) + .thenReturn(entity); + + FetchQuotaConsumedFromDataTableStep fetchQuotaConsumedFromDataTableStep = + new FetchQuotaConsumedFromDataTableStep(rawlsService, samService); + StepResult result = fetchQuotaConsumedFromDataTableStep.doStep(flightContext); + assertEquals(StepStatus.STEP_RESULT_FAILURE_FATAL, result.getStepStatus()); + + // try with quota_consumed attribute as 0 + entityAttributes.put("quota_consumed", 0); + result = fetchQuotaConsumedFromDataTableStep.doStep(flightContext); + assertEquals(StepStatus.STEP_RESULT_FAILURE_FATAL, result.getStepStatus()); + } + + @Test + void undoStepSuccess() { + FetchQuotaConsumedFromDataTableStep fetchQuotaConsumedFromDataTableStep = + new FetchQuotaConsumedFromDataTableStep(rawlsService, samService); + StepResult result = fetchQuotaConsumedFromDataTableStep.undoStep(flightContext); + + assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); + } +} diff --git a/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/PollQuotaConsumedSubmissionStatusStepTest.java b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/PollQuotaConsumedSubmissionStatusStepTest.java new file mode 100644 index 00000000..061cce1b --- /dev/null +++ b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/PollQuotaConsumedSubmissionStatusStepTest.java @@ -0,0 +1,166 @@ +package bio.terra.pipelines.stairway.imputation.steps; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.when; + +import bio.terra.pipelines.app.configuration.internal.WdlPipelineConfiguration; +import bio.terra.pipelines.dependencies.rawls.RawlsService; +import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException; +import bio.terra.pipelines.dependencies.sam.SamService; +import bio.terra.pipelines.stairway.PollQuotaConsumedSubmissionStatusStep; +import bio.terra.pipelines.stairway.imputation.ImputationJobMapKeys; +import bio.terra.pipelines.testutils.BaseEmbeddedDbTest; +import bio.terra.pipelines.testutils.StairwayTestUtils; +import bio.terra.pipelines.testutils.TestUtils; +import bio.terra.rawls.model.Submission; +import bio.terra.rawls.model.SubmissionStatus; +import bio.terra.rawls.model.Workflow; +import bio.terra.rawls.model.WorkflowStatus; +import bio.terra.stairway.FlightContext; +import bio.terra.stairway.FlightMap; +import bio.terra.stairway.StepResult; +import bio.terra.stairway.StepStatus; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.springframework.beans.factory.annotation.Autowired; + +class PollQuotaConsumedSubmissionStatusStepTest extends BaseEmbeddedDbTest { + + @Mock private RawlsService rawlsService; + @Mock private SamService samService; + @Autowired WdlPipelineConfiguration wdlPipelineConfiguration; + @Mock private FlightContext flightContext; + + private final UUID testJobId = TestUtils.TEST_NEW_UUID; + private final UUID randomUUID = UUID.randomUUID(); + + @BeforeEach + void setup() { + FlightMap inputParameters = new FlightMap(); + FlightMap workingMap = new FlightMap(); + workingMap.put(ImputationJobMapKeys.QUOTA_SUBMISSION_ID, randomUUID); + + when(flightContext.getInputParameters()).thenReturn(inputParameters); + when(flightContext.getWorkingMap()).thenReturn(workingMap); + when(samService.getTeaspoonsServiceAccountToken()).thenReturn("thisToken"); + } + + @Test + void doStepSuccess() throws InterruptedException { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + Submission response = + new Submission() + .status(SubmissionStatus.DONE) + .addWorkflowsItem(new Workflow().status(WorkflowStatus.SUCCEEDED)); + when(flightContext.getFlightId()).thenReturn(testJobId.toString()); + when(rawlsService.getSubmissionStatus( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + randomUUID)) + .thenReturn(response); + + // do the step + PollQuotaConsumedSubmissionStatusStep pollQuotaConsumedSubmissionStatusStep = + new PollQuotaConsumedSubmissionStatusStep( + rawlsService, samService, wdlPipelineConfiguration); + StepResult result = pollQuotaConsumedSubmissionStatusStep.doStep(flightContext); + + // make sure the step was a success + assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); + } + + @Test + void doStepRunningThenComplete() throws InterruptedException { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + Submission firstResponse = + new Submission() + .status(SubmissionStatus.SUBMITTED) + .addWorkflowsItem(new Workflow().status(WorkflowStatus.RUNNING)); + Submission secondResponse = + new Submission() + .status(SubmissionStatus.DONE) + .addWorkflowsItem(new Workflow().status(WorkflowStatus.SUCCEEDED)); + + when(flightContext.getFlightId()).thenReturn(testJobId.toString()); + when(rawlsService.getSubmissionStatus( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + randomUUID)) + .thenReturn(firstResponse) + .thenReturn(secondResponse); + + // do the step + PollQuotaConsumedSubmissionStatusStep pollQuotaConsumedSubmissionStatusStep = + new PollQuotaConsumedSubmissionStatusStep( + rawlsService, samService, wdlPipelineConfiguration); + StepResult result = pollQuotaConsumedSubmissionStatusStep.doStep(flightContext); + + // make sure the step was a success + assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); + } + + @Test + void doStepNotAllSuccessfulRuns() throws InterruptedException { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + Submission responseWithErrorRun = + new Submission() + .status(SubmissionStatus.DONE) + .addWorkflowsItem(new Workflow().status(WorkflowStatus.SUCCEEDED)) + .addWorkflowsItem(new Workflow().status(WorkflowStatus.FAILED)); + + when(flightContext.getFlightId()).thenReturn(testJobId.toString()); + when(rawlsService.getSubmissionStatus( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + randomUUID)) + .thenReturn(responseWithErrorRun); + + // do the step + PollQuotaConsumedSubmissionStatusStep pollQuotaConsumedSubmissionStatusStep = + new PollQuotaConsumedSubmissionStatusStep( + rawlsService, samService, wdlPipelineConfiguration); + StepResult result = pollQuotaConsumedSubmissionStatusStep.doStep(flightContext); + + // make sure the step fails + assertEquals(StepStatus.STEP_RESULT_FAILURE_FATAL, result.getStepStatus()); + } + + @Test + void doStepRawlsApiErrorRetry() throws InterruptedException { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + when(flightContext.getFlightId()).thenReturn(testJobId.toString()); + when(rawlsService.getSubmissionStatus( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + randomUUID)) + .thenThrow(new RawlsServiceApiException("this is the error message")); + + // do the step, expect a Retry status + PollQuotaConsumedSubmissionStatusStep pollQuotaConsumedSubmissionStatusStep = + new PollQuotaConsumedSubmissionStatusStep( + rawlsService, samService, wdlPipelineConfiguration); + StepResult result = pollQuotaConsumedSubmissionStatusStep.doStep(flightContext); + + assertEquals(StepStatus.STEP_RESULT_FAILURE_RETRY, result.getStepStatus()); + } + + @Test + void undoStepSuccess() { + PollQuotaConsumedSubmissionStatusStep pollQuotaConsumedSubmissionStatusStep = + new PollQuotaConsumedSubmissionStatusStep( + rawlsService, samService, wdlPipelineConfiguration); + StepResult result = pollQuotaConsumedSubmissionStatusStep.undoStep(flightContext); + + assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); + } +} diff --git a/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/SubmitQuotaConsumedSubmissionStepTest.java b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/SubmitQuotaConsumedSubmissionStepTest.java new file mode 100644 index 00000000..592e9b32 --- /dev/null +++ b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/SubmitQuotaConsumedSubmissionStepTest.java @@ -0,0 +1,243 @@ +package bio.terra.pipelines.stairway.imputation.steps; + +import static bio.terra.pipelines.testutils.TestUtils.VALID_METHOD_CONFIGURATION; +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.when; + +import bio.terra.pipelines.common.utils.PipelinesEnum; +import bio.terra.pipelines.dependencies.rawls.RawlsService; +import bio.terra.pipelines.dependencies.rawls.RawlsServiceApiException; +import bio.terra.pipelines.dependencies.sam.SamService; +import bio.terra.pipelines.stairway.SubmitQuotaConsumedSubmissionStep; +import bio.terra.pipelines.stairway.imputation.ImputationJobMapKeys; +import bio.terra.pipelines.testutils.BaseEmbeddedDbTest; +import bio.terra.pipelines.testutils.StairwayTestUtils; +import bio.terra.pipelines.testutils.TestUtils; +import bio.terra.rawls.model.MethodConfiguration; +import bio.terra.rawls.model.MethodRepoMethod; +import bio.terra.rawls.model.SubmissionReport; +import bio.terra.rawls.model.SubmissionRequest; +import bio.terra.stairway.FlightContext; +import bio.terra.stairway.FlightMap; +import bio.terra.stairway.StepResult; +import bio.terra.stairway.StepStatus; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; + +class SubmitQuotaConsumedSubmissionStepTest extends BaseEmbeddedDbTest { + @Mock private RawlsService rawlsService; + @Captor private ArgumentCaptor submissionRequestCaptor; + @Captor private ArgumentCaptor updateMethodConfigCaptor; + @Captor private ArgumentCaptor setMethodConfigCaptor; + @Mock private SamService samService; + @Mock private FlightContext flightContext; + + private final UUID testJobId = TestUtils.TEST_NEW_UUID; + private final UUID randomUUID = UUID.randomUUID(); + + @BeforeEach + void setup() { + FlightMap inputParameters = new FlightMap(); + FlightMap workingMap = new FlightMap(); + + when(flightContext.getInputParameters()).thenReturn(inputParameters); + when(flightContext.getWorkingMap()).thenReturn(workingMap); + when(samService.getTeaspoonsServiceAccountToken()).thenReturn("thisToken"); + } + + @Test + void doStepSuccess() { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + when(flightContext.getFlightId()).thenReturn(testJobId.toString()); + MethodConfiguration returnedMethodConfiguration = new MethodConfiguration(); + when(rawlsService.getCurrentMethodConfigForMethod( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME)) + .thenReturn(returnedMethodConfiguration); + when(rawlsService.validateMethodConfig( + returnedMethodConfiguration, + PipelinesEnum.ARRAY_IMPUTATION.getValue(), + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME, + TestUtils.TEST_PIPELINE_INPUTS_DEFINITION_LIST, + TestUtils.TEST_PIPELINE_OUTPUTS_DEFINITION_LIST, + TestUtils.TEST_WDL_METHOD_VERSION_1)) + .thenReturn(true); + when(rawlsService.submitWorkflow( + eq("thisToken"), + submissionRequestCaptor.capture(), + eq(TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT), + eq(TestUtils.CONTROL_WORKSPACE_NAME))) + .thenReturn(new SubmissionReport().submissionId(randomUUID.toString())); + + // do the step + SubmitQuotaConsumedSubmissionStep submitQuotaConsumedSubmissionStep = + new SubmitQuotaConsumedSubmissionStep(rawlsService, samService); + StepResult result = submitQuotaConsumedSubmissionStep.doStep(flightContext); + + // extract the captured RunSetRequest and validate + SubmissionRequest submissionRequest = submissionRequestCaptor.getValue(); + assertTrue(submissionRequest.isDeleteIntermediateOutputFiles()); + assertFalse(submissionRequest.isUseReferenceDisks()); + assertTrue(submissionRequest.isUseCallCache()); + assertEquals(testJobId.toString(), submissionRequest.getEntityName()); + + // make sure the step was a success + assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); + assertEquals( + randomUUID, + flightContext.getWorkingMap().get(ImputationJobMapKeys.QUOTA_SUBMISSION_ID, UUID.class)); + } + + @Test + void doStepWithInvalidMethodConfig() { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + when(flightContext.getFlightId()).thenReturn(testJobId.toString()); + // set up "current" method config to be version 1.1.1 with the corresponding method uri + MethodConfiguration returnedMethodConfiguration = + new MethodConfiguration() + .methodRepoMethod( + new MethodRepoMethod().methodUri("http/path/to/wdl/1.1.1").methodVersion("1.1.1")); + when(rawlsService.getCurrentMethodConfigForMethod( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME)) + .thenReturn(returnedMethodConfiguration); + when(rawlsService.validateMethodConfig( + returnedMethodConfiguration, + PipelinesEnum.ARRAY_IMPUTATION.getValue(), + TestUtils.TEST_WDL_METHOD_NAME_1, + TestUtils.TEST_PIPELINE_INPUTS_DEFINITION_LIST, + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_OUTPUT_DEFINITION_LIST, + TestUtils.TEST_WDL_METHOD_VERSION_1)) + .thenReturn(false); + when(rawlsService.updateMethodConfigToBeValid( + updateMethodConfigCaptor.capture(), + eq(PipelinesEnum.ARRAY_IMPUTATION.getValue()), + eq(SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME), + eq(TestUtils.TEST_PIPELINE_INPUTS_DEFINITION_LIST), + eq(SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_OUTPUT_DEFINITION_LIST), + eq(TestUtils.TEST_WDL_METHOD_VERSION_1))) + .thenReturn(VALID_METHOD_CONFIGURATION); + when(rawlsService.setMethodConfigForMethod( + eq("thisToken"), + setMethodConfigCaptor.capture(), + eq(TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT), + eq(TestUtils.CONTROL_WORKSPACE_NAME), + eq(SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME))) + .thenReturn(null); + when(rawlsService.submitWorkflow( + eq("thisToken"), + any(SubmissionRequest.class), + eq(TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT), + eq(TestUtils.CONTROL_WORKSPACE_NAME))) + .thenReturn(new SubmissionReport().submissionId(randomUUID.toString())); + + // do the step + SubmitQuotaConsumedSubmissionStep submitQuotaConsumedSubmissionStep = + new SubmitQuotaConsumedSubmissionStep(rawlsService, samService); + submitQuotaConsumedSubmissionStep.doStep(flightContext); + + // extract the captured updateMethodConfig input and setMethodConfig input and validate + MethodConfiguration updatedMethodConfigInput = updateMethodConfigCaptor.getValue(); + assertEquals("1.1.1", updatedMethodConfigInput.getMethodRepoMethod().getMethodVersion()); + assertEquals( + "http/path/to/wdl/1.1.1", updatedMethodConfigInput.getMethodRepoMethod().getMethodUri()); + + MethodConfiguration setMethodConfigInput = setMethodConfigCaptor.getValue(); + assertEquals(VALID_METHOD_CONFIGURATION, setMethodConfigInput); + } + + @Test + void doStepRawlsErrorRetry() { + // setup + StairwayTestUtils.constructCreateJobInputs(flightContext.getInputParameters()); + when(flightContext.getFlightId()).thenReturn(testJobId.toString()); + MethodConfiguration returnedMethodConfiguration = new MethodConfiguration(); + when(rawlsService.getCurrentMethodConfigForMethod( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME)) + .thenReturn(returnedMethodConfiguration); + when(rawlsService.validateMethodConfig( + returnedMethodConfiguration, + PipelinesEnum.ARRAY_IMPUTATION.getValue(), + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME, + TestUtils.TEST_PIPELINE_INPUTS_DEFINITION_LIST, + TestUtils.TEST_PIPELINE_OUTPUTS_DEFINITION_LIST, + TestUtils.TEST_WDL_METHOD_VERSION_1)) + .thenReturn(true); + + // throw exception on submitting workflow + when(rawlsService.submitWorkflow( + eq("thisToken"), + any(SubmissionRequest.class), + eq(TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT), + eq(TestUtils.CONTROL_WORKSPACE_NAME))) + .thenThrow(new RawlsServiceApiException("rawls is bad")); + // do the step + SubmitQuotaConsumedSubmissionStep submitQuotaConsumedSubmissionStep = + new SubmitQuotaConsumedSubmissionStep(rawlsService, samService); + StepResult result = submitQuotaConsumedSubmissionStep.doStep(flightContext); + // assert step is marked as retryable + assertEquals(StepStatus.STEP_RESULT_FAILURE_RETRY, result.getStepStatus()); + + // throw exception on setting method config to be valid + when(rawlsService.validateMethodConfig( + returnedMethodConfiguration, + PipelinesEnum.ARRAY_IMPUTATION.getValue(), + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME, + TestUtils.TEST_PIPELINE_INPUTS_DEFINITION_LIST, + TestUtils.TEST_PIPELINE_OUTPUTS_DEFINITION_LIST, + TestUtils.TEST_WDL_METHOD_VERSION_1)) + .thenReturn(false); + when(rawlsService.setMethodConfigForMethod( + "thisToken", + null, + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME)) + .thenThrow(new RawlsServiceApiException("rawls is bad")); + // do the step + submitQuotaConsumedSubmissionStep = + new SubmitQuotaConsumedSubmissionStep(rawlsService, samService); + result = submitQuotaConsumedSubmissionStep.doStep(flightContext); + // assert step is marked as retryable + assertEquals(StepStatus.STEP_RESULT_FAILURE_RETRY, result.getStepStatus()); + + // throw exception getting method config + when(rawlsService.getCurrentMethodConfigForMethod( + "thisToken", + TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT, + TestUtils.CONTROL_WORKSPACE_NAME, + SubmitQuotaConsumedSubmissionStep.QUOTA_CONSUMED_METHOD_NAME)) + .thenThrow(new RawlsServiceApiException("rawls is bad")); + // do the step + submitQuotaConsumedSubmissionStep = + new SubmitQuotaConsumedSubmissionStep(rawlsService, samService); + result = submitQuotaConsumedSubmissionStep.doStep(flightContext); + // assert step is marked as retryable + assertEquals(StepStatus.STEP_RESULT_FAILURE_RETRY, result.getStepStatus()); + } + + @Test + void undoStepSuccess() { + SubmitQuotaConsumedSubmissionStep submitQuotaConsumedSubmissionStep = + new SubmitQuotaConsumedSubmissionStep(rawlsService, samService); + StepResult result = submitQuotaConsumedSubmissionStep.undoStep(flightContext); + + assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); + } +} diff --git a/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/gcp/PollCromwellSubmissionStatusStepTest.java b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/gcp/PollCromwellSubmissionStatusStepTest.java index b6b5167f..f7245104 100644 --- a/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/gcp/PollCromwellSubmissionStatusStepTest.java +++ b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/gcp/PollCromwellSubmissionStatusStepTest.java @@ -63,7 +63,7 @@ void doStepSuccess() throws InterruptedException { // do the step PollCromwellSubmissionStatusStep pollCromwellSubmissionStatusStep = - new PollCromwellSubmissionStatusStep(samService, rawlsService, imputationConfiguration); + new PollCromwellSubmissionStatusStep(rawlsService, samService, imputationConfiguration); StepResult result = pollCromwellSubmissionStatusStep.doStep(flightContext); // make sure the step was a success @@ -94,7 +94,7 @@ void doStepRunningThenComplete() throws InterruptedException { // do the step PollCromwellSubmissionStatusStep pollCromwellSubmissionStatusStep = - new PollCromwellSubmissionStatusStep(samService, rawlsService, imputationConfiguration); + new PollCromwellSubmissionStatusStep(rawlsService, samService, imputationConfiguration); StepResult result = pollCromwellSubmissionStatusStep.doStep(flightContext); // make sure the step was a success @@ -121,7 +121,7 @@ void doStepNotAllSuccessfulRuns() throws InterruptedException { // do the step PollCromwellSubmissionStatusStep pollCromwellSubmissionStatusStep = - new PollCromwellSubmissionStatusStep(samService, rawlsService, imputationConfiguration); + new PollCromwellSubmissionStatusStep(rawlsService, samService, imputationConfiguration); StepResult result = pollCromwellSubmissionStatusStep.doStep(flightContext); // make sure the step fails @@ -142,7 +142,7 @@ void doStepRawlsApiErrorRetry() throws InterruptedException { // do the step, expect a Retry status PollCromwellSubmissionStatusStep pollCromwellSubmissionStatusStep = - new PollCromwellSubmissionStatusStep(samService, rawlsService, imputationConfiguration); + new PollCromwellSubmissionStatusStep(rawlsService, samService, imputationConfiguration); StepResult result = pollCromwellSubmissionStatusStep.doStep(flightContext); assertEquals(StepStatus.STEP_RESULT_FAILURE_RETRY, result.getStepStatus()); @@ -151,7 +151,7 @@ void doStepRawlsApiErrorRetry() throws InterruptedException { @Test void undoStepSuccess() { PollCromwellSubmissionStatusStep pollCromwellSubmissionStatusStep = - new PollCromwellSubmissionStatusStep(samService, rawlsService, imputationConfiguration); + new PollCromwellSubmissionStatusStep(rawlsService, samService, imputationConfiguration); StepResult result = pollCromwellSubmissionStatusStep.undoStep(flightContext); assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); diff --git a/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/gcp/SubmitCromwellSubmissionStepTest.java b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/gcp/SubmitCromwellSubmissionStepTest.java index 47fa4e80..af8b3ca0 100644 --- a/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/gcp/SubmitCromwellSubmissionStepTest.java +++ b/service/src/test/java/bio/terra/pipelines/stairway/imputation/steps/gcp/SubmitCromwellSubmissionStepTest.java @@ -38,6 +38,7 @@ class SubmitCromwellSubmissionStepTest extends BaseEmbeddedDbTest { @Autowired private ImputationConfiguration imputationConfiguration; private final UUID testJobId = TestUtils.TEST_NEW_UUID; + private final UUID randomUUID = UUID.randomUUID(); @BeforeEach void setup() { @@ -74,7 +75,7 @@ void doStepSuccess() { submissionRequestCaptor.capture(), eq(TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT), eq(TestUtils.CONTROL_WORKSPACE_NAME))) - .thenReturn(new SubmissionReport().submissionId(testJobId.toString())); + .thenReturn(new SubmissionReport().submissionId(randomUUID.toString())); // do the step SubmitCromwellSubmissionStep submitCromwellSubmissionStep = @@ -91,7 +92,7 @@ void doStepSuccess() { // make sure the step was a success assertEquals(StepStatus.STEP_RESULT_SUCCESS, result.getStepStatus()); assertEquals( - testJobId, + randomUUID, flightContext.getWorkingMap().get(ImputationJobMapKeys.SUBMISSION_ID, UUID.class)); } @@ -139,7 +140,7 @@ void doStepWithInvalidMethodConfig() { any(SubmissionRequest.class), eq(TestUtils.CONTROL_WORKSPACE_BILLING_PROJECT), eq(TestUtils.CONTROL_WORKSPACE_NAME))) - .thenReturn(new SubmissionReport().submissionId(testJobId.toString())); + .thenReturn(new SubmissionReport().submissionId(randomUUID.toString())); // do the step SubmitCromwellSubmissionStep submitCromwellSubmissionStep = diff --git a/service/src/test/resources/application-test.yml b/service/src/test/resources/application-test.yml index 6fa74b99..e45ca530 100644 --- a/service/src/test/resources/application-test.yml +++ b/service/src/test/resources/application-test.yml @@ -70,3 +70,6 @@ pipelines: sentry: dsn: https://public@sentry.example.com/1 env: doesntmatter + + wdl: + quotaConsumedPollingIntervalSeconds: 1