diff --git a/README.md b/README.md
index 7c21003..1a8eff2 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,10 @@ Transformer-based zero/few shot learning components for scikit-learn pipelines.
[Documentation](https://centre-for-humanities-computing.github.io/stormtrooper/)
+## New in version 0.4.0 :fire:
+
+- You can now use OpenAI's chat models with blazing fast :zap: async inference.
+
## New in version 0.3.0 🌟
- SetFit is now part of the library and can be used in scikit-learn workflows.
@@ -71,9 +75,24 @@ predictions = classifier.predict(example_texts)
assert list(predictions) == ["atheism/christianity", "astronomy/space"]
```
+OpenAI models:
+You can now use OpenAI's chat LLMs in stormtrooper workflows.
+
+```python
+from stormtrooper import OpenAIZeroShotClassifier
+
+classifier = OpenAIZeroShotClassifier("gpt-4").fit(None, class_labels)
+```
+
+```python
+predictions = classifier.predict(example_texts)
+
+assert list(predictions) == ["atheism/christianity", "astronomy/space"]
+```
+
### Few-Shot Learning
-For few-shot tasks you can only use Generative, Text2Text (aka. promptable) or SetFit models.
+For few-shot tasks you can only use Generative, Text2Text, OpenAI (aka. promptable) or SetFit models.
```python
from stormtrooper import GenerativeFewShotClassifier, Text2TextFewShotClassifier, SetFitFewShotClassifier
diff --git a/docs/_build/doctrees/environment.pickle b/docs/_build/doctrees/environment.pickle
index ef8b905..f2b34f2 100644
Binary files a/docs/_build/doctrees/environment.pickle and b/docs/_build/doctrees/environment.pickle differ
diff --git a/docs/_build/doctrees/index.doctree b/docs/_build/doctrees/index.doctree
index 4a89c72..6f9f0db 100644
Binary files a/docs/_build/doctrees/index.doctree and b/docs/_build/doctrees/index.doctree differ
diff --git a/docs/_build/doctrees/openai.doctree b/docs/_build/doctrees/openai.doctree
new file mode 100644
index 0000000..39c1e34
Binary files /dev/null and b/docs/_build/doctrees/openai.doctree differ
diff --git a/docs/_build/doctrees/prompting.doctree b/docs/_build/doctrees/prompting.doctree
index caeb744..8833138 100644
Binary files a/docs/_build/doctrees/prompting.doctree and b/docs/_build/doctrees/prompting.doctree differ
diff --git a/docs/_build/html/_sources/index.rst.txt b/docs/_build/html/_sources/index.rst.txt
index 20a0d54..a8dd978 100644
--- a/docs/_build/html/_sources/index.rst.txt
+++ b/docs/_build/html/_sources/index.rst.txt
@@ -17,6 +17,14 @@ If you intend to use SetFit models as well, install stormtrooper with optional d
pip install stormtrooper[setfit]
+From version 0.4.0 you can also use OpenAI models in stormtrooper.
+
+.. code-block::
+
+ pip install stormtrooper[openai]
+ export OPENAI_API_KEY="sk-..."
+
+
Usage
^^^^^^^^^
@@ -45,6 +53,7 @@ In this example I am going to use Google's FLAN-T5.
text2text
generative
setfit
+ openai
prompting
inference_on_gpu
diff --git a/docs/_build/html/_sources/openai.rst.txt b/docs/_build/html/_sources/openai.rst.txt
new file mode 100644
index 0000000..1ec22f4
--- /dev/null
+++ b/docs/_build/html/_sources/openai.rst.txt
@@ -0,0 +1,85 @@
+OpenAI models
+=================
+
+Stormtrooper gives you access to OpenAI's chat models for zero and few-shot classification.
+You get full control over temperature settings, system and user prompts.
+In contrast to other packages, like scikit-llm, stormtrooper also uses Python's asyncio to concurrently
+interact with OpenAI's API. This can give multiple times speedup on several tasks.
+
+You can also set upper limits for number of requests and tokens per minute, so you don't exceed your quota.
+This is by default set to the limit of the payed tier on OpenAI's API.
+
+You need to install stormtrooper with optional dependencies.
+
+.. code-block:: bash
+
+ pip install stormtrooper[openai]
+
+You additionally need to set the OpenAI API key as an environment variable.
+
+.. code-block:: bash
+
+ export OPENAI_API_KEY="sk-..."
+ # Setting organization is optional
+ export OPENAI_ORG="org-..."
+
+.. code-block:: python
+
+ from stormtrooper import OpenAIZeroShotClassifier, OpenAIFewShotClassifier
+
+ sample_text = "It is the Electoral College's responsibility to elect the president."
+
+ labels = ["politics", "science", "other"]
+
+Here's a zero shot example with ChatGPT 3.5:
+
+.. code-block:: python
+
+ model = OpenAIZeroShotClassifier("gpt-3.5-turbo").fit(None, labels)
+ predictions = model.predict([sample_text])
+ assert list(predictions) == ["politics"]
+
+And a few shot example with ChatGPT 4:
+
+.. code-block:: python
+
+ few_shot_examples = [
+ "Joe Biden is the president.",
+ "Liquid water was found on the moon.",
+ "Jerry likes football."
+ ]
+
+ model = OpenAIFewShotClassifier("gpt-4", temperature=0.2).fit(few_shot_examples, labels)
+ predictions = model.predict([sample_text])
+ assert list(predictions) == ["politics"]
+
+
+The format of the prompts is the same as with StableBeluga instruct models, and an error is raised if your prompt does not follow
+this format.
+
+.. code-block:: python
+
+ prompt = """
+ ### System:
+ You are a helpful assistant
+ ### User:
+ Your task will be to classify a text document into one
+ of the following classes: {classes}.
+ Please respond with a single label that you think fits
+ the document best.
+ Classify the following piece of text:
+ '{X}'
+ ### Assistant:
+ """
+
+ model = OpenAIZeroShotClassifier("gpt-4", prompt=prompt)
+
+
+API reference
+^^^^^^^^^^^^^
+
+.. autoclass:: stormtrooper.OpenAIZeroShotClassifier
+ :members:
+
+.. autoclass:: stormtrooper.OpenAIFewShotClassifier
+ :members:
diff --git a/docs/_build/html/_sources/prompting.rst.txt b/docs/_build/html/_sources/prompting.rst.txt
index 865f007..d0b33c0 100644
--- a/docs/_build/html/_sources/prompting.rst.txt
+++ b/docs/_build/html/_sources/prompting.rst.txt
@@ -1,7 +1,7 @@
Prompting
=========
-Text2Text and Generative models use a prompting approach for classification.
+Text2Text, Generative, and OpenAI models use a prompting approach for classification.
stormtrooper comes with default prompts, but these might not suit the model you want to use,
or your use case might require a different prompting strategy from the default.
stormtrooper allows you to specify custom prompts in these cases.
@@ -12,7 +12,7 @@ Templates
Prompting in stormtrooper uses a templating approach, where the .format() method is called on prompts to
insert labels and data.
-A zero-shot prompt for an instruct Llama model like Stable Beluga would look something like this (this is the default):
+A zero-shot prompt for an instruct Llama model like Stable Beluga or for ChatGPT would look something like this (this is the default):
.. code-block:: python
diff --git a/docs/_build/html/generative.html b/docs/_build/html/generative.html
index f5660d2..9d7d237 100644
--- a/docs/_build/html/generative.html
+++ b/docs/_build/html/generative.html
@@ -5,7 +5,7 @@
-
+
Stormtrooper gives you access to OpenAI’s chat models for zero and few-shot classification.
+You get full control over temperature settings, system and user prompts.
+In contrast to other packages, like scikit-llm, stormtrooper also uses Python’s asyncio to concurrently
+interact with OpenAI’s API. This can give multiple times speedup on several tasks.
+
You can also set upper limits for number of requests and tokens per minute, so you don’t exceed your quota.
+This is by default set to the limit of the payed tier on OpenAI’s API.
+
You need to install stormtrooper with optional dependencies.
+
pipinstallstormtrooper[openai]
+
+
+
You additionally need to set the OpenAI API key as an environment variable.
+
exportOPENAI_API_KEY="sk-..."
+# Setting organization is optional
+exportOPENAI_ORG="org-..."
+
+
+
fromstormtrooperimportOpenAIZeroShotClassifier,OpenAIFewShotClassifier
+
+sample_text="It is the Electoral College's responsibility to elect the president."
+
+labels=["politics","science","other"]
+
few_shot_examples=[
+ "Joe Biden is the president.",
+ "Liquid water was found on the moon.",
+ "Jerry likes football."
+]
+
+model=OpenAIFewShotClassifier("gpt-4",temperature=0.2).fit(few_shot_examples,labels)
+predictions=model.predict([sample_text])
+assertlist(predictions)==["politics"]
+
+
+
The format of the prompts is the same as with StableBeluga instruct models, and an error is raised if your prompt does not follow
+this format.
+
prompt="""
+### System:
+You are a helpful assistant
+### User:
+Your task will be to classify a text document into one
+of the following classes: {classes}.
+Please respond with a single label that you think fits
+the document best.
+Classify the following piece of text:
+'{X}'
+### Assistant:
+"""
+
+model=OpenAIZeroShotClassifier("gpt-4",prompt=prompt)
+
Scikit-learn compatible zero shot classification
+with OpenAI’s chat language models.
+
+
Parameters:
+
+
model_name (str, default'gpt-3.5-turbo') – Name of OpenAI chat model.
+
temperature (float=1.0) – What sampling temperature to use, between 0 and 2.
+Higher values like 0.8 will make the output more random,
+while lower values like 0.2 will make it
+more focused and deterministic.
+
prompt (str, optional) – You can specify the prompt which will be used to prompt the model.
+Use placeholders to indicate where the class labels and the
+data should be placed in the prompt.
+
max_new_tokens (int, default256) – Maximum number of tokens the model should generate.
+
max_requests_per_minute (int, default3500) – Maximum number of requests to send per minute.
+
max_tokens_per_minute (int, default90_000) – Maximum number of tokens per minute.
+
max_attempts_per_request (int, default5) – Maximum number of times a request shoulb be attempted if it fails
+for the first time.
+
fuzzy_match (bool, defaultTrue) – Indicates whether the output lables should be fuzzy matched
+to the learnt class labels.
+This is useful when the model isn’t giving specific enough answers.
Note that this method is only relevant if
+enable_metadata_routing=True (see sklearn.set_config()).
+Please see User Guide on how the routing
+mechanism works.
+
The options for each parameter are:
+
+
True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.
+
False: metadata is not requested and the meta-estimator will not pass it to score.
+
None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.
+
str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
+
+
The default (sklearn.utils.metadata_routing.UNCHANGED) retains the
+existing request. This allows you to change the request for some
+parameters and not others.
+
+
New in version 1.3.
+
+
+
Note
+
This method is only relevant if this estimator is used as a
+sub-estimator of a meta-estimator, e.g. used inside a
+pipeline.Pipeline. Otherwise it has no effect.
+
+
+
Parameters:
+
sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for sample_weight parameter in score.
Scikit-learn compatible few shot classification
+with OpenAI’s chat language models.
+
+
Parameters:
+
+
model_name (str, default'gpt-3.5-turbo') – Name of OpenAI chat model.
+
temperature (float=1.0) – What sampling temperature to use, between 0 and 2.
+Higher values like 0.8 will make the output more random,
+while lower values like 0.2 will make it
+more focused and deterministic.
+
prompt (str, optional) – You can specify the prompt which will be used to prompt the model.
+Use placeholders to indicate where the class labels and the
+data should be placed in the prompt.
+
max_new_tokens (int, default256) – Maximum number of tokens the model should generate.
+
max_requests_per_minute (int, default3500) – Maximum number of requests to send per minute.
+
max_tokens_per_minute (int, default90_000) – Maximum number of tokens per minute.
+
max_attempts_per_request (int, default5) – Maximum number of times a request shoulb be attempted if it fails
+for the first time.
+
fuzzy_match (bool, defaultTrue) – Indicates whether the output lables should be fuzzy matched
+to the learnt class labels.
+This is useful when the model isn’t giving specific enough answers.
Note that this method is only relevant if
+enable_metadata_routing=True (see sklearn.set_config()).
+Please see User Guide on how the routing
+mechanism works.
+
The options for each parameter are:
+
+
True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.
+
False: metadata is not requested and the meta-estimator will not pass it to score.
+
None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.
+
str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
+
+
The default (sklearn.utils.metadata_routing.UNCHANGED) retains the
+existing request. This allows you to change the request for some
+parameters and not others.
+
+
New in version 1.3.
+
+
+
Note
+
This method is only relevant if this estimator is used as a
+sub-estimator of a meta-estimator, e.g. used inside a
+pipeline.Pipeline. Otherwise it has no effect.
+
+
+
Parameters:
+
sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for sample_weight parameter in score.
+
+
Returns:
+
self – The updated object.
+
+
Return type:
+
object
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/_build/html/prompting.html b/docs/_build/html/prompting.html
index 9fc1849..03f26e7 100644
--- a/docs/_build/html/prompting.html
+++ b/docs/_build/html/prompting.html
@@ -3,9 +3,9 @@
-
+
-
+
Prompting - stormtrooper
@@ -170,6 +170,7 @@
Text2Text and Generative models use a prompting approach for classification.
+
Text2Text, Generative, and OpenAI models use a prompting approach for classification.
stormtrooper comes with default prompts, but these might not suit the model you want to use,
or your use case might require a different prompting strategy from the default.
stormtrooper allows you to specify custom prompts in these cases.
diff --git a/docs/index.rst b/docs/index.rst
index 20a0d54..a8dd978 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -17,6 +17,14 @@ If you intend to use SetFit models as well, install stormtrooper with optional d
pip install stormtrooper[setfit]
+From version 0.4.0 you can also use OpenAI models in stormtrooper.
+
+.. code-block::
+
+ pip install stormtrooper[openai]
+ export OPENAI_API_KEY="sk-..."
+
+
Usage
^^^^^^^^^
@@ -45,6 +53,7 @@ In this example I am going to use Google's FLAN-T5.
text2text
generative
setfit
+ openai
prompting
inference_on_gpu
diff --git a/docs/openai.rst b/docs/openai.rst
new file mode 100644
index 0000000..1ec22f4
--- /dev/null
+++ b/docs/openai.rst
@@ -0,0 +1,85 @@
+OpenAI models
+=================
+
+Stormtrooper gives you access to OpenAI's chat models for zero and few-shot classification.
+You get full control over temperature settings, system and user prompts.
+In contrast to other packages, like scikit-llm, stormtrooper also uses Python's asyncio to concurrently
+interact with OpenAI's API. This can give multiple times speedup on several tasks.
+
+You can also set upper limits for number of requests and tokens per minute, so you don't exceed your quota.
+This is by default set to the limit of the payed tier on OpenAI's API.
+
+You need to install stormtrooper with optional dependencies.
+
+.. code-block:: bash
+
+ pip install stormtrooper[openai]
+
+You additionally need to set the OpenAI API key as an environment variable.
+
+.. code-block:: bash
+
+ export OPENAI_API_KEY="sk-..."
+ # Setting organization is optional
+ export OPENAI_ORG="org-..."
+
+.. code-block:: python
+
+ from stormtrooper import OpenAIZeroShotClassifier, OpenAIFewShotClassifier
+
+ sample_text = "It is the Electoral College's responsibility to elect the president."
+
+ labels = ["politics", "science", "other"]
+
+Here's a zero shot example with ChatGPT 3.5:
+
+.. code-block:: python
+
+ model = OpenAIZeroShotClassifier("gpt-3.5-turbo").fit(None, labels)
+ predictions = model.predict([sample_text])
+ assert list(predictions) == ["politics"]
+
+And a few shot example with ChatGPT 4:
+
+.. code-block:: python
+
+ few_shot_examples = [
+ "Joe Biden is the president.",
+ "Liquid water was found on the moon.",
+ "Jerry likes football."
+ ]
+
+ model = OpenAIFewShotClassifier("gpt-4", temperature=0.2).fit(few_shot_examples, labels)
+ predictions = model.predict([sample_text])
+ assert list(predictions) == ["politics"]
+
+
+The format of the prompts is the same as with StableBeluga instruct models, and an error is raised if your prompt does not follow
+this format.
+
+.. code-block:: python
+
+ prompt = """
+ ### System:
+ You are a helpful assistant
+ ### User:
+ Your task will be to classify a text document into one
+ of the following classes: {classes}.
+ Please respond with a single label that you think fits
+ the document best.
+ Classify the following piece of text:
+ '{X}'
+ ### Assistant:
+ """
+
+ model = OpenAIZeroShotClassifier("gpt-4", prompt=prompt)
+
+
+API reference
+^^^^^^^^^^^^^
+
+.. autoclass:: stormtrooper.OpenAIZeroShotClassifier
+ :members:
+
+.. autoclass:: stormtrooper.OpenAIFewShotClassifier
+ :members:
diff --git a/docs/prompting.rst b/docs/prompting.rst
index 865f007..d0b33c0 100644
--- a/docs/prompting.rst
+++ b/docs/prompting.rst
@@ -1,7 +1,7 @@
Prompting
=========
-Text2Text and Generative models use a prompting approach for classification.
+Text2Text, Generative, and OpenAI models use a prompting approach for classification.
stormtrooper comes with default prompts, but these might not suit the model you want to use,
or your use case might require a different prompting strategy from the default.
stormtrooper allows you to specify custom prompts in these cases.
@@ -12,7 +12,7 @@ Templates
Prompting in stormtrooper uses a templating approach, where the .format() method is called on prompts to
insert labels and data.
-A zero-shot prompt for an instruct Llama model like Stable Beluga would look something like this (this is the default):
+A zero-shot prompt for an instruct Llama model like Stable Beluga or for ChatGPT would look something like this (this is the default):
.. code-block:: python