Skip to content

Commit

Permalink
Fixed the rest of ruff's claims.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrykWyzgowski committed Oct 11, 2024
1 parent 4d2d9a0 commit c98bda1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def client(self) -> LLMClient:
Client for the LLM.
"""

def count_tokens(self, prompt: BasePrompt) -> int:
def count_tokens(self, prompt: BasePrompt) -> int: # noqa: PLR6301
"""
Counts tokens in the prompt.
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/llms/clients/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class LocalLLMOptions(LLMOptions):
"""
Dataclass that represents all available LLM call options for the local LLM client.
Each of them is described in the [HuggingFace documentation]
(https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation). # noqa: E501
"""
(https://huggingface.co/docs/huggingface_hub/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation).
""" # noqa: E501

repetition_penalty: float | None | NotGiven = NOT_GIVEN
do_sample: bool | None | NotGiven = NOT_GIVEN
Expand Down
2 changes: 1 addition & 1 deletion packages/ragbits-core/src/ragbits/core/prompt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def json_mode(self) -> bool:
"""
return self.output_schema() is not None

def output_schema(self) -> dict | type[BaseModel] | None:
def output_schema(self) -> dict | type[BaseModel] | None: # noqa: PLR6301
"""
Returns the schema of the desired output. Can be used to request structured output from the LLM API
or to validate the output. Can return either a Pydantic model or a JSON schema.
Expand Down
14 changes: 8 additions & 6 deletions packages/ragbits-core/src/ragbits/core/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _get_io_types(cls) -> tuple:

@classmethod
def _parse_template(cls, template: str) -> Template:
env = Environment() # nosec B701 - HTML autoescaping not needed for plain text #noqa: S701
env = Environment(autoescape=True)
ast = env.parse(template)
template_variables = meta.find_undeclared_variables(ast)
input_fields = cls.input_type.model_fields.keys() if cls.input_type else set()
Expand Down Expand Up @@ -169,15 +169,17 @@ def list_few_shots(self) -> ChatFormat:
result: ChatFormat = []
for user_message, assistant_message in self.few_shots + self._instace_few_shots:
if not isinstance(user_message, str):
user_message = self._render_template(self.user_prompt_template, user_message)
user_content = self._render_template(self.user_prompt_template, user_message)
else:
user_content = user_message

if isinstance(assistant_message, BaseModel):
assistant_message = assistant_message.model_dump_json()
assistant_content = assistant_message.model_dump_json()
else:
assistant_message = str(assistant_message)
assistant_content = str(assistant_message)

result.append({"role": "user", "content": user_message})
result.append({"role": "assistant", "content": assistant_message})
result.append({"role": "user", "content": user_content})
result.append({"role": "assistant", "content": assistant_content})
return result

def output_schema(self) -> dict | type[BaseModel] | None:
Expand Down

0 comments on commit c98bda1

Please sign in to comment.