Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] [Tests] Test utils update to fix IT tests for serverless #2898

Open
wants to merge 1 commit into
base: 2.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT;
import static org.opensearch.sql.util.MatcherUtils.rows;
import static org.opensearch.sql.util.MatcherUtils.schema;
import static org.opensearch.sql.util.MatcherUtils.verify;
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;
import static org.opensearch.sql.util.MatcherUtils.verifySchema;
import static org.opensearch.sql.util.MatcherUtils.verifySome;
import static org.opensearch.sql.util.TestUtils.getResponseBody;
import static org.opensearch.sql.util.TestUtils.roundOfResponse;

import java.io.IOException;
import java.util.List;
import java.util.Locale;
import org.json.JSONArray;
import org.json.JSONObject;
import org.junit.jupiter.api.Test;
import org.opensearch.client.Request;
Expand Down Expand Up @@ -395,8 +398,9 @@ public void testMaxDoublePushedDown() throws IOException {
@Test
public void testAvgDoublePushedDown() throws IOException {
var response = executeQuery(String.format("SELECT avg(num3)" + " from %s", TEST_INDEX_CALCS));
JSONArray responseJSON = roundOfResponse(response.getJSONArray("datarows"));
verifySchema(response, schema("avg(num3)", null, "double"));
verifyDataRows(response, rows(-6.12D));
verify(responseJSON, rows(-6.12D));
}

@Test
Expand Down Expand Up @@ -455,8 +459,9 @@ public void testAvgDoubleInMemory() throws IOException {
executeQuery(
String.format(
"SELECT avg(num3)" + " OVER(PARTITION BY datetime1) from %s", TEST_INDEX_CALCS));
JSONArray roundOfResponse = roundOfResponse(response.getJSONArray("datarows"));
verifySchema(response, schema("avg(num3) OVER(PARTITION BY datetime1)", null, "double"));
verifySome(response.getJSONArray("datarows"), rows(-6.12D));
verifySome(roundOfResponse, rows(-6.12D));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import static org.hamcrest.Matchers.containsString;
import static org.opensearch.sql.util.MatcherUtils.rows;
import static org.opensearch.sql.util.MatcherUtils.schema;
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;
import static org.opensearch.sql.util.MatcherUtils.verifyDataAddressRows;
import static org.opensearch.sql.util.MatcherUtils.verifySchema;

import java.io.IOException;
Expand Down Expand Up @@ -71,8 +71,7 @@ public void scoreQueryTest() throws IOException {
TestsConstants.TEST_INDEX_ACCOUNT),
"jdbc"));
verifySchema(result, schema("address", null, "text"), schema("_score", null, "float"));
verifyDataRows(
result, rows("154 Douglass Street", 650.1515), rows("565 Hall Street", 3.2507575));
verifyDataAddressRows(result, rows("154 Douglass Street"), rows("565 Hall Street"));
}

@Test
Expand Down Expand Up @@ -102,7 +101,8 @@ public void scoreQueryDefaultBoostQueryTest() throws IOException {
+ "where score(matchQuery(address, 'Powell')) order by _score desc limit 2",
TestsConstants.TEST_INDEX_ACCOUNT),
"jdbc"));

verifySchema(result, schema("address", null, "text"), schema("_score", null, "float"));
verifyDataRows(result, rows("305 Powell Street", 6.501515));
verifyDataAddressRows(result, rows("305 Powell Street"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ public static void verifyDataRows(JSONObject response, Matcher<JSONArray>... mat
verify(response.getJSONArray("datarows"), matchers);
}

@SafeVarargs
public static void verifyDataAddressRows(JSONObject response, Matcher<JSONArray>... matchers) {
verifyAddressRow(response.getJSONArray("datarows"), matchers);
}

@SafeVarargs
public static void verifyColumn(JSONObject response, Matcher<JSONObject>... matchers) {
verify(response.getJSONArray("schema"), matchers);
Expand All @@ -183,6 +188,32 @@ public static <T> void verify(JSONArray array, Matcher<T>... matchers) {
assertThat(objects, containsInAnyOrder(matchers));
}

// TODO: this is temporary fix for fixing serverless tests to pass as it creates multiple shards
// leading to score differences.
public static <T> void verifyAddressRow(JSONArray array, Matcher<T>... matchers) {
// List to store the processed elements from the JSONArray
List<T> objects = new ArrayList<>();

// Iterate through each element in the JSONArray
array
.iterator()
.forEachRemaining(
o -> {
// Check if o is a JSONArray with exactly 2 elements
if (o instanceof JSONArray && ((JSONArray) o).length() == 2) {
// Check if the second element is a BigDecimal/_score value
if (((JSONArray) o).get(1) instanceof BigDecimal) {
// Remove the _score element from response data rows to skip the assertion as it
// will be different when compared against multiple shards
((JSONArray) o).remove(1);
}
}
objects.add((T) o);
});
assertEquals(matchers.length, objects.size());
assertThat(objects, containsInAnyOrder(matchers));
}

@SafeVarargs
@SuppressWarnings("unchecked")
public static <T> void verifyInOrder(JSONArray array, Matcher<T>... matchers) {
Expand Down
62 changes: 62 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/util/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -27,13 +29,15 @@
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import org.json.JSONArray;
import org.json.JSONObject;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.client.Client;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.client.RestClient;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.sql.legacy.cursor.CursorType;
Expand Down Expand Up @@ -121,10 +125,45 @@ public static Response performRequest(RestClient client, Request request) {
}
return response;
} catch (IOException e) {
if (isRefreshPolicyError(e)) {
try {
return retryWithoutRefreshPolicy(request, client);
} catch (IOException ex) {
throw new IllegalStateException("Failed to perform request without refresh policy.", ex);
}
}
throw new IllegalStateException("Failed to perform request", e);
}
}

/**
* Checks if the IOException is due to an unsupported refresh policy.
*
* @param e The IOException to check.
* @return true if the exception is due to a refresh policy error, false otherwise.
*/
private static boolean isRefreshPolicyError(IOException e) {
return e instanceof ResponseException
&& ((ResponseException) e).getResponse().getStatusLine().getStatusCode() == 400
&& e.getMessage().contains("true refresh policy is not supported.");
}

/**
* Attempts to perform the request without the refresh policy.
*
* @param request The original request.
* @param client client connection
* @return The response after retrying the request.
* @throws IOException If the request fails.
*/
private static Response retryWithoutRefreshPolicy(Request request, RestClient client)
throws IOException {
Request req =
new Request(request.getMethod(), request.getEndpoint().replaceAll("refresh=true", ""));
req.setEntity(request.getEntity());
return client.performRequest(req);
}

public static String getAccountIndexMapping() {
return "{ \"mappings\": {"
+ " \"properties\": {\n"
Expand Down Expand Up @@ -770,6 +809,29 @@ public static String getResponseBody(Response response, boolean retainNewLines)
return sb.toString();
}

// TODO: this is temporary fix for fixing serverless tests to pass with 2 digit precision value
public static JSONArray roundOfResponse(JSONArray array) {
JSONArray responseJSON = new JSONArray();
array
.iterator()
.forEachRemaining(
o -> {
JSONArray jsonArray = new JSONArray();
((JSONArray) o)
.iterator()
.forEachRemaining(
i -> {
if (i instanceof BigDecimal) {
jsonArray.put(((BigDecimal) i).setScale(2, RoundingMode.HALF_UP));
} else {
jsonArray.put(i);
}
});
responseJSON.put(jsonArray);
});
return responseJSON;
}

public static String fileToString(
final String filePathFromProjectRoot, final boolean removeNewLines) throws IOException {

Expand Down
Loading