From e981542ac3d396e044ebe8a12aa6d79583241f70 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 20 Feb 2024 19:34:14 +0530 Subject: [PATCH] Fix pricing for more recipes --- recipes/DocSearch.py | 5 +++-- recipes/DocSummary.py | 3 +++ recipes/FaceInpainting.py | 6 ++++-- recipes/GoogleImageGen.py | 3 +++ recipes/Img2Img.py | 6 ++++-- recipes/ObjectInpainting.py | 6 ++++-- recipes/SmartGPT.py | 3 +++ recipes/Text2Audio.py | 3 +++ 8 files changed, 27 insertions(+), 8 deletions(-) diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 6d7977c25..9e384e90e 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -205,9 +205,10 @@ def run_v2( def get_raw_price(self, state: dict) -> float: name = state.get("selected_model") try: - return llm_price[LargeLanguageModels[name]] * 2 + unit_price = llm_price[LargeLanguageModels[name]] * 2 except KeyError: - return 10 + unit_price = 10 + return unit_price * state.get("num_outputs", 1) def render_documents(state, label="**Documents**", *, key="documents"): diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 476bd5da6..fa4dbd427 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -181,6 +181,9 @@ def run(self, state: dict) -> typing.Iterator[str | None]: case _: raise NotImplementedError(f"{chain_type} not implemented") + def get_raw_price(self, state: dict) -> float: + return self.price * state.get("num_outputs", 1) + MAP_REDUCE_PROMPT = """ {documents} diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index 6d787bba6..2f1b6a0b0 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -336,6 +336,8 @@ def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") match selected_model: case InpaintingModels.dall_e.name: - return 20 + unit_price = 20 case _: - return 5 + unit_price = 5 + + return unit_price * state.get("num_outputs", 1) diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index 278128a37..a935e54cb 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -238,3 +238,6 @@ def render_example(self, state: dict): def preview_description(self, state: dict) -> str: return "Enter a Google Image Search query + your Img2Img text prompt describing how to alter the result to create a unique, relevant ai generated images for any search query." + + def get_raw_price(self, state: dict) -> float: + return super().get_raw_price(state) * state.get("num_outputs", 1) diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index e2713aaa2..2fbd982c3 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -202,6 +202,8 @@ def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") match selected_model: case Img2ImgModels.dall_e.name: - return 20 + unit_price = 20 case _: - return 5 + unit_price = 5 + + return unit_price * state.get("num_outputs", 1) diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index a1e0c2449..db12760bf 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -314,6 +314,8 @@ def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") match selected_model: case InpaintingModels.dall_e.name: - return 20 + unit_price = 20 case _: - return 5 + unit_price = 5 + + return unit_price * state.get("num_outputs", 1) diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index 0aabcfefb..10c4a3fe7 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -204,6 +204,9 @@ def render_steps(self): def preview_description(self, state: dict) -> str: return "SmartGPT is a cutting-edge AI technology that can be used to generate natural language responses to any given input. We have combined the power of [CoT](https://arxiv.org/abs/2305.02897), [Reflexion](https://arxiv.org/abs/2303.11366) & [DERA](https://arxiv.org/abs/2303.17071) into one pipeline so that you can use ChatGPT to its full potential! Input your prompt + a reflection/research prompt + a resolver prompt to use SmartGPT for enhanced text generation, natural language and incredible question-answer results." + def get_raw_price(self, state: dict) -> float: + return self.price * state.get("num_outputs", 1) + def answers_as_prompt(texts: list[str], sep="\n\n") -> str: return sep.join( diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index f7ddf5aab..5140fd2c2 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -140,6 +140,9 @@ def render_example(self, state: dict): def preview_description(self, state: dict) -> str: return "Generate AI Music with text instruction prompts. AudiLDM is capable of generating realistic audio samples by process any text input. Learn more [here](https://huggingface.co/cvssp/audioldm-m-full)." + def get_raw_price(self, state: dict) -> float: + return super().get_raw_price(state) * state.get("num_outputs", 1) + def _render_output(state): selected_models = state.get("selected_models", [])