Skip to content

Commit

Permalink
feat(system-prompts): init module (#292)
Browse files Browse the repository at this point in the history
Ref: #287

Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D authored Jan 29, 2024
1 parent b4e4d4e commit 2aa8351
Show file tree
Hide file tree
Showing 14 changed files with 889 additions and 1 deletion.
3 changes: 3 additions & 0 deletions examples/system_prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
System Prompts
"""
44 changes: 44 additions & 0 deletions examples/system_prompt/system_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Working with system prompts
The system prompt is a pre-defined prompt that helps cue the model to exhibit the desired behavior for a specific task.
"""

from pprint import pprint

from dotenv import load_dotenv

from genai.client import Client
from genai.credentials import Credentials

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint>
load_dotenv()
client = Client(credentials=Credentials.from_env())


def heading(text: str) -> str:
"""Helper function for centering text."""
return "\n" + f" {text} ".center(80, "=") + "\n"


print(heading("Create a system prompt"))
prompt_name = "Simple Verbalizer"
prompt_content = """classify { "label 1", "label 2" } Input: {{input}} Output:"""
create_response = client.system_prompt.create(name=prompt_name, content=prompt_content)
system_prompt_id = create_response.result.id
print(f"System Prompt ID: {system_prompt_id}")

print(heading("Get a system prompt details"))
retrieve_response = client.system_prompt.retrieve(id=system_prompt_id)
pprint(retrieve_response.result.model_dump())

print(heading("Show all existing system prompts"))
system_prompt_list_response = client.system_prompt.list(offset=0, limit=10)
print("Total Count: ", system_prompt_list_response.total_count)
print("Results: ", system_prompt_list_response.results)

print(heading("Delete a system prompt"))
client.system_prompt.delete(id=system_prompt_id)
print("OK")
16 changes: 16 additions & 0 deletions scripts/types_generator/schema_aliases.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,22 @@ alias:
- TextTokenizationCreateRequest_parameters
TextTokenizationReturnOptions:
- TextTokenizationCreateRequest_parameters_return_options
# SYSTEM PROMPTS -------------------------------------------------------------
SystemPrompt:
- SystemPromptCreate_result
- SystemPromptIdRetrieve_result
- SystemPromptIdUpdate_result
- SystemPromptRetrieve_results
SystemPromptAuthor:
- SystemPromptIdRetrieve_result_author
- SystemPromptCreate_result_author
- SystemPromptIdUpdate_result_author
- SystemPromptRetrieve_results_author
SystemPromptType:
- SystemPromptCreate_result_type
- SystemPromptIdRetrieve_result_type
- SystemPromptIdUpdate_result_type
- SystemPromptRetrieve_results_type
replace:
# Override schema by a different equivalent schema
# For example override A: B is means: schemas[A] = schemas[B]; del schemas[B]
Expand Down
69 changes: 69 additions & 0 deletions src/genai/_generated/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,17 @@ class StorageProviderLocation(str, Enum):
US_EAST = "us-east"


class SystemPromptAuthor(ApiBaseModel):
first_name: Optional[str] = None
id: int
last_name: Optional[str] = None


class SystemPromptType(str, Enum):
PRIVATE = "private"
SYSTEM = "system"


class Tasks(ApiBaseModel):
csv_example: Optional[str] = None
file_format_id: Optional[int] = None
Expand Down Expand Up @@ -582,6 +593,38 @@ class RequestIdDeleteParametersQuery(ApiBaseModel):
version: Literal["2023-11-22"] = "2023-11-22"


class SystemPromptRetrieveParametersQuery(ApiBaseModel):
limit: Optional[int] = Field(100, ge=1, le=100)
offset: Optional[int] = Field(0, ge=0)
version: Literal["2023-11-22"] = "2023-11-22"


class SystemPromptCreateParametersQuery(ApiBaseModel):
version: Literal["2023-11-22"] = "2023-11-22"


class SystemPromptCreateRequest(ApiBaseModel):
content: str
name: str


class SystemPromptIdDeleteParametersQuery(ApiBaseModel):
version: Literal["2023-11-22"] = "2023-11-22"


class SystemPromptIdRetrieveParametersQuery(ApiBaseModel):
version: Literal["2023-11-22"] = "2023-11-22"


class SystemPromptIdUpdateParametersQuery(ApiBaseModel):
version: Literal["2023-11-22"] = "2023-11-22"


class SystemPromptIdUpdateRequest(ApiBaseModel):
content: str
name: str


class TaskRetrieveParametersQuery(ApiBaseModel):
tune: Optional[bool] = True
version: Literal["2023-11-22"] = "2023-11-22"
Expand Down Expand Up @@ -895,6 +938,15 @@ class RequestRetrieveResults(ApiBaseModel):
version: Optional[RequestResultVersion] = None


class SystemPrompt(ApiBaseModel):
author: Optional[SystemPromptAuthor] = None
content: str
created_at: AwareDatetime
id: int
name: str
type: SystemPromptType


class TextCreateResponseModeration(ApiBaseModel):
hap: Optional[list[TextModeration]] = None
implicit_hate: Optional[list[TextModeration]] = None
Expand Down Expand Up @@ -1065,6 +1117,23 @@ class RequestRetrieveResponse(ApiBaseModel):
total_count: int


class SystemPromptRetrieveResponse(ApiBaseModel):
results: list[SystemPrompt]
total_count: int


class SystemPromptCreateResponse(ApiBaseModel):
result: SystemPrompt


class SystemPromptIdRetrieveResponse(ApiBaseModel):
result: SystemPrompt


class SystemPromptIdUpdateResponse(ApiBaseModel):
result: SystemPrompt


class TextChatCreateRequest(ApiBaseModel):
conversation_id: Optional[str] = None
messages: Optional[list[BaseMessage]] = Field(None, min_length=1)
Expand Down
30 changes: 30 additions & 0 deletions src/genai/_generated/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,36 @@ class RequestIdDeleteEndpoint(ApiEndpoint):
version: str = "2023-11-22"


class SystemPromptRetrieveEndpoint(ApiEndpoint):
path: str = "/v2/system_prompts"
method: str = "GET"
version: str = "2023-11-22"


class SystemPromptCreateEndpoint(ApiEndpoint):
path: str = "/v2/system_prompts"
method: str = "POST"
version: str = "2023-11-22"


class SystemPromptIdDeleteEndpoint(ApiEndpoint):
path: str = "/v2/system_prompts/{id}"
method: str = "DELETE"
version: str = "2023-11-22"


class SystemPromptIdRetrieveEndpoint(ApiEndpoint):
path: str = "/v2/system_prompts/{id}"
method: str = "GET"
version: str = "2023-11-22"


class SystemPromptIdUpdateEndpoint(ApiEndpoint):
path: str = "/v2/system_prompts/{id}"
method: str = "PUT"
version: str = "2023-11-22"


class TaskRetrieveEndpoint(ApiEndpoint):
path: str = "/v2/tasks"
method: str = "GET"
Expand Down
4 changes: 3 additions & 1 deletion src/genai/_utils/service/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
from abc import ABC
from enum import Enum
from typing import Generic, Optional, TypeVar, Union, cast
from urllib.parse import quote

Expand Down Expand Up @@ -90,13 +91,14 @@ def _log_method_execution(self, name: str, /, **kwargs):
@staticmethod
def _get_endpoint(
endpoint: type[ApiEndpoint],
**params: str,
**params: Union[int, str, Enum],
) -> str:
target_endpoint = endpoint.path
if not target_endpoint:
raise ValueError("Endpoint was not found in the provided config.")

for k, v in params.items():
v = v.value if isinstance(v, Enum) else str(v)
assert_is_not_empty_string(v)
parameter_expression = f"{{{k}}}"
if parameter_expression not in target_endpoint:
Expand Down
4 changes: 4 additions & 0 deletions src/genai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from genai.model import ModelService as _ModelService
from genai.prompt import PromptService as _PromptService
from genai.request import RequestService as _RequestService
from genai.system_prompt import SystemPromptService as _SystemPromptService
from genai.text import TextService as _TextService
from genai.tune import TuneService as _TuneService
from genai.user import UserService as _UserService
Expand All @@ -32,6 +33,7 @@ class BaseServices(BaseServiceServices):
ModelService: type[_ModelService] = _ModelService
FileService: type[_FileService] = _FileService
PromptService: type[_PromptService] = _PromptService
SystemPromptService: type[_SystemPromptService] = _SystemPromptService
UserService: type[_UserService] = _UserService


Expand Down Expand Up @@ -60,6 +62,7 @@ class Client(BaseService[BaseConfig, BaseServices]):
model: An instance of the `ModelService` class for managing models.
file: An instance of the `FileService` class for managing files.
prompt: An instance of the `PromptService` class for working with prompts.
system_prompt: An instance of the `SystemPromptService` class for working with system prompts.
user: An instance of the `UserService` class for managing user-related operations.
"""

Expand Down Expand Up @@ -129,4 +132,5 @@ def __init__(
self.model = services.ModelService(api_client=api_client)
self.file = services.FileService(api_client=api_client)
self.prompt = services.PromptService(api_client=api_client)
self.system_prompt = services.SystemPromptService(api_client=api_client)
self.user = services.UserService(api_client=api_client)
4 changes: 4 additions & 0 deletions src/genai/system_prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Modules containing functionalities related to system prompts"""

from genai.system_prompt.schema import *
from genai.system_prompt.system_prompt_service import *
19 changes: 19 additions & 0 deletions src/genai/system_prompt/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from genai._generated.api import (
SystemPrompt,
SystemPromptAuthor,
SystemPromptCreateResponse,
SystemPromptIdRetrieveResponse,
SystemPromptIdUpdateResponse,
SystemPromptRetrieveResponse,
SystemPromptType,
)

__all__ = [
"SystemPrompt",
"SystemPromptType",
"SystemPromptCreateResponse",
"SystemPromptIdUpdateResponse",
"SystemPromptRetrieveResponse",
"SystemPromptIdRetrieveResponse",
"SystemPromptAuthor",
]
Loading

0 comments on commit 2aa8351

Please sign in to comment.