Skip to content

Commit

Permalink
Merge pull request #2 from centre-for-humanities-computing/messages_api
Browse files Browse the repository at this point in the history
Chat API
  • Loading branch information
x-tabdeveloping authored Aug 7, 2024
2 parents 72c20cb + 591e3b6 commit ddd2430
Show file tree
Hide file tree
Showing 50 changed files with 13,662 additions and 2,598 deletions.
39 changes: 18 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,24 @@ Other packages promise to provide at least similar functionality (scikit-llm), w
- We opted for as bare-bones of an implementation and little coupling as possible. The library works at the lowest level of abstraction possible, and we hope our code will be rather easy for others to understand and contribute to.


## News :fire:
## New in version 0.5.0

- Hugging Face's Text Generation Inference is now supported in stormtrooper and can be used to speed up inference with generative and text2text LLMs. (0.4.1)
- You can now use OpenAI's chat models with blazing fast :zap: async inference. (0.4.0)
- SetFit is now part of the library and can be used in scikit-learn workflows. (0.3.0)
stormtrooper now uses chat templates from HuggingFace transformers for generative models.
This means that you no longer have to pass model-specific prompt templates to these and can define system and user prompts separately.

```python
from stormtrooper import GenerativeZeroShotClassifier

system_prompt = "You're a helpful assistant."
user_prompt = """
Classify a text into one of the following categories: {classes}
Text to clasify:
"{X}"
"""

model = GenerativeZeroShotClassifier().fit(None, ["political", "not political"])
model.predict("Joe Biden is no longer the candidate of the Democrats.")
```


## Examples
Expand Down Expand Up @@ -61,23 +74,7 @@ classifier = ZeroShotClassifier().fit(None, class_labels)
Generative models (GPT, Llama):
```python
from stormtrooper import GenerativeZeroShotClassifier
# You can hand-craft prompts if it suits you better, but
# a default prompt is already available
prompt = """
### System:
You are a literary expert tasked with labeling texts according to
their content.
Please follow the user's instructions as precisely as you can.
### User:
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
classifier = GenerativeZeroShotClassifier(prompt=prompt).fit(None, class_labels)
classifier = GenerativeZeroShotClassifier("meta-llama/Meta-Llama-3.1-8B-Instruct").fit(None, class_labels)
```

Text2Text models (T5):
Expand Down
Binary file modified docs/_build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/_build/doctrees/generative.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/index.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/inference_on_gpu.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/openai.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/prompting.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/setfit.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/text2text.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/textgen.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/zeroshot.doctree
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/_build/html/.buildinfo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: 603d7b972299f380c68753763fcac81d
config: e537640ea487a422d2d04661c45db399
tags: 645f666f9bcd5a90fca523b33c5a78b7
13 changes: 6 additions & 7 deletions docs/_build/html/_sources/openai.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,24 @@ And a few shot example with ChatGPT 4:
assert list(predictions) == ["politics"]
The format of the prompts is the same as with StableBeluga instruct models, and an error is raised if your prompt does not follow
this format.
Prompts have to be specified the same way as with other generative models.

.. code-block:: python
system_prompt = """
You're a helpful assistant.
"""
prompt = """
### System:
You are a helpful assistant
### User:
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
model = OpenAIZeroShotClassifier("gpt-4", prompt=prompt)
model = OpenAIZeroShotClassifier("gpt-4", prompt=prompt, system_prompt=system_prompt)
API reference
Expand Down
17 changes: 8 additions & 9 deletions docs/_build/html/_sources/prompting.rst.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
Prompting
=========

Text2Text, Generative, TGI and OpenAI models use a prompting approach for classification.
stormtrooper comes with default prompts, but these might not suit the model you want to use,
or your use case might require a different prompting strategy from the default.
Text2Text, Generative, and OpenAI models use a prompting approach for classification.
stormtrooper comes with default prompts, but these might not be the best for your model or use case.
stormtrooper allows you to specify custom prompts in these cases.

Templates
Expand All @@ -12,27 +11,27 @@ Templates
Prompting in stormtrooper uses a templating approach, where the .format() method is called on prompts to
insert labels and data.

A zero-shot prompt for an instruct Llama model like Stable Beluga or for ChatGPT would look something like this (this is the default):
You can specify both user and system prompts for Generative and OpenAI models:

.. code-block:: python
prompt = """
### System:
system_prompt = """
You are a classification model that is really good at following
instructions and produces brief answers
that users can use as data right away.
Please follow the user's instructions as precisely as you can.
### User:
"""
user_prompt = """
Your task will be to classify a text document into one
of the following classes: {classes}.
Please respond with a single label that you think fits
the document best.
Classify the following piece of text:
'{X}'
### Assistant:
"""
model = GenerativeZeroShotClassifier("stabilityai/StableBeluga-13B", prompt=prompt)
model = GenerativeZeroShotClassifier("stabilityai/StableBeluga-13B", prompt=prompt, system_prompt=system_prompt)
X represents the current text in question, while classes represents the classes learned from the data.

Expand Down
134 changes: 134 additions & 0 deletions docs/_build/html/_static/_sphinx_javascript_frameworks_compat.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* _sphinx_javascript_frameworks_compat.js
* ~~~~~~~~~~
*
* Compatability shim for jQuery and underscores.js.
*
* WILL BE REMOVED IN Sphinx 6.0
* xref RemovedInSphinx60Warning
*
*/

