Skip to content

Commit

Permalink
Add helper plugin to fetch models
Browse files Browse the repository at this point in the history
  • Loading branch information
manojlds committed Nov 11, 2018
1 parent 416f496 commit 2c8df51
Show file tree
Hide file tree
Showing 27 changed files with 658 additions and 2 deletions.
56 changes: 56 additions & 0 deletions fetch/src/main/java/com/indix/mlflow_gocd/FetchConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.indix.mlflow_gocd;


import java.util.Map;

public class FetchConfig {

private final String repo;
private final String pkg;
private final String artifactPattern;
private final String destination;
private final String destinationFile;

public String getRepo() {
return escapeEnvironmentVariable(repo);
}

public String getPkg() {
return escapeEnvironmentVariable(pkg);
}

public String getArtifactPattern() {
return artifactPattern;
}

public String getDestination() {
return destination;
}

public String getDestinationFile() {
return destinationFile;
}

public FetchConfig(Map config) {
repo = getValue(config, MLFLowFetchArtifactTask.REPO);
pkg = getValue(config, MLFLowFetchArtifactTask.PACKAGE);
artifactPattern = getValue(config, MLFLowFetchArtifactTask.ARTIFACT_PATTERN);
destination = getValue(config, MLFLowFetchArtifactTask.DESTINATION);
destinationFile = getValue(config, MLFLowFetchArtifactTask.DESTINATION_FILE);
}

private String escapeEnvironmentVariable(String value) {
if (value == null) {
return "";
}
return value.replaceAll("[^A-Za-z0-9_]", "_").toUpperCase();
}

private String getValue(Map config, String property) {
Map propertyMap = ((Map) config.get(property));
if (propertyMap != null) {
return (String) propertyMap.get("value");
}
return null;
}
}
74 changes: 74 additions & 0 deletions fetch/src/main/java/com/indix/mlflow_gocd/GoEnvironment.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.indix.mlflow_gocd;

import org.apache.commons.lang3.BooleanUtils;

import java.util.*;
import java.util.regex.Pattern;

import static org.apache.commons.lang3.StringUtils.isNotEmpty;

public class GoEnvironment {
public static final String AWS_USE_IAM_ROLE = "AWS_USE_IAM_ROLE";
public static final String AWS_REGION = "AWS_REGION";
public static final String AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY";
public static final String AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID";

private Pattern envPat = Pattern.compile("\\$\\{(\\w+)\\}");
private Map<String, String> environment = new HashMap<>();

public GoEnvironment() {
this.environment.putAll(System.getenv());
}

public GoEnvironment(Map<String, String> defaultEnvironment) {
this();
this.environment.putAll(defaultEnvironment);
}

public GoEnvironment putAll(Map<String, String> existing) {
environment.putAll(existing);
return this;
}

public Map<String,String> asMap() { return environment; }

public String get(String name) {
return environment.get(name);
}

public String getOrElse(String name, String defaultValue) {
if(has(name)) return get(name);
else return defaultValue;
}

public boolean has(String name) {
return environment.containsKey(name) && isNotEmpty(get(name));
}

public boolean isAbsent(String name) {
return !has(name);
}

private static final List<String> validUseIamRoleValues = new ArrayList<String>(Arrays.asList("true", "false", "yes", "no", "on", "off"));
public boolean hasAWSUseIamRole() {
if (!has(AWS_USE_IAM_ROLE)) {
return false;
}

String useIamRoleValue = get(AWS_USE_IAM_ROLE);
Boolean result = BooleanUtils.toBooleanObject(useIamRoleValue);
if (result == null) {
throw new IllegalArgumentException(getEnvInvalidFormatMessage(AWS_USE_IAM_ROLE,
useIamRoleValue, validUseIamRoleValues.toString()));
}
else {
return result.booleanValue();
}
}

private String getEnvInvalidFormatMessage(String environmentVariable, String value, String expected){
return String.format(
"Unexpected value in %s environment variable; was %s, but expected one of the following %s",
environmentVariable, value, expected);
}
}
241 changes: 241 additions & 0 deletions fetch/src/main/java/com/indix/mlflow_gocd/MLFLowFetchArtifactTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
package com.indix.mlflow_gocd;

