Skip to content

Commit

Permalink
Merge pull request #66 from beehyv/inference_endpoint
Browse files Browse the repository at this point in the history
making change to capture inference endpoint for openai
  • Loading branch information
kshitiz-okahu authored Oct 29, 2024
2 parents 17796c1 + 2609cff commit 43b81c1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
},
{
"attribute": "inference_endpoint",
"accessor": "lambda instance,args: resolve_from_alias(instance.__dict__, ['azure_endpoint', 'api_base'])"
"accessor": "lambda instance,args: resolve_from_alias(instance.__dict__, ['azure_endpoint', 'api_base']) or args['inference_endpoint']"
}
],
[
Expand Down
18 changes: 11 additions & 7 deletions src/monocle_apptrace/wrap_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ async def allm_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name:
input_arg_text = get_attribute(DATA_INPUT_KEY)
span.add_event(DATA_INPUT_KEY, {QUERY: input_arg_text})
provider_name = set_provider_name(instance)
instance_args = {"provider_name": provider_name}
provider_name, inference_endpoint = get_provider_name(instance)
instance_args = {"provider_name": provider_name, "inference_endpoint": inference_endpoint}

process_span(to_wrap, span, instance, instance_args)

Expand Down Expand Up @@ -246,8 +246,8 @@ def llm_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
if 'haystack.components.retrievers' in to_wrap['package'] and 'haystack.retriever' in span.name:
input_arg_text = get_attribute(DATA_INPUT_KEY)
span.add_event(DATA_INPUT_KEY, {QUERY: input_arg_text})
provider_name = set_provider_name(instance)
instance_args = {"provider_name": provider_name}
provider_name, inference_endpoint = get_provider_name(instance)
instance_args = {"provider_name": provider_name, "inference_endpoint": inference_endpoint}

process_span(to_wrap, span, instance, instance_args)

Expand Down Expand Up @@ -289,12 +289,16 @@ def update_llm_endpoint(curr_span: Span, instance):
)


def set_provider_name(instance):
def get_provider_name(instance):
provider_url = ""

inference_endpoint = ""
try:
if isinstance(instance.client._client.base_url.host, str):
provider_url = instance.client._client.base_url.host
if isinstance(instance.client._client.base_url, str):
inference_endpoint = instance.client._client.base_url
else:
inference_endpoint = str(instance.client._client.base_url)
except:
pass

Expand All @@ -309,7 +313,7 @@ def set_provider_name(instance):
parsed_provider_url = urlparse(provider_url)
except:
pass
return parsed_provider_url.hostname or provider_url
return parsed_provider_url.hostname or provider_url,inference_endpoint


def is_root_span(curr_span: Span) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions tests/langchain_chat_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ def format_docs(docs):
# "entity.count": 2,
# "entity.1.type": "inference.azure_oai",
# "entity.1.provider_name": "api.openai.com",
# "entity.1.inference_endpoint": "https://api.openai.com/v1/",
# "entity.2.name": "gpt-3.5-turbo-0125",
# "entity.2.type": "model.llm",
# "entity.2.model_name": "gpt-3.5-turbo-0125"
# "entity.2.type": "model.llm.gpt-3.5-turbo-0125"
# },
# "events": [
# {
Expand Down
2 changes: 2 additions & 0 deletions tests/langchain_custom_output_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ def test_llm_chain(self, test_name, test_input_infra, test_output_infra, mock_po
dataJson = json.loads(dataBodyStr) # more asserts can be added on individual fields

llm_vector_store_retriever_span = [x for x in dataJson["batch"] if 'langchain.task.VectorStoreRetriever' in x["name"]][0]
inference_span = [x for x in dataJson["batch"] if 'langchain.task.FakeListLLM' in x["name"]][0]

assert llm_vector_store_retriever_span["attributes"]["span.type"] == "retrieval"
assert llm_vector_store_retriever_span["attributes"]["entity.1.name"] == "FAISS"
assert llm_vector_store_retriever_span["attributes"]["entity.1.type"] == "vectorstore.FAISS"
assert inference_span['attributes']["entity.1.inference_endpoint"] == "https://example.com/"

finally:
os.environ.pop(test_input_infra)
Expand Down

0 comments on commit 43b81c1

Please sign in to comment.