Skip to content

Commit

Permalink
fix(prompt/mixin): Add name property and add it to saving/loading pat…
Browse files Browse the repository at this point in the history
…h. (explodinggradients#1853)

This pull request includes several changes to the `PromptMixin` class in
the `src/ragas/prompt/mixin.py` file. The changes focus on adding a
`name` attribute to the class and using this attribute when saving and
loading prompts. This solves the error when saving and loading several
prompts of different Synthesizers (e.g.
MultiHopAbstractQuerySynthesizer, MultiHopSpecificQuerySynthesizer,
SingleHopSpecificQuerySynthesizer etc.) as they had the same path
associated:

```
themes_personas_matching_prompt_english -> single_hop_specifc_query_synthesizer_themes_personas_matching_prompt_english
query_answer_generation_prompt_english -> single_hop_specifc_query_synthesizer_query_answer_generation_prompt_english
```

### Changes to `PromptMixin` class:

* Added a `name` attribute to the `PromptMixin` class.
* Modified the `save_prompts` method to include the `name` attribute in
the prompt file name.
* Modified the `load_prompts` method to include the `name` attribute in
the prompt file name.

---------

Co-authored-by: jjmachan <[email protected]>
Co-authored-by: Jithin James <[email protected]>
  • Loading branch information
3 people authored and sahusiddharth committed Jan 24, 2025
1 parent 69da5fd commit d277700
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions src/ragas/prompt/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ class PromptMixin:
eg: [BaseSynthesizer][ragas.testset.synthesizers.base.BaseSynthesizer], [MetricWithLLM][ragas.metrics.base.MetricWithLLM]
"""

def _get_prompts(self) -> t.Dict[str, PydanticPrompt]:
name: str = ""

def _get_prompts(self) -> t.Dict[str, PydanticPrompt]:
prompts = {}
for key, value in inspect.getmembers(self):
if isinstance(value, PydanticPrompt):
Expand Down Expand Up @@ -90,10 +91,13 @@ def save_prompts(self, path: str):
prompts = self.get_prompts()
for prompt_name, prompt in prompts.items():
# hash_hex = f"0x{hash(prompt) & 0xFFFFFFFFFFFFFFFF:016x}"
prompt_file_name = os.path.join(
path, f"{prompt_name}_{prompt.language}.json"
)
prompt.save(prompt_file_name)
if self.name == "":
file_name = os.path.join(path, f"{prompt_name}_{prompt.language}.json")
else:
file_name = os.path.join(
path, f"{self.name}_{prompt_name}_{prompt.language}.json"
)
prompt.save(file_name)

def load_prompts(self, path: str, language: t.Optional[str] = None):
"""
Expand All @@ -113,7 +117,12 @@ def load_prompts(self, path: str, language: t.Optional[str] = None):

loaded_prompts = {}
for prompt_name, prompt in self.get_prompts().items():
prompt_file_name = os.path.join(path, f"{prompt_name}_{language}.json")
loaded_prompt = prompt.__class__.load(prompt_file_name)
if self.name == "":
file_name = os.path.join(path, f"{prompt_name}_{language}.json")
else:
file_name = os.path.join(
path, f"{self.name}_{prompt_name}_{language}.json"
)
loaded_prompt = prompt.__class__.load(file_name)
loaded_prompts[prompt_name] = loaded_prompt
return loaded_prompts

0 comments on commit d277700

Please sign in to comment.