diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..f5c99a7 --- /dev/null +++ b/.travis.yml @@ -0,0 +1 @@ +language: java \ No newline at end of file diff --git a/README.md b/README.md index d904797..6f1ba37 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ Java client stub for a [PASE](https://github.com/aminfa/Pase) server. ## Code Example Take a look at the code example section of the [PASE repository](https://github.com/aminfa/Pase). -The same operations can execute using a `PaseInstance`: +The same operations can be executed using a `PaseInstance`: ```java PaseInstance instance = new PaseInstance("localhost:5000"); // specify host diff --git a/pom.xml b/pom.xml index 9b08727..a5eecc6 100644 --- a/pom.xml +++ b/pom.xml @@ -12,6 +12,8 @@ UTF-8 + 1.8 + 1.8 diff --git a/src/main/java/de/upb/pasestub/PaseInstance.java b/src/main/java/de/upb/pasestub/PaseInstance.java index 78595b6..9b984a4 100644 --- a/src/main/java/de/upb/pasestub/PaseInstance.java +++ b/src/main/java/de/upb/pasestub/PaseInstance.java @@ -1,13 +1,13 @@ package de.upb.pasestub; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; - import java.io.IOException; import java.util.HashMap; import java.util.Map; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; @@ -17,7 +17,7 @@ /** * Mutable PaseInterface Implementation. */ -public final class PaseInstance implements PaseInterface{ +public final class PaseInstance implements PaseInterface { /** * Flag that indicated that the create function has beed called successfully. */ @@ -39,86 +39,96 @@ public final class PaseInstance implements PaseInterface{ /** * PaseInstance defined with the given http host. (Don't include http:// in host) */ - public PaseInstance(String host){ + public PaseInstance(String host) { this.host = host; } /** * PaseInstance defined to access the Pase with standard port running on the same machine. */ - public PaseInstance(){ + public PaseInstance() { this("localhost:5000"); } + /** + * Constructor is used by 'copy' method to initialize a 'created' PaseInstance object. + */ + private PaseInstance(String host, String className, String id){ + this(host); + this.className = className; + this.id = id; + this.creationFlag = true; + } + // GETTERS: /** * Returns the pase instance id. */ - public String getId(){ + public String getId() { checkCreated(); // may throw Exception. return id; } + /** * Returns class Name that was assigned */ - public String getClassName(){ + public String getClassName() { checkCreated(); // may throw Exception. return className; } + /** * Returns the host url that this object was assigned to use. */ - public String getHost(){ + public String getHost() { return host; } /** * Returns the instance url that is used to access this instance on the pase server. */ - String getInstanceUrl(){ + String getInstanceUrl() { checkCreated(); // may throw Exception. return getHost() + "/" + getClassName() + "/" + getId(); } // INTERFACE: @Override - public boolean create(String constructor, Map parameters) - throws JsonProcessingException, IOException { - if(isCreated()){ + public boolean create(String constructor, Map parameters) + throws JsonProcessingException, IOException { + if (isCreated()) { // create was already called. Stop create - throw createAlreadyCalled(); + throw createAlreadyCalled(); } - if(constructor == null || constructor.trim().isEmpty() || parameters == null){ + if (constructor == null || constructor.trim().isEmpty() || parameters == null) { throw new NullPointerException(); } String jsonString = serialize(parameters); Response serverResponse = httpPost(host + "/" + constructor, jsonString); - if(serverResponse.code() != 200){ + if (serverResponse.code() != 200) { return false; } Map returnValues = deserializeMap(serverResponse.body().string()); - if(returnValues.containsKey("id") && returnValues.containsKey("class")){ + if (returnValues.containsKey("id") && returnValues.containsKey("class")) { id = returnValues.get("id").toString(); className = returnValues.get("class").toString(); creationFlag = true; return true; - } - else{ + } else { return false; } } - + @Override - public Object getAttribute(String attributeName) - throws IOException, JsonProcessingException{ - checkCreated(); - if(attributeName == null || attributeName.trim().isEmpty()){ + public Object getAttribute(String attributeName) throws IOException, JsonProcessingException { + checkCreated(); + if (attributeName == null || attributeName.trim().isEmpty()) { throw new NullPointerException(); } Response serverResponse = httpGet(getInstanceUrl() + "/" + attributeName); - if(serverResponse.code() != 200){ + if (serverResponse.code() != 200) { throw responseErrorCode(serverResponse); } Object pojo = deserializeObject(serverResponse.body().string()); @@ -126,36 +136,44 @@ public Object getAttribute(String attributeName) } @Override - public Object callFunction(String functionName, Map parameters) - throws JsonProcessingException, IOException{ - checkCreated(); - if(functionName == null || functionName.trim().isEmpty() || parameters == null){ + public Object callFunction(String functionName, Map parameters) + throws JsonProcessingException, IOException { + checkCreated(); + if (functionName == null || functionName.trim().isEmpty() || parameters == null) { throw new NullPointerException(); } String jsonString = serialize(parameters); Response serverResponse = httpPost(getInstanceUrl() + "/" + functionName, jsonString); - if(serverResponse.code() != 200){ + if (serverResponse.code() != 200) { throw responseErrorCode(serverResponse); } Object pojo = deserializeObject(serverResponse.body().string()); return pojo; } - + @Override + public PaseInterface cloneObject() throws JsonProcessingException, IOException { + checkCreated(); + Response serverResponse = httpGet(host + "/" + getClassName() + "/copy/" + getId()); + if (serverResponse.code() != 200) { + throw responseErrorCode(serverResponse); + } + Map returnMap = deserializeMap(serverResponse.body().string()); + String newClassName = returnMap.get("class").toString(); + String newId = returnMap.get("id").toString(); + return new PaseInstance(host, newClassName, newId); + } // HELPER FUNCTIONS: /** * Handles basic http post using OkHttp. */ - private Response httpPost(String url, String bodyString) throws IOException{ + private Response httpPost(String url, String bodyString) throws IOException { OkHttpClient client = new OkHttpClient(); MediaType mediaType = MediaType.parse("application/json"); RequestBody body = RequestBody.create(mediaType, bodyString); - Request request = new Request.Builder() - .url("http://" + url) - .post(body) - .addHeader("content-type", "application/json") - .build(); + Request request = new Request.Builder().url("http://" + url).post(body) + .addHeader("content-type", "application/json").build(); Response response = client.newCall(request).execute(); return response; } @@ -163,27 +181,23 @@ private Response httpPost(String url, String bodyString) throws IOException{ /** * Handles basic http get using OkHttp. */ - private Response httpGet(String url) throws IOException{ + private Response httpGet(String url) throws IOException { OkHttpClient client = new OkHttpClient(); - Request request = new Request.Builder() - .url("http://" + url) - .addHeader("content-type", "application/json") - .build(); + Request request = new Request.Builder().url("http://" + url).addHeader("content-type", "application/json") + .build(); Response response = client.newCall(request).execute(); return response; } - //TODO: custom json parser if there are any problems parsing objects. /** * JSON-Serializes the given map. */ - private String serialize(Map map) throws JsonProcessingException{ + private String serialize(Map map) throws JsonProcessingException { ObjectMapper mapper = new ObjectMapper(); - String jsonResult = mapper.writerWithDefaultPrettyPrinter() - .writeValueAsString(map); + String jsonResult = mapper.writerWithDefaultPrettyPrinter().writeValueAsString(map); return jsonResult; } @@ -192,18 +206,19 @@ private String serialize(Map map) throws JsonProcessingException */ private Map deserializeMap(String jsonString) throws IOException, JsonProcessingException { ObjectMapper mapper = new ObjectMapper(); - TypeReference> typeRef - = new TypeReference>() {}; + TypeReference> typeRef = new TypeReference>() { + }; Map map = mapper.readValue(jsonString, typeRef); return map; } + /** * JSON-Deserializes the given json string to a pojo. */ private Object deserializeObject(String jsonString) throws IOException, JsonProcessingException { ObjectMapper mapper = new ObjectMapper(); - TypeReference typeRef - = new TypeReference() {}; + TypeReference typeRef = new TypeReference() { + }; Object pojo = mapper.readValue(jsonString, typeRef); return pojo; } @@ -211,15 +226,15 @@ private Object deserializeObject(String jsonString) throws IOException, JsonProc /** * Returns true, if the create method has beed called successfully. */ - boolean isCreated(){ + boolean isCreated() { return creationFlag; } /** * Throws RuntimeException if create hasn't been called before. */ - private void checkCreated(){ - if(!isCreated()){ + private void checkCreated() { + if (!isCreated()) { throw createNotCalled(); } } @@ -228,7 +243,7 @@ private void checkCreated(){ /** * Creates a IllegalArgumentException for Server responses that aren't 200. */ - private IllegalArgumentException responseErrorCode(Response serverResponse) throws IOException{ + private IllegalArgumentException responseErrorCode(Response serverResponse) throws IOException { return new IllegalArgumentException(serverResponse.code() + "\n: " + serverResponse.body().string()); } @@ -236,14 +251,16 @@ private IllegalArgumentException responseErrorCode(Response serverResponse) thro * Creates a IllegalStateException that indicates that the create function has not been called before. * Used in functions where create-state is mandatory. */ - private IllegalStateException createNotCalled(){ + private IllegalStateException createNotCalled() { return new IllegalStateException("create function was not called."); } + /** * Creates a IllegalStateException that indicates that the create function has already been called. * Used in the create function to avoid calling it twice. */ - private IllegalStateException createAlreadyCalled(){ + private IllegalStateException createAlreadyCalled() { return new IllegalStateException("create function has already been called."); } + } \ No newline at end of file diff --git a/src/main/java/de/upb/pasestub/PaseInterface.java b/src/main/java/de/upb/pasestub/PaseInterface.java index 0b8baf2..defb1eb 100644 --- a/src/main/java/de/upb/pasestub/PaseInterface.java +++ b/src/main/java/de/upb/pasestub/PaseInterface.java @@ -1,14 +1,15 @@ package de.upb.pasestub; +import java.io.IOException; import java.util.Map; + import com.fasterxml.jackson.core.JsonProcessingException; -import java.io.IOException; /** * Defines java client stub interface to connect to a pase server. * Objects of PaseInterface must call the create function before getAttribute or callFunction, else an Illegal State Exception will be thrown. */ -public interface PaseInterface{ +public interface PaseInterface { /** * Creates this interface through a http request to a pase server. This is the first function to be called. @@ -19,7 +20,7 @@ public interface PaseInterface{ * @throws IOException when there are problems connecting to the pase server. */ public boolean create(String constructor, Map parameters) - throws JsonProcessingException, IOException; + throws JsonProcessingException, IOException; /** * Retrieves the value of the attribute with the given name from the server. @@ -30,9 +31,8 @@ public boolean create(String constructor, Map parameters) * @throws IOException when there are problems connecting to the pase server. * */ - public Object getAttribute(String attributeName) - throws IOException, JsonProcessingException; - + public Object getAttribute(String attributeName) throws IOException, JsonProcessingException; + /** * Calls the function with the given functionName on the server and returns it's value. * @param functionName: function name to be called. Will be used in the http request to the pase server. @@ -43,7 +43,16 @@ public Object getAttribute(String attributeName) * @throws IOException when there are problems connecting to the pase server. * */ - public Object callFunction(String functionName, Map parameters) - throws JsonProcessingException, IOException; + public Object callFunction(String functionName, Map parameters) + throws JsonProcessingException, IOException; + + /** + * Copies this object by making a 'copy' request to the Pase Server. + * + * @return A new PaseInterface object with the same class name but different id. + * + */ + public PaseInterface cloneObject() throws JsonProcessingException, IOException; + } \ No newline at end of file diff --git a/src/test/java/de/upb/pasestub/AllTestsSuite.java b/src/test/java/de/upb/pasestub/AllTestsSuite.java index c371208..31e261f 100644 --- a/src/test/java/de/upb/pasestub/AllTestsSuite.java +++ b/src/test/java/de/upb/pasestub/AllTestsSuite.java @@ -5,6 +5,6 @@ @RunWith(Suite.class) @Suite.SuiteClasses({ de.upb.pasestub.PaseInstanceTest.class - ,de.upb.pasestub.DeployTest.class // Uncomment if a pase server is running on port 5000 + //,de.upb.pasestub.DeployTest.class // Uncomment if a pase server is running on port 5000 }) public final class AllTestsSuite {} \ No newline at end of file diff --git a/src/test/java/de/upb/pasestub/DeployTest.java b/src/test/java/de/upb/pasestub/DeployTest.java index 3f71517..342c508 100644 --- a/src/test/java/de/upb/pasestub/DeployTest.java +++ b/src/test/java/de/upb/pasestub/DeployTest.java @@ -1,71 +1,112 @@ package de.upb.pasestub; - import java.util.ArrayList; -import java.util.List; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.junit.Assert; import org.junit.Test; -import static org.hamcrest.CoreMatchers.*; -import static org.junit.Assert.*; +/** + * Tests that are executed if there is a Pase server running in the background. + */ +public class DeployTest { -public class DeployTest{ + private static final double PRECISION = 0.01; - /** - * Tests basic functionality. - */ - @Test - public void deployTest1() throws Exception{ - PaseInstance instance = new PaseInstance("localhost:5000"); - Map parameters = new HashMap(); - parameters.put("a", 5); - parameters.put("b", 20); - boolean success = instance.create("plainlib.package1.b.B", parameters); + /** + * Tests basic functionality. + */ + @Test + public void deployTest1() throws Exception { + PaseInstance instance = new PaseInstance("localhost:5000"); + Map parameters = new HashMap(); + parameters.put("a", 5); + parameters.put("b", 20); + boolean success = instance.create("plainlib.package1.b.B", parameters); - Assert.assertTrue(success); + Assert.assertTrue(success); - System.out.println(instance.getInstanceUrl()); + // System.out.println(instance.getInstanceUrl()); - int a = (Integer) instance.getAttribute("a"); - Assert.assertEquals(a, 5); + int a = (Integer) instance.getAttribute("a"); + Assert.assertEquals(a, 5); - parameters = new HashMap(); - parameters.put("c", 2); - int result = (Integer) instance.callFunction("calc", parameters); + parameters = new HashMap(); + parameters.put("c", 2); + int result = (Integer) instance.callFunction("calc", parameters); Assert.assertEquals(result, 45); - } + + PaseInstance instance2 = (PaseInstance) instance.cloneObject(); + Assert.assertEquals(instance.getClassName(), instance2.getClassName()); + Assert.assertNotEquals(instance.getId(), instance2.getId()); + + } + + @Test + public void deployTest_LinearRegression() throws Exception { + PaseInstance instance = new PaseInstance("localhost:5000"); + Map parameters = new HashMap(); + parameters.put("normalize", true); + boolean success = instance.create("sklearn.linear_model.LinearRegression", parameters); + Assert.assertTrue(success && instance.isCreated()); + + parameters.clear(); + double[][] X = { { 0, 0 }, { 1, 1 }, { 2, 2 } }; + parameters.put("X", X); + double[] y = { 0, 1, 2 }; + parameters.put("y", y); + instance.callFunction("fit", parameters); + + // You will have to know the structure of the return value: + ArrayList coef_ = (ArrayList) instance.getAttribute("coef_"); + Assert.assertEquals(0.5, (double) coef_.get(0), PRECISION); + + parameters.clear(); + double[][] X2 = { { 0.5, 1 }, { 1, 0.5 } }; + parameters.put("X", X2); + ArrayList predictions = (ArrayList) instance.callFunction("predict", parameters); + List expected = Arrays.asList(0.75, 0.75); + Assert.assertEquals(expected.get(0), (double) predictions.get(0), PRECISION); + + } + + /** + * Tests this peace of python code: + * >>> from sklearn import linear_model + * >>> reg = linear_model.Ridge (alpha = .5) + * >>> reg.fit ([[0, 0], [0, 0], [1, 1]], [0, .1, 1]) + * Ridge(alpha=0.5, copy_X=True, fit_intercept=True, max_iter=None, + * normalize=False, random_state=None, solver='auto', tol=0.001) + * >>> reg.predict([[1,2],[10,20],[100,200]]) + * array([ 1.17272727, 10.5 , 103.77272727]) + */ @Test - public void deployTest_LinearRegression() throws Exception { - PaseInstance instance = new PaseInstance("localhost:5000"); - Map parameters = new HashMap(); - parameters.put("normalize", true); - boolean success = instance.create("sklearn.linear_model.LinearRegression", parameters); + public void deployTest_Ridge() throws Exception { + PaseInstance ridge = new PaseInstance(); + Map params = new HashMap(); + params.put("alpha", 0.5); + ridge.create("sklearn.linear_model.Ridge", params); + Assert.assertTrue(ridge.isCreated()); - Assert.assertTrue(success && instance.isCreated()); - - parameters.clear(); - double[][] X = {{0,0}, {1,1}, {2,2}}; - parameters.put("X", X); - double [] y = {0,1,2}; - parameters.put("y", y); - instance.callFunction("fit", parameters); - - // You will have to know the structure of the return value: - Map returnedMap = (Map) instance.getAttribute("coef_"); - ArrayList coef_ = (ArrayList) returnedMap.get("values"); - Assert.assertEquals(0.5, (double) coef_.get(0), 0.01); - - parameters.clear(); - double[][] X2 = {{0.5, 1}, {1, 0.5}}; - parameters.put("X", X2); - Map returnedMap2 = (Map) instance.callFunction("predict", parameters); - ArrayList predictions = (ArrayList) returnedMap2.get("values"); - List expected = Arrays.asList(0.75, 0.75); - assertThat(predictions, is(expected)); + params.clear(); + float[][] X_fit = {{0,0},{0,0},{1,1}}; + float[] y_fit = {0,0.1f,1}; + params.put("X", X_fit); + params.put("y", y_fit); + ridge.callFunction("fit", params); + + params.clear(); + float[][] X_predict = {{1,2},{10,20},{100,200}}; + params.put("X", X_predict); + ArrayList predictionResults = (ArrayList) ridge.callFunction("predict", params); + + List expected = Arrays.asList(1.17272727, 10.5 , 103.77272727); + for(int index = 0, size = predictionResults.size(); index()); } + @Test + public void cloneTest() throws Exception{ + + stubFor(post(urlEqualTo("/construct")) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"id\": \"7B495ECC9C\", \"class\": \"class_name\"}"))); + + stubFor(get(urlEqualTo("/class_name/copy/7B495ECC9C")) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"id\": \"7B495ECC99\", \"class\": \"class_name\"}"))); + + PaseInstance instance = new PaseInstance(host); + instance.create("construct", new HashMap<>()); + PaseInstance instanceClone = (PaseInstance) instance.cloneObject(); + + Assert.assertEquals(instance.getClassName(), instanceClone.getClassName()); + Assert.assertNotEquals(instance.getId(), instanceClone.getId()); + } }