/**
* select a different prefix for underscore
*/
$u = _.noConflict();


/**
* small helper function to urldecode strings
*
* See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL
*/
jQuery.urldecode = function(x) {
if (!x) {
return x
}
return decodeURIComponent(x.replace(/\+/g, ' '));
};

/**
* small helper function to urlencode strings
*/
jQuery.urlencode = encodeURIComponent;

/**
* This function returns the parsed url parameters of the
* current request. Multiple values per key are supported,
* it will always return arrays of strings for the value parts.
*/
jQuery.getQueryParameters = function(s) {
if (typeof s === 'undefined')
s = document.location.search;
var parts = s.substr(s.indexOf('?') + 1).split('&');
var result = {};
for (var i = 0; i < parts.length; i++) {
var tmp = parts[i].split('=', 2);
var key = jQuery.urldecode(tmp[0]);
var value = jQuery.urldecode(tmp[1]);
if (key in result)
result[key].push(value);
else
result[key] = [value];
}
return result;
};

/**
* highlight a given string on a jquery object by wrapping it in
* span elements with the given class name.
*/
jQuery.fn.highlightText = function(text, className) {
function highlight(node, addItems) {
if (node.nodeType === 3) {
var val = node.nodeValue;
var pos = val.toLowerCase().indexOf(text);
if (pos >= 0 &&
!jQuery(node.parentNode).hasClass(className) &&
!jQuery(node.parentNode).hasClass("nohighlight")) {
var span;
var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg");
if (isInSVG) {
span = document.createElementNS("http://www.w3.org/2000/svg", "tspan");
} else {
span = document.createElement("span");
span.className = className;
}
span.appendChild(document.createTextNode(val.substr(pos, text.length)));
node.parentNode.insertBefore(span, node.parentNode.insertBefore(
document.createTextNode(val.substr(pos + text.length)),
node.nextSibling));
node.nodeValue = val.substr(0, pos);
if (isInSVG) {
var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect");
var bbox = node.parentElement.getBBox();
rect.x.baseVal.value = bbox.x;
rect.y.baseVal.value = bbox.y;
rect.width.baseVal.value = bbox.width;
rect.height.baseVal.value = bbox.height;
rect.setAttribute('class', className);
addItems.push({
"parent": node.parentNode,
"target": rect});
}
}
}
else if (!jQuery(node).is("button, select, textarea")) {
jQuery.each(node.childNodes, function() {
highlight(this, addItems);
});
}
}
var addItems = [];
var result = this.each(function() {
highlight(this, addItems);
});
for (var i = 0; i < addItems.length; ++i) {
jQuery(addItems[i].parent).before(addItems[i].target);
}
return result;
};

/*
* backward compatibility for jQuery.browser
* This will be supported until firefox bug is fixed.
*/
if (!jQuery.browser) {
jQuery.uaMatch = function(ua) {
ua = ua.toLowerCase();

var match = /(chrome)[ \/]([\w.]+)/.exec(ua) ||
/(webkit)[ \/]([\w.]+)/.exec(ua) ||
/(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) ||
/(msie) ([\w.]+)/.exec(ua) ||
ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) ||
[];

return {
browser: match[ 1 ] || "",
version: match[ 2 ] || "0"
};
};
jQuery.browser = {};
jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true;
}
27 changes: 1 addition & 26 deletions docs/_build/html/_static/basic.css
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*
* Sphinx stylesheet -- basic theme.
*
* :copyright: Copyright 2007-2023 by the Sphinx team, see AUTHORS.
* :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS.
* :license: BSD, see LICENSE for details.
*
*/
Expand Down Expand Up @@ -237,10 +237,6 @@ a.headerlink {
visibility: hidden;
}

a:visited {
color: #551A8B;
}

h1:hover > a.headerlink,
h2:hover > a.headerlink,
h3:hover > a.headerlink,
Expand Down Expand Up @@ -328,15 +324,13 @@ aside.sidebar {
p.sidebar-title {
font-weight: bold;
}

nav.contents,
aside.topic,
div.admonition, div.topic, blockquote {
clear: left;
}

/* -- topics ---------------------------------------------------------------- */

nav.contents,
aside.topic,
div.topic {
Expand Down Expand Up @@ -612,7 +606,6 @@ ol.simple p,
ul.simple p {
margin-bottom: 0;
}

aside.footnote > span,
div.citation > span {
float: left;
Expand Down Expand Up @@ -674,16 +667,6 @@ dd {
margin-left: 30px;
}

.sig dd {
margin-top: 0px;
margin-bottom: 0px;
}

.sig dl {
margin-top: 0px;
margin-bottom: 0px;
}

dl > dd:last-child,
dl > dd:last-child > :last-child {
margin-bottom: 0;
Expand Down Expand Up @@ -752,14 +735,6 @@ abbr, acronym {
cursor: help;
}

.translated {
background-color: rgba(207, 255, 207, 0.2)
}

.untranslated {
background-color: rgba(255, 207, 207, 0.2)
}

/* -- code displays --------------------------------------------------------- */

pre {
Expand Down
2 changes: 1 addition & 1 deletion docs/_build/html/_static/debug.css
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ body {
.sb-footer__inner {
background: salmon;
}
.sb-article {
[role="main"] {
background: white;
}
Loading

0 comments on commit ddd2430

Please sign in to comment.