Skip to content

Commit

Permalink
refactor: GoogleAIGeminiGenerator - make some attributes public (#1317)
Browse files Browse the repository at this point in the history
* Make variables public

* Fix tests

* Change back to self._model

* Change to self._api_key
  • Loading branch information
sjrl authored Jan 23, 2025
1 parent b1d5375 commit d2a10fd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def __init__(
genai.configure(api_key=api_key.resolve_value())

self._api_key = api_key
self._model_name = model
self._generation_config = generation_config
self._safety_settings = safety_settings
self._model = GenerativeModel(self._model_name)
self._streaming_callback = streaming_callback
self.model_name = model
self.generation_config = generation_config
self.safety_settings = safety_settings
self._model = GenerativeModel(self.model_name)
self.streaming_callback = streaming_callback

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
Expand All @@ -126,13 +126,13 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
data = default_to_dict(
self,
api_key=self._api_key.to_dict(),
model=self._model_name,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
model=self.model_name,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
streaming_callback=callback_name,
)
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
Expand Down Expand Up @@ -198,13 +198,13 @@ def run(
"""

# check if streaming_callback is passed
streaming_callback = streaming_callback or self._streaming_callback
streaming_callback = streaming_callback or self.streaming_callback
converted_parts = [self._convert_part(p) for p in parts]
contents = [Content(parts=converted_parts, role="user")]
res = self._model.generate_content(
contents=contents,
generation_config=self._generation_config,
safety_settings=self._safety_settings,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
stream=streaming_callback is not None,
)
self._model.start_chat()
Expand Down
28 changes: 14 additions & 14 deletions integrations/google_ai/tests/generators/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_init(monkeypatch):
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
top_k=1,
)
safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}

Expand All @@ -28,9 +28,9 @@ def test_init(monkeypatch):
safety_settings=safety_settings,
)
mock_genai_configure.assert_called_once_with(api_key="test")
assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == generation_config
assert gemini._safety_settings == safety_settings
assert gemini.model_name == "gemini-1.5-flash"
assert gemini.generation_config == generation_config
assert gemini.safety_settings == safety_settings
assert isinstance(gemini._model, GenerativeModel)


Expand Down Expand Up @@ -105,7 +105,7 @@ def test_from_dict_with_param(monkeypatch):
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 0.5,
"top_k": 1,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
Expand All @@ -116,16 +116,16 @@ def test_from_dict_with_param(monkeypatch):
}
)

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == GenerationConfig(
assert gemini.model_name == "gemini-1.5-flash"
assert gemini.generation_config == GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
top_k=1,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini.safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert isinstance(gemini._model, GenerativeModel)


Expand All @@ -141,7 +141,7 @@ def test_from_dict(monkeypatch):
"generation_config": {
"temperature": 0.5,
"top_p": 0.5,
"top_k": 0.5,
"top_k": 1,
"candidate_count": 1,
"max_output_tokens": 10,
"stop_sequences": ["stop"],
Expand All @@ -152,16 +152,16 @@ def test_from_dict(monkeypatch):
}
)

assert gemini._model_name == "gemini-1.5-flash"
assert gemini._generation_config == GenerationConfig(
assert gemini.model_name == "gemini-1.5-flash"
assert gemini.generation_config == GenerationConfig(
candidate_count=1,
stop_sequences=["stop"],
max_output_tokens=10,
temperature=0.5,
top_p=0.5,
top_k=0.5,
top_k=1,
)
assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert gemini.safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}
assert isinstance(gemini._model, GenerativeModel)


Expand Down

0 comments on commit d2a10fd

Please sign in to comment.