From 8669f8d52db4c501f1a72a13df793f2639fa3c9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Thu, 21 Nov 2024 08:57:32 +0100 Subject: [PATCH] Added missing fuzzy match to OpenAI classifiers --- stormtrooper/chat.py | 9 ++++++++- stormtrooper/openai.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/stormtrooper/chat.py b/stormtrooper/chat.py index 8ff6e09..344c4d0 100644 --- a/stormtrooper/chat.py +++ b/stormtrooper/chat.py @@ -68,9 +68,16 @@ def fit(self, X: Optional[Iterable[str]], y: Iterable[str]): self.n_classes = len(self.classes_) return self + def fuzzy_match_label(self, label: str) -> str: + if label not in self.classes_: + label, _ = process.extractOne(label, self.classes_) + return label + def get_user_prompt(self, text: str) -> str: if getattr(self, "classes_", None) is None: - raise NotFittedError("No class labels have been learnt yet, fit the model.") + raise NotFittedError( + "No class labels have been learnt yet, fit the model." + ) if getattr(self, "examples_", None) is not None: text_examples = [] for label, examples in self.examples_.items(): diff --git a/stormtrooper/openai.py b/stormtrooper/openai.py index bed923f..f107cf5 100644 --- a/stormtrooper/openai.py +++ b/stormtrooper/openai.py @@ -76,7 +76,10 @@ async def predict_one_async(self, text: str) -> str: return response.choices[0].message.content def predict_one(self, text: str) -> str: - return asyncio.run(self.predict_one_async(text)) + label = asyncio.run(self.predict_one_async(text)) + if self.fuzzy_match: + label = self.fuzzy_match_label(label) + return label async def predict_async(self, X: Iterable[str]) -> np.ndarray: if self.classes_ is None: @@ -101,4 +104,7 @@ def predict(self, X: Iterable[str]) -> np.ndarray: array of shape (n_texts) Array of string class labels. """ - return asyncio.run(self.predict_async(X)) + labels = asyncio.run(self.predict_async(X)) + if self.fuzzy_match: + labels = [self.fuzzy_match_label(label) for label in labels] + return labels