Skip to content

Commit

Permalink
Google Generators: change answers to replies (#626)
Browse files Browse the repository at this point in the history
* change answers to replies

* fix test

* fix also vertex QA
  • Loading branch information
anakin87 authored Mar 27, 2024
1 parent 422e426 commit 9392e51
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class GoogleAIGeminiGenerator:
gemini = GoogleAIGeminiGenerator(model="gemini-pro", api_key=Secret.from_token("<MY_API_KEY>"))
res = gemini.run(parts = ["What is the most interesting thing you know?"])
for answer in res["answers"]:
for answer in res["replies"]:
print(answer)
```
Expand Down Expand Up @@ -55,7 +55,7 @@ class GoogleAIGeminiGenerator:
gemini = GoogleAIGeminiGenerator(model="gemini-pro-vision", api_key=Secret.from_token("<MY_API_KEY>"))
result = gemini.run(parts = ["What can you tell me about this robots?", *images])
for answer in result["answers"]:
for answer in result["replies"]:
print(answer)
```
"""
Expand Down Expand Up @@ -173,7 +173,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(answers=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[Union[str, Dict[str, str]]])
def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
"""
Generates text based on the given input parts.
Expand All @@ -182,7 +182,7 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
A heterogeneous list of strings, `ByteStream` or `Part` objects.
:returns:
A dictionary containing the following key:
- `answers`: A list of strings or dictionaries with function calls.
- `replies`: A list of strings or dictionaries with function calls.
"""

converted_parts = [self._convert_part(p) for p in parts]
Expand All @@ -194,16 +194,16 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
safety_settings=self._safety_settings,
)
self._model.start_chat()
answers = []
replies = []
for candidate in res.candidates:
for part in candidate.content.parts:
if part.text != "":
answers.append(part.text)
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
answers.append(function_call)
replies.append(function_call)

return {"answers": answers}
return {"replies": replies}
2 changes: 1 addition & 1 deletion integrations/google_ai/tests/generators/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,4 @@ def test_from_dict(monkeypatch):
def test_run():
gemini = GoogleAIGeminiGenerator(model="gemini-pro")
res = gemini.run("Tell me something cool")
assert len(res["answers"]) > 0
assert len(res["replies"]) > 0
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class VertexAICodeGenerator:
result = generator.run(prefix="def to_json(data):")
for answer in result["answers"]:
for answer in result["replies"]:
print(answer)
>>> ```python
Expand Down Expand Up @@ -92,17 +92,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAICodeGenerator":
"""
return default_from_dict(cls, data)

@component.output_types(answers=List[str])
@component.output_types(replies=List[str])
def run(self, prefix: str, suffix: Optional[str] = None):
"""
Generate code using a Google Vertex AI model.
:param prefix: Code before the current point.
:param suffix: Code after the current point.
:returns: A dictionary with the following keys:
- `answers`: A list of generated code snippets.
- `replies`: A list of generated code snippets.
"""
res = self._model.predict(prefix=prefix, suffix=suffix, **self._kwargs)
# Handle the case where the model returns multiple candidates
answers = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text]
return {"answers": answers}
replies = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text]
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class VertexAIGeminiGenerator:
gemini = VertexAIGeminiGenerator(project_id=project_id)
result = gemini.run(parts = ["What is the most interesting thing you know?"])
for answer in result["answers"]:
for answer in result["replies"]:
print(answer)
>>> 1. **The Origin of Life:** How and where did life begin? The answers to this ...
Expand Down Expand Up @@ -175,14 +175,14 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(answers=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[Union[str, Dict[str, str]]])
def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
"""
Generates content using the Gemini model.
:param parts: Prompt for the model.
:returns: A dictionary with the following keys:
- `answers`: A list of generated content.
- `replies`: A list of generated content.
"""
converted_parts = [self._convert_part(p) for p in parts]

Expand All @@ -194,16 +194,16 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
tools=self._tools,
)
self._model.start_chat()
answers = []
replies = []
for candidate in res.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
answers.append(part.text)
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
answers.append(function_call)
replies.append(function_call)

return {"answers": answers}
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class VertexAIImageQA:
res = qa.run(image=image, question="What color is this dog")
print(res["answers"][0])
print(res["replies"][0])
>>> white
```
Expand Down Expand Up @@ -82,14 +82,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageQA":
"""
return default_from_dict(cls, data)

@component.output_types(answers=List[str])
@component.output_types(replies=List[str])
def run(self, image: ByteStream, question: str):
"""Prompts model to answer a question about an image.
:param image: The image to ask the question about.
:param question: The question to ask.
:returns: A dictionary with the following keys:
- `answers`: A list of answers to the question.
- `replies`: A list of answers to the question.
"""
answers = self._model.ask_question(image=Image(image.data), question=question, **self._kwargs)
return {"answers": answers}
replies = self._model.ask_question(image=Image(image.data), question=question, **self._kwargs)
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class VertexAITextGenerator:
generator = VertexAITextGenerator(project_id=project_id)

This comment has been minimized.

Copy link
@KEYBUMPJANE

KEYBUMPJANE Apr 26, 2024

``

res = generator.run("Tell me a good interview question for a software engineer.")
print(res["answers"][0])
print(res["replies"][0])
>>> **Question:**
>>> You are given a list of integers and a target sum.
Expand Down Expand Up @@ -109,26 +109,26 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator":
)
return default_from_dict(cls, data)

@component.output_types(answers=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]])
@component.output_types(replies=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]])
def run(self, prompt: str):
"""Prompts the model to generate text.
:param prompt: The prompt to use for text generation.
:returns: A dictionary with the following keys:
- `answers`: A list of generated answers.
- `replies`: A list of generated replies.
- `safety_attributes`: A dictionary with the [safety scores](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/responsible-ai#safety_attribute_descriptions)
of each answer.
- `citations`: A list of citations for each answer.
"""
res = self._model.predict(prompt=prompt, **self._kwargs)

answers = []
replies = []
safety_attributes = []
citations = []

for prediction in res.raw_prediction_response.predictions:
answers.append(prediction["content"])
replies.append(prediction["content"])
safety_attributes.append(prediction["safetyAttributes"])
citations.append(prediction["citationMetadata"]["citations"])

return {"answers": answers, "safety_attributes": safety_attributes, "citations": citations}
return {"replies": replies, "safety_attributes": safety_attributes, "citations": citations}

0 comments on commit 9392e51

Please sign in to comment.