Skip to content

Commit

Permalink
ensure lists are converted to Opensearch-compatible formats
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinblack committed Dec 15, 2023
1 parent c3cb058 commit b45b880
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion arcaflow_plugin_opensearch/opensearch_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,46 @@
from opensearch_schema import ErrorOutput, SuccessOutput, StoreDocumentRequest


def convert_to_supported_type(value) -> typing.Dict:
type_of_val = type(value)
if type_of_val == list:
new_list = []
for i in value:
new_list.append(convert_to_supported_type(i))
# A list needs to be of a consistent type or it will
# not be indexible into a system like Opensearch
return convert_to_homogenous_list(new_list)
elif type_of_val == dict:
result = {}
for k in value:
result[convert_to_supported_type(k)] = convert_to_supported_type(
value[k]
)
return result
elif type_of_val in (float, int, str, bool):
return value
elif isinstance(type_of_val, type(None)):
return str("")
else:
print("Unknown type", type_of_val, "with val", str(value))
return str(value)

def convert_to_homogenous_list(input_list: list):
# To make all types in list homogeneous, we upconvert them
# to the least commom type.
# int -> float -> str
# bool + None -> str
list_type = str()
for j in input_list:
if type(j) in (str, bool, type(None)):
list_type = str()
break
elif type(j) is float:
list_type = float()
elif type(j) is int and type(list_type) is not float:
list_type = int()
return list(map(type(list_type), input_list))

@plugin.step(
id="opensearch",
name="OpenSearch",
Expand All @@ -18,6 +58,8 @@
def store(
params: StoreDocumentRequest,
) -> typing.Tuple[str, typing.Union[SuccessOutput, ErrorOutput]]:
document = convert_to_supported_type(params.data)

try:
if params.username:
opensearch = OpenSearch(
Expand All @@ -31,7 +73,10 @@ def store(
hosts=params.url,
verify_certs=params.tls_verify,
)
resp = opensearch.index(index=params.index, body=params.data)
resp = opensearch.index(
index=params.index,
body=document,
)
if resp["result"] != "created":
raise Exception(f"Document status: {resp['_shards']}")

Expand Down

0 comments on commit b45b880

Please sign in to comment.