Skip to content

Commit 0001754

Browse files
authored
Merge pull request #536 from vespa-engine/kkraune/feed-util
Add utility to create a vespa feed
2 parents 0de590f + 303c374 commit 0001754

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

vespa/application.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import aiohttp
66
import asyncio
77
import concurrent.futures
8+
import json
89
from collections import Counter
910
from typing import Any, Optional, Dict, List, IO
1011

@@ -30,13 +31,14 @@
3031
)
3132

3233

33-
def parse_feed_df(df: DataFrame, include_id: bool, id_field="id") -> List[Dict[str, Any]]:
34+
def parse_feed_df(df: DataFrame, include_id: bool, id_field="id", id_prefix="") -> List[Dict[str, Any]]:
3435
"""
3536
Convert a df into batch format for feeding
3637
3738
:param df: DataFrame with the following required columns ["id"]. Additional columns are assumed to be fields.
3839
:param include_id: Include id on the fields to be fed.
3940
:param id_field: Name of the column containing the id field.
41+
:param id_prefix: Add a string prefix to ID field, e.g. "id:namespace:schema::"
4042
:return: List of Dict containing 'id' and 'fields'.
4143
"""
4244
required_columns = [id_field]
@@ -46,7 +48,7 @@ def parse_feed_df(df: DataFrame, include_id: bool, id_field="id") -> List[Dict[s
4648
records = df.to_dict(orient="records")
4749
batch = [
4850
{
49-
"id": record[id_field],
51+
"id": record[id_field] if id_prefix == "" else id_prefix + str(record[id_field]),
5052
"fields": record
5153
if include_id
5254
else {k: v for k, v in record.items() if k not in [id_field]},
@@ -56,6 +58,21 @@ def parse_feed_df(df: DataFrame, include_id: bool, id_field="id") -> List[Dict[s
5658
return batch
5759

5860

61+
def df_to_vespafeed(df: DataFrame, schema_name: str, id_field="id", namespace="") -> str:
62+
"""
63+
Convert a df into a string in Vespa JSON feed format,
64+
see https://docs.vespa.ai/en/reference/document-json-format.html
65+
66+
:param df: DataFrame with the following required columns ["id"]. Additional columns are assumed to be fields.
67+
:param schema_name: Schema name
68+
:param id_field: Name of the column containing the id field.
69+
:param namespace: Set if namespace != schema_name
70+
:return: JSON string in Vespa feed format
71+
"""
72+
return json.dumps(parse_feed_df(df, True, id_field,
73+
"id:{}:{}::".format(schema_name if namespace == "" else namespace, schema_name)))
74+
75+
5976
def raise_for_status(response: Response) -> None:
6077
"""
6178
Raises an appropriate error if necessary.
@@ -418,7 +435,7 @@ def feed_batch(
418435
:return: List of HTTP POST responses
419436
"""
420437
mini_batches = [
421-
batch[i : i + batch_size] for i in range(0, len(batch), batch_size)
438+
batch[i: i + batch_size] for i in range(0, len(batch), batch_size)
422439
]
423440
batch_http_responses = []
424441
for idx, mini_batch in enumerate(mini_batches):

vespa/test_application.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,32 @@
22

33
import json
44
import unittest
5+
6+
import pandas
57
import pytest
68
from unittest.mock import PropertyMock, patch
79
from pandas import DataFrame
810
from requests.models import HTTPError, Response
911

1012
from vespa.package import ApplicationPackage, Schema, Document
11-
from vespa.application import Vespa, parse_feed_df, raise_for_status
13+
from vespa.application import Vespa, parse_feed_df, df_to_vespafeed, raise_for_status
1214
from vespa.exceptions import VespaError
1315

1416

17+
def test_df_to_vespafeed():
18+
df = pandas.DataFrame({
19+
"id": [0, 1, 2],
20+
"body": ["body 1", "body 2", "body 3"]
21+
})
22+
feed = json.loads(df_to_vespafeed(df, "myschema", "id"))
23+
24+
assert feed[1]["fields"]["body"] == "body 2"
25+
assert feed[0]["id"] == "id:myschema:myschema::0"
26+
27+
feed = json.loads(df_to_vespafeed(df, "myschema", "id", "mynamespace"))
28+
assert feed[2]["id"] == "id:mynamespace:myschema::2"
29+
30+
1531
class TestVespa(unittest.TestCase):
1632
def test_end_point(self):
1733
self.assertEqual(

0 commit comments

Comments
 (0)