Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google Generators: change answers to replies #626

Merged
merged 3 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class GoogleAIGeminiGenerator:

gemini = GoogleAIGeminiGenerator(model="gemini-pro", api_key=Secret.from_token("<MY_API_KEY>"))
res = gemini.run(parts = ["What is the most interesting thing you know?"])
for answer in res["answers"]:
for answer in res["replies"]:
print(answer)
```

Expand Down Expand Up @@ -55,7 +55,7 @@ class GoogleAIGeminiGenerator:

gemini = GoogleAIGeminiGenerator(model="gemini-pro-vision", api_key=Secret.from_token("<MY_API_KEY>"))
result = gemini.run(parts = ["What can you tell me about this robots?", *images])
for answer in result["answers"]:
for answer in result["replies"]:
print(answer)
```
"""
Expand Down Expand Up @@ -173,7 +173,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(answers=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[Union[str, Dict[str, str]]])
def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
"""
Generates text based on the given input parts.
Expand All @@ -182,7 +182,7 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
A heterogeneous list of strings, `ByteStream` or `Part` objects.
:returns:
A dictionary containing the following key:
- `answers`: A list of strings or dictionaries with function calls.
- `replies`: A list of strings or dictionaries with function calls.
"""

converted_parts = [self._convert_part(p) for p in parts]
Expand All @@ -194,16 +194,16 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
safety_settings=self._safety_settings,
)
self._model.start_chat()
answers = []
replies = []
for candidate in res.candidates:
for part in candidate.content.parts:
if part.text != "":
answers.append(part.text)
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
answers.append(function_call)
replies.append(function_call)

return {"answers": answers}
return {"replies": replies}
2 changes: 1 addition & 1 deletion integrations/google_ai/tests/generators/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,4 @@ def test_from_dict(monkeypatch):
def test_run():
gemini = GoogleAIGeminiGenerator(model="gemini-pro")
res = gemini.run("Tell me something cool")
assert len(res["answers"]) > 0
assert len(res["replies"]) > 0
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class VertexAICodeGenerator:

result = generator.run(prefix="def to_json(data):")

for answer in result["answers"]:
for answer in result["replies"]:
print(answer)

>>> ```python
Expand Down Expand Up @@ -92,17 +92,17 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAICodeGenerator":
"""
return default_from_dict(cls, data)

@component.output_types(answers=List[str])
@component.output_types(replies=List[str])
def run(self, prefix: str, suffix: Optional[str] = None):
"""
Generate code using a Google Vertex AI model.

:param prefix: Code before the current point.
:param suffix: Code after the current point.
:returns: A dictionary with the following keys:
- `answers`: A list of generated code snippets.
- `replies`: A list of generated code snippets.
"""
res = self._model.predict(prefix=prefix, suffix=suffix, **self._kwargs)
# Handle the case where the model returns multiple candidates
answers = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text]
return {"answers": answers}
replies = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text]
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class VertexAIGeminiGenerator:

gemini = VertexAIGeminiGenerator(project_id=project_id)
result = gemini.run(parts = ["What is the most interesting thing you know?"])
for answer in result["answers"]:
for answer in result["replies"]:
print(answer)

>>> 1. **The Origin of Life:** How and where did life begin? The answers to this ...
Expand Down Expand Up @@ -175,14 +175,14 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
msg = f"Unsupported type {type(part)} for part {part}"
raise ValueError(msg)

@component.output_types(answers=List[Union[str, Dict[str, str]]])
@component.output_types(replies=List[Union[str, Dict[str, str]]])
def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
"""
Generates content using the Gemini model.

:param parts: Prompt for the model.
:returns: A dictionary with the following keys:
- `answers`: A list of generated content.
- `replies`: A list of generated content.
"""
converted_parts = [self._convert_part(p) for p in parts]

Expand All @@ -194,16 +194,16 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]):
tools=self._tools,
)
self._model.start_chat()
answers = []
replies = []
for candidate in res.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
answers.append(part.text)
replies.append(part.text)
elif part.function_call is not None:
function_call = {
"name": part.function_call.name,
"args": dict(part.function_call.args.items()),
}
answers.append(function_call)
replies.append(function_call)

return {"answers": answers}
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class VertexAIImageQA:

res = qa.run(image=image, question="What color is this dog")

print(res["answers"][0])
print(res["replies"][0])

>>> white
```
Expand Down Expand Up @@ -82,14 +82,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageQA":
"""
return default_from_dict(cls, data)

@component.output_types(answers=List[str])
@component.output_types(replies=List[str])
def run(self, image: ByteStream, question: str):
"""Prompts model to answer a question about an image.

:param image: The image to ask the question about.
:param question: The question to ask.
:returns: A dictionary with the following keys:
- `answers`: A list of answers to the question.
- `replies`: A list of answers to the question.
"""
answers = self._model.ask_question(image=Image(image.data), question=question, **self._kwargs)
return {"answers": answers}
replies = self._model.ask_question(image=Image(image.data), question=question, **self._kwargs)
return {"replies": replies}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class VertexAITextGenerator:
generator = VertexAITextGenerator(project_id=project_id)
res = generator.run("Tell me a good interview question for a software engineer.")

print(res["answers"][0])
print(res["replies"][0])

>>> **Question:**
>>> You are given a list of integers and a target sum.
Expand Down Expand Up @@ -109,26 +109,26 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator":
)
return default_from_dict(cls, data)

@component.output_types(answers=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]])
@component.output_types(replies=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]])
def run(self, prompt: str):
"""Prompts the model to generate text.

:param prompt: The prompt to use for text generation.
:returns: A dictionary with the following keys:
- `answers`: A list of generated answers.
- `replies`: A list of generated replies.
- `safety_attributes`: A dictionary with the [safety scores](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/responsible-ai#safety_attribute_descriptions)
of each answer.
- `citations`: A list of citations for each answer.
"""
res = self._model.predict(prompt=prompt, **self._kwargs)

answers = []
replies = []
safety_attributes = []
citations = []

for prediction in res.raw_prediction_response.predictions:
answers.append(prediction["content"])
replies.append(prediction["content"])
safety_attributes.append(prediction["safetyAttributes"])
citations.append(prediction["citationMetadata"]["citations"])

return {"answers": answers, "safety_attributes": safety_attributes, "citations": citations}
return {"replies": replies, "safety_attributes": safety_attributes, "citations": citations}