import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.InstanceProfileCredentialsProvider;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.s3.model.GetObjectRequest;
import com.amazonaws.services.s3.model.ListObjectsRequest;
import com.amazonaws.services.s3.model.ObjectListing;
import com.amazonaws.services.s3.model.S3ObjectSummary;
import com.google.gson.GsonBuilder;
import com.indix.mlflow_gocd.models.TaskExecutionResult;
import com.thoughtworks.go.plugin.api.GoApplicationAccessor;
import com.thoughtworks.go.plugin.api.GoPlugin;
import com.thoughtworks.go.plugin.api.GoPluginIdentifier;
import com.thoughtworks.go.plugin.api.annotation.Extension;
import com.thoughtworks.go.plugin.api.exceptions.UnhandledRequestTypeException;
import com.thoughtworks.go.plugin.api.logging.Logger;
import com.thoughtworks.go.plugin.api.request.GoPluginApiRequest;
import com.thoughtworks.go.plugin.api.response.DefaultGoPluginApiResponse;
import com.thoughtworks.go.plugin.api.response.GoPluginApiResponse;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import static com.indix.mlflow_gocd.GoEnvironment.AWS_ACCESS_KEY_ID;
import static com.indix.mlflow_gocd.GoEnvironment.AWS_REGION;
import static com.indix.mlflow_gocd.GoEnvironment.AWS_SECRET_ACCESS_KEY;

