Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gkaretka committed Nov 4, 2024
1 parent a0e7668 commit dd3ccfb
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions prompterator/models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def build_response_format(json_schema):

@staticmethod
def enrich_model_params_of_function_calling(structured_output_config, model_params):
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
schema = json.loads(structured_output_data.schema)
if structured_output_config.enabled:
if structured_output_config.method == soi.FUNCTION_CALLING:
schema = json.loads(structured_output_config.schema)
function_name = ChatGPTMixin.get_function_calling_tooling_name(schema)

model_params["tools"] = ChatGPTMixin.build_function_calling_tooling(
Expand All @@ -140,17 +140,17 @@ def enrich_model_params_of_function_calling(structured_output_config, model_para
"type": "function",
"function": {"name": function_name},
}
if structured_output_data.method == soi.RESPONSE_FORMAT:
schema = json.loads(structured_output_data.schema)
if structured_output_config.method == soi.RESPONSE_FORMAT:
schema = json.loads(structured_output_config.schema)
model_params["response_format"] = ChatGPTMixin.build_response_format(schema)
return model_params

@staticmethod
def process_response(structured_output_config, response_data):
if structured_output_data.enabled:
if structured_output_data.method == soi.FUNCTION_CALLING:
if structured_output_config.enabled:
if structured_output_config.method == soi.FUNCTION_CALLING:
response_text = response_data.choices[0].message.tool_calls[0].function.arguments
elif structured_output_data.method == soi.RESPONSE_FORMAT:
elif structured_output_config.method == soi.RESPONSE_FORMAT:
response_text = response_data.choices[0].message.content
else:
response_text = response_data.choices[0].message.content
Expand All @@ -159,16 +159,17 @@ def process_response(structured_output_config, response_data):
return response_text

def call(self, idx, input, **kwargs):
structured_output_data: StructuredOutputConfig = kwargs["structured_output"]
structured_output_config: StructuredOutputConfig = kwargs["structured_output"]
model_params = kwargs["model_params"]

try:
model_params = ChatGPTMixin.enrich_model_params_of_function_calling(
structured_output_data, model_params
structured_output_config, model_params
)
except json.JSONDecodeError as e:
logger.error(
"Error occurred while loading provided json schema"
f"Error occurred while loading provided json schema. "
"Provided schema {structured_output_config.schema}"
"%d. Returning an empty response.",
idx,
exc_info=e,
Expand Down Expand Up @@ -198,11 +199,13 @@ def call(self, idx, input, **kwargs):
return {"idx": idx}

try:
response_text = ChatGPTMixin.process_response(structured_output_data, response_data)
response_text = ChatGPTMixin.process_response(structured_output_config, response_data)
return {"response": response_text, "data": response_data, "idx": idx}
except KeyError as e:
logger.error(
"Error occurred while processing response, response does not follow expected format"
"Error occurred while processing response,"
"response does not follow expected format"
f"Response: {response_data}"
"%d. Returning an empty response.",
idx,
exc_info=e,
Expand Down

0 comments on commit dd3ccfb

Please sign in to comment.