diff --git a/README.md b/README.md index 7c21003..1a8eff2 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,10 @@ Transformer-based zero/few shot learning components for scikit-learn pipelines. [Documentation](https://centre-for-humanities-computing.github.io/stormtrooper/) +## New in version 0.4.0 :fire: + +- You can now use OpenAI's chat models with blazing fast :zap: async inference. + ## New in version 0.3.0 🌟 - SetFit is now part of the library and can be used in scikit-learn workflows. @@ -71,9 +75,24 @@ predictions = classifier.predict(example_texts) assert list(predictions) == ["atheism/christianity", "astronomy/space"] ``` +OpenAI models: +You can now use OpenAI's chat LLMs in stormtrooper workflows. + +```python +from stormtrooper import OpenAIZeroShotClassifier + +classifier = OpenAIZeroShotClassifier("gpt-4").fit(None, class_labels) +``` + +```python +predictions = classifier.predict(example_texts) + +assert list(predictions) == ["atheism/christianity", "astronomy/space"] +``` + ### Few-Shot Learning -For few-shot tasks you can only use Generative, Text2Text (aka. promptable) or SetFit models. +For few-shot tasks you can only use Generative, Text2Text, OpenAI (aka. promptable) or SetFit models. ```python from stormtrooper import GenerativeFewShotClassifier, Text2TextFewShotClassifier, SetFitFewShotClassifier diff --git a/docs/_build/doctrees/environment.pickle b/docs/_build/doctrees/environment.pickle index ef8b905..f2b34f2 100644 Binary files a/docs/_build/doctrees/environment.pickle and b/docs/_build/doctrees/environment.pickle differ diff --git a/docs/_build/doctrees/index.doctree b/docs/_build/doctrees/index.doctree index 4a89c72..6f9f0db 100644 Binary files a/docs/_build/doctrees/index.doctree and b/docs/_build/doctrees/index.doctree differ diff --git a/docs/_build/doctrees/openai.doctree b/docs/_build/doctrees/openai.doctree new file mode 100644 index 0000000..39c1e34 Binary files /dev/null and b/docs/_build/doctrees/openai.doctree differ diff --git a/docs/_build/doctrees/prompting.doctree b/docs/_build/doctrees/prompting.doctree index caeb744..8833138 100644 Binary files a/docs/_build/doctrees/prompting.doctree and b/docs/_build/doctrees/prompting.doctree differ diff --git a/docs/_build/html/_sources/index.rst.txt b/docs/_build/html/_sources/index.rst.txt index 20a0d54..a8dd978 100644 --- a/docs/_build/html/_sources/index.rst.txt +++ b/docs/_build/html/_sources/index.rst.txt @@ -17,6 +17,14 @@ If you intend to use SetFit models as well, install stormtrooper with optional d pip install stormtrooper[setfit] +From version 0.4.0 you can also use OpenAI models in stormtrooper. + +.. code-block:: + + pip install stormtrooper[openai] + export OPENAI_API_KEY="sk-..." + + Usage ^^^^^^^^^ @@ -45,6 +53,7 @@ In this example I am going to use Google's FLAN-T5. text2text generative setfit + openai prompting inference_on_gpu diff --git a/docs/_build/html/_sources/openai.rst.txt b/docs/_build/html/_sources/openai.rst.txt new file mode 100644 index 0000000..1ec22f4 --- /dev/null +++ b/docs/_build/html/_sources/openai.rst.txt @@ -0,0 +1,85 @@ +OpenAI models +================= + +Stormtrooper gives you access to OpenAI's chat models for zero and few-shot classification. +You get full control over temperature settings, system and user prompts. +In contrast to other packages, like scikit-llm, stormtrooper also uses Python's asyncio to concurrently +interact with OpenAI's API. This can give multiple times speedup on several tasks. + +You can also set upper limits for number of requests and tokens per minute, so you don't exceed your quota. +This is by default set to the limit of the payed tier on OpenAI's API. + +You need to install stormtrooper with optional dependencies. + +.. code-block:: bash + + pip install stormtrooper[openai] + +You additionally need to set the OpenAI API key as an environment variable. + +.. code-block:: bash + + export OPENAI_API_KEY="sk-..." + # Setting organization is optional + export OPENAI_ORG="org-..." + +.. code-block:: python + + from stormtrooper import OpenAIZeroShotClassifier, OpenAIFewShotClassifier + + sample_text = "It is the Electoral College's responsibility to elect the president." + + labels = ["politics", "science", "other"] + +Here's a zero shot example with ChatGPT 3.5: + +.. code-block:: python + + model = OpenAIZeroShotClassifier("gpt-3.5-turbo").fit(None, labels) + predictions = model.predict([sample_text]) + assert list(predictions) == ["politics"] + +And a few shot example with ChatGPT 4: + +.. code-block:: python + + few_shot_examples = [ + "Joe Biden is the president.", + "Liquid water was found on the moon.", + "Jerry likes football." + ] + + model = OpenAIFewShotClassifier("gpt-4", temperature=0.2).fit(few_shot_examples, labels) + predictions = model.predict([sample_text]) + assert list(predictions) == ["politics"] + + +The format of the prompts is the same as with StableBeluga instruct models, and an error is raised if your prompt does not follow +this format. + +.. code-block:: python + + prompt = """ + ### System: + You are a helpful assistant + ### User: + Your task will be to classify a text document into one + of the following classes: {classes}. + Please respond with a single label that you think fits + the document best. + Classify the following piece of text: + '{X}' + ### Assistant: + """ + + model = OpenAIZeroShotClassifier("gpt-4", prompt=prompt) + + +API reference +^^^^^^^^^^^^^ + +.. autoclass:: stormtrooper.OpenAIZeroShotClassifier + :members: + +.. autoclass:: stormtrooper.OpenAIFewShotClassifier + :members: diff --git a/docs/_build/html/_sources/prompting.rst.txt b/docs/_build/html/_sources/prompting.rst.txt index 865f007..d0b33c0 100644 --- a/docs/_build/html/_sources/prompting.rst.txt +++ b/docs/_build/html/_sources/prompting.rst.txt @@ -1,7 +1,7 @@ Prompting ========= -Text2Text and Generative models use a prompting approach for classification. +Text2Text, Generative, and OpenAI models use a prompting approach for classification. stormtrooper comes with default prompts, but these might not suit the model you want to use, or your use case might require a different prompting strategy from the default. stormtrooper allows you to specify custom prompts in these cases. @@ -12,7 +12,7 @@ Templates Prompting in stormtrooper uses a templating approach, where the .format() method is called on prompts to insert labels and data. -A zero-shot prompt for an instruct Llama model like Stable Beluga would look something like this (this is the default): +A zero-shot prompt for an instruct Llama model like Stable Beluga or for ChatGPT would look something like this (this is the default): .. code-block:: python diff --git a/docs/_build/html/generative.html b/docs/_build/html/generative.html index f5660d2..9d7d237 100644 --- a/docs/_build/html/generative.html +++ b/docs/_build/html/generative.html @@ -5,7 +5,7 @@ - + Generative models - stormtrooper @@ -170,6 +170,7 @@
  • Text2Text models
  • Generative models
  • SetFit models
  • +
  • OpenAI models
  • Prompting
  • Inference on GPU
  • diff --git a/docs/_build/html/genindex.html b/docs/_build/html/genindex.html index 1c97c66..037d048 100644 --- a/docs/_build/html/genindex.html +++ b/docs/_build/html/genindex.html @@ -4,7 +4,7 @@ - Index - stormtrooper + Index - stormtrooper @@ -168,6 +168,7 @@
  • Text2Text models
  • Generative models
  • SetFit models
  • +
  • OpenAI models
  • Prompting
  • Inference on GPU
  • @@ -209,7 +210,7 @@

    Index

    -
    C | E | F | G | P | R | S | T | Z
    +
    C | E | F | G | O | P | R | S | T | Z

    C

    @@ -219,6 +220,10 @@

    C

    +
    +

    O

    + + + +
    +
    +

    P

    @@ -301,6 +324,10 @@

    P

    diff --git a/docs/_build/html/objects.inv b/docs/_build/html/objects.inv index b981df6..c6d8447 100644 Binary files a/docs/_build/html/objects.inv and b/docs/_build/html/objects.inv differ diff --git a/docs/_build/html/openai.html b/docs/_build/html/openai.html new file mode 100644 index 0000000..006b63a --- /dev/null +++ b/docs/_build/html/openai.html @@ -0,0 +1,650 @@ + + + + + + + + + OpenAI models - stormtrooper + + + + + + + + + + + + + + + + Contents + + + + + + Menu + + + + + + + + Expand + + + + + + Light mode + + + + + + + + + + + + + + Dark mode + + + + + + + Auto light/dark mode + + + + + + + + + + + + + + + + + + + +
    +
    +
    + +
    + +
    +
    + +
    + +
    +
    + +
    +
    +
    + + + + + Back to top + +
    + +
    + +
    + +
    +
    +
    +

    OpenAI models#

    +

    Stormtrooper gives you access to OpenAI’s chat models for zero and few-shot classification. +You get full control over temperature settings, system and user prompts. +In contrast to other packages, like scikit-llm, stormtrooper also uses Python’s asyncio to concurrently +interact with OpenAI’s API. This can give multiple times speedup on several tasks.

    +

    You can also set upper limits for number of requests and tokens per minute, so you don’t exceed your quota. +This is by default set to the limit of the payed tier on OpenAI’s API.

    +

    You need to install stormtrooper with optional dependencies.

    +
    pip install stormtrooper[openai]
    +
    +
    +

    You additionally need to set the OpenAI API key as an environment variable.

    +
    export OPENAI_API_KEY="sk-..."
    +# Setting organization is optional
    +export OPENAI_ORG="org-..."
    +
    +
    +
    from stormtrooper import OpenAIZeroShotClassifier, OpenAIFewShotClassifier
    +
    +sample_text = "It is the Electoral College's responsibility to elect the president."
    +
    +labels = ["politics", "science", "other"]
    +
    +
    +

    Here’s a zero shot example with ChatGPT 3.5:

    +
    model = OpenAIZeroShotClassifier("gpt-3.5-turbo").fit(None, labels)
    +predictions = model.predict([sample_text])
    +assert list(predictions) == ["politics"]
    +
    +
    +

    And a few shot example with ChatGPT 4:

    +
    few_shot_examples = [
    +  "Joe Biden is the president.",
    +  "Liquid water was found on the moon.",
    +  "Jerry likes football."
    +]
    +
    +model = OpenAIFewShotClassifier("gpt-4", temperature=0.2).fit(few_shot_examples, labels)
    +predictions = model.predict([sample_text])
    +assert list(predictions) == ["politics"]
    +
    +
    +

    The format of the prompts is the same as with StableBeluga instruct models, and an error is raised if your prompt does not follow +this format.

    +
    prompt = """
    +### System:
    +You are a helpful assistant
    +### User:
    +Your task will be to classify a text document into one
    +of the following classes: {classes}.
    +Please respond with a single label that you think fits
    +the document best.
    +Classify the following piece of text:
    +'{X}'
    +### Assistant:
    +"""
    +
    +model = OpenAIZeroShotClassifier("gpt-4", prompt=prompt)
    +
    +
    +
    +

    API reference#

    +
    +
    +class stormtrooper.OpenAIZeroShotClassifier(model_name: str = 'gpt-3.5-turbo', temperature: float = 1.0, prompt: str = "\n### System:\nYou are a classification model that is really good at following\ninstructions and produces brief answers\nthat users can use as data right away.\nPlease follow the user's instructions as precisely as you can.\n### User:\nYour task will be to classify a text document into one\nof the following classes: {classes}.\nPlease respond with a single label that you think fits\nthe document best.\nClassify the following piece of text:\n ```{X}```\n### Assistant:\n", max_new_tokens: int = 256, max_requests_per_minute: int = 3500, max_tokens_per_minute: int = 90000, max_attempts_per_request: int = 5, fuzzy_match: bool = True)#
    +

    Scikit-learn compatible zero shot classification +with OpenAI’s chat language models.

    +
    +
    Parameters:
    +
      +
    • model_name (str, default 'gpt-3.5-turbo') – Name of OpenAI chat model.

    • +
    • temperature (float = 1.0) – What sampling temperature to use, between 0 and 2. +Higher values like 0.8 will make the output more random, +while lower values like 0.2 will make it +more focused and deterministic.

    • +
    • prompt (str, optional) – You can specify the prompt which will be used to prompt the model. +Use placeholders to indicate where the class labels and the +data should be placed in the prompt.

    • +
    • max_new_tokens (int, default 256) – Maximum number of tokens the model should generate.

    • +
    • max_requests_per_minute (int, default 3500) – Maximum number of requests to send per minute.

    • +
    • max_tokens_per_minute (int, default 90_000) – Maximum number of tokens per minute.

    • +
    • max_attempts_per_request (int, default 5) – Maximum number of times a request shoulb be attempted if it fails +for the first time.

    • +
    • fuzzy_match (bool, default True) – Indicates whether the output lables should be fuzzy matched +to the learnt class labels. +This is useful when the model isn’t giving specific enough answers.

    • +
    +
    +
    +
    +
    +classes_#
    +

    Class names learned from the labels.

    +
    +
    Type:
    +

    array of str

    +
    +
    +
    + +
    +
    +fit(X, y: Iterable[str])#
    +

    Learns class labels.

    +
    +
    Parameters:
    +
      +
    • X (Any) – Ignored

    • +
    • y (iterable of str) – Iterable of class labels. +Should at least contain a representative sample +of potential labels.

    • +
    +
    +
    Returns:
    +

    Fitted model.

    +
    +
    Return type:
    +

    self

    +
    +
    +
    + +
    +
    +partial_fit(X, y: Iterable[str])#
    +

    Learns class labels. +Can learn new labels if new are encountered in the data.

    +
    +
    Parameters:
    +
      +
    • X (Any) – Ignored

    • +
    • y (iterable of str) – Iterable of class labels.

    • +
    +
    +
    Returns:
    +

    Fitted model.

    +
    +
    Return type:
    +

    self

    +
    +
    +
    + +
    +
    +predict(X: Iterable[str]) ndarray#
    +

    Predicts most probable class label for given texts.

    +
    +
    Parameters:
    +

    X (iterable of str) – Texts to label.

    +
    +
    Returns:
    +

    Array of string class labels.

    +
    +
    Return type:
    +

    array of shape (n_texts)

    +
    +
    +
    + +
    +
    +set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') OpenAIZeroShotClassifier#
    +

    Request metadata passed to the score method.

    +

    Note that this method is only relevant if +enable_metadata_routing=True (see sklearn.set_config()). +Please see User Guide on how the routing +mechanism works.

    +

    The options for each parameter are:

    +
      +
    • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

    • +
    • False: metadata is not requested and the meta-estimator will not pass it to score.

    • +
    • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

    • +
    • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

    • +
    +

    The default (sklearn.utils.metadata_routing.UNCHANGED) retains the +existing request. This allows you to change the request for some +parameters and not others.

    +
    +

    New in version 1.3.

    +
    +
    +

    Note

    +

    This method is only relevant if this estimator is used as a +sub-estimator of a meta-estimator, e.g. used inside a +pipeline.Pipeline. Otherwise it has no effect.

    +
    +
    +
    Parameters:
    +

    sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for sample_weight parameter in score.

    +
    +
    Returns:
    +

    self – The updated object.

    +
    +
    Return type:
    +

    object

    +
    +
    +
    + +
    + +
    +
    +class stormtrooper.OpenAIFewShotClassifier(model_name: str = 'gpt-3.5-turbo', temperature: float = 1.0, prompt: str = "\n### System:\nYou are a classification model that is really good at following\ninstructions and produces brief answers\nthat users can use as data right away.\nPlease follow the user's instructions as precisely as you can.\n### User:\nYour task will be to classify a text document into one\nof the following classes: {classes}.\nPlease respond with a single label that you think fits\nthe document best.\nHere are a couple of examples of labels assigned by experts:\n{examples}\nClassify the following piece of text:\n'{X}'\n### Assistant:\n", max_new_tokens: int = 256, max_requests_per_minute: int = 3500, max_tokens_per_minute: int = 90000, max_attempts_per_request: int = 5, fuzzy_match: bool = True)#
    +

    Scikit-learn compatible few shot classification +with OpenAI’s chat language models.

    +
    +
    Parameters:
    +
      +
    • model_name (str, default 'gpt-3.5-turbo') – Name of OpenAI chat model.

    • +
    • temperature (float = 1.0) – What sampling temperature to use, between 0 and 2. +Higher values like 0.8 will make the output more random, +while lower values like 0.2 will make it +more focused and deterministic.

    • +
    • prompt (str, optional) – You can specify the prompt which will be used to prompt the model. +Use placeholders to indicate where the class labels and the +data should be placed in the prompt.

    • +
    • max_new_tokens (int, default 256) – Maximum number of tokens the model should generate.

    • +
    • max_requests_per_minute (int, default 3500) – Maximum number of requests to send per minute.

    • +
    • max_tokens_per_minute (int, default 90_000) – Maximum number of tokens per minute.

    • +
    • max_attempts_per_request (int, default 5) – Maximum number of times a request shoulb be attempted if it fails +for the first time.

    • +
    • fuzzy_match (bool, default True) – Indicates whether the output lables should be fuzzy matched +to the learnt class labels. +This is useful when the model isn’t giving specific enough answers.

    • +
    +
    +
    +
    +
    +classes_#
    +

    Class names learned from the labels.

    +
    +
    Type:
    +

    array of str

    +
    +
    +
    + +
    +
    +fit(X: Iterable[str], y: Iterable[str])#
    +

    Learns class labels.

    +
    +
    Parameters:
    +
      +
    • X (iterable of str) – Examples to pass into the few-shot prompt.

    • +
    • y (iterable of str) – Iterable of class labels. +Should at least contain a representative sample +of potential labels.

    • +
    +
    +
    Returns:
    +

    Fitted model.

    +
    +
    Return type:
    +

    self

    +
    +
    +
    + +
    +
    +partial_fit(X: Iterable[str], y: Iterable[str])#
    +

    Learns class labels. +Can learn new labels if new are encountered in the data.

    +
    +
    Parameters:
    +
      +
    • X (iterable of str) – Examples to pass into the few-shot prompt.

    • +
    • y (iterable of str) – Iterable of class labels.

    • +
    +
    +
    Returns:
    +

    Fitted model.

    +
    +
    Return type:
    +

    self

    +
    +
    +
    + +
    +
    +predict(X: Iterable[str]) ndarray#
    +

    Predicts most probable class label for given texts.

    +
    +
    Parameters:
    +

    X (iterable of str) – Texts to label.

    +
    +
    Returns:
    +

    Array of string class labels.

    +
    +
    Return type:
    +

    array of shape (n_texts)

    +
    +
    +
    + +
    +
    +set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') OpenAIFewShotClassifier#
    +

    Request metadata passed to the score method.

    +

    Note that this method is only relevant if +enable_metadata_routing=True (see sklearn.set_config()). +Please see User Guide on how the routing +mechanism works.

    +

    The options for each parameter are:

    +
      +
    • True: metadata is requested, and passed to score if provided. The request is ignored if metadata is not provided.

    • +
    • False: metadata is not requested and the meta-estimator will not pass it to score.

    • +
    • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

    • +
    • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

    • +
    +

    The default (sklearn.utils.metadata_routing.UNCHANGED) retains the +existing request. This allows you to change the request for some +parameters and not others.

    +
    +

    New in version 1.3.

    +
    +
    +

    Note

    +

    This method is only relevant if this estimator is used as a +sub-estimator of a meta-estimator, e.g. used inside a +pipeline.Pipeline. Otherwise it has no effect.

    +
    +
    +
    Parameters:
    +

    sample_weight (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for sample_weight parameter in score.

    +
    +
    Returns:
    +

    self – The updated object.

    +
    +
    Return type:
    +

    object

    +
    +
    +
    + +
    + +
    +
    + +
    +
    + +
    + +
    +
    + + + + + \ No newline at end of file diff --git a/docs/_build/html/prompting.html b/docs/_build/html/prompting.html index 9fc1849..03f26e7 100644 --- a/docs/_build/html/prompting.html +++ b/docs/_build/html/prompting.html @@ -3,9 +3,9 @@ - + - + Prompting - stormtrooper @@ -170,6 +170,7 @@
  • Text2Text models
  • Generative models
  • SetFit models
  • +
  • OpenAI models
  • Prompting
  • Inference on GPU
  • @@ -211,7 +212,7 @@

    Prompting#

    -

    Text2Text and Generative models use a prompting approach for classification. +

    Text2Text, Generative, and OpenAI models use a prompting approach for classification. stormtrooper comes with default prompts, but these might not suit the model you want to use, or your use case might require a different prompting strategy from the default. stormtrooper allows you to specify custom prompts in these cases.

    @@ -219,7 +220,7 @@

    Prompting#

    Prompting in stormtrooper uses a templating approach, where the .format() method is called on prompts to insert labels and data.

    -

    A zero-shot prompt for an instruct Llama model like Stable Beluga would look something like this (this is the default):

    +

    A zero-shot prompt for an instruct Llama model like Stable Beluga or for ChatGPT would look something like this (this is the default):

    prompt = """
     ### System:
     You are a classification model that is really good at following
    @@ -270,14 +271,14 @@ 

    Templates - +
    Previous
    -
    SetFit models
    +
    OpenAI models
    diff --git a/docs/_build/html/search.html b/docs/_build/html/search.html index 901e128..498e59b 100644 --- a/docs/_build/html/search.html +++ b/docs/_build/html/search.html @@ -4,7 +4,7 @@ - Search - stormtrooper + Search - stormtrooper @@ -167,6 +167,7 @@
  • Text2Text models
  • Generative models
  • SetFit models
  • +
  • OpenAI models
  • Prompting
  • Inference on GPU
  • diff --git a/docs/_build/html/searchindex.js b/docs/_build/html/searchindex.js index e40cfeb..6308284 100644 --- a/docs/_build/html/searchindex.js +++ b/docs/_build/html/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["generative", "index", "inference_on_gpu", "prompting", "setfit", "text2text", "zeroshot"], "filenames": ["generative.rst", "index.rst", "inference_on_gpu.rst", "prompting.rst", "setfit.rst", "text2text.rst", "zeroshot.rst"], "titles": ["Generative models", "Getting Started", "Inference on GPU", "Prompting", "SetFit models", "Text2Text models", "Zero-shot models"], "terms": {"stormtroop": [0, 1, 2, 3, 4, 5, 6], "also": 0, "support": 0, "fulli": [0, 1], "architectur": 0, "both": 0, "few": [0, 1, 3, 4, 5], "shot": [0, 1, 3, 4, 5], "zero": [0, 1, 3, 4, 5], "learn": [0, 1, 3, 4, 5, 6], "It": [0, 4, 5, 6], "": [0, 1, 3, 4, 5, 6], "worth": 0, "note": [0, 4, 5, 6], "most": [0, 4, 5, 6], "instruct": [0, 3, 5], "finetun": 0, "ar": [0, 1, 3, 4, 5, 6], "quit": 0, "hefti": 0, "take": 0, "lot": 0, "resourc": 0, "run": [0, 2, 4, 5, 6], "we": [0, 4], "recommend": [0, 4], "you": [0, 1, 2, 3, 4, 5, 6], "exhaust": 0, "all": [0, 1, 4], "other": [0, 4, 5, 6], "option": [0, 1, 4, 5, 6], "befor": [0, 2], "turn": 0, "from": [0, 1, 2, 3, 4, 5, 6], "import": [0, 1, 2, 4, 5, 6], "generativezeroshotclassifi": [0, 3], "generativefewshotclassifi": 0, "sample_text": [0, 4, 5, 6], "i": [0, 1, 3, 4, 5, 6], "elector": [0, 4, 5, 6], "colleg": [0, 4, 5, 6], "respons": [0, 4, 5, 6], "elect": [0, 4, 5, 6], "presid": [0, 4, 5, 6], "label": [0, 3, 4, 5, 6], "polit": [0, 4, 5, 6], "scienc": [0, 4, 5, 6], "here": [0, 3, 4, 5], "exampl": [0, 1, 3, 4, 5], "stabilityai": [0, 3], "stablebeluga": [0, 3], "13b": [0, 3], "fit": [0, 1, 3, 4, 5, 6], "none": [0, 1, 4, 5, 6], "predict": [0, 1, 4, 5, 6], "assert": [0, 4, 5, 6], "list": [0, 4, 5, 6], "And": [0, 4, 5], "few_shot_exampl": [0, 4, 5], "joe": [0, 4, 5], "biden": [0, 4, 5], "liquid": [0, 4, 5], "water": [0, 4, 5], "wa": [0, 1, 4, 5], "found": [0, 4, 5], "moon": [0, 4, 5], "jerri": [0, 4, 5], "like": [0, 3, 4, 5], "footbal": [0, 4, 5], "class": [0, 3, 4, 5, 6], "model_nam": [0, 4, 5, 6], "str": [0, 4, 5, 6], "7b": 0, "prompt": [0, 1, 4, 5], "n": [0, 5], "system": [0, 3], "nyou": 0, "classif": [0, 1, 3, 4, 5, 6], "realli": [0, 3], "good": [0, 3], "follow": [0, 3], "ninstruct": 0, "produc": [0, 3], "brief": [0, 3], "answer": [0, 3, 5], "nthat": 0, "user": [0, 3, 4, 5, 6], "can": [0, 1, 2, 3, 4, 5, 6], "us": [0, 1, 3, 4, 5, 6], "data": [0, 3, 4, 5, 6], "right": [0, 3], "awai": [0, 3], "npleas": 0, "precis": [0, 3], "nyour": 0, "task": [0, 3], "classifi": [0, 3, 5], "text": [0, 3, 4, 5, 6], "document": [0, 3], "one": [0, 3, 5], "nof": 0, "respond": [0, 3, 5], "singl": [0, 3], "think": [0, 3], "nthe": 0, "best": [0, 3], "nclassifi": 0, "piec": [0, 3, 5], "x": [0, 3, 4, 5, 6], "assist": [0, 3], "max_new_token": [0, 5], "int": [0, 4, 5], "256": [0, 5], "fuzzy_match": [0, 5], "bool": [0, 4, 5, 6], "true": [0, 4, 5, 6], "progress_bar": [0, 5, 6], "devic": [0, 2, 4, 5, 6], "cpu": [0, 2, 4, 5, 6], "scikit": [0, 1, 4, 5, 6], "compat": [0, 1, 4, 5, 6], "languag": [0, 5], "paramet": [0, 4, 5, 6], "default": [0, 2, 3, 4, 5, 6], "huggingfac": [0, 1, 4, 5, 6], "specifi": [0, 2, 3, 5], "which": [0, 4, 5, 6], "placehold": [0, 5], "indic": [0, 4, 5, 6], "where": [0, 3, 5], "should": [0, 4, 5, 6], "place": [0, 5], "maximum": [0, 5], "number": [0, 4, 5], "token": [0, 5], "whether": [0, 2, 5, 6], "output": [0, 5, 6], "labl": [0, 5], "fuzzi": [0, 5], "match": [0, 5], "learnt": [0, 5], "thi": [0, 1, 2, 3, 4, 5, 6], "when": [0, 2, 5], "isn": [0, 5], "t": [0, 5], "give": [0, 3, 5], "specif": [0, 5, 6], "enough": [0, 5], "progress": [0, 5, 6], "bar": [0, 5, 6], "shown": [0, 5, 6], "classes_": [0, 4, 5, 6], "name": [0, 4, 5, 6], "type": [0, 4, 5, 6], "arrai": [0, 4, 5, 6], "y": [0, 4, 5, 6], "iter": [0, 4, 5, 6], "ani": [0, 3, 4, 5, 6], "ignor": [0, 4, 5, 6], "least": [0, 4, 5, 6], "contain": [0, 4, 5, 6], "repres": [0, 3, 4, 5, 6], "sampl": [0, 4, 5, 6], "potenti": [0, 4, 5, 6], "return": [0, 4, 5, 6], "self": [0, 4, 5, 6], "partial_fit": [0, 4, 5, 6], "new": [0, 1, 4, 5, 6], "encount": [0, 4, 5, 6], "ndarrai": [0, 4, 5, 6], "probabl": [0, 4, 5, 6], "given": [0, 4, 5, 6], "string": [0, 4, 5, 6], "shape": [0, 4, 5, 6], "n_text": [0, 4, 5, 6], "set_score_request": [0, 4, 5, 6], "sample_weight": [0, 4, 5, 6], "unchang": [0, 4, 5, 6], "request": [0, 4, 5, 6], "metadata": [0, 4, 5, 6], "pass": [0, 4, 5, 6], "score": [0, 4, 5, 6], "method": [0, 3, 4, 5, 6], "onli": [0, 3, 4, 5, 6], "relev": [0, 4, 5, 6], "enable_metadata_rout": [0, 4, 5, 6], "see": [0, 4, 5, 6], "sklearn": [0, 4, 5, 6], "set_config": [0, 4, 5, 6], "pleas": [0, 3, 4, 5, 6], "guid": [0, 4, 5, 6], "how": [0, 4, 5, 6], "rout": [0, 4, 5, 6], "mechan": [0, 4, 5, 6], "work": [0, 4, 5, 6], "The": [0, 4, 5, 6], "each": [0, 4, 5, 6], "provid": [0, 3, 4, 5, 6], "fals": [0, 4, 5, 6], "meta": [0, 4, 5, 6], "estim": [0, 4, 5, 6], "rais": [0, 4, 5, 6], "an": [0, 3, 4, 5, 6], "error": [0, 4, 5, 6], "alia": [0, 4, 5, 6], "instead": [0, 4, 5, 6], "origin": [0, 4, 5, 6], "util": [0, 4, 5, 6], "metadata_rout": [0, 4, 5, 6], "retain": [0, 4, 5, 6], "exist": [0, 4, 5, 6], "allow": [0, 3, 4, 5, 6], "chang": [0, 4, 5, 6], "some": [0, 3, 4, 5, 6], "version": [0, 4, 5, 6], "1": [0, 4, 5, 6], "3": [0, 4, 5, 6], "sub": [0, 4, 5, 6], "e": [0, 4, 5, 6], "g": [0, 4, 5, 6], "insid": [0, 4, 5, 6], "pipelin": [0, 1, 4, 5, 6], "otherwis": [0, 4, 5, 6], "ha": [0, 4, 5, 6], "effect": [0, 4, 5, 6], "updat": [0, 4, 5, 6], "object": [0, 4, 5, 6], "examples_": [0, 5], "dict": [0, 5], "run_prompt": 0, "result": 0, "generate_prompt": [0, 5], "base": [0, 1, 5], "lightweight": 1, "python": 1, "librari": 1, "transform": [1, 4, 6], "model": [1, 2, 3], "compon": 1, "therebi": 1, "make": [1, 2], "easier": 1, "integr": 1, "them": 1, "your": [1, 3], "workflow": 1, "pypi": 1, "pip": [1, 4], "torch": [1, 2], "If": 1, "intend": 1, "setfit": 1, "well": 1, "depend": [1, 4], "To": 1, "load": [1, 6], "hub": [1, 4], "In": [1, 6], "am": 1, "go": 1, "googl": [1, 5], "flan": [1, 5], "t5": [1, 3, 5], "text2textzeroshotclassifi": [1, 2, 5], "class_label": 1, "atheism": 1, "christian": 1, "astronomi": 1, "space": 1, "example_text": 1, "god": 1, "came": 1, "down": 1, "earth": 1, "save": 1, "u": 1, "A": [1, 3], "nebula": 1, "recent": 1, "discov": 1, "proxim": 1, "oort": 1, "cloud": 1, "text2text": [1, 3], "gener": [1, 3, 4, 5], "infer": [1, 4], "gpu": 1, "github": 1, "repositori": 1, "influenc": 2, "behaviour": 2, "initi": 2, "sure": 2, "check": 2, "have": 2, "cuda": 2, "avail": 2, "try": 2, "print": 2, "is_avail": 2, "0": [2, 6], "approach": [3, 4], "come": 3, "might": 3, "suit": 3, "want": 3, "case": 3, "requir": [3, 4], "differ": 3, "strategi": 3, "custom": 3, "format": [3, 5], "call": 3, "insert": 3, "llama": 3, "stabl": 3, "beluga": 3, "would": 3, "look": 3, "someth": 3, "current": 3, "question": 3, "while": 3, "let": 3, "sai": 3, "fewshot_prompt": 3, "same": [3, 5], "expert": 3, "stage": 3, "emit": 3, "variabl": 3, "oper": 3, "definit": 3, "train": 4, "effici": 4, "sentenc": 4, "free": 4, "need": 4, "wai": 4, "smaller": 4, "thu": 4, "faster": 4, "more": 4, "employ": 4, "high": 4, "perform": 4, "set": [4, 6], "sinc": 4, "packag": 4, "instal": 4, "its": 4, "setfitzeroshotclassifi": 4, "setfitfewshotclassifi": 4, "minilm": 4, "l6": 4, "v2": 4, "sample_s": 4, "8": 4, "easili": 5, "emploi": 5, "By": 5, "text2textfewshotclassifi": 5, "ni": 5, "na": 5, "nwith": 5, "seq2seq": 5, "design": 6, "tune": 6, "These": 6, "extens": 6, "function": 6, "includ": 6, "certainti": 6, "zeroshotclassifi": 6, "facebook": 6, "bart": 6, "larg": 6, "mnli": 6, "set_output": 6, "panda": 6, "924671": 6, "006629": 6, "0687": 6, "predict_proba": 6, "n_class": 6, "datafram": 6, "matrix": 6, "disabl": 6}, "objects": {"stormtrooper": [[0, 0, 1, "", "GenerativeFewShotClassifier"], [0, 0, 1, "", "GenerativeZeroShotClassifier"], [4, 0, 1, "", "SetFitFewShotClassifier"], [4, 0, 1, "", "SetFitZeroShotClassifier"], [5, 0, 1, "", "Text2TextFewShotClassifier"], [5, 0, 1, "", "Text2TextZeroShotClassifier"], [6, 0, 1, "", "ZeroShotClassifier"]], "stormtrooper.GenerativeFewShotClassifier": [[0, 1, 1, "", "classes_"], [0, 1, 1, "", "examples_"], [0, 2, 1, "", "fit"], [0, 2, 1, "", "generate_prompt"], [0, 2, 1, "", "partial_fit"], [0, 2, 1, "", "predict"], [0, 2, 1, "", "run_prompt"], [0, 2, 1, "", "set_score_request"]], "stormtrooper.GenerativeZeroShotClassifier": [[0, 1, 1, "", "classes_"], [0, 2, 1, "", "fit"], [0, 2, 1, "", "partial_fit"], [0, 2, 1, "", "predict"], [0, 2, 1, "", "set_score_request"]], "stormtrooper.SetFitFewShotClassifier": [[4, 1, 1, "", "classes_"], [4, 2, 1, "", "fit"], [4, 2, 1, "", "predict"], [4, 2, 1, "", "set_score_request"]], "stormtrooper.SetFitZeroShotClassifier": [[4, 1, 1, "", "classes_"], [4, 2, 1, "", "fit"], [4, 2, 1, "", "partial_fit"], [4, 2, 1, "", "predict"], [4, 2, 1, "", "set_score_request"]], "stormtrooper.Text2TextFewShotClassifier": [[5, 1, 1, "", "classes_"], [5, 1, 1, "", "examples_"], [5, 2, 1, "", "fit"], [5, 2, 1, "", "generate_prompt"], [5, 2, 1, "", "partial_fit"], [5, 2, 1, "", "predict"], [5, 2, 1, "", "set_score_request"]], "stormtrooper.Text2TextZeroShotClassifier": [[5, 1, 1, "", "classes_"], [5, 2, 1, "", "fit"], [5, 2, 1, "", "partial_fit"], [5, 2, 1, "", "predict"], [5, 2, 1, "", "set_score_request"]], "stormtrooper.ZeroShotClassifier": [[6, 1, 1, "", "classes_"], [6, 2, 1, "", "fit"], [6, 2, 1, "", "partial_fit"], [6, 2, 1, "", "predict"], [6, 2, 1, "", "predict_proba"], [6, 2, 1, "", "set_output"], [6, 2, 1, "", "set_score_request"], [6, 2, 1, "", "transform"]]}, "objtypes": {"0": "py:class", "1": "py:attribute", "2": "py:method"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "attribute", "Python attribute"], "2": ["py", "method", "Python method"]}, "titleterms": {"gener": 0, "model": [0, 4, 5, 6], "api": [0, 4, 5, 6], "refer": [0, 4, 5, 6], "get": 1, "start": 1, "instal": 1, "usag": 1, "user": 1, "guid": 1, "infer": 2, "gpu": 2, "prompt": 3, "templat": 3, "setfit": 4, "text2text": 5, "zero": 6, "shot": 6}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 60}, "alltitles": {"Generative models": [[0, "generative-models"]], "API reference": [[0, "api-reference"], [4, "api-reference"], [5, "api-reference"], [6, "api-reference"]], "Getting Started": [[1, "getting-started"]], "Installation": [[1, "installation"]], "Usage": [[1, "usage"]], "User guide": [[1, null]], "Inference on GPU": [[2, "inference-on-gpu"]], "Prompting": [[3, "prompting"]], "Templates": [[3, "templates"]], "SetFit models": [[4, "setfit-models"]], "Text2Text models": [[5, "text2text-models"]], "Zero-shot models": [[6, "zero-shot-models"]]}, "indexentries": {"generativefewshotclassifier (class in stormtrooper)": [[0, "stormtrooper.GenerativeFewShotClassifier"]], "generativezeroshotclassifier (class in stormtrooper)": [[0, "stormtrooper.GenerativeZeroShotClassifier"]], "classes_ (stormtrooper.generativefewshotclassifier attribute)": [[0, "stormtrooper.GenerativeFewShotClassifier.classes_"]], "classes_ (stormtrooper.generativezeroshotclassifier attribute)": [[0, "stormtrooper.GenerativeZeroShotClassifier.classes_"]], "examples_ (stormtrooper.generativefewshotclassifier attribute)": [[0, "stormtrooper.GenerativeFewShotClassifier.examples_"]], "fit() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.fit"]], "fit() (stormtrooper.generativezeroshotclassifier method)": [[0, "stormtrooper.GenerativeZeroShotClassifier.fit"]], "generate_prompt() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.generate_prompt"]], "partial_fit() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.partial_fit"]], "partial_fit() (stormtrooper.generativezeroshotclassifier method)": [[0, "stormtrooper.GenerativeZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.predict"]], "predict() (stormtrooper.generativezeroshotclassifier method)": [[0, "stormtrooper.GenerativeZeroShotClassifier.predict"]], "run_prompt() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.run_prompt"]], "set_score_request() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.set_score_request"]], "set_score_request() (stormtrooper.generativezeroshotclassifier method)": [[0, "stormtrooper.GenerativeZeroShotClassifier.set_score_request"]], "setfitfewshotclassifier (class in stormtrooper)": [[4, "stormtrooper.SetFitFewShotClassifier"]], "setfitzeroshotclassifier (class in stormtrooper)": [[4, "stormtrooper.SetFitZeroShotClassifier"]], "classes_ (stormtrooper.setfitfewshotclassifier attribute)": [[4, "stormtrooper.SetFitFewShotClassifier.classes_"]], "classes_ (stormtrooper.setfitzeroshotclassifier attribute)": [[4, "stormtrooper.SetFitZeroShotClassifier.classes_"]], "fit() (stormtrooper.setfitfewshotclassifier method)": [[4, "stormtrooper.SetFitFewShotClassifier.fit"]], "fit() (stormtrooper.setfitzeroshotclassifier method)": [[4, "stormtrooper.SetFitZeroShotClassifier.fit"]], "partial_fit() (stormtrooper.setfitzeroshotclassifier method)": [[4, "stormtrooper.SetFitZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.setfitfewshotclassifier method)": [[4, "stormtrooper.SetFitFewShotClassifier.predict"]], "predict() (stormtrooper.setfitzeroshotclassifier method)": [[4, "stormtrooper.SetFitZeroShotClassifier.predict"]], "set_score_request() (stormtrooper.setfitfewshotclassifier method)": [[4, "stormtrooper.SetFitFewShotClassifier.set_score_request"]], "set_score_request() (stormtrooper.setfitzeroshotclassifier method)": [[4, "stormtrooper.SetFitZeroShotClassifier.set_score_request"]], "text2textfewshotclassifier (class in stormtrooper)": [[5, "stormtrooper.Text2TextFewShotClassifier"]], "text2textzeroshotclassifier (class in stormtrooper)": [[5, "stormtrooper.Text2TextZeroShotClassifier"]], "classes_ (stormtrooper.text2textfewshotclassifier attribute)": [[5, "stormtrooper.Text2TextFewShotClassifier.classes_"]], "classes_ (stormtrooper.text2textzeroshotclassifier attribute)": [[5, "stormtrooper.Text2TextZeroShotClassifier.classes_"]], "examples_ (stormtrooper.text2textfewshotclassifier attribute)": [[5, "stormtrooper.Text2TextFewShotClassifier.examples_"]], "fit() (stormtrooper.text2textfewshotclassifier method)": [[5, "stormtrooper.Text2TextFewShotClassifier.fit"]], "fit() (stormtrooper.text2textzeroshotclassifier method)": [[5, "stormtrooper.Text2TextZeroShotClassifier.fit"]], "generate_prompt() (stormtrooper.text2textfewshotclassifier method)": [[5, "stormtrooper.Text2TextFewShotClassifier.generate_prompt"]], "partial_fit() (stormtrooper.text2textfewshotclassifier method)": [[5, "stormtrooper.Text2TextFewShotClassifier.partial_fit"]], "partial_fit() (stormtrooper.text2textzeroshotclassifier method)": [[5, "stormtrooper.Text2TextZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.text2textfewshotclassifier method)": [[5, "stormtrooper.Text2TextFewShotClassifier.predict"]], "predict() (stormtrooper.text2textzeroshotclassifier method)": [[5, "stormtrooper.Text2TextZeroShotClassifier.predict"]], "set_score_request() (stormtrooper.text2textfewshotclassifier method)": [[5, "stormtrooper.Text2TextFewShotClassifier.set_score_request"]], "set_score_request() (stormtrooper.text2textzeroshotclassifier method)": [[5, "stormtrooper.Text2TextZeroShotClassifier.set_score_request"]], "zeroshotclassifier (class in stormtrooper)": [[6, "stormtrooper.ZeroShotClassifier"]], "classes_ (stormtrooper.zeroshotclassifier attribute)": [[6, "stormtrooper.ZeroShotClassifier.classes_"]], "fit() (stormtrooper.zeroshotclassifier method)": [[6, "stormtrooper.ZeroShotClassifier.fit"]], "partial_fit() (stormtrooper.zeroshotclassifier method)": [[6, "stormtrooper.ZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.zeroshotclassifier method)": [[6, "stormtrooper.ZeroShotClassifier.predict"]], "predict_proba() (stormtrooper.zeroshotclassifier method)": [[6, "stormtrooper.ZeroShotClassifier.predict_proba"]], "set_output() (stormtrooper.zeroshotclassifier method)": [[6, "stormtrooper.ZeroShotClassifier.set_output"]], "set_score_request() (stormtrooper.zeroshotclassifier method)": [[6, "stormtrooper.ZeroShotClassifier.set_score_request"]], "transform() (stormtrooper.zeroshotclassifier method)": [[6, "stormtrooper.ZeroShotClassifier.transform"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["generative", "index", "inference_on_gpu", "openai", "prompting", "setfit", "text2text", "zeroshot"], "filenames": ["generative.rst", "index.rst", "inference_on_gpu.rst", "openai.rst", "prompting.rst", "setfit.rst", "text2text.rst", "zeroshot.rst"], "titles": ["Generative models", "Getting Started", "Inference on GPU", "OpenAI models", "Prompting", "SetFit models", "Text2Text models", "Zero-shot models"], "terms": {"stormtroop": [0, 1, 2, 3, 4, 5, 6, 7], "also": [0, 1, 3], "support": 0, "fulli": [0, 1], "architectur": 0, "both": 0, "few": [0, 1, 3, 4, 5, 6], "shot": [0, 1, 3, 4, 5, 6], "zero": [0, 1, 3, 4, 5, 6], "learn": [0, 1, 3, 4, 5, 6, 7], "It": [0, 3, 5, 6, 7], "": [0, 1, 3, 4, 5, 6, 7], "worth": 0, "note": [0, 3, 5, 6, 7], "most": [0, 3, 5, 6, 7], "instruct": [0, 3, 4, 6], "finetun": 0, "ar": [0, 1, 3, 4, 5, 6, 7], "quit": 0, "hefti": 0, "take": 0, "lot": 0, "resourc": 0, "run": [0, 2, 5, 6, 7], "we": [0, 5], "recommend": [0, 5], "you": [0, 1, 2, 3, 4, 5, 6, 7], "exhaust": 0, "all": [0, 1, 5], "other": [0, 3, 5, 6, 7], "option": [0, 1, 3, 5, 6, 7], "befor": [0, 2], "turn": 0, "from": [0, 1, 2, 3, 4, 5, 6, 7], "import": [0, 1, 2, 3, 5, 6, 7], "generativezeroshotclassifi": [0, 4], "generativefewshotclassifi": 0, "sample_text": [0, 3, 5, 6, 7], "i": [0, 1, 3, 4, 5, 6, 7], "elector": [0, 3, 5, 6, 7], "colleg": [0, 3, 5, 6, 7], "respons": [0, 3, 5, 6, 7], "elect": [0, 3, 5, 6, 7], "presid": [0, 3, 5, 6, 7], "label": [0, 3, 4, 5, 6, 7], "polit": [0, 3, 5, 6, 7], "scienc": [0, 3, 5, 6, 7], "here": [0, 3, 4, 5, 6], "exampl": [0, 1, 3, 4, 5, 6], "stabilityai": [0, 4], "stablebeluga": [0, 3, 4], "13b": [0, 4], "fit": [0, 1, 3, 4, 5, 6, 7], "none": [0, 1, 3, 5, 6, 7], "predict": [0, 1, 3, 5, 6, 7], "assert": [0, 3, 5, 6, 7], "list": [0, 3, 5, 6, 7], "And": [0, 3, 5, 6], "few_shot_exampl": [0, 3, 5, 6], "joe": [0, 3, 5, 6], "biden": [0, 3, 5, 6], "liquid": [0, 3, 5, 6], "water": [0, 3, 5, 6], "wa": [0, 1, 3, 5, 6], "found": [0, 3, 5, 6], "moon": [0, 3, 5, 6], "jerri": [0, 3, 5, 6], "like": [0, 3, 4, 5, 6], "footbal": [0, 3, 5, 6], "class": [0, 3, 4, 5, 6, 7], "model_nam": [0, 3, 5, 6, 7], "str": [0, 3, 5, 6, 7], "7b": 0, "prompt": [0, 1, 3, 5, 6], "n": [0, 3, 6], "system": [0, 3, 4], "nyou": [0, 3], "classif": [0, 1, 3, 4, 5, 6, 7], "realli": [0, 3, 4], "good": [0, 3, 4], "follow": [0, 3, 4], "ninstruct": [0, 3], "produc": [0, 3, 4], "brief": [0, 3, 4], "answer": [0, 3, 4, 6], "nthat": [0, 3], "user": [0, 3, 4, 5, 6, 7], "can": [0, 1, 2, 3, 4, 5, 6, 7], "us": [0, 1, 3, 4, 5, 6, 7], "data": [0, 3, 4, 5, 6, 7], "right": [0, 3, 4], "awai": [0, 3, 4], "npleas": [0, 3], "precis": [0, 3, 4], "nyour": [0, 3], "task": [0, 3, 4], "classifi": [0, 3, 4, 6], "text": [0, 3, 4, 5, 6, 7], "document": [0, 3, 4], "one": [0, 3, 4, 6], "nof": [0, 3], "respond": [0, 3, 4, 6], "singl": [0, 3, 4], "think": [0, 3, 4], "nthe": [0, 3], "best": [0, 3, 4], "nclassifi": [0, 3], "piec": [0, 3, 4, 6], "x": [0, 3, 4, 5, 6, 7], "assist": [0, 3, 4], "max_new_token": [0, 3, 6], "int": [0, 3, 5, 6], "256": [0, 3, 6], "fuzzy_match": [0, 3, 6], "bool": [0, 3, 5, 6, 7], "true": [0, 3, 5, 6, 7], "progress_bar": [0, 6, 7], "devic": [0, 2, 5, 6, 7], "cpu": [0, 2, 5, 6, 7], "scikit": [0, 1, 3, 5, 6, 7], "compat": [0, 1, 3, 5, 6, 7], "languag": [0, 3, 6], "paramet": [0, 3, 5, 6, 7], "default": [0, 2, 3, 4, 5, 6, 7], "huggingfac": [0, 1, 5, 6, 7], "specifi": [0, 2, 3, 4, 6], "which": [0, 3, 5, 6, 7], "placehold": [0, 3, 6], "indic": [0, 3, 5, 6, 7], "where": [0, 3, 4, 6], "should": [0, 3, 5, 6, 7], "place": [0, 3, 6], "maximum": [0, 3, 6], "number": [0, 3, 5, 6], "token": [0, 3, 6], "whether": [0, 2, 3, 6, 7], "output": [0, 3, 6, 7], "labl": [0, 3, 6], "fuzzi": [0, 3, 6], "match": [0, 3, 6], "learnt": [0, 3, 6], "thi": [0, 1, 2, 3, 4, 5, 6, 7], "when": [0, 2, 3, 6], "isn": [0, 3, 6], "t": [0, 3, 6], "give": [0, 3, 4, 6], "specif": [0, 3, 6, 7], "enough": [0, 3, 6], "progress": [0, 6, 7], "bar": [0, 6, 7], "shown": [0, 6, 7], "classes_": [0, 3, 5, 6, 7], "name": [0, 3, 5, 6, 7], "type": [0, 3, 5, 6, 7], "arrai": [0, 3, 5, 6, 7], "y": [0, 3, 5, 6, 7], "iter": [0, 3, 5, 6, 7], "ani": [0, 3, 4, 5, 6, 7], "ignor": [0, 3, 5, 6, 7], "least": [0, 3, 5, 6, 7], "contain": [0, 3, 5, 6, 7], "repres": [0, 3, 4, 5, 6, 7], "sampl": [0, 3, 5, 6, 7], "potenti": [0, 3, 5, 6, 7], "return": [0, 3, 5, 6, 7], "self": [0, 3, 5, 6, 7], "partial_fit": [0, 3, 5, 6, 7], "new": [0, 1, 3, 5, 6, 7], "encount": [0, 3, 5, 6, 7], "ndarrai": [0, 3, 5, 6, 7], "probabl": [0, 3, 5, 6, 7], "given": [0, 3, 5, 6, 7], "string": [0, 3, 5, 6, 7], "shape": [0, 3, 5, 6, 7], "n_text": [0, 3, 5, 6, 7], "set_score_request": [0, 3, 5, 6, 7], "sample_weight": [0, 3, 5, 6, 7], "unchang": [0, 3, 5, 6, 7], "request": [0, 3, 5, 6, 7], "metadata": [0, 3, 5, 6, 7], "pass": [0, 3, 5, 6, 7], "score": [0, 3, 5, 6, 7], "method": [0, 3, 4, 5, 6, 7], "onli": [0, 3, 4, 5, 6, 7], "relev": [0, 3, 5, 6, 7], "enable_metadata_rout": [0, 3, 5, 6, 7], "see": [0, 3, 5, 6, 7], "sklearn": [0, 3, 5, 6, 7], "set_config": [0, 3, 5, 6, 7], "pleas": [0, 3, 4, 5, 6, 7], "guid": [0, 3, 5, 6, 7], "how": [0, 3, 5, 6, 7], "rout": [0, 3, 5, 6, 7], "mechan": [0, 3, 5, 6, 7], "work": [0, 3, 5, 6, 7], "The": [0, 3, 5, 6, 7], "each": [0, 3, 5, 6, 7], "provid": [0, 3, 4, 5, 6, 7], "fals": [0, 3, 5, 6, 7], "meta": [0, 3, 5, 6, 7], "estim": [0, 3, 5, 6, 7], "rais": [0, 3, 5, 6, 7], "an": [0, 3, 4, 5, 6, 7], "error": [0, 3, 5, 6, 7], "alia": [0, 3, 5, 6, 7], "instead": [0, 3, 5, 6, 7], "origin": [0, 3, 5, 6, 7], "util": [0, 3, 5, 6, 7], "metadata_rout": [0, 3, 5, 6, 7], "retain": [0, 3, 5, 6, 7], "exist": [0, 3, 5, 6, 7], "allow": [0, 3, 4, 5, 6, 7], "chang": [0, 3, 5, 6, 7], "some": [0, 3, 4, 5, 6, 7], "version": [0, 1, 3, 5, 6, 7], "1": [0, 3, 5, 6, 7], "3": [0, 3, 5, 6, 7], "sub": [0, 3, 5, 6, 7], "e": [0, 3, 5, 6, 7], "g": [0, 3, 5, 6, 7], "insid": [0, 3, 5, 6, 7], "pipelin": [0, 1, 3, 5, 6, 7], "otherwis": [0, 3, 5, 6, 7], "ha": [0, 3, 5, 6, 7], "effect": [0, 3, 5, 6, 7], "updat": [0, 3, 5, 6, 7], "object": [0, 3, 5, 6, 7], "examples_": [0, 6], "dict": [0, 6], "run_prompt": 0, "result": 0, "generate_prompt": [0, 6], "base": [0, 1, 6], "lightweight": 1, "python": [1, 3], "librari": 1, "transform": [1, 5, 7], "model": [1, 2, 4], "compon": 1, "therebi": 1, "make": [1, 2, 3], "easier": 1, "integr": 1, "them": 1, "your": [1, 3, 4], "workflow": 1, "pypi": 1, "pip": [1, 3, 5], "torch": [1, 2], "If": 1, "intend": 1, "setfit": 1, "well": 1, "depend": [1, 3, 5], "0": [1, 2, 3, 7], "4": [1, 3], "openai": [1, 4], "export": [1, 3], "openai_api_kei": [1, 3], "sk": [1, 3], "To": 1, "load": [1, 7], "hub": [1, 5], "In": [1, 3, 7], "am": 1, "go": 1, "googl": [1, 6], "flan": [1, 6], "t5": [1, 4, 6], "text2textzeroshotclassifi": [1, 2, 6], "class_label": 1, "atheism": 1, "christian": 1, "astronomi": 1, "space": 1, "example_text": 1, "god": 1, "came": 1, "down": 1, "earth": 1, "save": 1, "u": 1, "A": [1, 4], "nebula": 1, "recent": 1, "discov": 1, "proxim": 1, "oort": 1, "cloud": 1, "text2text": [1, 4], "gener": [1, 3, 4, 5, 6], "infer": [1, 5], "gpu": 1, "github": 1, "repositori": 1, "influenc": 2, "behaviour": 2, "initi": 2, "sure": 2, "check": 2, "have": 2, "cuda": 2, "avail": 2, "try": 2, "print": 2, "is_avail": 2, "access": 3, "chat": 3, "get": 3, "full": 3, "control": 3, "over": 3, "temperatur": 3, "set": [3, 5, 7], "contrast": 3, "packag": [3, 5], "llm": 3, "asyncio": 3, "concurr": 3, "interact": 3, "multipl": 3, "time": 3, "speedup": 3, "sever": 3, "upper": 3, "limit": 3, "per": 3, "minut": 3, "so": 3, "don": 3, "exce": 3, "quota": 3, "pai": 3, "tier": 3, "need": [3, 5], "instal": [3, 5], "addition": 3, "kei": 3, "environ": 3, "variabl": [3, 4], "organ": 3, "openai_org": 3, "org": 3, "openaizeroshotclassifi": 3, "openaifewshotclassifi": 3, "chatgpt": [3, 4], "5": 3, "gpt": 3, "turbo": 3, "2": 3, "format": [3, 4, 6], "same": [3, 4, 6], "doe": 3, "help": 3, "float": 3, "max_requests_per_minut": 3, "3500": 3, "max_tokens_per_minut": 3, "90000": 3, "max_attempts_per_request": 3, "what": 3, "between": 3, "higher": 3, "valu": 3, "8": [3, 5], "more": [3, 5], "random": 3, "while": [3, 4], "lower": 3, "focus": 3, "determinist": 3, "send": 3, "90_000": 3, "shoulb": 3, "attempt": 3, "fail": 3, "first": 3, "nhere": 3, "coupl": 3, "assign": 3, "expert": [3, 4], "approach": [4, 5], "come": 4, "might": 4, "suit": 4, "want": 4, "case": 4, "requir": [4, 5], "differ": 4, "strategi": 4, "custom": 4, "call": 4, "insert": 4, "llama": 4, "stabl": 4, "beluga": 4, "would": 4, "look": 4, "someth": 4, "current": 4, "question": 4, "let": 4, "sai": 4, "fewshot_prompt": 4, "stage": 4, "emit": 4, "oper": 4, "definit": 4, "train": 5, "effici": 5, "sentenc": 5, "free": 5, "wai": 5, "smaller": 5, "thu": 5, "faster": 5, "employ": 5, "high": 5, "perform": 5, "sinc": 5, "its": 5, "setfitzeroshotclassifi": 5, "setfitfewshotclassifi": 5, "minilm": 5, "l6": 5, "v2": 5, "sample_s": 5, "easili": 6, "emploi": 6, "By": 6, "text2textfewshotclassifi": 6, "ni": 6, "na": 6, "nwith": 6, "seq2seq": 6, "design": 7, "tune": 7, "These": 7, "extens": 7, "function": 7, "includ": 7, "certainti": 7, "zeroshotclassifi": 7, "facebook": 7, "bart": 7, "larg": 7, "mnli": 7, "set_output": 7, "panda": 7, "924671": 7, "006629": 7, "0687": 7, "predict_proba": 7, "n_class": 7, "datafram": 7, "matrix": 7, "disabl": 7}, "objects": {"stormtrooper": [[0, 0, 1, "", "GenerativeFewShotClassifier"], [0, 0, 1, "", "GenerativeZeroShotClassifier"], [3, 0, 1, "", "OpenAIFewShotClassifier"], [3, 0, 1, "", "OpenAIZeroShotClassifier"], [5, 0, 1, "", "SetFitFewShotClassifier"], [5, 0, 1, "", "SetFitZeroShotClassifier"], [6, 0, 1, "", "Text2TextFewShotClassifier"], [6, 0, 1, "", "Text2TextZeroShotClassifier"], [7, 0, 1, "", "ZeroShotClassifier"]], "stormtrooper.GenerativeFewShotClassifier": [[0, 1, 1, "", "classes_"], [0, 1, 1, "", "examples_"], [0, 2, 1, "", "fit"], [0, 2, 1, "", "generate_prompt"], [0, 2, 1, "", "partial_fit"], [0, 2, 1, "", "predict"], [0, 2, 1, "", "run_prompt"], [0, 2, 1, "", "set_score_request"]], "stormtrooper.GenerativeZeroShotClassifier": [[0, 1, 1, "", "classes_"], [0, 2, 1, "", "fit"], [0, 2, 1, "", "partial_fit"], [0, 2, 1, "", "predict"], [0, 2, 1, "", "set_score_request"]], "stormtrooper.OpenAIFewShotClassifier": [[3, 1, 1, "", "classes_"], [3, 2, 1, "", "fit"], [3, 2, 1, "", "partial_fit"], [3, 2, 1, "", "predict"], [3, 2, 1, "", "set_score_request"]], "stormtrooper.OpenAIZeroShotClassifier": [[3, 1, 1, "", "classes_"], [3, 2, 1, "", "fit"], [3, 2, 1, "", "partial_fit"], [3, 2, 1, "", "predict"], [3, 2, 1, "", "set_score_request"]], "stormtrooper.SetFitFewShotClassifier": [[5, 1, 1, "", "classes_"], [5, 2, 1, "", "fit"], [5, 2, 1, "", "predict"], [5, 2, 1, "", "set_score_request"]], "stormtrooper.SetFitZeroShotClassifier": [[5, 1, 1, "", "classes_"], [5, 2, 1, "", "fit"], [5, 2, 1, "", "partial_fit"], [5, 2, 1, "", "predict"], [5, 2, 1, "", "set_score_request"]], "stormtrooper.Text2TextFewShotClassifier": [[6, 1, 1, "", "classes_"], [6, 1, 1, "", "examples_"], [6, 2, 1, "", "fit"], [6, 2, 1, "", "generate_prompt"], [6, 2, 1, "", "partial_fit"], [6, 2, 1, "", "predict"], [6, 2, 1, "", "set_score_request"]], "stormtrooper.Text2TextZeroShotClassifier": [[6, 1, 1, "", "classes_"], [6, 2, 1, "", "fit"], [6, 2, 1, "", "partial_fit"], [6, 2, 1, "", "predict"], [6, 2, 1, "", "set_score_request"]], "stormtrooper.ZeroShotClassifier": [[7, 1, 1, "", "classes_"], [7, 2, 1, "", "fit"], [7, 2, 1, "", "partial_fit"], [7, 2, 1, "", "predict"], [7, 2, 1, "", "predict_proba"], [7, 2, 1, "", "set_output"], [7, 2, 1, "", "set_score_request"], [7, 2, 1, "", "transform"]]}, "objtypes": {"0": "py:class", "1": "py:attribute", "2": "py:method"}, "objnames": {"0": ["py", "class", "Python class"], "1": ["py", "attribute", "Python attribute"], "2": ["py", "method", "Python method"]}, "titleterms": {"gener": 0, "model": [0, 3, 5, 6, 7], "api": [0, 3, 5, 6, 7], "refer": [0, 3, 5, 6, 7], "get": 1, "start": 1, "instal": 1, "usag": 1, "user": 1, "guid": 1, "infer": 2, "gpu": 2, "openai": 3, "prompt": 4, "templat": 4, "setfit": 5, "text2text": 6, "zero": 7, "shot": 7}, "envversion": {"sphinx.domains.c": 3, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 9, "sphinx.domains.index": 1, "sphinx.domains.javascript": 3, "sphinx.domains.math": 2, "sphinx.domains.python": 4, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx": 60}, "alltitles": {"Generative models": [[0, "generative-models"]], "API reference": [[0, "api-reference"], [3, "api-reference"], [5, "api-reference"], [6, "api-reference"], [7, "api-reference"]], "Getting Started": [[1, "getting-started"]], "Installation": [[1, "installation"]], "Usage": [[1, "usage"]], "User guide": [[1, null]], "Inference on GPU": [[2, "inference-on-gpu"]], "OpenAI models": [[3, "openai-models"]], "Prompting": [[4, "prompting"]], "Templates": [[4, "templates"]], "SetFit models": [[5, "setfit-models"]], "Text2Text models": [[6, "text2text-models"]], "Zero-shot models": [[7, "zero-shot-models"]]}, "indexentries": {"generativefewshotclassifier (class in stormtrooper)": [[0, "stormtrooper.GenerativeFewShotClassifier"]], "generativezeroshotclassifier (class in stormtrooper)": [[0, "stormtrooper.GenerativeZeroShotClassifier"]], "classes_ (stormtrooper.generativefewshotclassifier attribute)": [[0, "stormtrooper.GenerativeFewShotClassifier.classes_"]], "classes_ (stormtrooper.generativezeroshotclassifier attribute)": [[0, "stormtrooper.GenerativeZeroShotClassifier.classes_"]], "examples_ (stormtrooper.generativefewshotclassifier attribute)": [[0, "stormtrooper.GenerativeFewShotClassifier.examples_"]], "fit() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.fit"]], "fit() (stormtrooper.generativezeroshotclassifier method)": [[0, "stormtrooper.GenerativeZeroShotClassifier.fit"]], "generate_prompt() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.generate_prompt"]], "partial_fit() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.partial_fit"]], "partial_fit() (stormtrooper.generativezeroshotclassifier method)": [[0, "stormtrooper.GenerativeZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.predict"]], "predict() (stormtrooper.generativezeroshotclassifier method)": [[0, "stormtrooper.GenerativeZeroShotClassifier.predict"]], "run_prompt() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.run_prompt"]], "set_score_request() (stormtrooper.generativefewshotclassifier method)": [[0, "stormtrooper.GenerativeFewShotClassifier.set_score_request"]], "set_score_request() (stormtrooper.generativezeroshotclassifier method)": [[0, "stormtrooper.GenerativeZeroShotClassifier.set_score_request"]], "openaifewshotclassifier (class in stormtrooper)": [[3, "stormtrooper.OpenAIFewShotClassifier"]], "openaizeroshotclassifier (class in stormtrooper)": [[3, "stormtrooper.OpenAIZeroShotClassifier"]], "classes_ (stormtrooper.openaifewshotclassifier attribute)": [[3, "stormtrooper.OpenAIFewShotClassifier.classes_"]], "classes_ (stormtrooper.openaizeroshotclassifier attribute)": [[3, "stormtrooper.OpenAIZeroShotClassifier.classes_"]], "fit() (stormtrooper.openaifewshotclassifier method)": [[3, "stormtrooper.OpenAIFewShotClassifier.fit"]], "fit() (stormtrooper.openaizeroshotclassifier method)": [[3, "stormtrooper.OpenAIZeroShotClassifier.fit"]], "partial_fit() (stormtrooper.openaifewshotclassifier method)": [[3, "stormtrooper.OpenAIFewShotClassifier.partial_fit"]], "partial_fit() (stormtrooper.openaizeroshotclassifier method)": [[3, "stormtrooper.OpenAIZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.openaifewshotclassifier method)": [[3, "stormtrooper.OpenAIFewShotClassifier.predict"]], "predict() (stormtrooper.openaizeroshotclassifier method)": [[3, "stormtrooper.OpenAIZeroShotClassifier.predict"]], "set_score_request() (stormtrooper.openaifewshotclassifier method)": [[3, "stormtrooper.OpenAIFewShotClassifier.set_score_request"]], "set_score_request() (stormtrooper.openaizeroshotclassifier method)": [[3, "stormtrooper.OpenAIZeroShotClassifier.set_score_request"]], "setfitfewshotclassifier (class in stormtrooper)": [[5, "stormtrooper.SetFitFewShotClassifier"]], "setfitzeroshotclassifier (class in stormtrooper)": [[5, "stormtrooper.SetFitZeroShotClassifier"]], "classes_ (stormtrooper.setfitfewshotclassifier attribute)": [[5, "stormtrooper.SetFitFewShotClassifier.classes_"]], "classes_ (stormtrooper.setfitzeroshotclassifier attribute)": [[5, "stormtrooper.SetFitZeroShotClassifier.classes_"]], "fit() (stormtrooper.setfitfewshotclassifier method)": [[5, "stormtrooper.SetFitFewShotClassifier.fit"]], "fit() (stormtrooper.setfitzeroshotclassifier method)": [[5, "stormtrooper.SetFitZeroShotClassifier.fit"]], "partial_fit() (stormtrooper.setfitzeroshotclassifier method)": [[5, "stormtrooper.SetFitZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.setfitfewshotclassifier method)": [[5, "stormtrooper.SetFitFewShotClassifier.predict"]], "predict() (stormtrooper.setfitzeroshotclassifier method)": [[5, "stormtrooper.SetFitZeroShotClassifier.predict"]], "set_score_request() (stormtrooper.setfitfewshotclassifier method)": [[5, "stormtrooper.SetFitFewShotClassifier.set_score_request"]], "set_score_request() (stormtrooper.setfitzeroshotclassifier method)": [[5, "stormtrooper.SetFitZeroShotClassifier.set_score_request"]], "text2textfewshotclassifier (class in stormtrooper)": [[6, "stormtrooper.Text2TextFewShotClassifier"]], "text2textzeroshotclassifier (class in stormtrooper)": [[6, "stormtrooper.Text2TextZeroShotClassifier"]], "classes_ (stormtrooper.text2textfewshotclassifier attribute)": [[6, "stormtrooper.Text2TextFewShotClassifier.classes_"]], "classes_ (stormtrooper.text2textzeroshotclassifier attribute)": [[6, "stormtrooper.Text2TextZeroShotClassifier.classes_"]], "examples_ (stormtrooper.text2textfewshotclassifier attribute)": [[6, "stormtrooper.Text2TextFewShotClassifier.examples_"]], "fit() (stormtrooper.text2textfewshotclassifier method)": [[6, "stormtrooper.Text2TextFewShotClassifier.fit"]], "fit() (stormtrooper.text2textzeroshotclassifier method)": [[6, "stormtrooper.Text2TextZeroShotClassifier.fit"]], "generate_prompt() (stormtrooper.text2textfewshotclassifier method)": [[6, "stormtrooper.Text2TextFewShotClassifier.generate_prompt"]], "partial_fit() (stormtrooper.text2textfewshotclassifier method)": [[6, "stormtrooper.Text2TextFewShotClassifier.partial_fit"]], "partial_fit() (stormtrooper.text2textzeroshotclassifier method)": [[6, "stormtrooper.Text2TextZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.text2textfewshotclassifier method)": [[6, "stormtrooper.Text2TextFewShotClassifier.predict"]], "predict() (stormtrooper.text2textzeroshotclassifier method)": [[6, "stormtrooper.Text2TextZeroShotClassifier.predict"]], "set_score_request() (stormtrooper.text2textfewshotclassifier method)": [[6, "stormtrooper.Text2TextFewShotClassifier.set_score_request"]], "set_score_request() (stormtrooper.text2textzeroshotclassifier method)": [[6, "stormtrooper.Text2TextZeroShotClassifier.set_score_request"]], "zeroshotclassifier (class in stormtrooper)": [[7, "stormtrooper.ZeroShotClassifier"]], "classes_ (stormtrooper.zeroshotclassifier attribute)": [[7, "stormtrooper.ZeroShotClassifier.classes_"]], "fit() (stormtrooper.zeroshotclassifier method)": [[7, "stormtrooper.ZeroShotClassifier.fit"]], "partial_fit() (stormtrooper.zeroshotclassifier method)": [[7, "stormtrooper.ZeroShotClassifier.partial_fit"]], "predict() (stormtrooper.zeroshotclassifier method)": [[7, "stormtrooper.ZeroShotClassifier.predict"]], "predict_proba() (stormtrooper.zeroshotclassifier method)": [[7, "stormtrooper.ZeroShotClassifier.predict_proba"]], "set_output() (stormtrooper.zeroshotclassifier method)": [[7, "stormtrooper.ZeroShotClassifier.set_output"]], "set_score_request() (stormtrooper.zeroshotclassifier method)": [[7, "stormtrooper.ZeroShotClassifier.set_score_request"]], "transform() (stormtrooper.zeroshotclassifier method)": [[7, "stormtrooper.ZeroShotClassifier.transform"]]}}) \ No newline at end of file diff --git a/docs/_build/html/setfit.html b/docs/_build/html/setfit.html index 194977f..fd7bace 100644 --- a/docs/_build/html/setfit.html +++ b/docs/_build/html/setfit.html @@ -3,9 +3,9 @@ - + - + SetFit models - stormtrooper @@ -170,6 +170,7 @@
  • Text2Text models
  • Generative models
  • SetFit models
  • +
  • OpenAI models
  • Prompting
  • Inference on GPU
  • @@ -485,12 +486,12 @@

    API reference - +
    Next
    -
    Prompting
    +
    OpenAI models
    diff --git a/docs/_build/html/text2text.html b/docs/_build/html/text2text.html index c0ae150..03f6d41 100644 --- a/docs/_build/html/text2text.html +++ b/docs/_build/html/text2text.html @@ -5,7 +5,7 @@ - + Text2Text models - stormtrooper @@ -170,6 +170,7 @@
  • Text2Text models
  • Generative models
  • SetFit models
  • +
  • OpenAI models
  • Prompting
  • Inference on GPU
  • diff --git a/docs/_build/html/zeroshot.html b/docs/_build/html/zeroshot.html index c140120..130b0a8 100644 --- a/docs/_build/html/zeroshot.html +++ b/docs/_build/html/zeroshot.html @@ -5,7 +5,7 @@ - + Zero-shot models - stormtrooper @@ -170,6 +170,7 @@
  • Text2Text models
  • Generative models
  • SetFit models
  • +
  • OpenAI models
  • Prompting
  • Inference on GPU
  • diff --git a/docs/index.rst b/docs/index.rst index 20a0d54..a8dd978 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,6 +17,14 @@ If you intend to use SetFit models as well, install stormtrooper with optional d pip install stormtrooper[setfit] +From version 0.4.0 you can also use OpenAI models in stormtrooper. + +.. code-block:: + + pip install stormtrooper[openai] + export OPENAI_API_KEY="sk-..." + + Usage ^^^^^^^^^ @@ -45,6 +53,7 @@ In this example I am going to use Google's FLAN-T5. text2text generative setfit + openai prompting inference_on_gpu diff --git a/docs/openai.rst b/docs/openai.rst new file mode 100644 index 0000000..1ec22f4 --- /dev/null +++ b/docs/openai.rst @@ -0,0 +1,85 @@ +OpenAI models +================= + +Stormtrooper gives you access to OpenAI's chat models for zero and few-shot classification. +You get full control over temperature settings, system and user prompts. +In contrast to other packages, like scikit-llm, stormtrooper also uses Python's asyncio to concurrently +interact with OpenAI's API. This can give multiple times speedup on several tasks. + +You can also set upper limits for number of requests and tokens per minute, so you don't exceed your quota. +This is by default set to the limit of the payed tier on OpenAI's API. + +You need to install stormtrooper with optional dependencies. + +.. code-block:: bash + + pip install stormtrooper[openai] + +You additionally need to set the OpenAI API key as an environment variable. + +.. code-block:: bash + + export OPENAI_API_KEY="sk-..." + # Setting organization is optional + export OPENAI_ORG="org-..." + +.. code-block:: python + + from stormtrooper import OpenAIZeroShotClassifier, OpenAIFewShotClassifier + + sample_text = "It is the Electoral College's responsibility to elect the president." + + labels = ["politics", "science", "other"] + +Here's a zero shot example with ChatGPT 3.5: + +.. code-block:: python + + model = OpenAIZeroShotClassifier("gpt-3.5-turbo").fit(None, labels) + predictions = model.predict([sample_text]) + assert list(predictions) == ["politics"] + +And a few shot example with ChatGPT 4: + +.. code-block:: python + + few_shot_examples = [ + "Joe Biden is the president.", + "Liquid water was found on the moon.", + "Jerry likes football." + ] + + model = OpenAIFewShotClassifier("gpt-4", temperature=0.2).fit(few_shot_examples, labels) + predictions = model.predict([sample_text]) + assert list(predictions) == ["politics"] + + +The format of the prompts is the same as with StableBeluga instruct models, and an error is raised if your prompt does not follow +this format. + +.. code-block:: python + + prompt = """ + ### System: + You are a helpful assistant + ### User: + Your task will be to classify a text document into one + of the following classes: {classes}. + Please respond with a single label that you think fits + the document best. + Classify the following piece of text: + '{X}' + ### Assistant: + """ + + model = OpenAIZeroShotClassifier("gpt-4", prompt=prompt) + + +API reference +^^^^^^^^^^^^^ + +.. autoclass:: stormtrooper.OpenAIZeroShotClassifier + :members: + +.. autoclass:: stormtrooper.OpenAIFewShotClassifier + :members: diff --git a/docs/prompting.rst b/docs/prompting.rst index 865f007..d0b33c0 100644 --- a/docs/prompting.rst +++ b/docs/prompting.rst @@ -1,7 +1,7 @@ Prompting ========= -Text2Text and Generative models use a prompting approach for classification. +Text2Text, Generative, and OpenAI models use a prompting approach for classification. stormtrooper comes with default prompts, but these might not suit the model you want to use, or your use case might require a different prompting strategy from the default. stormtrooper allows you to specify custom prompts in these cases. @@ -12,7 +12,7 @@ Templates Prompting in stormtrooper uses a templating approach, where the .format() method is called on prompts to insert labels and data. -A zero-shot prompt for an instruct Llama model like Stable Beluga would look something like this (this is the default): +A zero-shot prompt for an instruct Llama model like Stable Beluga or for ChatGPT would look something like this (this is the default): .. code-block:: python