Skip to content

Commit

Permalink
updating inference span for mistral ai and changing workflow span inp…
Browse files Browse the repository at this point in the history
…ut/output events

Signed-off-by: hansrajr <[email protected]>
  • Loading branch information
Hansrajr committed Nov 21, 2024
1 parent 43712ce commit b08b7f0
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 44 deletions.
18 changes: 18 additions & 0 deletions src/monocle_apptrace/metamodel/maps/llamaindex_methods.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@
"wrapper_package": "wrap_common",
"wrapper_method": "allm_wrapper",
"output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"]
},
{
"package": "llama_index.llms.mistralai.base",
"object": "MistralAI",
"method": "chat",
"span_name": "llamaindex.mistralai",
"wrapper_package": "wrap_common",
"wrapper_method": "llm_wrapper",
"output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"]
},
{
"package": "llama_index.llms.mistralai.base",
"object": "MistralAI",
"method": "achat",
"span_name": "llamaindex.mistralai",
"wrapper_package": "wrap_common",
"wrapper_method": "allm_wrapper",
"output_processor": ["metamodel/maps/attributes/inference/llamaindex_entities.json"]
}
]
}
35 changes: 22 additions & 13 deletions src/monocle_apptrace/wrap_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
DATA_OUTPUT_KEY = "data.output"
PROMPT_INPUT_KEY = "data.input"
PROMPT_OUTPUT_KEY = "data.output"
QUERY = "question"
QUERY = "input"
RESPONSE = "response"
SESSION_PROPERTIES_KEY = "session"
INFRA_SERVICE_KEY = "infra_service_name"
Expand Down Expand Up @@ -211,7 +211,7 @@ def pre_task_processing(to_wrap, instance, args, span):
sdk_version = version("monocle_apptrace")
span.set_attribute("monocle_apptrace.version", sdk_version)
except:
logger.warning(f"Exception finding okahu-observability version.")
logger.warning(f"Exception finding monocle-apptrace version.")
update_span_with_prompt_input(to_wrap=to_wrap, wrapped_args=args, span=span)
update_span_with_context_input(to_wrap=to_wrap, wrapped_args=args, span=span)
except:
Expand Down Expand Up @@ -266,7 +266,7 @@ async def allm_wrapper(tracer, to_wrap, wrapped, instance, args, kwargs):
with tracer.start_as_current_span(name) as span:
provider_name, inference_endpoint = get_provider_name(instance)
return_value = await wrapped(*args, **kwargs)
kwargs.update({"provider_name": provider_name, "inference_endpoint": inference_endpoint})
kwargs.update({"provider_name": provider_name, "inference_endpoint": inference_endpoint or getattr(instance, 'endpoint', None)})
process_span(to_wrap, span, instance, args, kwargs, return_value)
update_span_from_llm_response(response=return_value, span=span, instance=instance, args=args)

Expand All @@ -292,7 +292,7 @@ def llm_wrapper(tracer: Tracer, to_wrap, wrapped, instance, args, kwargs):
with tracer.start_as_current_span(name) as span:
provider_name, inference_endpoint = get_provider_name(instance)
return_value = wrapped(*args, **kwargs)
kwargs.update({"provider_name": provider_name, "inference_endpoint": inference_endpoint})
kwargs.update({"provider_name": provider_name, "inference_endpoint": inference_endpoint or getattr(instance, 'endpoint', None)})
process_span(to_wrap, span, instance, args, kwargs, return_value)
update_span_from_llm_response(response=return_value, span=span, instance=instance, args=args)

