Skip to content

Commit

Permalink
refactor: remove list comprehensions to preserve type-hints (#301)
Browse files Browse the repository at this point in the history
Ref: #299
  • Loading branch information
jezekra1 authored Jan 30, 2024
1 parent 67478a1 commit f10483c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 34 deletions.
2 changes: 1 addition & 1 deletion examples/extensions/localserver/local_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def tokenize(self, input_text: str, parameters: TextTokenizationParameters) -> T

prompts = ["Hello! How are you?", "How's the weather?"]
for response in client.text.generation.create(model_id=FlanT5Model.model_id, inputs=prompts, parameters=parameters):
[result] = response.results
result = response.results[0]
print(f"Prompt: {result.input_text}\nResponse: {result.generated_text}")


Expand Down
2 changes: 1 addition & 1 deletion examples/extra/logging_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def heading(text: str) -> str:
return_options=TextGenerationReturnOptions(input_text=True),
),
):
[result] = response.results
result = response.results[0]
print(f"Prompt: {result.input_text}\nResponse: {result.generated_text}")
48 changes: 22 additions & 26 deletions examples/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,29 @@ def heading(text: str) -> str:
pprint(retrieve_response.result.model_dump())

print(heading("Generate text using prompt"))
[generation_response] = list(
client.text.generation.create(
prompt_id=prompt_id,
parameters=TextGenerationParameters(return_options=TextGenerationReturnOptions(input_text=True)),
)
)
[result] = generation_response.results
print(f"Prompt: {result.input_text}")
print(f"Answer: {result.generated_text}")
for generation_response in client.text.generation.create(
prompt_id=prompt_id,
parameters=TextGenerationParameters(return_options=TextGenerationReturnOptions(input_text=True)),
):
result = generation_response.results[0]
print(f"Prompt: {result.input_text}")
print(f"Answer: {result.generated_text}")

print(heading("Override prompt template variables"))
[generation_response] = list(
client.text.generation.create(
prompt_id=prompt_id,
parameters=TextGenerationParameters(return_options=TextGenerationReturnOptions(input_text=True)),
data={"meal": "pancakes", "author": "Edgar Allan Poe"},
)
)
[result] = generation_response.results
print(f"Prompt: {result.input_text}")
print(f"Answer: {result.generated_text}")
for generation_response in client.text.generation.create(
prompt_id=prompt_id,
parameters=TextGenerationParameters(return_options=TextGenerationReturnOptions(input_text=True)),
data={"meal": "pancakes", "author": "Edgar Allan Poe"},
):
result = generation_response.results[0]
print(f"Prompt: {result.input_text}")
print(f"Answer: {result.generated_text}")

print(heading("Show all existing prompts"))
prompt_list_response = client.prompt.list(search=prompt_name, limit=10, offset=0)
print("Total Count: ", prompt_list_response.total_count)
print("Results: ", prompt_list_response.results)
print(heading("Show all existing prompts"))
prompt_list_response = client.prompt.list(search=prompt_name, limit=10, offset=0)
print("Total Count: ", prompt_list_response.total_count)
print("Results: ", prompt_list_response.results)

print(heading("Delete prompt"))
client.prompt.delete(id=prompt_id)
print("OK")
print(heading("Delete prompt"))
client.prompt.delete(id=prompt_id)
print("OK")
8 changes: 4 additions & 4 deletions examples/text/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,23 @@ def heading(text: str) -> str:
):
print(f"Input text: {input_text}")
assert response.results
[result] = response.results
result = response.results[0]

# HAP
assert result.hap
[hap] = result.hap
hap = result.hap[0]
print("HAP:")
pprint(hap.model_dump())

# Stigma
assert result.stigma
[stigma] = result.stigma
stigma = result.stigma[0]
print("Stigma:")
pprint(stigma.model_dump())

# Implicit Hate
assert result.implicit_hate
[implicit_hate] = result.implicit_hate
implicit_hate = result.implicit_hate[0]
print("Implicit hate:")
pprint(implicit_hate.model_dump())

Expand Down
2 changes: 1 addition & 1 deletion examples/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def upload_files(client: Client, update=True):
prompt = "Return on investment was 5.0 % , compared to a negative 4.1 % in 2009 ."
print("Prompt: ", prompt)
gen_params = TextGenerationParameters(decoding_method=DecodingMethod.SAMPLE, max_new_tokens=1, min_new_tokens=1)
[gen_response] = list(client.text.generation.create(model_id=tune_result.id, inputs=[prompt]))
gen_response = next(client.text.generation.create(model_id=tune_result.id, inputs=[prompt]))
print("Answer: ", gen_response.results[0].generated_text)

print(heading("Get list of tuned models"))
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
"huggingface_agent.py",
"tune.py",
"parallel_processing.py",
"chroma_db_embedding.py",
}
skip_for_python_3_12 = {
# These files are skipped for python >= 3.12 because transformers library cannot be installed
"local_server.py",
"huggingface_agent.py",
"chroma_db_embedding.py",
}

scripts_lt_3_12 = {script for script in all_scripts if script.name not in ignore_files | skip_for_python_3_12}
Expand Down

0 comments on commit f10483c

Please sign in to comment.