diff --git a/.eclipseformat.xml b/.eclipseformat.xml new file mode 100644 index 00000000..6e93f9b2 --- /dev/null +++ b/.eclipseformat.xml @@ -0,0 +1,362 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/.gitignore b/.gitignore index c07e1002..cc9ec9b6 100644 --- a/.gitignore +++ b/.gitignore @@ -15,7 +15,7 @@ __pycache__/ RankLib.jar tmdb.json model.txt - +.DS_Store .idea/ .gradle/ diff --git a/licenses/httpclient5-5.3.1.jar.sha1 b/licenses/httpclient5-5.3.1.jar.sha1 new file mode 100644 index 00000000..c8f32c1e --- /dev/null +++ b/licenses/httpclient5-5.3.1.jar.sha1 @@ -0,0 +1 @@ +56b53c8f4bcdaada801d311cf2ff8a24d6d96883 \ No newline at end of file diff --git a/licenses/httpcore5-5.3.jar.sha1 b/licenses/httpcore5-5.3.jar.sha1 new file mode 100644 index 00000000..1721b75f --- /dev/null +++ b/licenses/httpcore5-5.3.jar.sha1 @@ -0,0 +1 @@ +a30e0b837732ac0e034c196dbdfcd8208d347a72 \ No newline at end of file diff --git a/src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java b/src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java index 3e4baba1..f60f6535 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java @@ -16,38 +16,41 @@ package com.o19s.es.ltr; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThan; + +import java.io.IOException; + +import org.apache.lucene.util.Accountable; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.unit.ByteSizeUnit; +import org.opensearch.core.common.unit.ByteSizeValue; + import com.o19s.es.ltr.action.BaseIntegrationTest; import com.o19s.es.ltr.feature.store.CompiledLtrModel; import com.o19s.es.ltr.feature.store.MemStore; import com.o19s.es.ltr.feature.store.index.CachedFeatureStore; import com.o19s.es.ltr.feature.store.index.Caches; import com.o19s.es.ltr.ranker.LtrRanker; -import org.apache.lucene.util.Accountable; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.common.unit.ByteSizeUnit; -import org.opensearch.core.common.unit.ByteSizeValue; - -import java.io.IOException; - -import static org.hamcrest.Matchers.allOf; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.lessThan; public class NodeSettingsIT extends BaseIntegrationTest { private final MemStore memStore = new MemStore(); private final int memSize = 1024; private final int expireAfterRead = 100; - private final int expireAfterWrite = expireAfterRead*4; + private final int expireAfterWrite = expireAfterRead * 4; @Override protected Settings nodeSettings() { Settings settings = super.nodeSettings(); - return Settings.builder().put(settings) - .put(Caches.LTR_CACHE_MEM_SETTING.getKey(), memSize + "kb") - .put(Caches.LTR_CACHE_EXPIRE_AFTER_READ.getKey(), expireAfterRead + "ms") - .put(Caches.LTR_CACHE_EXPIRE_AFTER_WRITE.getKey(), expireAfterWrite + "ms") - .build(); + return Settings + .builder() + .put(settings) + .put(Caches.LTR_CACHE_MEM_SETTING.getKey(), memSize + "kb") + .put(Caches.LTR_CACHE_EXPIRE_AFTER_READ.getKey(), expireAfterRead + "ms") + .put(Caches.LTR_CACHE_EXPIRE_AFTER_WRITE.getKey(), expireAfterWrite + "ms") + .build(); } public void testCacheSettings() throws IOException, InterruptedException { @@ -67,14 +70,14 @@ public void testCacheSettings() throws IOException, InterruptedException { } while (totalAdded < maxMemSize); assertThat(totalAdded, greaterThan(maxMemSize)); assertThat(caches.modelCache().weight(), greaterThan(0L)); - Thread.sleep(expireAfterWrite*2); + Thread.sleep(expireAfterWrite * 2); caches.modelCache().refresh(); assertEquals(0, caches.modelCache().weight()); cached.loadModel("test0"); // Second load for accessTime cached.loadModel("test0"); assertThat(caches.modelCache().weight(), greaterThan(0L)); - Thread.sleep(expireAfterRead*2); + Thread.sleep(expireAfterRead * 2); caches.modelCache().refresh(); assertEquals(0, caches.modelCache().weight()); } diff --git a/src/javaRestTest/java/com/o19s/es/ltr/action/AddFeaturesToSetActionIT.java b/src/javaRestTest/java/com/o19s/es/ltr/action/AddFeaturesToSetActionIT.java index 83415554..4fa336e6 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/action/AddFeaturesToSetActionIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/action/AddFeaturesToSetActionIT.java @@ -16,22 +16,23 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetRequestBuilder; -import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetResponse; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import org.opensearch.action.DocWriteResponse; +import static com.o19s.es.ltr.LtrTestUtils.randomFeature; +import static java.util.Arrays.asList; +import static org.hamcrest.CoreMatchers.containsString; +import static org.opensearch.ExceptionsHelper.unwrap; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutionException; -import static com.o19s.es.ltr.LtrTestUtils.randomFeature; -import static java.util.Arrays.asList; -import static org.opensearch.ExceptionsHelper.unwrap; -import static org.hamcrest.CoreMatchers.containsString; +import org.opensearch.action.DocWriteResponse; + +import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetRequestBuilder; +import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetResponse; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; public class AddFeaturesToSetActionIT extends BaseIntegrationTest { public void testAddToSetWithQuery() throws Exception { @@ -100,7 +101,7 @@ public void testAddToSetWithList() throws Exception { assertEquals(2, resp.getResponse().getVersion()); assertEquals(DocWriteResponse.Result.UPDATED, resp.getResponse().getResult()); set = getElement(StoredFeatureSet.class, StoredFeatureSet.TYPE, "new_feature_set"); - assertEquals(features.size()+1, set.size()); + assertEquals(features.size() + 1, set.size()); assertTrue(set.hasFeature("another_feature")); } @@ -109,8 +110,7 @@ public void testFailuresWhenEmpty() throws Exception { builder.request().setFeatureSet("new_broken_set"); builder.request().setFeatureNameQuery("doesnotexist*"); builder.request().setStore(IndexFeatureStore.DEFAULT_STORE); - Throwable iae = unwrap(expectThrows(ExecutionException.class, () -> builder.execute().get()), - IllegalArgumentException.class); + Throwable iae = unwrap(expectThrows(ExecutionException.class, () -> builder.execute().get()), IllegalArgumentException.class); assertNotNull(iae); assertThat(iae.getMessage(), containsString("returned no features")); } @@ -127,14 +127,12 @@ public void testFailuresOnDuplicates() throws Exception { assertEquals(DocWriteResponse.Result.CREATED, resp.getResponse().getResult()); assertEquals(1, resp.getResponse().getVersion()); - - AddFeaturesToSetRequestBuilder builder2= new AddFeaturesToSetRequestBuilder(client()); + AddFeaturesToSetRequestBuilder builder2 = new AddFeaturesToSetRequestBuilder(client()); builder2.request().setFeatureSet("duplicated_set"); builder2.request().setFeatureNameQuery("duplicated"); builder2.request().setStore(IndexFeatureStore.DEFAULT_STORE); - Throwable iae = unwrap(expectThrows(ExecutionException.class, () -> builder2.execute().get()), - IllegalArgumentException.class); + Throwable iae = unwrap(expectThrows(ExecutionException.class, () -> builder2.execute().get()), IllegalArgumentException.class); assertNotNull(iae); assertThat(iae.getMessage(), containsString("defined twice in this set")); } @@ -143,7 +141,7 @@ public void testMergeWithQuery() throws Exception { addElement(randomFeature("duplicated")); addElement(randomFeature("new_feature")); - AddFeaturesToSetRequestBuilder builder= new AddFeaturesToSetRequestBuilder(client()); + AddFeaturesToSetRequestBuilder builder = new AddFeaturesToSetRequestBuilder(client()); builder.request().setFeatureSet("merged_set"); builder.request().setFeatureNameQuery("duplicated*"); builder.request().setStore(IndexFeatureStore.DEFAULT_STORE); @@ -153,7 +151,7 @@ public void testMergeWithQuery() throws Exception { assertEquals(DocWriteResponse.Result.CREATED, resp.getResponse().getResult()); assertEquals(1, resp.getResponse().getVersion()); - builder= new AddFeaturesToSetRequestBuilder(client()); + builder = new AddFeaturesToSetRequestBuilder(client()); builder.request().setFeatureSet("merged_set"); builder.request().setFeatureNameQuery("*"); builder.request().setStore(IndexFeatureStore.DEFAULT_STORE); @@ -176,7 +174,7 @@ public void testMergeWithList() throws Exception { features.add(feat); } - AddFeaturesToSetRequestBuilder builder= new AddFeaturesToSetRequestBuilder(client()); + AddFeaturesToSetRequestBuilder builder = new AddFeaturesToSetRequestBuilder(client()); builder.request().setFeatureSet("new_feature_set"); builder.request().setFeatures(features); builder.request().setMerge(true); @@ -191,7 +189,7 @@ public void testMergeWithList() throws Exception { assertEquals(features.size(), set.size()); assertTrue(features.stream().map(StoredFeature::name).allMatch(set::hasFeature)); - builder= new AddFeaturesToSetRequestBuilder(client()); + builder = new AddFeaturesToSetRequestBuilder(client()); builder.request().setFeatureSet("new_feature_set"); builder.request().setFeatures(asList(randomFeature("another_feature"), randomFeature("feature0"))); builder.request().setMerge(true); @@ -200,7 +198,7 @@ public void testMergeWithList() throws Exception { assertEquals(2, resp.getResponse().getVersion()); assertEquals(DocWriteResponse.Result.UPDATED, resp.getResponse().getResult()); set = getElement(StoredFeatureSet.class, StoredFeatureSet.TYPE, "new_feature_set"); - assertEquals(features.size()+1, set.size()); + assertEquals(features.size() + 1, set.size()); assertTrue(set.hasFeature("another_feature")); assertEquals(0, set.featureOrdinal("feature0")); } diff --git a/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java b/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java index 98326bb5..8debb653 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/action/BaseIntegrationTest.java @@ -16,15 +16,22 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.LtrQueryParserPlugin; -import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreRequestBuilder; -import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreResponse; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; -import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine; +import static com.o19s.es.ltr.feature.store.ScriptFeature.EXTRA_LOGGING; +import static com.o19s.es.ltr.feature.store.ScriptFeature.FEATURE_VECTOR; + +import java.io.IOException; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.function.Supplier; + import org.apache.lucene.index.LeafReaderContext; +import org.junit.Before; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.admin.indices.create.CreateIndexAction; import org.opensearch.action.admin.indices.create.CreateIndexResponse; @@ -38,21 +45,15 @@ import org.opensearch.script.ScriptEngine; import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.test.TestGeoShapeFieldMapperPlugin; -import org.junit.Before; -import java.io.IOException; -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.ExecutionException; -import java.util.function.Supplier; - -import static com.o19s.es.ltr.feature.store.ScriptFeature.EXTRA_LOGGING; -import static com.o19s.es.ltr.feature.store.ScriptFeature.FEATURE_VECTOR; +import com.o19s.es.ltr.LtrQueryParserPlugin; +import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreRequestBuilder; +import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreResponse; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; +import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine; public abstract class BaseIntegrationTest extends OpenSearchSingleNodeTestCase { @@ -61,8 +62,8 @@ public abstract class BaseIntegrationTest extends OpenSearchSingleNodeTestCase { @Override // TODO: Remove the TestGeoShapeFieldMapperPlugin once upstream has completed the migration. protected Collection> getPlugins() { - return Arrays.asList(LtrQueryParserPlugin.class, NativeScriptPlugin.class, InjectionScriptPlugin.class, - TestGeoShapeFieldMapperPlugin.class); + return Arrays + .asList(LtrQueryParserPlugin.class, NativeScriptPlugin.class, InjectionScriptPlugin.class, TestGeoShapeFieldMapperPlugin.class); } public void createStore(String name) throws Exception { @@ -89,8 +90,8 @@ public void createDefaultStore() throws Exception { createStore(IndexFeatureStore.DEFAULT_STORE); } - public FeatureStoreResponse addElement(StorableElement element, - FeatureValidation validation) throws ExecutionException, InterruptedException { + public FeatureStoreResponse addElement(StorableElement element, FeatureValidation validation) throws ExecutionException, + InterruptedException { return addElement(element, validation, IndexFeatureStore.DEFAULT_STORE); } @@ -114,11 +115,10 @@ protected LtrRankerParserFactory parserFactory() { return getInstanceFromNode(LtrRankerParserFactory.class); } - public FeatureStoreResponse addElement(StorableElement element, - @Nullable FeatureValidation validation, - String store) throws ExecutionException, InterruptedException { - FeatureStoreRequestBuilder builder = - new FeatureStoreRequestBuilder(client(), FeatureStoreAction.INSTANCE); + public FeatureStoreResponse addElement(StorableElement element, @Nullable FeatureValidation validation, String store) + throws ExecutionException, + InterruptedException { + FeatureStoreRequestBuilder builder = new FeatureStoreRequestBuilder(client(), FeatureStoreAction.INSTANCE); builder.request().setStorableElement(element); builder.request().setAction(FeatureStoreAction.FeatureStoreRequest.Action.CREATE); builder.request().setStore(store); @@ -160,38 +160,40 @@ public String getType() { */ @SuppressWarnings("unchecked") @Override - public FactoryType compile(String scriptName, String scriptSource, - ScriptContext context, Map params) { + public FactoryType compile( + String scriptName, + String scriptSource, + ScriptContext context, + Map params + ) { if (!context.equals(ScoreScript.CONTEXT) && (!context.equals(AGGS_CONTEXT))) { - throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name - + "]"); + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } // we use the script "source" as the script identifier - ScoreScript.Factory factory = (p, lookup, searcher) -> - new ScoreScript.LeafFactory() { + ScoreScript.Factory factory = (p, lookup, searcher) -> new ScoreScript.LeafFactory() { + @Override + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { + return new ScoreScript(p, lookup, searcher, ctx) { @Override - public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return new ScoreScript(p, lookup, searcher, ctx) { - @Override - public double execute(ExplanationHolder explainationHolder) { - // For testing purposes just look for the "terms" key and see if stats were injected - if(p.containsKey("termStats")) { - Supplier>> termStats = (Supplier>>) p.get("termStats"); - ArrayList dfStats = termStats.get().get("df"); - return dfStats.size() > 0 ? dfStats.get(0) : 0.0; - } else { - return 0.0; - } - } - }; - } - - @Override - public boolean needs_score() { - return false; + public double execute(ExplanationHolder explainationHolder) { + // For testing purposes just look for the "terms" key and see if stats were injected + if (p.containsKey("termStats")) { + Supplier>> termStats = + (Supplier>>) p.get("termStats"); + ArrayList dfStats = termStats.get().get("df"); + return dfStats.size() > 0 ? dfStats.get(0) : 0.0; + } else { + return 0.0; + } } }; + } + + @Override + public boolean needs_score() { + return false; + } + }; return context.factoryClazz.cast(factory); } @@ -204,7 +206,6 @@ public Set> getSupportedContexts() { } } - public static class NativeScriptPlugin extends Plugin implements ScriptPlugin { public static final String FEATURE_EXTRACTOR = "feature_extractor"; @@ -231,94 +232,94 @@ public String getType() { */ @SuppressWarnings("unchecked") @Override - public FactoryType compile(String scriptName, String scriptSource, - ScriptContext context, Map params) { + public FactoryType compile( + String scriptName, + String scriptSource, + ScriptContext context, + Map params + ) { if (!context.equals(ScoreScript.CONTEXT) && (!context.equals(AGGS_CONTEXT))) { - throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name - + "]"); + throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]"); } // we use the script "source" as the script identifier if (FEATURE_EXTRACTOR.equals(scriptSource)) { - ScoreScript.Factory factory = (p, lookup, searcher) -> - new ScoreScript.LeafFactory() { - final Map featureSupplier; - final String dependentFeature; - double extraMultiplier = 0.0d; - - public static final String DEPENDENT_FEATURE = "dependent_feature"; - public static final String EXTRA_SCRIPT_PARAM = "extra_multiplier"; - - { - if (!p.containsKey(FEATURE_VECTOR)) { - throw new IllegalArgumentException("Missing parameter [" + FEATURE_VECTOR + "]"); - } - if (!p.containsKey(EXTRA_LOGGING)) { - throw new IllegalArgumentException("Missing parameter [" + EXTRA_LOGGING + "]"); - } - if (!p.containsKey(DEPENDENT_FEATURE)) { - throw new IllegalArgumentException("Missing parameter [depdendent_feature ]"); - } - if (p.containsKey(EXTRA_SCRIPT_PARAM)) { - extraMultiplier = Double.valueOf(p.get(EXTRA_SCRIPT_PARAM).toString()); - } - featureSupplier = (Map) p.get(FEATURE_VECTOR); - dependentFeature = p.get(DEPENDENT_FEATURE).toString(); - } + ScoreScript.Factory factory = (p, lookup, searcher) -> new ScoreScript.LeafFactory() { + final Map featureSupplier; + final String dependentFeature; + double extraMultiplier = 0.0d; - @Override - public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return new ScoreScript(p, lookup, searcher, ctx) { - @Override - public double execute(ExplanationHolder explainationHolder ) { - return extraMultiplier == 0.0d ? - featureSupplier.get(dependentFeature) * 10 : - featureSupplier.get(dependentFeature) * extraMultiplier; - } - }; - } + public static final String DEPENDENT_FEATURE = "dependent_feature"; + public static final String EXTRA_SCRIPT_PARAM = "extra_multiplier"; + { + if (!p.containsKey(FEATURE_VECTOR)) { + throw new IllegalArgumentException("Missing parameter [" + FEATURE_VECTOR + "]"); + } + if (!p.containsKey(EXTRA_LOGGING)) { + throw new IllegalArgumentException("Missing parameter [" + EXTRA_LOGGING + "]"); + } + if (!p.containsKey(DEPENDENT_FEATURE)) { + throw new IllegalArgumentException("Missing parameter [depdendent_feature ]"); + } + if (p.containsKey(EXTRA_SCRIPT_PARAM)) { + extraMultiplier = Double.valueOf(p.get(EXTRA_SCRIPT_PARAM).toString()); + } + featureSupplier = (Map) p.get(FEATURE_VECTOR); + dependentFeature = p.get(DEPENDENT_FEATURE).toString(); + } + + @Override + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { + return new ScoreScript(p, lookup, searcher, ctx) { @Override - public boolean needs_score() { - return false; + public double execute(ExplanationHolder explainationHolder) { + return extraMultiplier == 0.0d + ? featureSupplier.get(dependentFeature) * 10 + : featureSupplier.get(dependentFeature) * extraMultiplier; } }; + } + + @Override + public boolean needs_score() { + return false; + } + }; return context.factoryClazz.cast(factory); - } - else if (scriptSource.equals(FEATURE_EXTRACTOR + "_extra_logging")) { - ScoreScript.Factory factory = (p, lookup, searcher) -> - new ScoreScript.LeafFactory() { - { - if (!p.containsKey(FEATURE_VECTOR)) { - throw new IllegalArgumentException("Missing parameter [" + FEATURE_VECTOR + "]"); - } - if (!p.containsKey(EXTRA_LOGGING)) { - throw new IllegalArgumentException("Missing parameter [" + EXTRA_LOGGING + "]"); - } - } + } else if (scriptSource.equals(FEATURE_EXTRACTOR + "_extra_logging")) { + ScoreScript.Factory factory = (p, lookup, searcher) -> new ScoreScript.LeafFactory() { + { + if (!p.containsKey(FEATURE_VECTOR)) { + throw new IllegalArgumentException("Missing parameter [" + FEATURE_VECTOR + "]"); + } + if (!p.containsKey(EXTRA_LOGGING)) { + throw new IllegalArgumentException("Missing parameter [" + EXTRA_LOGGING + "]"); + } + } - @Override - public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { - return new ScoreScript(p, lookup, searcher, ctx) { - - @Override - public double execute(ExplanationHolder explanation) { - Map extraLoggingMap = ((Supplier>) getParams() - .get(EXTRA_LOGGING)).get(); - if (extraLoggingMap != null) { - extraLoggingMap.put("extra_float", 10.0f); - extraLoggingMap.put("extra_string", "additional_info"); - } - return 1.0d; - } - }; - } + @Override + public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { + return new ScoreScript(p, lookup, searcher, ctx) { @Override - public boolean needs_score() { - return false; + public double execute(ExplanationHolder explanation) { + Map extraLoggingMap = ((Supplier>) getParams() + .get(EXTRA_LOGGING)).get(); + if (extraLoggingMap != null) { + extraLoggingMap.put("extra_float", 10.0f); + extraLoggingMap.put("extra_string", "additional_info"); + } + return 1.0d; } }; + } + + @Override + public boolean needs_score() { + return false; + } + }; return context.factoryClazz.cast(factory); } diff --git a/src/javaRestTest/java/com/o19s/es/ltr/action/ListStoresActionIT.java b/src/javaRestTest/java/com/o19s/es/ltr/action/ListStoresActionIT.java index db6feed5..cc41d272 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/action/ListStoresActionIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/action/ListStoresActionIT.java @@ -16,6 +16,13 @@ package com.o19s.es.ltr.action; +import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.indexName; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import com.o19s.es.ltr.LtrTestUtils; import com.o19s.es.ltr.action.ListStoresAction.IndexStoreInfo; import com.o19s.es.ltr.action.ListStoresAction.ListStoresActionResponse; @@ -24,22 +31,14 @@ import com.o19s.es.ltr.feature.store.StoredLtrModel; import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.indexName; - public class ListStoresActionIT extends BaseIntegrationTest { public void testListStore() throws Exception { createStore(indexName("test2")); createStore(indexName("test3")); Map infos = new HashMap<>(); - String[] stores = new String[]{IndexFeatureStore.DEFAULT_STORE, indexName("test2"), indexName("test3")}; + String[] stores = new String[] { IndexFeatureStore.DEFAULT_STORE, indexName("test2"), indexName("test3") }; for (String store : stores) { - infos.put(IndexFeatureStore.storeName(store), - new IndexStoreInfo(store, IndexFeatureStore.VERSION, addElements(store))); + infos.put(IndexFeatureStore.storeName(store), new IndexStoreInfo(store, IndexFeatureStore.VERSION, addElements(store))); } ListStoresActionResponse resp = new ListStoresAction.ListStoresActionBuilder(client()).execute().get(); assertEquals(infos.size(), resp.getStores().size()); @@ -66,7 +65,7 @@ private Map addElements(String store) throws Exception { return counts; } - private void addElements(String store, int nFeatures, int nSets,int nModels) throws Exception { + private void addElements(String store, int nFeatures, int nSets, int nModels) throws Exception { for (int i = 0; i < nFeatures; i++) { StoredFeature feat = LtrTestUtils.randomFeature("feature" + i); addElement(feat, store); diff --git a/src/javaRestTest/java/com/o19s/es/ltr/action/ValidatingFeatureStoreActionIT.java b/src/javaRestTest/java/com/o19s/es/ltr/action/ValidatingFeatureStoreActionIT.java index 5f2c5bf7..0ad01a15 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/action/ValidatingFeatureStoreActionIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/action/ValidatingFeatureStoreActionIT.java @@ -16,22 +16,23 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import com.o19s.es.ltr.ranker.parser.LinearRankerParser; -import org.opensearch.index.query.QueryBuilders; -import org.hamcrest.CoreMatchers; +import static java.util.Collections.singletonList; +import static org.hamcrest.CoreMatchers.instanceOf; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutionException; -import static java.util.Collections.singletonList; -import static org.hamcrest.CoreMatchers.instanceOf; +import org.hamcrest.CoreMatchers; +import org.opensearch.index.query.QueryBuilders; + +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import com.o19s.es.ltr.ranker.parser.LinearRankerParser; public class ValidatingFeatureStoreActionIT extends BaseIntegrationTest { public void testValidateFeature() throws ExecutionException, InterruptedException { @@ -40,8 +41,8 @@ public void testValidateFeature() throws ExecutionException, InterruptedExceptio StoredFeature feature = new StoredFeature("test", singletonList("query_string"), "mustache", brokenQuery); Map params = new HashMap<>(); params.put("query_string", "a query"); - Throwable e = expectThrows(ExecutionException.class, - () -> addElement(feature, new FeatureValidation("test_index", params))).getCause(); + Throwable e = expectThrows(ExecutionException.class, () -> addElement(feature, new FeatureValidation("test_index", params))) + .getCause(); assertThat(e, instanceOf(IllegalArgumentException.class)); assertThat(e.getMessage(), CoreMatchers.containsString("Cannot store element, validation failed.")); } @@ -56,8 +57,10 @@ public void testValidateFeatureSet() throws ExecutionException, InterruptedExcep Map params = new HashMap<>(); params.put("query_string", "a query"); StoredFeatureSet brokenFeatureSet = new StoredFeatureSet("my_feature_set", Arrays.asList(feature, brokenFeature)); - Throwable e = expectThrows(ExecutionException.class, - () -> addElement(brokenFeatureSet, new FeatureValidation("test_index", params))).getCause(); + Throwable e = expectThrows( + ExecutionException.class, + () -> addElement(brokenFeatureSet, new FeatureValidation("test_index", params)) + ).getCause(); assertThat(e, instanceOf(IllegalArgumentException.class)); assertThat(e.getMessage(), CoreMatchers.containsString("Cannot store element, validation failed.")); } @@ -72,11 +75,13 @@ public void testValidateModel() throws ExecutionException, InterruptedException Map params = new HashMap<>(); params.put("query_string", "a query"); String model = "{\"test\": 2.1, \"broken\": 4.3}"; - StoredLtrModel brokenModel = new StoredLtrModel("broken_model", - new StoredFeatureSet("my_feature_set", Arrays.asList(feature, brokenFeature)), - new StoredLtrModel.LtrModelDefinition(LinearRankerParser.TYPE, model, true)); - Throwable e = expectThrows(ExecutionException.class, () -> addElement(brokenModel, - new FeatureValidation("test_index", params))).getCause(); + StoredLtrModel brokenModel = new StoredLtrModel( + "broken_model", + new StoredFeatureSet("my_feature_set", Arrays.asList(feature, brokenFeature)), + new StoredLtrModel.LtrModelDefinition(LinearRankerParser.TYPE, model, true) + ); + Throwable e = expectThrows(ExecutionException.class, () -> addElement(brokenModel, new FeatureValidation("test_index", params))) + .getCause(); assertThat(e, instanceOf(IllegalArgumentException.class)); assertThat(e.getMessage(), CoreMatchers.containsString("Cannot store element, validation failed.")); } @@ -112,11 +117,15 @@ public void testValidationOnCreateModelFromSet() throws ExecutionException, Inte StoredFeatureSet brokenFeatureSet = new StoredFeatureSet("my_feature_set", Arrays.asList(feature, brokenFeature)); // Store a broken feature set addElement(brokenFeatureSet); - CreateModelFromSetAction.CreateModelFromSetRequestBuilder request = - new CreateModelFromSetAction.CreateModelFromSetRequestBuilder(client()); + CreateModelFromSetAction.CreateModelFromSetRequestBuilder request = new CreateModelFromSetAction.CreateModelFromSetRequestBuilder( + client() + ); request.request().setValidation(new FeatureValidation("test_index", params)); - StoredLtrModel.LtrModelDefinition definition = new StoredLtrModel.LtrModelDefinition("model/linear", - "{\"test\": 2.1, \"broken\": 4.3}", true); + StoredLtrModel.LtrModelDefinition definition = new StoredLtrModel.LtrModelDefinition( + "model/linear", + "{\"test\": 2.1, \"broken\": 4.3}", + true + ); request.withoutVersion(IndexFeatureStore.DEFAULT_STORE, "my_feature_set", "broken_model", definition); request.request().setValidation(new FeatureValidation("test_index", params)); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, request::get); diff --git a/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java b/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java index 188e0f7d..fc640f7c 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java @@ -16,14 +16,17 @@ package com.o19s.es.ltr.logging; -import com.o19s.es.ltr.LtrTestUtils; -import com.o19s.es.ltr.action.BaseIntegrationTest; -import com.o19s.es.ltr.feature.store.ScriptFeature; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; -import com.o19s.es.ltr.query.StoredLtrQueryBuilder; -import com.o19s.es.ltr.ranker.parser.LinearRankerParserTests; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.instanceOf; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.tests.util.TestUtil; import org.opensearch.ExceptionsHelper; @@ -42,71 +45,135 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.rescore.QueryRescorerBuilder; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.instanceOf; +import com.o19s.es.ltr.LtrTestUtils; +import com.o19s.es.ltr.action.BaseIntegrationTest; +import com.o19s.es.ltr.feature.store.ScriptFeature; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; +import com.o19s.es.ltr.query.StoredLtrQueryBuilder; +import com.o19s.es.ltr.ranker.parser.LinearRankerParserTests; public class LoggingIT extends BaseIntegrationTest { public static final float FACTOR = 1.2F; public void prepareModels() throws Exception { List features = new ArrayList<>(3); - features.add(new StoredFeature("text_feature1", Collections.singletonList("query"), "mustache", - QueryBuilders.matchQuery("field1", "{{query}}").toString())); - features.add(new StoredFeature("text_feature2", Collections.singletonList("query"), "mustache", - QueryBuilders.matchQuery("field2", "{{query}}").toString())); - features.add(new StoredFeature("numeric_feature1", Collections.singletonList("query"), "mustache", - new FunctionScoreQueryBuilder(QueryBuilders.matchAllQuery(), new FieldValueFactorFunctionBuilder("scorefield1") - .factor(FACTOR) - .modifier(FieldValueFactorFunction.Modifier.LN2P) - .missing(0F)).scoreMode(FunctionScoreQuery.ScoreMode.MULTIPLY).toString())); - features.add(new StoredFeature("derived_feature", Collections.singletonList("query"), "derived_expression", - "100")); + features + .add( + new StoredFeature( + "text_feature1", + Collections.singletonList("query"), + "mustache", + QueryBuilders.matchQuery("field1", "{{query}}").toString() + ) + ); + features + .add( + new StoredFeature( + "text_feature2", + Collections.singletonList("query"), + "mustache", + QueryBuilders.matchQuery("field2", "{{query}}").toString() + ) + ); + features + .add( + new StoredFeature( + "numeric_feature1", + Collections.singletonList("query"), + "mustache", + new FunctionScoreQueryBuilder( + QueryBuilders.matchAllQuery(), + new FieldValueFactorFunctionBuilder("scorefield1") + .factor(FACTOR) + .modifier(FieldValueFactorFunction.Modifier.LN2P) + .missing(0F) + ).scoreMode(FunctionScoreQuery.ScoreMode.MULTIPLY).toString() + ) + ); + features.add(new StoredFeature("derived_feature", Collections.singletonList("query"), "derived_expression", "100")); StoredFeatureSet set = new StoredFeatureSet("my_set", features); addElement(set); - StoredLtrModel model = new StoredLtrModel("my_model", set, - new StoredLtrModel.LtrModelDefinition("model/linear", - LinearRankerParserTests.generateRandomModelString(set), true)); + StoredLtrModel model = new StoredLtrModel( + "my_model", + set, + new StoredLtrModel.LtrModelDefinition("model/linear", LinearRankerParserTests.generateRandomModelString(set), true) + ); addElement(model); } + public void prepareModelsExtraLogging() throws Exception { List features = new ArrayList<>(3); - features.add(new StoredFeature("text_feature1", Collections.singletonList("query"), "mustache", - QueryBuilders.matchQuery("field1", "{{query}}").toString())); - features.add(new StoredFeature("text_feature2", Collections.singletonList("query"), "mustache", - QueryBuilders.matchQuery("field2", "{{query}}").toString())); - features.add(new StoredFeature("numeric_feature1", Collections.singletonList("query"), "mustache", - new FunctionScoreQueryBuilder(QueryBuilders.matchAllQuery(), new FieldValueFactorFunctionBuilder("scorefield1") - .factor(FACTOR) - .modifier(FieldValueFactorFunction.Modifier.LN2P) - .missing(0F)).scoreMode(FunctionScoreQuery.ScoreMode.MULTIPLY).toString())); - features.add(new StoredFeature("derived_feature", Collections.singletonList("query"), "derived_expression", - "100")); - features.add(new StoredFeature("extra_logging_feature", Arrays.asList("query"), ScriptFeature.TEMPLATE_LANGUAGE, - "{\"lang\": \"native\", \"source\": \"feature_extractor_extra_logging\", \"params\": {}}")); + features + .add( + new StoredFeature( + "text_feature1", + Collections.singletonList("query"), + "mustache", + QueryBuilders.matchQuery("field1", "{{query}}").toString() + ) + ); + features + .add( + new StoredFeature( + "text_feature2", + Collections.singletonList("query"), + "mustache", + QueryBuilders.matchQuery("field2", "{{query}}").toString() + ) + ); + features + .add( + new StoredFeature( + "numeric_feature1", + Collections.singletonList("query"), + "mustache", + new FunctionScoreQueryBuilder( + QueryBuilders.matchAllQuery(), + new FieldValueFactorFunctionBuilder("scorefield1") + .factor(FACTOR) + .modifier(FieldValueFactorFunction.Modifier.LN2P) + .missing(0F) + ).scoreMode(FunctionScoreQuery.ScoreMode.MULTIPLY).toString() + ) + ); + features.add(new StoredFeature("derived_feature", Collections.singletonList("query"), "derived_expression", "100")); + features + .add( + new StoredFeature( + "extra_logging_feature", + Arrays.asList("query"), + ScriptFeature.TEMPLATE_LANGUAGE, + "{\"lang\": \"native\", \"source\": \"feature_extractor_extra_logging\", \"params\": {}}" + ) + ); StoredFeatureSet set = new StoredFeatureSet("my_set", features); addElement(set); - StoredLtrModel model = new StoredLtrModel("my_model", set, - new StoredLtrModel.LtrModelDefinition("model/linear", - LinearRankerParserTests.generateRandomModelString(set), true)); + StoredLtrModel model = new StoredLtrModel( + "my_model", + set, + new StoredLtrModel.LtrModelDefinition("model/linear", LinearRankerParserTests.generateRandomModelString(set), true) + ); addElement(model); } + public void prepareExternalScriptFeatures() throws Exception { List features = new ArrayList<>(3); - features.add(new StoredFeature("test_inject", Arrays.asList(), ScriptFeature.TEMPLATE_LANGUAGE, - "{\"lang\": \"inject\", \"source\": \"df\", \"params\": {\"term_stat\": { " + - "\"analyzer\": \"analyzerParam\", " + - "\"terms\": \"termsParam\", " + - "\"fields\": \"fieldsParam\" } } }")); + features + .add( + new StoredFeature( + "test_inject", + Arrays.asList(), + ScriptFeature.TEMPLATE_LANGUAGE, + "{\"lang\": \"inject\", \"source\": \"df\", \"params\": {\"term_stat\": { " + + "\"analyzer\": \"analyzerParam\", " + + "\"terms\": \"termsParam\", " + + "\"fields\": \"fieldsParam\" } } }" + ) + ); StoredFeatureSet set = new StoredFeatureSet("my_set", features); addElement(set); @@ -114,11 +181,18 @@ public void prepareExternalScriptFeatures() throws Exception { public void prepareInternalScriptFeatures() throws Exception { List features = new ArrayList<>(3); - features.add(new StoredFeature("test_inject", Arrays.asList("query"), ScriptFeature.TEMPLATE_LANGUAGE, - "{\"lang\": \"inject\", \"source\": \"df\", \"params\": {\"term_stat\": { " + - "\"analyzer\": \"!standard\", " + - "\"terms\": [\"found\"], " + - "\"fields\": [\"field1\"] } } }")); + features + .add( + new StoredFeature( + "test_inject", + Arrays.asList("query"), + ScriptFeature.TEMPLATE_LANGUAGE, + "{\"lang\": \"inject\", \"source\": \"df\", \"params\": {\"term_stat\": { " + + "\"analyzer\": \"!standard\", " + + "\"terms\": [\"found\"], " + + "\"fields\": [\"field1\"] } } }" + ) + ); StoredFeatureSet set = new StoredFeatureSet("my_set", features); addElement(set); @@ -127,50 +201,53 @@ public void prepareInternalScriptFeatures() throws Exception { public void testFailures() throws Exception { prepareModels(); buildIndex(); - QueryBuilder query = QueryBuilders.matchQuery("field1", "found") - .boost(random().nextInt(3)) - .queryName("not_sltr"); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false))); - - assertExcWithMessage(() -> client().prepareSearch("test_index") - .setSource(sourceBuilder).get(), IllegalArgumentException.class, "No query named [test] found"); - - SearchSourceBuilder sourceBuilder2 = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "not_sltr", false))); - - assertExcWithMessage(() -> client().prepareSearch("test_index") - .setSource(sourceBuilder2).get(), IllegalArgumentException.class, "Query named [not_sltr] must be a " + - "[sltr] query [TermQuery] found"); - - SearchSourceBuilder sourceBuilder3 = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addRescoreLogging("first_log", 0, false))); - assertExcWithMessage(() -> client().prepareSearch("test_index") - .setSource(sourceBuilder3).get(), IllegalArgumentException.class, "rescore index [0] is out of bounds, " + - "only [0]"); - - SearchSourceBuilder sourceBuilder4 = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .addRescorer(new QueryRescorerBuilder(QueryBuilders.matchAllQuery())) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addRescoreLogging("first_log", 0, false))); - assertExcWithMessage(() -> client().prepareSearch("test_index") - .setSource(sourceBuilder4).get(), IllegalArgumentException.class, "Expected a [sltr] query but found " + - "a [MatchAllDocsQuery] at index [0]"); + QueryBuilder query = QueryBuilders.matchQuery("field1", "found").boost(random().nextInt(3)).queryName("not_sltr"); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .ext(Collections.singletonList(new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false))); + + assertExcWithMessage( + () -> client().prepareSearch("test_index").setSource(sourceBuilder).get(), + IllegalArgumentException.class, + "No query named [test] found" + ); + + SearchSourceBuilder sourceBuilder2 = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .ext(Collections.singletonList(new LoggingSearchExtBuilder().addQueryLogging("first_log", "not_sltr", false))); + + assertExcWithMessage( + () -> client().prepareSearch("test_index").setSource(sourceBuilder2).get(), + IllegalArgumentException.class, + "Query named [not_sltr] must be a " + "[sltr] query [TermQuery] found" + ); + + SearchSourceBuilder sourceBuilder3 = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .ext(Collections.singletonList(new LoggingSearchExtBuilder().addRescoreLogging("first_log", 0, false))); + assertExcWithMessage( + () -> client().prepareSearch("test_index").setSource(sourceBuilder3).get(), + IllegalArgumentException.class, + "rescore index [0] is out of bounds, " + "only [0]" + ); + + SearchSourceBuilder sourceBuilder4 = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(QueryBuilders.matchAllQuery())) + .ext(Collections.singletonList(new LoggingSearchExtBuilder().addRescoreLogging("first_log", 0, false))); + assertExcWithMessage( + () -> client().prepareSearch("test_index").setSource(sourceBuilder4).get(), + IllegalArgumentException.class, + "Expected a [sltr] query but found " + "a [MatchAllDocsQuery] at index [0]" + ); } private void assertExcWithMessage(ThrowingRunnable r, Class exc, String msg) { @@ -192,27 +269,32 @@ public void testLog() throws Exception { Collections.shuffle(idsColl, random()); String[] ids = idsColl.subList(0, TestUtil.nextInt(random(), 5, 15)).toArray(new String[0]); StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .featureSetName("my_set") - .params(params) - .queryName("test") - .boost(random().nextInt(3)); + .featureSetName("my_set") + .params(params) + .queryName("test") + .boost(random().nextInt(3)); StoredLtrQueryBuilder sbuilder_rescore = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .featureSetName("my_set") - .params(params) - .queryName("test_rescore") - .boost(random().nextInt(3)); - - QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) - .filter(QueryBuilders.idsQuery().addIds(ids)); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false) - .addRescoreLogging("second_log", 0, true))); + .featureSetName("my_set") + .params(params) + .queryName("test_rescore") + .boost(random().nextInt(3)); + + QueryBuilder query = QueryBuilders + .boolQuery() + .must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery().addIds(ids)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext( + Collections + .singletonList( + new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false).addRescoreLogging("second_log", 0, true) + ) + ); SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); assertSearchHits(docs, resp); @@ -223,37 +305,45 @@ public void testLog() throws Exception { sbuilder_rescore.modelName("my_model"); sbuilder_rescore.boost(random().nextInt(3)); - query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) - .filter(QueryBuilders.idsQuery().addIds(ids)); - sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false) - .addRescoreLogging("second_log", 0, true))); + query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())).filter(QueryBuilders.idsQuery().addIds(ids)); + sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext( + Collections + .singletonList( + new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false).addRescoreLogging("second_log", 0, true) + ) + ); SearchResponse resp2 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); assertSearchHits(docs, resp2); - query = QueryBuilders.boolQuery() - .must(new WrapperQueryBuilder(sbuilder.toString())) - .must( - QueryBuilders.nestedQuery( + query = QueryBuilders + .boolQuery() + .must(new WrapperQueryBuilder(sbuilder.toString())) + .must( + QueryBuilders + .nestedQuery( "nesteddocs1", QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("nesteddocs1.field1", "nestedvalue")), ScoreMode.None - ).innerHit(new InnerHitBuilder()) - ); - sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false) - .addRescoreLogging("second_log", 0, true))); + ) + .innerHit(new InnerHitBuilder()) + ); + sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext( + Collections + .singletonList( + new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false).addRescoreLogging("second_log", 0, true) + ) + ); SearchResponse resp3 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); assertSearchHits(docs, resp3); } @@ -268,27 +358,32 @@ public void testLogExtraLogging() throws Exception { Collections.shuffle(idsColl, random()); String[] ids = idsColl.subList(0, TestUtil.nextInt(random(), 5, 15)).toArray(new String[0]); StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .featureSetName("my_set") - .params(params) - .queryName("test") - .boost(random().nextInt(3)); + .featureSetName("my_set") + .params(params) + .queryName("test") + .boost(random().nextInt(3)); StoredLtrQueryBuilder sbuilder_rescore = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .featureSetName("my_set") - .params(params) - .queryName("test_rescore") - .boost(random().nextInt(3)); - - QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) - .filter(QueryBuilders.idsQuery().addIds(ids)); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false) - .addRescoreLogging("second_log", 0, true))); + .featureSetName("my_set") + .params(params) + .queryName("test_rescore") + .boost(random().nextInt(3)); + + QueryBuilder query = QueryBuilders + .boolQuery() + .must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery().addIds(ids)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext( + Collections + .singletonList( + new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false).addRescoreLogging("second_log", 0, true) + ) + ); SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); assertSearchHitsExtraLogging(docs, resp); @@ -299,37 +394,45 @@ public void testLogExtraLogging() throws Exception { sbuilder_rescore.modelName("my_model"); sbuilder_rescore.boost(random().nextInt(3)); - query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) - .filter(QueryBuilders.idsQuery().addIds(ids)); - sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false) - .addRescoreLogging("second_log", 0, true))); + query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())).filter(QueryBuilders.idsQuery().addIds(ids)); + sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext( + Collections + .singletonList( + new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false).addRescoreLogging("second_log", 0, true) + ) + ); SearchResponse resp2 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); assertSearchHitsExtraLogging(docs, resp2); - query = QueryBuilders.boolQuery() - .must(new WrapperQueryBuilder(sbuilder.toString())) - .must( - QueryBuilders.nestedQuery( - "nesteddocs1", - QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("nesteddocs1.field1", "nestedvalue")), - ScoreMode.None - ).innerHit(new InnerHitBuilder()) - ); - sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false) - .addRescoreLogging("second_log", 0, true))); + query = QueryBuilders + .boolQuery() + .must(new WrapperQueryBuilder(sbuilder.toString())) + .must( + QueryBuilders + .nestedQuery( + "nesteddocs1", + QueryBuilders.boolQuery().filter(QueryBuilders.termQuery("nesteddocs1.field1", "nestedvalue")), + ScoreMode.None + ) + .innerHit(new InnerHitBuilder()) + ); + sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext( + Collections + .singletonList( + new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false).addRescoreLogging("second_log", 0, true) + ) + ); SearchResponse resp3 = client().prepareSearch("test_index").setSource(sourceBuilder).get(); assertSearchHitsExtraLogging(docs, resp3); } @@ -345,19 +448,20 @@ public void testScriptLogInternalParams() throws Exception { Collections.shuffle(idsColl, random()); String[] ids = idsColl.subList(0, TestUtil.nextInt(random(), 5, 15)).toArray(new String[0]); StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .featureSetName("my_set") - .params(params) - .queryName("test") - .boost(random().nextInt(3)); - - QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) - .filter(QueryBuilders.idsQuery().addIds(ids)); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false))); + .featureSetName("my_set") + .params(params) + .queryName("test") + .boost(random().nextInt(3)); + + QueryBuilder query = QueryBuilders + .boolQuery() + .must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery().addIds(ids)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .ext(Collections.singletonList(new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false))); SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); @@ -369,7 +473,7 @@ public void testScriptLogInternalParams() throws Exception { List> log = logs.get("first_log"); assertEquals(log.get(0).get("name"), "test_inject"); - assertTrue((Float)log.get(0).get("value") > 0.0F); + assertTrue((Float) log.get(0).get("value") > 0.0F); } public void testScriptLogExternalParams() throws Exception { @@ -391,19 +495,20 @@ public void testScriptLogExternalParams() throws Exception { Collections.shuffle(idsColl, random()); String[] ids = idsColl.subList(0, TestUtil.nextInt(random(), 5, 15)).toArray(new String[0]); StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .featureSetName("my_set") - .params(params) - .queryName("test") - .boost(random().nextInt(3)); - - QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) - .filter(QueryBuilders.idsQuery().addIds(ids)); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false))); + .featureSetName("my_set") + .params(params) + .queryName("test") + .boost(random().nextInt(3)); + + QueryBuilder query = QueryBuilders + .boolQuery() + .must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery().addIds(ids)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .ext(Collections.singletonList(new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false))); SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); @@ -415,7 +520,7 @@ public void testScriptLogExternalParams() throws Exception { List> log = logs.get("first_log"); assertEquals(log.get(0).get("name"), "test_inject"); - assertTrue((Float)log.get(0).get("value") > 0.0F); + assertTrue((Float) log.get(0).get("value") > 0.0F); } public void testScriptLogInvalidExternalParams() throws Exception { @@ -429,27 +534,30 @@ public void testScriptLogInvalidExternalParams() throws Exception { Collections.shuffle(idsColl, random()); String[] ids = idsColl.subList(0, TestUtil.nextInt(random(), 5, 15)).toArray(new String[0]); StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .featureSetName("my_set") - .params(params) - .queryName("test") - .boost(random().nextInt(3)); - - QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())) - .filter(QueryBuilders.idsQuery().addIds(ids)); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) - .fetchSource(false) - .size(10) - .ext(Collections.singletonList( - new LoggingSearchExtBuilder() - .addQueryLogging("first_log", "test", false))); - - assertExcWithMessage(() -> client().prepareSearch("test_index") - .setSource(sourceBuilder).get(), - IllegalArgumentException.class, "Term Stats injection requires fields and terms"); + .featureSetName("my_set") + .params(params) + .queryName("test") + .boost(random().nextInt(3)); + + QueryBuilder query = QueryBuilders + .boolQuery() + .must(new WrapperQueryBuilder(sbuilder.toString())) + .filter(QueryBuilders.idsQuery().addIds(ids)); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .query(query) + .fetchSource(false) + .size(10) + .ext(Collections.singletonList(new LoggingSearchExtBuilder().addQueryLogging("first_log", "test", false))); + + assertExcWithMessage( + () -> client().prepareSearch("test_index").setSource(sourceBuilder).get(), + IllegalArgumentException.class, + "Term Stats injection requires fields and terms" + ); } protected void assertSearchHits(Map docs, SearchResponse resp) { - for (SearchHit hit: resp.getHits()) { + for (SearchHit hit : resp.getHits()) { assertTrue(hit.getFields().containsKey("_ltrlog")); Map>> logs = hit.getFields().get("_ltrlog").getValue(); assertTrue(logs.containsKey("first_log")); @@ -465,8 +573,8 @@ protected void assertSearchHits(Map docs, SearchResponse resp) { assertEquals(log1.get(0).get("name"), "text_feature1"); assertEquals(log2.get(0).get("name"), "text_feature1"); - assertTrue((Float)log1.get(0).get("value") > 0F); - assertTrue((Float)log2.get(0).get("value") > 0F); + assertTrue((Float) log1.get(0).get("value") > 0F); + assertTrue((Float) log2.get(0).get("value") > 0F); assertEquals(log1.get(1).get("name"), "text_feature2"); assertFalse(log1.get(1).containsKey("value")); @@ -478,20 +586,20 @@ protected void assertSearchHits(Map docs, SearchResponse resp) { assertEquals(log1.get(0).get("name"), "text_feature1"); assertEquals(log2.get(0).get("name"), "text_feature1"); - assertTrue((Float)log1.get(1).get("value") > 0F); - assertTrue((Float)log2.get(1).get("value") > 0F); + assertTrue((Float) log1.get(1).get("value") > 0F); + assertTrue((Float) log2.get(1).get("value") > 0F); assertEquals(log1.get(0).get("name"), "text_feature1"); assertEquals(log2.get(0).get("name"), "text_feature1"); - assertEquals(0F, (Float)log2.get(0).get("value"), 0F); + assertEquals(0F, (Float) log2.get(0).get("value"), 0F); } float score = (float) Math.log1p((d.scorefield1 * FACTOR) + 1); assertEquals(log1.get(2).get("name"), "numeric_feature1"); assertEquals(log2.get(2).get("name"), "numeric_feature1"); - assertEquals(score, (Float)log1.get(2).get("value"), Math.ulp(score)); - assertEquals(score, (Float)log2.get(2).get("value"), Math.ulp(score)); + assertEquals(score, (Float) log1.get(2).get("value"), Math.ulp(score)); + assertEquals(score, (Float) log2.get(2).get("value"), Math.ulp(score)); assertEquals(log1.get(3).get("name"), "derived_feature"); assertEquals(log2.get(3).get("name"), "derived_feature"); @@ -504,7 +612,7 @@ protected void assertSearchHits(Map docs, SearchResponse resp) { @SuppressWarnings("unchecked") protected void assertSearchHitsExtraLogging(Map docs, SearchResponse resp) { - for (SearchHit hit: resp.getHits()) { + for (SearchHit hit : resp.getHits()) { assertTrue(hit.getFields().containsKey("_ltrlog")); Map>> logs = hit.getFields().get("_ltrlog").getValue(); assertTrue(logs.containsKey("first_log")); @@ -520,8 +628,8 @@ protected void assertSearchHitsExtraLogging(Map docs, SearchRespons assertEquals(log1.get(0).get("name"), "text_feature1"); assertEquals(log2.get(0).get("name"), "text_feature1"); - assertTrue((Float)log1.get(0).get("value") > 0F); - assertTrue((Float)log2.get(0).get("value") > 0F); + assertTrue((Float) log1.get(0).get("value") > 0F); + assertTrue((Float) log2.get(0).get("value") > 0F); assertEquals(log1.get(1).get("name"), "text_feature2"); assertFalse(log1.get(1).containsKey("value")); @@ -533,20 +641,20 @@ protected void assertSearchHitsExtraLogging(Map docs, SearchRespons assertEquals(log1.get(0).get("name"), "text_feature1"); assertEquals(log2.get(0).get("name"), "text_feature1"); - assertTrue((Float)log1.get(1).get("value") > 0F); - assertTrue((Float)log2.get(1).get("value") > 0F); + assertTrue((Float) log1.get(1).get("value") > 0F); + assertTrue((Float) log2.get(1).get("value") > 0F); assertEquals(log1.get(0).get("name"), "text_feature1"); assertEquals(log2.get(0).get("name"), "text_feature1"); - assertEquals(0F, (Float)log2.get(0).get("value"), 0F); + assertEquals(0F, (Float) log2.get(0).get("value"), 0F); } float score = (float) Math.log1p((d.scorefield1 * FACTOR) + 1); assertEquals(log1.get(2).get("name"), "numeric_feature1"); assertEquals(log2.get(2).get("name"), "numeric_feature1"); - assertEquals(score, (Float)log1.get(2).get("value"), Math.ulp(score)); - assertEquals(score, (Float)log2.get(2).get("value"), Math.ulp(score)); + assertEquals(score, (Float) log1.get(2).get("value"), Math.ulp(score)); + assertEquals(score, (Float) log2.get(2).get("value"), Math.ulp(score)); assertEquals(log1.get(3).get("name"), "derived_feature"); assertEquals(log2.get(3).get("name"), "derived_feature"); @@ -563,8 +671,8 @@ protected void assertSearchHitsExtraLogging(Map docs, SearchRespons assertEquals(log1.get(5).get("name"), "extra_logging"); assertEquals(log2.get(5).get("name"), "extra_logging"); - Map extraMap1 = (Map) log1.get(5).get("value"); - Map extraMap2 = (Map) log2.get(5).get("value"); + Map extraMap1 = (Map) log1.get(5).get("value"); + Map extraMap2 = (Map) log2.get(5).get("value"); assertEquals(2, extraMap1.size()); assertEquals(2, extraMap2.size()); @@ -575,11 +683,13 @@ protected void assertSearchHitsExtraLogging(Map docs, SearchRespons } } - public Map buildIndex() { - client().admin().indices().prepareCreate("test_index") - .setMapping( - "{\"properties\":{\"scorefield1\": {\"type\": \"float\"}, \"nesteddocs1\": {\"type\": \"nested\"}}}}") - .get(); + public Map buildIndex() { + client() + .admin() + .indices() + .prepareCreate("test_index") + .setMapping("{\"properties\":{\"scorefield1\": {\"type\": \"float\"}, \"nesteddocs1\": {\"type\": \"nested\"}}}}") + .get(); int numDocs = TestUtil.nextInt(random(), 20, 100); Map docs = new HashMap<>(); @@ -588,16 +698,14 @@ public Map buildIndex() { int numNestedDocs = TestUtil.nextInt(random(), 1, 20); List nesteddocs1 = new ArrayList<>(); for (int j = 0; j < numNestedDocs; j++) { - nesteddocs1.add( - new NestedDoc( - "nestedvalue", - Math.abs(random().nextFloat()))); + nesteddocs1.add(new NestedDoc("nestedvalue", Math.abs(random().nextFloat()))); } Doc d = new Doc( - field1IsFound ? "found" : "notfound", - field1IsFound ? "notfound" : "found", - Math.abs(random().nextFloat()), - nesteddocs1); + field1IsFound ? "found" : "notfound", + field1IsFound ? "notfound" : "found", + Math.abs(random().nextFloat()), + nesteddocs1 + ); indexDoc(d); docs.put(d.id, d); } @@ -606,9 +714,10 @@ public Map buildIndex() { } public void indexDoc(Doc d) { - IndexResponse resp = client().prepareIndex("test_index") - .setSource("field1", d.field1, "field2", d.field2, "scorefield1", d.scorefield1, "nesteddocs1", d.getNesteddocs1()) - .get(); + IndexResponse resp = client() + .prepareIndex("test_index") + .setSource("field1", d.field1, "field2", d.field2, "scorefield1", d.scorefield1, "nesteddocs1", d.getNesteddocs1()) + .get(); d.id = resp.getId(); } diff --git a/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java b/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java index b5139c8e..8cf891a8 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java @@ -17,6 +17,22 @@ package com.o19s.es.ltr.query; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; +import org.opensearch.action.search.SearchRequestBuilder; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.search.rescore.QueryRescoreMode; +import org.opensearch.search.rescore.QueryRescorerBuilder; + import com.o19s.es.ltr.LtrTestUtils; import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetRequestBuilder; import com.o19s.es.ltr.action.BaseIntegrationTest; @@ -28,48 +44,41 @@ import com.o19s.es.ltr.feature.store.StoredFeature; import com.o19s.es.ltr.feature.store.StoredLtrModel; import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import org.opensearch.action.search.SearchRequestBuilder; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.WrapperQueryBuilder; -import org.opensearch.search.rescore.QueryRescoreMode; -import org.opensearch.search.rescore.QueryRescorerBuilder; -import org.hamcrest.CoreMatchers; -import org.hamcrest.Matchers; - -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.ExecutionException; /** * Created by doug on 12/29/16. */ public class StoredLtrQueryIT extends BaseIntegrationTest { - private static final String SIMPLE_MODEL = "{" + - "\"feature1\": 1," + - "\"feature2\": -1," + - "\"feature3\": 10," + - "\"feature4\": 1," + - "\"feature5\": 1," + - "\"feature6\": 1" + - "}"; - - private static final String SIMPLE_SCRIPT_MODEL = "{" + - "\"feature1\": 1," + - "\"feature6\": 1" + - "}"; + private static final String SIMPLE_MODEL = "{" + + "\"feature1\": 1," + + "\"feature2\": -1," + + "\"feature3\": 10," + + "\"feature4\": 1," + + "\"feature5\": 1," + + "\"feature6\": 1" + + "}"; + private static final String SIMPLE_SCRIPT_MODEL = "{" + "\"feature1\": 1," + "\"feature6\": 1" + "}"; public void testScriptFeatureUseCase() throws Exception { - addElement(new StoredFeature("feature1", Collections.singletonList("query"), "mustache", - QueryBuilders.matchQuery("field1", "{{query}}").toString())); - addElement(new StoredFeature("feature6", Arrays.asList("query", "extra_multiplier_ltr"), ScriptFeature.TEMPLATE_LANGUAGE, - "{\"lang\": \"native\", \"source\": \"feature_extractor\", \"params\": { \"dependent_feature\": \"feature1\"," + - " \"extra_script_params\" : {\"extra_multiplier_ltr\": \"extra_multiplier\"}}}")); + addElement( + new StoredFeature( + "feature1", + Collections.singletonList("query"), + "mustache", + QueryBuilders.matchQuery("field1", "{{query}}").toString() + ) + ); + addElement( + new StoredFeature( + "feature6", + Arrays.asList("query", "extra_multiplier_ltr"), + ScriptFeature.TEMPLATE_LANGUAGE, + "{\"lang\": \"native\", \"source\": \"feature_extractor\", \"params\": { \"dependent_feature\": \"feature1\"," + + " \"extra_script_params\" : {\"extra_multiplier_ltr\": \"extra_multiplier\"}}}" + ) + ); AddFeaturesToSetRequestBuilder builder = new AddFeaturesToSetRequestBuilder(client()); builder.request().setFeatureSet("my_set"); @@ -80,21 +89,30 @@ public void testScriptFeatureUseCase() throws Exception { long version = builder.get().getResponse().getVersion(); CreateModelFromSetRequestBuilder createModelFromSetRequestBuilder = new CreateModelFromSetRequestBuilder(client()); - createModelFromSetRequestBuilder.withVersion(IndexFeatureStore.DEFAULT_STORE, "my_set", version, - "my_model", new StoredLtrModel.LtrModelDefinition("model/linear", SIMPLE_SCRIPT_MODEL, true)); + createModelFromSetRequestBuilder + .withVersion( + IndexFeatureStore.DEFAULT_STORE, + "my_set", + version, + "my_model", + new StoredLtrModel.LtrModelDefinition("model/linear", SIMPLE_SCRIPT_MODEL, true) + ); createModelFromSetRequestBuilder.get(); buildIndex(); Map params = new HashMap<>(); params.put("query", "hello"); params.put("dependent_feature", new HashMap<>()); params.put("extra_multiplier_ltr", 100.0d); - SearchRequestBuilder sb = client().prepareSearch("test_index") - .setQuery(QueryBuilders.matchQuery("field1", "world")) - .setRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .modelName("my_model").params(params).toString())) - .setScoreMode(QueryRescoreMode.Total) - .setQueryWeight(0) - .setRescoreQueryWeight(1)); + SearchRequestBuilder sb = client() + .prepareSearch("test_index") + .setQuery(QueryBuilders.matchQuery("field1", "world")) + .setRescorer( + new QueryRescorerBuilder( + new WrapperQueryBuilder( + new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()).modelName("my_model").params(params).toString() + ) + ).setScoreMode(QueryRescoreMode.Total).setQueryWeight(0).setRescoreQueryWeight(1) + ); SearchResponse sr = sb.get(); assertEquals(1, sr.getHits().getTotalHits().value); @@ -103,19 +121,49 @@ public void testScriptFeatureUseCase() throws Exception { } public void testFullUsecase() throws Exception { - addElement(new StoredFeature("feature1", Collections.singletonList("query"), "mustache", - QueryBuilders.matchQuery("field1", "{{query}}").toString())); - addElement(new StoredFeature("feature2", Collections.singletonList("query"), "mustache", - QueryBuilders.matchQuery("field2", "{{query}}").toString())); - addElement(new StoredFeature("feature3", Collections.singletonList("query"), "derived_expression", - "(feature1 - feature2) > 0 ? 1 : -1")); - addElement(new StoredFeature("feature4", Collections.singletonList("query"), "mustache", - QueryBuilders.matchQuery("field1", "{{query}}").toString())); - addElement(new StoredFeature("feature5", Collections.singletonList("multiplier"), "derived_expression", - "(feature1 - feature2) > 0 ? feature1 * multiplier: feature2 * multiplier")); - addElement(new StoredFeature("feature6", Collections.singletonList("query"), ScriptFeature.TEMPLATE_LANGUAGE, - "{\"lang\": \"native\", \"source\": \"feature_extractor\", \"params\": { \"dependent_feature\": \"feature1\"}}")); - + addElement( + new StoredFeature( + "feature1", + Collections.singletonList("query"), + "mustache", + QueryBuilders.matchQuery("field1", "{{query}}").toString() + ) + ); + addElement( + new StoredFeature( + "feature2", + Collections.singletonList("query"), + "mustache", + QueryBuilders.matchQuery("field2", "{{query}}").toString() + ) + ); + addElement( + new StoredFeature("feature3", Collections.singletonList("query"), "derived_expression", "(feature1 - feature2) > 0 ? 1 : -1") + ); + addElement( + new StoredFeature( + "feature4", + Collections.singletonList("query"), + "mustache", + QueryBuilders.matchQuery("field1", "{{query}}").toString() + ) + ); + addElement( + new StoredFeature( + "feature5", + Collections.singletonList("multiplier"), + "derived_expression", + "(feature1 - feature2) > 0 ? feature1 * multiplier: feature2 * multiplier" + ) + ); + addElement( + new StoredFeature( + "feature6", + Collections.singletonList("query"), + ScriptFeature.TEMPLATE_LANGUAGE, + "{\"lang\": \"native\", \"source\": \"feature_extractor\", \"params\": { \"dependent_feature\": \"feature1\"}}" + ) + ); AddFeaturesToSetRequestBuilder builder = new AddFeaturesToSetRequestBuilder(client()); builder.request().setFeatureSet("my_set"); @@ -139,8 +187,14 @@ public void testFullUsecase() throws Exception { long version = builder.get().getResponse().getVersion(); CreateModelFromSetRequestBuilder createModelFromSetRequestBuilder = new CreateModelFromSetRequestBuilder(client()); - createModelFromSetRequestBuilder.withVersion(IndexFeatureStore.DEFAULT_STORE, "my_set", version, - "my_model", new StoredLtrModel.LtrModelDefinition("model/linear", SIMPLE_MODEL, true)); + createModelFromSetRequestBuilder + .withVersion( + IndexFeatureStore.DEFAULT_STORE, + "my_set", + version, + "my_model", + new StoredLtrModel.LtrModelDefinition("model/linear", SIMPLE_MODEL, true) + ); createModelFromSetRequestBuilder.get(); buildIndex(); Map params = new HashMap<>(); @@ -149,13 +203,16 @@ public void testFullUsecase() throws Exception { params.put("query", negativeScore ? "bonjour" : "hello"); params.put("multiplier", negativeScore ? Integer.parseInt("-1") : 1.0); params.put("dependent_feature", new HashMap<>()); - SearchRequestBuilder sb = client().prepareSearch("test_index") - .setQuery(QueryBuilders.matchQuery("field1", "world")) - .setRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .modelName("my_model").params(params).toString())) - .setScoreMode(QueryRescoreMode.Total) - .setQueryWeight(0) - .setRescoreQueryWeight(1)); + SearchRequestBuilder sb = client() + .prepareSearch("test_index") + .setQuery(QueryBuilders.matchQuery("field1", "world")) + .setRescorer( + new QueryRescorerBuilder( + new WrapperQueryBuilder( + new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()).modelName("my_model").params(params).toString() + ) + ).setScoreMode(QueryRescoreMode.Total).setQueryWeight(0).setRescoreQueryWeight(1) + ); SearchResponse sr = sb.get(); assertEquals(1, sr.getHits().getTotalHits().value); @@ -170,13 +227,16 @@ public void testFullUsecase() throws Exception { params.put("query", negativeScore ? "bonjour" : "hello"); params.put("multiplier", negativeScore ? -1 : 1.0); params.put("dependent_feature", new HashMap<>()); - sb = client().prepareSearch("test_index") - .setQuery(QueryBuilders.matchQuery("field1", "world")) - .setRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .modelName("my_model").params(params).toString())) - .setScoreMode(QueryRescoreMode.Total) - .setQueryWeight(0) - .setRescoreQueryWeight(1)); + sb = client() + .prepareSearch("test_index") + .setQuery(QueryBuilders.matchQuery("field1", "world")) + .setRescorer( + new QueryRescorerBuilder( + new WrapperQueryBuilder( + new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()).modelName("my_model").params(params).toString() + ) + ).setScoreMode(QueryRescoreMode.Total).setQueryWeight(0).setRescoreQueryWeight(1) + ); sr = sb.get(); assertEquals(1, sr.getHits().getTotalHits().value); @@ -188,79 +248,111 @@ public void testFullUsecase() throws Exception { } // Test profiling - sb = client().prepareSearch("test_index") - .setProfile(true) - .setQuery(QueryBuilders.matchQuery("field1", "world")) - .setRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .modelName("my_model").params(params).toString())) - .setScoreMode(QueryRescoreMode.Total) - .setQueryWeight(0) - .setRescoreQueryWeight(1)); + sb = client() + .prepareSearch("test_index") + .setProfile(true) + .setQuery(QueryBuilders.matchQuery("field1", "world")) + .setRescorer( + new QueryRescorerBuilder( + new WrapperQueryBuilder( + new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()).modelName("my_model").params(params).toString() + ) + ).setScoreMode(QueryRescoreMode.Total).setQueryWeight(0).setRescoreQueryWeight(1) + ); sr = sb.get(); assertThat(sr.getProfileResults().isEmpty(), Matchers.equalTo(false)); - //we use only feature4 score and ignore other scores + // we use only feature4 score and ignore other scores params.put("query", "hello"); - sb = client().prepareSearch("test_index") - .setQuery(QueryBuilders.matchQuery("field1", "world")) - .setRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .modelName("my_model").params(params).activeFeatures(Collections.singletonList("feature4")).toString())) - .setScoreMode(QueryRescoreMode.Total) - .setQueryWeight(0) - .setRescoreQueryWeight(1)); + sb = client() + .prepareSearch("test_index") + .setQuery(QueryBuilders.matchQuery("field1", "world")) + .setRescorer( + new QueryRescorerBuilder( + new WrapperQueryBuilder( + new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .modelName("my_model") + .params(params) + .activeFeatures(Collections.singletonList("feature4")) + .toString() + ) + ).setScoreMode(QueryRescoreMode.Total).setQueryWeight(0).setRescoreQueryWeight(1) + ); sr = sb.get(); assertEquals(1, sr.getHits().getTotalHits().value); assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(0.0f)); assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThanOrEqualTo(1.0f)); - //we use feature 5 with query time positive int multiplier passed to feature5 + // we use feature 5 with query time positive int multiplier passed to feature5 params.put("query", "hello"); params.put("multiplier", Integer.parseInt("100")); - sb = client().prepareSearch("test_index") - .setQuery(QueryBuilders.matchQuery("field1", "world")) - .setRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .modelName("my_model").params(params).activeFeatures(Arrays.asList("feature1", "feature2", "feature5")).toString())) - .setScoreMode(QueryRescoreMode.Total) - .setQueryWeight(0) - .setRescoreQueryWeight(1)); + sb = client() + .prepareSearch("test_index") + .setQuery(QueryBuilders.matchQuery("field1", "world")) + .setRescorer( + new QueryRescorerBuilder( + new WrapperQueryBuilder( + new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .modelName("my_model") + .params(params) + .activeFeatures(Arrays.asList("feature1", "feature2", "feature5")) + .toString() + ) + ).setScoreMode(QueryRescoreMode.Total).setQueryWeight(0).setRescoreQueryWeight(1) + ); sr = sb.get(); assertEquals(1, sr.getHits().getTotalHits().value); assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(28.0f)); assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThan(30.0f)); - //we use feature 5 with query time negative double multiplier passed to feature5 + // we use feature 5 with query time negative double multiplier passed to feature5 params.put("query", "hello"); params.put("multiplier", Double.parseDouble("-100.55")); - sb = client().prepareSearch("test_index") - .setQuery(QueryBuilders.matchQuery("field1", "world")) - .setRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .modelName("my_model").params(params).activeFeatures(Arrays.asList("feature1", "feature2", "feature5")).toString())) - .setScoreMode(QueryRescoreMode.Total) - .setQueryWeight(0) - .setRescoreQueryWeight(1)); + sb = client() + .prepareSearch("test_index") + .setQuery(QueryBuilders.matchQuery("field1", "world")) + .setRescorer( + new QueryRescorerBuilder( + new WrapperQueryBuilder( + new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .modelName("my_model") + .params(params) + .activeFeatures(Arrays.asList("feature1", "feature2", "feature5")) + .toString() + ) + ).setScoreMode(QueryRescoreMode.Total).setQueryWeight(0).setRescoreQueryWeight(1) + ); sr = sb.get(); assertEquals(1, sr.getHits().getTotalHits().value); assertThat(sr.getHits().getAt(0).getScore(), Matchers.lessThan(-28.0f)); assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(-30.0f)); - //we use feature1 and feature6(ScriptFeature) + // we use feature1 and feature6(ScriptFeature) params.put("query", "hello"); params.put("dependent_feature", new HashMap<>()); - sb = client().prepareSearch("test_index") - .setQuery(QueryBuilders.matchQuery("field1", "world")) - .setRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) - .modelName("my_model").params(params).activeFeatures(Arrays.asList("feature1", "feature6")).toString())) - .setScoreMode(QueryRescoreMode.Total) - .setQueryWeight(0) - .setRescoreQueryWeight(1)); + sb = client() + .prepareSearch("test_index") + .setQuery(QueryBuilders.matchQuery("field1", "world")) + .setRescorer( + new QueryRescorerBuilder( + new WrapperQueryBuilder( + new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .modelName("my_model") + .params(params) + .activeFeatures(Arrays.asList("feature1", "feature6")) + .toString() + ) + ).setScoreMode(QueryRescoreMode.Total).setQueryWeight(0).setRescoreQueryWeight(1) + ); sr = sb.get(); assertEquals(1, sr.getHits().getTotalHits().value); assertThat(sr.getHits().getAt(0).getScore(), Matchers.greaterThan(0.2876f + 2.876f)); StoredLtrModel model = getElement(StoredLtrModel.class, StoredLtrModel.TYPE, "my_model"); - CachesStatsNodesResponse stats = client().execute(CachesStatsAction.INSTANCE, - new CachesStatsAction.CachesStatsNodesRequest()).get(); + CachesStatsNodesResponse stats = client() + .execute(CachesStatsAction.INSTANCE, new CachesStatsAction.CachesStatsNodesRequest()) + .get(); assertEquals(1, stats.getAll().getTotal().getCount()); assertEquals(model.compile(parserFactory()).ramBytesUsed(), stats.getAll().getTotal().getRam()); assertEquals(1, stats.getAll().getModels().getCount()); @@ -274,33 +366,33 @@ public void testFullUsecase() throws Exception { clearCache.clearModel(IndexFeatureStore.DEFAULT_STORE, "my_model"); client().execute(ClearCachesAction.INSTANCE, clearCache).get(); - stats = client().execute(CachesStatsAction.INSTANCE, - new CachesStatsAction.CachesStatsNodesRequest()).get(); + stats = client().execute(CachesStatsAction.INSTANCE, new CachesStatsAction.CachesStatsNodesRequest()).get(); assertEquals(0, stats.getAll().getTotal().getCount()); assertEquals(0, stats.getAll().getTotal().getRam()); } public void testInvalidDerived() throws Exception { - addElement(new StoredFeature("bad_df", Collections.singletonList("query"), "derived_expression", - "what + is + this")); + addElement(new StoredFeature("bad_df", Collections.singletonList("query"), "derived_expression", "what + is + this")); AddFeaturesToSetRequestBuilder builder = new AddFeaturesToSetRequestBuilder(client()); builder.request().setFeatureSet("my_bad_set"); builder.request().setFeatureNameQuery("bad_df"); builder.request().setStore(IndexFeatureStore.DEFAULT_STORE); - assertThat(expectThrows(ExecutionException.class, () -> builder.execute().get()).getMessage(), - CoreMatchers.containsString("refers to unknown feature")); + assertThat( + expectThrows(ExecutionException.class, () -> builder.execute().get()).getMessage(), + CoreMatchers.containsString("refers to unknown feature") + ); } public void buildIndex() { client().admin().indices().prepareCreate("test_index").get(); - client().prepareIndex("test_index") - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .setSource("field1", "hello world", "field2", "bonjour world") - .get(); + client() + .prepareIndex("test_index") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setSource("field1", "hello world", "field2", "bonjour world") + .get(); } - } diff --git a/src/main/java/com/o19s/es/explore/ExplorerQuery.java b/src/main/java/com/o19s/es/explore/ExplorerQuery.java index d22d6d74..e14f19ff 100644 --- a/src/main/java/com/o19s/es/explore/ExplorerQuery.java +++ b/src/main/java/com/o19s/es/explore/ExplorerQuery.java @@ -16,33 +16,32 @@ package com.o19s.es.explore; -import org.apache.lucene.search.QueryVisitor; -import org.opensearch.ltr.settings.LTRSettings; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; +import java.io.IOException; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Weight; -import org.apache.lucene.search.TermStatistics; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.ConstantScoreScorer; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.BooleanClause; - -import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.TermStatistics; +import org.apache.lucene.search.Weight; import org.apache.lucene.search.similarities.ClassicSimilarity; - -import java.io.IOException; -import java.util.HashSet; -import java.util.Objects; -import java.util.Set; +import org.opensearch.ltr.settings.LTRSettings; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; public class ExplorerQuery extends Query { private final Query query; @@ -56,26 +55,25 @@ public ExplorerQuery(Query query, String type, LTRStats ltrStats) { } private boolean isCollectionScoped() { - return type.endsWith("_count") - || type.endsWith("_df") - || type.endsWith("_idf") - || type.endsWith(("_ttf")); + return type.endsWith("_count") || type.endsWith("_df") || type.endsWith("_idf") || type.endsWith(("_ttf")); } - public Query getQuery() { return this.query; } + public Query getQuery() { + return this.query; + } - public String getType() { return this.type; } + public String getType() { + return this.type; + } @SuppressWarnings("EqualsWhichDoesntCheckParameterClass") @Override public boolean equals(Object other) { - return sameClassAs(other) && - equalsTo(getClass().cast(other)); + return sameClassAs(other) && equalsTo(getClass().cast(other)); } private boolean equalsTo(ExplorerQuery other) { - return Objects.equals(query, other.query) - && Objects.equals(type, other.type); + return Objects.equals(query, other.query) && Objects.equals(type, other.type); } @Override @@ -122,7 +120,7 @@ private Weight createWeightInternal(IndexSearcher searcher, ScoreMode scoreMode, for (Term term : terms) { TermStates ctx = TermStates.build(searcher, term, scoreMode.needsScores()); - if(ctx != null && ctx.docFreq() > 0){ + if (ctx != null && ctx.docFreq() > 0) { TermStatistics tStats = searcher.termStatistics(term, ctx.docFreq(), ctx.totalTermFreq()); df_stats.add(tStats.docFreq()); idf_stats.add(sim.idf(tStats.docFreq(), searcher.collectionStatistics(term.field()).docCount())); @@ -206,9 +204,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio Scorer scorer = scorer(context); int newDoc = scorer.iterator().advance(doc); assert newDoc == doc; // this is a DocIdSetIterator.all - return Explanation.match( - scorer.score(), - "Stat Score: " + type); + return Explanation.match(scorer.score(), "Stat Score: " + type); } @Override @@ -238,12 +234,10 @@ public boolean isCacheable(LeafReaderContext ctx) { } private BooleanClause makeBooleanClause(Term term, String type) throws IllegalArgumentException { - if(type.endsWith("_raw_tf")) { - return new BooleanClause(new PostingsExplorerQuery(term, PostingsExplorerQuery.Type.TF), - BooleanClause.Occur.SHOULD); - }else if(type.endsWith("_raw_tp")) { - return new BooleanClause(new PostingsExplorerQuery(term, PostingsExplorerQuery.Type.TP), - BooleanClause.Occur.SHOULD); + if (type.endsWith("_raw_tf")) { + return new BooleanClause(new PostingsExplorerQuery(term, PostingsExplorerQuery.Type.TF), BooleanClause.Occur.SHOULD); + } else if (type.endsWith("_raw_tp")) { + return new BooleanClause(new PostingsExplorerQuery(term, PostingsExplorerQuery.Type.TP), BooleanClause.Occur.SHOULD); } throw new IllegalArgumentException("Unknown ExplorerQuery type [" + type + "]"); } @@ -269,9 +263,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio if (scorer != null) { int newDoc = scorer.iterator().advance(doc); if (newDoc == doc) { - return Explanation.match( - scorer.score(), - "Stat Score: " + type); + return Explanation.match(scorer.score(), "Stat Score: " + type); } } return Explanation.noMatch("no matching term"); diff --git a/src/main/java/com/o19s/es/explore/ExplorerQueryBuilder.java b/src/main/java/com/o19s/es/explore/ExplorerQueryBuilder.java index d805ce9e..0d594f6f 100644 --- a/src/main/java/com/o19s/es/explore/ExplorerQueryBuilder.java +++ b/src/main/java/com/o19s/es/explore/ExplorerQueryBuilder.java @@ -15,8 +15,9 @@ package com.o19s.es.explore; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; +import java.io.IOException; +import java.util.Objects; + import org.apache.lucene.search.Query; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; @@ -31,9 +32,8 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.Rewriteable; - -import java.io.IOException; -import java.util.Objects; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; public class ExplorerQueryBuilder extends AbstractQueryBuilder implements NamedWriteable { public static final String NAME = "match_explorer"; @@ -44,11 +44,7 @@ public class ExplorerQueryBuilder extends AbstractQueryBuilder(NAME, ExplorerQueryBuilder::new); - PARSER.declareObject( - ExplorerQueryBuilder::query, - (parser, context) -> parseInnerQueryBuilder(parser), - QUERY_NAME - ); + PARSER.declareObject(ExplorerQueryBuilder::query, (parser, context) -> parseInnerQueryBuilder(parser), QUERY_NAME); PARSER.declareString(ExplorerQueryBuilder::statsType, TYPE_NAME); declareStandardFields(PARSER); } @@ -57,9 +53,7 @@ public class ExplorerQueryBuilder extends AbstractQueryBuilder 0) { - for(ChildScorable child : subScorer.getChildren()) { + if (subScorer.getChildren().size() > 0) { + for (ChildScorable child : subScorer.getChildren()) { assert child.child instanceof PostingsExplorerQuery.PostingsExplorerScorer; - if(child.child.docID() == docID()) { + if (child.child.docID() == docID()) { ((PostingsExplorerQuery.PostingsExplorerScorer) child.child).setType(type); tf_stats.add(child.child.score()); } @@ -52,25 +52,25 @@ public float score() throws IOException { } float retval; - switch(type) { - case("sum_raw_tf"): + switch (type) { + case ("sum_raw_tf"): retval = tf_stats.getSum(); break; - case("mean_raw_tf"): + case ("mean_raw_tf"): retval = tf_stats.getMean(); break; - case("max_raw_tf"): - case("max_raw_tp"): + case ("max_raw_tf"): + case ("max_raw_tp"): retval = tf_stats.getMax(); break; - case("min_raw_tf"): - case("min_raw_tp"): + case ("min_raw_tf"): + case ("min_raw_tp"): retval = tf_stats.getMin(); break; - case("stddev_raw_tf"): + case ("stddev_raw_tf"): retval = tf_stats.getStdDev(); break; - case("avg_raw_tp"): + case ("avg_raw_tp"): retval = tf_stats.getMean(); break; default: @@ -85,7 +85,6 @@ public int docID() { return subScorer.docID(); } - @Override public DocIdSetIterator iterator() { return subScorer.iterator(); diff --git a/src/main/java/com/o19s/es/explore/PostingsExplorerQuery.java b/src/main/java/com/o19s/es/explore/PostingsExplorerQuery.java index 5e740e37..71a2e75a 100644 --- a/src/main/java/com/o19s/es/explore/PostingsExplorerQuery.java +++ b/src/main/java/com/o19s/es/explore/PostingsExplorerQuery.java @@ -16,8 +16,12 @@ package com.o19s.es.explore; -import com.o19s.es.ltr.utils.CheckedBiFunction; -import org.apache.lucene.index.IndexReaderContext; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.ReaderUtil; @@ -25,21 +29,17 @@ import org.apache.lucene.index.TermState; import org.apache.lucene.index.TermStates; import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.Weight; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Weight; import org.opensearch.ltr.settings.LTRSettings; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Objects; -import java.util.Set; +import com.o19s.es.ltr.utils.CheckedBiFunction; public class PostingsExplorerQuery extends Query { private final Term term; @@ -67,8 +67,8 @@ public String toString(String field) { @Override public boolean equals(Object obj) { return this.sameClassAs(obj) - && this.term.equals(((PostingsExplorerQuery) obj).term) - && this.type.equals(((PostingsExplorerQuery) obj).type); + && this.term.equals(((PostingsExplorerQuery) obj).term) + && this.type.equals(((PostingsExplorerQuery) obj).type); } @Override @@ -77,16 +77,13 @@ public int hashCode() { } @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) - throws IOException { + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { if (!LTRSettings.isLTRPluginEnabled()) { throw new IllegalStateException("LTR plugin is disabled. To enable, update ltr.plugin.enabled to true"); } assert scoreMode.needsScores() : "Should not be used in filtering mode"; - return new PostingsExplorerWeight(this, this.term, TermStates.build(searcher, this.term, - scoreMode.needsScores()), - this.type); + return new PostingsExplorerWeight(this, this.term, TermStates.build(searcher, this.term, scoreMode.needsScores()), this.type); } /** @@ -130,16 +127,14 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio Scorer scorer = this.scorer(context); int newDoc = scorer.iterator().advance(doc); if (newDoc == doc) { - return Explanation - .match(scorer.score(), "weight(" + this.getQuery() + " in doc " + newDoc + ")"); + return Explanation.match(scorer.score(), "weight(" + this.getQuery() + " in doc " + newDoc + ")"); } return Explanation.noMatch("no matching term"); } @Override public Scorer scorer(LeafReaderContext context) throws IOException { - assert this.termStates != null && this.termStates - .wasBuiltFor(ReaderUtil.getTopLevelContext(context)); + assert this.termStates != null && this.termStates.wasBuiltFor(ReaderUtil.getTopLevelContext(context)); TermState state = this.termStates.get(context); if (state == null) { return null; @@ -205,6 +200,7 @@ static class TPScorer extends PostingsExplorerScorer { TPScorer(Weight weight, PostingsEnum postingsEnum) { super(weight, postingsEnum); } + @Override public float score() throws IOException { if (this.postingsEnum.freq() <= 0) { @@ -212,23 +208,23 @@ public float score() throws IOException { } ArrayList positions = new ArrayList(); - for (int i=0;i this.max) { + if (val > this.max) { this.max = val; } } @@ -62,7 +62,7 @@ public ArrayList getData() { return data; } - public int getSize(){ + public int getSize() { return data.size(); } @@ -89,7 +89,7 @@ public float getSum() { float sum = 0.0f; - for(float a : data) { + for (float a : data) { sum += a; } @@ -101,9 +101,9 @@ public float getVariance() { float mean = getMean(); float temp = 0.0f; - for(float a : data) - temp += (a-mean)*(a-mean); - return temp/data.size(); + for (float a : data) + temp += (a - mean) * (a - mean); + return temp / data.size(); } public float getStdDev() { @@ -113,7 +113,7 @@ public float getStdDev() { } public float getAggr(AggrType type) { - switch(type) { + switch (type) { case AVG: return getMean(); case MAX: diff --git a/src/main/java/com/o19s/es/ltr/LtrQueryContext.java b/src/main/java/com/o19s/es/ltr/LtrQueryContext.java index 30fabd0b..9735cde7 100644 --- a/src/main/java/com/o19s/es/ltr/LtrQueryContext.java +++ b/src/main/java/com/o19s/es/ltr/LtrQueryContext.java @@ -16,11 +16,11 @@ package com.o19s.es.ltr; -import org.opensearch.index.query.QueryShardContext; - import java.util.Collections; import java.util.Set; +import org.opensearch.index.query.QueryShardContext; + /** * LTR queryShardContext used to track information needed for building lucene queries */ @@ -46,6 +46,6 @@ public boolean isFeatureActive(String featureName) { } public Set getActiveFeatures() { - return activeFeatures==null? Collections.emptySet(): Collections.unmodifiableSet(activeFeatures); + return activeFeatures == null ? Collections.emptySet() : Collections.unmodifiableSet(activeFeatures); } } diff --git a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java index f7b2617b..c55c05d2 100644 --- a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java +++ b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java @@ -16,19 +16,75 @@ */ package com.o19s.es.ltr; -import ciir.umass.edu.learning.RankerFactory; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.Collections.unmodifiableList; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.lucene.analysis.core.KeywordTokenizer; +import org.apache.lucene.analysis.miscellaneous.LengthFilter; +import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; +import org.opensearch.action.ActionRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedFunction; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; +import org.opensearch.core.ParseField; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry.Entry; +import org.opensearch.core.index.Index; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.env.Environment; +import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.analysis.PreConfiguredTokenFilter; +import org.opensearch.index.analysis.PreConfiguredTokenizer; import org.opensearch.ltr.breaker.LTRCircuitBreakerService; -import org.opensearch.ltr.settings.LTRSettings; import org.opensearch.ltr.rest.RestStatsLTRAction; +import org.opensearch.ltr.settings.LTRSettings; import org.opensearch.ltr.stats.LTRStat; import org.opensearch.ltr.stats.LTRStats; import org.opensearch.ltr.stats.StatName; import org.opensearch.ltr.stats.suppliers.CacheStatsOnNodeSupplier; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; import org.opensearch.ltr.stats.suppliers.PluginHealthStatusSupplier; import org.opensearch.ltr.stats.suppliers.StoreStatsSupplier; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; import org.opensearch.ltr.transport.LTRStatsAction; import org.opensearch.ltr.transport.TransportLTRStatsAction; +import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.AnalysisPlugin; +import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.ScriptPlugin; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.repositories.RepositoriesService; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; +import org.opensearch.script.ScriptContext; +import org.opensearch.script.ScriptEngine; +import org.opensearch.script.ScriptService; +import org.opensearch.search.fetch.FetchSubPhase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.watcher.ResourceWatcherService; + import com.o19s.es.explore.ExplorerQueryBuilder; import com.o19s.es.ltr.action.AddFeaturesToSetAction; import com.o19s.es.ltr.action.CachesStatsAction; @@ -59,71 +115,17 @@ import com.o19s.es.ltr.ranker.parser.XGBoostJsonParser; import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine; import com.o19s.es.ltr.ranker.ranklib.RanklibModelParser; +import com.o19s.es.ltr.rest.RestAddFeatureToSet; import com.o19s.es.ltr.rest.RestCreateModelFromSet; import com.o19s.es.ltr.rest.RestFeatureManager; +import com.o19s.es.ltr.rest.RestFeatureStoreCaches; import com.o19s.es.ltr.rest.RestSearchStoreElements; import com.o19s.es.ltr.rest.RestStoreManager; -import com.o19s.es.ltr.rest.RestAddFeatureToSet; -import com.o19s.es.ltr.rest.RestFeatureStoreCaches; import com.o19s.es.ltr.utils.FeatureStoreLoader; import com.o19s.es.ltr.utils.Suppliers; import com.o19s.es.termstat.TermStatQueryBuilder; -import org.apache.lucene.analysis.core.KeywordTokenizer; -import org.apache.lucene.analysis.miscellaneous.LengthFilter; -import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter; -import org.opensearch.action.ActionRequest; -import org.opensearch.core.action.ActionResponse; -import org.opensearch.client.Client; -import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.node.DiscoveryNodes; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.CheckedFunction; -import org.opensearch.core.ParseField; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.common.io.stream.NamedWriteableRegistry.Entry; -import org.opensearch.common.settings.ClusterSettings; -import org.opensearch.common.settings.IndexScopedSettings; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.settings.SettingsFilter; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.env.Environment; -import org.opensearch.env.NodeEnvironment; -import org.opensearch.core.index.Index; -import org.opensearch.index.analysis.PreConfiguredTokenFilter; -import org.opensearch.index.analysis.PreConfiguredTokenizer; -import org.opensearch.monitor.jvm.JvmService; -import org.opensearch.plugins.ActionPlugin; -import org.opensearch.plugins.AnalysisPlugin; -import org.opensearch.plugins.Plugin; -import org.opensearch.plugins.ScriptPlugin; -import org.opensearch.plugins.SearchPlugin; -import org.opensearch.repositories.RepositoriesService; -import org.opensearch.rest.RestController; -import org.opensearch.rest.RestHandler; -import org.opensearch.script.ScriptContext; -import org.opensearch.script.ScriptEngine; -import org.opensearch.script.ScriptService; -import org.opensearch.search.fetch.FetchSubPhase; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.watcher.ResourceWatcherService; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; -import static java.util.Collections.unmodifiableList; +import ciir.umass.edu.learning.RankerFactory; public class LtrQueryParserPlugin extends Plugin implements SearchPlugin, ScriptPlugin, ActionPlugin, AnalysisPlugin { public static final String LTR_BASE_URI = "/_plugins/_ltr"; @@ -137,10 +139,10 @@ public LtrQueryParserPlugin(Settings settings) { // Use memoize to Lazy load the RankerFactory as it's a heavy object to construct Supplier ranklib = Suppliers.memoize(RankerFactory::new); parserFactory = new LtrRankerParserFactory.Builder() - .register(RanklibModelParser.TYPE, () -> new RanklibModelParser(ranklib.get())) - .register(LinearRankerParser.TYPE, LinearRankerParser::new) - .register(XGBoostJsonParser.TYPE, XGBoostJsonParser::new) - .build(); + .register(RanklibModelParser.TYPE, () -> new RanklibModelParser(ranklib.get())) + .register(LinearRankerParser.TYPE, LinearRankerParser::new) + .register(XGBoostJsonParser.TYPE, XGBoostJsonParser::new) + .build(); ltrStats = getInitialStats(); } @@ -148,27 +150,27 @@ public LtrQueryParserPlugin(Settings settings) { public List> getQueries() { return asList( - new QuerySpec<>( - ExplorerQueryBuilder.NAME, - (input) -> new ExplorerQueryBuilder(input, ltrStats), - (ctx) -> ExplorerQueryBuilder.fromXContent(ctx, ltrStats) - ), - new QuerySpec<>( - LtrQueryBuilder.NAME, - (input) -> new LtrQueryBuilder(input, ltrStats), - (ctx) -> LtrQueryBuilder.fromXContent(ctx, ltrStats) - ), - new QuerySpec<>( - StoredLtrQueryBuilder.NAME, - (input) -> new StoredLtrQueryBuilder(getFeatureStoreLoader(), input, ltrStats), - (ctx) -> StoredLtrQueryBuilder.fromXContent(getFeatureStoreLoader(), ctx, ltrStats) - ), - new QuerySpec<>(TermStatQueryBuilder.NAME, TermStatQueryBuilder::new, TermStatQueryBuilder::fromXContent), - new QuerySpec<>( - ValidatingLtrQueryBuilder.NAME, - (input) -> new ValidatingLtrQueryBuilder(input, parserFactory, ltrStats), - (ctx) -> ValidatingLtrQueryBuilder.fromXContent(ctx, parserFactory, ltrStats) - ) + new QuerySpec<>( + ExplorerQueryBuilder.NAME, + (input) -> new ExplorerQueryBuilder(input, ltrStats), + (ctx) -> ExplorerQueryBuilder.fromXContent(ctx, ltrStats) + ), + new QuerySpec<>( + LtrQueryBuilder.NAME, + (input) -> new LtrQueryBuilder(input, ltrStats), + (ctx) -> LtrQueryBuilder.fromXContent(ctx, ltrStats) + ), + new QuerySpec<>( + StoredLtrQueryBuilder.NAME, + (input) -> new StoredLtrQueryBuilder(getFeatureStoreLoader(), input, ltrStats), + (ctx) -> StoredLtrQueryBuilder.fromXContent(getFeatureStoreLoader(), ctx, ltrStats) + ), + new QuerySpec<>(TermStatQueryBuilder.NAME, TermStatQueryBuilder::new, TermStatQueryBuilder::fromXContent), + new QuerySpec<>( + ValidatingLtrQueryBuilder.NAME, + (input) -> new ValidatingLtrQueryBuilder(input, parserFactory, ltrStats), + (ctx) -> ValidatingLtrQueryBuilder.fromXContent(ctx, parserFactory, ltrStats) + ) ); } @@ -180,7 +182,8 @@ public List getFetchSubPhases(FetchPhaseConstructionContext conte @Override public List> getSearchExts() { return singletonList( - new SearchExtSpec<>(LoggingSearchExtBuilder.NAME, LoggingSearchExtBuilder::new, LoggingSearchExtBuilder::parse)); + new SearchExtSpec<>(LoggingSearchExtBuilder.NAME, LoggingSearchExtBuilder::new, LoggingSearchExtBuilder::parse) + ); } @Override @@ -189,10 +192,15 @@ public ScriptEngine getScriptEngine(Settings settings, Collection getRestHandlers(Settings settings, RestController restController, - ClusterSettings clusterSettings, IndexScopedSettings indexScopedSettings, - SettingsFilter settingsFilter, IndexNameExpressionResolver indexNameExpressionResolver, - Supplier nodesInCluster) { + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { List list = new ArrayList<>(); for (String type : ValidatingLtrQueryBuilder.SUPPORTED_TYPES) { @@ -210,23 +218,28 @@ public List getRestHandlers(Settings settings, RestController restC @Override public List> getActions() { - return unmodifiableList(asList( + return unmodifiableList( + asList( new ActionHandler<>(FeatureStoreAction.INSTANCE, TransportFeatureStoreAction.class), new ActionHandler<>(CachesStatsAction.INSTANCE, TransportCacheStatsAction.class), new ActionHandler<>(ClearCachesAction.INSTANCE, TransportClearCachesAction.class), new ActionHandler<>(AddFeaturesToSetAction.INSTANCE, TransportAddFeatureToSetAction.class), new ActionHandler<>(CreateModelFromSetAction.INSTANCE, TransportCreateModelFromSetAction.class), new ActionHandler<>(ListStoresAction.INSTANCE, TransportListStoresAction.class), - new ActionHandler<>(LTRStatsAction.INSTANCE, TransportLTRStatsAction.class))); + new ActionHandler<>(LTRStatsAction.INSTANCE, TransportLTRStatsAction.class) + ) + ); } @Override public List getNamedWriteables() { - return unmodifiableList(asList( + return unmodifiableList( + asList( new Entry(StorableElement.class, StoredFeature.TYPE, StoredFeature::new), new Entry(StorableElement.class, StoredFeatureSet.TYPE, StoredFeatureSet::new), new Entry(StorableElement.class, StoredLtrModel.TYPE, StoredLtrModel::new) - )); + ) + ); } @Override @@ -237,17 +250,25 @@ public List> getContexts() { @Override public List getNamedXContent() { - return unmodifiableList(asList( - new NamedXContentRegistry.Entry(StorableElement.class, - new ParseField(StoredFeature.TYPE), - (CheckedFunction) StoredFeature::parse), - new NamedXContentRegistry.Entry(StorableElement.class, - new ParseField(StoredFeatureSet.TYPE), - (CheckedFunction) StoredFeatureSet::parse), - new NamedXContentRegistry.Entry(StorableElement.class, - new ParseField(StoredLtrModel.TYPE), - (CheckedFunction) StoredLtrModel::parse) - )); + return unmodifiableList( + asList( + new NamedXContentRegistry.Entry( + StorableElement.class, + new ParseField(StoredFeature.TYPE), + (CheckedFunction) StoredFeature::parse + ), + new NamedXContentRegistry.Entry( + StorableElement.class, + new ParseField(StoredFeatureSet.TYPE), + (CheckedFunction) StoredFeatureSet::parse + ), + new NamedXContentRegistry.Entry( + StorableElement.class, + new ParseField(StoredLtrModel.TYPE), + (CheckedFunction) StoredLtrModel::parse + ) + ) + ); } @Override @@ -255,26 +276,29 @@ public List> getSettings() { List> list1 = LTRSettings.getInstance().getSettings(); List> list2 = asList( - IndexFeatureStore.STORE_VERSION_PROP, - Caches.LTR_CACHE_MEM_SETTING, - Caches.LTR_CACHE_EXPIRE_AFTER_READ, - Caches.LTR_CACHE_EXPIRE_AFTER_WRITE); + IndexFeatureStore.STORE_VERSION_PROP, + Caches.LTR_CACHE_MEM_SETTING, + Caches.LTR_CACHE_EXPIRE_AFTER_READ, + Caches.LTR_CACHE_EXPIRE_AFTER_WRITE + ); return unmodifiableList(Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList())); } @Override - public Collection createComponents(Client client, - ClusterService clusterService, - ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, - Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier) { + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { clusterService.addListener(event -> { for (Index i : event.indicesDeleted()) { if (IndexFeatureStore.isIndexStore(i.getName())) { @@ -293,32 +317,31 @@ public Collection createComponents(Client client, } private void addStats( - final Client client, - final ClusterService clusterService, - final LTRCircuitBreakerService ltrCircuitBreakerService + final Client client, + final ClusterService clusterService, + final LTRCircuitBreakerService ltrCircuitBreakerService ) { final StoreStatsSupplier storeStatsSupplier = StoreStatsSupplier.create(client, clusterService); ltrStats.addStats(StatName.LTR_STORES_STATS.getName(), new LTRStat<>(true, storeStatsSupplier)); - final PluginHealthStatusSupplier pluginHealthStatusSupplier = PluginHealthStatusSupplier.create( - client, clusterService, ltrCircuitBreakerService); + final PluginHealthStatusSupplier pluginHealthStatusSupplier = PluginHealthStatusSupplier + .create(client, clusterService, ltrCircuitBreakerService); ltrStats.addStats(StatName.LTR_PLUGIN_STATUS.getName(), new LTRStat<>(true, pluginHealthStatusSupplier)); } private LTRStats getInitialStats() { Map> stats = new HashMap<>(); - stats.put(StatName.LTR_CACHE_STATS.getName(), - new LTRStat<>(false, new CacheStatsOnNodeSupplier(caches))); - stats.put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - stats.put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); + stats.put(StatName.LTR_CACHE_STATS.getName(), new LTRStat<>(false, new CacheStatsOnNodeSupplier(caches))); + stats.put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + stats.put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); return new LTRStats((stats)); } protected FeatureStoreLoader getFeatureStoreLoader() { - return (storeName, clientSupplier) -> - new CachedFeatureStore(new IndexFeatureStore(storeName, clientSupplier, parserFactory), caches); + return (storeName, clientSupplier) -> new CachedFeatureStore( + new IndexFeatureStore(storeName, clientSupplier, parserFactory), + caches + ); } // A simplified version of some token filters needed by the feature stores. @@ -329,15 +352,18 @@ protected FeatureStoreLoader getFeatureStoreLoader() { @Override public List getPreConfiguredTokenFilters() { - return Arrays.asList( - PreConfiguredTokenFilter.singleton("ltr_edge_ngram", true, - (ts) -> new EdgeNGramTokenFilter(ts, 1, STORABLE_ELEMENT_MAX_NAME_SIZE, false)), - PreConfiguredTokenFilter.singleton("ltr_length", true, - (ts) -> new LengthFilter(ts, 0, STORABLE_ELEMENT_MAX_NAME_SIZE))); + return Arrays + .asList( + PreConfiguredTokenFilter + .singleton("ltr_edge_ngram", true, (ts) -> new EdgeNGramTokenFilter(ts, 1, STORABLE_ELEMENT_MAX_NAME_SIZE, false)), + PreConfiguredTokenFilter.singleton("ltr_length", true, (ts) -> new LengthFilter(ts, 0, STORABLE_ELEMENT_MAX_NAME_SIZE)) + ); } public List getPreConfiguredTokenizers() { - return Collections.singletonList(PreConfiguredTokenizer.singleton("ltr_keyword", - () -> new KeywordTokenizer(KeywordTokenizer.DEFAULT_BUFFER_SIZE))); + return Collections + .singletonList( + PreConfiguredTokenizer.singleton("ltr_keyword", () -> new KeywordTokenizer(KeywordTokenizer.DEFAULT_BUFFER_SIZE)) + ); } } diff --git a/src/main/java/com/o19s/es/ltr/action/AddFeaturesToSetAction.java b/src/main/java/com/o19s/es/ltr/action/AddFeaturesToSetAction.java index 7d3a6865..e07652f5 100644 --- a/src/main/java/com/o19s/es/ltr/action/AddFeaturesToSetAction.java +++ b/src/main/java/com/o19s/es/ltr/action/AddFeaturesToSetAction.java @@ -16,27 +16,28 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetResponse; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.FeatureValidation; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.List; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestBuilder; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.action.ActionResponse; import org.opensearch.action.ActionType; import org.opensearch.action.index.IndexResponse; import org.opensearch.client.OpenSearchClient; +import org.opensearch.common.xcontent.StatusToXContentObject; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable.Reader; -import org.opensearch.common.xcontent.StatusToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - -import static org.opensearch.action.ValidateActions.addValidationError; +import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetResponse; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StoredFeature; public class AddFeaturesToSetAction extends ActionType { public static final AddFeaturesToSetAction INSTANCE = new AddFeaturesToSetAction(); @@ -66,11 +67,9 @@ public static class AddFeaturesToSetRequest extends ActionRequest { private String routing; private FeatureValidation validation; - public AddFeaturesToSetRequest() { - } - + public AddFeaturesToSetRequest() {} - public AddFeaturesToSetRequest(StreamInput in) throws IOException { + public AddFeaturesToSetRequest(StreamInput in) throws IOException { super(in); store = in.readString(); features = in.readList(StoredFeature::new); @@ -83,7 +82,6 @@ public AddFeaturesToSetRequest(StreamInput in) throws IOException { validation = in.readOptionalWriteable(FeatureValidation::new); } - @Override public ActionRequestValidationException validate() { ActionRequestValidationException arve = null; diff --git a/src/main/java/com/o19s/es/ltr/action/CachesStatsAction.java b/src/main/java/com/o19s/es/ltr/action/CachesStatsAction.java index 4fc22bfe..12b5129f 100644 --- a/src/main/java/com/o19s/es/ltr/action/CachesStatsAction.java +++ b/src/main/java/com/o19s/es/ltr/action/CachesStatsAction.java @@ -16,8 +16,11 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.CachesStatsAction.CachesStatsNodesResponse; -import com.o19s.es.ltr.feature.store.index.Caches; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.opensearch.action.ActionRequestBuilder; import org.opensearch.action.ActionType; import org.opensearch.action.FailedNodeException; @@ -33,10 +36,8 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import com.o19s.es.ltr.action.CachesStatsAction.CachesStatsNodesResponse; +import com.o19s.es.ltr.feature.store.index.Caches; public class CachesStatsAction extends ActionType { public static final String NAME = "cluster:admin/ltr/caches/stats"; @@ -117,6 +118,7 @@ public StatDetails getAll() { return allStores; } } + public static class CachesStatsNodeResponse extends BaseNodeResponse { private StatDetails allStores; private Map byStore; @@ -159,6 +161,7 @@ public StatDetails getAllStores() { return allStores; } } + public static class StatDetails implements Writeable, ToXContent { private Stat total; private Stat features; @@ -223,12 +226,13 @@ public void doSum(StatDetails other) { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject() - .field("total", total) - .field("features", features) - .field("featuresets", featuresets) - .field("models", models) - .endObject(); + return builder + .startObject() + .field("total", total) + .field("features", features) + .field("featuresets", featuresets) + .field("models", models) + .endObject(); } public Stat getTotal() { @@ -282,17 +286,13 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject() - .field("ram", ram) - .field("count", count) - .endObject(); + return builder.startObject().field("ram", ram).field("count", count).endObject(); } } } - public static class CachesStatsActionBuilder extends - ActionRequestBuilder { - public CachesStatsActionBuilder(OpenSearchClient client){ + public static class CachesStatsActionBuilder extends ActionRequestBuilder { + public CachesStatsActionBuilder(OpenSearchClient client) { super(client, INSTANCE, new CachesStatsNodesRequest()); } } diff --git a/src/main/java/com/o19s/es/ltr/action/ClearCachesAction.java b/src/main/java/com/o19s/es/ltr/action/ClearCachesAction.java index c9999136..1f0f4266 100644 --- a/src/main/java/com/o19s/es/ltr/action/ClearCachesAction.java +++ b/src/main/java/com/o19s/es/ltr/action/ClearCachesAction.java @@ -16,7 +16,12 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesResponse; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + import org.opensearch.action.ActionRequestBuilder; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.ActionType; @@ -29,13 +34,9 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; import org.opensearch.core.common.io.stream.Writeable.Reader; -import static org.opensearch.action.ValidateActions.addValidationError; +import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesResponse; public class ClearCachesAction extends ActionType { public static final String NAME = "cluster:admin/ltr/caches"; @@ -61,7 +62,6 @@ public static class ClearCachesNodesRequest extends BaseNodesRequest responses, - List failures) { + public ClearCachesNodesResponse( + ClusterName clusterName, + List responses, + List failures + ) { super(clusterName, responses, failures); } @@ -172,6 +175,5 @@ public ClearCachesNodeResponse(DiscoveryNode node) { super(node); } - } } diff --git a/src/main/java/com/o19s/es/ltr/action/CreateModelFromSetAction.java b/src/main/java/com/o19s/es/ltr/action/CreateModelFromSetAction.java index d3f2565d..ff68ca58 100644 --- a/src/main/java/com/o19s/es/ltr/action/CreateModelFromSetAction.java +++ b/src/main/java/com/o19s/es/ltr/action/CreateModelFromSetAction.java @@ -16,25 +16,26 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetResponse; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StoredLtrModel; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestBuilder; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.action.ActionResponse; import org.opensearch.action.ActionType; import org.opensearch.action.index.IndexResponse; import org.opensearch.client.OpenSearchClient; +import org.opensearch.common.xcontent.StatusToXContentObject; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.common.xcontent.StatusToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; - -import static org.opensearch.action.ValidateActions.addValidationError; +import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetResponse; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StoredLtrModel; public class CreateModelFromSetAction extends ActionType { public static final String NAME = "cluster:admin/ltr/store/create-model-from-set"; @@ -44,16 +45,20 @@ protected CreateModelFromSetAction() { super(NAME, CreateModelFromSetResponse::new); } - - public static class CreateModelFromSetRequestBuilder extends ActionRequestBuilder { + public static class CreateModelFromSetRequestBuilder extends + ActionRequestBuilder { public CreateModelFromSetRequestBuilder(OpenSearchClient client) { super(client, INSTANCE, new CreateModelFromSetRequest()); } - public CreateModelFromSetRequestBuilder withVersion(String store, String featureSetName, long expectedSetVersion, - String modelName, StoredLtrModel.LtrModelDefinition definition) { + public CreateModelFromSetRequestBuilder withVersion( + String store, + String featureSetName, + long expectedSetVersion, + String modelName, + StoredLtrModel.LtrModelDefinition definition + ) { request.store = store; request.featureSetName = featureSetName; request.expectedSetVersion = expectedSetVersion; @@ -62,8 +67,12 @@ public CreateModelFromSetRequestBuilder withVersion(String store, String feature return this; } - public CreateModelFromSetRequestBuilder withoutVersion(String store, String featureSetName, String modelName, - StoredLtrModel.LtrModelDefinition definition) { + public CreateModelFromSetRequestBuilder withoutVersion( + String store, + String featureSetName, + String modelName, + StoredLtrModel.LtrModelDefinition definition + ) { request.store = store; request.featureSetName = featureSetName; request.expectedSetVersion = null; diff --git a/src/main/java/com/o19s/es/ltr/action/FeatureStoreAction.java b/src/main/java/com/o19s/es/ltr/action/FeatureStoreAction.java index 43383159..b37a4613 100644 --- a/src/main/java/com/o19s/es/ltr/action/FeatureStoreAction.java +++ b/src/main/java/com/o19s/es/ltr/action/FeatureStoreAction.java @@ -16,28 +16,29 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreResponse; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Objects; + import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestBuilder; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.action.ActionResponse; import org.opensearch.action.ActionType; import org.opensearch.action.index.IndexResponse; import org.opensearch.client.OpenSearchClient; +import org.opensearch.common.xcontent.StatusToXContentObject; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable.Reader; -import org.opensearch.common.xcontent.StatusToXContentObject; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Objects; - -import static org.opensearch.action.ValidateActions.addValidationError; +import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreResponse; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; public class FeatureStoreAction extends ActionType { public static final String NAME = "cluster:admin/ltr/featurestore/data"; @@ -52,8 +53,7 @@ public Reader getResponseReader() { return FeatureStoreResponse::new; } - public static class FeatureStoreRequestBuilder - extends ActionRequestBuilder { + public static class FeatureStoreRequestBuilder extends ActionRequestBuilder { public FeatureStoreRequestBuilder(OpenSearchClient client, FeatureStoreAction action) { super(client, action, new FeatureStoreRequest()); } diff --git a/src/main/java/com/o19s/es/ltr/action/LTRStatsAction.java b/src/main/java/com/o19s/es/ltr/action/LTRStatsAction.java index d117c50e..2aef7a07 100644 --- a/src/main/java/com/o19s/es/ltr/action/LTRStatsAction.java +++ b/src/main/java/com/o19s/es/ltr/action/LTRStatsAction.java @@ -16,6 +16,12 @@ package com.o19s.es.ltr.action; +import java.io.IOException; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + import org.opensearch.action.ActionRequestBuilder; import org.opensearch.action.ActionType; import org.opensearch.action.FailedNodeException; @@ -31,13 +37,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.transport.TransportRequest; - -import java.io.IOException; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; public class LTRStatsAction extends ActionType { public static final String NAME = "cluster:admin/ltr/stats"; @@ -47,8 +46,7 @@ public LTRStatsAction() { super(NAME, LTRStatsNodesResponse::new); } - public static class LTRStatsRequestBuilder - extends ActionRequestBuilder { + public static class LTRStatsRequestBuilder extends ActionRequestBuilder { private static final String[] nodeIds = null; public LTRStatsRequestBuilder(OpenSearchClient client) { @@ -152,8 +150,12 @@ public LTRStatsNodesResponse(StreamInput in) throws IOException { clusterStats = in.readMap(); } - public LTRStatsNodesResponse(ClusterName clusterName, List nodeResponses, - List failures, Map clusterStats) { + public LTRStatsNodesResponse( + ClusterName clusterName, + List nodeResponses, + List failures, + Map clusterStats + ) { super(clusterName, nodeResponses, failures); this.clusterStats = clusterStats; } diff --git a/src/main/java/com/o19s/es/ltr/action/ListStoresAction.java b/src/main/java/com/o19s/es/ltr/action/ListStoresAction.java index 37849d58..9c0c5773 100644 --- a/src/main/java/com/o19s/es/ltr/action/ListStoresAction.java +++ b/src/main/java/com/o19s/es/ltr/action/ListStoresAction.java @@ -16,14 +16,18 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.ListStoresAction.ListStoresActionResponse; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + import org.opensearch.action.ActionRequestBuilder; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.action.ActionResponse; import org.opensearch.action.ActionType; import org.opensearch.action.support.master.MasterNodeReadRequest; import org.opensearch.client.OpenSearchClient; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -32,11 +36,8 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; +import com.o19s.es.ltr.action.ListStoresAction.ListStoresActionResponse; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; public class ListStoresAction extends ActionType { public static final String NAME = "cluster:admin/ltr/featurestore/list"; @@ -80,9 +81,7 @@ public ListStoresActionResponse(List info) { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject() - .field("stores", stores) - .endObject(); + return builder.startObject().field("stores", stores).endObject(); } @Override @@ -107,6 +106,7 @@ public IndexStoreInfo(String indexName, int version, Map counts this.version = version; this.counts = counts; } + public IndexStoreInfo(StreamInput in) throws IOException { storeName = in.readString(); indexName = in.readString(); @@ -124,7 +124,8 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject() + return builder + .startObject() .field("store", storeName) .field("index", indexName) .field("version", version) @@ -149,9 +150,8 @@ public Map getCounts() { } } - public static class ListStoresActionBuilder extends - ActionRequestBuilder { - public ListStoresActionBuilder(OpenSearchClient client){ + public static class ListStoresActionBuilder extends ActionRequestBuilder { + public ListStoresActionBuilder(OpenSearchClient client) { super(client, INSTANCE, new ListStoresActionRequest()); } } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportAddFeatureToSetAction.java b/src/main/java/com/o19s/es/ltr/action/TransportAddFeatureToSetAction.java index be2039fd..5695af24 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportAddFeatureToSetAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportAddFeatureToSetAction.java @@ -16,17 +16,14 @@ package com.o19s.es.ltr.action; -import org.opensearch.ltr.breaker.LTRCircuitBreakerService; -import org.opensearch.ltr.exception.LimitExceededException; -import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetRequest; -import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetResponse; -import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreRequest; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import org.opensearch.core.action.ActionListener; +import static org.opensearch.core.action.ActionListener.wrap; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.get.TransportGetAction; @@ -40,21 +37,25 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.CountDown; +import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ltr.breaker.LTRCircuitBreakerService; +import org.opensearch.ltr.exception.LimitExceededException; import org.opensearch.search.SearchHit; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; - -import static org.opensearch.core.action.ActionListener.wrap; +import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetRequest; +import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetResponse; +import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreRequest; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; public class TransportAddFeatureToSetAction extends HandledTransportAction { private final ClusterService clusterService; @@ -64,12 +65,18 @@ public class TransportAddFeatureToSetAction extends HandledTransportAction listener, - ClusterService clusterService, TransportSearchAction searchAction, TransportGetAction getAction, - TransportFeatureStoreAction featureStoreAction) { + AsyncAction( + Task task, + AddFeaturesToSetRequest request, + ActionListener listener, + ClusterService clusterService, + TransportSearchAction searchAction, + TransportGetAction getAction, + TransportFeatureStoreAction featureStoreAction + ) { this.task = task; this.listener = listener; this.featureSetName = request.getFeatureSet(); @@ -153,8 +167,8 @@ private void start() { featuresRef.set(features); } GetRequest getRequest = new GetRequest(store) - .id(StorableElement.generateId(StoredFeatureSet.TYPE, featureSetName)) - .routing(routing); + .id(StorableElement.generateId(StoredFeatureSet.TYPE, featureSetName)) + .routing(routing); getRequest.setParentTask(clusterService.localNode().getId(), task.getId()); getAction.execute(getRequest, wrap(this::onGetResponse, this::onGetFailure)); @@ -178,7 +192,7 @@ private void fetchFeaturesFromStore() { BoolQueryBuilder bq = QueryBuilders.boolQuery(); bq.must(nameQuery); bq.must(QueryBuilders.matchQuery("type", StoredFeature.TYPE)); - //srequest.types(IndexFeatureStore.ES_TYPE); + // srequest.types(IndexFeatureStore.ES_TYPE); srequest.source().query(bq); srequest.source().fetchSource(true); srequest.source().size(StoredFeatureSet.MAX_FEATURES); @@ -280,16 +294,15 @@ private void updateSet(StoredFeatureSet set) { long version = this.version.get(); final FeatureStoreRequest frequest; if (version > 0) { - frequest = new FeatureStoreRequest(store, set, version); + frequest = new FeatureStoreRequest(store, set, version); } else { frequest = new FeatureStoreRequest(store, set, FeatureStoreRequest.Action.CREATE); } frequest.setRouting(routing); frequest.setParentTask(clusterService.localNode().getId(), task.getId()); frequest.setValidation(validation); - featureStoreAction.execute(frequest, wrap( - (r) -> listener.onResponse(new AddFeaturesToSetResponse(r.getResponse())), - listener::onFailure)); + featureStoreAction + .execute(frequest, wrap((r) -> listener.onResponse(new AddFeaturesToSetResponse(r.getResponse())), listener::onFailure)); } } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportCacheStatsAction.java b/src/main/java/com/o19s/es/ltr/action/TransportCacheStatsAction.java index f22e2269..164a73da 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportCacheStatsAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportCacheStatsAction.java @@ -16,10 +16,9 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.CachesStatsAction.CachesStatsNodeResponse; -import com.o19s.es.ltr.action.CachesStatsAction.CachesStatsNodesRequest; -import com.o19s.es.ltr.action.CachesStatsAction.CachesStatsNodesResponse; -import com.o19s.es.ltr.feature.store.index.Caches; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.BaseNodeRequest; @@ -27,33 +26,50 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.List; +import com.o19s.es.ltr.action.CachesStatsAction.CachesStatsNodeResponse; +import com.o19s.es.ltr.action.CachesStatsAction.CachesStatsNodesRequest; +import com.o19s.es.ltr.action.CachesStatsAction.CachesStatsNodesResponse; +import com.o19s.es.ltr.feature.store.index.Caches; -public class TransportCacheStatsAction extends TransportNodesAction { +public class TransportCacheStatsAction extends + TransportNodesAction { private final Caches caches; @Inject - public TransportCacheStatsAction(Settings settings, ThreadPool threadPool, - ClusterService clusterService, TransportService transportService, - ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, - Caches caches) { - super(CachesStatsAction.NAME, threadPool, clusterService, transportService, - actionFilters, CachesStatsNodesRequest::new, CachesStatsNodeRequest::new, - ThreadPool.Names.MANAGEMENT, CachesStatsAction.CachesStatsNodeResponse.class); + public TransportCacheStatsAction( + Settings settings, + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + Caches caches + ) { + super( + CachesStatsAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + CachesStatsNodesRequest::new, + CachesStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + CachesStatsAction.CachesStatsNodeResponse.class + ); this.caches = caches; } @Override - protected CachesStatsNodesResponse newResponse(CachesStatsNodesRequest request, List responses, - List failures) { + protected CachesStatsNodesResponse newResponse( + CachesStatsNodesRequest request, + List responses, + List failures + ) { return new CachesStatsNodesResponse(clusterService.getClusterName(), responses, failures); } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportClearCachesAction.java b/src/main/java/com/o19s/es/ltr/action/TransportClearCachesAction.java index e93dc8f0..16c70dd2 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportClearCachesAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportClearCachesAction.java @@ -16,10 +16,9 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodeResponse; -import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesRequest; -import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesResponse; -import com.o19s.es.ltr.feature.store.index.Caches; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.BaseNodeRequest; @@ -27,33 +26,51 @@ import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.List; +import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodeResponse; +import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesRequest; +import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesResponse; +import com.o19s.es.ltr.feature.store.index.Caches; -public class TransportClearCachesAction extends TransportNodesAction { +public class TransportClearCachesAction extends + TransportNodesAction { private final Caches caches; @Inject - public TransportClearCachesAction(Settings settings, ThreadPool threadPool, - ClusterService clusterService, TransportService transportService, - ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, - Caches caches) { - super(ClearCachesAction.NAME, threadPool, clusterService, transportService, actionFilters, - ClearCachesNodesRequest::new, ClearCachesNodeRequest::new, ThreadPool.Names.MANAGEMENT, ClearCachesNodeResponse.class); + public TransportClearCachesAction( + Settings settings, + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + Caches caches + ) { + super( + ClearCachesAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + ClearCachesNodesRequest::new, + ClearCachesNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + ClearCachesNodeResponse.class + ); this.caches = caches; } @Override - protected ClearCachesNodesResponse newResponse(ClearCachesNodesRequest request, List responses, - List failures) { + protected ClearCachesNodesResponse newResponse( + ClearCachesNodesRequest request, + List responses, + List failures + ) { return new ClearCachesNodesResponse(clusterService.getClusterName(), responses, failures); } @@ -71,20 +88,20 @@ protected ClearCachesNodeResponse newNodeResponse(StreamInput in) throws IOExcep protected ClearCachesNodeResponse nodeOperation(ClearCachesNodeRequest request) { ClearCachesNodesRequest r = request.request; switch (r.getOperation()) { - case ClearStore: - caches.evict(r.getStore()); - break; - case ClearFeature: - caches.evictFeature(r.getStore(), r.getName()); - break; - case ClearFeatureSet: - caches.evictFeatureSet(r.getStore(), r.getName()); - break; - case ClearModel: - caches.evictModel(r.getStore(), r.getName()); - break; - default: - throw new RuntimeException("Unsupported operation [" + r.getOperation() + "]"); + case ClearStore: + caches.evict(r.getStore()); + break; + case ClearFeature: + caches.evictFeature(r.getStore(), r.getName()); + break; + case ClearFeatureSet: + caches.evictFeatureSet(r.getStore(), r.getName()); + break; + case ClearModel: + caches.evictModel(r.getStore(), r.getName()); + break; + default: + throw new RuntimeException("Unsupported operation [" + r.getOperation() + "]"); } return new ClearCachesNodeResponse(clusterService.localNode()); } @@ -103,7 +120,6 @@ public ClearCachesNodeRequest(ClearCachesNodesRequest req) { request = new ClearCachesNodesRequest(in); } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/src/main/java/com/o19s/es/ltr/action/TransportCreateModelFromSetAction.java b/src/main/java/com/o19s/es/ltr/action/TransportCreateModelFromSetAction.java index 46e43bee..1abc5387 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportCreateModelFromSetAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportCreateModelFromSetAction.java @@ -16,16 +16,8 @@ package com.o19s.es.ltr.action; -import org.opensearch.ltr.breaker.LTRCircuitBreakerService; -import org.opensearch.ltr.exception.LimitExceededException; -import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetRequest; -import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetResponse; -import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreRequest; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import org.opensearch.core.action.ActionListener; +import java.io.IOException; + import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.get.TransportGetAction; @@ -35,11 +27,20 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ltr.breaker.LTRCircuitBreakerService; +import org.opensearch.ltr.exception.LimitExceededException; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.io.IOException; +import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetRequest; +import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetResponse; +import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreRequest; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; public class TransportCreateModelFromSetAction extends HandledTransportAction { private final ClusterService clusterService; @@ -48,12 +49,17 @@ public class TransportCreateModelFromSetAction extends HandledTransportAction this.doStore(task, r, request, listener), listener::onFailure)); } - private void doStore(Task parentTask, GetResponse response, CreateModelFromSetRequest request, - ActionListener listener) { + private void doStore( + Task parentTask, + GetResponse response, + CreateModelFromSetRequest request, + ActionListener listener + ) { if (!response.isExists()) { throw new IllegalArgumentException("Stored feature set [" + request.getFeatureSetName() + "] does not exist"); } if (request.getExpectedSetVersion() != null && request.getExpectedSetVersion() != response.getVersion()) { - throw new IllegalArgumentException("Stored feature set [" + request.getFeatureSetName() + "]" + - " has version [" + response.getVersion() + "] but [" + request.getExpectedSetVersion() + "] was expected."); + throw new IllegalArgumentException( + "Stored feature set [" + + request.getFeatureSetName() + + "]" + + " has version [" + + response.getVersion() + + "] but [" + + request.getExpectedSetVersion() + + "] was expected." + ); } final StoredFeatureSet set; try { set = IndexFeatureStore.parse(StoredFeatureSet.class, StoredFeatureSet.TYPE, response.getSourceAsBytesRef()); - } catch(IOException ioe) { + } catch (IOException ioe) { throw new IllegalStateException("Cannot parse stored feature set [" + request.getFeatureSetName() + "]", ioe); } // Model will be parsed & checked by TransportFeatureStoreAction @@ -97,9 +116,11 @@ private void doStore(Task parentTask, GetResponse response, CreateModelFromSetRe featureStoreRequest.setRouting(request.getRouting()); featureStoreRequest.setParentTask(clusterService.localNode().getId(), parentTask.getId()); featureStoreRequest.setValidation(request.getValidation()); - featureStoreAction.execute(featureStoreRequest, ActionListener.wrap( - (r) -> listener.onResponse(new CreateModelFromSetResponse(r.getResponse())), - listener::onFailure)); + featureStoreAction + .execute( + featureStoreRequest, + ActionListener.wrap((r) -> listener.onResponse(new CreateModelFromSetResponse(r.getResponse())), listener::onFailure) + ); } } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportFeatureStoreAction.java b/src/main/java/com/o19s/es/ltr/action/TransportFeatureStoreAction.java index d67852c7..91023b29 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportFeatureStoreAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportFeatureStoreAction.java @@ -16,23 +16,13 @@ package com.o19s.es.ltr.action; -import org.opensearch.ltr.breaker.LTRCircuitBreakerService; -import org.opensearch.ltr.exception.LimitExceededException; -import org.opensearch.ltr.stats.LTRStats; -import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesRequest; -import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreRequest; -import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreResponse; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import com.o19s.es.ltr.query.ValidatingLtrQueryBuilder; -import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; +import static org.opensearch.core.action.ActionListener.wrap; + +import java.io.IOException; +import java.util.Optional; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.index.IndexAction; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchAction; @@ -44,13 +34,24 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ltr.breaker.LTRCircuitBreakerService; +import org.opensearch.ltr.exception.LimitExceededException; +import org.opensearch.ltr.stats.LTRStats; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.Optional; - -import static org.opensearch.core.action.ActionListener.wrap; +import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesRequest; +import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreRequest; +import com.o19s.es.ltr.action.FeatureStoreAction.FeatureStoreResponse; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import com.o19s.es.ltr.query.ValidatingLtrQueryBuilder; +import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; public class TransportFeatureStoreAction extends HandledTransportAction { private final LtrRankerParserFactory factory; @@ -61,15 +62,17 @@ public class TransportFeatureStoreAction extends HandledTransportAction store(request, task, listener), ltrStats); + validate(request.getValidation(), request.getStorableElement(), task, listener, () -> store(request, task, listener), ltrStats); } else { store(request, task, listener); } @@ -102,15 +108,15 @@ protected void doExecute(Task task, FeatureStoreRequest request, ActionListener< private Optional buildClearCache(FeatureStoreRequest request) throws IOException { if (request.getAction() == FeatureStoreRequest.Action.UPDATE) { - ClearCachesAction.ClearCachesNodesRequest clearCachesNodesRequest = new ClearCachesAction.ClearCachesNodesRequest(); - switch (request.getStorableElement().type()) { - case StoredFeature.TYPE: - clearCachesNodesRequest.clearFeature(request.getStore(), request.getStorableElement().name()); - return Optional.of(clearCachesNodesRequest); - case StoredFeatureSet.TYPE: - clearCachesNodesRequest.clearFeatureSet(request.getStore(), request.getStorableElement().name()); - return Optional.of(clearCachesNodesRequest); - } + ClearCachesAction.ClearCachesNodesRequest clearCachesNodesRequest = new ClearCachesAction.ClearCachesNodesRequest(); + switch (request.getStorableElement().type()) { + case StoredFeature.TYPE: + clearCachesNodesRequest.clearFeature(request.getStore(), request.getStorableElement().name()); + return Optional.of(clearCachesNodesRequest); + case StoredFeatureSet.TYPE: + clearCachesNodesRequest.clearFeatureSet(request.getStore(), request.getStorableElement().name()); + return Optional.of(clearCachesNodesRequest); + } } return Optional.empty(); } @@ -118,13 +124,14 @@ private Optional buildClearCache(FeatureStoreRequest re private IndexRequest buildIndexRequest(Task parentTask, FeatureStoreRequest request) throws IOException { StorableElement elt = request.getStorableElement(); - IndexRequest indexRequest = client.prepareIndex(request.getStore()) - .setId(elt.id()) - .setCreate(request.getAction() == FeatureStoreRequest.Action.CREATE) - .setRouting(request.getRouting()) - .setSource(IndexFeatureStore.toSource(elt)) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .request(); + IndexRequest indexRequest = client + .prepareIndex(request.getStore()) + .setId(elt.id()) + .setCreate(request.getAction() == FeatureStoreRequest.Action.CREATE) + .setRouting(request.getRouting()) + .setSource(IndexFeatureStore.toSource(elt)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .request(); indexRequest.setParentTask(clusterService.localNode().getId(), parentTask.getId()); return indexRequest; } @@ -138,8 +145,10 @@ private void precheck(FeatureStoreRequest request) { try { model.compile(factory); } catch (Exception e) { - throw new IllegalArgumentException("Error while parsing model [" + model.name() + "]" + - " with type [" + model.rankingModelType() + "]", e); + throw new IllegalArgumentException( + "Error while parsing model [" + model.name() + "]" + " with type [" + model.rankingModelType() + "]", + e + ); } } else if (request.getStorableElement() instanceof StoredFeatureSet) { StoredFeatureSet set = (StoredFeatureSet) request.getStorableElement(); @@ -160,14 +169,15 @@ private void precheck(FeatureStoreRequest request) { * @param onSuccess action ro run when the validation is successfull * @param ltrStats LTR stats */ - private void validate(FeatureValidation validation, - StorableElement element, - Task task, - ActionListener listener, - Runnable onSuccess, - LTRStats ltrStats) { - ValidatingLtrQueryBuilder ltrBuilder = new ValidatingLtrQueryBuilder(element, - validation, factory, ltrStats); + private void validate( + FeatureValidation validation, + StorableElement element, + Task task, + ActionListener listener, + Runnable onSuccess, + LTRStats ltrStats + ) { + ValidatingLtrQueryBuilder ltrBuilder = new ValidatingLtrQueryBuilder(element, validation, factory, ltrStats); SearchRequestBuilder builder = new SearchRequestBuilder(client, SearchAction.INSTANCE); builder.setIndices(validation.getIndex()); builder.setQuery(ltrBuilder); @@ -177,14 +187,15 @@ private void validate(FeatureValidation validation, builder.setTerminateAfter(1000); builder.request().setParentTask(clusterService.localNode().getId(), task.getId()); builder.execute(wrap((r) -> { - if (r.getFailedShards() > 0) { - ShardSearchFailure failure = r.getShardFailures()[0]; - throw new IllegalArgumentException("Validating the element caused " + r.getFailedShards() + - " shard failures, see root cause: " + failure.reason(), failure.getCause()); - } - onSuccess.run(); - }, - (e) -> listener.onFailure(new IllegalArgumentException("Cannot store element, validation failed.", e)))); + if (r.getFailedShards() > 0) { + ShardSearchFailure failure = r.getShardFailures()[0]; + throw new IllegalArgumentException( + "Validating the element caused " + r.getFailedShards() + " shard failures, see root cause: " + failure.reason(), + failure.getCause() + ); + } + onSuccess.run(); + }, (e) -> listener.onFailure(new IllegalArgumentException("Cannot store element, validation failed.", e)))); } /** @@ -195,17 +206,13 @@ private void store(FeatureStoreRequest request, Task task, ActionListener clearCachesNodesRequest = buildClearCache(request); IndexRequest indexRequest = buildIndexRequest(task, request); - client.execute(IndexAction.INSTANCE, indexRequest, wrap( - (r) -> { - // Run and forget, log only if something bad happens - // but don't wait for the action to be done nor set the parent task. - clearCachesNodesRequest.ifPresent((req) -> clearCachesAction.execute(req, wrap( - (r2) -> { - }, - (e) -> logger.error("Failed to clear cache", e)))); - listener.onResponse(new FeatureStoreResponse(r)); - }, - listener::onFailure)); + client.execute(IndexAction.INSTANCE, indexRequest, wrap((r) -> { + // Run and forget, log only if something bad happens + // but don't wait for the action to be done nor set the parent task. + clearCachesNodesRequest + .ifPresent((req) -> clearCachesAction.execute(req, wrap((r2) -> {}, (e) -> logger.error("Failed to clear cache", e)))); + listener.onResponse(new FeatureStoreResponse(r)); + }, listener::onFailure)); } catch (IOException ioe) { listener.onFailure(ioe); } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportLTRStatsAction.java b/src/main/java/com/o19s/es/ltr/action/TransportLTRStatsAction.java index ea7d5916..70993b12 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportLTRStatsAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportLTRStatsAction.java @@ -16,54 +16,67 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodeRequest; -import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodeResponse; -import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodesRequest; -import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodesResponse; -import org.opensearch.ltr.stats.LTRStats; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.ltr.stats.LTRStats; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; +import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodeRequest; +import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodeResponse; +import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodesRequest; +import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodesResponse; public class TransportLTRStatsAction extends - TransportNodesAction { + TransportNodesAction { private final LTRStats ltrStats; @Inject - public TransportLTRStatsAction(ThreadPool threadPool, - ClusterService clusterService, - TransportService transportService, - ActionFilters actionFilters, - LTRStats ltrStats) { - super(LTRStatsAction.NAME, threadPool, clusterService, transportService, - actionFilters, LTRStatsNodesRequest::new, LTRStatsNodeRequest::new, - ThreadPool.Names.MANAGEMENT, LTRStatsNodeResponse.class); + public TransportLTRStatsAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + LTRStats ltrStats + ) { + super( + LTRStatsAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + LTRStatsNodesRequest::new, + LTRStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + LTRStatsNodeResponse.class + ); this.ltrStats = ltrStats; } @Override - protected LTRStatsNodesResponse newResponse(LTRStatsNodesRequest request, - List nodeResponses, - List failures) { + protected LTRStatsNodesResponse newResponse( + LTRStatsNodesRequest request, + List nodeResponses, + List failures + ) { Set statsToBeRetrieved = request.getStatsToBeRetrieved(); - Map clusterStats = - ltrStats.getClusterStats() - .entrySet() - .stream() - .filter(e -> statsToBeRetrieved.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue())); + Map clusterStats = ltrStats + .getClusterStats() + .entrySet() + .stream() + .filter(e -> statsToBeRetrieved.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue())); return new LTRStatsNodesResponse(clusterService.getClusterName(), nodeResponses, failures, clusterStats); } @@ -83,12 +96,12 @@ protected LTRStatsNodeResponse nodeOperation(LTRStatsNodeRequest request) { LTRStatsNodesRequest nodesRequest = request.getLTRStatsNodesRequest(); Set statsToBeRetrieved = nodesRequest.getStatsToBeRetrieved(); - Map statValues = - ltrStats.getNodeStats() - .entrySet() - .stream() - .filter(e -> statsToBeRetrieved.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue())); + Map statValues = ltrStats + .getNodeStats() + .entrySet() + .stream() + .filter(e -> statsToBeRetrieved.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue())); return new LTRStatsNodeResponse(clusterService.localNode(), statValues); } } diff --git a/src/main/java/com/o19s/es/ltr/action/TransportListStoresAction.java b/src/main/java/com/o19s/es/ltr/action/TransportListStoresAction.java index d980ca68..cad0e88c 100644 --- a/src/main/java/com/o19s/es/ltr/action/TransportListStoresAction.java +++ b/src/main/java/com/o19s/es/ltr/action/TransportListStoresAction.java @@ -16,10 +16,20 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.ListStoresAction.ListStoresActionRequest; -import com.o19s.es.ltr.action.ListStoresAction.ListStoresActionResponse; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import org.opensearch.core.action.ActionListener; +import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.STORE_VERSION_PROP; +import static java.util.stream.Collectors.toMap; +import static org.opensearch.common.collect.Tuple.tuple; +import static org.opensearch.core.action.ActionListener.wrap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Stream; + import org.opensearch.action.admin.cluster.state.ClusterStateRequest; import org.opensearch.action.search.MultiSearchRequestBuilder; import org.opensearch.action.search.MultiSearchResponse; @@ -34,8 +44,9 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.collect.Tuple; import org.opensearch.common.inject.Inject; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.bucket.MultiBucketsAggregation; @@ -43,32 +54,35 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Stream; - -import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.STORE_VERSION_PROP; -import static java.util.stream.Collectors.toMap; -import static org.opensearch.core.action.ActionListener.wrap; -import static org.opensearch.common.collect.Tuple.tuple; +import com.o19s.es.ltr.action.ListStoresAction.ListStoresActionRequest; +import com.o19s.es.ltr.action.ListStoresAction.ListStoresActionResponse; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; public class TransportListStoresAction extends TransportMasterNodeReadAction { private final Client client; @Inject - public TransportListStoresAction(Settings settings, TransportService transportService,ClusterService clusterService, - ThreadPool threadPool, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver, Client client) { - super(ListStoresAction.NAME, transportService, clusterService, threadPool, - actionFilters, ListStoresActionRequest::new, indexNameExpressionResolver); + public TransportListStoresAction( + Settings settings, + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + Client client + ) { + super( + ListStoresAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + ListStoresActionRequest::new, + indexNameExpressionResolver + ); this.client = client; } - + @Override protected String executor() { return ThreadPool.Names.SAME; @@ -80,21 +94,25 @@ protected ListStoresActionResponse read(StreamInput in) throws IOException { } @Override - protected void masterOperation(ListStoresActionRequest request, ClusterState state, - ActionListener listener) throws Exception { - String[] names = indexNameExpressionResolver.concreteIndexNames(state, - new ClusterStateRequest().indices(IndexFeatureStore.DEFAULT_STORE, IndexFeatureStore.STORE_PREFIX + "*")); + protected void masterOperation(ListStoresActionRequest request, ClusterState state, ActionListener listener) + throws Exception { + String[] names = indexNameExpressionResolver + .concreteIndexNames( + state, + new ClusterStateRequest().indices(IndexFeatureStore.DEFAULT_STORE, IndexFeatureStore.STORE_PREFIX + "*") + ); final MultiSearchRequestBuilder req = client.prepareMultiSearch(); final List> versions = new ArrayList<>(); - Stream.of(names) - .filter(IndexFeatureStore::isIndexStore) - .map((s) -> clusterService.state().metadata().getIndices().get(s)) - .filter(Objects::nonNull) - .filter((im) -> STORE_VERSION_PROP.exists(im.getSettings())) - .forEach((m) -> { - req.add(countSearchRequest(m)); - versions.add(tuple(m.getIndex().getName(),STORE_VERSION_PROP.get(m.getSettings()))); - }); + Stream + .of(names) + .filter(IndexFeatureStore::isIndexStore) + .map((s) -> clusterService.state().metadata().getIndices().get(s)) + .filter(Objects::nonNull) + .filter((im) -> STORE_VERSION_PROP.exists(im.getSettings())) + .forEach((m) -> { + req.add(countSearchRequest(m)); + versions.add(tuple(m.getIndex().getName(), STORE_VERSION_PROP.get(m.getSettings()))); + }); if (versions.isEmpty()) { listener.onResponse(new ListStoresActionResponse(Collections.emptyList())); } else { @@ -102,12 +120,12 @@ protected void masterOperation(ListStoresActionRequest request, ClusterState sta } } - private SearchRequestBuilder countSearchRequest(IndexMetadata meta) { - return client.prepareSearch(meta.getIndex().getName()) - .setQuery(QueryBuilders.matchAllQuery()) - .setSize(0) - .addAggregation(AggregationBuilders.terms("type").field("type").size(100)); + return client + .prepareSearch(meta.getIndex().getName()) + .setQuery(QueryBuilders.matchAllQuery()) + .setSize(0) + .addAggregation(AggregationBuilders.terms("type").field("type").size(100)); } private ListStoresActionResponse toResponse(MultiSearchResponse response, List> versions) { @@ -120,14 +138,11 @@ private ListStoresActionResponse toResponse(MultiSearchResponse response, List idxAndVersion = vs.next(); Map counts = Collections.emptyMap(); if (!it.isFailure()) { - Terms aggs = it.getResponse() - .getAggregations() - .get("type"); + Terms aggs = it.getResponse().getAggregations().get("type"); counts = aggs - .getBuckets() - .stream() - .collect(toMap(MultiBucketsAggregation.Bucket::getKeyAsString, - (b) -> (int) b.getDocCount())); + .getBuckets() + .stream() + .collect(toMap(MultiBucketsAggregation.Bucket::getKeyAsString, (b) -> (int) b.getDocCount())); } infos.add(new ListStoresAction.IndexStoreInfo(idxAndVersion.v1(), idxAndVersion.v2(), counts)); } diff --git a/src/main/java/com/o19s/es/ltr/feature/Feature.java b/src/main/java/com/o19s/es/ltr/feature/Feature.java index 4337a038..9e17a269 100644 --- a/src/main/java/com/o19s/es/ltr/feature/Feature.java +++ b/src/main/java/com/o19s/es/ltr/feature/Feature.java @@ -16,10 +16,11 @@ package com.o19s.es.ltr.feature; -import com.o19s.es.ltr.LtrQueryContext; +import java.util.Map; + import org.apache.lucene.search.Query; -import java.util.Map; +import com.o19s.es.ltr.LtrQueryContext; /** * A feature that can be transformed into a lucene query @@ -57,6 +58,5 @@ default Feature optimize() { * * @param set the feature-set to validate the current feature against */ - default void validate(FeatureSet set) { - } + default void validate(FeatureSet set) {} } diff --git a/src/main/java/com/o19s/es/ltr/feature/FeatureSet.java b/src/main/java/com/o19s/es/ltr/feature/FeatureSet.java index 5cc9ac8d..cd4ace70 100644 --- a/src/main/java/com/o19s/es/ltr/feature/FeatureSet.java +++ b/src/main/java/com/o19s/es/ltr/feature/FeatureSet.java @@ -16,12 +16,13 @@ package com.o19s.es.ltr.feature; -import com.o19s.es.ltr.LtrQueryContext; -import org.apache.lucene.search.Query; - import java.util.List; import java.util.Map; +import org.apache.lucene.search.Query; + +import com.o19s.es.ltr.LtrQueryContext; + /** * A set of features. * Features can be identified by their name or ordinal. @@ -88,6 +89,5 @@ default FeatureSet optimize() { return this; } - default void validate() { - } + default void validate() {} } diff --git a/src/main/java/com/o19s/es/ltr/feature/FeatureValidation.java b/src/main/java/com/o19s/es/ltr/feature/FeatureValidation.java index df12a97b..72c90584 100644 --- a/src/main/java/com/o19s/es/ltr/feature/FeatureValidation.java +++ b/src/main/java/com/o19s/es/ltr/feature/FeatureValidation.java @@ -16,6 +16,10 @@ package com.o19s.es.ltr.feature; +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -26,10 +30,6 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.Map; -import java.util.Objects; - /** * Simple object to store the parameters needed to validate stored elements: * - The list of template params to replace @@ -37,8 +37,10 @@ */ public class FeatureValidation implements Writeable, ToXContentObject { @SuppressWarnings("unchecked") - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("feature_validation", - (Object[] args) -> new FeatureValidation((String) args[0], (Map) args[1])); + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "feature_validation", + (Object[] args) -> new FeatureValidation((String) args[0], (Map) args[1]) + ); public static final ParseField INDEX = new ParseField("index"); @@ -46,8 +48,7 @@ public class FeatureValidation implements Writeable, ToXContentObject { static { PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX); - PARSER.declareField(ConstructingObjectParser.constructorArg(), XContentParser::map, - PARAMS, ObjectParser.ValueType.OBJECT); + PARSER.declareField(ConstructingObjectParser.constructorArg(), XContentParser::map, PARAMS, ObjectParser.ValueType.OBJECT); } private final String index; @@ -79,11 +80,12 @@ public Map getParams() { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; FeatureValidation that = (FeatureValidation) o; - return Objects.equals(index, that.index) && - Objects.equals(params, that.params); + return Objects.equals(index, that.index) && Objects.equals(params, that.params); } @Override @@ -93,9 +95,6 @@ public int hashCode() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject() - .field(INDEX.getPreferredName(), index) - .field(PARAMS.getPreferredName(), this.params) - .endObject(); + return builder.startObject().field(INDEX.getPreferredName(), index).field(PARAMS.getPreferredName(), this.params).endObject(); } } diff --git a/src/main/java/com/o19s/es/ltr/feature/PrebuiltFeature.java b/src/main/java/com/o19s/es/ltr/feature/PrebuiltFeature.java index 637ed17f..7fb873d3 100644 --- a/src/main/java/com/o19s/es/ltr/feature/PrebuiltFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/PrebuiltFeature.java @@ -16,20 +16,21 @@ package com.o19s.es.ltr.feature; -import com.o19s.es.ltr.LtrQueryContext; +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.Weight; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; import org.opensearch.common.Nullable; import org.opensearch.ltr.settings.LTRSettings; -import java.io.IOException; -import java.util.Map; -import java.util.Objects; +import com.o19s.es.ltr.LtrQueryContext; /** * A prebuilt featured query, needed by query builders @@ -44,7 +45,8 @@ public PrebuiltFeature(@Nullable String name, Query query) { this.query = Objects.requireNonNull(query); } - @Override @Nullable + @Override + @Nullable public String name() { return name; } @@ -69,8 +71,7 @@ public boolean equals(Object o) { return false; } PrebuiltFeature other = (PrebuiltFeature) o; - return Objects.equals(name, other.name) - && Objects.equals(query, other.query); + return Objects.equals(name, other.name) && Objects.equals(query, other.query); } @Override diff --git a/src/main/java/com/o19s/es/ltr/feature/PrebuiltFeatureSet.java b/src/main/java/com/o19s/es/ltr/feature/PrebuiltFeatureSet.java index 27e2e730..eac4367e 100644 --- a/src/main/java/com/o19s/es/ltr/feature/PrebuiltFeatureSet.java +++ b/src/main/java/com/o19s/es/ltr/feature/PrebuiltFeatureSet.java @@ -16,16 +16,17 @@ package com.o19s.es.ltr.feature; -import com.o19s.es.ltr.LtrQueryContext; -import org.apache.lucene.search.Query; -import org.opensearch.common.Nullable; - import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.IntStream; +import org.apache.lucene.search.Query; +import org.opensearch.common.Nullable; + +import com.o19s.es.ltr.LtrQueryContext; + public class PrebuiltFeatureSet implements FeatureSet { private final List features; private final String name; @@ -81,9 +82,10 @@ private int findFeatureIndexByName(String featureName) { // slow, not meant for runtime usage, mostly needed for tests // would make sense to implement a Map to do this once // feature names are mandatory and unique. - return IntStream.range(0, features.size()) - .filter(i -> Objects.equals(((PrebuiltFeature)features.get(i)).name(), featureName)) - .findFirst() - .orElse(-1); + return IntStream + .range(0, features.size()) + .filter(i -> Objects.equals(((PrebuiltFeature) features.get(i)).name(), featureName)) + .findFirst() + .orElse(-1); } } diff --git a/src/main/java/com/o19s/es/ltr/feature/PrebuiltLtrModel.java b/src/main/java/com/o19s/es/ltr/feature/PrebuiltLtrModel.java index 749b5d50..4234f198 100644 --- a/src/main/java/com/o19s/es/ltr/feature/PrebuiltLtrModel.java +++ b/src/main/java/com/o19s/es/ltr/feature/PrebuiltLtrModel.java @@ -16,10 +16,10 @@ package com.o19s.es.ltr.feature; -import com.o19s.es.ltr.ranker.LtrRanker; - import java.util.Objects; +import com.o19s.es.ltr.ranker.LtrRanker; + /** * Prebuilt model */ diff --git a/src/main/java/com/o19s/es/ltr/feature/store/CompiledLtrModel.java b/src/main/java/com/o19s/es/ltr/feature/store/CompiledLtrModel.java index 13c279ee..22780400 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/CompiledLtrModel.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/CompiledLtrModel.java @@ -16,14 +16,15 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.LtrModel; -import com.o19s.es.ltr.ranker.LtrRanker; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; + import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.LtrModel; +import com.o19s.es.ltr.ranker.LtrRanker; public class CompiledLtrModel implements LtrModel, Accountable { private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(StoredLtrModel.class); @@ -67,9 +68,10 @@ public FeatureSet featureSet() { */ @Override public long ramBytesUsed() { - return BASE_RAM_USED + name.length() * Character.BYTES + NUM_BYTES_ARRAY_HEADER - + (set instanceof Accountable ? ((Accountable)set).ramBytesUsed() : set.size() * NUM_BYTES_OBJECT_HEADER) - + (ranker instanceof Accountable ? - ((Accountable)ranker).ramBytesUsed() : set.size() * NUM_BYTES_OBJECT_HEADER); + return BASE_RAM_USED + name.length() * Character.BYTES + NUM_BYTES_ARRAY_HEADER + (set instanceof Accountable + ? ((Accountable) set).ramBytesUsed() + : set.size() * NUM_BYTES_OBJECT_HEADER) + (ranker instanceof Accountable + ? ((Accountable) ranker).ramBytesUsed() + : set.size() * NUM_BYTES_OBJECT_HEADER); } } diff --git a/src/main/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplier.java b/src/main/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplier.java index 6279f217..45224123 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplier.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplier.java @@ -19,11 +19,10 @@ import java.util.Map; import java.util.function.Supplier; +public class ExtraLoggingSupplier implements Supplier> { + protected Supplier> supplier; -public class ExtraLoggingSupplier implements Supplier> { - protected Supplier> supplier; - - public void setSupplier(Supplier> supplier) { + public void setSupplier(Supplier> supplier) { this.supplier = supplier; } @@ -39,4 +38,4 @@ public Map get() { } return null; } -} \ No newline at end of file +} diff --git a/src/main/java/com/o19s/es/ltr/feature/store/FeatureNormDefinition.java b/src/main/java/com/o19s/es/ltr/feature/store/FeatureNormDefinition.java index 9b3b25fc..e51f5720 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/FeatureNormDefinition.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/FeatureNormDefinition.java @@ -16,10 +16,11 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; + /** * Parsed feature norm from model definition */ diff --git a/src/main/java/com/o19s/es/ltr/feature/store/FeatureStore.java b/src/main/java/com/o19s/es/ltr/feature/store/FeatureStore.java index d4eed864..68ec9b79 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/FeatureStore.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/FeatureStore.java @@ -16,11 +16,11 @@ package com.o19s.es.ltr.feature.store; +import java.io.IOException; + import com.o19s.es.ltr.feature.Feature; import com.o19s.es.ltr.feature.FeatureSet; -import java.io.IOException; - /** * A feature store */ diff --git a/src/main/java/com/o19s/es/ltr/feature/store/FeatureSupplier.java b/src/main/java/com/o19s/es/ltr/feature/store/FeatureSupplier.java index 2f0e828f..452625f2 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/FeatureSupplier.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/FeatureSupplier.java @@ -16,15 +16,14 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.ranker.LtrRanker; - import java.util.AbstractMap; import java.util.AbstractSet; import java.util.Iterator; import java.util.Set; import java.util.function.Supplier; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.LtrRanker; public class FeatureSupplier extends AbstractMap implements Supplier { private Supplier vectorSupplier; @@ -134,4 +133,3 @@ private LtrRanker.FeatureVector getFeatureVector() { } } - diff --git a/src/main/java/com/o19s/es/ltr/feature/store/MinMaxFeatureNormDefinition.java b/src/main/java/com/o19s/es/ltr/feature/store/MinMaxFeatureNormDefinition.java index d1eb39c3..0a6e2955 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/MinMaxFeatureNormDefinition.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/MinMaxFeatureNormDefinition.java @@ -16,15 +16,16 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.ranker.normalizer.MinMaxFeatureNormalizer; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import java.io.IOException; + import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; +import com.o19s.es.ltr.ranker.normalizer.MinMaxFeatureNormalizer; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; /** * Parsing and serialization for a min/max normalizer @@ -112,16 +113,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - @Override public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof MinMaxFeatureNormDefinition)) return false; + if (this == o) + return true; + if (!(o instanceof MinMaxFeatureNormDefinition)) + return false; MinMaxFeatureNormDefinition that = (MinMaxFeatureNormDefinition) o; - if (!this.featureName.equals(that.featureName)) return false; - if (this.minimum != that.minimum) return false; - if (this.maximum != that.maximum) return false; + if (!this.featureName.equals(that.featureName)) + return false; + if (this.minimum != that.minimum) + return false; + if (this.maximum != that.maximum) + return false; return true; } diff --git a/src/main/java/com/o19s/es/ltr/feature/store/OptimizedFeatureSet.java b/src/main/java/com/o19s/es/ltr/feature/store/OptimizedFeatureSet.java index 109f4af9..7bf29847 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/OptimizedFeatureSet.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/OptimizedFeatureSet.java @@ -16,21 +16,22 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.LtrQueryContext; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; -import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.Query; -import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.RamUsageEstimator; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; import java.util.ArrayList; import java.util.List; import java.util.Map; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; + +import com.o19s.es.ltr.LtrQueryContext; +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; public class OptimizedFeatureSet implements FeatureSet, Accountable { private final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(StoredFeatureSet.class); @@ -53,8 +54,8 @@ public String name() { @Override public List toQueries(LtrQueryContext context, Map params) { List queries = new ArrayList<>(features.size()); - for(Feature feature : features) { - if(context.isFeatureActive(feature.name())) { + for (Feature feature : features) { + if (context.isFeatureActive(feature.name())) { queries.add(feature.doToQuery(context, this, params)); } else { queries.add(new MatchNoDocsQuery("Feature " + feature.name() + " deactivated")); @@ -101,12 +102,15 @@ public void validate() { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; OptimizedFeatureSet that = (OptimizedFeatureSet) o; - if (!name.equals(that.name)) return false; + if (!name.equals(that.name)) + return false; return features.equals(that.features); } @@ -122,8 +126,9 @@ public int hashCode() { */ @Override public long ramBytesUsed() { - return BASE_RAM_USED + - featureMap.size() * NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_HEADER + NUM_BYTES_ARRAY_HEADER + - features.stream().mapToLong((f) -> f instanceof Accountable ? ((Accountable)f).ramBytesUsed() : 1).sum(); + return BASE_RAM_USED + featureMap.size() * NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_HEADER + NUM_BYTES_ARRAY_HEADER + features + .stream() + .mapToLong((f) -> f instanceof Accountable ? ((Accountable) f).ramBytesUsed() : 1) + .sum(); } } diff --git a/src/main/java/com/o19s/es/ltr/feature/store/PrecompiledExpressionFeature.java b/src/main/java/com/o19s/es/ltr/feature/store/PrecompiledExpressionFeature.java index 6aeaa902..b8846f92 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/PrecompiledExpressionFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/PrecompiledExpressionFeature.java @@ -16,15 +16,7 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.LtrQueryContext; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.query.DerivedExpressionQuery; -import com.o19s.es.ltr.utils.Scripting; -import org.apache.lucene.expressions.Expression; -import org.apache.lucene.search.Query; -import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.RamUsageEstimator; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; import java.util.Arrays; import java.util.Collection; @@ -36,7 +28,16 @@ import java.util.Set; import java.util.stream.Collectors; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import org.apache.lucene.expressions.Expression; +import org.apache.lucene.search.Query; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; + +import com.o19s.es.ltr.LtrQueryContext; +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.query.DerivedExpressionQuery; +import com.o19s.es.ltr.utils.Scripting; public class PrecompiledExpressionFeature implements Feature, Accountable { @@ -64,9 +65,8 @@ public static PrecompiledExpressionFeature compile(StoredFeature feature) { @Override public long ramBytesUsed() { - return BASE_RAM_USED + - (Character.BYTES * name.length()) + NUM_BYTES_ARRAY_HEADER + - (((Character.BYTES * expression.sourceText.length()) + NUM_BYTES_ARRAY_HEADER) * 2); + return BASE_RAM_USED + (Character.BYTES * name.length()) + NUM_BYTES_ARRAY_HEADER + (((Character.BYTES * expression.sourceText + .length()) + NUM_BYTES_ARRAY_HEADER) * 2); } @Override @@ -76,9 +76,10 @@ public String name() { @Override public Query doToQuery(LtrQueryContext context, FeatureSet set, Map params) { - List missingParams = queryParams.stream() - .filter((x) -> params == null || !params.containsKey(x)) - .collect(Collectors.toList()); + List missingParams = queryParams + .stream() + .filter((x) -> params == null || !params.containsKey(x)) + .collect(Collectors.toList()); if (!missingParams.isEmpty()) { String names = missingParams.stream().collect(Collectors.joining(",")); throw new IllegalArgumentException("Missing required param(s): [" + names + "]"); @@ -88,7 +89,6 @@ public Query doToQuery(LtrQueryContext context, FeatureSet set, Map getQueryParamValues() { return getQueryParamValues(); } @@ -126,9 +126,9 @@ public boolean equals(Object o) { PrecompiledExpressionFeature that = (PrecompiledExpressionFeature) o; return Objects.equals(name, that.name) - && Objects.equals(expression, that.expression) - && Objects.equals(queryParams, that.queryParams) - && Objects.equals(expressionVariables, that.expressionVariables); + && Objects.equals(expression, that.expression) + && Objects.equals(queryParams, that.queryParams) + && Objects.equals(expressionVariables, that.expressionVariables); } @Override @@ -140,12 +140,14 @@ public int hashCode() { public void validate(FeatureSet set) { for (String var : expression.variables) { if (!set.hasFeature(var) && !queryParams.contains(var)) { - throw new IllegalArgumentException("Derived feature [" + this.name + "] refers " + - "to unknown feature or parameter: [" + var + "]"); + throw new IllegalArgumentException( + "Derived feature [" + this.name + "] refers " + "to unknown feature or parameter: [" + var + "]" + ); } - if(set.hasFeature(var) && queryParams.contains(var)){ - throw new IllegalArgumentException("Duplicate name " + var + " . " + - "Cannot be used as both feature name and query parameter name"); + if (set.hasFeature(var) && queryParams.contains(var)) { + throw new IllegalArgumentException( + "Duplicate name " + var + " . " + "Cannot be used as both feature name and query parameter name" + ); } } } diff --git a/src/main/java/com/o19s/es/ltr/feature/store/PrecompiledTemplateFeature.java b/src/main/java/com/o19s/es/ltr/feature/store/PrecompiledTemplateFeature.java index 5bf51755..690e6e4f 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/PrecompiledTemplateFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/PrecompiledTemplateFeature.java @@ -16,33 +16,33 @@ package com.o19s.es.ltr.feature.store; -import com.github.mustachejava.Mustache; -import com.o19s.es.ltr.LtrQueryContext; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.template.mustache.MustacheUtils; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + import org.apache.lucene.search.Query; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; -import org.opensearch.core.common.ParsingException; import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.ParsingException; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardException; import org.opensearch.index.query.Rewriteable; -import java.io.IOException; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; -import static org.opensearch.index.query.AbstractQueryBuilder.parseInnerQueryBuilder; +import com.github.mustachejava.Mustache; +import com.o19s.es.ltr.LtrQueryContext; +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.template.mustache.MustacheUtils; public class PrecompiledTemplateFeature implements Feature, Accountable { private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(StoredFeature.class); @@ -68,12 +68,10 @@ public static PrecompiledTemplateFeature compile(StoredFeature feature) { @Override public long ramBytesUsed() { - return BASE_RAM_USED + - (Character.BYTES * name.length()) + NUM_BYTES_ARRAY_HEADER + - queryParams.stream() - .mapToLong(x -> (Character.BYTES * x.length()) + - NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_HEADER + NUM_BYTES_ARRAY_HEADER).sum() + - (((Character.BYTES * templateString.length()) + NUM_BYTES_ARRAY_HEADER) * 2); + return BASE_RAM_USED + (Character.BYTES * name.length()) + NUM_BYTES_ARRAY_HEADER + queryParams + .stream() + .mapToLong(x -> (Character.BYTES * x.length()) + NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_HEADER + NUM_BYTES_ARRAY_HEADER) + .sum() + (((Character.BYTES * templateString.length()) + NUM_BYTES_ARRAY_HEADER) * 2); } @Override @@ -83,9 +81,10 @@ public String name() { @Override public Query doToQuery(LtrQueryContext context, FeatureSet set, Map params) { - List missingParams = queryParams.stream() - .filter((x) -> params == null || !params.containsKey(x)) - .collect(Collectors.toList()); + List missingParams = queryParams + .stream() + .filter((x) -> params == null || !params.containsKey(x)) + .collect(Collectors.toList()); if (!missingParams.isEmpty()) { String names = missingParams.stream().collect(Collectors.joining(",")); throw new IllegalArgumentException("Missing required param(s): [" + names + "]"); @@ -93,9 +92,10 @@ public Query doToQuery(LtrQueryContext context, FeatureSet set, Map queryParams) public static ScriptFeature compile(StoredFeature feature) { try { - XContentParser xContentParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, feature.template()); + XContentParser xContentParser = XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, feature.template()); return new ScriptFeature(feature.name(), Script.parse(xContentParser, "native"), feature.queryParams()); } catch (IOException e) { @@ -130,9 +130,7 @@ public String name() { @Override @SuppressWarnings("unchecked") public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map params) { - List missingParams = queryParams.stream() - .filter((x) -> !params.containsKey(x)) - .collect(Collectors.toList()); + List missingParams = queryParams.stream().filter((x) -> !params.containsKey(x)).collect(Collectors.toList()); if (!missingParams.isEmpty()) { String names = String.join(",", missingParams); throw new IllegalArgumentException("Missing required param(s): [" + names + "]"); @@ -172,7 +170,7 @@ public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map terms; - LtrScript(ScriptScoreFunction function, - FeatureSupplier supplier, - ExtraLoggingSupplier extraLoggingSupplier, - Set terms) { + LtrScript(ScriptScoreFunction function, FeatureSupplier supplier, ExtraLoggingSupplier extraLoggingSupplier, Set terms) { this.function = function; this.supplier = supplier; this.extraLoggingSupplier = extraLoggingSupplier; @@ -270,10 +273,10 @@ static class LtrScript extends Query implements LtrRewritableQuery { @Override public boolean equals(Object o) { - if (this == o) return true; + if (this == o) + return true; LtrScript ol = (LtrScript) o; - return sameClassAs(o) - && Objects.equals(function, ol.function); + return sameClassAs(o) && Objects.equals(function, ol.function); } @Override @@ -311,8 +314,8 @@ public Query ltrRewrite(LtrRewriteContext context) throws IOException { return this; } - @Override - public void visit(QueryVisitor visitor) { + @Override + public void visit(QueryVisitor visitor) { Set fields = terms.stream().map(Term::field).collect(Collectors.toUnmodifiableSet()); for (String field : fields) { if (visitor.acceptField(field) == false) { @@ -320,7 +323,7 @@ public void visit(QueryVisitor visitor) { } } visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this).consumeTerms(this, terms.toArray(new Term[0])); - } + } } static class LtrScriptWeight extends Weight { @@ -330,10 +333,8 @@ static class LtrScriptWeight extends Weight { private final Set terms; private final HashMap termContexts; - LtrScriptWeight(Query query, ScriptScoreFunction function, - Set terms, - IndexSearcher searcher, - ScoreMode scoreMode) throws IOException { + LtrScriptWeight(Query query, ScriptScoreFunction function, Set terms, IndexSearcher searcher, ScoreMode scoreMode) + throws IOException { super(query); this.function = function; this.terms = terms; @@ -394,14 +395,13 @@ public DocIdSetIterator iterator() { */ @Override public float getMaxScore(int upTo) throws IOException { - //TODO?? + // TODO?? return Float.POSITIVE_INFINITY; } }; } - public void extractTerms(Set terms) { - } + public void extractTerms(Set terms) {} @Override public boolean isCacheable(LeafReaderContext ctx) { diff --git a/src/main/java/com/o19s/es/ltr/feature/store/StandardFeatureNormDefinition.java b/src/main/java/com/o19s/es/ltr/feature/store/StandardFeatureNormDefinition.java index 933282a9..876b4702 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/StandardFeatureNormDefinition.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/StandardFeatureNormDefinition.java @@ -16,8 +16,8 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; -import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; +import java.io.IOException; + import org.opensearch.OpenSearchException; import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; @@ -26,7 +26,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; public class StandardFeatureNormDefinition implements FeatureNormDefinition { @@ -39,7 +40,6 @@ public class StandardFeatureNormDefinition implements FeatureNormDefinition { private static final ParseField STD_DEVIATION = new ParseField("standard_deviation"); private static final ParseField MEAN = new ParseField("mean"); - static { PARSER = ObjectParser.fromBuilder("standard", StandardFeatureNormDefinition::new); PARSER.declareFloat(StandardFeatureNormDefinition::setMean, MEAN); @@ -64,8 +64,7 @@ public void setMean(float mean) { public void setStdDeviation(float stdDeviation) { if (stdDeviation <= 0.0f) { - throw new OpenSearchException("Standard Deviation Must Be Positive. " + - " You passed: " + Float.toString(stdDeviation)); + throw new OpenSearchException("Standard Deviation Must Be Positive. " + " You passed: " + Float.toString(stdDeviation)); } this.stdDeviation = stdDeviation; } @@ -110,13 +109,18 @@ public StoredFeatureNormalizers.Type normType() { @Override public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof StandardFeatureNormDefinition)) return false; + if (this == o) + return true; + if (!(o instanceof StandardFeatureNormDefinition)) + return false; StandardFeatureNormDefinition that = (StandardFeatureNormDefinition) o; - if (!this.featureName.equals(that.featureName)) return false; - if (this.stdDeviation != that.stdDeviation) return false; - if (this.mean != that.mean) return false; + if (!this.featureName.equals(that.featureName)) + return false; + if (this.stdDeviation != that.stdDeviation) + return false; + if (this.mean != that.mean) + return false; return true; } diff --git a/src/main/java/com/o19s/es/ltr/feature/store/StorableElement.java b/src/main/java/com/o19s/es/ltr/feature/store/StorableElement.java index b3ecce1b..0439c8c2 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/StorableElement.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/StorableElement.java @@ -66,10 +66,13 @@ public String getName() { void resolveName(XContentParser parser, String name) { if (this.name == null && name != null) { this.name = name; - } else if ( this.name == null /* && name == null */) { + } else if (this.name == null /* && name == null */) { throw new ParsingException(parser.getTokenLocation(), "Field [name] is mandatory"); } else if ( /* this.name != null && */ name != null && !this.name.equals(name)) { - throw new ParsingException(parser.getTokenLocation(), "Invalid [name], expected ["+name+"] but got [" + this.name+"]"); + throw new ParsingException( + parser.getTokenLocation(), + "Invalid [name], expected [" + name + "] but got [" + this.name + "]" + ); } } } diff --git a/src/main/java/com/o19s/es/ltr/feature/store/StoredFeature.java b/src/main/java/com/o19s/es/ltr/feature/store/StoredFeature.java index 1ca9d6ce..e6bec9d4 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/StoredFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/StoredFeature.java @@ -16,42 +16,43 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.LtrQueryContext; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.template.mustache.MustacheUtils; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + import org.apache.lucene.search.Query; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.QueryShardException; import org.opensearch.index.query.ScriptQueryBuilder; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; -import java.io.IOException; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import com.o19s.es.ltr.LtrQueryContext; +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.template.mustache.MustacheUtils; public class StoredFeature implements Feature, Accountable, StorableElement { private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(StoredFeature.class); @@ -133,8 +134,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } else { builder.field(TEMPLATE.getPreferredName()); // it's ok to use NamedXContentRegistry.EMPTY because we don't really parse we copy the structure... - XContentParser parser = MediaTypeRegistry.xContent(template).xContent().createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, template); + XContentParser parser = MediaTypeRegistry + .xContent(template) + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, template); builder.copyCurrentStructure(parser); } builder.endObject(); @@ -156,12 +159,20 @@ public static StoredFeature parse(XContentParser parser, String name) { throw new ParsingException(parser.getTokenLocation(), "Field [template] is mandatory"); } if (state.template instanceof String) { - return new StoredFeature(state.getName(), Collections.unmodifiableList(state.queryParams), - state.templateLanguage, (String) state.template); + return new StoredFeature( + state.getName(), + Collections.unmodifiableList(state.queryParams), + state.templateLanguage, + (String) state.template + ); } else { assert state.template instanceof XContentBuilder; - return new StoredFeature(state.getName(), Collections.unmodifiableList(state.queryParams), - state.templateLanguage, (XContentBuilder) state.template); + return new StoredFeature( + state.getName(), + Collections.unmodifiableList(state.queryParams), + state.templateLanguage, + (XContentBuilder) state.template + ); } } catch (IllegalArgumentException iae) { throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); @@ -193,9 +204,7 @@ public String type() { @Override public Query doToQuery(LtrQueryContext context, FeatureSet set, Map params) { - List missingParams = queryParams.stream() - .filter((x) -> !params.containsKey(x)) - .collect(Collectors.toList()); + List missingParams = queryParams.stream().filter((x) -> !params.containsKey(x)).collect(Collectors.toList()); if (!missingParams.isEmpty()) { String names = missingParams.stream().collect(Collectors.joining(",")); @@ -218,19 +227,25 @@ public Query doToQuery(LtrQueryContext context, FeatureSet set, Map (Character.BYTES * x.length()) + - NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_HEADER + NUM_BYTES_ARRAY_HEADER).sum() + - (Character.BYTES * templateLanguage.length()) + NUM_BYTES_ARRAY_HEADER + - (Character.BYTES * template.length()) + NUM_BYTES_ARRAY_HEADER; + return BASE_RAM_USED + (Character.BYTES * name.length()) + NUM_BYTES_ARRAY_HEADER + queryParams + .stream() + .mapToLong(x -> (Character.BYTES * x.length()) + NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_HEADER + NUM_BYTES_ARRAY_HEADER) + .sum() + (Character.BYTES * templateLanguage.length()) + NUM_BYTES_ARRAY_HEADER + (Character.BYTES * template.length()) + + NUM_BYTES_ARRAY_HEADER; } @Override diff --git a/src/main/java/com/o19s/es/ltr/feature/store/StoredFeatureNormalizers.java b/src/main/java/com/o19s/es/ltr/feature/store/StoredFeatureNormalizers.java index 972a5860..ce6af9f8 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/StoredFeatureNormalizers.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/StoredFeatureNormalizers.java @@ -16,8 +16,11 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.opensearch.OpenSearchException; import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; @@ -27,14 +30,11 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; public class StoredFeatureNormalizers { - public enum Type { STANDARD, MIN_MAX; @@ -49,13 +49,12 @@ public enum Type { PARSER = (XContentParser p, Void c, String featureName) -> { // this seems really intended for switching on the key (here featureName) and making // a decision, when in reality, we want to look a layer deeper and switch on that - ObjectParser parser = new ObjectParser<>("feature_normalizers", - FeatureNormConsumer::new); + ObjectParser parser = new ObjectParser<>("feature_normalizers", FeatureNormConsumer::new); parser.declareObject(FeatureNormConsumer::setFtrNormDefn, StandardFeatureNormDefinition.PARSER, STANDARD); parser.declareObject(FeatureNormConsumer::setFtrNormDefn, MinMaxFeatureNormDefinition.PARSER, MIN_MAX); - FeatureNormConsumer parsedNorm = parser.parse(p, featureName); + FeatureNormConsumer parsedNorm = parser.parse(p, featureName); return parsedNorm.ftrNormDefn; }; @@ -66,8 +65,7 @@ public enum Type { private static class FeatureNormConsumer { FeatureNormDefinition ftrNormDefn; - FeatureNormConsumer() { - } + FeatureNormConsumer() {} public FeatureNormDefinition getFtrNormDefn() { return this.ftrNormDefn; @@ -81,7 +79,6 @@ public void setFtrNormDefn(FeatureNormDefinition ftrNormDefn) { } } - private final Map featureNormalizers; public StoredFeatureNormalizers() { @@ -90,7 +87,7 @@ public StoredFeatureNormalizers() { public StoredFeatureNormalizers(final List ftrNormDefs) { this.featureNormalizers = new HashMap<>(); - for (FeatureNormDefinition ftrNorm: ftrNormDefs) { + for (FeatureNormDefinition ftrNorm : ftrNormDefs) { this.featureNormalizers.put(ftrNorm.featureName(), ftrNorm); } } @@ -110,7 +107,7 @@ public Normalizer getNormalizer(String featureName) { public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { builder.startObject(); // begin feature norms - for (Map.Entry ftrNormDefEntry: featureNormalizers.entrySet()) { + for (Map.Entry ftrNormDefEntry : featureNormalizers.entrySet()) { builder.field(ftrNormDefEntry.getKey()); ftrNormDefEntry.getValue().toXContent(builder, params); } @@ -121,13 +118,12 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par public Map compileOrdToNorms(FeatureSet featureSet) { Map ftrNorms = new HashMap<>(); - for (Map.Entry ftrNormDefEntry: featureNormalizers.entrySet()) { + for (Map.Entry ftrNormDefEntry : featureNormalizers.entrySet()) { String featureName = ftrNormDefEntry.getValue().featureName(); Normalizer ftrNorm = ftrNormDefEntry.getValue().createFeatureNorm(); if (!featureSet.hasFeature(featureName)) { - throw new OpenSearchException("Feature " + featureName + - " not found in feature set " + featureSet.name()); + throw new OpenSearchException("Feature " + featureName + " not found in feature set " + featureSet.name()); } int ord = featureSet.featureOrdinal(featureName); @@ -138,8 +134,10 @@ public Map compileOrdToNorms(FeatureSet featureSet) { @Override public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof StoredFeatureNormalizers)) return false; + if (this == o) + return true; + if (!(o instanceof StoredFeatureNormalizers)) + return false; StoredFeatureNormalizers that = (StoredFeatureNormalizers) o; return that.featureNormalizers.equals(this.featureNormalizers); @@ -154,8 +152,7 @@ public int numNormalizers() { return this.featureNormalizers.size(); } - - private FeatureNormDefinition createFromStreamInput(StreamInput input) throws IOException { + private FeatureNormDefinition createFromStreamInput(StreamInput input) throws IOException { Type normType = input.readEnum(Type.class); if (normType == Type.STANDARD) { return new StandardFeatureNormDefinition(input); diff --git a/src/main/java/com/o19s/es/ltr/feature/store/StoredFeatureSet.java b/src/main/java/com/o19s/es/ltr/feature/store/StoredFeatureSet.java index e055c29a..07a88319 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/StoredFeatureSet.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/StoredFeatureSet.java @@ -16,9 +16,19 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.LtrQueryContext; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.RandomAccess; + import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.Accountable; @@ -31,18 +41,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.RandomAccess; - -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import com.o19s.es.ltr.LtrQueryContext; +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; public class StoredFeatureSet implements FeatureSet, Accountable, StorableElement { public static final int MAX_FEATURES = 10000; @@ -60,9 +61,7 @@ public class StoredFeatureSet implements FeatureSet, Accountable, StorableElemen static { PARSER = new ObjectParser<>(TYPE, ParsingState::new); PARSER.declareString(ParsingState::setName, NAME); - PARSER.declareObjectArray(ParsingState::setFeatures, - (p, c) -> StoredFeature.parse(p), - FEATURES); + PARSER.declareObjectArray(ParsingState::setFeatures, (p, c) -> StoredFeature.parse(p), FEATURES); } public static StoredFeatureSet parse(XContentParser parser) { @@ -94,8 +93,9 @@ public StoredFeatureSet(String name, List features) { for (StoredFeature feature : features) { ordinal++; if (featureMap.put(feature.name(), ordinal) != null) { - throw new IllegalArgumentException("Feature [" + feature.name() + "] defined twice in this set: " + - "feature names must be unique in a set."); + throw new IllegalArgumentException( + "Feature [" + feature.name() + "] defined twice in this set: " + "feature names must be unique in a set." + ); } } } @@ -185,10 +185,7 @@ public StoredFeatureSet append(List features) { * @throws IllegalArgumentException if the resulting size of the set exceed MAX_FEATURES */ public StoredFeatureSet merge(List mergedFeatures) { - int merged = (int) mergedFeatures.stream() - .map(StoredFeature::name) - .filter(this::hasFeature) - .count(); + int merged = (int) mergedFeatures.stream().map(StoredFeature::name).filter(this::hasFeature).count(); if (size() + (mergedFeatures.size() - merged) > MAX_FEATURES) { throw new IllegalArgumentException("The resulting feature set would be too large"); @@ -250,9 +247,10 @@ public int size() { @Override public long ramBytesUsed() { - return BASE_RAM_USED + - featureMap.size() * NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_HEADER + NUM_BYTES_ARRAY_HEADER + - features.stream().mapToLong(StoredFeature::ramBytesUsed).sum(); + return BASE_RAM_USED + featureMap.size() * NUM_BYTES_OBJECT_REF + NUM_BYTES_OBJECT_HEADER + NUM_BYTES_ARRAY_HEADER + features + .stream() + .mapToLong(StoredFeature::ramBytesUsed) + .sum(); } @Override diff --git a/src/main/java/com/o19s/es/ltr/feature/store/StoredLtrModel.java b/src/main/java/com/o19s/es/ltr/feature/store/StoredLtrModel.java index 77ce15ff..43466c8b 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/StoredLtrModel.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/StoredLtrModel.java @@ -16,29 +16,30 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; -import com.o19s.es.ltr.ranker.parser.LtrRankerParser; -import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; +import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.json.JsonXContent; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Objects; - -import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import com.o19s.es.ltr.ranker.parser.LtrRankerParser; +import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; public class StoredLtrModel implements StorableElement { public static final String TYPE = "model"; @@ -58,19 +59,22 @@ public class StoredLtrModel implements StorableElement { static { PARSER = new ObjectParser<>(TYPE, ParsingState::new); PARSER.declareString(ParsingState::setName, NAME); - PARSER.declareObject(ParsingState::setFeatureSet, - (parser, ctx) -> StoredFeatureSet.parse(parser), - FEATURE_SET); - PARSER.declareObject(ParsingState::setRankingModel, LtrModelDefinition.PARSER, - MODEL); + PARSER.declareObject(ParsingState::setFeatureSet, (parser, ctx) -> StoredFeatureSet.parse(parser), FEATURE_SET); + PARSER.declareObject(ParsingState::setRankingModel, LtrModelDefinition.PARSER, MODEL); } public StoredLtrModel(String name, StoredFeatureSet featureSet, LtrModelDefinition definition) { this(name, featureSet, definition.type, definition.definition, definition.modelAsString, definition.featureNormalizers); } - public StoredLtrModel(String name, StoredFeatureSet featureSet, String rankingModelType, String rankingModel, - boolean modelAsString, StoredFeatureNormalizers featureNormalizerSet) { + public StoredLtrModel( + String name, + StoredFeatureSet featureSet, + String rankingModelType, + String rankingModel, + boolean modelAsString, + StoredFeatureNormalizers featureNormalizerSet + ) { this.name = Objects.requireNonNull(name); this.featureSet = Objects.requireNonNull(featureSet); this.rankingModelType = Objects.requireNonNull(rankingModelType); @@ -168,7 +172,9 @@ public String rankingModel() { /** * @return the stored set of feature normalizers */ - public StoredFeatureNormalizers getFeatureNormalizers() { return this.parsedFtrNorms; } + public StoredFeatureNormalizers getFeatureNormalizers() { + return this.parsedFtrNorms; + } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { @@ -182,9 +188,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (modelAsString) { builder.value(rankingModel); } else { - try (XContentParser parser = JsonXContent.jsonXContent.createParser(EMPTY, - LoggingDeprecationHandler.INSTANCE, rankingModel) - ) { + try (XContentParser parser = JsonXContent.jsonXContent.createParser(EMPTY, LoggingDeprecationHandler.INSTANCE, rankingModel)) { builder.copyCurrentStructure(parser); } } @@ -197,15 +201,21 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof StoredLtrModel)) return false; + if (this == o) + return true; + if (!(o instanceof StoredLtrModel)) + return false; StoredLtrModel that = (StoredLtrModel) o; - if (!name.equals(that.name)) return false; - if (!featureSet.equals(that.featureSet)) return false; - if (!rankingModelType.equals(that.rankingModelType)) return false; - if (!parsedFtrNorms.equals(that.parsedFtrNorms)) return false; + if (!name.equals(that.name)) + return false; + if (!featureSet.equals(that.featureSet)) + return false; + if (!rankingModelType.equals(that.rankingModelType)) + return false; + if (!parsedFtrNorms.equals(that.parsedFtrNorms)) + return false; return rankingModel.equals(that.rankingModel); } @@ -246,17 +256,12 @@ public static class LtrModelDefinition implements Writeable { static { PARSER = new ObjectParser<>("model", LtrModelDefinition::new); - PARSER.declareString(LtrModelDefinition::setType, - MODEL_TYPE); - PARSER.declareField((p, d, c) -> d.parseModel(p), - MODEL_DEFINITION, - ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING); - - PARSER.declareNamedObjects(LtrModelDefinition::setNamedFeatureNormalizers, - StoredFeatureNormalizers.PARSER, - FEATURE_NORMALIZERS); - } + PARSER.declareString(LtrModelDefinition::setType, MODEL_TYPE); + PARSER.declareField((p, d, c) -> d.parseModel(p), MODEL_DEFINITION, ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING); + PARSER + .declareNamedObjects(LtrModelDefinition::setNamedFeatureNormalizers, StoredFeatureNormalizers.PARSER, FEATURE_NORMALIZERS); + } private LtrModelDefinition() { this.featureNormalizers = new StoredFeatureNormalizers(); @@ -284,7 +289,6 @@ public void writeTo(StreamOutput out) throws IOException { this.featureNormalizers.writeTo(out); } - private void setType(String type) { this.type = type; } @@ -305,7 +309,9 @@ public boolean isModelAsString() { return modelAsString; } - public StoredFeatureNormalizers getFtrNorms() {return this.featureNormalizers;} + public StoredFeatureNormalizers getFtrNorms() { + return this.featureNormalizers; + } public static LtrModelDefinition parse(XContentParser parser, Void ctx) throws IOException { LtrModelDefinition def = PARSER.parse(parser, ctx); @@ -319,16 +325,16 @@ public static LtrModelDefinition parse(XContentParser parser, Void ctx) throws I } private void parseModel(XContentParser parser) throws IOException { - if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { - modelAsString = true; - definition = parser.text(); - } else { - try (XContentBuilder builder = JsonXContent.contentBuilder()) { - builder.copyCurrentStructure(parser); - modelAsString = false; - definition = builder.toString(); - } + if (parser.currentToken() == XContentParser.Token.VALUE_STRING) { + modelAsString = true; + definition = parser.text(); + } else { + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + builder.copyCurrentStructure(parser); + modelAsString = false; + definition = builder.toString(); } + } } } } diff --git a/src/main/java/com/o19s/es/ltr/feature/store/index/CachedFeatureStore.java b/src/main/java/com/o19s/es/ltr/feature/store/index/CachedFeatureStore.java index 7a45328c..c88c1cad 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/index/CachedFeatureStore.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/index/CachedFeatureStore.java @@ -16,13 +16,14 @@ package com.o19s.es.ltr.feature.store.index; +import java.io.IOException; + +import org.opensearch.common.cache.Cache; + import com.o19s.es.ltr.feature.Feature; import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.feature.store.CompiledLtrModel; import com.o19s.es.ltr.feature.store.FeatureStore; -import org.opensearch.common.cache.Cache; - -import java.io.IOException; /** * Cache layer on top of an {@link IndexFeatureStore} diff --git a/src/main/java/com/o19s/es/ltr/feature/store/index/Caches.java b/src/main/java/com/o19s/es/ltr/feature/store/index/Caches.java index 90a786a4..68175df9 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/index/Caches.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/index/Caches.java @@ -16,20 +16,6 @@ package com.o19s.es.ltr.feature.store.index; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.store.CompiledLtrModel; -import org.apache.lucene.util.Accountable; -import org.apache.lucene.util.RamUsageEstimator; -import org.opensearch.common.CheckedFunction; -import org.opensearch.common.cache.Cache; -import org.opensearch.common.cache.CacheBuilder; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.common.unit.ByteSizeValue; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.monitor.jvm.JvmInfo; - import java.io.IOException; import java.util.Iterator; import java.util.Map; @@ -41,58 +27,76 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Stream; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.common.CheckedFunction; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.CacheBuilder; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.unit.ByteSizeValue; +import org.opensearch.monitor.jvm.JvmInfo; + +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.store.CompiledLtrModel; + /** * Store various caches used by the plugin */ public class Caches { public static final Setting LTR_CACHE_MEM_SETTING; - public static final Setting LTR_CACHE_EXPIRE_AFTER_WRITE = Setting.timeSetting("ltr.caches.expire_after_write", - TimeValue.timeValueHours(1), - TimeValue.timeValueNanos(0), - Setting.Property.NodeScope); - public static final Setting LTR_CACHE_EXPIRE_AFTER_READ = Setting.timeSetting("ltr.caches.expire_after_read", - TimeValue.timeValueHours(1), - TimeValue.timeValueNanos(0), - Setting.Property.NodeScope); + public static final Setting LTR_CACHE_EXPIRE_AFTER_WRITE = Setting + .timeSetting("ltr.caches.expire_after_write", TimeValue.timeValueHours(1), TimeValue.timeValueNanos(0), Setting.Property.NodeScope); + public static final Setting LTR_CACHE_EXPIRE_AFTER_READ = Setting + .timeSetting("ltr.caches.expire_after_read", TimeValue.timeValueHours(1), TimeValue.timeValueNanos(0), Setting.Property.NodeScope); private final Cache featureCache; private final Cache featureSetCache; private final Cache modelCache; static { - LTR_CACHE_MEM_SETTING = Setting.memorySizeSetting("ltr.caches.max_mem", - (s) -> new ByteSizeValue(Math.min(RamUsageEstimator.ONE_MB*10, - JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()/10)).toString(), - Setting.Property.NodeScope); + LTR_CACHE_MEM_SETTING = Setting + .memorySizeSetting( + "ltr.caches.max_mem", + (s) -> new ByteSizeValue(Math.min(RamUsageEstimator.ONE_MB * 10, JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() / 10)) + .toString(), + Setting.Property.NodeScope + ); } private final Map perStoreStats = new ConcurrentHashMap<>(); private final long maxWeight; public Caches(TimeValue expAfterWrite, TimeValue expAfterAccess, ByteSizeValue maxWeight) { this.featureCache = configCache(CacheBuilder.builder(), expAfterWrite, expAfterAccess, maxWeight) - .weigher(Caches::weigther) - .removalListener((l) -> this.onRemove(l.getKey(), l.getValue())) - .build(); + .weigher(Caches::weigther) + .removalListener((l) -> this.onRemove(l.getKey(), l.getValue())) + .build(); this.featureSetCache = configCache(CacheBuilder.builder(), expAfterWrite, expAfterAccess, maxWeight) - .weigher(Caches::weigther) - .removalListener((l) -> this.onRemove(l.getKey(), l.getValue())) - .build(); + .weigher(Caches::weigther) + .removalListener((l) -> this.onRemove(l.getKey(), l.getValue())) + .build(); this.modelCache = configCache(CacheBuilder.builder(), expAfterWrite, expAfterAccess, maxWeight) - .weigher((s, w) -> w.ramBytesUsed()) - .removalListener((l) -> this.onRemove(l.getKey(), l.getValue())) - .build(); + .weigher((s, w) -> w.ramBytesUsed()) + .removalListener((l) -> this.onRemove(l.getKey(), l.getValue())) + .build(); this.maxWeight = maxWeight.getBytes(); } public static long weigther(CacheKey key, Object data) { if (data instanceof Accountable) { - return ((Accountable)data).ramBytesUsed(); + return ((Accountable) data).ramBytesUsed(); } return 1; } - private CacheBuilder configCache(CacheBuilder builder, TimeValue expireAfterWrite, - TimeValue expireAfterAccess, ByteSizeValue maxWeight) { + private CacheBuilder configCache( + CacheBuilder builder, + TimeValue expireAfterWrite, + TimeValue expireAfterAccess, + ByteSizeValue maxWeight + ) { if (expireAfterWrite.nanos() > 0) { builder.setExpireAfterWrite(expireAfterWrite); } @@ -104,9 +108,7 @@ private CacheBuilder configCache(CacheBuilder builder, TimeVa } public Caches(Settings settings) { - this(LTR_CACHE_EXPIRE_AFTER_WRITE.get(settings), - LTR_CACHE_EXPIRE_AFTER_READ.get(settings), - LTR_CACHE_MEM_SETTING.get(settings)); + this(LTR_CACHE_EXPIRE_AFTER_WRITE.get(settings), LTR_CACHE_EXPIRE_AFTER_READ.get(settings), LTR_CACHE_MEM_SETTING.get(settings)); } private void onAdd(CacheKey k, Object acc) { @@ -133,8 +135,7 @@ CompiledLtrModel loadModel(CacheKey key, CheckedFunction E cacheLoad(CacheKey key, Cache cache, - CheckedFunction loader) throws IOException { + private E cacheLoad(CacheKey key, Cache cache, CheckedFunction loader) throws IOException { try { return cache.computeIfAbsent(key, (k) -> { E elt = loader.apply(k.getId()); @@ -168,8 +169,8 @@ public void evictModel(String index, String name) { private void evict(String index, Cache cache) { Iterator ite = cache.keys().iterator(); - while(ite.hasNext()) { - if(ite.next().storeName.equals(index)) { + while (ite.hasNext()) { + if (ite.next().storeName.equals(index)) { ite.remove(); } } @@ -226,12 +227,15 @@ public String getId() { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; CacheKey cacheKey = (CacheKey) o; - if (!storeName.equals(cacheKey.storeName)) return false; + if (!storeName.equals(cacheKey.storeName)) + return false; return id.equals(cacheKey.id); } @@ -290,7 +294,7 @@ private int update(boolean add, Object elt) { } long ramUsed = 1; if (elt instanceof Accountable) { - ramUsed = ((Accountable)elt).ramBytesUsed(); + ramUsed = ((Accountable) elt).ramBytesUsed(); } ram.addAndGet(factor * ramUsed); diff --git a/src/main/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStore.java b/src/main/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStore.java index da66e1a8..f43c54d7 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStore.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStore.java @@ -16,15 +16,18 @@ package com.o19s.es.ltr.feature.store.index; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.store.CompiledLtrModel; -import com.o19s.es.ltr.feature.store.FeatureStore; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; -import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; +import static com.o19s.es.ltr.feature.store.StorableElement.generateId; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Objects; +import java.util.Optional; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.BytesRef; @@ -36,34 +39,32 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.metadata.MetadataCreateIndexService; import org.opensearch.common.CheckedFunction; -import org.opensearch.core.ParseField; -import org.opensearch.core.common.bytes.BytesReference; -import org.apache.logging.log4j.LogManager; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.XContentType; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.Objects; -import java.util.Optional; -import java.util.function.Supplier; -import java.util.regex.Pattern; - -import static com.o19s.es.ltr.feature.store.StorableElement.generateId; +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.store.CompiledLtrModel; +import com.o19s.es.ltr.feature.store.FeatureStore; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; +import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; public class IndexFeatureStore implements FeatureStore { public static final int VERSION = 2; - public static final Setting STORE_VERSION_PROP = Setting.intSetting("index.ltrstore_version", - VERSION, -1, Integer.MAX_VALUE, Setting.Property.IndexScope); + public static final Setting STORE_VERSION_PROP = Setting + .intSetting("index.ltrstore_version", VERSION, -1, Integer.MAX_VALUE, Setting.Property.IndexScope); public static final String DEFAULT_STORE = ".ltrstore"; public static final String STORE_PREFIX = DEFAULT_STORE + "_"; private static final String MAPPING_FILE = "fstore-index-mapping.json"; @@ -83,15 +84,27 @@ public class IndexFeatureStore implements FeatureStore { private static final ObjectParser SOURCE_PARSER; static { SOURCE_PARSER = new ObjectParser<>("", true, ParserState::new); - SOURCE_PARSER.declareField(ParserState::setElement, + SOURCE_PARSER + .declareField( + ParserState::setElement, (CheckedFunction) StoredFeature::parse, - new ParseField(StoredFeature.TYPE), ObjectParser.ValueType.OBJECT); - SOURCE_PARSER.declareField(ParserState::setElement, + new ParseField(StoredFeature.TYPE), + ObjectParser.ValueType.OBJECT + ); + SOURCE_PARSER + .declareField( + ParserState::setElement, (CheckedFunction) StoredFeatureSet::parse, - new ParseField(StoredFeatureSet.TYPE), ObjectParser.ValueType.OBJECT); - SOURCE_PARSER.declareField(ParserState::setElement, + new ParseField(StoredFeatureSet.TYPE), + ObjectParser.ValueType.OBJECT + ); + SOURCE_PARSER + .declareField( + ParserState::setElement, (CheckedFunction) StoredLtrModel::parse, - new ParseField(StoredLtrModel.TYPE), ObjectParser.ValueType.OBJECT); + new ParseField(StoredLtrModel.TYPE), + ObjectParser.ValueType.OBJECT + ); } private final String index; @@ -112,19 +125,15 @@ public String getStoreName() { @Override public Feature load(final String name) throws IOException { return getAndParse(name, StoredFeature.class, StoredFeature.TYPE) - .orElseThrow( - () -> new ResourceNotFoundException("Unknown feature [" + name + "]") - ) - .optimize(); + .orElseThrow(() -> new ResourceNotFoundException("Unknown feature [" + name + "]")) + .optimize(); } @Override public FeatureSet loadSet(final String name) throws IOException { return getAndParse(name, StoredFeatureSet.class, StoredFeatureSet.TYPE) - .orElseThrow( - () -> new ResourceNotFoundException("Unknown featureset [" + name + "]") - ) - .optimize(); + .orElseThrow(() -> new ResourceNotFoundException("Unknown featureset [" + name + "]")) + .optimize(); } /** @@ -169,16 +178,14 @@ public static String storeName(String indexName) { * @return true if this index name is a possible index store, false otherwise. */ public static boolean isIndexStore(String indexName) { - return Objects.requireNonNull(indexName).equals(DEFAULT_STORE) || - (indexName.startsWith(STORE_PREFIX) && indexName.length() > STORE_PREFIX.length()); + return Objects.requireNonNull(indexName).equals(DEFAULT_STORE) + || (indexName.startsWith(STORE_PREFIX) && indexName.length() > STORE_PREFIX.length()); } @Override public CompiledLtrModel loadModel(String name) throws IOException { StoredLtrModel model = getAndParse(name, StoredLtrModel.class, StoredLtrModel.TYPE) - .orElseThrow( - () -> new ResourceNotFoundException("Unknown model [" + name + "]") - ); + .orElseThrow(() -> new ResourceNotFoundException("Unknown model [" + name + "]")); return model.compile(parserFactory); } @@ -230,19 +237,25 @@ public static E parse(Class eltClass, String type return parse(eltClass, type, bytes, 0, bytes.length); } - public static E parse(Class eltClass, String type, byte[] bytes, - int offset, int length) throws IOException { - try (XContentParser parser = MediaTypeRegistry.xContent(bytes).xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, bytes)) { + public static E parse(Class eltClass, String type, byte[] bytes, int offset, int length) + throws IOException { + try ( + XContentParser parser = MediaTypeRegistry + .xContent(bytes) + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, bytes) + ) { return parse(eltClass, type, parser); } } public static E parse(Class eltClass, String type, BytesReference bytesReference) throws IOException { BytesRef ref = bytesReference.toBytesRef(); - try (XContentParser parser = MediaTypeRegistry.xContent(ref.bytes, ref.offset, ref.length).xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, - ref.bytes, ref.offset, ref.length) + try ( + XContentParser parser = MediaTypeRegistry + .xContent(ref.bytes, ref.offset, ref.length) + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, ref.bytes, ref.offset, ref.length) ) { return parse(eltClass, type, parser); } @@ -275,8 +288,8 @@ void setElement(StorableElement element) { public static CreateIndexRequest buildIndexRequest(String indexName) { return new CreateIndexRequest(indexName) - .mapping(readResourceFile(indexName, MAPPING_FILE), XContentType.JSON) - .settings(storeIndexSettings(indexName)); + .mapping(readResourceFile(indexName, MAPPING_FILE), XContentType.JSON) + .settings(storeIndexSettings(indexName)); } private static String readResourceFile(String indexName, String resource) { @@ -285,25 +298,29 @@ private static String readResourceFile(String indexName, String resource) { is.transferTo(out); return out.toString(StandardCharsets.UTF_8.name()); } catch (Exception e) { - LOGGER.error( + LOGGER + .error( (org.apache.logging.log4j.util.Supplier) () -> new ParameterizedMessage( - "failed to create ltr feature store index [{}] with resource [{}]", - indexName, resource), e); + "failed to create ltr feature store index [{}] with resource [{}]", + indexName, + resource + ), + e + ); throw new IllegalStateException("failed to create ltr feature store index with resource [" + resource + "]", e); } } private static Settings storeIndexSettings(String indexName) { - return Settings.builder() - .put(IndexMetadata.INDEX_NUMBER_OF_SHARDS_SETTING.getKey(), 1) - .put(IndexMetadata.INDEX_AUTO_EXPAND_REPLICAS_SETTING.getKey(), "0-2") - .put(STORE_VERSION_PROP.getKey(), VERSION) - .put(IndexMetadata.SETTING_PRIORITY, Integer.MAX_VALUE) - .put(IndexMetadata.SETTING_INDEX_HIDDEN, true) - .put(Settings.builder() - .loadFromSource(readResourceFile(indexName, ANALYSIS_FILE), XContentType.JSON) - .build()) - .build(); + return Settings + .builder() + .put(IndexMetadata.INDEX_NUMBER_OF_SHARDS_SETTING.getKey(), 1) + .put(IndexMetadata.INDEX_AUTO_EXPAND_REPLICAS_SETTING.getKey(), "0-2") + .put(STORE_VERSION_PROP.getKey(), VERSION) + .put(IndexMetadata.SETTING_PRIORITY, Integer.MAX_VALUE) + .put(IndexMetadata.SETTING_INDEX_HIDDEN, true) + .put(Settings.builder().loadFromSource(readResourceFile(indexName, ANALYSIS_FILE), XContentType.JSON).build()) + .build(); } /** @@ -317,7 +334,10 @@ public static void validateFeatureStoreName(String storeName) { if (INVALID_NAMES.matcher(storeName).matches()) { throw new IllegalArgumentException("A featurestore name cannot be based on the words [feature], [featureset] and [model]"); } - MetadataCreateIndexService.validateIndexOrAliasName(storeName, - (name, error) -> new IllegalArgumentException("Invalid feature store name [" + name + "]: " + error)); + MetadataCreateIndexService + .validateIndexOrAliasName( + storeName, + (name, error) -> new IllegalArgumentException("Invalid feature store name [" + name + "]: " + error) + ); } } diff --git a/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java b/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java index 90685a89..80feddaa 100644 --- a/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java +++ b/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java @@ -16,9 +16,14 @@ package com.o19s.es.ltr.logging; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.query.RankerQuery; -import com.o19s.es.ltr.ranker.LogLtrRanker; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -36,13 +41,9 @@ import org.opensearch.search.rescore.QueryRescorer; import org.opensearch.search.rescore.RescoreContext; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.query.RankerQuery; +import com.o19s.es.ltr.ranker.LogLtrRanker; public class LoggingFetchSubPhase implements FetchSubPhase { @Override @@ -56,7 +57,6 @@ public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOExcept List loggers = new ArrayList<>(); Map namedQueries = context.parsedQuery().namedFilters(); - if (namedQueries.size() > 0) { ext.logSpecsStream().filter((l) -> l.getNamedQuery() != null).forEach((l) -> { Tuple query = extractQuery(l, namedQueries); @@ -71,45 +71,72 @@ public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOExcept }); } - - Weight w = context.searcher().rewrite(builder.build()).createWeight(context.searcher(), ScoreMode.COMPLETE, 1.0F); return new LoggingFetchSubPhaseProcessor(w, loggers); } - private Tuple extractQuery(LoggingSearchExtBuilder.LogSpec - logSpec, Map namedQueries) { + private Tuple extractQuery(LoggingSearchExtBuilder.LogSpec logSpec, Map namedQueries) { Query q = namedQueries.get(logSpec.getNamedQuery()); if (q == null) { throw new IllegalArgumentException("No query named [" + logSpec.getNamedQuery() + "] found"); } - return toLogger(logSpec, inspectQuery(q) - .orElseThrow(() -> new IllegalArgumentException("Query named [" + logSpec.getNamedQuery() + - "] must be a [sltr] query [" + - ((q instanceof BoostQuery) ? ((BoostQuery) q).getQuery().getClass().getSimpleName( + return toLogger( + logSpec, + inspectQuery(q) + .orElseThrow( + () -> new IllegalArgumentException( + "Query named [" + + logSpec.getNamedQuery() + + "] must be a [sltr] query [" + + ((q instanceof BoostQuery) ? ((BoostQuery) q).getQuery().getClass().getSimpleName( - ) : q.getClass().getSimpleName()) + - "] found"))); + ) : q.getClass().getSimpleName()) + + "] found" + ) + ) + ); } - private Tuple extractRescore(LoggingSearchExtBuilder.LogSpec logSpec, - List contexts) { + private Tuple extractRescore(LoggingSearchExtBuilder.LogSpec logSpec, List contexts) { if (logSpec.getRescoreIndex() >= contexts.size()) { - throw new IllegalArgumentException("rescore index [" + logSpec.getRescoreIndex() + "] is out of bounds, only " + - "[" + contexts.size() + "] rescore context(s) are available"); + throw new IllegalArgumentException( + "rescore index [" + + logSpec.getRescoreIndex() + + "] is out of bounds, only " + + "[" + + contexts.size() + + "] rescore context(s) are available" + ); } RescoreContext context = contexts.get(logSpec.getRescoreIndex()); if (!(context instanceof QueryRescorer.QueryRescoreContext)) { - throw new IllegalArgumentException("Expected a [QueryRescoreContext] but found a " + - "[" + context.getClass().getSimpleName() + "] " + - "at index [" + logSpec.getRescoreIndex() + "]"); + throw new IllegalArgumentException( + "Expected a [QueryRescoreContext] but found a " + + "[" + + context.getClass().getSimpleName() + + "] " + + "at index [" + + logSpec.getRescoreIndex() + + "]" + ); } QueryRescorer.QueryRescoreContext qrescore = (QueryRescorer.QueryRescoreContext) context; - return toLogger(logSpec, inspectQuery(qrescore.query()) - .orElseThrow(() -> new IllegalArgumentException("Expected a [sltr] query but found a " + - "[" + qrescore.query().getClass().getSimpleName() + "] " + - "at index [" + logSpec.getRescoreIndex() + "]"))); + return toLogger( + logSpec, + inspectQuery(qrescore.query()) + .orElseThrow( + () -> new IllegalArgumentException( + "Expected a [sltr] query but found a " + + "[" + + qrescore.query().getClass().getSimpleName() + + "] " + + "at index [" + + logSpec.getRescoreIndex() + + "]" + ) + ) + ); } private Optional inspectQuery(Query q) { @@ -126,6 +153,7 @@ private Tuple toLogger(LoggingSearchExtBuilder.LogS query = query.toLoggerQuery(consumer); return new Tuple<>(query, consumer); } + static class LoggingFetchSubPhaseProcessor implements FetchSubPhaseProcessor { private final Weight weight; private final List loggers; @@ -136,7 +164,6 @@ static class LoggingFetchSubPhaseProcessor implements FetchSubPhaseProcessor { this.loggers = loggers; } - @Override public void setNextReader(LeafReaderContext readerContext) throws IOException { scorer = weight.scorer(readerContext); @@ -160,19 +187,18 @@ static class HitLogConsumer implements LogLtrRanker.LogConsumer { private final boolean missingAsZero; // [ - // { - // "name": "featureName", - // "value": 1.33 - // }, - // { - // "name": "otherFeatureName", - // } + // { + // "name": "featureName", + // "value": 1.33 + // }, + // { + // "name": "otherFeatureName", + // } // ] private List> currentLog; private SearchHit currentHit; private Map extraLogging; - HitLogConsumer(String name, FeatureSet set, boolean missingAsZero) { this.name = name; this.set = set; diff --git a/src/main/java/com/o19s/es/ltr/logging/LoggingSearchExtBuilder.java b/src/main/java/com/o19s/es/ltr/logging/LoggingSearchExtBuilder.java index aa64662a..3a01b1f1 100644 --- a/src/main/java/com/o19s/es/ltr/logging/LoggingSearchExtBuilder.java +++ b/src/main/java/com/o19s/es/ltr/logging/LoggingSearchExtBuilder.java @@ -16,6 +16,12 @@ package com.o19s.es.ltr.logging; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Stream; + import org.opensearch.common.Nullable; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; @@ -28,12 +34,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.search.SearchExtBuilder; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.stream.Stream; - public class LoggingSearchExtBuilder extends SearchExtBuilder { public static final String NAME = "ltr_log"; @@ -66,11 +66,10 @@ public static LoggingSearchExtBuilder parse(XContentParser parser) throws IOExce try { LoggingSearchExtBuilder ext = PARSER.parse(parser, null); if (ext.logSpecs == null || ext.logSpecs.isEmpty()) { - throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] should define at least one [" + - LOG_SPECS + "]"); + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] should define at least one [" + LOG_SPECS + "]"); } return ext; - } catch(IllegalArgumentException iae) { + } catch (IllegalArgumentException iae) { throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); } } @@ -177,8 +176,10 @@ private static LogSpec parse(XContentParser parser, Void context) throws IOExcep try { LogSpec spec = PARSER.parse(parser, null); if (spec.namedQuery == null && spec.rescoreIndex == null) { - throw new ParsingException(parser.getTokenLocation(), "Either " + - "[" + NAMED_QUERY + "] or [" + RESCORE_INDEX + "] must be set."); + throw new ParsingException( + parser.getTokenLocation(), + "Either " + "[" + NAMED_QUERY + "] or [" + RESCORE_INDEX + "] must be set." + ); } if (spec.rescoreIndex != null && spec.rescoreIndex < 0) { throw new ParsingException(parser.getTokenLocation(), "[" + RESCORE_INDEX + "] must be a non-negative integer."); @@ -208,14 +209,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; LogSpec logSpec = (LogSpec) o; - if (missingAsZero != logSpec.missingAsZero) return false; - if (loggerName != null ? !loggerName.equals(logSpec.loggerName) : logSpec.loggerName != null) return false; - if (namedQuery != null ? !namedQuery.equals(logSpec.namedQuery) : logSpec.namedQuery != null) return false; + if (missingAsZero != logSpec.missingAsZero) + return false; + if (loggerName != null ? !loggerName.equals(logSpec.loggerName) : logSpec.loggerName != null) + return false; + if (namedQuery != null ? !namedQuery.equals(logSpec.namedQuery) : logSpec.namedQuery != null) + return false; return rescoreIndex != null ? rescoreIndex.equals(logSpec.rescoreIndex) : logSpec.rescoreIndex == null; } diff --git a/src/main/java/com/o19s/es/ltr/query/DerivedExpressionQuery.java b/src/main/java/com/o19s/es/ltr/query/DerivedExpressionQuery.java index 39d52cf7..d2ee64a9 100644 --- a/src/main/java/com/o19s/es/ltr/query/DerivedExpressionQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/DerivedExpressionQuery.java @@ -16,36 +16,33 @@ package com.o19s.es.ltr.query; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.ranker.LtrRanker; +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Supplier; + import org.apache.lucene.expressions.Bindings; import org.apache.lucene.expressions.Expression; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.DoubleValues; +import org.apache.lucene.search.DoubleValuesSource; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.Weight; -import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.ConstantScoreWeight; -import org.apache.lucene.search.ConstantScoreScorer; -import org.apache.lucene.search.DoubleValuesSource; -import org.apache.lucene.search.DoubleValues; -import org.apache.lucene.search.ConstantScoreWeight; -import org.apache.lucene.search.ConstantScoreWeight; -import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.Weight; import org.opensearch.ltr.settings.LTRSettings; - -import java.io.IOException; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.function.Supplier; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.LtrRanker; public class DerivedExpressionQuery extends Query implements LtrRewritableQuery { private final FeatureSet features; @@ -68,8 +65,8 @@ public boolean equals(Object obj) { } DerivedExpressionQuery that = (DerivedExpressionQuery) obj; return Objects.deepEquals(expression, that.expression) - && Objects.deepEquals(features, that.features) - && Objects.deepEquals(queryParamValues, that.queryParamValues); + && Objects.deepEquals(features, that.features) + && Objects.deepEquals(queryParamValues, that.queryParamValues); } @Override @@ -117,8 +114,7 @@ public boolean isCacheable(LeafReaderContext ctx) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - return new ConstantScoreScorer(this, score(), - scoreMode, DocIdSetIterator.all(context.reader().maxDoc())); + return new ConstantScoreScorer(this, score(), scoreMode, DocIdSetIterator.all(context.reader().maxDoc())); } }; } @@ -130,9 +126,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException { public boolean equals(Object obj) { assert false; // Should not be called as it is likely an indication that it'll be cached but should not... - return sameClassAs(obj) && - Objects.equals(this.query, ((FVDerivedExpressionQuery)obj).query) && - Objects.equals(this.fvSupplier, ((FVDerivedExpressionQuery)obj).fvSupplier); + return sameClassAs(obj) + && Objects.equals(this.query, ((FVDerivedExpressionQuery) obj).query) + && Objects.equals(this.fvSupplier, ((FVDerivedExpressionQuery) obj).fvSupplier); } @Override @@ -168,10 +164,10 @@ public void extractTerms(Set terms) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - Bindings bindings = new Bindings(){ + Bindings bindings = new Bindings() { @Override public DoubleValuesSource getDoubleValuesSource(String name) { - Double queryParamValue = queryParamValues.get(name); + Double queryParamValue = queryParamValues.get(name); if (queryParamValue != null) { return DoubleValuesSource.constant(queryParamValue); } @@ -188,7 +184,7 @@ public DoubleValuesSource getDoubleValuesSource(String name) { @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { - Bindings bindings = new Bindings(){ + Bindings bindings = new Bindings() { @Override public DoubleValuesSource getDoubleValuesSource(String name) { return new FVDoubleValuesSource(vectorSupplier, features.featureOrdinal(name)); @@ -291,8 +287,7 @@ public boolean equals(Object o) { return false; } FVDoubleValuesSource that = (FVDoubleValuesSource) o; - return ordinal == that.ordinal && - Objects.equals(vectorSupplier, that.vectorSupplier); + return ordinal == that.ordinal && Objects.equals(vectorSupplier, that.vectorSupplier); } @Override @@ -302,10 +297,7 @@ public int hashCode() { @Override public String toString() { - return "FVDoubleValuesSource{" + - "ordinal=" + ordinal + - ", vectorSupplier=" + vectorSupplier + - '}'; + return "FVDoubleValuesSource{" + "ordinal=" + ordinal + ", vectorSupplier=" + vectorSupplier + '}'; } @Override diff --git a/src/main/java/com/o19s/es/ltr/query/LtrQueryBuilder.java b/src/main/java/com/o19s/es/ltr/query/LtrQueryBuilder.java index eb0c3d4c..0ef9c342 100644 --- a/src/main/java/com/o19s/es/ltr/query/LtrQueryBuilder.java +++ b/src/main/java/com/o19s/es/ltr/query/LtrQueryBuilder.java @@ -17,14 +17,12 @@ package com.o19s.es.ltr.query; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import com.o19s.es.ltr.feature.PrebuiltFeature; -import com.o19s.es.ltr.feature.PrebuiltFeatureSet; -import com.o19s.es.ltr.feature.PrebuiltLtrModel; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine; -import com.o19s.es.ltr.utils.AbstractQueryBuilderUtils; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + import org.apache.lucene.search.Query; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; @@ -39,13 +37,16 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.Rewriteable; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; import org.opensearch.script.Script; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Objects; +import com.o19s.es.ltr.feature.PrebuiltFeature; +import com.o19s.es.ltr.feature.PrebuiltFeatureSet; +import com.o19s.es.ltr.feature.PrebuiltLtrModel; +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine; +import com.o19s.es.ltr.utils.AbstractQueryBuilderUtils; public class LtrQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "ltr"; @@ -55,21 +56,21 @@ public class LtrQueryBuilder extends AbstractQueryBuilder { static { PARSER = new ObjectParser<>(NAME, LtrQueryBuilder::new); declareStandardFields(PARSER); - PARSER.declareObjectArray( - LtrQueryBuilder::features, - (parser, context) -> parseInnerQueryBuilder(parser), - new ParseField("features")); - PARSER.declareField( + PARSER + .declareObjectArray(LtrQueryBuilder::features, (parser, context) -> parseInnerQueryBuilder(parser), new ParseField("features")); + PARSER + .declareField( (parser, ltr, context) -> ltr.rankerScript(Script.parse(parser, DEFAULT_SCRIPT_LANG)), - new ParseField("model"), ObjectParser.ValueType.OBJECT_OR_STRING); + new ParseField("model"), + ObjectParser.ValueType.OBJECT_OR_STRING + ); } private Script _rankLibScript; private List _features; private LTRStats _ltrStats; - public LtrQueryBuilder() { - } + public LtrQueryBuilder() {} public LtrQueryBuilder(Script _rankLibScript, List features, LTRStats ltrStats) { this._rankLibScript = _rankLibScript; @@ -85,7 +86,7 @@ public LtrQueryBuilder(StreamInput in, LTRStats ltrStats) throws IOException { } private static void doXArrayContent(String field, List clauses, XContentBuilder builder, Params params) - throws IOException { + throws IOException { if (clauses.isEmpty()) { return; } @@ -104,8 +105,7 @@ public static LtrQueryBuilder fromXContent(XContentParser parser, LTRStats ltrSt throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); } if (builder._rankLibScript == null) { - throw new ParsingException(parser.getTokenLocation(), - "[ltr] query requires a model, none specified"); + throw new ParsingException(parser.getTokenLocation(), "[ltr] query requires a model, none specified"); } builder.ltrStats(ltrStats); return builder; @@ -145,9 +145,7 @@ private Query _doToQuery(QueryShardContext context) throws IOException { } features = Collections.unmodifiableList(features); - RankLibScriptEngine.RankLibModelContainer.Factory factory = context.compile( - _rankLibScript, - RankLibScriptEngine.CONTEXT); + RankLibScriptEngine.RankLibModelContainer.Factory factory = context.compile(_rankLibScript, RankLibScriptEngine.CONTEXT); RankLibScriptEngine.RankLibModelContainer executableScript = factory.newInstance(); LtrRanker ranker = (LtrRanker) executableScript.run(); @@ -192,8 +190,7 @@ protected int doHashCode() { @Override protected boolean doEquals(LtrQueryBuilder other) { - return Objects.equals(_rankLibScript, other._rankLibScript) && - Objects.equals(_features, other._features); + return Objects.equals(_rankLibScript, other._rankLibScript) && Objects.equals(_features, other._features); } @Override @@ -215,7 +212,6 @@ public final LtrQueryBuilder ltrStats(LTRStats ltrStats) { return this; } - public List features() { return _features; } diff --git a/src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java b/src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java index 5ce823f2..5328db5a 100644 --- a/src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/LtrRewritableQuery.java @@ -16,10 +16,10 @@ package com.o19s.es.ltr.query; -import org.apache.lucene.search.Query; - import java.io.IOException; +import org.apache.lucene.search.Query; + public interface LtrRewritableQuery { /** * Rewrite the query so that it holds the vectorSupplier and provide extra logging support diff --git a/src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java b/src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java index e6a44129..7970f4e5 100644 --- a/src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java +++ b/src/main/java/com/o19s/es/ltr/query/LtrRewriteContext.java @@ -16,11 +16,11 @@ package com.o19s.es.ltr.query; +import java.util.function.Supplier; + import com.o19s.es.ltr.ranker.LogLtrRanker; import com.o19s.es.ltr.ranker.LtrRanker; -import java.util.function.Supplier; - /** * Contains context needed to rewrite queries to holds the vectorSupplier and provide extra logging support */ @@ -46,7 +46,7 @@ public Supplier getFeatureVectorSupplier() { */ public LogLtrRanker.LogConsumer getLogConsumer() { if (ranker instanceof LogLtrRanker) { - return ((LogLtrRanker)ranker).getLogConsumer(); + return ((LogLtrRanker) ranker).getLogConsumer(); } return null; } diff --git a/src/main/java/com/o19s/es/ltr/query/NoopScorer.java b/src/main/java/com/o19s/es/ltr/query/NoopScorer.java index 36aa7175..273b2795 100644 --- a/src/main/java/com/o19s/es/ltr/query/NoopScorer.java +++ b/src/main/java/com/o19s/es/ltr/query/NoopScorer.java @@ -15,17 +15,18 @@ */ package com.o19s.es.ltr.query; +import java.io.IOException; + import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; -import java.io.IOException; - /** * Created by doug on 2/3/17. */ public class NoopScorer extends Scorer { private final DocIdSetIterator _noopIter; + /** * Constructs a Scorer * diff --git a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java index cea4d95c..74b359a2 100644 --- a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java @@ -16,9 +16,37 @@ package com.o19s.es.ltr.query; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.RandomAccess; +import java.util.Set; +import java.util.stream.Stream; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; import org.opensearch.ltr.settings.LTRSettings; import org.opensearch.ltr.stats.LTRStats; import org.opensearch.ltr.stats.StatName; + import com.o19s.es.ltr.LtrQueryContext; import com.o19s.es.ltr.feature.Feature; import com.o19s.es.ltr.feature.FeatureSet; @@ -27,33 +55,6 @@ import com.o19s.es.ltr.ranker.LogLtrRanker; import com.o19s.es.ltr.ranker.LtrRanker; import com.o19s.es.ltr.ranker.NullRanker; -import org.apache.lucene.index.IndexReader; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Weight; -import org.apache.lucene.search.ConstantScoreWeight; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.ConstantScoreScorer; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.DisiPriorityQueue; -import org.apache.lucene.search.DisiWrapper; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.RandomAccess; -import java.util.Set; -import java.util.stream.Stream; /** * Lucene query designed to apply a ranking model provided by {@link LtrRanker} @@ -89,11 +90,12 @@ public class RankerQuery extends Query { private final Map featureScoreCache; private RankerQuery( - List queries, - FeatureSet features, - LtrRanker ranker, - Map featureScoreCache, - LTRStats ltrStats) { + List queries, + FeatureSet features, + LtrRanker ranker, + Map featureScoreCache, + LTRStats ltrStats + ) { this.queries = Objects.requireNonNull(queries); this.features = Objects.requireNonNull(features); this.ranker = Objects.requireNonNull(ranker); @@ -110,12 +112,12 @@ private RankerQuery( */ public static RankerQuery build(PrebuiltLtrModel model, LTRStats ltrStats) { return build( - model.ranker(), - model.featureSet(), - new LtrQueryContext(null, Collections.emptySet()), - Collections.emptyMap(), - false, - ltrStats + model.ranker(), + model.featureSet(), + new LtrQueryContext(null, Collections.emptySet()), + Collections.emptyMap(), + false, + ltrStats ); } @@ -128,27 +130,23 @@ public static RankerQuery build(PrebuiltLtrModel model, LTRStats ltrStats) { * @return the lucene query */ public static RankerQuery build( - LtrModel model, - LtrQueryContext context, - Map params, - Boolean featureScoreCacheFlag, - LTRStats ltrStats) { - return build( - model.ranker(), - model.featureSet(), - context, - params, - featureScoreCacheFlag, - ltrStats); + LtrModel model, + LtrQueryContext context, + Map params, + Boolean featureScoreCacheFlag, + LTRStats ltrStats + ) { + return build(model.ranker(), model.featureSet(), context, params, featureScoreCacheFlag, ltrStats); } private static RankerQuery build( - LtrRanker ranker, - FeatureSet features, - LtrQueryContext context, - Map params, - Boolean featureScoreCacheFlag, - LTRStats ltrStats) { + LtrRanker ranker, + FeatureSet features, + LtrQueryContext context, + Map params, + Boolean featureScoreCacheFlag, + LTRStats ltrStats + ) { List queries = features.toQueries(context, params); Map featureScoreCache = null; if (null != featureScoreCacheFlag && featureScoreCacheFlag) { @@ -157,11 +155,15 @@ private static RankerQuery build( return new RankerQuery(queries, features, ranker, featureScoreCache, ltrStats); } - public static RankerQuery buildLogQuery(LogLtrRanker.LogConsumer consumer, FeatureSet features, - LtrQueryContext context, Map params, LTRStats ltrStats) { + public static RankerQuery buildLogQuery( + LogLtrRanker.LogConsumer consumer, + FeatureSet features, + LtrQueryContext context, + Map params, + LTRStats ltrStats + ) { List queries = features.toQueries(context, params); - return new RankerQuery(queries, features, - new LogLtrRanker(consumer, features.size()), null, ltrStats); + return new RankerQuery(queries, features, new LogLtrRanker(consumer, features.size()), null, ltrStats); } public RankerQuery toLoggerQuery(LogLtrRanker.LogConsumer consumer) { @@ -192,9 +194,9 @@ public boolean equals(Object obj) { return false; } RankerQuery that = (RankerQuery) obj; - return Objects.deepEquals(queries, that.queries) && - Objects.deepEquals(features, that.features) && - Objects.equals(ranker, that.ranker); + return Objects.deepEquals(queries, that.queries) + && Objects.deepEquals(features, that.features) + && Objects.equals(ranker, that.ranker); } Stream stream() { @@ -249,8 +251,7 @@ private Weight createWeightInternal(IndexSearcher searcher, ScoreMode scoreMode, return new ConstantScoreWeight(this, boost) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - return new ConstantScoreScorer(this, score(), - scoreMode, DocIdSetIterator.all(context.reader().maxDoc())); + return new ConstantScoreScorer(this, score(), scoreMode, DocIdSetIterator.all(context.reader().maxDoc())); } @Override @@ -279,8 +280,13 @@ public static class RankerWeight extends Weight { private final FeatureSet features; private final Map featureScoreCache; - RankerWeight(RankerQuery query, List weights, FVLtrRankerWrapper ranker, FeatureSet features, - Map featureScoreCache) { + RankerWeight( + RankerQuery query, + List weights, + FVLtrRankerWrapper ranker, + FeatureSet features, + Map featureScoreCache + ) { super(query); assert weights instanceof RandomAccess; this.weights = weights; @@ -296,7 +302,7 @@ public boolean isCacheable(LeafReaderContext ctx) { public void extractTerms(Set terms) { for (Weight w : weights) { -// w.extractTerms(terms); + // w.extractTerms(terms); QueryVisitor.termCollector(terms); } } @@ -341,7 +347,11 @@ public RankerScorer scorer(LeafReaderContext context) throws IOException { } DisjunctionDISI rankerIterator = new DisjunctionDISI( - DocIdSetIterator.all(context.reader().maxDoc()), disiPriorityQueue, context.docBase, featureScoreCache); + DocIdSetIterator.all(context.reader().maxDoc()), + disiPriorityQueue, + context.docBase, + featureScoreCache + ); return new RankerScorer(scorers, rankerIterator, ranker, context.docBase, featureScoreCache); } @@ -357,8 +367,13 @@ class RankerScorer extends Scorer { private final int docBase; private final Map featureScoreCache; - RankerScorer(List scorers, DisjunctionDISI iterator, FVLtrRankerWrapper ranker, - int docBase, Map featureScoreCache) { + RankerScorer( + List scorers, + DisjunctionDISI iterator, + FVLtrRankerWrapper ranker, + int docBase, + Map featureScoreCache + ) { super(RankerWeight.this); this.scorers = scorers; this.iterator = iterator; @@ -417,10 +432,10 @@ public float score() throws IOException { return ranker.score(fv); } -// @Override -// public int freq() throws IOException { -// return scorers.size(); -// } + // @Override + // public int freq() throws IOException { + // return scorers.size(); + // } @Override public DocIdSetIterator iterator() { @@ -451,8 +466,12 @@ static class DisjunctionDISI extends DocIdSetIterator { private final int docBase; private final Map featureScoreCache; - DisjunctionDISI(DocIdSetIterator main, DisiPriorityQueue subIteratorsPriorityQueue, int docBase, - Map featureScoreCache) { + DisjunctionDISI( + DocIdSetIterator main, + DisiPriorityQueue subIteratorsPriorityQueue, + int docBase, + Map featureScoreCache + ) { this.main = main; this.subIteratorsPriorityQueue = subIteratorsPriorityQueue; this.docBase = docBase; @@ -526,8 +545,10 @@ public float score(FeatureVector point) { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; FVLtrRankerWrapper that = (FVLtrRankerWrapper) o; return Objects.equals(wrapped, that.wrapped); } diff --git a/src/main/java/com/o19s/es/ltr/query/StoredLtrQueryBuilder.java b/src/main/java/com/o19s/es/ltr/query/StoredLtrQueryBuilder.java index 50f3651f..4d09cff5 100644 --- a/src/main/java/com/o19s/es/ltr/query/StoredLtrQueryBuilder.java +++ b/src/main/java/com/o19s/es/ltr/query/StoredLtrQueryBuilder.java @@ -16,15 +16,14 @@ package com.o19s.es.ltr.query; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import com.o19s.es.ltr.LtrQueryContext; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.store.CompiledLtrModel; -import com.o19s.es.ltr.feature.store.FeatureStore; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import com.o19s.es.ltr.ranker.linear.LinearRanker; -import com.o19s.es.ltr.utils.FeatureStoreLoader; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; + import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.io.stream.NamedWriteable; @@ -35,14 +34,16 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryShardContext; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; -import java.io.IOException; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; +import com.o19s.es.ltr.LtrQueryContext; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.store.CompiledLtrModel; +import com.o19s.es.ltr.feature.store.FeatureStore; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import com.o19s.es.ltr.ranker.linear.LinearRanker; +import com.o19s.es.ltr.utils.FeatureStoreLoader; /** * sltr query, build a ltr query based on a stored model. @@ -84,7 +85,6 @@ public StoredLtrQueryBuilder(FeatureStoreLoader storeLoader) { this.storeLoader = storeLoader; } - public StoredLtrQueryBuilder(FeatureStoreLoader storeLoader, StreamInput input, LTRStats ltrStats) throws IOException { super(input); this.storeLoader = Objects.requireNonNull(storeLoader); @@ -98,9 +98,8 @@ public StoredLtrQueryBuilder(FeatureStoreLoader storeLoader, StreamInput input, this.ltrStats = ltrStats; } - public static StoredLtrQueryBuilder fromXContent(FeatureStoreLoader storeLoader, - XContentParser parser, - LTRStats ltrStats) throws IOException { + public static StoredLtrQueryBuilder fromXContent(FeatureStoreLoader storeLoader, XContentParser parser, LTRStats ltrStats) + throws IOException { storeLoader = Objects.requireNonNull(storeLoader); final StoredLtrQueryBuilder builder = new StoredLtrQueryBuilder(storeLoader); try { @@ -175,8 +174,10 @@ protected RankerQuery doToQuery(QueryShardContext context) throws IOException { private RankerQuery doToQueryInternal(QueryShardContext context) throws IOException { String indexName = storeName != null ? IndexFeatureStore.indexName(storeName) : IndexFeatureStore.DEFAULT_STORE; FeatureStore store = storeLoader.load(indexName, context::getClient); - LtrQueryContext ltrQueryContext = new LtrQueryContext(context, - activeFeatures == null ? Collections.emptySet() : new HashSet<>(activeFeatures)); + LtrQueryContext ltrQueryContext = new LtrQueryContext( + context, + activeFeatures == null ? Collections.emptySet() : new HashSet<>(activeFeatures) + ); if (modelName != null) { CompiledLtrModel model = store.loadModel(modelName); validateActiveFeatures(model.featureSet(), ltrQueryContext); @@ -195,12 +196,12 @@ private RankerQuery doToQueryInternal(QueryShardContext context) throws IOExcept @Override protected boolean doEquals(StoredLtrQueryBuilder other) { - return Objects.equals(modelName, other.modelName) && - Objects.equals(featureScoreCacheFlag, other.featureScoreCacheFlag) && - Objects.equals(featureSetName, other.featureSetName) && - Objects.equals(storeName, other.storeName) && - Objects.equals(params, other.params) && - Objects.equals(activeFeatures, other.activeFeatures); + return Objects.equals(modelName, other.modelName) + && Objects.equals(featureScoreCacheFlag, other.featureScoreCacheFlag) + && Objects.equals(featureSetName, other.featureSetName) + && Objects.equals(storeName, other.storeName) + && Objects.equals(params, other.params) + && Objects.equals(activeFeatures, other.activeFeatures); } @Override diff --git a/src/main/java/com/o19s/es/ltr/query/ValidatingLtrQueryBuilder.java b/src/main/java/com/o19s/es/ltr/query/ValidatingLtrQueryBuilder.java index 534a48e7..6955c839 100644 --- a/src/main/java/com/o19s/es/ltr/query/ValidatingLtrQueryBuilder.java +++ b/src/main/java/com/o19s/es/ltr/query/ValidatingLtrQueryBuilder.java @@ -16,8 +16,31 @@ package com.o19s.es.ltr.query; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableSet; +import static java.util.stream.Collectors.joining; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; +import java.util.function.BiConsumer; + +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.QueryShardException; import org.opensearch.ltr.stats.LTRStats; import org.opensearch.ltr.stats.StatName; + import com.o19s.es.ltr.LtrQueryContext; import com.o19s.es.ltr.feature.Feature; import com.o19s.es.ltr.feature.FeatureSet; @@ -30,34 +53,11 @@ import com.o19s.es.ltr.feature.store.StoredLtrModel; import com.o19s.es.ltr.ranker.linear.LinearRanker; import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; -import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.search.Query; -import org.opensearch.core.ParseField; -import org.opensearch.core.common.ParsingException; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ObjectParser; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.query.AbstractQueryBuilder; -import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.query.QueryShardException; - -import java.io.IOException; -import java.util.HashSet; -import java.util.Objects; -import java.util.Set; -import java.util.function.BiConsumer; - -import static java.util.Arrays.asList; -import static java.util.Collections.unmodifiableSet; -import static java.util.stream.Collectors.joining; public class ValidatingLtrQueryBuilder extends AbstractQueryBuilder { - public static final Set SUPPORTED_TYPES = unmodifiableSet(new HashSet<>(asList( - StoredFeature.TYPE, - StoredFeatureSet.TYPE, - StoredLtrModel.TYPE))); + public static final Set SUPPORTED_TYPES = unmodifiableSet( + new HashSet<>(asList(StoredFeature.TYPE, StoredFeatureSet.TYPE, StoredLtrModel.TYPE)) + ); public static final String NAME = "validating_ltr_query"; private static final ParseField VALIDATION = new ParseField("validation"); @@ -66,24 +66,21 @@ public class ValidatingLtrQueryBuilder extends AbstractQueryBuilder setElem = (b, v) -> { if (b.element != null) { - throw new IllegalArgumentException("[" + b.element.type() + "] already set, only one element can be set at a time (" + - SUPPORTED_TYPES.stream().collect(joining(",")) + ")."); + throw new IllegalArgumentException( + "[" + + b.element.type() + + "] already set, only one element can be set at a time (" + + SUPPORTED_TYPES.stream().collect(joining(",")) + + ")." + ); } b.element = v; }; - PARSER.declareObject(setElem, - (parser, ctx) -> StoredFeature.parse(parser), - new ParseField(StoredFeature.TYPE)); - PARSER.declareObject(setElem, - (parser, ctx) -> StoredFeatureSet.parse(parser), - new ParseField(StoredFeatureSet.TYPE)); - PARSER.declareObject(setElem, - (parser, ctx) -> StoredLtrModel.parse(parser), - new ParseField(StoredLtrModel.TYPE)); - PARSER.declareObject((b, v) -> b.validation = v, - (p, c) -> FeatureValidation.PARSER.apply(p, null), - new ParseField("validation")); + PARSER.declareObject(setElem, (parser, ctx) -> StoredFeature.parse(parser), new ParseField(StoredFeature.TYPE)); + PARSER.declareObject(setElem, (parser, ctx) -> StoredFeatureSet.parse(parser), new ParseField(StoredFeatureSet.TYPE)); + PARSER.declareObject(setElem, (parser, ctx) -> StoredLtrModel.parse(parser), new ParseField(StoredLtrModel.TYPE)); + PARSER.declareObject((b, v) -> b.validation = v, (p, c) -> FeatureValidation.PARSER.apply(p, null), new ParseField("validation")); declareStandardFields(PARSER); } @@ -91,14 +88,17 @@ public class ValidatingLtrQueryBuilder extends AbstractQueryBuilder getExtraLoggingMap() {return null;} + + default Map getExtraLoggingMap() { + return null; + } + default void reset() {} } } diff --git a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java index eb115c16..50783424 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java +++ b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java @@ -16,13 +16,14 @@ package com.o19s.es.ltr.ranker.dectree; -import com.o19s.es.ltr.ranker.DenseFeatureVector; -import com.o19s.es.ltr.ranker.DenseLtrRanker; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import java.util.Objects; + import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; -import java.util.Objects; +import com.o19s.es.ltr.ranker.DenseFeatureVector; +import com.o19s.es.ltr.ranker.DenseLtrRanker; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; /** * Naive implementation of additive decision tree. @@ -64,7 +65,7 @@ protected float score(DenseFeatureVector vector) { float sum = 0; float[] scores = vector.scores; for (int i = 0; i < trees.length; i++) { - sum += weights[i]*trees[i].eval(scores); + sum += weights[i] * trees[i].eval(scores); } return normalizer.normalize(sum); } @@ -79,13 +80,13 @@ protected int size() { */ @Override public long ramBytesUsed() { - return BASE_RAM_USED + RamUsageEstimator.sizeOf(weights) - + RamUsageEstimator.sizeOf(trees); + return BASE_RAM_USED + RamUsageEstimator.sizeOf(weights) + RamUsageEstimator.sizeOf(trees); } public interface Node extends Accountable { - boolean isLeaf(); - float eval(float[] scores); + boolean isLeaf(); + + float eval(float[] scores); } public static class Split implements Node { diff --git a/src/main/java/com/o19s/es/ltr/ranker/linear/LinearRanker.java b/src/main/java/com/o19s/es/ltr/ranker/linear/LinearRanker.java index 691ec63f..28adb95c 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/linear/LinearRanker.java +++ b/src/main/java/com/o19s/es/ltr/ranker/linear/LinearRanker.java @@ -16,13 +16,14 @@ package com.o19s.es.ltr.ranker.linear; -import com.o19s.es.ltr.ranker.DenseFeatureVector; -import com.o19s.es.ltr.ranker.DenseLtrRanker; +import java.util.Arrays; +import java.util.Objects; + import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; -import java.util.Arrays; -import java.util.Objects; +import com.o19s.es.ltr.ranker.DenseFeatureVector; +import com.o19s.es.ltr.ranker.DenseLtrRanker; /** * Simple linear ranker that applies a dot product based @@ -45,7 +46,7 @@ protected float score(DenseFeatureVector point) { float[] scores = point.scores; float score = 0; for (int i = 0; i < weights.length; i++) { - score += weights[i]*scores[i]; + score += weights[i] * scores[i]; } return score; } @@ -57,8 +58,10 @@ protected int size() { @Override public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; LinearRanker ranker = (LinearRanker) o; diff --git a/src/main/java/com/o19s/es/ltr/ranker/normalizer/FeatureNormalizingRanker.java b/src/main/java/com/o19s/es/ltr/ranker/normalizer/FeatureNormalizingRanker.java index 999dc16e..4c43a913 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/normalizer/FeatureNormalizingRanker.java +++ b/src/main/java/com/o19s/es/ltr/ranker/normalizer/FeatureNormalizingRanker.java @@ -16,12 +16,13 @@ package com.o19s.es.ltr.ranker.normalizer; -import com.o19s.es.ltr.ranker.LtrRanker; +import java.util.Map; +import java.util.Objects; + import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; -import java.util.Map; -import java.util.Objects; +import com.o19s.es.ltr.ranker.LtrRanker; public class FeatureNormalizingRanker implements LtrRanker, Accountable { @@ -55,7 +56,7 @@ public FeatureVector newFeatureVector(FeatureVector reuse) { @Override public float score(FeatureVector point) { - for (Map.Entry ordToNorm: this.ftrNorms.entrySet()) { + for (Map.Entry ordToNorm : this.ftrNorms.entrySet()) { int ord = ordToNorm.getKey(); float origFtrScore = point.getFeatureScore(ord); float normed = ordToNorm.getValue().normalize(origFtrScore); @@ -66,23 +67,26 @@ public float score(FeatureVector point) { @Override public boolean equals(Object other) { - if (other == null) return false; - if (!(other instanceof FeatureNormalizingRanker)) { + if (other == null) + return false; + if (!(other instanceof FeatureNormalizingRanker)) { return false; } - final FeatureNormalizingRanker that = (FeatureNormalizingRanker)(other); - if (that == null) return false; + final FeatureNormalizingRanker that = (FeatureNormalizingRanker) (other); + if (that == null) + return false; - if (!that.ftrNorms.equals(this.ftrNorms)) return false; - if (!that.wrapped.equals(this.wrapped)) return false; + if (!that.ftrNorms.equals(this.ftrNorms)) + return false; + if (!that.wrapped.equals(this.wrapped)) + return false; return true; } @Override public int hashCode() { - return this.wrapped.hashCode() + - (31 * this.ftrNorms.hashCode()); + return this.wrapped.hashCode() + (31 * this.ftrNorms.hashCode()); } @Override @@ -91,10 +95,10 @@ public long ramBytesUsed() { long ftrNormSize = ftrNorms.size() * (PER_FTR_NORM_RAM_USED); if (this.wrapped instanceof Accountable) { - Accountable accountable = (Accountable)this.wrapped; + Accountable accountable = (Accountable) this.wrapped; return BASE_RAM_USED + accountable.ramBytesUsed() + ftrNormSize; } else { return BASE_RAM_USED + ftrNormSize; } } -} \ No newline at end of file +} diff --git a/src/main/java/com/o19s/es/ltr/ranker/normalizer/MinMaxFeatureNormalizer.java b/src/main/java/com/o19s/es/ltr/ranker/normalizer/MinMaxFeatureNormalizer.java index 9e3773f7..aefaa4e3 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/normalizer/MinMaxFeatureNormalizer.java +++ b/src/main/java/com/o19s/es/ltr/ranker/normalizer/MinMaxFeatureNormalizer.java @@ -24,15 +24,15 @@ * See * https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html */ -public class MinMaxFeatureNormalizer implements Normalizer { +public class MinMaxFeatureNormalizer implements Normalizer { float maximum; float minimum; public MinMaxFeatureNormalizer(float minimum, float maximum) { if (minimum >= maximum) { - throw new IllegalArgumentException("Minimum " + Double.toString(minimum) + - " must be smaller than than maximum: " + - Double.toString(maximum)); + throw new IllegalArgumentException( + "Minimum " + Double.toString(minimum) + " must be smaller than than maximum: " + Double.toString(maximum) + ); } this.minimum = minimum; this.maximum = maximum; @@ -40,17 +40,21 @@ public MinMaxFeatureNormalizer(float minimum, float maximum) { @Override public float normalize(float value) { - return (value - minimum) / (maximum - minimum); + return (value - minimum) / (maximum - minimum); } @Override public boolean equals(Object other) { - if (this == other) return true; - if (!(other instanceof MinMaxFeatureNormalizer)) return false; + if (this == other) + return true; + if (!(other instanceof MinMaxFeatureNormalizer)) + return false; MinMaxFeatureNormalizer that = (MinMaxFeatureNormalizer) other; - if (this.minimum != that.minimum) return false; - if (this.maximum != that.maximum) return false; + if (this.minimum != that.minimum) + return false; + if (this.maximum != that.maximum) + return false; return true; diff --git a/src/main/java/com/o19s/es/ltr/ranker/normalizer/Normalizers.java b/src/main/java/com/o19s/es/ltr/ranker/normalizer/Normalizers.java index 9567f11b..70cc8e3c 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/normalizer/Normalizers.java +++ b/src/main/java/com/o19s/es/ltr/ranker/normalizer/Normalizers.java @@ -24,10 +24,12 @@ * Class that manages Normalizer implementations */ public class Normalizers { - private static final Map NORMALIZERS = Collections.unmodifiableMap(new HashMap() {{ - put(NOOP_NORMALIZER_NAME, new NoopNormalizer()); - put(SIGMOID_NORMALIZER_NAME, new SigmoidNormalizer()); - }}); + private static final Map NORMALIZERS = Collections.unmodifiableMap(new HashMap() { + { + put(NOOP_NORMALIZER_NAME, new NoopNormalizer()); + put(SIGMOID_NORMALIZER_NAME, new SigmoidNormalizer()); + } + }); public static final String NOOP_NORMALIZER_NAME = "noop"; public static final String SIGMOID_NORMALIZER_NAME = "sigmoid"; diff --git a/src/main/java/com/o19s/es/ltr/ranker/normalizer/StandardFeatureNormalizer.java b/src/main/java/com/o19s/es/ltr/ranker/normalizer/StandardFeatureNormalizer.java index 9bb9a948..93dba518 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/normalizer/StandardFeatureNormalizer.java +++ b/src/main/java/com/o19s/es/ltr/ranker/normalizer/StandardFeatureNormalizer.java @@ -21,13 +21,11 @@ public class StandardFeatureNormalizer implements Normalizer { private float mean; private float stdDeviation; - public StandardFeatureNormalizer(float mean, float stdDeviation) { this.mean = mean; this.stdDeviation = stdDeviation; } - @Override public float normalize(float value) { return (value - this.mean) / this.stdDeviation; @@ -35,12 +33,16 @@ public float normalize(float value) { @Override public boolean equals(Object other) { - if (this == other) return true; - if (!(other instanceof StandardFeatureNormalizer)) return false; + if (this == other) + return true; + if (!(other instanceof StandardFeatureNormalizer)) + return false; StandardFeatureNormalizer that = (StandardFeatureNormalizer) other; - if (this.mean != that.mean) return false; - if (this.stdDeviation != that.stdDeviation) return false; + if (this.mean != that.mean) + return false; + if (this.stdDeviation != that.stdDeviation) + return false; return true; @@ -53,4 +55,4 @@ public int hashCode() { return hashCode; } -} \ No newline at end of file +} diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/LinearRankerParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/LinearRankerParser.java index 6788903e..6843ebb1 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/LinearRankerParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/LinearRankerParser.java @@ -16,25 +16,24 @@ package com.o19s.es.ltr.ranker.parser; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.ranker.linear.LinearRanker; -import org.opensearch.core.common.ParsingException; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.json.JsonXContent; +import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; import java.io.IOException; -import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.XContentParser; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.linear.LinearRanker; public class LinearRankerParser implements LtrRankerParser { public static final String TYPE = "model/linear"; @Override public LinearRanker parse(FeatureSet set, String model) { - try (XContentParser parser = JsonXContent.jsonXContent.createParser(EMPTY, - LoggingDeprecationHandler.INSTANCE, model) - ) { + try (XContentParser parser = JsonXContent.jsonXContent.createParser(EMPTY, LoggingDeprecationHandler.INSTANCE, model)) { return parse(parser, set); } catch (IOException e) { throw new IllegalArgumentException(e.getMessage(), e); @@ -44,7 +43,7 @@ public LinearRanker parse(FeatureSet set, String model) { private LinearRanker parse(XContentParser parser, FeatureSet set) throws IOException { float[] weights = new float[set.size()]; if (parser.nextToken() != XContentParser.Token.START_OBJECT) { - throw new ParsingException(parser.getTokenLocation(), "Expected start object but found [" + parser.currentToken() +"]"); + throw new ParsingException(parser.getTokenLocation(), "Expected start object but found [" + parser.currentToken() + "]"); } while (parser.nextToken() == XContentParser.Token.FIELD_NAME) { String fname = parser.currentName(); @@ -52,7 +51,7 @@ private LinearRanker parse(XContentParser parser, FeatureSet set) throws IOExcep throw new ParsingException(parser.getTokenLocation(), "Feature [" + fname + "] is unknown."); } if (parser.nextToken() != XContentParser.Token.VALUE_NUMBER) { - throw new ParsingException(parser.getTokenLocation(), "Expected a float but found [" + parser.currentToken() +"]"); + throw new ParsingException(parser.getTokenLocation(), "Expected a float but found [" + parser.currentToken() + "]"); } weights[set.featureOrdinal(fname)] = parser.floatValue(); } diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java index b594afc2..42a5c793 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java @@ -16,25 +16,26 @@ package com.o19s.es.ltr.ranker.parser; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; -import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree.Node; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; -import com.o19s.es.ltr.ranker.normalizer.Normalizers; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.ListIterator; + +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentParseException; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.json.JsonXContent; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.ListIterator; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree.Node; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import com.o19s.es.ltr.ranker.normalizer.Normalizers; /** * Parse XGBoost models generated by mjolnir (https://gerrit.wikimedia.org/r/search/MjoLniR) @@ -45,8 +46,9 @@ public class XGBoostJsonParser implements LtrRankerParser { @Override public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { XGBoostDefinition modelDefinition; - try (XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, model) + try ( + XContentParser parser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, model) ) { modelDefinition = XGBoostDefinition.parse(parser, set); } catch (IOException e) { @@ -94,8 +96,10 @@ public static XGBoostDefinition parse(XContentParser parser, FeatureSet set) thr definition.splitParserStates.add(SplitParserState.parse(parser, set)); } } else { - throw new ParsingException(parser.getTokenLocation(), "Expected [START_ARRAY] or [START_OBJECT] but got [" - + startToken + "]"); + throw new ParsingException( + parser.getTokenLocation(), + "Expected [START_ARRAY] or [START_OBJECT] but got [" + startToken + "]" + ); } if (definition.splitParserStates.size() == 0) { throw new ParsingException(parser.getTokenLocation(), "XGBoost model must define at lease one tree"); @@ -139,7 +143,7 @@ void setSplitParserStates(List splitParserStates) { Node[] getTrees(FeatureSet set) { Node[] trees = new Node[splitParserStates.size()]; ListIterator it = splitParserStates.listIterator(); - while(it.hasNext()) { + while (it.hasNext()) { trees[it.nextIndex()] = it.next().toNode(set); } return trees; @@ -158,8 +162,7 @@ private static class SplitParserState { PARSER.declareInt(SplitParserState::setLeftNodeId, new ParseField("yes")); PARSER.declareInt(SplitParserState::setMissingNodeId, new ParseField("missing")); PARSER.declareFloat(SplitParserState::setLeaf, new ParseField("leaf")); - PARSER.declareObjectArray(SplitParserState::setChildren, SplitParserState::parse, - new ParseField("children")); + PARSER.declareObjectArray(SplitParserState::setChildren, SplitParserState::parse, new ParseField("children")); PARSER.declareFloat(SplitParserState::setThreshold, new ParseField("split_condition")); } @@ -181,8 +184,10 @@ public static SplitParserState parse(XContentParser parser, FeatureSet set) { throw new ParsingException(parser.getTokenLocation(), "This split does not have all the required fields"); } if (!split.splitHasValidChildren()) { - throw new ParsingException(parser.getTokenLocation(), "Split structure is invalid, yes, no and/or" + - " missing branches does not point to the proper children."); + throw new ParsingException( + parser.getTokenLocation(), + "Split structure is invalid, yes, no and/or" + " missing branches does not point to the proper children." + ); } if (!set.hasFeature(split.split)) { throw new ParsingException(parser.getTokenLocation(), "Unknown feature [" + split.split + "]"); @@ -192,6 +197,7 @@ public static SplitParserState parse(XContentParser parser, FeatureSet set) { } return split; } + void setNodeId(Integer nodeId) { this.nodeId = nodeId; } @@ -229,8 +235,14 @@ void setChildren(List children) { } boolean splitHasAllFields() { - return nodeId != null && threshold != null && split != null && leftNodeId != null && rightNodeId != null && depth != null - && children != null && children.size() == 2; + return nodeId != null + && threshold != null + && split != null + && leftNodeId != null + && rightNodeId != null + && depth != null + && children != null + && children.size() == 2; } boolean leafHasAllFields() { @@ -238,18 +250,21 @@ boolean leafHasAllFields() { } boolean splitHasValidChildren() { - return children.size() == 2 && - leftNodeId.equals(children.get(0).nodeId) && rightNodeId.equals(children.get(1).nodeId); + return children.size() == 2 && leftNodeId.equals(children.get(0).nodeId) && rightNodeId.equals(children.get(1).nodeId); } + boolean isSplit() { return leaf == null; } - Node toNode(FeatureSet set) { if (isSplit()) { - return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set), - set.featureOrdinal(split), threshold); + return new NaiveAdditiveDecisionTree.Split( + children.get(0).toNode(set), + children.get(1).toNode(set), + set.featureOrdinal(split), + threshold + ); } else { return new NaiveAdditiveDecisionTree.Leaf(leaf); } diff --git a/src/main/java/com/o19s/es/ltr/ranker/ranklib/DenseProgramaticDataPoint.java b/src/main/java/com/o19s/es/ltr/ranker/ranklib/DenseProgramaticDataPoint.java index 9dd5eb78..6598a04f 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/ranklib/DenseProgramaticDataPoint.java +++ b/src/main/java/com/o19s/es/ltr/ranker/ranklib/DenseProgramaticDataPoint.java @@ -16,11 +16,12 @@ */ package com.o19s.es.ltr.ranker.ranklib; -import ciir.umass.edu.learning.DataPoint; -import ciir.umass.edu.utilities.RankLibError; +import java.util.Arrays; + import com.o19s.es.ltr.ranker.LtrRanker; -import java.util.Arrays; +import ciir.umass.edu.learning.DataPoint; +import ciir.umass.edu.utilities.RankLibError; /** * Implements FeatureVector but without needing to pass in a stirng @@ -28,20 +29,21 @@ */ public class DenseProgramaticDataPoint extends DataPoint implements LtrRanker.FeatureVector { private static final int RANKLIB_FEATURE_INDEX_OFFSET = 1; + public DenseProgramaticDataPoint(int numFeatures) { - this.fVals = new float[numFeatures+RANKLIB_FEATURE_INDEX_OFFSET]; // add 1 because RankLib features 1 based + this.fVals = new float[numFeatures + RANKLIB_FEATURE_INDEX_OFFSET]; // add 1 because RankLib features 1 based } public float getFeatureValue(int fid) { - if(fid > 0 && fid < this.fVals.length) { - return isUnknown(this.fVals[fid])?0.0F:this.fVals[fid]; + if (fid > 0 && fid < this.fVals.length) { + return isUnknown(this.fVals[fid]) ? 0.0F : this.fVals[fid]; } else { throw RankLibError.create("Error in DenseDataPoint::getFeatureValue(): requesting unspecified feature, fid=" + fid); } } public void setFeatureValue(int fid, float fval) { - if(fid > 0 && fid < this.fVals.length) { + if (fid > 0 && fid < this.fVals.length) { this.fVals[fid] = fval; } else { throw RankLibError.create("Error in DenseDataPoint::setFeatureValue(): feature (id=" + fid + ") not found."); @@ -59,13 +61,13 @@ public float[] getFeatureVector() { @Override public void setFeatureScore(int featureIdx, float score) { // add 1 because RankLib features 1 based - this.setFeatureValue(featureIdx+1, score); + this.setFeatureValue(featureIdx + 1, score); } @Override public float getFeatureScore(int featureIdx) { // add 1 because RankLib features 1 based - return this.getFeatureValue(featureIdx+1); + return this.getFeatureValue(featureIdx + 1); } public void reset() { diff --git a/src/main/java/com/o19s/es/ltr/ranker/ranklib/RankLibScriptEngine.java b/src/main/java/com/o19s/es/ltr/ranker/ranklib/RankLibScriptEngine.java index a2e4de3b..356e340f 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/ranklib/RankLibScriptEngine.java +++ b/src/main/java/com/o19s/es/ltr/ranker/ranklib/RankLibScriptEngine.java @@ -16,17 +16,18 @@ */ package com.o19s.es.ltr.ranker.ranklib; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; -import org.opensearch.script.ScriptContext; -import org.opensearch.script.ScriptEngine; - import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.Objects; import java.util.Set; +import org.opensearch.script.ScriptContext; +import org.opensearch.script.ScriptEngine; + +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; + /** * Created by doug on 12/30/16. * The ranklib models are treated like scripts in that they @@ -41,8 +42,10 @@ public class RankLibScriptEngine implements ScriptEngine { public static final String NAME = "ranklib"; public static final String EXTENSION = "ranklib"; - public static final ScriptContext CONTEXT = - new ScriptContext<>("ranklib", RankLibModelContainer.Factory.class); + public static final ScriptContext CONTEXT = new ScriptContext<>( + "ranklib", + RankLibModelContainer.Factory.class + ); private final LtrRankerParserFactory factory; public RankLibScriptEngine(LtrRankerParserFactory factory) { @@ -50,7 +53,6 @@ public RankLibScriptEngine(LtrRankerParserFactory factory) { this.factory = Objects.requireNonNull(factory); } - @Override public String getType() { return NAME; @@ -98,6 +100,6 @@ public Object run() { return _ranker; } - public void execute () {} + public void execute() {} } } diff --git a/src/main/java/com/o19s/es/ltr/ranker/ranklib/RanklibModelParser.java b/src/main/java/com/o19s/es/ltr/ranker/ranklib/RanklibModelParser.java index 107b5f2b..1f341314 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/ranklib/RanklibModelParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/ranklib/RanklibModelParser.java @@ -16,12 +16,13 @@ package com.o19s.es.ltr.ranker.ranklib; -import ciir.umass.edu.learning.Ranker; -import ciir.umass.edu.learning.RankerFactory; import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.ranker.LtrRanker; import com.o19s.es.ltr.ranker.parser.LtrRankerParser; +import ciir.umass.edu.learning.Ranker; +import ciir.umass.edu.learning.RankerFactory; + /** * Load a ranklib model from a script file, mostly a wrapper around the * existing script that complies with the {@link LtrRankerParser} interface diff --git a/src/main/java/com/o19s/es/ltr/ranker/ranklib/RanklibRanker.java b/src/main/java/com/o19s/es/ltr/ranker/ranklib/RanklibRanker.java index c8c4d116..8c671988 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/ranklib/RanklibRanker.java +++ b/src/main/java/com/o19s/es/ltr/ranker/ranklib/RanklibRanker.java @@ -16,9 +16,10 @@ package com.o19s.es.ltr.ranker.ranklib; -import ciir.umass.edu.learning.Ranker; import com.o19s.es.ltr.ranker.LtrRanker; +import ciir.umass.edu.learning.Ranker; + public class RanklibRanker implements LtrRanker { private final Ranker ranker; private final int featureSetSize; diff --git a/src/main/java/com/o19s/es/ltr/rest/AutoDetectParser.java b/src/main/java/com/o19s/es/ltr/rest/AutoDetectParser.java index 057070ff..5ac821d6 100644 --- a/src/main/java/com/o19s/es/ltr/rest/AutoDetectParser.java +++ b/src/main/java/com/o19s/es/ltr/rest/AutoDetectParser.java @@ -16,20 +16,21 @@ package com.o19s.es.ltr.rest; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; +import static com.o19s.es.ltr.query.ValidatingLtrQueryBuilder.SUPPORTED_TYPES; +import static java.util.stream.Collectors.joining; + +import java.io.IOException; + import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; - -import static com.o19s.es.ltr.query.ValidatingLtrQueryBuilder.SUPPORTED_TYPES; -import static java.util.stream.Collectors.joining; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; class AutoDetectParser { private String expectedName; @@ -39,18 +40,10 @@ class AutoDetectParser { private static final ObjectParser PARSER = new ObjectParser<>("storable_elements"); static { - PARSER.declareObject(AutoDetectParser::setElement, - StoredFeature::parse, - new ParseField(StoredFeature.TYPE)); - PARSER.declareObject(AutoDetectParser::setElement, - StoredFeatureSet::parse, - new ParseField(StoredFeatureSet.TYPE)); - PARSER.declareObject(AutoDetectParser::setElement, - StoredLtrModel::parse, - new ParseField(StoredLtrModel.TYPE)); - PARSER.declareObject((b, v) -> b.validation = v, - (p, c) -> FeatureValidation.PARSER.apply(p, null), - new ParseField("validation")); + PARSER.declareObject(AutoDetectParser::setElement, StoredFeature::parse, new ParseField(StoredFeature.TYPE)); + PARSER.declareObject(AutoDetectParser::setElement, StoredFeatureSet::parse, new ParseField(StoredFeatureSet.TYPE)); + PARSER.declareObject(AutoDetectParser::setElement, StoredLtrModel::parse, new ParseField(StoredLtrModel.TYPE)); + PARSER.declareObject((b, v) -> b.validation = v, (p, c) -> FeatureValidation.PARSER.apply(p, null), new ParseField("validation")); } AutoDetectParser(String name) { @@ -60,8 +53,10 @@ class AutoDetectParser { public void parse(XContentParser parser) throws IOException { PARSER.parse(parser, this, expectedName); if (element == null) { - throw new ParsingException(parser.getTokenLocation(), "Element of type [" + SUPPORTED_TYPES.stream().collect(joining(",")) + - "] is mandatory."); + throw new ParsingException( + parser.getTokenLocation(), + "Element of type [" + SUPPORTED_TYPES.stream().collect(joining(",")) + "] is mandatory." + ); } } @@ -71,8 +66,13 @@ public StorableElement getElement() { public void setElement(StorableElement element) { if (this.element != null) { - throw new IllegalArgumentException("[" + element.type() + "] already set, only one element can be set at a time (" + - SUPPORTED_TYPES.stream().collect(joining(",")) + ")."); + throw new IllegalArgumentException( + "[" + + element.type() + + "] already set, only one element can be set at a time (" + + SUPPORTED_TYPES.stream().collect(joining(",")) + + ")." + ); } this.element = element; } diff --git a/src/main/java/com/o19s/es/ltr/rest/FeatureStoreBaseRestHandler.java b/src/main/java/com/o19s/es/ltr/rest/FeatureStoreBaseRestHandler.java index 373d6d4b..d22d28c6 100644 --- a/src/main/java/com/o19s/es/ltr/rest/FeatureStoreBaseRestHandler.java +++ b/src/main/java/com/o19s/es/ltr/rest/FeatureStoreBaseRestHandler.java @@ -16,10 +16,11 @@ package com.o19s.es.ltr.rest; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; + public abstract class FeatureStoreBaseRestHandler extends BaseRestHandler { protected String indexName(RestRequest request) { diff --git a/src/main/java/com/o19s/es/ltr/rest/RestAddFeatureToSet.java b/src/main/java/com/o19s/es/ltr/rest/RestAddFeatureToSet.java index ca0f2c60..62b55fd1 100644 --- a/src/main/java/com/o19s/es/ltr/rest/RestAddFeatureToSet.java +++ b/src/main/java/com/o19s/es/ltr/rest/RestAddFeatureToSet.java @@ -16,22 +16,23 @@ package com.o19s.es.ltr.rest; -import org.opensearch.ltr.settings.LTRSettings; -import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetRequestBuilder; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StoredFeature; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableList; + +import java.io.IOException; +import java.util.List; + import org.opensearch.client.node.NodeClient; import org.opensearch.core.ParseField; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ltr.settings.LTRSettings; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestStatusToXContentListener; -import java.io.IOException; -import java.util.List; - -import static java.util.Arrays.asList; -import static java.util.Collections.unmodifiableList; +import com.o19s.es.ltr.action.AddFeaturesToSetAction.AddFeaturesToSetRequestBuilder; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StoredFeature; public class RestAddFeatureToSet extends FeatureStoreBaseRestHandler { @@ -42,12 +43,14 @@ public String getName() { @Override public List routes() { - return unmodifiableList(asList( + return unmodifiableList( + asList( new Route(RestRequest.Method.POST, "/_ltr/_featureset/{name}/_addfeatures/{query}"), new Route(RestRequest.Method.POST, "/_ltr/{store}/_featureset/{name}/_addfeatures/{query}"), new Route(RestRequest.Method.POST, "/_ltr/_featureset/{name}/_addfeatures"), new Route(RestRequest.Method.POST, "/_ltr/{store}/_featureset/{name}/_addfeatures") - )); + ) + ); } @Override @@ -74,13 +77,15 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli validation = featuresParser.validation; } if (featureQuery == null && (features == null || features.isEmpty())) { - throw new IllegalArgumentException("features must be provided as a query for the feature store " + - "or in the body, none provided"); + throw new IllegalArgumentException( + "features must be provided as a query for the feature store " + "or in the body, none provided" + ); } if (featureQuery != null && (features != null && !features.isEmpty())) { - throw new IllegalArgumentException("features must be provided as a query for the feature store " + - "or directly in the body not both"); + throw new IllegalArgumentException( + "features must be provided as a query for the feature store " + "or directly in the body not both" + ); } AddFeaturesToSetRequestBuilder builder = new AddFeaturesToSetRequestBuilder(client); @@ -99,14 +104,13 @@ static class FeaturesParserState { private List features; private FeatureValidation validation; static { - PARSER.declareObjectArray( + PARSER + .declareObjectArray( FeaturesParserState::setFeatures, (parser, context) -> StoredFeature.parse(parser), - new ParseField("features")); - PARSER.declareObject( - FeaturesParserState::setValidation, - FeatureValidation.PARSER::apply, - new ParseField("validation")); + new ParseField("features") + ); + PARSER.declareObject(FeaturesParserState::setValidation, FeatureValidation.PARSER::apply, new ParseField("validation")); } public void parse(XContentParser parser) throws IOException { diff --git a/src/main/java/com/o19s/es/ltr/rest/RestCreateModelFromSet.java b/src/main/java/com/o19s/es/ltr/rest/RestCreateModelFromSet.java index 3b0b6819..073ed78f 100644 --- a/src/main/java/com/o19s/es/ltr/rest/RestCreateModelFromSet.java +++ b/src/main/java/com/o19s/es/ltr/rest/RestCreateModelFromSet.java @@ -16,29 +16,30 @@ package com.o19s.es.ltr.rest; -import org.opensearch.ltr.settings.LTRSettings; -import com.o19s.es.ltr.action.CreateModelFromSetAction; -import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetRequestBuilder; -import com.o19s.es.ltr.feature.FeatureValidation; -import com.o19s.es.ltr.feature.store.StoredLtrModel; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableList; + +import java.io.IOException; +import java.util.List; + import org.opensearch.ExceptionsHelper; -import org.opensearch.core.action.ActionListener; import org.opensearch.client.node.NodeClient; import org.opensearch.core.ParseField; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.ParsingException; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.ltr.settings.LTRSettings; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.core.rest.RestStatus; import org.opensearch.rest.action.RestStatusToXContentListener; -import java.io.IOException; -import java.util.List; - -import static java.util.Arrays.asList; -import static java.util.Collections.unmodifiableList; +import com.o19s.es.ltr.action.CreateModelFromSetAction; +import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetRequestBuilder; +import com.o19s.es.ltr.feature.FeatureValidation; +import com.o19s.es.ltr.feature.store.StoredLtrModel; public class RestCreateModelFromSet extends FeatureStoreBaseRestHandler { @@ -49,9 +50,12 @@ public String getName() { @Override public List routes() { - return unmodifiableList(asList( - new Route(RestRequest.Method.POST , "/_ltr/{store}/_featureset/{name}/_createmodel"), - new Route(RestRequest.Method.POST, "/_ltr/_featureset/{name}/_createmodel" ))); + return unmodifiableList( + asList( + new Route(RestRequest.Method.POST, "/_ltr/{store}/_featureset/{name}/_createmodel"), + new Route(RestRequest.Method.POST, "/_ltr/_featureset/{name}/_createmodel") + ) + ); } @Override @@ -79,30 +83,37 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } builder.request().setValidation(state.validation); builder.routing(routing); - return (channel) -> builder.execute(ActionListener.wrap( - response -> new RestStatusToXContentListener(channel, - (r) -> r.getResponse().getLocation(routing)).onResponse(response), - (e) -> { - final Exception exc; - final RestStatus status; - if (ExceptionsHelper.unwrap(e, VersionConflictEngineException.class) != null) { - exc = new IllegalArgumentException("Element of type [" + StoredLtrModel.TYPE + - "] are not updatable, please create a new one instead."); - exc.addSuppressed(e); - status = RestStatus.METHOD_NOT_ALLOWED; - } else { - exc = e; - status = ExceptionsHelper.status(exc); - } - - try { - channel.sendResponse(new BytesRestResponse(channel, status, exc)); - } catch (Exception inner) { - inner.addSuppressed(e); - logger.error("failed to send failure response", inner); - } - } - )); + return (channel) -> builder + .execute( + ActionListener + .wrap( + response -> new RestStatusToXContentListener( + channel, + (r) -> r.getResponse().getLocation(routing) + ).onResponse(response), + (e) -> { + final Exception exc; + final RestStatus status; + if (ExceptionsHelper.unwrap(e, VersionConflictEngineException.class) != null) { + exc = new IllegalArgumentException( + "Element of type [" + StoredLtrModel.TYPE + "] are not updatable, please create a new one instead." + ); + exc.addSuppressed(e); + status = RestStatus.METHOD_NOT_ALLOWED; + } else { + exc = e; + status = ExceptionsHelper.status(exc); + } + + try { + channel.sendResponse(new BytesRestResponse(channel, status, exc)); + } catch (Exception inner) { + inner.addSuppressed(e); + logger.error("failed to send failure response", inner); + } + } + ) + ); } private static class ParserState { @@ -143,9 +154,7 @@ private static class Model { private static final ObjectParser MODEL_PARSER = new ObjectParser<>("model", Model::new); static { MODEL_PARSER.declareString(Model::setName, new ParseField("name")); - MODEL_PARSER.declareObject(Model::setModel, - StoredLtrModel.LtrModelDefinition::parse, - new ParseField("model")); + MODEL_PARSER.declareObject(Model::setModel, StoredLtrModel.LtrModelDefinition::parse, new ParseField("model")); } String name; diff --git a/src/main/java/com/o19s/es/ltr/rest/RestFeatureManager.java b/src/main/java/com/o19s/es/ltr/rest/RestFeatureManager.java index e79f1b47..dafbc40a 100644 --- a/src/main/java/com/o19s/es/ltr/rest/RestFeatureManager.java +++ b/src/main/java/com/o19s/es/ltr/rest/RestFeatureManager.java @@ -16,33 +16,33 @@ package com.o19s.es.ltr.rest; -import com.o19s.es.ltr.action.ClearCachesAction; -import com.o19s.es.ltr.action.FeatureStoreAction; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; -import org.opensearch.core.action.ActionListener; +import static com.o19s.es.ltr.feature.store.StorableElement.generateId; +import static com.o19s.es.ltr.query.ValidatingLtrQueryBuilder.SUPPORTED_TYPES; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableList; +import static org.opensearch.core.rest.RestStatus.NOT_FOUND; +import static org.opensearch.core.rest.RestStatus.OK; + +import java.io.IOException; +import java.util.List; + import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.ltr.settings.LTRSettings; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.core.rest.RestStatus; import org.opensearch.rest.action.RestStatusToXContentListener; import org.opensearch.rest.action.RestToXContentListener; -import java.io.IOException; -import java.util.List; - -import static com.o19s.es.ltr.feature.store.StorableElement.generateId; -import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.ES_TYPE; -import static com.o19s.es.ltr.query.ValidatingLtrQueryBuilder.SUPPORTED_TYPES; -import static java.util.Arrays.asList; -import static java.util.Collections.unmodifiableList; -import static org.opensearch.core.rest.RestStatus.NOT_FOUND; -import static org.opensearch.core.rest.RestStatus.OK; +import com.o19s.es.ltr.action.ClearCachesAction; +import com.o19s.es.ltr.action.FeatureStoreAction; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; public class RestFeatureManager extends FeatureStoreBaseRestHandler { private final String type; @@ -58,7 +58,8 @@ public String getName() { @Override public List routes() { - return unmodifiableList(asList( + return unmodifiableList( + asList( new Route(RestRequest.Method.PUT, "/_ltr/{store}/_" + this.type + "/{name}"), new Route(RestRequest.Method.PUT, "/_ltr/_" + this.type + "/{name}"), new Route(RestRequest.Method.POST, "/_ltr/{store}/_" + this.type + "/{name}"), @@ -69,7 +70,8 @@ public List routes() { new Route(RestRequest.Method.GET, "/_ltr/_" + this.type + "/{name}"), new Route(RestRequest.Method.HEAD, "/_ltr/{store}/_" + this.type + "/{name}"), new Route(RestRequest.Method.HEAD, "/_ltr/_" + this.type + "/{name}") - )); + ) + ); } @Override @@ -93,34 +95,36 @@ RestChannelConsumer delete(NodeClient client, String type, String indexName, Res String name = request.param("name"); String id = generateId(type, name); String routing = request.param("routing"); - return (channel) -> { + return (channel) -> { RestStatusToXContentListener restR = new RestStatusToXContentListener<>(channel, (r) -> r.getLocation(routing)); - client.prepareDelete(indexName, id) - .setRouting(routing) - .execute(ActionListener.wrap((deleteResponse) -> { - // wrap the response so we can send another request to clear the cache - // usually we send only one transport request from the rest layer - // it's still unclear which direction we should take (thick or thin REST layer?) - ClearCachesAction.ClearCachesNodesRequest clearCache = new ClearCachesAction.ClearCachesNodesRequest(); - switch (type) { - case StoredFeature.TYPE: - clearCache.clearFeature(indexName, name); - break; - case StoredFeatureSet.TYPE: - clearCache.clearFeatureSet(indexName, name); - break; - case StoredLtrModel.TYPE: - clearCache.clearModel(indexName, name); - break; - } - client.execute(ClearCachesAction.INSTANCE, clearCache, ActionListener.wrap( - (r) -> restR.onResponse(deleteResponse), - // Is it good to fail the whole request if cache invalidation failed? - restR::onFailure - )); - }, - restR::onFailure - )); + client.prepareDelete(indexName, id).setRouting(routing).execute(ActionListener.wrap((deleteResponse) -> { + // wrap the response so we can send another request to clear the cache + // usually we send only one transport request from the rest layer + // it's still unclear which direction we should take (thick or thin REST layer?) + ClearCachesAction.ClearCachesNodesRequest clearCache = new ClearCachesAction.ClearCachesNodesRequest(); + switch (type) { + case StoredFeature.TYPE: + clearCache.clearFeature(indexName, name); + break; + case StoredFeatureSet.TYPE: + clearCache.clearFeatureSet(indexName, name); + break; + case StoredLtrModel.TYPE: + clearCache.clearModel(indexName, name); + break; + } + client + .execute( + ClearCachesAction.INSTANCE, + clearCache, + ActionListener + .wrap( + (r) -> restR.onResponse(deleteResponse), + // Is it good to fail the whole request if cache invalidation failed? + restR::onFailure + ) + ); + }, restR::onFailure)); }; } @@ -129,14 +133,12 @@ RestChannelConsumer get(NodeClient client, String type, String indexName, RestRe String name = request.param("name"); String routing = request.param("routing"); String id = generateId(type, name); - return (channel) -> client.prepareGet(indexName, id) - .setRouting(routing) - .execute(new RestToXContentListener(channel) { - @Override - protected RestStatus getStatus(final GetResponse response) { - return response.isExists() ? OK : NOT_FOUND; - } - }); + return (channel) -> client.prepareGet(indexName, id).setRouting(routing).execute(new RestToXContentListener(channel) { + @Override + protected RestStatus getStatus(final GetResponse response) { + return response.isExists() ? OK : NOT_FOUND; + } + }); } RestChannelConsumer addOrUpdate(NodeClient client, String type, String indexName, RestRequest request) throws IOException { @@ -159,14 +161,17 @@ RestChannelConsumer addOrUpdate(NodeClient client, String type, String indexName } if (request.method() == RestRequest.Method.POST && !elt.updatable()) { try { - throw new IllegalArgumentException("Element of type [" + elt.type() + "] are not updatable, " + - "please create a new one instead."); + throw new IllegalArgumentException( + "Element of type [" + elt.type() + "] are not updatable, " + "please create a new one instead." + ); } catch (IllegalArgumentException iae) { return (channel) -> channel.sendResponse(new BytesRestResponse(channel, RestStatus.METHOD_NOT_ALLOWED, iae)); } } FeatureStoreAction.FeatureStoreRequestBuilder builder = new FeatureStoreAction.FeatureStoreRequestBuilder( - client, FeatureStoreAction.INSTANCE); + client, + FeatureStoreAction.INSTANCE + ); if (request.method() == RestRequest.Method.PUT) { builder.request().setAction(FeatureStoreAction.FeatureStoreRequest.Action.CREATE); } else { diff --git a/src/main/java/com/o19s/es/ltr/rest/RestFeatureStoreCaches.java b/src/main/java/com/o19s/es/ltr/rest/RestFeatureStoreCaches.java index a013ab3c..07556983 100644 --- a/src/main/java/com/o19s/es/ltr/rest/RestFeatureStoreCaches.java +++ b/src/main/java/com/o19s/es/ltr/rest/RestFeatureStoreCaches.java @@ -16,23 +16,24 @@ package com.o19s.es.ltr.rest; -import org.opensearch.ltr.settings.LTRSettings; -import com.o19s.es.ltr.action.CachesStatsAction; -import com.o19s.es.ltr.action.ClearCachesAction; -import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesResponse; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableList; +import static org.opensearch.core.rest.RestStatus.OK; + +import java.util.List; + import org.opensearch.client.node.NodeClient; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ltr.settings.LTRSettings; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestResponse; import org.opensearch.rest.action.RestActions.NodesResponseRestListener; import org.opensearch.rest.action.RestBuilderListener; -import java.util.List; - -import static java.util.Arrays.asList; -import static java.util.Collections.unmodifiableList; -import static org.opensearch.core.rest.RestStatus.OK; +import com.o19s.es.ltr.action.CachesStatsAction; +import com.o19s.es.ltr.action.ClearCachesAction; +import com.o19s.es.ltr.action.ClearCachesAction.ClearCachesNodesResponse; /** * Clear cache (default store): @@ -53,11 +54,13 @@ public String getName() { @Override public List routes() { - return unmodifiableList(asList( - new Route(RestRequest.Method.POST, "/_ltr/_clearcache"), - new Route(RestRequest.Method.POST, "/_ltr/{store}/_clearcache"), - new Route(RestRequest.Method.GET, "/_ltr/_cachestats") - )); + return unmodifiableList( + asList( + new Route(RestRequest.Method.POST, "/_ltr/_clearcache"), + new Route(RestRequest.Method.POST, "/_ltr/{store}/_clearcache"), + new Route(RestRequest.Method.GET, "/_ltr/_cachestats") + ) + ); } @Override @@ -73,27 +76,25 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } } - @SuppressWarnings({"rawtypes", "unchecked"}) + @SuppressWarnings({ "rawtypes", "unchecked" }) private RestChannelConsumer getStats(NodeClient client) { - return (channel) -> client.execute(CachesStatsAction.INSTANCE, new CachesStatsAction.CachesStatsNodesRequest(), - new NodesResponseRestListener(channel)); + return (channel) -> client + .execute(CachesStatsAction.INSTANCE, new CachesStatsAction.CachesStatsNodesRequest(), new NodesResponseRestListener(channel)); } private RestChannelConsumer clearCache(RestRequest request, NodeClient client) { String storeName = indexName(request); ClearCachesAction.ClearCachesNodesRequest cacheRequest = new ClearCachesAction.ClearCachesNodesRequest(); cacheRequest.clearStore(storeName); - return (channel) -> client.execute(ClearCachesAction.INSTANCE, cacheRequest, - new RestBuilderListener(channel) { + return (channel) -> client + .execute(ClearCachesAction.INSTANCE, cacheRequest, new RestBuilderListener(channel) { @Override - public RestResponse buildResponse(ClearCachesNodesResponse clearCachesNodesResponse, - XContentBuilder builder) throws Exception { - builder.startObject() - .field("acknowledged", true); + public RestResponse buildResponse(ClearCachesNodesResponse clearCachesNodesResponse, XContentBuilder builder) + throws Exception { + builder.startObject().field("acknowledged", true); builder.endObject(); return new BytesRestResponse(OK, builder); } - } - ); + }); } } diff --git a/src/main/java/com/o19s/es/ltr/rest/RestSearchStoreElements.java b/src/main/java/com/o19s/es/ltr/rest/RestSearchStoreElements.java index 939abb17..87c849ea 100644 --- a/src/main/java/com/o19s/es/ltr/rest/RestSearchStoreElements.java +++ b/src/main/java/com/o19s/es/ltr/rest/RestSearchStoreElements.java @@ -15,21 +15,20 @@ */ package com.o19s.es.ltr.rest; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; -import org.opensearch.client.node.NodeClient; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.ltr.settings.LTRSettings; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.action.RestStatusToXContentListener; - -import java.util.List; - import static java.util.Arrays.asList; import static java.util.Collections.unmodifiableList; import static org.opensearch.index.query.QueryBuilders.boolQuery; import static org.opensearch.index.query.QueryBuilders.matchQuery; import static org.opensearch.index.query.QueryBuilders.termQuery; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.ltr.settings.LTRSettings; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestStatusToXContentListener; + public class RestSearchStoreElements extends FeatureStoreBaseRestHandler { private final String type; @@ -44,10 +43,9 @@ public String getName() { @Override public List routes() { - return unmodifiableList(asList( - new Route(RestRequest.Method.GET, "/_ltr/{store}/_" + type), - new Route(RestRequest.Method.GET, "/_ltr/_" + type) - )); + return unmodifiableList( + asList(new Route(RestRequest.Method.GET, "/_ltr/{store}/_" + type), new Route(RestRequest.Method.GET, "/_ltr/_" + type)) + ); } @Override @@ -67,11 +65,12 @@ RestChannelConsumer search(NodeClient client, String type, String indexName, Res if (prefix != null && !prefix.isEmpty()) { qb.must(matchQuery("name.prefix", prefix)); } - return (channel) -> client.prepareSearch(indexName) - .setQuery(qb) - .setSize(size) - .setFrom(from) - .execute(new RestStatusToXContentListener<>(channel)); + return (channel) -> client + .prepareSearch(indexName) + .setQuery(qb) + .setSize(size) + .setFrom(from) + .execute(new RestStatusToXContentListener<>(channel)); } } diff --git a/src/main/java/com/o19s/es/ltr/rest/RestStoreManager.java b/src/main/java/com/o19s/es/ltr/rest/RestStoreManager.java index 87582f8d..98798712 100644 --- a/src/main/java/com/o19s/es/ltr/rest/RestStoreManager.java +++ b/src/main/java/com/o19s/es/ltr/rest/RestStoreManager.java @@ -15,26 +15,27 @@ */ package com.o19s.es.ltr.rest; -import com.o19s.es.ltr.action.ListStoresAction; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableList; + +import java.io.IOException; +import java.util.List; + import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.admin.indices.exists.indices.IndicesExistsResponse; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.ltr.settings.LTRSettings; +import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestResponse; -import org.opensearch.core.rest.RestStatus; import org.opensearch.rest.action.RestBuilderListener; import org.opensearch.rest.action.RestToXContentListener; -import org.opensearch.rest.BaseRestHandler; -import java.io.IOException; -import java.util.List; - -import static java.util.Arrays.asList; -import static java.util.Collections.unmodifiableList; +import com.o19s.es.ltr.action.ListStoresAction; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; public class RestStoreManager extends FeatureStoreBaseRestHandler { @Override @@ -44,7 +45,8 @@ public String getName() { @Override public List routes() { - return unmodifiableList(asList( + return unmodifiableList( + asList( new Route(RestRequest.Method.PUT, "/_ltr/{store}"), new Route(RestRequest.Method.PUT, "/_ltr"), new Route(RestRequest.Method.POST, "/_ltr/{store}"), @@ -53,7 +55,8 @@ public List routes() { new Route(RestRequest.Method.DELETE, "/_ltr"), new Route(RestRequest.Method.GET, "/_ltr"), new Route(RestRequest.Method.GET, "/_ltr/{store}") - )); + ) + ); } /** @@ -99,34 +102,28 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } RestChannelConsumer listStores(NodeClient client) { - return (channel) -> new ListStoresAction.ListStoresActionBuilder(client).execute( - new RestToXContentListener<>(channel) - ); + return (channel) -> new ListStoresAction.ListStoresActionBuilder(client).execute(new RestToXContentListener<>(channel)); } RestChannelConsumer getStore(NodeClient client, String indexName) { - return (channel) -> client.admin().indices().prepareExists(indexName) - .execute(new RestBuilderListener(channel) { - @Override - public RestResponse buildResponse( - IndicesExistsResponse indicesExistsResponse, - XContentBuilder builder - ) throws Exception { - builder.startObject() - .field("exists", indicesExistsResponse.isExists()) - .endObject() - .close(); - return new BytesRestResponse( - indicesExistsResponse.isExists() ? RestStatus.OK : RestStatus.NOT_FOUND, - builder - ); - } - }); + return (channel) -> client + .admin() + .indices() + .prepareExists(indexName) + .execute(new RestBuilderListener(channel) { + @Override + public RestResponse buildResponse(IndicesExistsResponse indicesExistsResponse, XContentBuilder builder) throws Exception { + builder.startObject().field("exists", indicesExistsResponse.isExists()).endObject().close(); + return new BytesRestResponse(indicesExistsResponse.isExists() ? RestStatus.OK : RestStatus.NOT_FOUND, builder); + } + }); } RestChannelConsumer createIndex(NodeClient client, String indexName) { - return (channel) -> client.admin().indices() - .create(IndexFeatureStore.buildIndexRequest(indexName), new RestToXContentListener<>(channel)); + return (channel) -> client + .admin() + .indices() + .create(IndexFeatureStore.buildIndexRequest(indexName), new RestToXContentListener<>(channel)); } RestChannelConsumer deleteIndex(NodeClient client, String indexName) { diff --git a/src/main/java/com/o19s/es/ltr/utils/AbstractQueryBuilderUtils.java b/src/main/java/com/o19s/es/ltr/utils/AbstractQueryBuilderUtils.java index b92242c5..d0726e08 100644 --- a/src/main/java/com/o19s/es/ltr/utils/AbstractQueryBuilderUtils.java +++ b/src/main/java/com/o19s/es/ltr/utils/AbstractQueryBuilderUtils.java @@ -15,14 +15,14 @@ */ package com.o19s.es.ltr.utils; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.index.query.QueryBuilder; - import java.io.IOException; import java.util.ArrayList; import java.util.List; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.index.query.QueryBuilder; + /** * Contains a few methods copied from the AbstractQueryBuilder class. These methods are not accessible from sub classes * that do not reside in the same package. @@ -49,5 +49,4 @@ public static List readQueries(StreamInput in) throws IOException return queries; } - } diff --git a/src/main/java/com/o19s/es/ltr/utils/FeatureStoreLoader.java b/src/main/java/com/o19s/es/ltr/utils/FeatureStoreLoader.java index 11da3dc5..87e50c0f 100644 --- a/src/main/java/com/o19s/es/ltr/utils/FeatureStoreLoader.java +++ b/src/main/java/com/o19s/es/ltr/utils/FeatureStoreLoader.java @@ -16,10 +16,12 @@ package com.o19s.es.ltr.utils; -import com.o19s.es.ltr.feature.store.FeatureStore; import java.util.function.Supplier; + import org.opensearch.client.Client; +import com.o19s.es.ltr.feature.store.FeatureStore; + @FunctionalInterface public interface FeatureStoreLoader { FeatureStore load(String storeName, Supplier clientSupplier); diff --git a/src/main/java/com/o19s/es/ltr/utils/Scripting.java b/src/main/java/com/o19s/es/ltr/utils/Scripting.java index 4454ec4c..f1c0c206 100644 --- a/src/main/java/com/o19s/es/ltr/utils/Scripting.java +++ b/src/main/java/com/o19s/es/ltr/utils/Scripting.java @@ -15,12 +15,6 @@ */ package com.o19s.es.ltr.utils; -import org.apache.lucene.expressions.Expression; -import org.apache.lucene.expressions.js.JavascriptCompiler; -import org.opensearch.SpecialPermission; -import org.opensearch.script.ClassPermission; -import org.opensearch.script.ScriptException; - import java.security.AccessControlContext; import java.security.AccessController; import java.security.PrivilegedAction; @@ -29,6 +23,12 @@ import java.util.List; import java.util.Map; +import org.apache.lucene.expressions.Expression; +import org.apache.lucene.expressions.js.JavascriptCompiler; +import org.opensearch.SpecialPermission; +import org.opensearch.script.ClassPermission; +import org.opensearch.script.ScriptException; + public class Scripting { private Scripting() {} @@ -36,7 +36,7 @@ public static Object compile(String scriptSource) { return compile(scriptSource, JavascriptCompiler.DEFAULT_FUNCTIONS); } - public static Object compile(String scriptSource, Map functions) { + public static Object compile(String scriptSource, Map functions) { // classloader created here final SecurityManager sm = System.getSecurityManager(); if (sm != null) { diff --git a/src/main/java/com/o19s/es/template/mustache/CustomMustacheFactory.java b/src/main/java/com/o19s/es/template/mustache/CustomMustacheFactory.java index feb871c6..60e5f07e 100644 --- a/src/main/java/com/o19s/es/template/mustache/CustomMustacheFactory.java +++ b/src/main/java/com/o19s/es/template/mustache/CustomMustacheFactory.java @@ -16,21 +16,6 @@ package com.o19s.es.template.mustache; -import com.fasterxml.jackson.core.io.JsonStringEncoder; -import com.github.mustachejava.Code; -import com.github.mustachejava.DefaultMustacheFactory; -import com.github.mustachejava.DefaultMustacheVisitor; -import com.github.mustachejava.Mustache; -import com.github.mustachejava.MustacheException; -import com.github.mustachejava.MustacheVisitor; -import com.github.mustachejava.TemplateContext; -import com.github.mustachejava.codes.DefaultMustache; -import com.github.mustachejava.codes.IterableCode; -import com.github.mustachejava.codes.WriteCode; -import org.opensearch.core.common.Strings; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentType; - import java.io.IOException; import java.io.StringWriter; import java.io.Writer; @@ -47,6 +32,21 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.XContentBuilder; + +import com.fasterxml.jackson.core.io.JsonStringEncoder; +import com.github.mustachejava.Code; +import com.github.mustachejava.DefaultMustacheFactory; +import com.github.mustachejava.DefaultMustacheVisitor; +import com.github.mustachejava.Mustache; +import com.github.mustachejava.MustacheException; +import com.github.mustachejava.MustacheVisitor; +import com.github.mustachejava.TemplateContext; +import com.github.mustachejava.codes.DefaultMustache; +import com.github.mustachejava.codes.IterableCode; +import com.github.mustachejava.codes.WriteCode; + /** * XXX: shamelessly copied from the mustache module */ diff --git a/src/main/java/com/o19s/es/template/mustache/CustomReflectionObjectHandler.java b/src/main/java/com/o19s/es/template/mustache/CustomReflectionObjectHandler.java index 1eeb3605..641564de 100644 --- a/src/main/java/com/o19s/es/template/mustache/CustomReflectionObjectHandler.java +++ b/src/main/java/com/o19s/es/template/mustache/CustomReflectionObjectHandler.java @@ -16,9 +16,6 @@ package com.o19s.es.template.mustache; -import com.github.mustachejava.reflect.ReflectionObjectHandler; -import org.opensearch.common.util.iterable.Iterables; - import java.lang.reflect.Array; import java.util.AbstractMap; import java.util.Collection; @@ -27,6 +24,10 @@ import java.util.Map; import java.util.Set; +import org.opensearch.common.util.iterable.Iterables; + +import com.github.mustachejava.reflect.ReflectionObjectHandler; + /** * XXX: shamelessly copied from the mustache module */ diff --git a/src/main/java/com/o19s/es/template/mustache/MustacheUtils.java b/src/main/java/com/o19s/es/template/mustache/MustacheUtils.java index dc334840..ab0d7888 100644 --- a/src/main/java/com/o19s/es/template/mustache/MustacheUtils.java +++ b/src/main/java/com/o19s/es/template/mustache/MustacheUtils.java @@ -16,21 +16,21 @@ package com.o19s.es.template.mustache; -import com.github.mustachejava.Mustache; -import com.github.mustachejava.MustacheException; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.logging.log4j.util.Supplier; -import org.opensearch.SpecialPermission; -import org.apache.logging.log4j.LogManager; - - import java.io.StringReader; import java.io.StringWriter; import java.security.AccessController; import java.security.PrivilegedAction; import java.util.Map; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.logging.log4j.util.Supplier; +import org.opensearch.SpecialPermission; + +import com.github.mustachejava.Mustache; +import com.github.mustachejava.MustacheException; + public class MustacheUtils { public static final String TEMPLATE_LANGUAGE = "mustache"; private static final Logger logger = LogManager.getLogger(MustacheUtils.class); diff --git a/src/main/java/com/o19s/es/termstat/TermStatQuery.java b/src/main/java/com/o19s/es/termstat/TermStatQuery.java index aada0fe3..2e0d056e 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatQuery.java +++ b/src/main/java/com/o19s/es/termstat/TermStatQuery.java @@ -16,28 +16,29 @@ package com.o19s.es.termstat; -import com.o19s.es.explore.StatisticsHelper; -import com.o19s.es.explore.StatisticsHelper.AggrType; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + import org.apache.lucene.expressions.Expression; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Weight; -import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; import org.opensearch.ltr.settings.LTRSettings; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.Set; +import com.o19s.es.explore.StatisticsHelper; +import com.o19s.es.explore.StatisticsHelper.AggrType; public class TermStatQuery extends Query { private Expression expr; @@ -52,26 +53,33 @@ public TermStatQuery(Expression expr, AggrType aggr, AggrType posAggr, Set this.terms = terms; } - public Expression getExpr() { return this.expr; } - public AggrType getAggr() { return this.aggr; } - public AggrType getPosAggr() { return this.posAggr; } - public Set getTerms() { return this.terms; } + + public AggrType getAggr() { + return this.aggr; + } + + public AggrType getPosAggr() { + return this.posAggr; + } + + public Set getTerms() { + return this.terms; + } @SuppressWarnings("EqualsWhichDoesntCheckParameterClass") @Override public boolean equals(Object other) { - return sameClassAs(other) && - equalsTo(getClass().cast(other)); + return sameClassAs(other) && equalsTo(getClass().cast(other)); } private boolean equalsTo(TermStatQuery other) { return Objects.equals(expr.sourceText, other.expr.sourceText) - && Objects.equals(aggr, other.aggr) - && Objects.equals(posAggr, other.posAggr) - && Objects.equals(terms, other.terms); + && Objects.equals(aggr, other.aggr) + && Objects.equals(posAggr, other.posAggr) + && Objects.equals(terms, other.terms); } @Override @@ -80,7 +88,9 @@ public Query rewrite(IndexReader reader) throws IOException { } @Override - public int hashCode() { return Objects.hash(expr.sourceText, aggr, posAggr, terms); } + public int hashCode() { + return Objects.hash(expr.sourceText, aggr, posAggr, terms); + } @Override public String toString(String field) { @@ -88,8 +98,7 @@ public String toString(String field) { } @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) - throws IOException { + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { if (!LTRSettings.isLTRPluginEnabled()) { throw new IllegalStateException("LTR plugin is disabled. To enable, update ltr.plugin.enabled to true"); } @@ -109,12 +118,8 @@ static class TermStatWeight extends Weight { private final Set terms; private final Map termContexts; - TermStatWeight(IndexSearcher searcher, - TermStatQuery tsq, - Set terms, - ScoreMode scoreMode, - AggrType aggr, - AggrType posAggr) throws IOException { + TermStatWeight(IndexSearcher searcher, TermStatQuery tsq, Set terms, ScoreMode scoreMode, AggrType aggr, AggrType posAggr) + throws IOException { super(tsq); this.searcher = searcher; this.expression = tsq.expr; @@ -148,8 +153,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio Scorer scorer = this.scorer(context); int newDoc = scorer.iterator().advance(doc); if (newDoc == doc) { - return Explanation - .match(scorer.score(), "weight(" + this.expression.sourceText + " in doc " + newDoc + ")"); + return Explanation.match(scorer.score(), "weight(" + this.expression.sourceText + " in doc " + newDoc + ")"); } return Explanation.noMatch("no matching term"); } @@ -167,9 +171,7 @@ public boolean isCacheable(LeafReaderContext ctx) { @Override public void visit(QueryVisitor visitor) { - Term[] acceptedTerms = terms.stream().filter( - t -> visitor.acceptField(t.field()) - ).toArray(Term[]::new); + Term[] acceptedTerms = terms.stream().filter(t -> visitor.acceptField(t.field())).toArray(Term[]::new); if (acceptedTerms.length > 0) { QueryVisitor v = visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this); diff --git a/src/main/java/com/o19s/es/termstat/TermStatQueryBuilder.java b/src/main/java/com/o19s/es/termstat/TermStatQueryBuilder.java index d5ae474f..80732543 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatQueryBuilder.java +++ b/src/main/java/com/o19s/es/termstat/TermStatQueryBuilder.java @@ -15,9 +15,14 @@ */ package com.o19s.es.termstat; -import com.o19s.es.explore.StatisticsHelper.AggrType; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Set; -import com.o19s.es.ltr.utils.Scripting; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute; @@ -36,13 +41,8 @@ import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryShardContext; -import java.io.IOException; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Objects; -import java.util.Set; +import com.o19s.es.explore.StatisticsHelper.AggrType; +import com.o19s.es.ltr.utils.Scripting; public class TermStatQueryBuilder extends AbstractQueryBuilder implements NamedWriteable { public static final String NAME = "term_stat"; @@ -75,9 +75,7 @@ public class TermStatQueryBuilder extends AbstractQueryBuilder text) { - this.fields = text.toArray(new String[]{}); + this.fields = text.toArray(new String[] {}); return this; } - public String posAggr() { return pos_aggr; } + public TermStatQueryBuilder posAggr(String pos_aggr) { this.pos_aggr = pos_aggr; return this; } - public String[] terms() { return terms; } + public String[] terms() { + return terms; + } + public TermStatQueryBuilder terms(String[] terms) { this.terms = terms; return this; } + public TermStatQueryBuilder terms(List terms) { - this.terms = terms.toArray(new String[]{}); + this.terms = terms.toArray(new String[] {}); return this; } } diff --git a/src/main/java/com/o19s/es/termstat/TermStatScorer.java b/src/main/java/com/o19s/es/termstat/TermStatScorer.java index 8382ba71..a7379439 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatScorer.java +++ b/src/main/java/com/o19s/es/termstat/TermStatScorer.java @@ -15,13 +15,15 @@ */ package com.o19s.es.termstat; -import com.o19s.es.explore.StatisticsHelper; -import com.o19s.es.explore.StatisticsHelper.AggrType; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + import org.apache.lucene.expressions.Bindings; import org.apache.lucene.expressions.Expression; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; - import org.apache.lucene.index.TermStates; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DoubleValues; @@ -30,10 +32,8 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; +import com.o19s.es.explore.StatisticsHelper; +import com.o19s.es.explore.StatisticsHelper.AggrType; public class TermStatScorer extends Scorer { private final DocIdSetIterator iter; @@ -48,15 +48,17 @@ public class TermStatScorer extends Scorer { private final ScoreMode scoreMode; private final Map termContexts; - public TermStatScorer(TermStatQuery.TermStatWeight weight, - IndexSearcher searcher, - LeafReaderContext context, - Expression compiledExpression, - Set terms, - ScoreMode scoreMode, - AggrType aggr, - AggrType posAggr, - Map termContexts) { + public TermStatScorer( + TermStatQuery.TermStatWeight weight, + IndexSearcher searcher, + LeafReaderContext context, + Expression compiledExpression, + Set terms, + ScoreMode scoreMode, + AggrType aggr, + AggrType posAggr, + Map termContexts + ) { super(weight); this.context = context; this.compiledExpression = compiledExpression; @@ -69,6 +71,7 @@ public TermStatScorer(TermStatQuery.TermStatWeight weight, this.iter = DocIdSetIterator.all(context.reader().maxDoc()); } + @Override public DocIdSetIterator iterator() { return iter; @@ -90,7 +93,7 @@ public float score() throws IOException { // Prepare computed statistics StatisticsHelper computed = new StatisticsHelper(); HashMap termStatDict = new HashMap<>(); - Bindings bindings = new Bindings(){ + Bindings bindings = new Bindings() { @Override public DoubleValuesSource getDoubleValuesSource(String name) { return DoubleValuesSource.constant(termStatDict.get(name)); @@ -102,7 +105,7 @@ public DoubleValuesSource getDoubleValuesSource(String name) { return 0.0f; } - for(int i = 0; i < tsq.size(); i++) { + for (int i = 0; i < tsq.size(); i++) { // Update the term stat dictionary for the current term termStatDict.put("df", tsq.get("df").get(i)); termStatDict.put("idf", tsq.get("idf").get(i)); diff --git a/src/main/java/com/o19s/es/termstat/TermStatSupplier.java b/src/main/java/com/o19s/es/termstat/TermStatSupplier.java index 2d3d29f2..1ce87b3e 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatSupplier.java +++ b/src/main/java/com/o19s/es/termstat/TermStatSupplier.java @@ -15,34 +15,34 @@ */ package com.o19s.es.termstat; -import com.o19s.es.explore.StatisticsHelper; -import com.o19s.es.explore.StatisticsHelper.AggrType; +import java.io.IOException; +import java.util.AbstractMap; +import java.util.AbstractSet; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.ReaderUtil; import org.apache.lucene.index.Term; -import org.apache.lucene.index.TermsEnum; import org.apache.lucene.index.TermState; import org.apache.lucene.index.TermStates; +import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.similarities.ClassicSimilarity; -import java.io.IOException; -import java.util.AbstractMap; -import java.util.AbstractSet; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Set; - +import com.o19s.es.explore.StatisticsHelper; +import com.o19s.es.explore.StatisticsHelper.AggrType; -public class TermStatSupplier extends AbstractMap> { - private final List ACCEPTED_KEYS = Arrays.asList(new String[]{"df", "idf", "tf", "ttf", "tp"}); +public class TermStatSupplier extends AbstractMap> { + private final List ACCEPTED_KEYS = Arrays.asList(new String[] { "df", "idf", "tf", "ttf", "tp" }); private AggrType posAggrType = AggrType.AVG; private final ClassicSimilarity sim; @@ -59,9 +59,14 @@ public TermStatSupplier() { this.tp_stats = new StatisticsHelper(); } - public void bump (IndexSearcher searcher, LeafReaderContext context, - int docID, Set terms, - ScoreMode scoreMode, Map termContexts) throws IOException { + public void bump( + IndexSearcher searcher, + LeafReaderContext context, + int docID, + Set terms, + ScoreMode scoreMode, + Map termContexts + ) throws IOException { df_stats.getData().clear(); idf_stats.getData().clear(); tf_stats.getData().clear(); @@ -77,8 +82,7 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, TermStates termStates = termContexts.get(term); - assert termStates != null && termStates - .wasBuiltFor(ReaderUtil.getTopLevelContext(context)); + assert termStates != null && termStates.wasBuiltFor(ReaderUtil.getTopLevelContext(context)); TermState state = termStates.get(context); @@ -100,12 +104,12 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, postingsEnum = termsEnum.postings(postingsEnum, PostingsEnum.ALL); // Verify document is in postings - if (postingsEnum.advance(docID) == docID){ + if (postingsEnum.advance(docID) == docID) { matchedTermCount++; tf_stats.add(postingsEnum.freq()); - if(postingsEnum.freq() > 0) { + if (postingsEnum.freq() > 0) { StatisticsHelper positions = new StatisticsHelper(); for (int i = 0; i < postingsEnum.freq(); i++) { positions.add((float) postingsEnum.nextPosition() + 1); @@ -115,7 +119,7 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, } else { tp_stats.add(0.0f); } - // If document isn't in postings default to 0 for tf/tp + // If document isn't in postings default to 0 for tf/tp } else { tf_stats.add(0.0f); tp_stats.add(0.0f); @@ -150,7 +154,7 @@ public boolean containsKey(Object statType) { public ArrayList get(Object statType) { String key = (String) statType; - switch(key) { + switch (key) { case "df": return df_stats.getData(); @@ -236,4 +240,3 @@ private void insertZeroes() { tp_stats.add(0.0f); } } - diff --git a/src/main/java/org/opensearch/ltr/breaker/LTRCircuitBreakerService.java b/src/main/java/org/opensearch/ltr/breaker/LTRCircuitBreakerService.java index 84e4b7e9..9621a218 100644 --- a/src/main/java/org/opensearch/ltr/breaker/LTRCircuitBreakerService.java +++ b/src/main/java/org/opensearch/ltr/breaker/LTRCircuitBreakerService.java @@ -15,14 +15,14 @@ package org.opensearch.ltr.breaker; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ltr.settings.LTRSettings; import org.opensearch.monitor.jvm.JvmService; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; - /** * Class {@code LTRCircuitBreakerService} provide storing, retrieving circuit breakers functions. * diff --git a/src/main/java/org/opensearch/ltr/rest/RestStatsLTRAction.java b/src/main/java/org/opensearch/ltr/rest/RestStatsLTRAction.java index 43f35b8a..db4644bf 100644 --- a/src/main/java/org/opensearch/ltr/rest/RestStatsLTRAction.java +++ b/src/main/java/org/opensearch/ltr/rest/RestStatsLTRAction.java @@ -15,13 +15,8 @@ package org.opensearch.ltr.rest; -import org.opensearch.client.node.NodeClient; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.transport.LTRStatsAction; -import org.opensearch.ltr.transport.LTRStatsRequest; -import org.opensearch.rest.BaseRestHandler; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.action.RestActions; +import static com.o19s.es.ltr.LtrQueryParserPlugin.LTR_BASE_URI; +import static com.o19s.es.ltr.LtrQueryParserPlugin.LTR_LEGACY_BASE_URI; import java.io.IOException; import java.util.HashSet; @@ -30,8 +25,13 @@ import java.util.Set; import java.util.stream.Collectors; -import static com.o19s.es.ltr.LtrQueryParserPlugin.LTR_BASE_URI; -import static com.o19s.es.ltr.LtrQueryParserPlugin.LTR_LEGACY_BASE_URI; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.transport.LTRStatsAction; +import org.opensearch.ltr.transport.LTRStatsRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestActions; /** * Provide an API to get information on the plugin usage and @@ -62,41 +62,40 @@ public List routes() { @Override public List replacedRoutes() { - return List.of( + return List + .of( new ReplacedRoute( - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/{nodeId}/stats/"), - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/{nodeId}/stats/") + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/{nodeId}/stats/"), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/{nodeId}/stats/") ), new ReplacedRoute( - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/{nodeId}/stats/{stat}"), - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/{nodeId}/stats/{stat}") + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/{nodeId}/stats/{stat}"), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/{nodeId}/stats/{stat}") ), new ReplacedRoute( - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/stats/"), - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/stats/") + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/stats/"), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/stats/") ), new ReplacedRoute( - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/stats/{stat}"), - RestRequest.Method.GET, - String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/stats/{stat}") + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_BASE_URI, "/stats/{stat}"), + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s%s", LTR_LEGACY_BASE_URI, "/stats/{stat}") ) - ); + ); } @Override - @SuppressWarnings({"rawtypes", "unchecked"}) + @SuppressWarnings({ "rawtypes", "unchecked" }) protected RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException { final LTRStatsRequest ltrStatsRequest = getRequest(request); - return (channel) -> client.execute(LTRStatsAction.INSTANCE, - ltrStatsRequest, - new RestActions.NodesResponseRestListener(channel)); + return (channel) -> client.execute(LTRStatsAction.INSTANCE, ltrStatsRequest, new RestActions.NodesResponseRestListener(channel)); } /** @@ -106,9 +105,7 @@ protected RestChannelConsumer prepareRequest(final RestRequest request, final No * @return LTRStatsRequest */ private LTRStatsRequest getRequest(final RestRequest request) { - final LTRStatsRequest ltrStatsRequest = new LTRStatsRequest( - splitCommaSeparatedParam(request, "nodeId") - ); + final LTRStatsRequest ltrStatsRequest = new LTRStatsRequest(splitCommaSeparatedParam(request, "nodeId")); ltrStatsRequest.timeout(request.param("timeout")); final List requestedStats = List.of(splitCommaSeparatedParam(request, "stat")); @@ -125,31 +122,26 @@ private LTRStatsRequest getRequest(final RestRequest request) { return ltrStatsRequest; } - private Set getStatsToBeRetrieved( - final RestRequest request, - final Set validStats, - final List requestedStats) { + private Set getStatsToBeRetrieved(final RestRequest request, final Set validStats, final List requestedStats) { if (requestedStats.contains(LTRStatsRequest.ALL_STATS_KEY)) { - throw new IllegalArgumentException(String.format(Locale.ROOT, "Request %s contains both %s and individual stats", - request.path(), LTRStatsRequest.ALL_STATS_KEY)); + throw new IllegalArgumentException( + String + .format(Locale.ROOT, "Request %s contains both %s and individual stats", request.path(), LTRStatsRequest.ALL_STATS_KEY) + ); } - final Set invalidStats = requestedStats.stream() - .filter(s -> !validStats.contains(s)) - .collect(Collectors.toSet()); + final Set invalidStats = requestedStats.stream().filter(s -> !validStats.contains(s)).collect(Collectors.toSet()); if (!invalidStats.isEmpty()) { - throw new IllegalArgumentException( - unrecognized(request, invalidStats, new HashSet<>(requestedStats), "stat")); + throw new IllegalArgumentException(unrecognized(request, invalidStats, new HashSet<>(requestedStats), "stat")); } return new HashSet<>(requestedStats); } private boolean isAllStatsRequested(final List requestedStats) { - return requestedStats.isEmpty() - || (requestedStats.size() == 1 && requestedStats.contains(LTRStatsRequest.ALL_STATS_KEY)); + return requestedStats.isEmpty() || (requestedStats.size() == 1 && requestedStats.contains(LTRStatsRequest.ALL_STATS_KEY)); } private String[] splitCommaSeparatedParam(final RestRequest request, final String paramName) { diff --git a/src/main/java/org/opensearch/ltr/settings/LTRSettings.java b/src/main/java/org/opensearch/ltr/settings/LTRSettings.java index 3c1616f9..7305aed8 100644 --- a/src/main/java/org/opensearch/ltr/settings/LTRSettings.java +++ b/src/main/java/org/opensearch/ltr/settings/LTRSettings.java @@ -15,11 +15,9 @@ package org.opensearch.ltr.settings; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; +import static java.util.Collections.unmodifiableMap; +import static org.opensearch.common.settings.Setting.Property.Dynamic; +import static org.opensearch.common.settings.Setting.Property.NodeScope; import java.util.ArrayList; import java.util.HashMap; @@ -27,9 +25,11 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import static java.util.Collections.unmodifiableMap; -import static org.opensearch.common.settings.Setting.Property.Dynamic; -import static org.opensearch.common.settings.Setting.Property.NodeScope; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; public class LTRSettings { @@ -76,12 +76,10 @@ public static synchronized LTRSettings getInstance() { private void setSettingsUpdateConsumers() { for (Setting setting : settings.values()) { - clusterService.getClusterSettings().addSettingsUpdateConsumer( - setting, - newVal -> { - logger.info("[LTR] The value of setting [{}] changed to [{}]", setting.getKey(), newVal); - latestSettings.put(setting.getKey(), newVal); - }); + clusterService.getClusterSettings().addSettingsUpdateConsumer(setting, newVal -> { + logger.info("[LTR] The value of setting [{}] changed to [{}]", setting.getKey(), newVal); + latestSettings.put(setting.getKey(), newVal); + }); } } diff --git a/src/main/java/org/opensearch/ltr/stats/LTRStat.java b/src/main/java/org/opensearch/ltr/stats/LTRStat.java index 49e8577f..d05ff264 100644 --- a/src/main/java/org/opensearch/ltr/stats/LTRStat.java +++ b/src/main/java/org/opensearch/ltr/stats/LTRStat.java @@ -15,10 +15,10 @@ package org.opensearch.ltr.stats; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; - import java.util.function.Supplier; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; + /** * Class represents a stat the plugin keeps track of */ @@ -60,8 +60,7 @@ public T getValue() { */ public void increment() { if (!(supplier instanceof CounterSupplier)) { - throw new UnsupportedOperationException( - "cannot increment the supplier: " + supplier.getClass().getName()); + throw new UnsupportedOperationException("cannot increment the supplier: " + supplier.getClass().getName()); } ((CounterSupplier) supplier).increment(); } diff --git a/src/main/java/org/opensearch/ltr/stats/LTRStats.java b/src/main/java/org/opensearch/ltr/stats/LTRStats.java index 2d444923..d03717c0 100644 --- a/src/main/java/org/opensearch/ltr/stats/LTRStats.java +++ b/src/main/java/org/opensearch/ltr/stats/LTRStats.java @@ -73,7 +73,6 @@ public void addStats(String key, LTRStat stat) { this.stats.put(key, stat); } - /** * Get a map of the stats that are kept at the node level * @@ -93,9 +92,10 @@ public Map> getClusterStats() { } private Map> getClusterOrNodeStats(Boolean isClusterStats) { - return stats.entrySet() - .stream() - .filter(e -> e.getValue().isClusterLevel() == isClusterStats) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + return stats + .entrySet() + .stream() + .filter(e -> e.getValue().isClusterLevel() == isClusterStats) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } } diff --git a/src/main/java/org/opensearch/ltr/stats/StatName.java b/src/main/java/org/opensearch/ltr/stats/StatName.java index b5a51e90..1c5b0516 100644 --- a/src/main/java/org/opensearch/ltr/stats/StatName.java +++ b/src/main/java/org/opensearch/ltr/stats/StatName.java @@ -37,8 +37,6 @@ public String getName() { } public static Set getNames() { - return Arrays.stream(StatName.values()) - .map(StatName::getName) - .collect(Collectors.toSet()); + return Arrays.stream(StatName.values()).map(StatName::getName).collect(Collectors.toSet()); } } diff --git a/src/main/java/org/opensearch/ltr/stats/suppliers/CacheStatsOnNodeSupplier.java b/src/main/java/org/opensearch/ltr/stats/suppliers/CacheStatsOnNodeSupplier.java index 7c334185..3f0a52b1 100644 --- a/src/main/java/org/opensearch/ltr/stats/suppliers/CacheStatsOnNodeSupplier.java +++ b/src/main/java/org/opensearch/ltr/stats/suppliers/CacheStatsOnNodeSupplier.java @@ -15,14 +15,15 @@ package org.opensearch.ltr.stats.suppliers; -import com.o19s.es.ltr.feature.store.index.Caches; -import org.opensearch.common.cache.Cache; - import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.function.Supplier; +import org.opensearch.common.cache.Cache; + +import com.o19s.es.ltr.feature.store.index.Caches; + /** * Aggregate stats on the cache used by the plugin per node. */ diff --git a/src/main/java/org/opensearch/ltr/stats/suppliers/PluginHealthStatusSupplier.java b/src/main/java/org/opensearch/ltr/stats/suppliers/PluginHealthStatusSupplier.java index 521db9f2..7e66cc10 100644 --- a/src/main/java/org/opensearch/ltr/stats/suppliers/PluginHealthStatusSupplier.java +++ b/src/main/java/org/opensearch/ltr/stats/suppliers/PluginHealthStatusSupplier.java @@ -15,14 +15,14 @@ package org.opensearch.ltr.stats.suppliers; +import java.util.List; +import java.util.function.Supplier; + import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.ltr.breaker.LTRCircuitBreakerService; import org.opensearch.ltr.stats.suppliers.utils.StoreUtils; -import java.util.List; -import java.util.function.Supplier; - /** * Supplier for an overall plugin health status, which is based on the * aggregate store health and the circuit breaker state. @@ -35,8 +35,7 @@ public class PluginHealthStatusSupplier implements Supplier { private final StoreUtils storeUtils; private final LTRCircuitBreakerService ltrCircuitBreakerService; - protected PluginHealthStatusSupplier(StoreUtils storeUtils, - LTRCircuitBreakerService ltrCircuitBreakerService) { + protected PluginHealthStatusSupplier(StoreUtils storeUtils, LTRCircuitBreakerService ltrCircuitBreakerService) { this.storeUtils = storeUtils; this.ltrCircuitBreakerService = ltrCircuitBreakerService; } @@ -51,9 +50,7 @@ public String get() { private String getAggregateStoresStatus() { List storeNames = storeUtils.getAllLtrStoreNames(); - return storeNames.stream() - .map(storeUtils::getLtrStoreHealthStatus) - .reduce(STATUS_GREEN, this::combineStatuses); + return storeNames.stream().map(storeUtils::getLtrStoreHealthStatus).reduce(STATUS_GREEN, this::combineStatuses); } private String combineStatuses(String status1, String status2) { @@ -67,9 +64,10 @@ private String combineStatuses(String status1, String status2) { } public static PluginHealthStatusSupplier create( - final Client client, - final ClusterService clusterService, - LTRCircuitBreakerService ltrCircuitBreakerService) { + final Client client, + final ClusterService clusterService, + LTRCircuitBreakerService ltrCircuitBreakerService + ) { return new PluginHealthStatusSupplier(new StoreUtils(client, clusterService), ltrCircuitBreakerService); } } diff --git a/src/main/java/org/opensearch/ltr/stats/suppliers/StoreStatsSupplier.java b/src/main/java/org/opensearch/ltr/stats/suppliers/StoreStatsSupplier.java index d419a223..2eaa4dd6 100644 --- a/src/main/java/org/opensearch/ltr/stats/suppliers/StoreStatsSupplier.java +++ b/src/main/java/org/opensearch/ltr/stats/suppliers/StoreStatsSupplier.java @@ -15,16 +15,16 @@ package org.opensearch.ltr.stats.suppliers; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.ltr.stats.suppliers.utils.StoreUtils; - import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Supplier; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ltr.stats.suppliers.utils.StoreUtils; + /** * A supplier to provide stats on the LTR stores. It retrieves basic information * on the store, such as the health of the underlying index and number of documents diff --git a/src/main/java/org/opensearch/ltr/stats/suppliers/utils/StoreUtils.java b/src/main/java/org/opensearch/ltr/stats/suppliers/utils/StoreUtils.java index ff77de25..c3163d2b 100644 --- a/src/main/java/org/opensearch/ltr/stats/suppliers/utils/StoreUtils.java +++ b/src/main/java/org/opensearch/ltr/stats/suppliers/utils/StoreUtils.java @@ -15,9 +15,13 @@ package org.opensearch.ltr.stats.suppliers.utils; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + import org.opensearch.action.admin.cluster.state.ClusterStateRequest; import org.opensearch.action.search.SearchType; import org.opensearch.client.Client; @@ -30,12 +34,9 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; /** * A utility class to provide details on the LTR stores. It queries the underlying @@ -61,9 +62,11 @@ public boolean checkLtrStoreExists(String storeName) { } public List getAllLtrStoreNames() { - String[] names = indexNameExpressionResolver.concreteIndexNames(clusterService.state(), - new ClusterStateRequest().indices( - IndexFeatureStore.DEFAULT_STORE, IndexFeatureStore.STORE_PREFIX + "*")); + String[] names = indexNameExpressionResolver + .concreteIndexNames( + clusterService.state(), + new ClusterStateRequest().indices(IndexFeatureStore.DEFAULT_STORE, IndexFeatureStore.STORE_PREFIX + "*") + ); return Arrays.asList(names); } @@ -72,8 +75,8 @@ public String getLtrStoreHealthStatus(String storeName) { throw new IndexNotFoundException(storeName); } ClusterIndexHealth indexHealth = new ClusterIndexHealth( - clusterService.state().metadata().index(storeName), - clusterService.state().getRoutingTable().index(storeName) + clusterService.state().metadata().index(storeName), + clusterService.state().getRoutingTable().index(storeName) ); return indexHealth.getStatus().name().toLowerCase(Locale.getDefault()); @@ -118,10 +121,11 @@ public long getModelCount(String storeName) { } private SearchHits searchStore(String storeName, String docType) { - return client.prepareSearch(storeName) - .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) - .setQuery(QueryBuilders.termQuery("type", docType)) - .get() - .getHits(); + return client + .prepareSearch(storeName) + .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) + .setQuery(QueryBuilders.termQuery("type", docType)) + .get() + .getHits(); } } diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsAction.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsAction.java index 1dc280a6..49b77a54 100644 --- a/src/main/java/org/opensearch/ltr/transport/LTRStatsAction.java +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsAction.java @@ -27,8 +27,7 @@ protected LTRStatsAction() { super(NAME, LTRStatsNodesResponse::new); } - public static class LTRStatsRequestBuilder - extends ActionRequestBuilder { + public static class LTRStatsRequestBuilder extends ActionRequestBuilder { private static final String[] nodeIds = null; protected LTRStatsRequestBuilder(final OpenSearchClient client) { diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeRequest.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeRequest.java index d66f2e15..f7a9e323 100644 --- a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeRequest.java +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeRequest.java @@ -15,12 +15,11 @@ package org.opensearch.ltr.transport; +import java.io.IOException; + import org.opensearch.action.support.nodes.BaseNodeRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.transport.TransportRequest; - -import java.io.IOException; /** * LTRStatsNodeRequest to get a node stat @@ -46,4 +45,4 @@ public void writeTo(final StreamOutput out) throws IOException { super.writeTo(out); ltrStatsRequest.writeTo(out); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeResponse.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeResponse.java index afea1796..5b06dc18 100644 --- a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeResponse.java +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodeResponse.java @@ -15,6 +15,9 @@ package org.opensearch.ltr.transport; +import java.io.IOException; +import java.util.Map; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; @@ -22,9 +25,6 @@ import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Map; - public class LTRStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { private final Map statsMap; diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodesResponse.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodesResponse.java index db7c3570..b59604a3 100644 --- a/src/main/java/org/opensearch/ltr/transport/LTRStatsNodesResponse.java +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsNodesResponse.java @@ -15,6 +15,10 @@ package org.opensearch.ltr.transport; +import java.io.IOException; +import java.util.List; +import java.util.Map; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -23,10 +27,6 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; -import java.util.Map; - public class LTRStatsNodesResponse extends BaseNodesResponse implements ToXContent { private static final String NODES_KEY = "nodes"; private final Map clusterStats; @@ -37,9 +37,11 @@ public LTRStatsNodesResponse(final StreamInput in) throws IOException { } public LTRStatsNodesResponse( - final ClusterName clusterName, - final List nodeResponses, - final List failures, Map clusterStats) { + final ClusterName clusterName, + final List nodeResponses, + final List failures, + Map clusterStats + ) { super(clusterName, nodeResponses, failures); this.clusterStats = clusterStats; } diff --git a/src/main/java/org/opensearch/ltr/transport/LTRStatsRequest.java b/src/main/java/org/opensearch/ltr/transport/LTRStatsRequest.java index f618a296..c4ca8386 100644 --- a/src/main/java/org/opensearch/ltr/transport/LTRStatsRequest.java +++ b/src/main/java/org/opensearch/ltr/transport/LTRStatsRequest.java @@ -15,15 +15,14 @@ package org.opensearch.ltr.transport; +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import java.io.IOException; -import java.util.HashSet; -import java.util.Set; - public class LTRStatsRequest extends BaseNodesRequest { /** @@ -87,4 +86,4 @@ public void writeTo(final StreamOutput out) throws IOException { super.writeTo(out); out.writeStringCollection(statsToBeRetrieved); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/ltr/transport/TransportLTRStatsAction.java b/src/main/java/org/opensearch/ltr/transport/TransportLTRStatsAction.java index 7303c5cc..2bd3877b 100644 --- a/src/main/java/org/opensearch/ltr/transport/TransportLTRStatsAction.java +++ b/src/main/java/org/opensearch/ltr/transport/TransportLTRStatsAction.java @@ -15,6 +15,12 @@ package org.opensearch.ltr.transport; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; @@ -25,45 +31,48 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - public class TransportLTRStatsAction extends - TransportNodesAction { + TransportNodesAction { private final LTRStats ltrStats; @Inject public TransportLTRStatsAction( - final ThreadPool threadPool, - final ClusterService clusterService, - final TransportService transportService, - final ActionFilters actionFilters, - final LTRStats ltrStats) { + final ThreadPool threadPool, + final ClusterService clusterService, + final TransportService transportService, + final ActionFilters actionFilters, + final LTRStats ltrStats + ) { - super(LTRStatsAction.NAME, threadPool, clusterService, transportService, - actionFilters, LTRStatsRequest::new, LTRStatsNodeRequest::new, - ThreadPool.Names.MANAGEMENT, LTRStatsNodeResponse.class); + super( + LTRStatsAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + LTRStatsRequest::new, + LTRStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + LTRStatsNodeResponse.class + ); this.ltrStats = ltrStats; } @Override protected LTRStatsNodesResponse newResponse( - final LTRStatsRequest request, - final List nodeResponses, - final List failures) { + final LTRStatsRequest request, + final List nodeResponses, + final List failures + ) { final Set statsToBeRetrieved = request.getStatsToBeRetrieved(); - final Map clusterStats = ltrStats.getClusterStats() - .entrySet() - .stream() - .filter(e -> statsToBeRetrieved.contains(e.getKey())) - .collect( - Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue()) - ); + final Map clusterStats = ltrStats + .getClusterStats() + .entrySet() + .stream() + .filter(e -> statsToBeRetrieved.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue())); return new LTRStatsNodesResponse(clusterService.getClusterName(), nodeResponses, failures, clusterStats); } @@ -83,13 +92,12 @@ protected LTRStatsNodeResponse nodeOperation(final LTRStatsNodeRequest request) final LTRStatsRequest ltrStatsRequest = request.getLTRStatsNodesRequest(); final Set statsToBeRetrieved = ltrStatsRequest.getStatsToBeRetrieved(); - final Map statValues = ltrStats.getNodeStats() - .entrySet() - .stream() - .filter(e -> statsToBeRetrieved.contains(e.getKey())) - .collect( - Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue()) - ); + final Map statValues = ltrStats + .getNodeStats() + .entrySet() + .stream() + .filter(e -> statsToBeRetrieved.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getValue())); return new LTRStatsNodeResponse(clusterService.localNode(), statValues); } } diff --git a/src/test/java/com/o19s/es/explore/ExplorerQueryBuilderTests.java b/src/test/java/com/o19s/es/explore/ExplorerQueryBuilderTests.java index 03ecdf18..41010c18 100644 --- a/src/test/java/com/o19s/es/explore/ExplorerQueryBuilderTests.java +++ b/src/test/java/com/o19s/es/explore/ExplorerQueryBuilderTests.java @@ -15,39 +15,43 @@ */ package com.o19s.es.explore; -import org.opensearch.ltr.stats.LTRStat; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; -import com.o19s.es.ltr.LtrQueryParserPlugin; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableMap; +import static org.hamcrest.CoreMatchers.instanceOf; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; + import org.apache.lucene.search.Query; import org.opensearch.core.common.ParsingException; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ltr.stats.LTRStat; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; import org.opensearch.plugins.Plugin; import org.opensearch.test.AbstractQueryTestCase; import org.opensearch.test.TestGeoShapeFieldMapperPlugin; -import java.io.IOException; -import java.util.Collection; -import java.util.HashMap; -import static java.util.Collections.unmodifiableMap; -import static java.util.Arrays.asList; -import static org.hamcrest.CoreMatchers.instanceOf; +import com.o19s.es.ltr.LtrQueryParserPlugin; public class ExplorerQueryBuilderTests extends AbstractQueryTestCase { // TODO: Remove the TestGeoShapeFieldMapperPlugin once upstream has completed the migration. protected Collection> getPlugins() { return asList(LtrQueryParserPlugin.class, TestGeoShapeFieldMapperPlugin.class); } - private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() {{ - put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - }})); + + private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() { + { + put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + } + })); + @Override protected ExplorerQueryBuilder doCreateTestQueryBuilder() { ExplorerQueryBuilder builder = new ExplorerQueryBuilder(); @@ -58,18 +62,18 @@ protected ExplorerQueryBuilder doCreateTestQueryBuilder() { } public void testParse() throws Exception { - String query = " {" + - " \"match_explorer\": {" + - " \"query\": {" + - " \"match\": {" + - " \"title\": \"test\"" + - " }" + - " }," + - " \"type\": \"stddev_raw_tf\"" + - " }" + - "}"; - - ExplorerQueryBuilder builder = (ExplorerQueryBuilder)parseQuery(query); + String query = " {" + + " \"match_explorer\": {" + + " \"query\": {" + + " \"match\": {" + + " \"title\": \"test\"" + + " }" + + " }," + + " \"type\": \"stddev_raw_tf\"" + + " }" + + "}"; + + ExplorerQueryBuilder builder = (ExplorerQueryBuilder) parseQuery(query); assertNotNull(builder.query()); assertEquals(builder.statsType(), "stddev_raw_tf"); @@ -91,25 +95,21 @@ public void testMustRewrite() throws IOException { } public void testMissingQuery() throws Exception { - String query = " {" + - " \"match_explorer\": {" + - " \"type\": \"stddev_raw_tf\"" + - " }" + - "}"; + String query = " {" + " \"match_explorer\": {" + " \"type\": \"stddev_raw_tf\"" + " }" + "}"; expectThrows(ParsingException.class, () -> parseQuery(query)); } public void testMissingType() throws Exception { - String query = " {" + - " \"match_explorer\": {" + - " \"query\": {" + - " \"match\": {" + - " \"title\": \"test\"" + - " }" + - " }" + - " }" + - "}"; + String query = " {" + + " \"match_explorer\": {" + + " \"query\": {" + + " \"match\": {" + + " \"title\": \"test\"" + + " }" + + " }" + + " }" + + "}"; expectThrows(ParsingException.class, () -> parseQuery(query)); } diff --git a/src/test/java/com/o19s/es/explore/ExplorerQueryTests.java b/src/test/java/com/o19s/es/explore/ExplorerQueryTests.java index fad5ad73..cbaf16fd 100644 --- a/src/test/java/com/o19s/es/explore/ExplorerQueryTests.java +++ b/src/test/java/com/o19s/es/explore/ExplorerQueryTests.java @@ -15,10 +15,12 @@ */ package com.o19s.es.explore; -import org.opensearch.ltr.stats.LTRStat; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; +import static java.util.Collections.unmodifiableMap; +import static org.hamcrest.Matchers.equalTo; + +import java.util.HashMap; +import java.util.Map; + import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.StoredField; @@ -37,15 +39,13 @@ import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.common.lucene.Lucene; import org.junit.After; import org.junit.Before; - -import java.util.HashMap; -import java.util.Map; - -import static java.util.Collections.unmodifiableMap; -import static org.hamcrest.Matchers.equalTo; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.ltr.stats.LTRStat; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; public class ExplorerQueryTests extends LuceneTestCase { private Directory dir; @@ -55,19 +55,18 @@ public class ExplorerQueryTests extends LuceneTestCase { // Some simple documents to index private final String[] docs = new String[] { - "how now brown cow", - "brown is the color of cows", - "brown cow", - "banana cows are yummy", - "dance with monkeys and do not stop to dance", - "break on through to the other side... break on through to the other side... break on through to the other side" - }; + "how now brown cow", + "brown is the color of cows", + "brown cow", + "banana cows are yummy", + "dance with monkeys and do not stop to dance", + "break on through to the other side... break on through to the other side... break on through to the other side" }; @Before public void setupIndex() throws Exception { dir = new ByteBuffersDirectory(); - try(IndexWriter indexWriter = new IndexWriter(dir, new IndexWriterConfig(Lucene.STANDARD_ANALYZER))) { + try (IndexWriter indexWriter = new IndexWriter(dir, new IndexWriterConfig(Lucene.STANDARD_ANALYZER))) { for (int i = 0; i < docs.length; i++) { Document doc = new Document(); doc.add(new Field("_id", Integer.toString(i + 1), StoredField.TYPE)); @@ -79,10 +78,8 @@ public void setupIndex() throws Exception { reader = DirectoryReader.open(dir); searcher = new IndexSearcher(reader); Map> stats = new HashMap<>(); - stats.put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - stats.put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); + stats.put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + stats.put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); ltrStats = new LTRStats(unmodifiableMap(stats)); } diff --git a/src/test/java/com/o19s/es/explore/StatisticsHelperTests.java b/src/test/java/com/o19s/es/explore/StatisticsHelperTests.java index bd79eb0b..7fde0ea2 100644 --- a/src/test/java/com/o19s/es/explore/StatisticsHelperTests.java +++ b/src/test/java/com/o19s/es/explore/StatisticsHelperTests.java @@ -18,14 +18,12 @@ import org.apache.lucene.tests.util.LuceneTestCase; public class StatisticsHelperTests extends LuceneTestCase { - private final float[] dataset = new float[] { - 0.0f, -5.0f, 10.0f, 5.0f - }; + private final float[] dataset = new float[] { 0.0f, -5.0f, 10.0f, 5.0f }; public void testStats() throws Exception { StatisticsHelper stats = new StatisticsHelper(); - for(float f : dataset) { + for (float f : dataset) { stats.add(f); } diff --git a/src/test/java/com/o19s/es/ltr/LtrQueryContextTests.java b/src/test/java/com/o19s/es/ltr/LtrQueryContextTests.java index 7cdf8652..0923fc0d 100644 --- a/src/test/java/com/o19s/es/ltr/LtrQueryContextTests.java +++ b/src/test/java/com/o19s/es/ltr/LtrQueryContextTests.java @@ -1,11 +1,11 @@ package com.o19s.es.ltr; -import org.apache.lucene.tests.util.LuceneTestCase; - import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import org.apache.lucene.tests.util.LuceneTestCase; + /* * Copyright [2017] Wikimedia Foundation * @@ -56,4 +56,3 @@ public void testGetActiveFeatures() { } } - diff --git a/src/test/java/com/o19s/es/ltr/LtrTestUtils.java b/src/test/java/com/o19s/es/ltr/LtrTestUtils.java index 02f96447..f9cb5755 100644 --- a/src/test/java/com/o19s/es/ltr/LtrTestUtils.java +++ b/src/test/java/com/o19s/es/ltr/LtrTestUtils.java @@ -16,6 +16,21 @@ package com.o19s.es.ltr; +import static org.apache.lucene.tests.util.LuceneTestCase.random; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.function.IntFunction; + +import org.apache.lucene.tests.util.TestUtil; +import org.opensearch.common.CheckedFunction; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.WrapperQueryBuilder; + import com.o19s.es.ltr.feature.Feature; import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.feature.store.CompiledLtrModel; @@ -36,20 +51,6 @@ import com.o19s.es.ltr.ranker.normalizer.Normalizer; import com.o19s.es.ltr.ranker.parser.LinearRankerParser; import com.o19s.es.ltr.utils.FeatureStoreLoader; -import org.apache.lucene.tests.util.TestUtil; -import org.opensearch.common.CheckedFunction; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.index.query.WrapperQueryBuilder; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import java.util.function.IntFunction; - -import static org.apache.lucene.tests.util.LuceneTestCase.random; public class LtrTestUtils { @@ -80,8 +81,7 @@ public static CompiledLtrModel buildRandomModel() throws IOException { if (ftrNorms.size() > 0) { ranker = new FeatureNormalizingRanker(ranker, ftrNorms); } - return new CompiledLtrModel(TestUtil.randomSimpleString(random(), 5, 10), set, - ranker); + return new CompiledLtrModel(TestUtil.randomSimpleString(random(), 5, 10), set, ranker); } public static StoredLtrModel randomLinearModel(String name, StoredFeatureSet set) throws IOException { @@ -101,8 +101,7 @@ public static StoredFeatureNormalizers randomFtrNorms(FeatureSet set) { if (random().nextBoolean()) { continue; - } - else { + } else { Feature ftr = set.feature(i); @@ -130,23 +129,30 @@ public static LtrRanker buildRandomRanker(int fSize) { if (random().nextBoolean()) { ranker = LinearRankerTests.generateRandomRanker(fSize); } else { - ranker = NaiveAdditiveDecisionTreeTests.generateRandomDecTree(fSize, TestUtil.nextInt(random(), 1, 50), - 5, 50, null); + ranker = NaiveAdditiveDecisionTreeTests.generateRandomDecTree(fSize, TestUtil.nextInt(random(), 1, 50), 5, 50, null); } return ranker; } public static FeatureStoreLoader nullLoader() { - return (storeName, client) -> {throw new IllegalStateException("Invalid state, this query cannot be " + - "built without a valid store loader. Your are seeing this exception because you attempt to call " + - "doToQuery on a " + StoredLtrQueryBuilder.class.getSimpleName() + " instance that was built with " + - "an invalid FeatureStoreLoader. If you are trying to run integration tests with this query consider " + - "wrapping it inside a " + WrapperQueryBuilder.class.getSimpleName() + ":\n" + - "\tnew WrapperQueryBuilder(sltrBuilder.toString())\n" + - "This will force elastic to initialize the feature loader properly");}; + return (storeName, client) -> { + throw new IllegalStateException( + "Invalid state, this query cannot be " + + "built without a valid store loader. Your are seeing this exception because you attempt to call " + + "doToQuery on a " + + StoredLtrQueryBuilder.class.getSimpleName() + + " instance that was built with " + + "an invalid FeatureStoreLoader. If you are trying to run integration tests with this query consider " + + "wrapping it inside a " + + WrapperQueryBuilder.class.getSimpleName() + + ":\n" + + "\tnew WrapperQueryBuilder(sltrBuilder.toString())\n" + + "This will force elastic to initialize the feature loader properly" + ); + }; } - public static Function wrapFuncion(CheckedFunction f) { + public static Function wrapFuncion(CheckedFunction f) { return (p) -> { try { return f.apply(p); @@ -156,7 +162,7 @@ public static Function wrapFuncion(CheckedFuncti }; } - public static IntFunction wrapIntFuncion(CheckedFunction f) { + public static IntFunction wrapIntFuncion(CheckedFunction f) { return (p) -> { try { return f.apply(p); @@ -170,46 +176,42 @@ public static FeatureStoreLoader wrapMemStore(MemStore store) { return (storeName, client) -> store; } - public static String testFeatureSetString() { - return "{\n" + - "\"name\": \"movie_features\",\n" + - "\"type\": \"featureset\",\n" + - "\"featureset\": {\n" + - " \"name\": \"movie_features\",\n" + - " \"features\": [\n" + - " {\n" + - " \"name\": \"1\",\n" + - " \"params\": [\n" + - " \"keywords\"\n" + - " ],\n" + - " \"template_language\": \"mustache\",\n" + - " \"template\": {\n" + - " \"match\": {\n" + - " \"title\": \"{{keywords}}\"\n" + - " }\n" + - " }\n" + - " },\n" + - " {\n" + - " \"name\": \"2\",\n" + - " \"params\": [\n" + - " \"keywords\"\n" + - " ],\n" + - " \"template_language\": \"mustache\",\n" + - " \"template\": {\n" + - " \"match\": {\n" + - " \"overview\": \"{{keywords}}\"\n" + - " }\n" + - " }\n" + - " }\n" + - " ]\n" + - "}\n}"; + return "{\n" + + "\"name\": \"movie_features\",\n" + + "\"type\": \"featureset\",\n" + + "\"featureset\": {\n" + + " \"name\": \"movie_features\",\n" + + " \"features\": [\n" + + " {\n" + + " \"name\": \"1\",\n" + + " \"params\": [\n" + + " \"keywords\"\n" + + " ],\n" + + " \"template_language\": \"mustache\",\n" + + " \"template\": {\n" + + " \"match\": {\n" + + " \"title\": \"{{keywords}}\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"name\": \"2\",\n" + + " \"params\": [\n" + + " \"keywords\"\n" + + " ],\n" + + " \"template_language\": \"mustache\",\n" + + " \"template\": {\n" + + " \"match\": {\n" + + " \"overview\": \"{{keywords}}\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}\n}"; } public static String testModelString() { - return "{\n" + - "\"name\": \"movie_model\",\n" + - "\"type\": \"model\"" + - "\n}"; + return "{\n" + "\"name\": \"movie_model\",\n" + "\"type\": \"model\"" + "\n}"; } } diff --git a/src/test/java/com/o19s/es/ltr/action/TransportLTRStatsActionTests.java b/src/test/java/com/o19s/es/ltr/action/TransportLTRStatsActionTests.java index 970dd63d..41156b08 100644 --- a/src/test/java/com/o19s/es/ltr/action/TransportLTRStatsActionTests.java +++ b/src/test/java/com/o19s/es/ltr/action/TransportLTRStatsActionTests.java @@ -16,25 +16,26 @@ package com.o19s.es.ltr.action; -import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodeRequest; -import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodeResponse; -import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodesRequest; -import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodesResponse; -import org.opensearch.ltr.stats.LTRStat; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import org.opensearch.action.FailedNodeException; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.test.OpenSearchIntegTestCase; -import org.opensearch.transport.TransportService; -import org.junit.Before; +import static org.mockito.Mockito.mock; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.mockito.Mockito.mock; +import org.junit.Before; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ltr.stats.LTRStat; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.transport.TransportService; + +import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodeRequest; +import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodeResponse; +import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodesRequest; +import com.o19s.es.ltr.action.LTRStatsAction.LTRStatsNodesResponse; public class TransportLTRStatsActionTests extends OpenSearchIntegTestCase { @@ -53,11 +54,11 @@ public void setup() throws Exception { ltrStats = new LTRStats(statsMap); action = new TransportLTRStatsAction( - client().threadPool(), - clusterService(), - mock(TransportService.class), - mock(ActionFilters.class), - ltrStats + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + ltrStats ); } diff --git a/src/test/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplierTests.java b/src/test/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplierTests.java index 245ed8bf..64ab1803 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplierTests.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/ExtraLoggingSupplierTests.java @@ -16,12 +16,13 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.ranker.LogLtrRanker; -import org.apache.lucene.tests.util.LuceneTestCase; - import java.util.HashMap; import java.util.Map; +import org.apache.lucene.tests.util.LuceneTestCase; + +import com.o19s.es.ltr.ranker.LogLtrRanker; + public class ExtraLoggingSupplierTests extends LuceneTestCase { public void testGetWithConsumerNotSet() { ExtraLoggingSupplier supplier = new ExtraLoggingSupplier(); @@ -41,14 +42,14 @@ public void testGetWithSuppliedNull() { } public void testGetWithSuppliedMap() { - Map extraLoggingMap = new HashMap<>(); + Map extraLoggingMap = new HashMap<>(); LogLtrRanker.LogConsumer consumer = new LogLtrRanker.LogConsumer() { @Override public void accept(int featureOrdinal, float score) {} @Override - public Map getExtraLoggingMap() { + public Map getExtraLoggingMap() { return extraLoggingMap; } }; diff --git a/src/test/java/com/o19s/es/ltr/feature/store/FeatureSupplierTests.java b/src/test/java/com/o19s/es/ltr/feature/store/FeatureSupplierTests.java index fa0d07bf..061d7a72 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/FeatureSupplierTests.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/FeatureSupplierTests.java @@ -16,18 +16,19 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.ranker.DenseFeatureVector; -import com.o19s.es.ltr.ranker.LtrRanker; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.index.query.QueryBuilders; +import static java.util.Collections.singletonList; import java.util.Collections; import java.util.Iterator; import java.util.Map; import java.util.Set; -import static java.util.Collections.singletonList; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.opensearch.index.query.QueryBuilders; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.DenseFeatureVector; +import com.o19s.es.ltr.ranker.LtrRanker; @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") public class FeatureSupplierTests extends LuceneTestCase { @@ -64,7 +65,7 @@ public void testGetFeatureScore() { assertNull(featureSupplier.get("bad_test")); } - public void testEntrySetWhenFeatureVectorNotSet(){ + public void testEntrySetWhenFeatureVectorNotSet() { FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet()); Set> entrySet = featureSupplier.entrySet(); assertTrue(entrySet.isEmpty()); @@ -74,7 +75,7 @@ public void testEntrySetWhenFeatureVectorNotSet(){ assertEquals(0, entrySet.size()); } - public void testEntrySetWhenFeatureVectorIsSet(){ + public void testEntrySetWhenFeatureVectorIsSet() { FeatureSupplier featureSupplier = new FeatureSupplier(getFeatureSet()); LtrRanker.FeatureVector featureVector = new DenseFeatureVector(1); featureVector.setFeatureScore(0, 10.0f); @@ -85,10 +86,9 @@ public void testEntrySetWhenFeatureVectorIsSet(){ Iterator> iterator = entrySet.iterator(); assertTrue(iterator.hasNext()); Map.Entry item = iterator.next(); - assertEquals("test",item.getKey()); - assertEquals(10.0f,item.getValue(), 0.0f); + assertEquals("test", item.getKey()); + assertEquals(10.0f, item.getValue(), 0.0f); assertEquals(1, entrySet.size()); } } - diff --git a/src/test/java/com/o19s/es/ltr/feature/store/MemStore.java b/src/test/java/com/o19s/es/ltr/feature/store/MemStore.java index 209a7f07..c65a87ff 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/MemStore.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/MemStore.java @@ -16,13 +16,13 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; - import java.io.IOException; import java.util.HashMap; import java.util.Map; +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; + /** * in memory test store */ diff --git a/src/test/java/com/o19s/es/ltr/feature/store/StoredFeatureParserTests.java b/src/test/java/com/o19s/es/ltr/feature/store/StoredFeatureParserTests.java index a67a990d..fa80d581 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/StoredFeatureParserTests.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/StoredFeatureParserTests.java @@ -16,30 +16,29 @@ package com.o19s.es.ltr.feature.store; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThan; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; + +import java.io.IOException; +import java.util.Arrays; + import org.apache.lucene.tests.util.LuceneTestCase; +import org.hamcrest.CoreMatchers; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.AbstractQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; -import org.hamcrest.CoreMatchers; - -import java.io.IOException; -import java.util.Arrays; - -import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; -import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; -import static org.hamcrest.Matchers.allOf; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.lessThan; public class StoredFeatureParserTests extends LuceneTestCase { public void testParseFeatureAsJson() throws IOException { @@ -50,13 +49,15 @@ public void testParseFeatureAsJson() throws IOException { } public static String generateTestFeature(String name) { - return "{\n" + - "\"name\": \""+name+"\",\n" + - "\"params\": [\"param1\", \"param2\"],\n" + - "\"template_language\": \"mustache\",\n" + - "\"template\": \n" + - new MatchQueryBuilder("match_field", "match_word").toString() + - "\n}\n"; + return "{\n" + + "\"name\": \"" + + name + + "\",\n" + + "\"params\": [\"param1\", \"param2\"],\n" + + "\"template_language\": \"mustache\",\n" + + "\"template\": \n" + + new MatchQueryBuilder("match_field", "match_word").toString() + + "\n}\n"; } public static String generateTestFeature() { @@ -67,29 +68,25 @@ public void assertTestFeature(StoredFeature feature) { assertEquals("testFeature", feature.name()); assertArrayEquals(Arrays.asList("param1", "param2").toArray(), feature.queryParams().toArray()); assertEquals("mustache", feature.templateLanguage()); - assertEquals(writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")), - feature.template()); + assertEquals(writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")), feature.template()); assertFalse(feature.templateAsString()); } public void testParseFeatureAsString() throws IOException { - String featureString = "{\n" + - "\"name\": \"testFeature\",\n" + - "\"params\": [\"param1\", \"param2\"],\n" + - "\"template_language\": \"mustache\",\n" + - "\"template\": \"" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) - .replace("\"", "\\\"") + - "\"\n}\n"; - + String featureString = "{\n" + + "\"name\": \"testFeature\",\n" + + "\"params\": [\"param1\", \"param2\"],\n" + + "\"template_language\": \"mustache\",\n" + + "\"template\": \"" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")).replace("\"", "\\\"") + + "\"\n}\n"; StoredFeature feature = parse(featureString); assertEquals("testFeature", feature.name()); assertArrayEquals(Arrays.asList("param1", "param2").toArray(), feature.queryParams().toArray()); assertEquals("mustache", feature.templateLanguage()); - assertEquals(writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")), - feature.template()); + assertEquals(writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")), feature.template()); assertTrue(feature.templateAsString()); } @@ -103,145 +100,140 @@ public void testToXContent() throws IOException { } public void testParseErrorOnMissingName() throws IOException { - String featureString = "{\n" + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"mustache\",\n" + - "\"template\": \n" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + - "}"; - assertThat(expectThrows(ParsingException.class, () -> parse(featureString)).getMessage(), - equalTo("Field [name] is mandatory")); + String featureString = "{\n" + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"mustache\",\n" + + "\"template\": \n" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + + "}"; + assertThat(expectThrows(ParsingException.class, () -> parse(featureString)).getMessage(), equalTo("Field [name] is mandatory")); } public void testParseWithExternalName() throws IOException { - String featureString = "{\n" + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"mustache\",\n" + - "\"template\": \n" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + - "}"; + String featureString = "{\n" + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"mustache\",\n" + + "\"template\": \n" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + + "}"; StoredFeature set = parse(featureString, "my_feature"); assertEquals("my_feature", set.name()); } public void testParseWithInconsistentExternalName() throws IOException { - String featureString = "{\n" + - "\"name\": \"testFeature\",\n" + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"mustache\",\n" + - "\"template\": \n" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + - "}"; - assertThat(expectThrows(ParsingException.class, - () -> parse(featureString, "testFeature2")).getMessage(), - CoreMatchers.equalTo("Invalid [name], expected [testFeature2] but got [testFeature]")); + String featureString = "{\n" + + "\"name\": \"testFeature\",\n" + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"mustache\",\n" + + "\"template\": \n" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + + "}"; + assertThat( + expectThrows(ParsingException.class, () -> parse(featureString, "testFeature2")).getMessage(), + CoreMatchers.equalTo("Invalid [name], expected [testFeature2] but got [testFeature]") + ); } public void testParseErrorOnBadTemplate() throws IOException { - String featureString = "{\n" + - "\"name\": \"testFeature\",\n" + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"mustache\",\n" + - "\"template\": \"{{hop\"" + - "}"; - assertThat(expectThrows(IllegalArgumentException.class, () -> parse(featureString).optimize()).getMessage(), - containsString("Improperly closed variable")); + String featureString = "{\n" + + "\"name\": \"testFeature\",\n" + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"mustache\",\n" + + "\"template\": \"{{hop\"" + + "}"; + assertThat( + expectThrows(IllegalArgumentException.class, () -> parse(featureString).optimize()).getMessage(), + containsString("Improperly closed variable") + ); } public void testParseErrorOnMissingTemplate() throws IOException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"mustache\"\n" + - "}"; - assertThat(expectThrows(ParsingException.class, () -> parse(featureString)).getMessage(), - equalTo("Field [template] is mandatory")); + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"mustache\"\n" + + "}"; + assertThat(expectThrows(ParsingException.class, () -> parse(featureString)).getMessage(), equalTo("Field [template] is mandatory")); } public void testParseErrorOnUnknownField() throws IOException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"mustache\",\n" + - "\n\"bogusField\":\"oops\"," + - "\"template\": \n" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + - "}"; - assertThat(expectThrows(ParsingException.class, () -> parse(featureString)).getMessage(), - containsString("bogusField")); + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"mustache\",\n" + + "\n\"bogusField\":\"oops\"," + + "\"template\": \n" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + + "}"; + assertThat(expectThrows(ParsingException.class, () -> parse(featureString)).getMessage(), containsString("bogusField")); } public void testParseWithoutParams() throws IOException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"template_language\":\"mustache\",\n" + - "\"template\": \n" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + - "}"; + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"template_language\":\"mustache\",\n" + + "\"template\": \n" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + + "}"; StoredFeature feat = parse(featureString); assertTrue(feat.queryParams().isEmpty()); } public void testParseWithEmptyParams() throws IOException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"params\":[]," + - "\"template_language\":\"mustache\",\n" + - "\"template\": \n" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + - "}"; + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"params\":[]," + + "\"template_language\":\"mustache\",\n" + + "\"template\": \n" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) + + "}"; StoredFeature feat = parse(featureString); assertTrue(feat.queryParams().isEmpty()); } public void testRamBytesUsed() throws IOException, InterruptedException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"mustache\",\n" + - "\"template\":\"" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) - .replace("\"", "\\\"") + - "\"}"; + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"mustache\",\n" + + "\"template\":\"" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")).replace("\"", "\\\"") + + "\"}"; StoredFeature feature = parse(featureString); - long approxSize = featureString.length()*Character.BYTES; - assertThat(feature.ramBytesUsed(), - allOf(greaterThan((long) (approxSize*0.66)), - lessThan((long) (approxSize*1.33)))); + long approxSize = featureString.length() * Character.BYTES; + assertThat(feature.ramBytesUsed(), allOf(greaterThan((long) (approxSize * 0.66)), lessThan((long) (approxSize * 1.33)))); } public void testExpressionOptimization() throws IOException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"template_language\":\"derived_expression\",\n" + - "\"template\":\"Math.random()" + - "\"}"; + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"template_language\":\"derived_expression\",\n" + + "\"template\":\"Math.random()" + + "\"}"; StoredFeature feature = parse(featureString); assertThat(feature.optimize(), instanceOf(PrecompiledExpressionFeature.class)); } public void testMustacheOptimization() throws IOException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"mustache\",\n" + - "\"template\":\"" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) - .replace("\"", "\\\"") + - "\"}"; + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"mustache\",\n" + + "\"template\":\"" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")).replace("\"", "\\\"") + + "\"}"; StoredFeature feature = parse(featureString); assertThat(feature.optimize(), instanceOf(PrecompiledTemplateFeature.class)); } public void testDontOptimizeOnThirdPartyTemplateEngine() throws IOException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"params\":[\"param1\",\"param2\"]," + - "\"template_language\":\"third_party_template_engine\",\n" + - "\"template\":\"" + - writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")) - .replace("\"", "\\\"") + - "\"}"; + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"params\":[\"param1\",\"param2\"]," + + "\"template_language\":\"third_party_template_engine\",\n" + + "\"template\":\"" + + writeAsNonFormattedString(new MatchQueryBuilder("match_field", "match_word")).replace("\"", "\\\"") + + "\"}"; StoredFeature feature = parse(featureString); assertSame(feature, feature.optimize()); } @@ -251,11 +243,10 @@ static StoredFeature parse(String featureString) throws IOException { } static StoredFeature parse(String featureString, String defaultName) throws IOException { - return StoredFeature.parse(jsonXContent.createParser(EMPTY, - LoggingDeprecationHandler.INSTANCE, featureString), defaultName); + return StoredFeature.parse(jsonXContent.createParser(EMPTY, LoggingDeprecationHandler.INSTANCE, featureString), defaultName); } private String writeAsNonFormattedString(AbstractQueryBuilder builder) { return Strings.toString(MediaTypeRegistry.JSON, builder, false, false); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/feature/store/StoredFeatureSetParserTests.java b/src/test/java/com/o19s/es/ltr/feature/store/StoredFeatureSetParserTests.java index 9b2538c5..1fec8dce 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/StoredFeatureSetParserTests.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/StoredFeatureSetParserTests.java @@ -16,17 +16,16 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.query.DerivedExpressionQuery; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.core.common.ParsingException; -import org.opensearch.core.common.Strings; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.index.query.MatchQueryBuilder; +import static org.apache.lucene.tests.util.TestUtil.randomRealisticUnicodeString; +import static org.apache.lucene.tests.util.TestUtil.randomSimpleString; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThan; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; import java.io.IOException; import java.util.ArrayList; @@ -37,16 +36,16 @@ import java.util.Set; import java.util.function.Consumer; -import static org.apache.lucene.tests.util.TestUtil.randomRealisticUnicodeString; -import static org.apache.lucene.tests.util.TestUtil.randomSimpleString; -import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; -import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; -import static org.hamcrest.CoreMatchers.allOf; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.lessThan; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchQueryBuilder; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.query.DerivedExpressionQuery; public class StoredFeatureSetParserTests extends LuceneTestCase { @@ -75,11 +74,12 @@ private void assertFeatureSet(StoredFeatureSet set, List features ramSize += actual.ramBytesUsed(); } assertFalse(set.hasFeature(unknownName())); - assertThat(expectThrows(IllegalArgumentException.class, - () -> set.feature(unknownName())).getMessage(), - containsString("Unknown feature")); + assertThat( + expectThrows(IllegalArgumentException.class, () -> set.feature(unknownName())).getMessage(), + containsString("Unknown feature") + ); - assertThat(set.ramBytesUsed(), allOf(greaterThan((long) (ramSize*0.66)), lessThan((long) (ramSize*1.33)))); + assertThat(set.ramBytesUsed(), allOf(greaterThan((long) (ramSize * 0.66)), lessThan((long) (ramSize * 1.33)))); } public void testToXContent() throws IOException { @@ -94,22 +94,28 @@ public void testToXContent() throws IOException { } public void testParseErrorOnDups() throws IOException { - String set = "{\"name\" : \"my_set\",\n" + - "\"features\": [\n" + - StoredFeatureParserTests.generateTestFeature() + "," + - StoredFeatureParserTests.generateTestFeature() + - "]}"; - assertThat(expectThrows(ParsingException.class, - () -> parse(set)).getMessage(), - containsString("feature names must be unique in a set")); + String set = "{\"name\" : \"my_set\",\n" + + "\"features\": [\n" + + StoredFeatureParserTests.generateTestFeature() + + "," + + StoredFeatureParserTests.generateTestFeature() + + "]}"; + assertThat( + expectThrows(ParsingException.class, () -> parse(set)).getMessage(), + containsString("feature names must be unique in a set") + ); } public void testExpressionMissingQueryParameter() throws IOException { FeatureSet optimizedFeatureSet = getFeatureSet(); assertThat(optimizedFeatureSet.feature(0), instanceOf(PrecompiledExpressionFeature.class)); - assertThat(expectThrows(IllegalArgumentException.class, - () -> optimizedFeatureSet.feature(0).doToQuery(null, optimizedFeatureSet, new HashMap<>())).getMessage(), - containsString("Missing required param(s): [param1]")); + assertThat( + expectThrows( + IllegalArgumentException.class, + () -> optimizedFeatureSet.feature(0).doToQuery(null, optimizedFeatureSet, new HashMap<>()) + ).getMessage(), + containsString("Missing required param(s): [param1]") + ); } public void testExpressionInvalidQueryParameter() throws IOException { @@ -117,9 +123,11 @@ public void testExpressionInvalidQueryParameter() throws IOException { assertThat(optimizedFeatureSet.feature(0), instanceOf(PrecompiledExpressionFeature.class)); Map params = new HashMap<>(); params.put("param1", "NaN"); - assertThat(expectThrows(IllegalArgumentException.class, - () -> optimizedFeatureSet.feature(0).doToQuery(null, optimizedFeatureSet, params)).getMessage(), - containsString("parameter: param1 expected to be of type Double")); + assertThat( + expectThrows(IllegalArgumentException.class, () -> optimizedFeatureSet.feature(0).doToQuery(null, optimizedFeatureSet, params)) + .getMessage(), + containsString("parameter: param1 expected to be of type Double") + ); } public void testExpressionIntegerQueryParameter() throws IOException { @@ -142,49 +150,35 @@ private void assertDerivedExpressionQuery(Object param) throws IOException { assertThat(optimizedFeatureSet.feature(0).doToQuery(null, optimizedFeatureSet, params), instanceOf(DerivedExpressionQuery.class)); } - private FeatureSet getFeatureSet() throws IOException { - String featureString = "{\n" + - "\"name\":\"testFeature\"," + - "\"params\":[\"param1\"]," + - "\"template_language\":\"derived_expression\",\n" + - "\"template\":\"log10(param1)" + - "\"}"; - String set = "{\"name\" : \"my_set\",\n" + - "\"features\": [\n" + - featureString + - "]}"; + String featureString = "{\n" + + "\"name\":\"testFeature\"," + + "\"params\":[\"param1\"]," + + "\"template_language\":\"derived_expression\",\n" + + "\"template\":\"log10(param1)" + + "\"}"; + String set = "{\"name\" : \"my_set\",\n" + "\"features\": [\n" + featureString + "]}"; StoredFeatureSet featureSet = parse(set); return featureSet.optimize(); } public void testParseErrorOnMissingName() throws IOException { - String missingName = "{" + - "\"features\": [\n" + - StoredFeatureParserTests.generateTestFeature() + - "]}"; - assertThat(expectThrows(ParsingException.class, - () -> parse(missingName)).getMessage(), - equalTo("Field [name] is mandatory")); + String missingName = "{" + "\"features\": [\n" + StoredFeatureParserTests.generateTestFeature() + "]}"; + assertThat(expectThrows(ParsingException.class, () -> parse(missingName)).getMessage(), equalTo("Field [name] is mandatory")); } public void testParseWithExternalName() throws IOException { - String missingName = "{" + - "\"features\": [\n" + - StoredFeatureParserTests.generateTestFeature() + - "]}"; + String missingName = "{" + "\"features\": [\n" + StoredFeatureParserTests.generateTestFeature() + "]}"; StoredFeatureSet set = parse(missingName, "my_set"); assertEquals("my_set", set.name()); } public void testParseWithInconsistentExternalName() throws IOException { - String set = "{\"name\" : \"my_set\",\n" + - "\"features\": [\n" + - StoredFeatureParserTests.generateTestFeature() + - "]}"; - assertThat(expectThrows(ParsingException.class, - () -> parse(set, "my_set2")).getMessage(), - equalTo("Invalid [name], expected [my_set2] but got [my_set]")); + String set = "{\"name\" : \"my_set\",\n" + "\"features\": [\n" + StoredFeatureParserTests.generateTestFeature() + "]}"; + assertThat( + expectThrows(ParsingException.class, () -> parse(set, "my_set2")).getMessage(), + equalTo("Invalid [name], expected [my_set2] but got [my_set]") + ); } public void testParseErrorOnMissingSet() throws IOException { @@ -194,32 +188,30 @@ public void testParseErrorOnMissingSet() throws IOException { } public void testParseErrorOnEmptySet() throws IOException { - String missingList = "{ \"name\": \"my_set\"," + - "\"features\": []}"; + String missingList = "{ \"name\": \"my_set\"," + "\"features\": []}"; StoredFeatureSet set = parse(missingList); assertEquals(0, set.size()); } public void testParseErrorOnExtraField() throws IOException { - String set = "{\"name\" : \"my_set\",\n" + - "\"random_field\": \"oops\"," + - "\"features\": [\n" + - StoredFeatureParserTests.generateTestFeature() + - "]}"; - assertThat(expectThrows(ParsingException.class, - () -> parse(set)).getMessage(), - containsString("[2:1] [featureset] unknown field [random_field]")); + String set = "{\"name\" : \"my_set\",\n" + + "\"random_field\": \"oops\"," + + "\"features\": [\n" + + StoredFeatureParserTests.generateTestFeature() + + "]}"; + assertThat( + expectThrows(ParsingException.class, () -> parse(set)).getMessage(), + containsString("[2:1] [featureset] unknown field [random_field]") + ); } private static StoredFeatureSet parse(String missingName) throws IOException { - return StoredFeatureSet.parse(jsonXContent.createParser(EMPTY, - LoggingDeprecationHandler.INSTANCE, missingName)); + return StoredFeatureSet.parse(jsonXContent.createParser(EMPTY, LoggingDeprecationHandler.INSTANCE, missingName)); } private static StoredFeatureSet parse(String missingName, String defaultName) throws IOException { - return StoredFeatureSet.parse(jsonXContent.createParser(EMPTY, - LoggingDeprecationHandler.INSTANCE, missingName), defaultName); + return StoredFeatureSet.parse(jsonXContent.createParser(EMPTY, LoggingDeprecationHandler.INSTANCE, missingName), defaultName); } public static StoredFeature buildRandomFeature() throws IOException { @@ -229,18 +221,27 @@ public static StoredFeature buildRandomFeature() throws IOException { public static StoredFeature buildRandomFeature(String name) throws IOException { return StoredFeatureParserTests.parse(generateRandomFeature(name)); } + private static String generateRandomFeature() { return generateRandomFeature(rName()); } private static String generateRandomFeature(String name) { - return "{\n" + - "\"name\": \"" + name + "\",\n" + - "\"params\": [\"" + rName() + "\", \"" + rName() + "\"],\n" + - "\"template_language\": \"" + rName() + "\",\n" + - "\"template\": \n" + - new MatchQueryBuilder(rName(), randomRealisticUnicodeString(random())).toString() + - "\n}\n"; + return "{\n" + + "\"name\": \"" + + name + + "\",\n" + + "\"params\": [\"" + + rName() + + "\", \"" + + rName() + + "\"],\n" + + "\"template_language\": \"" + + rName() + + "\",\n" + + "\"template\": \n" + + new MatchQueryBuilder(rName(), randomRealisticUnicodeString(random())).toString() + + "\n}\n"; } private static String rName() { @@ -273,22 +274,20 @@ public static String generateRandomFeatureSet(Consumer features) } public static String generateRandomFeatureSet(String name, Consumer features) throws IOException { - return generateRandomFeatureSet(name, features, random().nextInt(20)+1); + return generateRandomFeatureSet(name, features, random().nextInt(20) + 1); } public static String generateRandomFeatureSet(String name, Consumer features, int nbFeat) throws IOException { StringBuilder sb = new StringBuilder(); - sb.append("{\"name\" : \"") - .append(name) - .append("\",\n"); + sb.append("{\"name\" : \"").append(name).append("\",\n"); sb.append("\"features\":["); boolean first = true; // Simply avoid adding the same feature twice because of random string Set addedFeatures = new HashSet<>(); - while(nbFeat-->0) { + while (nbFeat-- > 0) { String featureString = generateRandomFeature(); - StoredFeature feature = StoredFeature.parse(jsonXContent.createParser(EMPTY, - LoggingDeprecationHandler.INSTANCE, featureString)); + StoredFeature feature = StoredFeature + .parse(jsonXContent.createParser(EMPTY, LoggingDeprecationHandler.INSTANCE, featureString)); if (!addedFeatures.add(feature.name())) { continue; } @@ -304,4 +303,4 @@ public static String generateRandomFeatureSet(String name, Consumer features = IntStream.range(0, StoredFeatureSet.MAX_FEATURES) - .mapToObj(wrapIntFuncion((i) -> randomFeature("feat" + i))) - .collect(Collectors.toList()); + List features = IntStream + .range(0, StoredFeatureSet.MAX_FEATURES) + .mapToObj(wrapIntFuncion((i) -> randomFeature("feat" + i))) + .collect(Collectors.toList()); StoredFeatureSet set_v2 = set_v1.append(features); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, - () -> set_v2.append(singletonList(randomFeature("new_feat")))); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> set_v2.append(singletonList(randomFeature("new_feat"))) + ); assertThat(iae.getMessage(), equalTo("The resulting feature set would be too large")); } public void testMergeMaxSize() throws IOException { StoredFeatureSet set_v1 = new StoredFeatureSet("name", emptyList()); assertEquals(0, set_v1.size()); - //noinspection ConstantConditions + // noinspection ConstantConditions assert StoredFeatureSet.MAX_FEATURES > 10; - List features = IntStream.range(0, StoredFeatureSet.MAX_FEATURES - 2) - .mapToObj(wrapIntFuncion((i) -> randomFeature("feat" + i))) - .collect(Collectors.toList()); - StoredFeatureSet set_v2 = set_v1.append(features) - .merge(asList(randomFeature("feat0"), - randomFeature("feat9"), - randomFeature("new1"), - randomFeature("new2"))); + List features = IntStream + .range(0, StoredFeatureSet.MAX_FEATURES - 2) + .mapToObj(wrapIntFuncion((i) -> randomFeature("feat" + i))) + .collect(Collectors.toList()); + StoredFeatureSet set_v2 = set_v1 + .append(features) + .merge(asList(randomFeature("feat0"), randomFeature("feat9"), randomFeature("new1"), randomFeature("new2"))); assertEquals(StoredFeatureSet.MAX_FEATURES, set_v2.size()); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, - () -> set_v2.merge(asList(randomFeature("new4"), randomFeature("feat9")))); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> set_v2.merge(asList(randomFeature("new4"), randomFeature("feat9"))) + ); assertThat(iae.getMessage(), equalTo("The resulting feature set would be too large")); } @@ -115,4 +119,4 @@ public void testMerge() throws IOException { assertSame(feat2_v1, set_v3.feature(1)); assertSame(feat3_v1, set_v3.feature(2)); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/feature/store/StoredLtrModelParserTests.java b/src/test/java/com/o19s/es/ltr/feature/store/StoredLtrModelParserTests.java index 17a3c557..d945139b 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/StoredLtrModelParserTests.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/StoredLtrModelParserTests.java @@ -16,31 +16,32 @@ package com.o19s.es.ltr.feature.store; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.linear.LinearRanker; -import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; -import com.o19s.es.ltr.ranker.normalizer.MinMaxFeatureNormalizer; -import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; -import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; -import org.apache.lucene.util.BytesRef; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; + +import java.io.IOException; + import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.core.common.ParsingException; +import org.apache.lucene.util.BytesRef; import org.opensearch.common.Randomness; -import org.opensearch.core.common.io.stream.ByteBufferStreamInput; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.ByteBufferStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.XContentType; -import java.io.IOException; - -import static org.opensearch.core.xcontent.NamedXContentRegistry.EMPTY; -import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.linear.LinearRanker; +import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; +import com.o19s.es.ltr.ranker.normalizer.MinMaxFeatureNormalizer; +import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; +import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; public class StoredLtrModelParserTests extends LuceneTestCase { private LtrRanker ranker; @@ -48,60 +49,58 @@ public class StoredLtrModelParserTests extends LuceneTestCase { public void setUp() throws Exception { super.setUp(); - ranker = new LinearRanker(new float[]{1F,2F,3F}); - factory = new LtrRankerParserFactory.Builder() - .register("model/dummy", () -> (set, model) -> ranker) - .build(); + ranker = new LinearRanker(new float[] { 1F, 2F, 3F }); + factory = new LtrRankerParserFactory.Builder().register("model/dummy", () -> (set, model) -> ranker).build(); } public String getTestModel() throws IOException { - return "{\n" + - " \"name\":\"my_model\",\n" + - " \"feature_set\":" + - StoredFeatureSetParserTests.generateRandomFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\"\n"+ - " }" + - "}"; + return "{\n" + + " \"name\":\"my_model\",\n" + + " \"feature_set\":" + + StoredFeatureSetParserTests.generateRandomFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\"\n" + + " }" + + "}"; } public String getTestModelAsXContent() throws IOException { - return "{\n" + - " \"name\":\"my_model\",\n" + - " \"feature_set\":" + - StoredFeatureSetParserTests.generateRandomFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": [\"completely ignored\"]\n"+ - " }" + - "}"; + return "{\n" + + " \"name\":\"my_model\",\n" + + " \"feature_set\":" + + StoredFeatureSetParserTests.generateRandomFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": [\"completely ignored\"]\n" + + " }" + + "}"; } public String getSimpleFeatureSet() { - String inlineFeatureSet = "{" + - "\"name\": \"normed_model\"," + - " \"features\": [{" + - " \"name\": \"feature_1\"," + - " \"params\": [\"keywords\"]," + - " \"template\": {" + - " \"match\": {" + - " \"a_field\": {" + - " \"query\": \"test1\"" + - " }" + - " }" + - " }" + - " }," + - " {" + - " \"name\": \"feature_2\"," + - " \"params\": [\"keywords\"]," + - " \"template\": {" + - " \"match\": {" + - " \"esyww\": {" + - " \"query\": \"test1\"" + - " }}}}]}"; + String inlineFeatureSet = "{" + + "\"name\": \"normed_model\"," + + " \"features\": [{" + + " \"name\": \"feature_1\"," + + " \"params\": [\"keywords\"]," + + " \"template\": {" + + " \"match\": {" + + " \"a_field\": {" + + " \"query\": \"test1\"" + + " }" + + " }" + + " }" + + " }," + + " {" + + " \"name\": \"feature_2\"," + + " \"params\": [\"keywords\"]," + + " \"template\": {" + + " \"match\": {" + + " \"esyww\": {" + + " \"query\": \"test1\"" + + " }}}}]}"; return inlineFeatureSet; } @@ -127,26 +126,27 @@ private void assertTestModelAsXContent(StoredLtrModel model) throws IOException } public void testCompileFeatureNorms() throws IOException { - String modelJson = "{\n" + - " \"name\":\"my_model\",\n" + - " \"feature_set\":" + getSimpleFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\",\n" + - " \"feature_normalizers\": {\n" + - " \"feature_1\": { \"standard\":" + - " {\"mean\": 1.25," + - " \"standard_deviation\": 0.25}}}" + - " }" + - "}"; + String modelJson = "{\n" + + " \"name\":\"my_model\",\n" + + " \"feature_set\":" + + getSimpleFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\",\n" + + " \"feature_normalizers\": {\n" + + " \"feature_1\": { \"standard\":" + + " {\"mean\": 1.25," + + " \"standard_deviation\": 0.25}}}" + + " }" + + "}"; StoredLtrModel model = parse(modelJson); CompiledLtrModel compiledModel = model.compile(factory); LtrRanker ranker = compiledModel.ranker(); assertEquals(ranker.getClass(), FeatureNormalizingRanker.class); - FeatureNormalizingRanker normRanker = (FeatureNormalizingRanker)ranker; + FeatureNormalizingRanker normRanker = (FeatureNormalizingRanker) ranker; LtrRanker.FeatureVector ftrVector = normRanker.newFeatureVector(null); @@ -161,26 +161,27 @@ public void testCompileFeatureNorms() throws IOException { } public void testFeatureStdNormParsing() throws IOException { - String modelJson = "{\n" + - " \"name\":\"my_model\",\n" + - " \"feature_set\":" + getSimpleFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\",\n"+ - " \"feature_normalizers\": {\n"+ - " \"feature_1\": { \"standard\":" + - " {\"mean\": 1.25," + - " \"standard_deviation\": 0.25}}}" + - " }" + - "}"; + String modelJson = "{\n" + + " \"name\":\"my_model\",\n" + + " \"feature_set\":" + + getSimpleFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\",\n" + + " \"feature_normalizers\": {\n" + + " \"feature_1\": { \"standard\":" + + " {\"mean\": 1.25," + + " \"standard_deviation\": 0.25}}}" + + " }" + + "}"; StoredLtrModel model = parse(modelJson); StoredFeatureNormalizers ftrNormSet = model.getFeatureNormalizers(); assertNotNull(ftrNormSet); - StandardFeatureNormalizer stdFtrNorm = (StandardFeatureNormalizer)ftrNormSet.getNormalizer("feature_1"); + StandardFeatureNormalizer stdFtrNorm = (StandardFeatureNormalizer) ftrNormSet.getNormalizer("feature_1"); assertNotNull(stdFtrNorm); float expectedMean = 1.25f; @@ -192,7 +193,7 @@ public void testFeatureStdNormParsing() throws IOException { StoredLtrModel reparsedModel = reparseModel(model); ftrNormSet = reparsedModel.getFeatureNormalizers(); - stdFtrNorm = (StandardFeatureNormalizer)ftrNormSet.getNormalizer("feature_1"); + stdFtrNorm = (StandardFeatureNormalizer) ftrNormSet.getNormalizer("feature_1"); testVal = Randomness.get().nextFloat(); expectedNormalized = (testVal - expectedMean) / expectedStdDev; @@ -202,26 +203,27 @@ public void testFeatureStdNormParsing() throws IOException { } public void testFeatureMinMaxParsing() throws IOException { - String modelJson = "{\n" + - " \"name\":\"my_model\",\n" + - " \"feature_set\":" + getSimpleFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\",\n"+ - " \"feature_normalizers\": {\n"+ - " \"feature_2\": { \"min_max\":" + - " {\"minimum\": 0.05," + - " \"maximum\": 1.25}}}" + - " }" + - "}"; + String modelJson = "{\n" + + " \"name\":\"my_model\",\n" + + " \"feature_set\":" + + getSimpleFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\",\n" + + " \"feature_normalizers\": {\n" + + " \"feature_2\": { \"min_max\":" + + " {\"minimum\": 0.05," + + " \"maximum\": 1.25}}}" + + " }" + + "}"; StoredLtrModel model = parse(modelJson); StoredFeatureNormalizers ftrNormSet = model.getFeatureNormalizers(); assertNotNull(ftrNormSet); - MinMaxFeatureNormalizer minMaxFtrNorm = (MinMaxFeatureNormalizer)ftrNormSet.getNormalizer("feature_2"); + MinMaxFeatureNormalizer minMaxFtrNorm = (MinMaxFeatureNormalizer) ftrNormSet.getNormalizer("feature_2"); float expectedMin = 0.05f; float expectedMax = 1.25f; @@ -231,7 +233,7 @@ public void testFeatureMinMaxParsing() throws IOException { StoredLtrModel reparsedModel = reparseModel(model); ftrNormSet = reparsedModel.getFeatureNormalizers(); - minMaxFtrNorm = (MinMaxFeatureNormalizer)ftrNormSet.getNormalizer("feature_2"); + minMaxFtrNorm = (MinMaxFeatureNormalizer) ftrNormSet.getNormalizer("feature_2"); testVal = Randomness.get().nextFloat(); expectedNormalized = (testVal - expectedMin) / (expectedMax - expectedMin); @@ -248,19 +250,20 @@ public StoredLtrModel reparseModel(StoredLtrModel srcModel) throws IOException { } public void testSerialization() throws IOException { - String modelJson = "{\n" + - " \"name\":\"my_model\",\n" + - " \"feature_set\":" + getSimpleFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\",\n"+ - " \"feature_normalizers\": {\n"+ - " \"feature_2\": { \"min_max\":" + - " {\"minimum\": 1.0," + - " \"maximum\": 1.25}}}" + - " }" + - "}"; + String modelJson = "{\n" + + " \"name\":\"my_model\",\n" + + " \"feature_set\":" + + getSimpleFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\",\n" + + " \"feature_normalizers\": {\n" + + " \"feature_2\": { \"min_max\":" + + " {\"minimum\": 1.0," + + " \"maximum\": 1.25}}}" + + " }" + + "}"; StoredLtrModel model = parse(modelJson); @@ -279,16 +282,15 @@ public void testSerialization() throws IOException { } public void testSerializationModelDef() throws IOException { - String modelDefnJson = "{\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\",\n"+ - " \"feature_normalizers\": {\n"+ - " \"feature_2\": { \"min_max\":" + - " {\"minimum\": 1.0," + - " \"maximum\": 1.25}}}}"; - - XContentParser xContent = jsonXContent.createParser(EMPTY, - LoggingDeprecationHandler.INSTANCE, modelDefnJson); + String modelDefnJson = "{\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\",\n" + + " \"feature_normalizers\": {\n" + + " \"feature_2\": { \"min_max\":" + + " {\"minimum\": 1.0," + + " \"maximum\": 1.25}}}}"; + + XContentParser xContent = jsonXContent.createParser(EMPTY, LoggingDeprecationHandler.INSTANCE, modelDefnJson); StoredLtrModel.LtrModelDefinition modelDef = StoredLtrModel.LtrModelDefinition.parse(xContent, null); BytesStreamOutput out = new BytesStreamOutput(); @@ -307,25 +309,24 @@ public void testSerializationModelDef() throws IOException { // aparo: disabled because it comes first of OpenSearch > 7.10 // public void testSerializationUpgradeBinaryStream() throws IOException { - // // Below is base64 encoded a model with no feature norm data - // // to ensure proper parsing of a binary stream missing ftr norms - // // - // // String modelDefnJson = "{\n" + - // // " \"type\": \"model/dummy\",\n" + - // // " \"definition\": \"completely ignored\"}"; - // String base64Encoded = "C21vZGVsL2R1bW15EmNvbXBsZXRlbHkgaWdub3JlZAE="; - // byte[] bytes = Base64.getDecoder().decode(base64Encoded); - // StreamInput input = ByteBufferStreamInput.wrap(bytes, 0, bytes.length); - // input.setVersion(Version.V_7_6_0); - - // StoredLtrModel.LtrModelDefinition modelUnserialized = new StoredLtrModel.LtrModelDefinition(input); - // assertEquals(modelUnserialized.getDefinition(), "completely ignored"); - // assertEquals(modelUnserialized.getType(), "model/dummy"); - // assertEquals(modelUnserialized.getFtrNorms().numNormalizers(), 0); + // // Below is base64 encoded a model with no feature norm data + // // to ensure proper parsing of a binary stream missing ftr norms + // // + // // String modelDefnJson = "{\n" + + // // " \"type\": \"model/dummy\",\n" + + // // " \"definition\": \"completely ignored\"}"; + // String base64Encoded = "C21vZGVsL2R1bW15EmNvbXBsZXRlbHkgaWdub3JlZAE="; + // byte[] bytes = Base64.getDecoder().decode(base64Encoded); + // StreamInput input = ByteBufferStreamInput.wrap(bytes, 0, bytes.length); + // input.setVersion(Version.V_7_6_0); + + // StoredLtrModel.LtrModelDefinition modelUnserialized = new StoredLtrModel.LtrModelDefinition(input); + // assertEquals(modelUnserialized.getDefinition(), "completely ignored"); + // assertEquals(modelUnserialized.getType(), "model/dummy"); + // assertEquals(modelUnserialized.getFtrNorms().numNormalizers(), 0); // } - public void testToXContent() throws IOException { StoredLtrModel model = parse(getTestModel()); @@ -342,84 +343,85 @@ public void testToXContent() throws IOException { } public void testParseFailureOnMissingName() throws IOException { - String modelString = "{\n" + - " \"feature_set\":" + - StoredFeatureSetParserTests.generateRandomFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\"\n"+ - " }" + - "}"; - assertThat(expectThrows(ParsingException.class, () -> parse(modelString)).getMessage(), - equalTo("Field [name] is mandatory")); + String modelString = "{\n" + + " \"feature_set\":" + + StoredFeatureSetParserTests.generateRandomFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\"\n" + + " }" + + "}"; + assertThat(expectThrows(ParsingException.class, () -> parse(modelString)).getMessage(), equalTo("Field [name] is mandatory")); } public void testParseWithExternalName() throws IOException { - String modelString = "{\n" + - " \"feature_set\":" + - StoredFeatureSetParserTests.generateRandomFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\"\n"+ - " }" + - "}"; + String modelString = "{\n" + + " \"feature_set\":" + + StoredFeatureSetParserTests.generateRandomFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\"\n" + + " }" + + "}"; StoredLtrModel model = parse(modelString, "myModel"); assertEquals("myModel", model.name()); } public void testParseWithInconsistentName() throws IOException { - String modelString = "{\n" + - " \"name\": \"myModel\"," + - " \"feature_set\":" + - StoredFeatureSetParserTests.generateRandomFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\"\n"+ - " }" + - "}"; - assertThat(expectThrows(ParsingException.class, () -> parse(modelString, "myModel2")).getMessage(), - equalTo("Invalid [name], expected [myModel2] but got [myModel]")); + String modelString = "{\n" + + " \"name\": \"myModel\"," + + " \"feature_set\":" + + StoredFeatureSetParserTests.generateRandomFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\"\n" + + " }" + + "}"; + assertThat( + expectThrows(ParsingException.class, () -> parse(modelString, "myModel2")).getMessage(), + equalTo("Invalid [name], expected [myModel2] but got [myModel]") + ); } public void testParseFailureOnMissingModel() throws IOException { - String modelString = "{\n" + - " \"name\":\"my_model\",\n" + - " \"feature_set\":" + - StoredFeatureSetParserTests.generateRandomFeatureSet() + - "}"; - assertThat(expectThrows(ParsingException.class, () -> parse(modelString)).getMessage(), - equalTo("Field [model] is mandatory")); + String modelString = "{\n" + + " \"name\":\"my_model\",\n" + + " \"feature_set\":" + + StoredFeatureSetParserTests.generateRandomFeatureSet() + + "}"; + assertThat(expectThrows(ParsingException.class, () -> parse(modelString)).getMessage(), equalTo("Field [model] is mandatory")); } public void testParseFailureOnMissingFeatureSet() throws IOException { - String modelString = "{\n" + - " \"name\":\"my_model\",\n" + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\"\n"+ - " }" + - "}"; - assertThat(expectThrows(ParsingException.class, () -> parse(modelString)).getMessage(), - equalTo("Field [feature_set] is mandatory")); + String modelString = "{\n" + + " \"name\":\"my_model\",\n" + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\"\n" + + " }" + + "}"; + assertThat( + expectThrows(ParsingException.class, () -> parse(modelString)).getMessage(), + equalTo("Field [feature_set] is mandatory") + ); } public void testParseFailureOnBogusField() throws IOException { - String modelString = "{\n" + - " \"name\":\"my_model\",\n" + - " \"bogusField\": \"foo\",\n" + - " \"feature_set\":" + - StoredFeatureSetParserTests.generateRandomFeatureSet() + - "," + - " \"model\": {\n" + - " \"type\": \"model/dummy\",\n" + - " \"definition\": \"completely ignored\"\n"+ - " }" + - "}"; - assertThat(expectThrows(ParsingException.class, () -> parse(modelString)).getMessage(), - containsString("bogusField")); + String modelString = "{\n" + + " \"name\":\"my_model\",\n" + + " \"bogusField\": \"foo\",\n" + + " \"feature_set\":" + + StoredFeatureSetParserTests.generateRandomFeatureSet() + + "," + + " \"model\": {\n" + + " \"type\": \"model/dummy\",\n" + + " \"definition\": \"completely ignored\"\n" + + " }" + + "}"; + assertThat(expectThrows(ParsingException.class, () -> parse(modelString)).getMessage(), containsString("bogusField")); } private StoredLtrModel parse(String jsonString) throws IOException { @@ -427,7 +429,6 @@ private StoredLtrModel parse(String jsonString) throws IOException { } private StoredLtrModel parse(String jsonString, String name) throws IOException { - return StoredLtrModel.parse(jsonXContent.createParser(EMPTY, - LoggingDeprecationHandler.INSTANCE, jsonString), name); + return StoredLtrModel.parse(jsonXContent.createParser(EMPTY, LoggingDeprecationHandler.INSTANCE, jsonString), name); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/feature/store/index/CachedFeatureStoreTests.java b/src/test/java/com/o19s/es/ltr/feature/store/index/CachedFeatureStoreTests.java index 0fc2a86b..4baae5c8 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/index/CachedFeatureStoreTests.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/index/CachedFeatureStoreTests.java @@ -16,20 +16,21 @@ package com.o19s.es.ltr.feature.store.index; -import com.o19s.es.ltr.LtrTestUtils; -import com.o19s.es.ltr.feature.store.CompiledLtrModel; -import com.o19s.es.ltr.feature.store.MemStore; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import static org.hamcrest.CoreMatchers.instanceOf; + +import java.io.IOException; + import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; import org.opensearch.common.settings.Settings; -import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.unit.ByteSizeValue; -import java.io.IOException; - -import static org.hamcrest.CoreMatchers.instanceOf; +import com.o19s.es.ltr.LtrTestUtils; +import com.o19s.es.ltr.feature.store.CompiledLtrModel; +import com.o19s.es.ltr.feature.store.MemStore; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; public class CachedFeatureStoreTests extends LuceneTestCase { private final MemStore memStore = new MemStore(); @@ -53,8 +54,7 @@ public void testCachedFeature() throws IOException { assertEquals(1, caches.getPerStoreStats(memStore.getStoreName()).featureCount()); assertEquals(feat.ramBytesUsed(), caches.getPerStoreStats(memStore.getStoreName()).totalRam()); assertEquals(1, caches.getPerStoreStats(memStore.getStoreName()).totalCount()); - assertThat(expectThrows(IOException.class, () -> store.load("unk")).getCause(), - instanceOf(IllegalArgumentException.class)); + assertThat(expectThrows(IOException.class, () -> store.load("unk")).getCause(), instanceOf(IllegalArgumentException.class)); } public void testCachedFeatureSet() throws IOException { @@ -71,8 +71,7 @@ public void testCachedFeatureSet() throws IOException { assertEquals(set.ramBytesUsed(), caches.getPerStoreStats(memStore.getStoreName()).totalRam()); assertEquals(1, caches.getPerStoreStats(memStore.getStoreName()).totalCount()); - assertThat(expectThrows(IOException.class, () -> store.loadSet("unk")).getCause(), - instanceOf(IllegalArgumentException.class)); + assertThat(expectThrows(IOException.class, () -> store.loadSet("unk")).getCause(), instanceOf(IllegalArgumentException.class)); } public void testCachedModelSet() throws IOException { @@ -88,8 +87,7 @@ public void testCachedModelSet() throws IOException { assertEquals(1, caches.getPerStoreStats(memStore.getStoreName()).modelCount()); assertEquals(model.ramBytesUsed(), caches.getPerStoreStats(memStore.getStoreName()).modelRam()); assertEquals(1, caches.getPerStoreStats(memStore.getStoreName()).totalCount()); - assertThat(expectThrows(IOException.class, () -> store.loadModel("unk")).getCause(), - instanceOf(IllegalArgumentException.class)); + assertThat(expectThrows(IOException.class, () -> store.loadModel("unk")).getCause(), instanceOf(IllegalArgumentException.class)); } public void testWontBlowUp() throws IOException { @@ -190,4 +188,4 @@ public void testCacheStatsIsolation() throws IOException { caches.evict(two.getStoreName()); assertTrue(caches.getCachedStoreNames().isEmpty()); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStoreTests.java b/src/test/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStoreTests.java index 597a1f1f..e582797b 100644 --- a/src/test/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStoreTests.java +++ b/src/test/java/com/o19s/es/ltr/feature/store/index/IndexFeatureStoreTests.java @@ -16,43 +16,44 @@ package com.o19s.es.ltr.feature.store.index; -import com.o19s.es.ltr.LtrTestUtils; -import com.o19s.es.ltr.feature.store.StorableElement; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureNormalizers; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.feature.store.StoredLtrModel; -import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; -import org.apache.lucene.util.BytesRef; +import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.STORE_PREFIX; +import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.indexName; +import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.isIndexStore; +import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.storeName; +import static org.apache.lucene.tests.util.TestUtil.randomRealisticUnicodeString; +import static org.apache.lucene.tests.util.TestUtil.randomSimpleString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Supplier; + import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.BytesRef; import org.hamcrest.MatcherAssert; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.get.GetRequestBuilder; import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.client.Requests; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Supplier; - -import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.STORE_PREFIX; -import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.indexName; -import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.isIndexStore; -import static com.o19s.es.ltr.feature.store.index.IndexFeatureStore.storeName; -import static org.apache.lucene.tests.util.TestUtil.randomRealisticUnicodeString; -import static org.apache.lucene.tests.util.TestUtil.randomSimpleString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import com.o19s.es.ltr.LtrTestUtils; +import com.o19s.es.ltr.feature.store.StorableElement; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureNormalizers; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.feature.store.StoredLtrModel; +import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; public class IndexFeatureStoreTests extends LuceneTestCase { @@ -75,13 +76,11 @@ public void testThat_exceptionIsThrown_forNonExistingFeature() { when(getResponseMock.isExists()).thenReturn(false); IndexFeatureStore store = new IndexFeatureStore("index", clientSupplier, mock(LtrRankerParserFactory.class)); - MatcherAssert.assertThat( - expectThrows( - ResourceNotFoundException.class, - () -> store.load("my_feature") - ).getMessage(), + MatcherAssert + .assertThat( + expectThrows(ResourceNotFoundException.class, () -> store.load("my_feature")).getMessage(), equalTo("Unknown feature [my_feature]") - ); + ); } public void testThat_exceptionIsThrown_forNonExistingFeatureSet() { @@ -89,13 +88,11 @@ public void testThat_exceptionIsThrown_forNonExistingFeatureSet() { when(getResponseMock.isExists()).thenReturn(false); IndexFeatureStore store = new IndexFeatureStore("index", clientSupplier, mock(LtrRankerParserFactory.class)); - MatcherAssert.assertThat( - expectThrows( - ResourceNotFoundException.class, - () -> store.loadSet("my_feature_set") - ).getMessage(), + MatcherAssert + .assertThat( + expectThrows(ResourceNotFoundException.class, () -> store.loadSet("my_feature_set")).getMessage(), equalTo("Unknown featureset [my_feature_set]") - ); + ); } public void testThat_exceptionIsThrown_forNonExistingModel() { @@ -103,25 +100,26 @@ public void testThat_exceptionIsThrown_forNonExistingModel() { when(getResponseMock.isExists()).thenReturn(false); IndexFeatureStore store = new IndexFeatureStore("index", clientSupplier, mock(LtrRankerParserFactory.class)); - MatcherAssert.assertThat( - expectThrows( - ResourceNotFoundException.class, - () -> store.loadModel("my_model") - ).getMessage(), + MatcherAssert + .assertThat( + expectThrows(ResourceNotFoundException.class, () -> store.loadModel("my_model")).getMessage(), equalTo("Unknown model [my_model]") - ); + ); } public void testParse() throws Exception { parseAssertions(LtrTestUtils.randomFeature()); parseAssertions(LtrTestUtils.randomFeatureSet()); - parseAssertions(new StoredLtrModel( + parseAssertions( + new StoredLtrModel( randomSimpleString(random(), 5, 10), LtrTestUtils.randomFeatureSet(), randomSimpleString(random(), 5, 10), randomRealisticUnicodeString(random(), 5, 1000), true, - new StoredFeatureNormalizers())); + new StoredFeatureNormalizers() + ) + ); } public void testIsIndexName() { @@ -160,22 +158,19 @@ public void testBadValues() throws IOException { Map map = new HashMap<>(); XContentBuilder builder = XContentBuilder.builder(Requests.INDEX_CONTENT_TYPE.xContent()); BytesReference bytes = BytesReference.bytes(builder.map(map)); - assertThat(expectThrows(IllegalArgumentException.class, - () -> IndexFeatureStore.parse(StoredFeature.class, StoredFeature.TYPE, bytes)) - .getMessage(), equalTo("No StorableElement found.")); + assertThat( + expectThrows(IllegalArgumentException.class, () -> IndexFeatureStore.parse(StoredFeature.class, StoredFeature.TYPE, bytes)) + .getMessage(), + equalTo("No StorableElement found.") + ); builder = XContentBuilder.builder(Requests.INDEX_CONTENT_TYPE.xContent()); map.put("featureset", LtrTestUtils.randomFeatureSet()); BytesReference bytes2 = BytesReference.bytes(builder.map(map)); assertThat( - expectThrows( - IllegalArgumentException.class, - () -> IndexFeatureStore.parse(StoredFeature.class, StoredFeature.TYPE, bytes2) - ).getMessage(), - equalTo( - "Expected an element of type [" + StoredFeature.TYPE + "] but" + - " got [" + StoredFeatureSet.TYPE + "]." - ) + expectThrows(IllegalArgumentException.class, () -> IndexFeatureStore.parse(StoredFeature.class, StoredFeature.TYPE, bytes2)) + .getMessage(), + equalTo("Expected an element of type [" + StoredFeature.TYPE + "] but" + " got [" + StoredFeatureSet.TYPE + "].") ); } @@ -194,11 +189,12 @@ private void parseAssertions(StorableElement elt) throws IOException { } private void assertNameAndTypes(StorableElement elt, BytesReference ref) throws IOException { - XContentParser parser = Requests.INDEX_CONTENT_TYPE.xContent().createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, ref.streamInput()); - Map map = parser.map(); + XContentParser parser = Requests.INDEX_CONTENT_TYPE + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, ref.streamInput()); + Map map = parser.map(); assertEquals(elt.name(), map.get("name")); assertEquals(elt.type(), map.get("type")); assertTrue(map.containsKey(elt.type())); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java b/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java index b25f8edf..b84bab08 100644 --- a/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java +++ b/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java @@ -16,17 +16,18 @@ package com.o19s.es.ltr.logging; -import org.opensearch.ltr.stats.LTRStat; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; -import com.o19s.es.ltr.feature.PrebuiltFeature; -import com.o19s.es.ltr.feature.PrebuiltFeatureSet; -import com.o19s.es.ltr.feature.PrebuiltLtrModel; -import com.o19s.es.ltr.logging.LoggingFetchSubPhase.LoggingFetchSubPhaseProcessor; -import com.o19s.es.ltr.query.RankerQuery; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.linear.LinearRankerTests; +import static java.util.Collections.unmodifiableMap; +import static org.opensearch.common.lucene.search.function.FieldValueFactorFunction.Modifier.LN2P; +import static org.opensearch.index.fielddata.IndexNumericFieldData.NumericType.FLOAT; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -48,47 +49,46 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; +import org.junit.AfterClass; +import org.junit.BeforeClass; import org.opensearch.common.lucene.search.function.CombineFunction; import org.opensearch.common.lucene.search.function.FieldValueFactorFunction; import org.opensearch.common.lucene.search.function.FunctionScoreQuery; -import org.opensearch.core.common.text.Text; import org.opensearch.index.fielddata.plain.SortedNumericIndexFieldData; +import org.opensearch.ltr.stats.LTRStat; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; import org.opensearch.search.SearchHit; import org.opensearch.search.fetch.FetchSubPhase; import org.opensearch.search.fetch.FetchSubPhaseProcessor; import org.opensearch.search.lookup.SourceLookup; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.UUID; - -import static java.util.Collections.unmodifiableMap; -import static org.opensearch.common.lucene.search.function.FieldValueFactorFunction.Modifier.LN2P; -import static org.opensearch.index.fielddata.IndexNumericFieldData.NumericType.FLOAT; +import com.o19s.es.ltr.feature.PrebuiltFeature; +import com.o19s.es.ltr.feature.PrebuiltFeatureSet; +import com.o19s.es.ltr.feature.PrebuiltLtrModel; +import com.o19s.es.ltr.logging.LoggingFetchSubPhase.LoggingFetchSubPhaseProcessor; +import com.o19s.es.ltr.query.RankerQuery; +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.linear.LinearRankerTests; public class LoggingFetchSubPhaseTests extends LuceneTestCase { public static final float FACTOR = 1.2F; private static Directory directory; private static IndexSearcher searcher; - private static Map docs; - private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() {{ - put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - }})); + private static Map docs; + private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() { + { + put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + } + })); @BeforeClass public static void init() throws Exception { directory = newDirectory(random()); - try(IndexWriter writer = new IndexWriter(directory, newIndexWriterConfig(new StandardAnalyzer()))) { + try (IndexWriter writer = new IndexWriter(directory, newIndexWriterConfig(new StandardAnalyzer()))) { int nDoc = TestUtil.nextInt(random(), 20, 100); docs = new HashMap<>(); for (int i = 0; i < nDoc; i++) { @@ -122,9 +122,9 @@ public void testLogging() throws IOException { query1 = query1.toLoggerQuery(logger1); query2 = query2.toLoggerQuery(logger2); BooleanQuery query = new BooleanQuery.Builder() - .add(new BooleanClause(query1, BooleanClause.Occur.MUST)) - .add(new BooleanClause(query2, BooleanClause.Occur.MUST)) - .build(); + .add(new BooleanClause(query1, BooleanClause.Occur.MUST)) + .add(new BooleanClause(query2, BooleanClause.Occur.MUST)) + .build(); LoggingFetchSubPhase subPhase = new LoggingFetchSubPhase(); Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1.0F); List loggers = Arrays.asList(logger1, logger2); @@ -151,20 +151,20 @@ public void testLogging() throws IOException { assertTrue(log1.get(0).containsKey("value")); assertEquals((Float) 0.0F, log1.get(0).get("value")); assertTrue(log2.get(0).containsKey("value")); - assertTrue((Float)log2.get(0).get("value") > 0F); + assertTrue((Float) log2.get(0).get("value") > 0F); } - int bits = (int)(long) d.getField("score").numericValue(); + int bits = (int) (long) d.getField("score").numericValue(); float rawScore = Float.intBitsToFloat(bits); - double expectedScore = rawScore*FACTOR; - expectedScore = Math.log1p(expectedScore+1); - assertEquals((float) expectedScore, (Float)log1.get(1).get("value"), Math.ulp((float)expectedScore)); - assertEquals((float) expectedScore, (Float)log1.get(1).get("value"), Math.ulp((float)expectedScore)); + double expectedScore = rawScore * FACTOR; + expectedScore = Math.log1p(expectedScore + 1); + assertEquals((float) expectedScore, (Float) log1.get(1).get("value"), Math.ulp((float) expectedScore)); + assertEquals((float) expectedScore, (Float) log1.get(1).get("value"), Math.ulp((float) expectedScore)); } } public SearchHit[] preprocessRandomHits(FetchSubPhaseProcessor processor) throws IOException { int minHits = TestUtil.nextInt(random(), 5, 10); - int maxHits = TestUtil.nextInt(random(), minHits, minHits+10); + int maxHits = TestUtil.nextInt(random(), minHits, minHits + 10); List hits = new ArrayList<>(maxHits); searcher.search(new MatchAllDocsQuery(), new SimpleCollector() { /** @@ -189,12 +189,7 @@ public void collect(int doc) throws IOException { if (hits.size() < minHits || (random().nextBoolean() && hits.size() < maxHits)) { Document d = context.reader().document(doc); String id = d.get("id"); - SearchHit hit = new SearchHit( - doc, - id, - random().nextBoolean() ? new HashMap<>() : null, - null - ); + SearchHit hit = new SearchHit(doc, id, random().nextBoolean() ? new HashMap<>() : null, null); processor.process(new FetchSubPhase.HitContext(hit, context, doc, new SourceLookup())); hits.add(hit); } @@ -225,9 +220,13 @@ public RankerQuery buildQuery(String text) { } public Query buildFunctionScore() { - FieldValueFactorFunction fieldValueFactorFunction = new FieldValueFactorFunction("score", FACTOR, LN2P, 0D, - new SortedNumericIndexFieldData("score", FLOAT)); - return new FunctionScoreQuery(new MatchAllDocsQuery(), - fieldValueFactorFunction, CombineFunction.MULTIPLY, 0F, Float.MAX_VALUE); + FieldValueFactorFunction fieldValueFactorFunction = new FieldValueFactorFunction( + "score", + FACTOR, + LN2P, + 0D, + new SortedNumericIndexFieldData("score", FLOAT) + ); + return new FunctionScoreQuery(new MatchAllDocsQuery(), fieldValueFactorFunction, CombineFunction.MULTIPLY, 0F, Float.MAX_VALUE); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/logging/LoggingSearchExtBuilderTests.java b/src/test/java/com/o19s/es/ltr/logging/LoggingSearchExtBuilderTests.java index 975a5093..a151d24f 100644 --- a/src/test/java/com/o19s/es/ltr/logging/LoggingSearchExtBuilderTests.java +++ b/src/test/java/com/o19s/es/ltr/logging/LoggingSearchExtBuilderTests.java @@ -16,21 +16,21 @@ package com.o19s.es.ltr.logging; -import org.opensearch.core.common.ParsingException; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.opensearch.test.OpenSearchTestCase; +import static com.o19s.es.ltr.logging.LoggingSearchExtBuilder.parse; +import static org.hamcrest.CoreMatchers.containsString; import java.io.IOException; import java.util.List; import java.util.stream.Collectors; -import static com.o19s.es.ltr.logging.LoggingSearchExtBuilder.parse; -import static org.hamcrest.CoreMatchers.containsString; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; public class LoggingSearchExtBuilderTests extends OpenSearchTestCase { public LoggingSearchExtBuilder buildTestExt() { @@ -43,11 +43,11 @@ public LoggingSearchExtBuilder buildTestExt() { } public String getTestExtAsString() { - return "{\"log_specs\":[" + - "{\"name\":\"name1\",\"named_query\":\"query1\",\"missing_as_zero\":true}," + - "{\"named_query\":\"query2\"}," + - "{\"name\":\"rescore0\",\"rescore_index\":0,\"missing_as_zero\":true}," + - "{\"rescore_index\":1}]}"; + return "{\"log_specs\":[" + + "{\"name\":\"name1\",\"named_query\":\"query1\",\"missing_as_zero\":true}," + + "{\"named_query\":\"query2\"}," + + "{\"name\":\"rescore0\",\"rescore_index\":0,\"missing_as_zero\":true}," + + "{\"rescore_index\":1}]}"; } public void testEquals() { @@ -83,38 +83,26 @@ public void testSer() throws IOException { public void testFailOnNoLogSpecs() throws IOException { String data = "{}"; - ParsingException exc = expectThrows(ParsingException.class, - () ->parse(createParser(JsonXContent.jsonXContent, data))); - assertThat(exc.getMessage(), - containsString("should define at least one [log_specs]")); + ParsingException exc = expectThrows(ParsingException.class, () -> parse(createParser(JsonXContent.jsonXContent, data))); + assertThat(exc.getMessage(), containsString("should define at least one [log_specs]")); } public void testFailOnEmptyLogSpecs() throws IOException { String data = "{\"log_specs\":[]}"; - ParsingException exc = expectThrows(ParsingException.class, - () ->parse(createParser(JsonXContent.jsonXContent, data))); - assertThat(exc.getMessage(), - containsString("should define at least one [log_specs]")); + ParsingException exc = expectThrows(ParsingException.class, () -> parse(createParser(JsonXContent.jsonXContent, data))); + assertThat(exc.getMessage(), containsString("should define at least one [log_specs]")); } public void testFailOnBadLogSpec() throws IOException { - String data = "{\"log_specs\":[" + - "{\"name\":\"name1\",\"missing_as_zero\":true}," + - "]}"; - ParsingException exc = expectThrows(ParsingException.class, - () ->parse(createParser(JsonXContent.jsonXContent, data))); - assertThat(exc.getCause().getCause().getMessage(), - containsString("Either [named_query] or [rescore_index] must be set")); + String data = "{\"log_specs\":[" + "{\"name\":\"name1\",\"missing_as_zero\":true}," + "]}"; + ParsingException exc = expectThrows(ParsingException.class, () -> parse(createParser(JsonXContent.jsonXContent, data))); + assertThat(exc.getCause().getCause().getMessage(), containsString("Either [named_query] or [rescore_index] must be set")); } public void testFailOnNegativeRescoreIndex() throws IOException { - String data = "{\"log_specs\":[" + - "{\"name\":\"name1\",\"rescore_index\":-1, \"missing_as_zero\":true}," + - "]}"; - ParsingException exc = expectThrows(ParsingException.class, - () ->parse(createParser(JsonXContent.jsonXContent, data))); - assertThat(exc.getCause().getCause().getMessage(), - containsString("non-negative")); + String data = "{\"log_specs\":[" + "{\"name\":\"name1\",\"rescore_index\":-1, \"missing_as_zero\":true}," + "]}"; + ParsingException exc = expectThrows(ParsingException.class, () -> parse(createParser(JsonXContent.jsonXContent, data))); + assertThat(exc.getCause().getCause().getMessage(), containsString("non-negative")); } public void assertTestExt(LoggingSearchExtBuilder actual) { @@ -144,4 +132,4 @@ public void assertTestExt(LoggingSearchExtBuilder actual) { assertEquals((Integer) 1, l.getRescoreIndex()); assertFalse(l.isMissingAsZero()); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/query/LtrQueryBuilderTests.java b/src/test/java/com/o19s/es/ltr/query/LtrQueryBuilderTests.java index 93cd7df3..ed4ab8de 100644 --- a/src/test/java/com/o19s/es/ltr/query/LtrQueryBuilderTests.java +++ b/src/test/java/com/o19s/es/ltr/query/LtrQueryBuilderTests.java @@ -16,33 +16,34 @@ */ package com.o19s.es.ltr.query; -import org.opensearch.ltr.stats.LTRStat; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; -import com.o19s.es.ltr.LtrQueryParserPlugin; +import static java.util.Collections.unmodifiableMap; +import static org.hamcrest.CoreMatchers.instanceOf; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + import org.apache.lucene.search.Query; import org.opensearch.index.query.MatchAllQueryBuilder; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.WrapperQueryBuilder; +import org.opensearch.ltr.stats.LTRStat; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; import org.opensearch.plugins.Plugin; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; import org.opensearch.test.AbstractQueryTestCase; import org.opensearch.test.TestGeoShapeFieldMapperPlugin; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; - -import static java.util.Collections.unmodifiableMap; -import static org.hamcrest.CoreMatchers.instanceOf; +import com.o19s.es.ltr.LtrQueryParserPlugin; /** * Created by doug on 12/27/16. @@ -53,93 +54,98 @@ public class LtrQueryBuilderTests extends AbstractQueryTestCase protected Collection> getPlugins() { return Arrays.asList(LtrQueryParserPlugin.class, TestGeoShapeFieldMapperPlugin.class); } - private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() {{ - put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - }})); - private static final String simpleModel = "## LambdaMART\\n" + - "## name:foo\\n" + - "## No. of trees = 1\\n" + - "## No. of leaves = 10\\n" + - "## No. of threshold candidates = 256\\n" + - "## Learning rate = 0.1\\n" + - "## Stop early = 100\\n" + - "\\n" + - "\\n" + - " \\n" + - " \\n" + - " 1 \\n" + - " 0.45867884 \\n" + - " \\n" + - " 1 \\n" + - " 0.0 \\n" + - " \\n" + - " -2.0 \\n" + - " \\n" + - " \\n" + - " -1.3413081169128418 \\n" + - " \\n" + - " \\n" + - " \\n" + - " 1 \\n" + - " 0.6115718 \\n" + - " \\n" + - " 0.3089442849159241 \\n" + - " \\n" + - " \\n" + - " 2.0 \\n" + - " \\n" + - " \\n" + - " \\n" + - " \\n" + - ""; + private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() { + { + put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + } + })); + + private static final String simpleModel = "## LambdaMART\\n" + + "## name:foo\\n" + + "## No. of trees = 1\\n" + + "## No. of leaves = 10\\n" + + "## No. of threshold candidates = 256\\n" + + "## Learning rate = 0.1\\n" + + "## Stop early = 100\\n" + + "\\n" + + "\\n" + + " \\n" + + " \\n" + + " 1 \\n" + + " 0.45867884 \\n" + + " \\n" + + " 1 \\n" + + " 0.0 \\n" + + " \\n" + + " -2.0 \\n" + + " \\n" + + " \\n" + + " -1.3413081169128418 \\n" + + " \\n" + + " \\n" + + " \\n" + + " 1 \\n" + + " 0.6115718 \\n" + + " \\n" + + " 0.3089442849159241 \\n" + + " \\n" + + " \\n" + + " 2.0 \\n" + + " \\n" + + " \\n" + + " \\n" + + " \\n" + + ""; public void testCachedQueryParsing() throws IOException { String scriptSpec = "{\"source\": \"" + simpleModel + "\"}"; - String ltrQuery = "{ " + - " \"ltr\": {" + - " \"model\": " + scriptSpec + ", " + - " \"features\": [ " + - " {\"match\": { " + - " \"foo\": \"bar\" " + - " }}, " + - " {\"match\": { " + - " \"baz\": \"sham\" " + - " }} " + - " ] " + - " } " + - "}"; - LtrQueryBuilder queryBuilder = (LtrQueryBuilder)parseQuery(ltrQuery); + String ltrQuery = "{ " + + " \"ltr\": {" + + " \"model\": " + + scriptSpec + + ", " + + " \"features\": [ " + + " {\"match\": { " + + " \"foo\": \"bar\" " + + " }}, " + + " {\"match\": { " + + " \"baz\": \"sham\" " + + " }} " + + " ] " + + " } " + + "}"; + LtrQueryBuilder queryBuilder = (LtrQueryBuilder) parseQuery(ltrQuery); } public void testNamedFeatures() throws IOException { String scriptSpec = "{\"source\": \"" + simpleModel + "\"}"; - String ltrQuery = "{ " + - " \"ltr\": {" + - " \"model\": " + scriptSpec + ", " + - " \"features\": [ " + - " {\"match\": { " + - " \"foo\": { " + - " \"query\": \"bar\", " + - " \"_name\": \"bar_query\" " + - " }}}, " + - " {\"match\": { " + - " \"baz\": {" + - " \"query\": \"sham\"," + - " \"_name\": \"sham_query\" " + - " }}} " + - " ] " + - " } " + - "}"; - LtrQueryBuilder queryBuilder = (LtrQueryBuilder)parseQuery(ltrQuery); + String ltrQuery = "{ " + + " \"ltr\": {" + + " \"model\": " + + scriptSpec + + ", " + + " \"features\": [ " + + " {\"match\": { " + + " \"foo\": { " + + " \"query\": \"bar\", " + + " \"_name\": \"bar_query\" " + + " }}}, " + + " {\"match\": { " + + " \"baz\": {" + + " \"query\": \"sham\"," + + " \"_name\": \"sham_query\" " + + " }}} " + + " ] " + + " } " + + "}"; + LtrQueryBuilder queryBuilder = (LtrQueryBuilder) parseQuery(ltrQuery); queryBuilder.ltrStats(ltrStats); QueryShardContext context = createShardContext(); - RankerQuery query = (RankerQuery)queryBuilder.toQuery(context); + RankerQuery query = (RankerQuery) queryBuilder.toQuery(context); assertEquals(query.getFeature(0).name(), "bar_query"); assertEquals(query.getFeature(1).name(), "sham_query"); @@ -148,26 +154,28 @@ public void testNamedFeatures() throws IOException { public void testUnnamedFeatures() throws IOException { String scriptSpec = "{\"source\": \"" + simpleModel + "\"}"; - String ltrQuery = "{ " + - " \"ltr\": {" + - " \"model\": " + scriptSpec + ", " + - " \"features\": [ " + - " {\"match\": { " + - " \"foo\": { " + - " \"query\": \"bar\" " + - " }}}, " + - " {\"match\": { " + - " \"baz\": {" + - " \"query\": \"sham\"," + - " \"_name\": \"\" " + - " }}} " + - " ] " + - " } " + - "}"; - LtrQueryBuilder queryBuilder = (LtrQueryBuilder)parseQuery(ltrQuery); + String ltrQuery = "{ " + + " \"ltr\": {" + + " \"model\": " + + scriptSpec + + ", " + + " \"features\": [ " + + " {\"match\": { " + + " \"foo\": { " + + " \"query\": \"bar\" " + + " }}}, " + + " {\"match\": { " + + " \"baz\": {" + + " \"query\": \"sham\"," + + " \"_name\": \"\" " + + " }}} " + + " ] " + + " } " + + "}"; + LtrQueryBuilder queryBuilder = (LtrQueryBuilder) parseQuery(ltrQuery); queryBuilder.ltrStats(ltrStats); QueryShardContext context = createShardContext(); - RankerQuery query = (RankerQuery)queryBuilder.toQuery(context); + RankerQuery query = (RankerQuery) queryBuilder.toQuery(context); assertNull(query.getFeature(0).name()); assertEquals(query.getFeature(1).name(), ""); @@ -191,15 +199,17 @@ public void testCacheability() throws IOException { @Override protected LtrQueryBuilder doCreateTestQueryBuilder() { LtrQueryBuilder builder = new LtrQueryBuilder(); - builder.features(Arrays.asList( - new MatchQueryBuilder("foo", "bar"), - new MatchQueryBuilder("baz", "sham") - )); - builder.rankerScript(new Script(ScriptType.INLINE, "ranklib", - // Remove escape sequences - simpleModel.replace("\\\"", "\"") - .replace("\\n", "\n"), - Collections.emptyMap())); + builder.features(Arrays.asList(new MatchQueryBuilder("foo", "bar"), new MatchQueryBuilder("baz", "sham"))); + builder + .rankerScript( + new Script( + ScriptType.INLINE, + "ranklib", + // Remove escape sequences + simpleModel.replace("\\\"", "\"").replace("\\n", "\n"), + Collections.emptyMap() + ) + ); builder.ltrStats(ltrStats); return builder; } diff --git a/src/test/java/com/o19s/es/ltr/query/LtrQueryTests.java b/src/test/java/com/o19s/es/ltr/query/LtrQueryTests.java index a9fe0d9f..426b0123 100644 --- a/src/test/java/com/o19s/es/ltr/query/LtrQueryTests.java +++ b/src/test/java/com/o19s/es/ltr/query/LtrQueryTests.java @@ -16,29 +16,19 @@ */ package com.o19s.es.ltr.query; -import ciir.umass.edu.learning.DataPoint; -import ciir.umass.edu.learning.RANKER_TYPE; -import ciir.umass.edu.learning.RankList; -import ciir.umass.edu.learning.Ranker; -import ciir.umass.edu.learning.RankerFactory; -import ciir.umass.edu.learning.RankerTrainer; -import ciir.umass.edu.metric.NDCGScorer; -import ciir.umass.edu.utilities.MyThreadPool; -import org.opensearch.ltr.stats.LTRStat; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.PrebuiltFeature; -import com.o19s.es.ltr.feature.PrebuiltFeatureSet; -import com.o19s.es.ltr.feature.PrebuiltLtrModel; -import com.o19s.es.ltr.ranker.LogLtrRanker; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; -import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; -import com.o19s.es.ltr.ranker.ranklib.DenseProgramaticDataPoint; -import com.o19s.es.ltr.ranker.ranklib.RanklibRanker; +import static java.util.Collections.unmodifiableMap; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.stream.Collectors; + import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Store; @@ -46,7 +36,6 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.index.Term; import org.apache.lucene.misc.SweetSpotSimilarity; import org.apache.lucene.queries.BlendedTermQuery; @@ -80,47 +69,59 @@ import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.TFIDFSimilarity; import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.common.lucene.search.function.FunctionScoreQuery; -import org.opensearch.common.lucene.search.function.WeightFactorFunction; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; +import org.opensearch.common.lucene.search.function.FunctionScoreQuery; +import org.opensearch.common.lucene.search.function.WeightFactorFunction; +import org.opensearch.ltr.stats.LTRStat; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; -import java.util.stream.Collectors; -import static java.util.Collections.unmodifiableMap; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.PrebuiltFeature; +import com.o19s.es.ltr.feature.PrebuiltFeatureSet; +import com.o19s.es.ltr.feature.PrebuiltLtrModel; +import com.o19s.es.ltr.ranker.LogLtrRanker; +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; +import com.o19s.es.ltr.ranker.ranklib.DenseProgramaticDataPoint; +import com.o19s.es.ltr.ranker.ranklib.RanklibRanker; + +import ciir.umass.edu.learning.DataPoint; +import ciir.umass.edu.learning.RANKER_TYPE; +import ciir.umass.edu.learning.RankList; +import ciir.umass.edu.learning.Ranker; +import ciir.umass.edu.learning.RankerFactory; +import ciir.umass.edu.learning.RankerTrainer; +import ciir.umass.edu.metric.NDCGScorer; +import ciir.umass.edu.utilities.MyThreadPool; @LuceneTestCase.SuppressSysoutChecks(bugUrl = "RankURL does this when training models... ") public class LtrQueryTests extends LuceneTestCase { // Number of ULPs allowed when checking scores equality private static final int SCORE_NB_ULP_PREC = 1; + private int[] range(int start, int stop) { + int[] result = new int[stop - start]; - private int[] range(int start, int stop) - { - int[] result = new int[stop-start]; - - for(int i=0;i>() {{ - put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - }})); + private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() { + { + put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + } + })); private Field newField(String name, String value, Store stored) { FieldType tagsFieldType = new FieldType(); @@ -137,15 +138,13 @@ private Field newField(String name, String value, Store stored) { private Similarity similarity; // docs with doc ids array index - private final String[] docs = new String[] { "how now brown cow", - "brown is the color of cows", - "brown cow", - "banana cows are yummy"}; + private final String[] docs = new String[] { "how now brown cow", "brown is the color of cows", "brown cow", "banana cows are yummy" }; @Before public void setupIndex() throws IOException { dirUnderTest = newDirectory(); - List sims = Arrays.asList( + List sims = Arrays + .asList( new ClassicSimilarity(), new SweetSpotSimilarity(), // extends Classic new BM25Similarity(), @@ -170,7 +169,6 @@ public void setupIndex() throws IOException { indexWriterUnderTest.forceMerge(1); indexWriterUnderTest.flush(); - indexReaderUnderTest = indexWriterUnderTest.getReader(); searcherUnderTest = newSearcher(indexReaderUnderTest); searcherUnderTest.setSimilarity(similarity); @@ -206,6 +204,7 @@ public void reset() { private LeafReaderContext context; private Scorable scorer; + /** * Indicates what features are required from the scorer. */ @@ -235,13 +234,15 @@ public void collect(int doc) throws IOException { return featuresPerDoc; } - public List makeQueryJudgements(int qid, - Map> featuresPerDoc, - int modelSize, - Float[] relevanceGradesPerDoc, - Map ftrNorms) { - assert(featuresPerDoc.size() == docs.length); - assert(relevanceGradesPerDoc.length == docs.length); + public List makeQueryJudgements( + int qid, + Map> featuresPerDoc, + int modelSize, + Float[] relevanceGradesPerDoc, + Map ftrNorms + ) { + assert (featuresPerDoc.size() == docs.length); + assert (relevanceGradesPerDoc.length == docs.length); List rVal = new ArrayList<>(); SortedMap points = new TreeMap<>(); @@ -250,15 +251,13 @@ public List makeQueryJudgements(int qid, int docId = Integer.decode(doc); dp.setLabel(relevanceGradesPerDoc[docId]); dp.setID(String.valueOf(qid)); - vector.forEach( - (final Integer ftrOrd, Float score) -> { - Normalizer ftrNorm = ftrNorms.get(ftrOrd); - if (ftrNorm != null) { - score = ftrNorm.normalize(score); - } - dp.setFeatureScore(ftrOrd, score); + vector.forEach((final Integer ftrOrd, Float score) -> { + Normalizer ftrNorm = ftrNorms.get(ftrOrd); + if (ftrNorm != null) { + score = ftrNorm.normalize(score); } - ); + dp.setFeatureScore(ftrOrd, score); + }); points.put(docId, dp); }); points.forEach((k, v) -> rVal.add(v)); @@ -268,7 +267,7 @@ public List makeQueryJudgements(int qid, public void checkFeatureNames(Explanation expl, List features) { Explanation[] expls = expl.getDetails(); int ftrIdx = 0; - for (Explanation ftrExpl: expls) { + for (Explanation ftrExpl : expls) { String ftrName = features.get(ftrIdx).name(); String expectedFtrName; if (ftrName == null) { @@ -277,30 +276,30 @@ public void checkFeatureNames(Explanation expl, List features) expectedFtrName = "Feature " + ftrIdx + "(" + ftrName + "):"; } - String ftrExplainStart = ftrExpl.getDescription().substring(0,expectedFtrName.length()); + String ftrExplainStart = ftrExpl.getDescription().substring(0, expectedFtrName.length()); assertEquals(expectedFtrName, ftrExplainStart); ftrIdx++; } } - public void checkModelWithFeatures(List features, int[] modelFeatures, - Map ftrNorms) throws IOException { + public void checkModelWithFeatures(List features, int[] modelFeatures, Map ftrNorms) + throws IOException { // Each RankList needed for training corresponds to one query, // or that apperas how RankLib wants the data List samples = new ArrayList<>(); - Map> rawFeaturesPerDoc = getFeatureScores(features, 0.0f); + Map> rawFeaturesPerDoc = getFeatureScores(features, 0.0f); if (ftrNorms == null) { ftrNorms = new HashMap<>(); } // Normalize prior to training - // these ranklists have been normalized for training - RankList rl = new RankList(makeQueryJudgements(0, rawFeaturesPerDoc, features.size(), - new Float[] {3.0f, 2.0f, 4.0f, 0.0f}, ftrNorms)); + RankList rl = new RankList( + makeQueryJudgements(0, rawFeaturesPerDoc, features.size(), new Float[] { 3.0f, 2.0f, 4.0f, 0.0f }, ftrNorms) + ); samples.add(rl); int[] featuresToUse = modelFeatures; @@ -311,23 +310,25 @@ public void checkModelWithFeatures(List features, int[] modelFe // each RankList appears to correspond to a // query RankerTrainer trainer = new RankerTrainer(); - Ranker ranker = trainer.train(/*what type of model ot train*/RANKER_TYPE.RANKNET, - /*The training data*/ samples - /*which features to use*/, featuresToUse - /*how to score ranking*/, new NDCGScorer()); - float[] scores = {(float)ranker.eval(rl.get(0)), (float)ranker.eval(rl.get(1)), - (float)ranker.eval(rl.get(2)), (float)ranker.eval(rl.get(3))}; + Ranker ranker = trainer.train(/*what type of model ot train*/RANKER_TYPE.RANKNET, /*The training data*/ samples + /*which features to use*/, featuresToUse + /*how to score ranking*/, new NDCGScorer()); + float[] scores = { + (float) ranker.eval(rl.get(0)), + (float) ranker.eval(rl.get(1)), + (float) ranker.eval(rl.get(2)), + (float) ranker.eval(rl.get(3)) }; // Ok now lets rerun that as a Lucene Query RankerQuery ltrQuery = toRankerQuery(features, ranker, ftrNorms); TopDocs topDocs = searcherUnderTest.search(ltrQuery, 10); ScoreDoc[] scoreDocs = topDocs.scoreDocs; - assert(scoreDocs.length == docs.length); + assert (scoreDocs.length == docs.length); ScoreDoc sc = scoreDocs[0]; scoreDocs[0] = scoreDocs[2]; scoreDocs[2] = sc; - for (ScoreDoc scoreDoc: scoreDocs) { + for (ScoreDoc scoreDoc : scoreDocs) { assertScoresMatch(features, scores, ltrQuery, scoreDoc); } @@ -336,35 +337,46 @@ public void checkModelWithFeatures(List features, int[] modelFe String modelAsStr = ranker.model(); RankerFactory rankerFactory = new RankerFactory(); Ranker rankerAgain = rankerFactory.loadRankerFromString(modelAsStr); - float[] scoresAgain = {(float)ranker.eval(rl.get(0)), (float)ranker.eval(rl.get(1)), - (float)ranker.eval(rl.get(2)), (float)ranker.eval(rl.get(3))}; + float[] scoresAgain = { + (float) ranker.eval(rl.get(0)), + (float) ranker.eval(rl.get(1)), + (float) ranker.eval(rl.get(2)), + (float) ranker.eval(rl.get(3)) }; topDocs = searcherUnderTest.search(ltrQuery, 10); scoreDocs = topDocs.scoreDocs; - assert(scoreDocs.length == docs.length); - for (ScoreDoc scoreDoc: scoreDocs) { + assert (scoreDocs.length == docs.length); + for (ScoreDoc scoreDoc : scoreDocs) { assertScoresMatch(features, scoresAgain, ltrQuery, scoreDoc); } } - private void assertScoresMatch(List features, float[] scores, - RankerQuery ltrQuery, ScoreDoc scoreDoc) throws IOException { + private void assertScoresMatch(List features, float[] scores, RankerQuery ltrQuery, ScoreDoc scoreDoc) + throws IOException { Document d = searcherUnderTest.doc(scoreDoc.doc); String idVal = d.get("id"); int docId = Integer.decode(idVal); float modelScore = scores[docId]; float queryScore = scoreDoc.score; - assertEquals("Scores match with similarity " + similarity.getClass(), modelScore, - queryScore, SCORE_NB_ULP_PREC *Math.ulp(modelScore)); + assertEquals( + "Scores match with similarity " + similarity.getClass(), + modelScore, + queryScore, + SCORE_NB_ULP_PREC * Math.ulp(modelScore) + ); if (!(similarity instanceof TFIDFSimilarity)) { // There are precision issues with these similarities when using explain // It produces 0.56103003 for feat:0 in doc1 using score() but 0.5610301 using explain Explanation expl = searcherUnderTest.explain(ltrQuery, docId); - assertEquals("Explain scores match with similarity " + similarity.getClass(), expl.getValue().floatValue(), - queryScore, 5 * Math.ulp(modelScore)); + assertEquals( + "Explain scores match with similarity " + similarity.getClass(), + expl.getValue().floatValue(), + queryScore, + 5 * Math.ulp(modelScore) + ); checkFeatureNames(expl, features); } } @@ -380,74 +392,75 @@ private RankerQuery toRankerQuery(List features, Ranker ranker, public void testTrainModel() throws IOException { String userQuery = "brown cow"; - List features = Arrays.asList( - new TermQuery(new Term("field", userQuery.split(" ")[0])), - new PhraseQuery("field", userQuery.split(" "))); + List features = Arrays + .asList(new TermQuery(new Term("field", userQuery.split(" ")[0])), new PhraseQuery("field", userQuery.split(" "))); checkModelWithFeatures(toPrebuildFeatureWithNoName(features), null, null); } public void testSubsetFeaturesFuncScore() throws IOException { - // public LambdaMART(List samples, int[] features, MetricScorer scorer) { + // public LambdaMART(List samples, int[] features, MetricScorer scorer) { String userQuery = "brown cow"; Query baseQuery = new MatchAllDocsQuery(); - List features = Arrays.asList( - new TermQuery(new Term("field", userQuery.split(" ")[0])), + List features = Arrays + .asList( + new TermQuery(new Term("field", userQuery.split(" ")[0])), new PhraseQuery("field", userQuery.split(" ")), - new FunctionScoreQuery(baseQuery, new WeightFactorFunction(1.0f)) ); - checkModelWithFeatures(toPrebuildFeatureWithNoName(features), new int[] {1}, null); + new FunctionScoreQuery(baseQuery, new WeightFactorFunction(1.0f)) + ); + checkModelWithFeatures(toPrebuildFeatureWithNoName(features), new int[] { 1 }, null); } public void testSubsetFeaturesTermQ() throws IOException { - // public LambdaMART(List samples, int[] features, MetricScorer scorer) { + // public LambdaMART(List samples, int[] features, MetricScorer scorer) { String userQuery = "brown cow"; Query baseQuery = new MatchAllDocsQuery(); - List features = Arrays.asList( - new TermQuery(new Term("field", userQuery.split(" ")[0])), + List features = Arrays + .asList( + new TermQuery(new Term("field", userQuery.split(" ")[0])), new PhraseQuery("field", userQuery.split(" ")), - new PhraseQuery(1, "field", userQuery.split(" ") )); - checkModelWithFeatures(toPrebuildFeatureWithNoName(features), new int[] {1}, null); + new PhraseQuery(1, "field", userQuery.split(" ")) + ); + checkModelWithFeatures(toPrebuildFeatureWithNoName(features), new int[] { 1 }, null); } - public void testExplainWithNames() throws IOException { - // public LambdaMART(List samples, int[] features, MetricScorer scorer) { + // public LambdaMART(List samples, int[] features, MetricScorer scorer) { String userQuery = "brown cow"; - List features = Arrays.asList( - new PrebuiltFeature("funky_term_q", new TermQuery(new Term("field", userQuery.split(" ")[0]))), - new PrebuiltFeature("funky_phrase_q", new PhraseQuery("field", userQuery.split(" ")))); + List features = Arrays + .asList( + new PrebuiltFeature("funky_term_q", new TermQuery(new Term("field", userQuery.split(" ")[0]))), + new PrebuiltFeature("funky_phrase_q", new PhraseQuery("field", userQuery.split(" "))) + ); checkModelWithFeatures(features, null, null); } public void testOnRewrittenQueries() throws IOException { String userQuery = "brown cow"; - Term[] termsToBlend = new Term[]{new Term("field", userQuery.split(" ")[0])}; + Term[] termsToBlend = new Term[] { new Term("field", userQuery.split(" ")[0]) }; Query blended = BlendedTermQuery.dismaxBlendedQuery(termsToBlend, 1f); - List features = Arrays.asList(new TermQuery(new Term("field", userQuery.split(" ")[0])), blended); + List features = Arrays.asList(new TermQuery(new Term("field", userQuery.split(" ")[0])), blended); checkModelWithFeatures(toPrebuildFeatureWithNoName(features), null, null); } private List toPrebuildFeatureWithNoName(List features) { - return features.stream() - .map(x -> new PrebuiltFeature(null, x)) - .collect(Collectors.toList()); + return features.stream().map(x -> new PrebuiltFeature(null, x)).collect(Collectors.toList()); } public void testNoMatchQueries() throws IOException { String userQuery = "brown cow"; - Term[] termsToBlend = new Term[]{new Term("field", userQuery.split(" ")[0])}; + Term[] termsToBlend = new Term[] { new Term("field", userQuery.split(" ")[0]) }; Query blended = BlendedTermQuery.dismaxBlendedQuery(termsToBlend, 1f); - List features = Arrays.asList( - new PrebuiltFeature(null, new TermQuery(new Term("field", "missingterm"))), - new PrebuiltFeature(null, blended)); + List features = Arrays + .asList(new PrebuiltFeature(null, new TermQuery(new Term("field", "missingterm"))), new PrebuiltFeature(null, blended)); checkModelWithFeatures(features, null, null); } @@ -455,9 +468,11 @@ public void testNoMatchQueries() throws IOException { public void testMatchingNormalizedQueries() throws IOException { String userQuery = "brown cow"; - List features = Arrays.asList( - new PrebuiltFeature(null, new TermQuery(new Term("field", "brown"))), - new PrebuiltFeature(null, new TermQuery(new Term("field", "cow")))); + List features = Arrays + .asList( + new PrebuiltFeature(null, new TermQuery(new Term("field", "brown"))), + new PrebuiltFeature(null, new TermQuery(new Term("field", "cow"))) + ); Map ftrNorms = new HashMap<>(); ftrNorms.put(0, new StandardFeatureNormalizer(1, 0.5f)); ftrNorms.put(1, new StandardFeatureNormalizer(1, 0.5f)); @@ -465,11 +480,12 @@ public void testMatchingNormalizedQueries() throws IOException { checkModelWithFeatures(features, null, ftrNorms); } - public void testNoMatchNormalizedQueries() throws IOException { - List features = Arrays.asList( - new PrebuiltFeature(null, new TermQuery(new Term("field", "missingterm"))), - new PrebuiltFeature(null, new TermQuery(new Term("field", "othermissingterm")))); + List features = Arrays + .asList( + new PrebuiltFeature(null, new TermQuery(new Term("field", "missingterm"))), + new PrebuiltFeature(null, new TermQuery(new Term("field", "othermissingterm"))) + ); Map ftrNorms = new HashMap<>(); ftrNorms.put(0, new StandardFeatureNormalizer(0.5f, 1)); ftrNorms.put(1, new StandardFeatureNormalizer(0.7f, 0.2f)); diff --git a/src/test/java/com/o19s/es/ltr/query/StoredLtrQueryBuilderTests.java b/src/test/java/com/o19s/es/ltr/query/StoredLtrQueryBuilderTests.java index 9eb029ab..52e36f98 100644 --- a/src/test/java/com/o19s/es/ltr/query/StoredLtrQueryBuilderTests.java +++ b/src/test/java/com/o19s/es/ltr/query/StoredLtrQueryBuilderTests.java @@ -16,66 +16,67 @@ package com.o19s.es.ltr.query; -import org.opensearch.ltr.stats.LTRStat; -import org.opensearch.ltr.stats.LTRStats; -import org.opensearch.ltr.stats.StatName; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; -import com.o19s.es.ltr.LtrQueryParserPlugin; -import com.o19s.es.ltr.LtrTestUtils; -import com.o19s.es.ltr.feature.store.CompiledLtrModel; -import com.o19s.es.ltr.feature.store.MemStore; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.ranker.DenseFeatureVector; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.linear.LinearRanker; -import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; -import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; -import com.o19s.es.ltr.utils.FeatureStoreLoader; +import static java.util.Collections.unmodifiableMap; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; -import org.opensearch.core.common.io.stream.ByteBufferStreamInput; +import org.junit.Before; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.common.lucene.search.function.FieldValueFactorFunction; import org.opensearch.common.lucene.search.function.FunctionScoreQuery; import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.ByteBufferStreamInput; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.Rewriteable; import org.opensearch.index.query.functionscore.FieldValueFactorFunctionBuilder; import org.opensearch.index.query.functionscore.FunctionScoreQueryBuilder; +import org.opensearch.ltr.stats.LTRStat; +import org.opensearch.ltr.stats.LTRStats; +import org.opensearch.ltr.stats.StatName; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; import org.opensearch.plugins.Plugin; import org.opensearch.test.AbstractQueryTestCase; import org.opensearch.test.TestGeoShapeFieldMapperPlugin; -import org.junit.Before; - -import java.io.IOException; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import static java.util.Collections.unmodifiableMap; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.instanceOf; +import com.o19s.es.ltr.LtrQueryParserPlugin; +import com.o19s.es.ltr.LtrTestUtils; +import com.o19s.es.ltr.feature.store.CompiledLtrModel; +import com.o19s.es.ltr.feature.store.MemStore; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.ranker.DenseFeatureVector; +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.linear.LinearRanker; +import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; +import com.o19s.es.ltr.utils.FeatureStoreLoader; public class StoredLtrQueryBuilderTests extends AbstractQueryTestCase { private static final MemStore store = new MemStore(); - private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() {{ - put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - }})); + private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() { + { + put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + } + })); // TODO: Remove the TestGeoShapeFieldMapperPlugin once upstream has completed the migration. protected Collection> getPlugins() { @@ -98,24 +99,32 @@ protected Set getObjectsHoldingArbitraryContent() { public void setUp() throws Exception { super.setUp(); store.clear(); - StoredFeature feature1 = new StoredFeature("match1", Collections.singletonList("query_string"), - "mustache", - new MatchQueryBuilder("field1", "{{query_string}}").toString()); - StoredFeature feature2 = new StoredFeature("match2", Collections.singletonList("query_string"), - "mustache", - new MatchQueryBuilder("field2", "{{query_string}}").toString()); - StoredFeature feature3 = new StoredFeature("score3", Collections.emptyList(), - "mustache", - new FunctionScoreQueryBuilder(new FieldValueFactorFunctionBuilder("scorefield2") - .factor(1.2F) - .modifier(FieldValueFactorFunction.Modifier.LN2P) - .missing(0F)).toString()); + StoredFeature feature1 = new StoredFeature( + "match1", + Collections.singletonList("query_string"), + "mustache", + new MatchQueryBuilder("field1", "{{query_string}}").toString() + ); + StoredFeature feature2 = new StoredFeature( + "match2", + Collections.singletonList("query_string"), + "mustache", + new MatchQueryBuilder("field2", "{{query_string}}").toString() + ); + StoredFeature feature3 = new StoredFeature( + "score3", + Collections.emptyList(), + "mustache", + new FunctionScoreQueryBuilder( + new FieldValueFactorFunctionBuilder("scorefield2").factor(1.2F).modifier(FieldValueFactorFunction.Modifier.LN2P).missing(0F) + ).toString() + ); StoredFeatureSet set = new StoredFeatureSet("set1", Arrays.asList(feature1, feature2, feature3)); store.add(set); - LtrRanker ranker = new LinearRanker(new float[] {0.1F, 0.2F, 0.3F}); + LtrRanker ranker = new LinearRanker(new float[] { 0.1F, 0.2F, 0.3F }); Map ftrNorms = new HashMap<>(); - ftrNorms.put(2, new StandardFeatureNormalizer(1.0f,0.5f)); + ftrNorms.put(2, new StandardFeatureNormalizer(1.0f, 0.5f)); ranker = new FeatureNormalizingRanker(ranker, ftrNorms); CompiledLtrModel model = new CompiledLtrModel("model1", set, ranker); @@ -145,14 +154,18 @@ public void testMissingParams() { builder.ltrStats(ltrStats); builder.modelName("model1"); - assertThat(expectThrows(IllegalArgumentException.class, () -> builder.toQuery(createShardContext())).getMessage(), - equalTo("Missing required param(s): [query_string]")); + assertThat( + expectThrows(IllegalArgumentException.class, () -> builder.toQuery(createShardContext())).getMessage(), + equalTo("Missing required param(s): [query_string]") + ); Map params = new HashMap<>(); params.put("query_string2", "a wonderful query"); builder.params(params); - assertThat(expectThrows(IllegalArgumentException.class, () -> builder.toQuery(createShardContext())).getMessage(), - equalTo("Missing required param(s): [query_string]")); + assertThat( + expectThrows(IllegalArgumentException.class, () -> builder.toQuery(createShardContext())).getMessage(), + equalTo("Missing required param(s): [query_string]") + ); } @@ -161,8 +174,10 @@ public void testInvalidActiveFeatures() { builder.modelName("model1"); builder.activeFeatures(Collections.singletonList("non_existent_feature")); builder.ltrStats(ltrStats); - assertThat(expectThrows(IllegalArgumentException.class, () -> builder.toQuery(createShardContext())).getMessage(), - equalTo("Feature: [non_existent_feature] provided in active_features does not exist")); + assertThat( + expectThrows(IllegalArgumentException.class, () -> builder.toQuery(createShardContext())).getMessage(), + equalTo("Feature: [non_existent_feature] provided in active_features does not exist") + ); } public void testSerDe() throws IOException { @@ -175,7 +190,10 @@ public void testSerDe() throws IOException { BytesRef ref = out.bytes().toBytesRef(); StreamInput input = ByteBufferStreamInput.wrap(ref.bytes, ref.offset, ref.length); StoredLtrQueryBuilder builderFromInputStream = new StoredLtrQueryBuilder( - LtrTestUtils.wrapMemStore(StoredLtrQueryBuilderTests.store), input, ltrStats); + LtrTestUtils.wrapMemStore(StoredLtrQueryBuilderTests.store), + input, + ltrStats + ); List expected = Collections.singletonList("match1"); assertEquals(expected, builderFromInputStream.activeFeatures()); } @@ -205,8 +223,7 @@ private void assertQueryClass(Class clazz, boolean setActiveFeature) throws I } @Override - protected void doAssertLuceneQuery(StoredLtrQueryBuilder queryBuilder, - Query query, QueryShardContext context) throws IOException { + protected void doAssertLuceneQuery(StoredLtrQueryBuilder queryBuilder, Query query, QueryShardContext context) throws IOException { assertThat(query, instanceOf(RankerQuery.class)); RankerQuery rquery = (RankerQuery) query; Iterator ite = rquery.stream().iterator(); @@ -238,10 +255,9 @@ protected void doAssertLuceneQuery(StoredLtrQueryBuilder queryBuilder, assertTrue(ite.hasNext()); featureQuery = ite.next(); - builder = new FunctionScoreQueryBuilder(new FieldValueFactorFunctionBuilder("scorefield2") - .factor(1.2F) - .modifier(FieldValueFactorFunction.Modifier.LN2P) - .missing(0F)); + builder = new FunctionScoreQueryBuilder( + new FieldValueFactorFunctionBuilder("scorefield2").factor(1.2F).modifier(FieldValueFactorFunction.Modifier.LN2P).missing(0F) + ); qcontext = createShardContext(); expected = Rewriteable.rewrite(builder, qcontext).toQuery(qcontext); assertEquals(expected, featureQuery); diff --git a/src/test/java/com/o19s/es/ltr/query/ValidatingLtrQueryBuilderTests.java b/src/test/java/com/o19s/es/ltr/query/ValidatingLtrQueryBuilderTests.java index ab1a5fb5..697e5936 100644 --- a/src/test/java/com/o19s/es/ltr/query/ValidatingLtrQueryBuilderTests.java +++ b/src/test/java/com/o19s/es/ltr/query/ValidatingLtrQueryBuilderTests.java @@ -16,10 +16,36 @@ package com.o19s.es.ltr.query; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableMap; +import static java.util.stream.Collectors.joining; +import static org.hamcrest.CoreMatchers.instanceOf; + +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; +import org.junit.runner.RunWith; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.ltr.stats.LTRStat; import org.opensearch.ltr.stats.LTRStats; import org.opensearch.ltr.stats.StatName; import org.opensearch.ltr.stats.suppliers.CounterSupplier; +import org.opensearch.plugins.Plugin; +import org.opensearch.test.AbstractQueryTestCase; +import org.opensearch.test.TestGeoShapeFieldMapperPlugin; + import com.carrotsearch.randomizedtesting.RandomizedRunner; import com.o19s.es.ltr.LtrQueryParserPlugin; import com.o19s.es.ltr.feature.FeatureValidation; @@ -30,44 +56,19 @@ import com.o19s.es.ltr.feature.store.StoredLtrModel; import com.o19s.es.ltr.ranker.parser.LinearRankerParser; import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; -import org.apache.lucene.search.MatchNoDocsQuery; -import org.apache.lucene.search.Query; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.QueryShardContext; -import org.opensearch.plugins.Plugin; -import org.opensearch.test.AbstractQueryTestCase; -import org.opensearch.test.TestGeoShapeFieldMapperPlugin; -import org.junit.runner.RunWith; - -import java.io.IOException; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; -import java.util.function.BiFunction; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import static java.util.Arrays.asList; -import static java.util.Collections.unmodifiableMap; -import static java.util.stream.Collectors.joining; -import static org.hamcrest.CoreMatchers.instanceOf; @RunWith(RandomizedRunner.class) public class ValidatingLtrQueryBuilderTests extends AbstractQueryTestCase { private final LtrRankerParserFactory factory = new LtrRankerParserFactory.Builder() - .register(LinearRankerParser.TYPE, LinearRankerParser::new) - .build(); + .register(LinearRankerParser.TYPE, LinearRankerParser::new) + .build(); - private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() {{ - put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), - new LTRStat<>(false, new CounterSupplier())); - }})); + private LTRStats ltrStats = new LTRStats(unmodifiableMap(new HashMap>() { + { + put(StatName.LTR_REQUEST_TOTAL_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + put(StatName.LTR_REQUEST_ERROR_COUNT.getName(), new LTRStat<>(false, new CounterSupplier())); + } + })); // TODO: Remove the TestGeoShapeFieldMapperPlugin once upstream has completed the migration. protected Collection> getPlugins() { @@ -76,8 +77,7 @@ protected Collection> getPlugins() { @Override protected Set getObjectsHoldingArbitraryContent() { - return new HashSet<>(asList(FeatureValidation.PARAMS.getPreferredName(), - StoredFeature.TEMPLATE.getPreferredName())); + return new HashSet<>(asList(FeatureValidation.PARAMS.getPreferredName(), StoredFeature.TEMPLATE.getPreferredName())); } /** @@ -86,20 +86,24 @@ protected Set getObjectsHoldingArbitraryContent() { @Override protected ValidatingLtrQueryBuilder doCreateTestQueryBuilder() { StorableElement element; - Function buildFeature = (n) -> new StoredFeature(n, - Collections.singletonList("query_string"), "mustache", - QueryBuilders.matchQuery("test", "{{query_string}}").toString()); - BiFunction buildFeatureSet = (i, name) -> new StoredFeatureSet(name, IntStream.range(0, i) - .mapToObj((idx) -> buildFeature.apply("feature" + idx)) - .collect(Collectors.toList())); - Function buildModel = (name) -> new StoredLtrModel(name, - buildFeatureSet.apply(5, "the_feature_set"), - "model/linear", - IntStream.range(0, 5) - .mapToObj((i) -> "\"feature" + i + "\": " + random().nextFloat()) - .collect(joining(",", "{", "}")), - true, - new StoredFeatureNormalizers()); + Function buildFeature = (n) -> new StoredFeature( + n, + Collections.singletonList("query_string"), + "mustache", + QueryBuilders.matchQuery("test", "{{query_string}}").toString() + ); + BiFunction buildFeatureSet = (i, name) -> new StoredFeatureSet( + name, + IntStream.range(0, i).mapToObj((idx) -> buildFeature.apply("feature" + idx)).collect(Collectors.toList()) + ); + Function buildModel = (name) -> new StoredLtrModel( + name, + buildFeatureSet.apply(5, "the_feature_set"), + "model/linear", + IntStream.range(0, 5).mapToObj((i) -> "\"feature" + i + "\": " + random().nextFloat()).collect(joining(",", "{", "}")), + true, + new StoredFeatureNormalizers() + ); int type = randomInt(2); switch (type) { diff --git a/src/test/java/com/o19s/es/ltr/ranker/DenseLtrRankerTests.java b/src/test/java/com/o19s/es/ltr/ranker/DenseLtrRankerTests.java index a3ba1d3e..2ad03a14 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/DenseLtrRankerTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/DenseLtrRankerTests.java @@ -24,18 +24,18 @@ public void newFeatureVector() throws Exception { DummyDenseRanker ranker = new DummyDenseRanker(modelSize); DenseFeatureVector vector = ranker.newFeatureVector(null); assertNotNull(vector); - for(int i = 0; i < modelSize; i++) { + for (int i = 0; i < modelSize; i++) { assertEquals(0, vector.getFeatureScore(0), Math.ulp(0)); } float[] points = vector.scores; assertEquals(points.length, 2); - for(int i = 0; i < modelSize; i++) { + for (int i = 0; i < modelSize; i++) { vector.setFeatureScore(0, random().nextFloat()); } LtrRanker.FeatureVector vector2 = ranker.newFeatureVector(vector); assertSame(vector, vector2); - for(int i = 0; i < modelSize; i++) { + for (int i = 0; i < modelSize; i++) { assertEquals(0, vector.getFeatureScore(0), Math.ulp(0)); } } @@ -62,4 +62,4 @@ public String name() { return "dummy"; } } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/ranker/LogLtrRankerTests.java b/src/test/java/com/o19s/es/ltr/ranker/LogLtrRankerTests.java index 9e003f32..17138e89 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/LogLtrRankerTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/LogLtrRankerTests.java @@ -16,10 +16,11 @@ package com.o19s.es.ltr.ranker; -import com.o19s.es.ltr.ranker.linear.LinearRankerTests; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; +import com.o19s.es.ltr.ranker.linear.LinearRankerTests; + public class LogLtrRankerTests extends LuceneTestCase { public void testNewFeatureVector() throws Exception { int modelSize = TestUtil.nextInt(random(), 1, 20); @@ -51,4 +52,4 @@ public void score() throws Exception { } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index 9606a6e2..af8a346c 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -16,19 +16,13 @@ package com.o19s.es.ltr.ranker.dectree; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.PrebuiltFeature; -import com.o19s.es.ltr.feature.PrebuiltFeatureSet; -import com.o19s.es.ltr.ranker.DenseFeatureVector; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.linear.LinearRankerTests; -import com.o19s.es.ltr.ranker.normalizer.Normalizer; -import com.o19s.es.ltr.ranker.normalizer.Normalizers; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.search.MatchAllDocsQuery; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.tests.util.TestUtil; -import org.apache.logging.log4j.LogManager; +import static org.apache.lucene.tests.util.TestUtil.nextInt; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; +import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.core.AllOf.allOf; import java.io.BufferedReader; import java.io.IOException; @@ -44,19 +38,31 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; -import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_REF; -import static org.apache.lucene.tests.util.TestUtil.nextInt; -import static org.hamcrest.Matchers.greaterThan; -import static org.hamcrest.Matchers.lessThan; -import static org.hamcrest.core.AllOf.allOf; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.PrebuiltFeature; +import com.o19s.es.ltr.feature.PrebuiltFeatureSet; +import com.o19s.es.ltr.ranker.DenseFeatureVector; +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.linear.LinearRankerTests; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import com.o19s.es.ltr.ranker.normalizer.Normalizers; public class NaiveAdditiveDecisionTreeTests extends LuceneTestCase { static final Logger LOG = LogManager.getLogger(NaiveAdditiveDecisionTreeTests.class); + public void testName() { - NaiveAdditiveDecisionTree dectree = new NaiveAdditiveDecisionTree(new NaiveAdditiveDecisionTree.Node[0], - new float[0], 0, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + NaiveAdditiveDecisionTree dectree = new NaiveAdditiveDecisionTree( + new NaiveAdditiveDecisionTree.Node[0], + new float[0], + 0, + Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME) + ); assertEquals("naive_additive_decision_tree", dectree.name()); } @@ -67,7 +73,7 @@ public void testScore() throws IOException { vector.setFeatureScore(1, 2); vector.setFeatureScore(2, 3); - float expected = 1.2F*3.4F + 3.2F*2.8F; + float expected = 1.2F * 3.4F + 3.2F * 2.8F; assertEquals(expected, ranker.score(vector), Math.ulp(expected)); } @@ -78,16 +84,14 @@ public void testSigmoidScore() throws IOException { vector.setFeatureScore(1, 2); vector.setFeatureScore(2, 3); - float expected = 1.2F*3.4F + 3.2F*2.8F; + float expected = 1.2F * 3.4F + 3.2F * 2.8F; expected = (float) (1 / (1 + Math.exp(-expected))); assertEquals(expected, ranker.score(vector), Math.ulp(expected)); } public void testPerfAndRobustness() { SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector(); - NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, - 100, 1000, - 5, 50, counts); + NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, 100, 1000, 5, 50, counts); DenseFeatureVector vector = ranker.newFeatureVector(null); int nPass = TestUtil.nextInt(random(), 10, 8916); @@ -101,36 +105,51 @@ public void testPerfAndRobustness() { ranker.score(vector); } time += System.currentTimeMillis(); - LOG.info("Scored {} docs with {} trees/{} features within {}ms ({} ms/doc), " + - "{} nodes ({} splits & {} leaves) ", - nPass, counts.trees.get(), ranker.size(), time, (float) time / (float) nPass, - counts.nodes.get(), counts.splits.get(), counts.leaves.get()); + LOG + .info( + "Scored {} docs with {} trees/{} features within {}ms ({} ms/doc), " + "{} nodes ({} splits & {} leaves) ", + nPass, + counts.trees.get(), + ranker.size(), + time, + (float) time / (float) nPass, + counts.nodes.get(), + counts.splits.get(), + counts.leaves.get() + ); } public void testRamSize() { SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector(); - NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, - 100, 1000, - 5, 50, counts); + NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, 100, 1000, 5, 50, counts); long actualSize = ranker.ramBytesUsed(); long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2); expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES); expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER; - assertThat(actualSize, allOf( - greaterThan((long) (expectedApprox*0.66F)), - lessThan((long) (expectedApprox*1.33F)))); + assertThat(actualSize, allOf(greaterThan((long) (expectedApprox * 0.66F)), lessThan((long) (expectedApprox * 1.33F)))); } - public static NaiveAdditiveDecisionTree generateRandomDecTree(int minFeatures, int maxFeatures, int minTrees, - int maxTrees, int minDepth, int maxDepth, - RandomTreeGeneratorStatsCollector collector) { + public static NaiveAdditiveDecisionTree generateRandomDecTree( + int minFeatures, + int maxFeatures, + int minTrees, + int maxTrees, + int minDepth, + int maxDepth, + RandomTreeGeneratorStatsCollector collector + ) { int nFeat = nextInt(random(), minFeatures, maxFeatures); int nbTrees = nextInt(random(), minTrees, maxTrees); return generateRandomDecTree(nFeat, nbTrees, minDepth, maxDepth, collector); } - public static NaiveAdditiveDecisionTree generateRandomDecTree(int nbFeatures, int nbTree, int minDepth, - int maxDepth, RandomTreeGeneratorStatsCollector collector) { + public static NaiveAdditiveDecisionTree generateRandomDecTree( + int nbFeatures, + int nbTree, + int minDepth, + int maxDepth, + RandomTreeGeneratorStatsCollector collector + ) { NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[nbTree]; float[] weights = LinearRankerTests.generateRandomWeights(nbTree); for (int i = 0; i < nbTree; i++) { @@ -143,8 +162,12 @@ public static NaiveAdditiveDecisionTree generateRandomDecTree(int nbFeatures, in } public void testSize() { - NaiveAdditiveDecisionTree ranker = new NaiveAdditiveDecisionTree(new NaiveAdditiveDecisionTree.Node[0], - new float[0], 3, Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + NaiveAdditiveDecisionTree ranker = new NaiveAdditiveDecisionTree( + new NaiveAdditiveDecisionTree.Node[0], + new float[0], + 3, + Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME) + ); assertEquals(ranker.size(), 3); } @@ -159,7 +182,7 @@ private NaiveAdditiveDecisionTree parseTreeModel(String textRes, Normalizer norm List treesAndWeight = parser.parseTrees(); float[] weights = new float[treesAndWeight.size()]; NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[treesAndWeight.size()]; - for(int i = 0; i < treesAndWeight.size(); i++) { + for (int i = 0; i < treesAndWeight.size(); i++) { weights[i] = treesAndWeight.get(i).weight; trees[i] = treesAndWeight.get(i).tree; } @@ -169,9 +192,10 @@ private NaiveAdditiveDecisionTree parseTreeModel(String textRes, Normalizer norm private static class TreeTextParser { FeatureSet set; Iterator lines; + private TreeTextParser(InputStream is, FeatureSet set) throws IOException { List lines = new ArrayList<>(); - try(BufferedReader br = new BufferedReader(new InputStreamReader(is, Charset.forName("UTF-8")))) { + try (BufferedReader br = new BufferedReader(new InputStreamReader(is, Charset.forName("UTF-8")))) { String line; while ((line = br.readLine()) != null) { lines.add(line); @@ -180,12 +204,13 @@ private TreeTextParser(InputStream is, FeatureSet set) throws IOException { this.set = set; this.lines = lines.iterator(); } + private List parseTrees() { List trees = new ArrayList<>(); - while(lines.hasNext()) { + while (lines.hasNext()) { String line = lines.next(); - if(line.startsWith("- tree")) { + if (line.startsWith("- tree")) { TreeAndWeight tree = new TreeAndWeight(); tree.weight = extractLastFloat(line); tree.tree = parseTree(); @@ -202,11 +227,11 @@ NaiveAdditiveDecisionTree.Node parseTree() { throw new IllegalArgumentException("Invalid tree"); } line = lines.next(); - } while(line.startsWith("#")); + } while (line.startsWith("#")); if (line.contains("- output")) { return new NaiveAdditiveDecisionTree.Leaf(extractLastFloat(line)); - } else if(line.contains("- split")) { + } else if (line.contains("- split")) { String featName = line.split(":")[1]; int ord = set.featureOrdinal(featName); if (ord < 0 || ord > set.size()) { @@ -216,8 +241,7 @@ NaiveAdditiveDecisionTree.Node parseTree() { NaiveAdditiveDecisionTree.Node right = parseTree(); NaiveAdditiveDecisionTree.Node left = parseTree(); - return new NaiveAdditiveDecisionTree.Split(left, right, - ord, threshold); + return new NaiveAdditiveDecisionTree.Split(left, right, ord, threshold); } else { throw new IllegalArgumentException("Invalid tree"); } @@ -226,7 +250,7 @@ NaiveAdditiveDecisionTree.Node parseTree() { float extractLastFloat(String line) { Pattern p = Pattern.compile(".*:([0-9.]+)$"); Matcher m = p.matcher(line); - if(m.find()) { + if (m.find()) { return Float.parseFloat(m.group(1)); } throw new IllegalArgumentException("Cannot extract float from " + line); @@ -251,13 +275,17 @@ public RandomTreeGenerator(int maxFeat, int minDepth, int maxDepth, RandomTreeGe this.minDepth = minDepth; this.maxDepth = maxDepth; this.statsCollector = collector != null ? collector : RandomTreeGeneratorStatsCollector.NULL; - featureGen = () -> nextInt(random(), 0, maxFeat-1); - outputGenerator = () -> - (random().nextBoolean() ? 1F : -1F) * - ((float)nextInt(random(), 0, 1000) / (float)nextInt(random(), 1, 1000)); - thresholdGenerator = (feat) -> - (random().nextBoolean() ? 1F : -1F) * - ((float)nextInt(random(), 0, 1000) / (float)nextInt(random(), 1, 1000)); + featureGen = () -> nextInt(random(), 0, maxFeat - 1); + outputGenerator = () -> (random().nextBoolean() ? 1F : -1F) * ((float) nextInt(random(), 0, 1000) / (float) nextInt( + random(), + 1, + 1000 + )); + thresholdGenerator = (feat) -> (random().nextBoolean() ? 1F : -1F) * ((float) nextInt(random(), 0, 1000) / (float) nextInt( + random(), + 1, + 1000 + )); leafDecider = () -> random().nextBoolean(); } @@ -268,7 +296,7 @@ NaiveAdditiveDecisionTree.Node genTree() { NaiveAdditiveDecisionTree.Node newNode(int depth) { statsCollector.newNode(); - if (depth>=maxDepth) { + if (depth >= maxDepth) { return newLeaf(depth); } else if (depth <= minDepth) { return newSplit(depth); @@ -293,10 +321,15 @@ private NaiveAdditiveDecisionTree.Node newLeaf(int depth) { } public interface RandomTreeGeneratorStatsCollector { - RandomTreeGeneratorStatsCollector NULL = new RandomTreeGeneratorStatsCollector() {}; + RandomTreeGeneratorStatsCollector NULL = new RandomTreeGeneratorStatsCollector() { + }; + default void newSplit(int depth, int feature, float thresh) {} + default void newLeaf(int depth, float output) {} + default void newNode() {} + default void newTree() {} } @@ -326,4 +359,4 @@ public void newTree() { trees.incrementAndGet(); } } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/ranker/linear/LinearRankerTests.java b/src/test/java/com/o19s/es/ltr/ranker/linear/LinearRankerTests.java index ce0d1c17..409c0e01 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/linear/LinearRankerTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/linear/LinearRankerTests.java @@ -16,50 +16,49 @@ package com.o19s.es.ltr.ranker.linear; -import com.o19s.es.ltr.ranker.DenseFeatureVector; -import com.o19s.es.ltr.ranker.LtrRanker; -import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTreeTests; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.tests.util.TestUtil; - +import static org.apache.lucene.tests.util.TestUtil.nextInt; import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; import static org.apache.lucene.util.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; -import static org.apache.lucene.tests.util.TestUtil.nextInt; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.core.AllOf.allOf; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.tests.util.TestUtil; + +import com.o19s.es.ltr.ranker.DenseFeatureVector; +import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTreeTests; + public class LinearRankerTests extends LuceneTestCase { private static final Logger LOG = LogManager.getLogger(NaiveAdditiveDecisionTreeTests.class); public void testName() throws Exception { - LinearRanker ranker = new LinearRanker(new float[]{1,2}); + LinearRanker ranker = new LinearRanker(new float[] { 1, 2 }); assertEquals("linear", ranker.name()); } public void testScore() { - LinearRanker ranker = new LinearRanker(new float[]{1,2,3}); + LinearRanker ranker = new LinearRanker(new float[] { 1, 2, 3 }); LtrRanker.FeatureVector point = ranker.newFeatureVector(null); point.setFeatureScore(0, 2); point.setFeatureScore(1, 3); point.setFeatureScore(2, 4); - float expected = 1F*2F + 2F*3F + 3F*4F; + float expected = 1F * 2F + 2F * 3F + 3F * 4F; assertEquals(expected, ranker.score(point), Math.ulp(expected)); } public void testSize() { - LinearRanker ranker = new LinearRanker(new float[]{1,2,3}); + LinearRanker ranker = new LinearRanker(new float[] { 1, 2, 3 }); assertEquals(ranker.size(), 3); } public void testRamSize() { LinearRanker ranker = generateRandomRanker(1, 1000); - int expectedSize = ranker.size()*Float.BYTES + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER; - assertThat(ranker.ramBytesUsed(), - allOf(greaterThan((long) (expectedSize*0.66F)), - lessThan((long) (expectedSize*1.33F)))); + int expectedSize = ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER; + assertThat(ranker.ramBytesUsed(), allOf(greaterThan((long) (expectedSize * 0.66F)), lessThan((long) (expectedSize * 1.33F)))); } public void testPerfAndRobustness() { @@ -77,13 +76,13 @@ public void testPerfAndRobustness() { ranker.score(vector); } time += System.currentTimeMillis(); - LOG.info("Scored {} docs with {} features within {}ms ({} ms/doc)", - nPass, ranker.size(), time, (float) time / (float) nPass); + LOG.info("Scored {} docs with {} features within {}ms ({} ms/doc)", nPass, ranker.size(), time, (float) time / (float) nPass); } public static LinearRanker generateRandomRanker(int minsize, int maxsize) { return generateRandomRanker(nextInt(random(), minsize, maxsize)); } + public static LinearRanker generateRandomRanker(int size) { return new LinearRanker(generateRandomWeights(size)); } @@ -96,7 +95,7 @@ public static float[] generateRandomWeights(int s) { public static void fillRandomWeights(float[] weights) { for (int i = 0; i < weights.length; i++) { - weights[i] = (float) nextInt(random(),1, 100000) / (float) nextInt(random(), 1, 100000); + weights[i] = (float) nextInt(random(), 1, 100000) / (float) nextInt(random(), 1, 100000); } } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/ranker/normalizer/NormalizersTests.java b/src/test/java/com/o19s/es/ltr/ranker/normalizer/NormalizersTests.java index b360ec9b..2e0c0040 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/normalizer/NormalizersTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/normalizer/NormalizersTests.java @@ -27,8 +27,10 @@ public void testGet() { } public void testInvalidName() { - assertThat(expectThrows(IllegalArgumentException.class, () -> Normalizers.get("not_normalizer")).getMessage(), - CoreMatchers.containsString("is not a valid Normalizer")); + assertThat( + expectThrows(IllegalArgumentException.class, () -> Normalizers.get("not_normalizer")).getMessage(), + CoreMatchers.containsString("is not a valid Normalizer") + ); } public void testExists() { diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/LinearRankerParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/LinearRankerParserTests.java index 254ab8bc..d2eeeadb 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/LinearRankerParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/LinearRankerParserTests.java @@ -16,22 +16,22 @@ package com.o19s.es.ltr.ranker.parser; +import static java.util.Collections.singletonList; + +import java.io.IOException; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.Assert; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.XContentBuilder; + import com.o19s.es.ltr.LtrTestUtils; import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.feature.store.StoredFeatureSet; import com.o19s.es.ltr.ranker.DenseFeatureVector; import com.o19s.es.ltr.ranker.linear.LinearRanker; import com.o19s.es.ltr.ranker.linear.LinearRankerTests; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.core.common.ParsingException; -import org.opensearch.core.common.Strings; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.json.JsonXContent; -import org.junit.Assert; - -import java.io.IOException; - -import static java.util.Collections.singletonList; public class LinearRankerParserTests extends LuceneTestCase { public void testParse() throws IOException { @@ -95,4 +95,4 @@ public static String generateRandomModelString(FeatureSet set) throws IOExceptio builder.endObject().close(); return builder.toString(); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/LtrRankerParserFactoryTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/LtrRankerParserFactoryTests.java index 0d329746..91595763 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/LtrRankerParserFactoryTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/LtrRankerParserFactoryTests.java @@ -16,28 +16,25 @@ package com.o19s.es.ltr.ranker.parser; -import org.apache.lucene.tests.util.LuceneTestCase; - import static org.hamcrest.CoreMatchers.containsString; +import org.apache.lucene.tests.util.LuceneTestCase; + public class LtrRankerParserFactoryTests extends LuceneTestCase { public void testGetParser() { LtrRankerParser parser = (set, model) -> null; - LtrRankerParserFactory factory = new LtrRankerParserFactory.Builder() - .register("model/test", () -> parser) - .build(); + LtrRankerParserFactory factory = new LtrRankerParserFactory.Builder().register("model/test", () -> parser).build(); assertSame(parser, factory.getParser("model/test")); - assertThat(expectThrows(IllegalArgumentException.class, - () -> factory.getParser("model/foobar")).getMessage(), - containsString("Unsupported LtrRanker format/type [model/foobar]")); + assertThat( + expectThrows(IllegalArgumentException.class, () -> factory.getParser("model/foobar")).getMessage(), + containsString("Unsupported LtrRanker format/type [model/foobar]") + ); } public void testDeclareMultiple() { LtrRankerParser parser = (set, model) -> null; - LtrRankerParserFactory.Builder builder = new LtrRankerParserFactory.Builder() - .register("model/test", () -> parser); - expectThrows(RuntimeException.class, - () -> builder.register("model/test", () -> parser)); + LtrRankerParserFactory.Builder builder = new LtrRankerParserFactory.Builder().register("model/test", () -> parser); + expectThrows(RuntimeException.class, () -> builder.register("model/test", () -> parser)); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java index 02d8ad42..7260238c 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java @@ -16,18 +16,9 @@ package com.o19s.es.ltr.ranker.parser; -import com.o19s.es.ltr.LtrTestUtils; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.store.StoredFeature; -import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.ranker.DenseFeatureVector; -import com.o19s.es.ltr.ranker.LtrRanker.FeatureVector; -import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; -import com.o19s.es.ltr.ranker.linear.LinearRankerTests; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.core.common.ParsingException; -import org.opensearch.common.io.Streams; -import org.hamcrest.CoreMatchers; +import static com.o19s.es.ltr.LtrTestUtils.randomFeature; +import static com.o19s.es.ltr.LtrTestUtils.randomFeatureSet; +import static java.util.Collections.singletonList; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -37,12 +28,22 @@ import java.util.Arrays; import java.util.List; -import static com.o19s.es.ltr.LtrTestUtils.randomFeature; -import static com.o19s.es.ltr.LtrTestUtils.randomFeatureSet; -import static java.util.Collections.singletonList; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.hamcrest.CoreMatchers; +import org.opensearch.core.common.ParsingException; + +import com.o19s.es.ltr.LtrTestUtils; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.ranker.DenseFeatureVector; +import com.o19s.es.ltr.ranker.LtrRanker.FeatureVector; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; +import com.o19s.es.ltr.ranker.linear.LinearRankerTests; public class XGBoostJsonParserTests extends LuceneTestCase { private final XGBoostJsonParser parser = new XGBoostJsonParser(); + public void testReadLeaf() throws IOException { String model = "[ {\"nodeid\": 0, \"leaf\": 0.234}]"; FeatureSet set = randomFeatureSet(); @@ -51,18 +52,18 @@ public void testReadLeaf() throws IOException { } public void testReadSimpleSplit() throws IOException { - String model = "[{" + - "\"nodeid\": 0," + - "\"split\":\"feat1\"," + - "\"depth\":0," + - "\"split_condition\":0.123," + - "\"yes\":1," + - "\"no\": 2," + - "\"missing\":2,"+ - "\"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + - "]}]"; + String model = "[{" + + "\"nodeid\": 0," + + "\"split\":\"feat1\"," + + "\"depth\":0," + + "\"split_condition\":0.123," + + "\"yes\":1," + + "\"no\": 2," + + "\"missing\":2," + + "\"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); NaiveAdditiveDecisionTree tree = parser.parse(set, model); @@ -76,19 +77,19 @@ public void testReadSimpleSplit() throws IOException { } public void testReadSimpleSplitInObject() throws IOException { - String model = "{" + - "\"splits\": [{" + - " \"nodeid\": 0," + - " \"split\":\"feat1\"," + - " \"depth\":0," + - " \"split_condition\":0.123," + - " \"yes\":1," + - " \"no\": 2," + - " \"missing\":2,"+ - " \"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + - "]}]}"; + String model = "{" + + "\"splits\": [{" + + " \"nodeid\": 0," + + " \"split\":\"feat1\"," + + " \"depth\":0," + + " \"split_condition\":0.123," + + " \"yes\":1," + + " \"no\": 2," + + " \"missing\":2," + + " \"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]}"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); NaiveAdditiveDecisionTree tree = parser.parse(set, model); @@ -102,20 +103,20 @@ public void testReadSimpleSplitInObject() throws IOException { } public void testReadSimpleSplitWithObjective() throws IOException { - String model = "{" + - "\"objective\": \"reg:linear\"," + - "\"splits\": [{" + - " \"nodeid\": 0," + - " \"split\":\"feat1\"," + - " \"depth\":0," + - " \"split_condition\":0.123," + - " \"yes\":1," + - " \"no\": 2," + - " \"missing\":2,"+ - " \"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + - "]}]}"; + String model = "{" + + "\"objective\": \"reg:linear\"," + + "\"splits\": [{" + + " \"nodeid\": 0," + + " \"split\":\"feat1\"," + + " \"depth\":0," + + " \"split_condition\":0.123," + + " \"yes\":1," + + " \"no\": 2," + + " \"missing\":2," + + " \"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]}"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); NaiveAdditiveDecisionTree tree = parser.parse(set, model); @@ -129,62 +130,66 @@ public void testReadSimpleSplitWithObjective() throws IOException { } public void testReadSplitWithUnknownParams() throws IOException { - String model = "{" + - "\"not_param\": \"value\"," + - "\"splits\": [{" + - " \"nodeid\": 0," + - " \"split\":\"feat1\"," + - " \"depth\":0," + - " \"split_condition\":0.123," + - " \"yes\":1," + - " \"no\": 2," + - " \"missing\":2,"+ - " \"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + - "]}]}"; + String model = "{" + + "\"not_param\": \"value\"," + + "\"splits\": [{" + + " \"nodeid\": 0," + + " \"split\":\"feat1\"," + + " \"depth\":0," + + " \"split_condition\":0.123," + + " \"yes\":1," + + " \"no\": 2," + + " \"missing\":2," + + " \"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]}"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); - assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), - CoreMatchers.containsString("Unable to parse XGBoost object")); + assertThat( + expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Unable to parse XGBoost object") + ); } public void testBadObjectiveParam() throws IOException { - String model = "{" + - "\"objective\": \"reg:invalid\"," + - "\"splits\": [{" + - " \"nodeid\": 0," + - " \"split\":\"feat1\"," + - " \"depth\":0," + - " \"split_condition\":0.123," + - " \"yes\":1," + - " \"no\": 2," + - " \"missing\":2,"+ - " \"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + - "]}]}"; + String model = "{" + + "\"objective\": \"reg:invalid\"," + + "\"splits\": [{" + + " \"nodeid\": 0," + + " \"split\":\"feat1\"," + + " \"depth\":0," + + " \"split_condition\":0.123," + + " \"yes\":1," + + " \"no\": 2," + + " \"missing\":2," + + " \"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]}"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); - assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), - CoreMatchers.containsString("Unable to parse XGBoost object")); + assertThat( + expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Unable to parse XGBoost object") + ); } public void testReadWithLogisticObjective() throws IOException { - String model = "{" + - "\"objective\": \"reg:logistic\"," + - "\"splits\": [{" + - " \"nodeid\": 0," + - " \"split\":\"feat1\"," + - " \"depth\":0," + - " \"split_condition\":0.123," + - " \"yes\":1," + - " \"no\": 2," + - " \"missing\":2,"+ - " \"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": -0.2}" + - "]}]}"; + String model = "{" + + "\"objective\": \"reg:logistic\"," + + "\"splits\": [{" + + " \"nodeid\": 0," + + " \"split\":\"feat1\"," + + " \"depth\":0," + + " \"split_condition\":0.123," + + " \"yes\":1," + + " \"no\": 2," + + " \"missing\":2," + + " \"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": -0.2}" + + "]}]}"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); NaiveAdditiveDecisionTree tree = parser.parse(set, model); @@ -198,68 +203,77 @@ public void testReadWithLogisticObjective() throws IOException { } public void testMissingField() throws IOException { - String model = "[{" + - "\"nodeid\": 0," + - "\"split\":\"feat1\"," + - "\"depth\":0," + - "\"split_condition\":0.123," + - "\"no\": 2," + - "\"missing\":2,"+ - "\"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + - "]}]"; + String model = "[{" + + "\"nodeid\": 0," + + "\"split\":\"feat1\"," + + "\"depth\":0," + + "\"split_condition\":0.123," + + "\"no\": 2," + + "\"missing\":2," + + "\"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); - assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), - CoreMatchers.containsString("This split does not have all the required fields")); + assertThat( + expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("This split does not have all the required fields") + ); } public void testBadStruct() throws IOException { - String model = "[{" + - "\"nodeid\": 0," + - "\"split\":\"feat1\"," + - "\"depth\":0," + - "\"split_condition\":0.123," + - "\"yes\":1," + - "\"no\": 3," + - "\"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + - "]}]"; + String model = "[{" + + "\"nodeid\": 0," + + "\"split\":\"feat1\"," + + "\"depth\":0," + + "\"split_condition\":0.123," + + "\"yes\":1," + + "\"no\": 3," + + "\"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); - assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), - CoreMatchers.containsString("Split structure is invalid, yes, no and/or")); + assertThat( + expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Split structure is invalid, yes, no and/or") + ); } public void testMissingFeat() throws IOException { - String model = "[{" + - "\"nodeid\": 0," + - "\"split\":\"feat2\"," + - "\"depth\":0," + - "\"split_condition\":0.123," + - "\"yes\":1," + - "\"no\": 2," + - "\"missing\":2,"+ - "\"children\": [" + - " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + - " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + - "]}]"; + String model = "[{" + + "\"nodeid\": 0," + + "\"split\":\"feat2\"," + + "\"depth\":0," + + "\"split_condition\":0.123," + + "\"yes\":1," + + "\"no\": 2," + + "\"missing\":2," + + "\"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]"; FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); - assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), - CoreMatchers.containsString("Unknown feature [feat2]")); + assertThat( + expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Unknown feature [feat2]") + ); } public void testComplexModel() throws Exception { String model = readModel("/models/xgboost-wmf.json"); List features = new ArrayList<>(); - List names = Arrays.asList("all_near_match", + List names = Arrays + .asList( + "all_near_match", "category", "heading", "incoming_links", "popularity_score", "redirect_or_suggest_dismax", "text_or_opening_text_dismax", - "title"); + "title" + ); for (String n : names) { features.add(LtrTestUtils.randomFeature(n)); } diff --git a/src/test/java/com/o19s/es/ltr/rest/FeaturesParserTests.java b/src/test/java/com/o19s/es/ltr/rest/FeaturesParserTests.java index c4cd3a91..a32000e6 100644 --- a/src/test/java/com/o19s/es/ltr/rest/FeaturesParserTests.java +++ b/src/test/java/com/o19s/es/ltr/rest/FeaturesParserTests.java @@ -16,29 +16,27 @@ package com.o19s.es.ltr.rest; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; +import static com.o19s.es.ltr.feature.store.StoredFeatureParserTests.generateTestFeature; +import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; import java.io.IOException; import java.util.stream.Collectors; import java.util.stream.IntStream; -import static com.o19s.es.ltr.feature.store.StoredFeatureParserTests.generateTestFeature; -import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; public class FeaturesParserTests extends LuceneTestCase { public void testParseArray() throws IOException { RestAddFeatureToSet.FeaturesParserState fparser = new RestAddFeatureToSet.FeaturesParserState(); - int nFeat = random().nextInt(18)+1; - String featuresArray = IntStream.range(0, nFeat) - .mapToObj((i) -> generateTestFeature("feat" + i)) - .collect(Collectors.joining(",")); - XContentParser parser = jsonXContent.createParser(NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, "{\"features\":[" + featuresArray + "]}"); + int nFeat = random().nextInt(18) + 1; + String featuresArray = IntStream.range(0, nFeat).mapToObj((i) -> generateTestFeature("feat" + i)).collect(Collectors.joining(",")); + XContentParser parser = jsonXContent + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, "{\"features\":[" + featuresArray + "]}"); fparser.parse(parser); assertEquals(nFeat, fparser.getFeatures().size()); assertEquals("feat0", fparser.getFeatures().get(0).name()); } -} \ No newline at end of file +} diff --git a/src/test/java/com/o19s/es/termstat/TermStatQueryBuilderTests.java b/src/test/java/com/o19s/es/termstat/TermStatQueryBuilderTests.java index 4b68af9f..d931385e 100644 --- a/src/test/java/com/o19s/es/termstat/TermStatQueryBuilderTests.java +++ b/src/test/java/com/o19s/es/termstat/TermStatQueryBuilderTests.java @@ -15,9 +15,12 @@ */ package com.o19s.es.termstat; -import com.o19s.es.explore.StatisticsHelper.AggrType; +import static java.util.Arrays.asList; +import static org.hamcrest.CoreMatchers.instanceOf; + +import java.io.IOException; +import java.util.Collection; -import com.o19s.es.ltr.LtrQueryParserPlugin; import org.apache.lucene.search.Query; import org.opensearch.core.common.ParsingException; import org.opensearch.index.query.QueryShardContext; @@ -25,11 +28,8 @@ import org.opensearch.test.AbstractQueryTestCase; import org.opensearch.test.TestGeoShapeFieldMapperPlugin; -import java.io.IOException; -import java.util.Collection; - -import static java.util.Arrays.asList; -import static org.hamcrest.CoreMatchers.instanceOf; +import com.o19s.es.explore.StatisticsHelper.AggrType; +import com.o19s.es.ltr.LtrQueryParserPlugin; public class TermStatQueryBuilderTests extends AbstractQueryTestCase { // TODO: Remove the TestGeoShapeFieldMapperPlugin once upstream has completed the migration. @@ -45,22 +45,22 @@ protected TermStatQueryBuilder doCreateTestQueryBuilder() { builder.expr("tf"); builder.aggr(AggrType.AVG.getType()); builder.posAggr(AggrType.AVG.getType()); - builder.fields(new String[]{"text"}); - builder.terms(new String[]{"cow"}); + builder.fields(new String[] { "text" }); + builder.terms(new String[] { "cow" }); return builder; } public void testParse() throws Exception { - String query = " {" + - " \"term_stat\": {" + - " \"expr\": \"tf\"," + - " \"aggr\": \"min\"," + - " \"pos_aggr\": \"max\"," + - " \"fields\": [\"text\"]," + - " \"terms\": [\"cow\"]" + - " }" + - "}"; + String query = " {" + + " \"term_stat\": {" + + " \"expr\": \"tf\"," + + " \"aggr\": \"min\"," + + " \"pos_aggr\": \"max\"," + + " \"fields\": [\"text\"]," + + " \"terms\": [\"cow\"]" + + " }" + + "}"; TermStatQueryBuilder builder = (TermStatQueryBuilder) parseQuery(query); @@ -71,14 +71,14 @@ public void testParse() throws Exception { } public void testMissingExpr() throws Exception { - String query = " {" + - " \"term_stat\": {" + - " \"aggr\": \"min\"," + - " \"pos_aggr\": \"max\"," + - " \"fields\": [\"text\"]," + - " \"terms\": [\"cow\"]" + - " }" + - "}"; + String query = " {" + + " \"term_stat\": {" + + " \"aggr\": \"min\"," + + " \"pos_aggr\": \"max\"," + + " \"fields\": [\"text\"]," + + " \"terms\": [\"cow\"]" + + " }" + + "}"; expectThrows(ParsingException.class, () -> parseQuery(query)); } diff --git a/src/test/java/com/o19s/es/termstat/TermStatQueryTests.java b/src/test/java/com/o19s/es/termstat/TermStatQueryTests.java index dd953ebc..0ed211ac 100644 --- a/src/test/java/com/o19s/es/termstat/TermStatQueryTests.java +++ b/src/test/java/com/o19s/es/termstat/TermStatQueryTests.java @@ -15,9 +15,11 @@ */ package com.o19s.es.termstat; -import com.o19s.es.explore.StatisticsHelper.AggrType; +import static org.hamcrest.Matchers.equalTo; + +import java.util.HashSet; +import java.util.Set; -import com.o19s.es.ltr.utils.Scripting; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.StoredField; @@ -33,16 +35,12 @@ import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; -import org.opensearch.common.lucene.Lucene; - import org.junit.After; import org.junit.Before; +import org.opensearch.common.lucene.Lucene; - -import java.util.HashSet; -import java.util.Set; - -import static org.hamcrest.Matchers.equalTo; +import com.o19s.es.explore.StatisticsHelper.AggrType; +import com.o19s.es.ltr.utils.Scripting; public class TermStatQueryTests extends LuceneTestCase { private Directory dir; @@ -51,19 +49,18 @@ public class TermStatQueryTests extends LuceneTestCase { // Some simple documents to index private final String[] docs = new String[] { - "how now brown cow", - "brown is the color of cows", - "brown cow", - "banana cows are yummy", - "dance with monkeys and do not stop to dance", - "break on through to the other side... break on through to the other side... break on through to the other side" - }; + "how now brown cow", + "brown is the color of cows", + "brown cow", + "banana cows are yummy", + "dance with monkeys and do not stop to dance", + "break on through to the other side... break on through to the other side... break on through to the other side" }; @Before public void setupIndex() throws Exception { dir = new ByteBuffersDirectory(); - try(IndexWriter indexWriter = new IndexWriter(dir, new IndexWriterConfig(Lucene.STANDARD_ANALYZER))) { + try (IndexWriter indexWriter = new IndexWriter(dir, new IndexWriterConfig(Lucene.STANDARD_ANALYZER))) { for (int i = 0; i < docs.length; i++) { Document doc = new Document(); doc.add(new Field("_id", Integer.toString(i + 1), StoredField.TYPE)); diff --git a/src/test/java/org/opensearch/ltr/LTRRestTestCase.java b/src/test/java/org/opensearch/ltr/LTRRestTestCase.java index 3adde5e6..7f895734 100644 --- a/src/test/java/org/opensearch/ltr/LTRRestTestCase.java +++ b/src/test/java/org/opensearch/ltr/LTRRestTestCase.java @@ -15,6 +15,11 @@ package org.opensearch.ltr; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.common.util.io.Streams; @@ -24,27 +29,23 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.rest.OpenSearchRestTestCase; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; - public class LTRRestTestCase extends OpenSearchRestTestCase { /** * Utility to update settings */ public void updateClusterSettings(String settingKey, Object value) throws Exception { - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("persistent") - .field(settingKey, value) - .endObject() - .endObject(); + XContentBuilder builder = XContentFactory + .jsonBuilder() + .startObject() + .startObject("persistent") + .field(settingKey, value) + .endObject() + .endObject(); Request request = new Request("PUT", "_cluster/settings"); request.setJsonEntity(BytesReference.bytes(builder).utf8ToString()); Response response = client().performRequest(request); - assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertEquals(RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } /** @@ -52,16 +53,14 @@ public void updateClusterSettings(String settingKey, Object value) throws Except * @param name suffix of index name */ public void createLTRStore(String name) throws IOException { - String path = "_ltr";; + String path = "_ltr"; + ; if (name != null && !name.isEmpty()) { path = path + "/" + name; } - Request request = new Request( - "PUT", - "/" + path - ); + Request request = new Request("PUT", "/" + path); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -79,16 +78,14 @@ public void createDefaultLTRStore() throws IOException { * @param name suffix of index name */ public void deleteLTRStore(String name) throws IOException { - String path = "_ltr";; + String path = "_ltr"; + ; if (name != null && !name.isEmpty()) { path = path + "/" + name; } - Request request = new Request( - "DELETE", - "/" + path - ); + Request request = new Request("DELETE", "/" + path); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -106,28 +103,27 @@ public void deleteDefaultLTRStore() throws IOException { * @param name feature set */ public void createFeatureSet(String name) throws IOException { - Request request = new Request( - "POST", - "/_ltr/_featureset/" + name - ); + Request request = new Request("POST", "/_ltr/_featureset/" + name); + + XContentBuilder xb = XContentFactory + .jsonBuilder() + .startObject() + .startObject("featureset") + .field("name", name) + .startArray("features"); - XContentBuilder xb = XContentFactory.jsonBuilder() + for (int i = 1; i < 3; ++i) { + xb .startObject() - .startObject("featureset") - .field("name", name) - .startArray("features"); - - for (int i=1; i<3; ++i) { - xb.startObject() - .field("name", String.valueOf(i)) - .array("params", "keywords") - .field("template_language", "mustache") - .startObject("template") - .startObject("match") - .field("field"+i, "{{keywords}}") - .endObject() - .endObject() - .endObject(); + .field("name", String.valueOf(i)) + .array("params", "keywords") + .field("template_language", "mustache") + .startObject("template") + .startObject("match") + .field("field" + i, "{{keywords}}") + .endObject() + .endObject() + .endObject(); } xb.endArray().endObject().endObject(); @@ -148,10 +144,7 @@ public void createDefaultFeatureSet() throws IOException { * @param name feature set */ public void deleteFeatureSet(String name) throws IOException { - Request request = new Request( - "DELETE", - "/_ltr/_featureset/" + name - ); + Request request = new Request("DELETE", "/_ltr/_featureset/" + name); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -169,10 +162,7 @@ public void deleteDefaultFeatureSet() throws IOException { * @param name feature set */ public void getFeatureSet(String name) throws IOException { - Request request = new Request( - "GET", - "/_ltr/_featureset/" + name - ); + Request request = new Request("GET", "/_ltr/_featureset/" + name); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -193,21 +183,19 @@ public void createModel(String name) throws IOException { String defaultJsonModel = readSourceModel("/models/default-xgb-model.json"); - Request request = new Request( - "POST", - "/_ltr/_featureset/default_features/_createmodel" - ); + Request request = new Request("POST", "/_ltr/_featureset/default_features/_createmodel"); - XContentBuilder xb = XContentFactory.jsonBuilder() - .startObject() - .startObject("model") - .field("name", name) - .startObject("model") - .field("type", "model/xgboost+json") - .field("definition", defaultJsonModel) - .endObject() - .endObject() - .endObject(); + XContentBuilder xb = XContentFactory + .jsonBuilder() + .startObject() + .startObject("model") + .field("name", name) + .startObject("model") + .field("type", "model/xgboost+json") + .field("definition", defaultJsonModel) + .endObject() + .endObject() + .endObject(); request.setJsonEntity(BytesReference.bytes(xb).utf8ToString()); Response response = client().performRequest(request); @@ -226,10 +214,7 @@ public void createDefaultModel() throws IOException { * @param name feature set */ public void deleteModel(String name) throws IOException { - Request request = new Request( - "DELETE", - "/_ltr/_model/" + name - ); + Request request = new Request("DELETE", "/_ltr/_model/" + name); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -247,10 +232,7 @@ public void deleteDefaultModel() throws IOException { * @param name feature set */ public void getModel(String name) throws IOException { - Request request = new Request( - "GET", - "/_ltr/_model/" + name - ); + Request request = new Request("GET", "/_ltr/_model/" + name); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -266,7 +248,7 @@ public void getDefaultModel() throws IOException { private String readSource(String path) throws IOException { try (InputStream is = this.getClass().getResourceAsStream(path)) { ByteArrayOutputStream bos = new ByteArrayOutputStream(); - Streams.copy(is, bos); + Streams.copy(is, bos); return bos.toString(StandardCharsets.UTF_8.name()); } } diff --git a/src/test/java/org/opensearch/ltr/breaker/LTRCircuitBreakerServiceTests.java b/src/test/java/org/opensearch/ltr/breaker/LTRCircuitBreakerServiceTests.java index 801ba05a..03650e05 100644 --- a/src/test/java/org/opensearch/ltr/breaker/LTRCircuitBreakerServiceTests.java +++ b/src/test/java/org/opensearch/ltr/breaker/LTRCircuitBreakerServiceTests.java @@ -15,6 +15,13 @@ package org.opensearch.ltr.breaker; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.when; + import org.junit.Before; import org.junit.Test; import org.mockito.InjectMocks; @@ -23,13 +30,6 @@ import org.opensearch.monitor.jvm.JvmService; import org.opensearch.monitor.jvm.JvmStats; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; -import static org.mockito.Mockito.when; - public class LTRCircuitBreakerServiceTests { @InjectMocks @@ -116,4 +116,4 @@ public void testIsOpen1() { ltrCircuitBreakerService.registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(jvmService)); assertThat(ltrCircuitBreakerService.isOpen(), equalTo(true)); } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/ltr/breaker/MemoryCircuitBreakerTests.java b/src/test/java/org/opensearch/ltr/breaker/MemoryCircuitBreakerTests.java index 448ddb01..559cd488 100644 --- a/src/test/java/org/opensearch/ltr/breaker/MemoryCircuitBreakerTests.java +++ b/src/test/java/org/opensearch/ltr/breaker/MemoryCircuitBreakerTests.java @@ -15,6 +15,10 @@ package org.opensearch.ltr.breaker; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.when; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -22,10 +26,6 @@ import org.opensearch.monitor.jvm.JvmService; import org.opensearch.monitor.jvm.JvmStats; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Mockito.when; - public class MemoryCircuitBreakerTests { @Mock @@ -74,4 +74,4 @@ public void testIsOpen3() { when(mem.getHeapUsedPercent()).thenReturn((short) 95); assertThat(breaker.isOpen(), equalTo(true)); } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/ltr/settings/LTRSettingsTestIT.java b/src/test/java/org/opensearch/ltr/settings/LTRSettingsTestIT.java index e7334b76..50b2bc1f 100644 --- a/src/test/java/org/opensearch/ltr/settings/LTRSettingsTestIT.java +++ b/src/test/java/org/opensearch/ltr/settings/LTRSettingsTestIT.java @@ -15,11 +15,11 @@ package org.opensearch.ltr.settings; +import static org.hamcrest.Matchers.containsString; + import org.opensearch.client.ResponseException; import org.opensearch.ltr.LTRRestTestCase; -import static org.hamcrest.Matchers.containsString; - public class LTRSettingsTestIT extends LTRRestTestCase { public void testCreateStoreDisabled() throws Exception { diff --git a/src/test/java/org/opensearch/ltr/stats/LTRStatTests.java b/src/test/java/org/opensearch/ltr/stats/LTRStatTests.java index f94c6351..bde34534 100644 --- a/src/test/java/org/opensearch/ltr/stats/LTRStatTests.java +++ b/src/test/java/org/opensearch/ltr/stats/LTRStatTests.java @@ -15,13 +15,13 @@ package org.opensearch.ltr.stats; -import org.junit.Test; -import org.opensearch.ltr.stats.suppliers.CounterSupplier; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import org.junit.Test; +import org.opensearch.ltr.stats.suppliers.CounterSupplier; + public class LTRStatTests { @Test public void testIsClusterLevel() { @@ -52,7 +52,7 @@ public void testIncrementCounterSupplier() { } @Test(expected = UnsupportedOperationException.class) - public void testThrowExceptionIncrementNonCounterSupplier(){ + public void testThrowExceptionIncrementNonCounterSupplier() { LTRStat nonIncStat = new LTRStat<>(false, () -> "test"); nonIncStat.increment(); } diff --git a/src/test/java/org/opensearch/ltr/stats/LTRStatsTests.java b/src/test/java/org/opensearch/ltr/stats/LTRStatsTests.java index 47eec70f..95b6b980 100644 --- a/src/test/java/org/opensearch/ltr/stats/LTRStatsTests.java +++ b/src/test/java/org/opensearch/ltr/stats/LTRStatsTests.java @@ -15,17 +15,17 @@ package org.opensearch.ltr.stats; -import org.junit.Before; -import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; +import org.junit.Before; +import org.junit.Test; public class LTRStatsTests { @@ -72,8 +72,7 @@ public void testGetNodeStats() { Set> nodeStats = new HashSet<>(ltrStats.getNodeStats().values()); for (LTRStat stat : stats.values()) { - assertTrue((stat.isClusterLevel() && !nodeStats.contains(stat)) || - (!stat.isClusterLevel() && nodeStats.contains(stat))); + assertTrue((stat.isClusterLevel() && !nodeStats.contains(stat)) || (!stat.isClusterLevel() && nodeStats.contains(stat))); } } @@ -83,8 +82,7 @@ public void testGetClusterStats() { Set> clusterStats = new HashSet<>(ltrStats.getClusterStats().values()); for (LTRStat stat : stats.values()) { - assertTrue((stat.isClusterLevel() && clusterStats.contains(stat)) || - (!stat.isClusterLevel() && !clusterStats.contains(stat))); + assertTrue((stat.isClusterLevel() && clusterStats.contains(stat)) || (!stat.isClusterLevel() && !clusterStats.contains(stat))); } } } diff --git a/src/test/java/org/opensearch/ltr/stats/suppliers/CacheStatsOnNodeSupplierTests.java b/src/test/java/org/opensearch/ltr/stats/suppliers/CacheStatsOnNodeSupplierTests.java index 7b701c03..ff3a8133 100644 --- a/src/test/java/org/opensearch/ltr/stats/suppliers/CacheStatsOnNodeSupplierTests.java +++ b/src/test/java/org/opensearch/ltr/stats/suppliers/CacheStatsOnNodeSupplierTests.java @@ -15,10 +15,10 @@ package org.opensearch.ltr.stats.suppliers; -import com.o19s.es.ltr.feature.Feature; -import com.o19s.es.ltr.feature.FeatureSet; -import com.o19s.es.ltr.feature.store.CompiledLtrModel; -import com.o19s.es.ltr.feature.store.index.Caches; +import static org.mockito.Mockito.when; + +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -26,9 +26,10 @@ import org.opensearch.common.cache.Cache; import org.opensearch.test.OpenSearchTestCase; -import java.util.Map; - -import static org.mockito.Mockito.when; +import com.o19s.es.ltr.feature.Feature; +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.store.CompiledLtrModel; +import com.o19s.es.ltr.feature.store.index.Caches; public class CacheStatsOnNodeSupplierTests extends OpenSearchTestCase { @Mock @@ -79,8 +80,7 @@ public void testGetCacheStats() { 1, 1, 0, 1, 800); } - private void assertCacheStats(Map stat, long hits, - long misses, long evictions, int entries, long memUsage) { + private void assertCacheStats(Map stat, long hits, long misses, long evictions, int entries, long memUsage) { assertEquals(hits, stat.get("hit_count")); assertEquals(misses, stat.get("miss_count")); assertEquals(evictions, stat.get("eviction_count")); diff --git a/src/test/java/org/opensearch/ltr/stats/suppliers/CounterSupplierTests.java b/src/test/java/org/opensearch/ltr/stats/suppliers/CounterSupplierTests.java index ab3bd03d..32713e7d 100644 --- a/src/test/java/org/opensearch/ltr/stats/suppliers/CounterSupplierTests.java +++ b/src/test/java/org/opensearch/ltr/stats/suppliers/CounterSupplierTests.java @@ -26,4 +26,4 @@ public void testGetAndIncrement() { counterSupplier.increment(); assertEquals((Long) 1L, counterSupplier.get()); } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/ltr/stats/suppliers/PluginHealthStatusSupplierTests.java b/src/test/java/org/opensearch/ltr/stats/suppliers/PluginHealthStatusSupplierTests.java index 7f55ac6f..c8fd21bf 100644 --- a/src/test/java/org/opensearch/ltr/stats/suppliers/PluginHealthStatusSupplierTests.java +++ b/src/test/java/org/opensearch/ltr/stats/suppliers/PluginHealthStatusSupplierTests.java @@ -15,6 +15,11 @@ package org.opensearch.ltr.stats.suppliers; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.when; + +import java.util.Arrays; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -23,11 +28,6 @@ import org.opensearch.ltr.breaker.LTRCircuitBreakerService; import org.opensearch.ltr.stats.suppliers.utils.StoreUtils; -import java.util.Arrays; - -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.when; - public class PluginHealthStatusSupplierTests { private PluginHealthStatusSupplier pluginHealthStatusSupplier; @@ -40,8 +40,7 @@ public class PluginHealthStatusSupplierTests { @Before public void setup() { MockitoAnnotations.openMocks(this); - pluginHealthStatusSupplier = - new PluginHealthStatusSupplier(storeUtils, ltrCircuitBreakerService); + pluginHealthStatusSupplier = new PluginHealthStatusSupplier(storeUtils, ltrCircuitBreakerService); } @Test diff --git a/src/test/java/org/opensearch/ltr/stats/suppliers/StoreStatsSupplierTests.java b/src/test/java/org/opensearch/ltr/stats/suppliers/StoreStatsSupplierTests.java index e9c58e03..836bec2e 100644 --- a/src/test/java/org/opensearch/ltr/stats/suppliers/StoreStatsSupplierTests.java +++ b/src/test/java/org/opensearch/ltr/stats/suppliers/StoreStatsSupplierTests.java @@ -15,6 +15,12 @@ package org.opensearch.ltr.stats.suppliers; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -22,12 +28,6 @@ import org.opensearch.ltr.stats.suppliers.utils.StoreUtils; import org.opensearch.test.OpenSearchTestCase; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static org.mockito.Mockito.when; - public class StoreStatsSupplierTests extends OpenSearchTestCase { private static final String STORE_NAME = ".ltrstore"; @@ -68,4 +68,4 @@ public void getStoreStats_Success() { assertEquals(1, ltrStoreStats.get(StoreStatsSupplier.LTR_STORE_FEATURE_SET_COUNT)); assertEquals(5L, ltrStoreStats.get(StoreStatsSupplier.LTR_STORE_MODEL_COUNT)); } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/ltr/stats/suppliers/utils/StoreUtilsTests.java b/src/test/java/org/opensearch/ltr/stats/suppliers/utils/StoreUtilsTests.java index cbc86fce..59c4febd 100644 --- a/src/test/java/org/opensearch/ltr/stats/suppliers/utils/StoreUtilsTests.java +++ b/src/test/java/org/opensearch/ltr/stats/suppliers/utils/StoreUtilsTests.java @@ -15,14 +15,14 @@ package org.opensearch.ltr.stats.suppliers.utils; -import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.opensearch.index.IndexNotFoundException; -import org.opensearch.ltr.stats.suppliers.utils.StoreUtils; import org.opensearch.test.OpenSearchIntegTestCase; -import java.util.Map; +import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; public class StoreUtilsTests extends OpenSearchIntegTestCase { private StoreUtils storeUtils; @@ -96,46 +96,42 @@ public void getModelCount() { assertEquals(1, storeUtils.getModelCount(IndexFeatureStore.DEFAULT_STORE)); } - private String testFeatureSet() { - return "{\n" + - "\"name\": \"movie_features\",\n" + - "\"type\": \"featureset\",\n" + - "\"featureset\": {\n" + - " \"name\": \"movie_features\",\n" + - " \"features\": [\n" + - " {\n" + - " \"name\": \"1\",\n" + - " \"params\": [\n" + - " \"keywords\"\n" + - " ],\n" + - " \"template_language\": \"mustache\",\n" + - " \"template\": {\n" + - " \"match\": {\n" + - " \"title\": \"{{keywords}}\"\n" + - " }\n" + - " }\n" + - " },\n" + - " {\n" + - " \"name\": \"2\",\n" + - " \"params\": [\n" + - " \"keywords\"\n" + - " ],\n" + - " \"template_language\": \"mustache\",\n" + - " \"template\": {\n" + - " \"match\": {\n" + - " \"overview\": \"{{keywords}}\"\n" + - " }\n" + - " }\n" + - " }\n" + - " ]\n" + - "}\n}"; + return "{\n" + + "\"name\": \"movie_features\",\n" + + "\"type\": \"featureset\",\n" + + "\"featureset\": {\n" + + " \"name\": \"movie_features\",\n" + + " \"features\": [\n" + + " {\n" + + " \"name\": \"1\",\n" + + " \"params\": [\n" + + " \"keywords\"\n" + + " ],\n" + + " \"template_language\": \"mustache\",\n" + + " \"template\": {\n" + + " \"match\": {\n" + + " \"title\": \"{{keywords}}\"\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"name\": \"2\",\n" + + " \"params\": [\n" + + " \"keywords\"\n" + + " ],\n" + + " \"template_language\": \"mustache\",\n" + + " \"template\": {\n" + + " \"match\": {\n" + + " \"overview\": \"{{keywords}}\"\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}\n}"; } private String testModel() { - return "{\n" + - "\"name\": \"movie_model\",\n" + - "\"type\": \"model\"" + - "\n}"; + return "{\n" + "\"name\": \"movie_model\",\n" + "\"type\": \"model\"" + "\n}"; } } diff --git a/src/test/java/org/opensearch/ltr/transport/TransportLTRStatsActionTests.java b/src/test/java/org/opensearch/ltr/transport/TransportLTRStatsActionTests.java index 9001f224..42e6d580 100644 --- a/src/test/java/org/opensearch/ltr/transport/TransportLTRStatsActionTests.java +++ b/src/test/java/org/opensearch/ltr/transport/TransportLTRStatsActionTests.java @@ -15,6 +15,13 @@ package org.opensearch.ltr.transport; +import static org.mockito.Mockito.mock; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.junit.Before; import org.junit.Test; import org.opensearch.action.FailedNodeException; @@ -26,13 +33,6 @@ import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.transport.TransportService; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.mockito.Mockito.mock; - public class TransportLTRStatsActionTests extends OpenSearchIntegTestCase { private TransportLTRStatsAction action; @@ -55,11 +55,11 @@ public void setUp() throws Exception { ltrStats = new LTRStats(statsMap); action = new TransportLTRStatsAction( - client().threadPool(), - clusterService(), - mock(TransportService.class), - mock(ActionFilters.class), - ltrStats + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + ltrStats ); }