Skip to content

Commit 58eaad0

Browse files
authored
Merge pull request #217 from vespa-engine/tgm/default-query-model
Allow default query model to be specified and define it for TextSearch
2 parents 90bbeff + adec27c commit 58eaad0

File tree

4 files changed

+38
-5
lines changed

4 files changed

+38
-5
lines changed

vespa/application.py

+9
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ def _build_query_body(
208208
**kwargs
209209
) -> Dict:
210210
assert query is not None, "No 'query' specified."
211+
if not query_model:
212+
query_model = self.get_default_query_model()
211213
assert query_model is not None, "No 'query_model' specified."
212214
body = query_model.create_body(query=query)
213215
if recall is not None:
@@ -852,6 +854,13 @@ def application_package(self):
852854
else:
853855
return self._application_package
854856

857+
def get_default_query_model(self):
858+
try:
859+
app_package = self.application_package
860+
except ValueError:
861+
return None
862+
return app_package.default_query_model
863+
855864
def get_model_from_application_package(self, model_name: str):
856865
"""Get model definition from application package, if available."""
857866
app_package = self.application_package

vespa/gallery.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
QueryProfileType,
1313
QueryTypeField,
1414
)
15+
from vespa.query import QueryModel, AND, RankProfile as Ranking
1516

1617

1718
class TextSearch(ApplicationPackage):
@@ -48,11 +49,18 @@ def __init__(
4849
first_phase=" + ".join(["bm25({})".format(x) for x in text_fields]),
4950
),
5051
RankProfile(
51-
name="native_rank", first_phase="nativeRank({})".format(",".join(text_fields))
52+
name="native_rank",
53+
first_phase="nativeRank({})".format(",".join(text_fields)),
5254
),
5355
],
5456
)
55-
super().__init__(name=name, schema=[schema])
57+
super().__init__(
58+
name=name,
59+
schema=[schema],
60+
default_query_model=QueryModel(
61+
name="and_bm25", match_phase=AND(), rank_profile=Ranking(name="bm25")
62+
),
63+
)
5664

5765

5866
class QuestionAnswering(ApplicationPackage):

vespa/package.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from jinja2 import Environment, PackageLoader, select_autoescape
77

88
from vespa.json_serialization import ToJson, FromJson
9+
from vespa.query import QueryModel
910

1011

1112
class HNSW(ToJson, FromJson["HNSW"]):
@@ -1153,6 +1154,7 @@ def __init__(
11531154
create_schema_by_default: bool = True,
11541155
create_query_profile_by_default: bool = True,
11551156
tasks: Optional[List[Task]] = None,
1157+
default_query_model: Optional[QueryModel] = None
11561158
) -> None:
11571159
"""
11581160
Create a Vespa Application Package.
@@ -1173,6 +1175,7 @@ def __init__(
11731175
:param create_query_profile_by_default: Include a default :class:`QueryProfile` and :class:`QueryProfileType`
11741176
in case it is not explicitly defined by the user in the `query_profile` and `query_profile_type` parameters.
11751177
:param tasks: List of tasks to be served.
1178+
:param default_query_model: Optional QueryModel to be used as default for the application.
11761179
11771180
The easiest way to get started is to create a default application package:
11781181
@@ -1200,6 +1203,7 @@ def __init__(
12001203
self.model_configs = {}
12011204
self.stateless_model_evaluation = stateless_model_evaluation
12021205
self.models = {} if not tasks else {model.model_id: model for model in tasks}
1206+
self.default_query_model = default_query_model
12031207

12041208
@property
12051209
def schemas(self) -> List[Schema]:

vespa/test_integration_docker.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,12 @@ def execute_data_operations(
319319
),
320320
},
321321
)
322+
#
323+
# Query with 'query' without QueryModel
324+
#
325+
with self.assertRaisesRegex(AssertionError, "No 'query_model' specified."):
326+
_ = app.query(query="this should not work")
327+
322328
#
323329
# Update data
324330
#
@@ -1394,10 +1400,16 @@ def setUp(self) -> None:
13941400
#
13951401
self.app.feed_df(df)
13961402

1403+
def test_default_query_model(self):
1404+
result = self.app.query(query="what is finance?", debug_request=True)
1405+
expected_request_body = {
1406+
"yql": 'select * from sources * where (userInput("what is finance?"));',
1407+
"ranking": {"profile": "bm25", "listFeatures": "false"},
1408+
}
1409+
self.assertDictEqual(expected_request_body, result.request_body)
1410+
13971411
def test_query(self):
1398-
result = self.app.query(
1399-
query="what is finance?", query_model=QueryModel(match_phase=OR())
1400-
)
1412+
result = self.app.query(query="what is finance?")
14011413
for hit in result.hits:
14021414
self.assertIn("fields", hit)
14031415

0 commit comments

Comments
 (0)