Skip to content

Commit

Permalink
feat: file type check + small docs improvement.
Browse files Browse the repository at this point in the history
  • Loading branch information
drudilorenzo committed Apr 13, 2024
1 parent 4649eb0 commit 47a8106
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 74 deletions.
77 changes: 47 additions & 30 deletions openai_cost_logger/openai_cost_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"cost"
]

"""OpenAI cost logger"""
"""OpenAI cost logger."""
class OpenAICostLogger:
def __init__(
self,
Expand Down Expand Up @@ -68,6 +68,38 @@ def update_cost(self, response: ChatCompletion) -> None:
self.__write_cost_to_json(response)
self.__validate_cost()


def get_current_cost(self) -> float:
"""Get the current cost of the cost tracker.
Returns:
float: The current cost.
"""
return self.cost


def __get_answer_cost(self, answer: Dict) -> float:
"""Calculate the cost of the answer based on the input and output tokens.
Args:
answer (dict): The response from the model.
Returns:
float: The cost of the answer.
"""
return (self.input_cost * answer.usage.prompt_tokens) / COST_UNIT + \
(self.output_cost * answer.usage.completion_tokens) / COST_UNIT


def __validate_cost(self):
"""Check if the cost exceeds the upperbound and raise an exception if it does.
Raises:
Exception: If the cost exceeds the upperbound.
"""
if self.cost > self.cost_upperbound:
raise Exception(f"Cost exceeded upperbound: {self.cost} > {self.cost_upperbound}")


def __write_cost_to_json(self, response: ChatCompletion) -> None:
"""Write the cost to a json file.
Expand All @@ -81,10 +113,14 @@ def __write_cost_to_json(self, response: ChatCompletion) -> None:
with open(self.filepath, 'w') as file:
json.dump(data, file, indent=4)


def __check_existance_log_folder(self) -> None:
"""Check if the log folder exists and create it if it does not."""
self.filepath.parent.mkdir(parents=True, exist_ok=True)


def __build_log_file(self) -> None:
"""Create the log file with the header."""
log_file_template = {
"experiment_name": self.experiment_name,
"creation_datetime": strftime("%Y-%m-%d %H:%M:%S"),
Expand All @@ -94,41 +130,22 @@ def __build_log_file(self) -> None:
}
with open(self.filepath, 'w') as file:
json.dump(log_file_template, file, indent=4)


def __build_log_breadown_entry(self, response: ChatCompletion) -> Dict:
"""Build a json log entry for the breakdown of the cost.
Args:
response (ChatCompletion): The response from the model.
Returns:
Dict: The json log entry.
"""
return {
"cost": self.__get_answer_cost(response),
"input_tokens": response.usage.prompt_tokens,
"output_tokens": response.usage.completion_tokens,
"content": response.choices[0].message.content,
"inferred_model": response.model,
"datetime": strftime("%Y-%m-%d %H:%M:%S"),
}

def get_current_cost(self) -> float:
"""Get the current cost of the cost tracker.
Returns:
float: The current cost.
"""
return self.cost

def __get_answer_cost(self, answer: Dict) -> float:
"""Calculate the cost of the answer based on the input and output tokens.
Args:
answer (dict): The response from the model.
Returns:
float: The cost of the answer.
"""
return (self.input_cost * answer.usage.prompt_tokens) / COST_UNIT + \
(self.output_cost * answer.usage.completion_tokens) / COST_UNIT

def __validate_cost(self):
"""Check if the cost exceeds the upperbound and raise an exception if it does.
Raises:
Exception: If the cost exceeds the upperbound.
"""
if self.cost > self.cost_upperbound:
raise Exception(f"Cost exceeded upperbound: {self.cost} > {self.cost_upperbound}")
}
83 changes: 39 additions & 44 deletions openai_cost_logger/openai_cost_logger_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from openai_cost_logger.constants import DEFAULT_LOG_PATH

"""Cost logger visualizer."""
class OpenAICostLoggerViz:

@staticmethod
Expand All @@ -22,11 +23,13 @@ def get_total_cost(path: str = DEFAULT_LOG_PATH) -> float:
"""
cost = 0
for filename in os.listdir(path):
with open(Path(path, filename), mode='r') as file:
data = json.load(file)
cost += data["total_cost"]
if filename.endswith(".json"):
with open(Path(path, filename), mode='r') as file:
data = json.load(file)
cost += data["total_cost"]
return cost


@staticmethod
def print_total_cost(path: str = DEFAULT_LOG_PATH) -> None:
"""Print the total cost of all the logs in the directory.
Expand All @@ -35,9 +38,9 @@ def print_total_cost(path: str = DEFAULT_LOG_PATH) -> None:
log_folder (str, optional): Cost logs directory. Defaults to DEFAULT_LOG_PATH.
This method reads all the files in the specified directory.
"""

print(f"Total cost: {round(OpenAICostLoggerViz.get_total_cost(path), 6)} (USD)")



@staticmethod
def get_total_cost_by_model(path: str = DEFAULT_LOG_PATH) -> Dict[str, float]:
"""Return the total cost by model of all the logs in the directory.
Expand All @@ -51,13 +54,15 @@ def get_total_cost_by_model(path: str = DEFAULT_LOG_PATH) -> Dict[str, float]:
"""
cost_by_model = defaultdict(float)
for filename in os.listdir(path):
with open(Path(path, filename), mode='r') as file:
data = json.load(file)
if data["model"] not in cost_by_model:
cost_by_model[data["model"]] = 0
cost_by_model[data["model"]] += data["total_cost"]
if filename.endswith(".json"):
with open(Path(path, filename), mode='r') as file:
data = json.load(file)
if data["model"] not in cost_by_model:
cost_by_model[data["model"]] = 0
cost_by_model[data["model"]] += data["total_cost"]
return cost_by_model



def print_total_cost_by_model(path: str = DEFAULT_LOG_PATH) -> None:
"""Print the total cost by model of all the logs in the directory.
Expand All @@ -68,60 +73,50 @@ def print_total_cost_by_model(path: str = DEFAULT_LOG_PATH) -> None:
cost_by_model = OpenAICostLoggerViz.get_total_cost_by_model(path)
for model, cost in cost_by_model.items():
print(f"{model}: {round(cost, 6)} (USD)")



@staticmethod
def plot_cost_by_day(path: str = DEFAULT_LOG_PATH, last_n_days: int = None) -> None:
"""Plot the cost by day of all the logs in the directory.
def plot_cost_by_strftime(path: str = DEFAULT_LOG_PATH, strftime_aggregator: str = "%Y-%m-%d", last_n_days: int = None) -> None:
"""Plot the cost by day of all the logs in the directory aggregated using strftime_aggregator.
Args:
path (str, optional): Cost logs directory. Defaults to DEFAULT_LOG_PATH.
This method reads all the files in the specified directory.
last_n_days (int, optional): The number of last days to plot. Defaults to None.
"""
cost_by_day = defaultdict(float)
cost_by_aggregation_key = defaultdict(float)
for filename in os.listdir(path):
with open(Path(path, filename), mode='r') as file:
data = json.load(file)
creation_datetime = data["creation_datetime"]
day = creation_datetime.split(' ')[0]
cost_by_day[day] += data["total_cost"]
if filename.endswith(".json"):
with open(Path(path, filename), mode='r') as file:
data = json.load(file)
creation_datetime = datetime.strptime(data["creation_datetime"], "%Y-%m-%d %H:%M:%S")
aggregation_key = creation_datetime.strftime(strftime_aggregator)
cost_by_aggregation_key[aggregation_key] += data["total_cost"]

cost_by_day = dict(sorted(cost_by_day.items(), key=lambda x: x[0]))
cost_by_aggregation_key = dict(sorted(cost_by_aggregation_key.items(), key=lambda x: x[0]))
if last_n_days:
cost_by_day = dict(list(cost_by_day.items())[-last_n_days:])
cost_by_aggregation_key = dict(list(cost_by_aggregation_key.items())[-last_n_days:])

plt.bar(cost_by_day.keys(), cost_by_day.values(), width=0.5)
plt.bar(cost_by_aggregation_key.keys(), cost_by_aggregation_key.values(), width=0.5)
plt.xticks(rotation=30, fontsize=8)
plt.xlabel('Day')
plt.ylabel('Cost [$]')
plt.title('Cost by day')
plt.tight_layout()
plt.show()



@staticmethod
def plot_cost_by_strftime(path: str = DEFAULT_LOG_PATH, strftime_aggregator: str = "%Y-%m-%d", last_n_days: int = None) -> None:
def plot_cost_by_day(path: str = DEFAULT_LOG_PATH, last_n_days: int = None) -> None:
"""Plot the cost by day of all the logs in the directory.
Args:
path (str, optional): Cost logs directory. Defaults to DEFAULT_LOG_PATH.
This method reads all the files in the specified directory.
last_n_days (int, optional): The number of last days to plot. Defaults to None.
"""
cost_by_aggregation_key = defaultdict(float)
for filename in os.listdir(path):
with open(Path(path, filename), mode='r') as file:
data = json.load(file)
creation_datetime = datetime.strptime(data["creation_datetime"], "%Y-%m-%d %H:%M:%S")
aggregation_key = creation_datetime.strftime(strftime_aggregator)
cost_by_aggregation_key[aggregation_key] += data["total_cost"]

cost_by_aggregation_key = dict(sorted(cost_by_aggregation_key.items(), key=lambda x: x[0]))
if last_n_days:
cost_by_aggregation_key = dict(list(cost_by_aggregation_key.items())[-last_n_days:])

plt.bar(cost_by_aggregation_key.keys(), cost_by_aggregation_key.values(), width=0.5)
plt.xticks(rotation=30, fontsize=8)
plt.xlabel('Day')
plt.ylabel('Cost [$]')
plt.title('Cost by day')
plt.tight_layout()
plt.show()
OpenAICostLoggerViz.plot_cost_by_strftime(
path=path,
strftime_aggregator="%Y-%m-%d",
last_n_days=last_n_days
)

0 comments on commit 47a8106

Please sign in to comment.