Expand Down Expand Up @@ -437,16 +437,22 @@ def update_events_for_inference_span(response, span, args):
if span.name == 'haystack.components.generators.openai.OpenAIGenerator':
args_input = get_attribute(DATA_INPUT_KEY)
span.add_event(name="data.input", attributes={"input": args_input}, )
span.add_event(name="data.output", attributes={"output": response['replies'][0]}, )
span.add_event(name="data.output", attributes={"response": response['replies'][0]}, )
if args and isinstance(args, tuple) and len(args) > 0:
if hasattr(args[0], "messages") and isinstance(args[0].messages, list):
for msg in args[0].messages:
if hasattr(msg, 'content') and hasattr(msg, 'type'):
if hasattr(msg, "content") and hasattr(msg, "type"):
if msg.type == "system":
system_message = msg.content
elif msg.type in ["user", "human"]:
user_message = msg.content
if msg.type == "system":
system_message = msg.content
elif msg.type in ["user", "human"]:
user_message = msg.content
if isinstance(args[0], list):
for msg in args[0]:
if hasattr(msg, 'content') and hasattr(msg, 'role'):
if msg.role == "system":
system_message = msg.content
elif msg.role in ["user", "human"]:
user_message = extract_query_from_content(msg.content)
if span.name == 'llamaindex.openai':
if args and isinstance(args, tuple):
chat_messages = args[0]
Expand All @@ -468,7 +474,7 @@ def update_events_for_inference_span(response, span, args):
if hasattr(response, "content") or hasattr(response, "message"):
assistant_message = getattr(response, "content", "") or response.message.content
if assistant_message:
span.add_event(name="data.output", attributes={"assistant" if system_message else "output": assistant_message}, )
span.add_event(name="data.output", attributes={"assistant" if system_message else "response": assistant_message}, )


def extract_query_from_content(content):
Expand Down Expand Up @@ -533,7 +539,7 @@ def update_span_with_prompt_input(to_wrap, wrapped_args, span: Span):
input_arg_text = flatten_dict(prompt_inputs)
span.add_event(PROMPT_INPUT_KEY, input_arg_text)
elif isinstance(input_arg_text, dict):
span.add_event(PROMPT_INPUT_KEY, input_arg_text)
span.add_event(PROMPT_INPUT_KEY, {QUERY: input_arg_text['input']})
else:
span.add_event(PROMPT_INPUT_KEY, {QUERY: input_arg_text})

Expand All @@ -550,4 +556,7 @@ def update_span_with_prompt_output(to_wrap, wrapped_args, span: Span):
elif isinstance(wrapped_args, str):
span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: wrapped_args})
elif isinstance(wrapped_args, dict):
span.add_event(PROMPT_OUTPUT_KEY, wrapped_args)
if "langchain.schema.runnable" in package_name:
span.add_event(PROMPT_OUTPUT_KEY, {RESPONSE: wrapped_args['answer']})
else:
span.add_event(PROMPT_OUTPUT_KEY, wrapped_args)
42 changes: 24 additions & 18 deletions tests/langchain_chat_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,22 @@


# llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
llm = OpenAI(model="gpt-3.5-turbo-instruct")
# llm = AzureOpenAI(
# # engine=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
# azure_deployment=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
# api_key=os.environ.get("AZURE_OPENAI_API_KEY"),
# api_version=os.environ.get("AZURE_OPENAI_API_VERSION"),
# azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT"),
# temperature=0.1,
# # model="gpt-4",
# model="gpt-3.5-turbo-0125")
# llm = OpenAI(model="gpt-3.5-turbo-instruct")
llm = AzureChatOpenAI(
# engine=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
azure_deployment=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
api_key=os.environ.get("AZURE_OPENAI_API_KEY"),
api_version=os.environ.get("AZURE_OPENAI_API_VERSION"),
azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT"),
temperature=0.7,
# model="gpt-4",
model="gpt-3.5-turbo-0125"
)

# llm = ChatMistralAI(
# model="mistral-large-latest",
# temperature=0.7,
# )