@Extension
public class MLFLowFetchArtifactTask implements GoPlugin {

public static final String REPO = "Repo";
public static final String PACKAGE = "Package";
public static final String ARTIFACT_PATTERN = "ArtifactPattern";
public static final String DESTINATION = "Destination";
public static final String DESTINATION_FILE = "DestinationFile";

private static Logger logger = Logger.getLoggerFor(MLFLowFetchArtifactTask.class);

@Override
public void initializeGoApplicationAccessor(GoApplicationAccessor goApplicationAccessor) {

}

@Override
public GoPluginApiResponse handle(GoPluginApiRequest request) throws UnhandledRequestTypeException {
if ("configuration".equals(request.requestName())) {
return handleGetConfigRequest();
} else if ("validate".equals(request.requestName())) {
return handleValidation(request);
} else if ("execute".equals(request.requestName())) {
return handleTaskExecution(request);
} else if ("view".equals(request.requestName())) {
return handleTaskView();
}
throw new UnhandledRequestTypeException(request.requestName());
}

private GoPluginApiResponse handleTaskView() {
int responseCode = DefaultGoPluginApiResponse.SUCCESS_RESPONSE_CODE;
Map view = new HashMap();
view.put("displayValue", "Fetch artifacts from MLFlow run");
try {
view.put("template", IOUtils.toString(getClass().getResourceAsStream("/views/task.template.html"), "UTF-8"));
} catch (Exception e) {
responseCode = DefaultGoPluginApiResponse.INTERNAL_ERROR;
String errorMessage = "Failed to find template: " + e.getMessage();
view.put("exception", errorMessage);
logger.error(errorMessage, e);
}
return createResponse(responseCode, view);
}

private GoPluginApiResponse handleTaskExecution(GoPluginApiRequest request) {
TaskExecutionResult result;
try {
Map executionRequest = (Map) new GsonBuilder().create().fromJson(request.requestBody(), Object.class);
Map configMap = (Map) executionRequest.get("config");
TaskContext context = new TaskContext((Map) executionRequest.get("context"));
FetchConfig config = new FetchConfig(configMap);
final GoEnvironment env = new GoEnvironment(context.getEnvironmentVariables());
final AmazonS3 client = getS3client(env);
final String artifactsUri = env.get(String.format("GO_PACKAGE_%s_%s_ARTIFACT_URI", config.getRepo(), config.getPkg()));
context.printMessage(String.format("Artifacts uri for %s - %s is %s", config.getRepo(), config.getPkg(), artifactsUri));
final String bucketName = artifactsUri.split("//|/")[1];
final String prefix = artifactsUri.replace(String.format("s3://%s/", bucketName), "");
context.printMessage(String.format("Looking for artifact with pattern %s in prefix %s", config.getArtifactPattern(), prefix));
final String s3Prefix = getPrefixS3(client, bucketName, prefix, config.getArtifactPattern());
if (s3Prefix != null) {
String destination = String.format("%s/%s", context.getWorkingDir(), config.getDestination());
if (StringUtils.isNotBlank(config.getDestination())) {
setupDestinationDirectory(destination);
}
if (StringUtils.isNotBlank(config.getDestinationFile())) {
destination = String.format("%s/%s", destination.replaceFirst("/$", ""), config.getDestinationFile());
}
context.printMessage(String.format("Getting artifacts from s3://%s/%s to %s", bucketName, s3Prefix, destination));
getS3(client, bucketName, s3Prefix, destination);
result = new TaskExecutionResult(true, "Fetched all artifacts");
} else {
result = new TaskExecutionResult(false, "Specified artifacts not found");
}


} catch(Exception ex) {
String message = String.format("Failure while downloading artifacts - %s", ex.getMessage());
logger.error(message, ex);
result = new TaskExecutionResult(false, message, ex);
}

return createResponse(result.responseCode(), result.toMap());

}

private GoPluginApiResponse handleValidation(GoPluginApiRequest request) {
Map configMap = (Map) new GsonBuilder().create().fromJson(request.requestBody(), Object.class);
FetchConfig config = new FetchConfig(configMap);

Map<String, String> errors = new HashMap<>();
if (StringUtils.isBlank(config.getRepo())) {
errors.put(REPO, "This field is required");
}
if (StringUtils.isBlank(config.getPkg())) {
errors.put(PACKAGE, "This field is required");
}

if (StringUtils.isBlank(config.getArtifactPattern())) {
errors.put(ARTIFACT_PATTERN, "This field is required");
}
final HashMap validationResult = new HashMap();
if (!errors.isEmpty()) {
validationResult.put("errors", errors);
}

return createResponse(DefaultGoPluginApiResponse.SUCCESS_RESPONSE_CODE, validationResult);
}

private GoPluginApiResponse handleGetConfigRequest() {
HashMap config = new HashMap();

HashMap repo = new HashMap();
repo.put("default-value", "");
repo.put("required", true);
config.put(REPO, repo);

HashMap pkg = new HashMap();
pkg.put("default-value", "");
pkg.put("required", true);
config.put(PACKAGE, pkg);

HashMap sourcePrefix = new HashMap();
sourcePrefix.put("default-value", "");
sourcePrefix.put("required", true);
config.put(ARTIFACT_PATTERN, sourcePrefix);

HashMap destination = new HashMap();
destination.put("default-value", "");
destination.put("required", false);
config.put(DESTINATION, destination);

HashMap destinationFile = new HashMap();
destinationFile.put("default-value", "");
destinationFile.put("required", false);
config.put(DESTINATION_FILE, destinationFile);

return createResponse(DefaultGoPluginApiResponse.SUCCESS_RESPONSE_CODE, config);
}

@Override
public GoPluginIdentifier pluginIdentifier() {
return new GoPluginIdentifier("task", Arrays.asList("1.0"));
}

private GoPluginApiResponse createResponse(int responseCode, Map body) {
final DefaultGoPluginApiResponse response = new DefaultGoPluginApiResponse(responseCode);
response.setResponseBody(new GsonBuilder().serializeNulls().create().toJson(body));
return response;
}

private void setupDestinationDirectory(String destination) {
File destinationDirectory = new File(destination);
try {
if(!destinationDirectory.exists()) {
FileUtils.forceMkdir(destinationDirectory);
}
} catch (IOException ioe) {
logger.error(String.format("Error while setting up destination - %s", ioe.getMessage()), ioe);
}
}

private static AmazonS3 getS3client(GoEnvironment env) {
AmazonS3ClientBuilder amazonS3ClientBuilder = AmazonS3ClientBuilder.standard();

if (env.has(AWS_REGION)) {
amazonS3ClientBuilder.withRegion(env.get(AWS_REGION));
}
if (env.hasAWSUseIamRole()) {
amazonS3ClientBuilder.withCredentials(new InstanceProfileCredentialsProvider(false));
} else if (env.has(AWS_ACCESS_KEY_ID) && env.has(AWS_SECRET_ACCESS_KEY)) {
BasicAWSCredentials basicCreds = new BasicAWSCredentials(env.get(AWS_ACCESS_KEY_ID), env.get(AWS_SECRET_ACCESS_KEY));
amazonS3ClientBuilder.withCredentials(new AWSStaticCredentialsProvider(basicCreds));
}

return amazonS3ClientBuilder.build();
}

public String getPrefixS3(AmazonS3 client, String bucket, String prefix, String artifactPattern) {
ListObjectsRequest listObjectsRequest = new ListObjectsRequest()
.withBucketName(bucket)
.withPrefix(prefix);

ObjectListing objectListing;
do {
objectListing = client.listObjects(listObjectsRequest);
for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) {
if (objectSummary.getSize() > 0 && objectSummary.getKey().matches(String.format("%s/%s", prefix, artifactPattern))) {
return objectSummary.getKey();
}
}
listObjectsRequest.setMarker(objectListing.getNextMarker());
} while (objectListing.isTruncated());

return null;
}

public void getS3(AmazonS3 client, String bucket, String from, String to) {
GetObjectRequest getObjectRequest = new GetObjectRequest(bucket, from);
File destinationFile = new File(to);
destinationFile.getParentFile().mkdirs();
client.getObject(getObjectRequest, destinationFile);
}

}
Loading

0 comments on commit 2c8df51

Please sign in to comment.