Skip to content

Commit

Permalink
Merge pull request middlewarehq#548 from VipinDevelops/fix/middleware…
Browse files Browse the repository at this point in the history
…hq#546

[FIX]: Handle Error response from AI API
  • Loading branch information
VipinDevelops authored Sep 17, 2024
2 parents 592647b + 6c0bd76 commit c5657d0
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 58 deletions.
62 changes: 42 additions & 20 deletions backend/analytics_server/mhq/service/ai/ai_analytics_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import requests
from http import HTTPStatus
from enum import Enum
from typing import Dict, List
from typing import Dict, List, Union


class AIProvider(Enum):
Expand Down Expand Up @@ -44,7 +44,30 @@ def __init__(self, llm: LLM, access_token: str):
def _get_message(self, message: str, role: str = "user"):
return {"role": role, "content": message}

def _handle_api_response(self, response) -> Dict[str, Union[str, int]]:
"""
Handles the API response, returning a success or error structure that the frontend can use.
"""
if response.status_code == HTTPStatus.OK:
return {
"status": "success",
"data": response.json()["choices"][0]["message"]["content"],
}
elif response.status_code == HTTPStatus.UNAUTHORIZED:
return {
"status": "error",
"message": "Unauthorized Access: Your access token is either missing, expired, or invalid. Please ensure that you are providing a valid token. ",
}
else:
return {
"status": "error",
"message": f"Unexpected error: {response.text}",
}

def _open_ai_fetch_completion_open_ai(self, messages: List[Dict[str, str]]):
"""
Handles the request to OpenAI API for fetching completions.
"""
payload = {
"model": self.LLM_NAME_TO_MODEL_MAP[self._llm],
"temperature": 0.6,
Expand All @@ -53,13 +76,12 @@ def _open_ai_fetch_completion_open_ai(self, messages: List[Dict[str, str]]):
api_url = "https://api.openai.com/v1/chat/completions"
response = requests.post(api_url, headers=self._headers, json=payload)

print(payload, api_url, response)
if response.status_code != HTTPStatus.OK:
raise Exception(response.json())

return response.json()
return self._handle_api_response(response)

def _fireworks_ai_fetch_completions(self, messages: List[Dict[str, str]]):
"""
Handles the request to Fireworks AI API for fetching completions.
"""
payload = {
"model": self.LLM_NAME_TO_MODEL_MAP[self._llm],
"temperature": 0.6,
Expand All @@ -73,28 +95,28 @@ def _fireworks_ai_fetch_completions(self, messages: List[Dict[str, str]]):
api_url = "https://api.fireworks.ai/inference/v1/chat/completions"
response = requests.post(api_url, headers=self._headers, json=payload)

if response.status_code != HTTPStatus.OK:
raise Exception(response.json())

return response.json()

def _fetch_completion(self, messages: List[Dict[str, str]]):
return self._handle_api_response(response)

def _fetch_completion(
self, messages: List[Dict[str, str]]
) -> Dict[str, Union[str, int]]:
"""
Fetches the completion using the appropriate AI provider based on the LLM.
"""
if self._ai_provider == AIProvider.FIREWORKS_AI:
return self._fireworks_ai_fetch_completions(messages)["choices"][0][
"message"
]["content"]
return self._fireworks_ai_fetch_completions(messages)

if self._ai_provider == AIProvider.OPEN_AI:
return self._open_ai_fetch_completion_open_ai(messages)["choices"][0][
"message"
]["content"]
return self._open_ai_fetch_completion_open_ai(messages)

raise Exception(f"Invalid AI provider {self._ai_provider}")
return {
"status": "error",
"message": f"Invalid AI provider {self._ai_provider}",
}

def get_dora_metrics_score(
self, four_keys_data: Dict[str, float]
) -> Dict[str, str]:
) -> Dict[str, Union[str, int]]:
"""
Calculate the DORA metrics score using input data and an LLM (Language Learning Model).
Expand Down
104 changes: 66 additions & 38 deletions web-server/pages/api/internal/ai/dora_metrics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,50 +64,78 @@ const postSchema = yup.object().shape({
});

const endpoint = new Endpoint(nullSchema);

endpoint.handle.POST(postSchema, async (req, res) => {
const { data, model, access_token } = req.payload;
const dora_data = data as TeamDoraMetricsApiResponseType;

try {
const [
doraMetricsScore,
leadTimeSummary,
CFRSummary,
MTTRSummary,
deploymentFrequencySummary,
doraTrendSummary
] = await Promise.all(
[
getDoraMetricsScore,
getLeadTimeSummary,
getCFRSummary,
getMTTRSummary,
getDeploymentFrequencySummary,
getDoraTrendsCorrelationSummary
].map((fn) => fn(dora_data, model, access_token))
);

const aggregatedData = {
...doraMetricsScore,
...leadTimeSummary,
...CFRSummary,
...MTTRSummary,
...deploymentFrequencySummary,
...doraTrendSummary
};

const dora_data = data as unknown as TeamDoraMetricsApiResponseType;

const [
dora_metrics_score,
lead_time_trends_summary,
change_failure_rate_trends_summary,
mean_time_to_recovery_trends_summary,
deployment_frequency_trends_summary,
dora_trend_summary
] = await Promise.all(
[
getDoraMetricsScore,
getLeadTimeSummary,
getCFRSummary,
getMTTRSummary,
getDeploymentFrequencySummary,
getDoraTrendsCorrelationSummary
].map((f) => f(dora_data, model, access_token))
);
const compiledSummary = await getDORACompiledSummary(
aggregatedData,
model,
access_token
);

const aggregated_dora_data = {
...dora_metrics_score,
...lead_time_trends_summary,
...change_failure_rate_trends_summary,
...mean_time_to_recovery_trends_summary,
...deployment_frequency_trends_summary,
...dora_trend_summary
} as AggregatedDORAData;

const dora_compiled_summary = await getDORACompiledSummary(
aggregated_dora_data,
model,
access_token
);
const responses = {
...aggregatedData,
...compiledSummary
};

res.send({
...aggregated_dora_data,
...dora_compiled_summary
});
const { status, message } = checkForErrors(responses);

if (status === 'error') {
return res.status(400).send({ message });
}

const simplifiedData = Object.fromEntries(
Object.entries(responses).map(([key, value]) => [key, value.data])
);

return res.status(200).send(simplifiedData);
} catch (error) {
return res.status(500).send({
message: 'Internal Server Error',
error: error.message
});
}
});
const checkForErrors = (
responses: Record<string, { status: string; message: string }>
): { status: string; message: string } => {
const errorResponse = Object.values(responses).find(
(value) => value.status === 'error'
);

return errorResponse
? { status: 'error', message: errorResponse.message }
: { status: 'success', message: '' };
};

const getDoraMetricsScore = (
dora_data: TeamDoraMetricsApiResponseType,
Expand Down

0 comments on commit c5657d0

Please sign in to comment.