Skip to content

Commit

Permalink
Wrote some tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Amin F committed Dec 28, 2017
1 parent d474f52 commit fdf5181
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 45 deletions.
29 changes: 18 additions & 11 deletions src/main/java/de/upb/pasestub/PaseInstance.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ public boolean create(String constructor, Map<String, Object> parameters)
// create was already called. Stop create
throw createAlreadyCalled();
}
if(constructor == null || constructor.trim().isEmpty() || parameters == null){
throw new NullPointerException();
}
String jsonString = serialize(parameters);
Response serverResponse = httpPost(host + "/" + constructor, jsonString);
if(serverResponse.code() != 200){
Expand All @@ -111,6 +114,9 @@ public boolean create(String constructor, Map<String, Object> parameters)
public Object getAttribute(String attributeName)
throws IOException, JsonProcessingException{
checkCreated();
if(attributeName == null || attributeName.trim().isEmpty()){
throw new NullPointerException();
}
Response serverResponse = httpGet(getInstanceUrl() + "/" + attributeName);
if(serverResponse.code() != 200){
throw responseErrorCode(serverResponse);
Expand All @@ -123,6 +129,9 @@ public Object getAttribute(String attributeName)
public Object callFunction(String functionName, Map<String, Object> parameters)
throws JsonProcessingException, IOException{
checkCreated();
if(functionName == null || functionName.trim().isEmpty() || parameters == null){
throw new NullPointerException();
}
String jsonString = serialize(parameters);
Response serverResponse = httpPost(getInstanceUrl() + "/" + functionName, jsonString);
if(serverResponse.code() != 200){
Expand All @@ -138,7 +147,7 @@ public Object callFunction(String functionName, Map<String, Object> parameters)
/**
* Handles basic http post using OkHttp.
*/
Response httpPost(String url, String bodyString) throws IOException{
private Response httpPost(String url, String bodyString) throws IOException{
OkHttpClient client = new OkHttpClient();
MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, bodyString);
Expand All @@ -154,7 +163,7 @@ Response httpPost(String url, String bodyString) throws IOException{
/**
* Handles basic http get using OkHttp.
*/
Response httpGet(String url) throws IOException{
private Response httpGet(String url) throws IOException{
OkHttpClient client = new OkHttpClient();

Request request = new Request.Builder()
Expand All @@ -171,7 +180,7 @@ Response httpGet(String url) throws IOException{
/**
* JSON-Serializes the given map.
*/
String serialize(Map<String, Object> map) throws JsonProcessingException{
private String serialize(Map<String, Object> map) throws JsonProcessingException{
ObjectMapper mapper = new ObjectMapper();
String jsonResult = mapper.writerWithDefaultPrettyPrinter()
.writeValueAsString(map);
Expand All @@ -181,7 +190,7 @@ String serialize(Map<String, Object> map) throws JsonProcessingException{
/**
* JSON-Deserializes the given json string to a map.
*/
Map<String, Object> deserializeMap(String jsonString) throws IOException, JsonProcessingException {
private Map<String, Object> deserializeMap(String jsonString) throws IOException, JsonProcessingException {
ObjectMapper mapper = new ObjectMapper();
TypeReference<HashMap<String, Object>> typeRef
= new TypeReference<HashMap<String, Object>>() {};
Expand All @@ -191,7 +200,7 @@ Map<String, Object> deserializeMap(String jsonString) throws IOException, JsonPr
/**
* JSON-Deserializes the given json string to a pojo.
*/
Object deserializeObject(String jsonString) throws IOException, JsonProcessingException {
private Object deserializeObject(String jsonString) throws IOException, JsonProcessingException {
ObjectMapper mapper = new ObjectMapper();
TypeReference<Object> typeRef
= new TypeReference<Object>() {};
Expand All @@ -209,7 +218,7 @@ boolean isCreated(){
/**
* Throws RuntimeException if create hasn't been called before.
*/
void checkCreated(){
private void checkCreated(){
if(!isCreated()){
throw createNotCalled();
}
Expand All @@ -219,24 +228,22 @@ void checkCreated(){
/**
* Creates a IllegalArgumentException for Server responses that aren't 200.
*/
IllegalArgumentException responseErrorCode(Response serverResponse) throws IOException{
private IllegalArgumentException responseErrorCode(Response serverResponse) throws IOException{
return new IllegalArgumentException(serverResponse.code() + "\n: " + serverResponse.body().string());
}

/**
* Creates a IllegalStateException that indicates that the create function has not been called before.
* Used in functions where create-state is mandatory.
*/
IllegalStateException createNotCalled(){
private IllegalStateException createNotCalled(){
return new IllegalStateException("create function was not called.");
}
/**
* Creates a IllegalStateException that indicates that the create function has already been called.
* Used in the create function to avoid calling it twice.
*/
IllegalStateException createAlreadyCalled(){
private IllegalStateException createAlreadyCalled(){
return new IllegalStateException("create function has already been called.");
}


}
1 change: 1 addition & 0 deletions src/test/java/de/upb/pasestub/AllTestsSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
@RunWith(Suite.class)
@Suite.SuiteClasses({
de.upb.pasestub.PaseInstanceTest.class
,de.upb.pasestub.DeployTest.class // Uncomment if a pase server is running on port 5000
})
public final class AllTestsSuite {}
71 changes: 71 additions & 0 deletions src/test/java/de/upb/pasestub/DeployTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package de.upb.pasestub;


import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import org.junit.Assert;
import org.junit.Test;

import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.*;

public class DeployTest{

/**
* Tests basic functionality.
*/
@Test
public void deployTest1() throws Exception{
PaseInstance instance = new PaseInstance("localhost:5000");
Map<String, Object> parameters = new HashMap<String, Object>();
parameters.put("a", 5);
parameters.put("b", 20);
boolean success = instance.create("plainlib.package1.b.B", parameters);

Assert.assertTrue(success);

System.out.println(instance.getInstanceUrl());

int a = (Integer) instance.getAttribute("a");
Assert.assertEquals(a, 5);

parameters = new HashMap<String, Object>();
parameters.put("c", 2);
int result = (Integer) instance.callFunction("calc", parameters);
Assert.assertEquals(result, 45);
}

@Test
public void deployTest_LinearRegression() throws Exception {
PaseInstance instance = new PaseInstance("localhost:5000");
Map<String, Object> parameters = new HashMap<String, Object>();
parameters.put("normalize", true);
boolean success = instance.create("sklearn.linear_model.LinearRegression", parameters);

Assert.assertTrue(success && instance.isCreated());

parameters.clear();
double[][] X = {{0,0}, {1,1}, {2,2}};
parameters.put("X", X);
double [] y = {0,1,2};
parameters.put("y", y);
instance.callFunction("fit", parameters);

// You will have to know the structure of the return value:
Map<String, Object> returnedMap = (Map<String,Object>) instance.getAttribute("coef_");
ArrayList<Double> coef_ = (ArrayList<Double>) returnedMap.get("values");
Assert.assertEquals(0.5, (double) coef_.get(0), 0.01);

parameters.clear();
double[][] X2 = {{0.5, 1}, {1, 0.5}};
parameters.put("X", X2);
Map<String, Object> returnedMap2 = (Map<String,Object>) instance.callFunction("predict", parameters);
ArrayList<Double> predictions = (ArrayList<Double>) returnedMap2.get("values");
List<Double> expected = Arrays.asList(0.75, 0.75);
assertThat(predictions, is(expected));
}
}
103 changes: 69 additions & 34 deletions src/test/java/de/upb/pasestub/PaseInstanceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,55 +20,90 @@ public class PaseInstanceTest{

public PaseInstanceTest(){
}

private int port = 30000;
private String host = "localhost:" + port;
@Rule
public WireMockRule wireMockRule = new WireMockRule(30000); // No-args constructor defaults to port 8080
public WireMockRule wireMockRule = new WireMockRule(port);


@Test
public void test1() throws Exception{
PaseInstance instance = new PaseInstance("localhost:5000");
Map<String, Object> parameters = new HashMap<String, Object>();
parameters.put("a", 5);
parameters.put("b", 20);
boolean success = instance.create("plainlib.package1.b.B", parameters);
public void testCorrectCreate() throws IOException{
// Create mock
String constructor = "plainlib.package1.b.B";
stubFor(post(urlEqualTo("/" + constructor))
.withHeader("content-type", equalToIgnoreCase("application/json; charset=UTF-8"))
.withRequestBody(equalToJson("{\"a\" : 10, \"b\" : 20}"))
.willReturn(aResponse()
.withStatus(200)
.withHeader("Content-Type", "application/json")
.withBody("{\"id\": \"7B495ECC9C\", \"class\": \"plainlib.package1.b.B\"}")));

PaseInstance instance = new PaseInstance(host);
Map<String, Object> map = new HashMap<String, Object>();
map.put("a", 10);
map.put("b", 20);

Assert.assertFalse(instance.isCreated());
boolean success = instance.create(constructor, map);
Assert.assertTrue(success);
Assert.assertTrue(instance.isCreated());

System.out.println(instance.getInstanceUri());

int a = (Integer) instance.getAttribute("a");
Assert.assertEquals(a, 5);

parameters = new HashMap<String, Object>();
parameters.put("c", 2);
int result = (Integer) instance.callFunction("cablc", parameters);
Assert.assertEquals(result, 45);

Assert.assertEquals("7B495ECC9C", instance.getId());
Assert.assertEquals("plainlib.package1.b.B", instance.getClassName());

Assert.assertEquals(host + "/plainlib.package1.b.B/7B495ECC9C", instance.getInstanceUrl());

}
//@Test
public void exampleTest() throws IOException{
stubFor(post(urlEqualTo("/plainlib.package1.b.B"))
.withHeader("content-type", equalToIgnoreCase("application/json; charset=UTF-8"))
@Test
public void testCorrect1() throws IOException{
// Create mock
String constructor = "plainlib.package1.b.B";
stubFor(post(urlEqualTo("/" + constructor))
.withRequestBody(equalToJson("{\"a\" : 10, \"b\" : 20}"))
.willReturn(aResponse()
.withStatus(200)
.withHeader("Content-Type", "application/json")
.withBody("{\"id\": \"7B495ECC9C\", \"class\": \"plainlib.package1.b.B\"}")));

OkHttpClient client = new OkHttpClient();

MediaType mediaType = MediaType.parse("application/json");
RequestBody body = RequestBody.create(mediaType, "{\"a\" : 10, \"b\" : 20}");
Request request = new Request.Builder()
.url("http://localhost:30000/plainlib.package1.b.B")
.post(body)
.addHeader("content-type", "application/json")
.build();
Response response = client.newCall(request).execute();
System.out.println("\n\n" + response.body().string().toString());

stubFor(get(urlEqualTo("/plainlib.package1.b.B/7B495ECC9C/b")).willReturn(aResponse()
.withStatus(200)
.withHeader("Content-Type", "application/json")
.withBody("10")));

stubFor(post(urlEqualTo("/plainlib.package1.b.B/7B495ECC9C/calc"))
.withRequestBody(equalToJson("{\"c\" : 5}"))
.willReturn(aResponse()
.withStatus(200)
.withHeader("Content-Type", "application/json")
.withBody("110")));


PaseInstance instance = new PaseInstance(host);
Map<String, Object> map = new HashMap<String, Object>();
map.put("a", 10);
map.put("b", 20);
instance.create(constructor, map);

int b = (Integer) instance.getAttribute("b");
Assert.assertEquals(10, b);

map = new HashMap<String, Object>();
map.put("c", 5);
int result = (Integer) instance.callFunction("calc", map);
Assert.assertEquals(110, result);
}

@Test(expected = IOException.class)
public void testWrongHostCreate() throws IOException{

PaseInstance instance = new PaseInstance("localhost:10000"); // server shouldn't be accessible
instance.create("con", new HashMap<String, Object>()); // Should be throwing IOException
}

@Test(expected = IllegalStateException.class)
public void testNoCreateCall() throws IOException{
PaseInstance instance = new PaseInstance("localhost:10000"); // server shouldn't be accessible
instance.callFunction("func", new HashMap<String, Object>());
}

}

0 comments on commit fdf5181

Please sign in to comment.