# Load, chunk and index the contents of the blog.
loader = WebBaseLoader(
Expand All @@ -56,12 +62,12 @@
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
# rag_chain = (
# {"context": retriever | format_docs, "question": RunnablePassthrough()}
# | prompt
# | llm
# | StrOutputParser()
# )


contextualize_q_system_prompt = """Given a chat history and the latest user question \
Expand Down Expand Up @@ -104,13 +110,13 @@ def format_docs(docs):

question = "What is Task Decomposition?"
ai_msg_1 = rag_chain.invoke({"input": question, "chat_history": chat_history})
print(ai_msg_1["answer"])
# print(ai_msg_1["answer"])
chat_history.extend([HumanMessage(content=question), ai_msg_1["answer"]])

second_question = "What are common ways of doing it?"
ai_msg_2 = rag_chain.invoke({"input": second_question, "chat_history": chat_history})

print(ai_msg_2["answer"])
# print(ai_msg_2["answer"])


# {
Expand Down
28 changes: 18 additions & 10 deletions tests/langchain_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from langchain_text_splitters import RecursiveCharacterTextSplitter
from monocle_apptrace.instrumentor import setup_monocle_telemetry
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from langchain_mistralai import ChatMistralAI

import os
os.environ["AZURE_OPENAI_API_DEPLOYMENT"] = ""
os.environ["AZURE_OPENAI_API_KEY"] = ""
Expand All @@ -21,17 +23,23 @@
span_processors=[BatchSpanProcessor(ConsoleSpanExporter())],
wrapper_methods=[])

# llm = OpenAI(model="gpt-3.5-turbo-instruct")
llm = AzureOpenAI(
# engine=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
azure_deployment=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
api_key=os.environ.get("AZURE_OPENAI_API_KEY"),
api_version=os.environ.get("AZURE_OPENAI_API_VERSION"),
azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT"),
temperature=0.1,
# model="gpt-4",

model="gpt-3.5-turbo-0125")
llm = ChatMistralAI(
model="mistral-large-latest",
temperature=0.7,
)

# llm = OpenAI(model="gpt-3.5-turbo-instruct")
# llm = AzureOpenAI(
# # engine=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
# azure_deployment=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
# api_key=os.environ.get("AZURE_OPENAI_API_KEY"),
# api_version=os.environ.get("AZURE_OPENAI_API_VERSION"),
# azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT"),
# temperature=0.1,
# # model="gpt-4",
#
# model="gpt-3.5-turbo-0125")
# Load, chunk and index the contents of the blog.
loader = WebBaseLoader(
web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
Expand Down
14 changes: 11 additions & 3 deletions tests/llama_index_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
from monocle_apptrace.wrap_common import llm_wrapper
from monocle_apptrace.wrapper import WrapperMethod
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter

from llama_index.llms.mistralai import MistralAI
os.environ["AZURE_OPENAI_API_DEPLOYMENT"] = ""
os.environ["AZURE_OPENAI_API_KEY"] = ""
os.environ["AZURE_OPENAI_API_VERSION"] = ""
os.environ["AZURE_OPENAI_ENDPOINT"] = ""
os.environ["OPENAI_API_KEY"] = ""
os.environ["MISTRAL_API_KEY"] = ""
setup_monocle_telemetry(
workflow_name="llama_index_1",
span_processors=[BatchSpanProcessor(ConsoleSpanExporter())],
wrapper_methods=[]
)

api_key_string = "jkfXbZDj7QrBHlBbBBYmEWZAJa2kk1Gd"
# Creating a Chroma client
# EphemeralClient operates purely in-memory, PersistentClient will also save to disk
chroma_client = chromadb.EphemeralClient()
Expand All @@ -37,7 +43,7 @@
documents, storage_context=storage_context, embed_model=embed_model
)

llm = OpenAI(temperature=0.1, model="gpt-4")
# llm = OpenAI(temperature=0.8, model="gpt-4")
# llm = AzureOpenAI(
# engine=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
# azure_deployment=os.environ.get("AZURE_OPENAI_API_DEPLOYMENT"),
Expand All @@ -49,6 +55,8 @@
#
# model="gpt-3.5-turbo-0125")

llm = MistralAI(api_key=os.getenv("MISTRAL_API_KEY"))

query_engine = index.as_query_engine(llm= llm, )
response = query_engine.query("What did the author do growing up?")

Expand Down

0 comments on commit b08b7f0

Please sign in to comment.