diff --git a/README.md b/README.md
index a0079a7e..adf9fa92 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,7 @@ Corresponding medium post can be found [here](https://towardsdatascience.com/key
2.3. [Max Sum Distance](#maxsum)
2.4. [Maximal Marginal Relevance](#maximal)
2.5. [Embedding Models](#embeddings)
+ 3. [Large Language Models](#llms)
@@ -226,6 +227,55 @@ kw_model = KeyBERT(model=roberta)
You can select any 🤗 transformers model [here](https://huggingface.co/models).
+
+## 3. Large Language Models
+[Back to ToC](#toc)
+
+With `KeyLLM` you can new perform keyword extraction with Large Language Models (LLM). You can find the full documentation [here](https://maartengr.github.io/KeyBERT/guides/keyllm.html) but there are two examples that are common with this new method. Make sure to install the OpenAI package through `pip install openai` before you start.
+
+First, we can ask OpenAI directly to extract keywords:
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+
+# Create your LLM
+openai.api_key = "sk-..."
+llm = OpenAI()
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+```
+
+This will query any ChatGPT model and ask it to extract keywords from text.
+
+Second, we can find documents that are likely to have the same keywords and only extract keywords for those.
+This is much more efficient then asking the keywords for every single documents. There are likely documents that
+have the exact same keywords. Doing so is straightforward:
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+from sentence_transformers import SentenceTransformer
+
+# Extract embeddings
+model = SentenceTransformer('all-MiniLM-L6-v2')
+embeddings = model.encode(MY_DOCUMENTS, convert_to_tensor=True)
+
+# Create your LLM
+openai.api_key = "sk-..."
+llm = OpenAI()
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+
+# Extract keywords
+keywords = kw_model.extract_keywords(MY_DOCUMENTS, embeddings=embeddings, threshold=.75)
+```
+
+You can use the `threshold` parameter to decide how similar documents need to be in order to receive the same keywords.
## Citation
To cite KeyBERT in your work, please use the following bibtex reference:
diff --git a/docs/api/cohere.md b/docs/api/cohere.md
new file mode 100644
index 00000000..d9eb8f2a
--- /dev/null
+++ b/docs/api/cohere.md
@@ -0,0 +1,3 @@
+# `Cohere`
+
+::: keybert.llm._cohere.Cohere
diff --git a/docs/api/keyllm.md b/docs/api/keyllm.md
new file mode 100644
index 00000000..24f8ce5e
--- /dev/null
+++ b/docs/api/keyllm.md
@@ -0,0 +1,3 @@
+# `KeyLLM`
+
+::: keybert._llm.KeyLLM
diff --git a/docs/api/langchain.md b/docs/api/langchain.md
new file mode 100644
index 00000000..d0087f05
--- /dev/null
+++ b/docs/api/langchain.md
@@ -0,0 +1,3 @@
+# `LangChain`
+
+::: keybert.llm._langchain.LangChain
diff --git a/docs/api/litellm.md b/docs/api/litellm.md
new file mode 100644
index 00000000..e3608f78
--- /dev/null
+++ b/docs/api/litellm.md
@@ -0,0 +1,3 @@
+# `LiteLLM`
+
+::: keybert.llm._litellm.LiteLLM
diff --git a/docs/api/openai.md b/docs/api/openai.md
new file mode 100644
index 00000000..9b6f36d0
--- /dev/null
+++ b/docs/api/openai.md
@@ -0,0 +1,3 @@
+# `OpenAI`
+
+::: keybert.llm._openai.OpenAI
diff --git a/docs/api/textgeneration.md b/docs/api/textgeneration.md
new file mode 100644
index 00000000..a18639f3
--- /dev/null
+++ b/docs/api/textgeneration.md
@@ -0,0 +1,3 @@
+# `TextGeneration`
+
+::: keybert.llm._textgeneration.TextGeneration
diff --git a/docs/changelog.md b/docs/changelog.md
index a6ae4948..5a7c8ad9 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -3,6 +3,34 @@ hide:
- navigation
---
+## **Version 0.8.0**
+*Release date: 27 September, 2023*
+
+**Highlights**:
+
+* Use `KeyLLM` to leverage LLMs for extracting keywords
+ * Use it either with or without candidate keywords generated through `KeyBERT`
+ * Multiple LLMs are integrated: OpenAI, Cohere, LangChain, HF, and LiteLLM
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+
+# Create your LLM
+openai.api_key = "sk-..."
+llm = OpenAI()
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+```
+
+See [here](https://maartengr.github.io/KeyBERT/guides/keyllm.html) for full documentation on use cases of `KeyLLM` and [here](https://maartengr.github.io/KeyBERT/guides/llms.html) for the implemented Large Language Models.
+
+**Fixes**:
+
+* Enable Guided KeyBERT for seed keywords differing among docs by [@shengbo-ma](https://github.com/shengbo-ma) in [#152](https://github.com/MaartenGr/KeyBERT/pull/152)
+
## **Version 0.7.0**
*Release date: 3 November, 2022*
diff --git a/docs/guides/keyllm.md b/docs/guides/keyllm.md
new file mode 100644
index 00000000..94f36526
--- /dev/null
+++ b/docs/guides/keyllm.md
@@ -0,0 +1,255 @@
+A minimal method for keyword extraction with Large Language Models (LLM). There are a number of implementations that allow you to mix and match `KeyBERT` with `KeyLLM`. You could also choose to use `KeyLLM` without `KeyBERT`.
+
+
+--8<-- "docs/images/keyllm.svg"
+
+
+We start with an example of some data:
+
+```python
+documents = [
+"The website mentions that it only takes a couple of days to deliver but I still have not received mine.",
+"I received my package!",
+"Whereas the most powerful LLMs have generally been accessible only through limited APIs (if at all), Meta released LLaMA's model weights to the research community under a noncommercial license."
+]
+```
+
+This data was chosen to show the different use cases and techniques. As you might have noticed documents 1 and 2 are quite similar whereas document 3 is about an entirely different subject. This similarity will be taken into account when using `KeyBERT` together with `KeyLLM`
+
+Let's start with `KeyLLM` only.
+
+# Use Cases
+
+If you want the full performance and easiest method, you can skip the use cases below and go straight to number 5 where you will combine `KeyBERT` with `KeyLLM`.
+
+!!! Tip
+ If you want to use KeyLLM without any of the HuggingFace packages, you can install it as follows:
+ `pip install keybert --no-deps`
+ `pip install scikit-learn numpy rich tqdm`
+ This will make the installation much smaller and the import much quicker.
+
+## 1. **Create** Keywords with `KeyLLM`
+
+We start by creating keywords for each document. This creation process is simply asking the LLM to come up with a bunch of keywords for each document. The focus here is on **creating** keywords which refers to the idea that the keywords do not necessarily need to appear in the input documents.
+
+Install the relevant LLM first:
+
+```bash
+pip install openai
+```
+
+Then we can use any OpenAI model, such as ChatGPT, as follows:
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+
+# Create your LLM
+openai.api_key = "sk-..."
+llm = OpenAI()
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+
+# Extract keywords
+keywords = kw_model.extract_keywords(documents)
+```
+
+This creates the following keywords:
+
+```python
+[['Website',
+ 'Delivery',
+ 'Mention',
+ 'Timeframe',
+ 'Not received',
+ 'Order fulfillment'],
+ ['Package', 'Received', 'Delivery', 'Order fulfillment'],
+ ['Powerful LLMs',
+ 'Limited APIs',
+ 'Meta',
+ 'Model weights',
+ 'Research community',
+ '']]
+```
+
+## 2. **Extract** Keywords with `KeyLLM`
+
+Instead of creating keywords out of thin air, we ask the LLM to check whether they actually appear in the text and limit the keywords to those that are found in the documents. We do this by using a custom prompt together with `check_vocab=True`:
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+
+# Create your LLM
+openai.api_key = "sk-..."
+
+prompt = """
+I have the following document:
+[DOCUMENT]
+
+Based on the information above, extract the keywords that best describe the topic of the text.
+Make sure to only extract keywords that appear in the text.
+Use the following format separated by commas:
+
+"""
+llm = OpenAI()
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+
+# Extract keywords
+keywords = kw_model.extract_keywords(documents, check_vocab=True); keywords
+```
+
+This creates the following keywords:
+
+```python
+[['website', 'couple of days', 'deliver', 'received'],
+ ['package', 'received'],
+ ['LLMs',
+ 'APIs',
+ 'Meta',
+ 'LLaMA',
+ 'model weights',
+ 'research community',
+ 'noncommercial license']]
+```
+
+## 3. **Fine-tune** Candidate Keywords
+
+If you already have a list of keywords, you could fine-tune them by asking the LLM to come up with nicer tags or names that we could use. We can use the `[CANDIDATES]` tag in the prompt to assign where they should go.
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+
+# Create your LLM
+openai.api_key = "sk-..."
+
+prompt = """
+I have the following document:
+[DOCUMENT]
+
+With the following candidate keywords:
+[CANDIDATES]
+
+Based on the information above, improve the candidate keywords to best describe the topic of the document.
+
+Use the following format separated by commas:
+
+"""
+llm = OpenAI(model="gpt-3.5-turbo", prompt=prompt, chat=True)
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+
+# Extract keywords
+candidate_keywords = [['website', 'couple of days', 'deliver', 'received'],
+ ['received', 'package'],
+ ['most powerful LLMs',
+ 'limited APIs',
+ 'Meta',
+ "LLaMA's model weights",
+ 'research community',
+ 'noncommercial license']]
+keywords = kw_model.extract_keywords(documents, candidate_keywords=candidate_keywords); keywords
+```
+
+This creates the following keywords:
+
+```python
+[['delivery timeframe', 'discrepancy', 'website', 'order status'],
+ ['received package'],
+ ['most powerful language models',
+ 'API limitations',
+ "Meta's release",
+ "LLaMA's model weights",
+ 'research community access',
+ 'noncommercial licensing']]
+```
+
+## 4. **Efficient** `KeyLLM`
+
+If you have embeddings of your documents, you could use those to find documents that are most similar to one another. Those documents could then all receive the same keywords and only one of these documents will need to be passed to the LLM. This can make computation much faster as only a subset of documents will need to receive keywords.
+
+
+--8<-- "docs/images/efficient.svg"
+
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+from sentence_transformers import SentenceTransformer
+
+# Extract embeddings
+model = SentenceTransformer('all-MiniLM-L6-v2')
+embeddings = model.encode(documents, convert_to_tensor=True)
+
+# Create your LLM
+openai.api_key = "sk-..."
+llm = OpenAI()
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+
+# Extract keywords
+keywords = kw_model.extract_keywords(documents, embeddings=embeddings, threshold=.75)
+```
+
+This creates the following keywords:
+
+```python
+[['Website',
+ 'Delivery',
+ 'Mention',
+ 'Timeframe',
+ 'Not received',
+ 'Waiting',
+ 'Order fulfillment'],
+ ['Received', 'Package', 'Delivery', 'Order fulfillment'],
+ ['Powerful LLMs', 'Limited APIs', 'Meta', 'LLaMA', 'Model weights']]
+```
+
+
+## 5. **Efficient** `KeyLLM` + `KeyBERT`
+
+This is the best of both worlds. We use `KeyBERT` to generate a first pass of keywords and embeddings and give those to `KeyLLM` for a final pass. Again, the most similar documents will be clustered and they will all receive the same keywords. You can change this behavior with `threshold`. A higher value will reduce the number of documents that are clustered and a lower value will increase the number of documents that are clustered.
+
+
+--8<-- "docs/images/keybert_keyllm.svg"
+
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM, KeyBERT
+
+# Create your LLM
+openai.api_key = "sk-..."
+llm = OpenAI()
+
+# Load it in KeyLLM
+kw_model = KeyBERT(llm=llm)
+
+# Extract keywords
+keywords = kw_model.extract_keywords(documents); keywords
+```
+
+This creates the following keywords:
+
+```python
+[['Website',
+ 'Delivery',
+ 'Timeframe',
+ 'Mention',
+ 'Order fulfillment',
+ 'Not received',
+ 'Waiting'],
+ ['Package', 'Received', 'Confirmation', 'Delivery', 'Order fulfillment'],
+ ['LLMs', 'Limited APIs', 'Meta', 'LLaMA', 'Model weights', '']]
+```
diff --git a/docs/guides/llms.md b/docs/guides/llms.md
new file mode 100644
index 00000000..b57cf7d4
--- /dev/null
+++ b/docs/guides/llms.md
@@ -0,0 +1,203 @@
+# Large Language Models (LLM)
+In this tutorial we will be going through the Large Language Models (LLM) that can be used in KeyLLM.
+Having the option to choose the LLM allow you to leverage the model that suit your use-case.
+
+### **OpenAI**
+To use OpenAI's external API, we need to define our key and use the `keybert.llm.OpenAI` model.
+
+We install the package first:
+
+```bash
+pip install openai
+```
+
+Then we run OpenAI as follows:
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+
+# Create your OpenAI LLM
+openai.api_key = "sk-..."
+llm = OpenAI()
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+
+# Extract keywords
+keywords = kw_model.extract_keywords(MY_DOCUMENTS)
+```
+
+If you want to use a chat-based model, please run the following instead:
+
+```python
+import openai
+from keybert.llm import OpenAI
+from keybert import KeyLLM
+
+# Create your LLM
+openai.api_key = "sk-..."
+llm = OpenAI(model="gpt-3.5-turbo", chat=True)
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+```
+
+### **Cohere**
+To use Cohere's external API, we need to define our key and use the `keybert.llm.Cohere` model.
+
+We install the package first:
+
+```bash
+pip install cohere
+```
+
+Then we run Cohere as follows:
+
+
+```python
+import cohere
+from keybert.llm import Cohere
+from keybert import KeyLLM
+
+# Create your OpenAI LLM
+co = cohere.Client(my_api_key)
+llm = Cohere(co)
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+
+# Extract keywords
+keywords = kw_model.extract_keywords(MY_DOCUMENTS)
+```
+
+### **LiteLLM**
+[LiteLLM](https://github.com/BerriAI/litellm) allows you to use any closed-source LLM with KeyLLM
+
+We install the package first:
+
+```bash
+pip install litellm
+```
+
+
+Let's use OpenAI as an example:
+
+```python
+import os
+from keybert.llm import LiteLLM
+from keybert import KeyLLM
+
+# Select LLM
+os.environ["OPENAI_API_KEY"] = "sk-..."
+llm = LiteLLM("gpt-3.5-turbo")
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+```
+
+### 🤗 **Hugging Face Transformers**
+To use a Hugging Face transformers model, load in a pipeline and point
+to any model found on their model hub (https://huggingface.co/models). Let's use Llama 2 as an example:
+
+```python
+from torch import cuda, bfloat16
+import transformers
+
+model_id = 'meta-llama/Llama-2-7b-chat-hf'
+
+# 4-bit Quantization to load Llama 2 with less GPU memory
+bnb_config = transformers.BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type='nf4',
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_compute_dtype=bfloat16
+)
+
+# Llama 2 Model & Tokenizer
+tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
+model = transformers.AutoModelForCausalLM.from_pretrained(
+ model_id,
+ trust_remote_code=True,
+ quantization_config=bnb_config,
+ device_map='auto',
+)
+model.eval()
+
+# Our text generator
+generator = transformers.pipeline(
+ model=model, tokenizer=tokenizer,
+ task='text-generation',
+ temperature=0.1,
+ max_new_tokens=500,
+ repetition_penalty=1.1
+)
+```
+
+Then, we load the `generator` in `KeyLLM` with a custom prompt:
+
+```python
+from keybert.llm import TextGeneration
+from keybert import KeyLLM
+
+prompt = """
+[INST] <>
+
+You are a helpful assistant specialized in extracting comma-separated keywords.
+You are to the point and only give the answer in isolation without any chat-based fluff.
+
+<>
+I have the following document:
+- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
+
+Please give me the keywords that are present in this document and separate them with commas.
+Make sure you to only return the keywords and say nothing else. For example, don't say:
+"Here are the keywords present in the document"
+[/INST] meat, beef, eat, eating, emissions, steak, food, health, processed, chicken [INST]
+
+I have the following document:
+- [DOCUMENT]
+
+Please give me the keywords that are present in this document and separate them with commas.
+Make sure you to only return the keywords and say nothing else. For example, don't say:
+"Here are the keywords present in the document"
+[/INST]
+"""
+
+# Load it in KeyLLM
+llm = TextGeneration(generator, prompt=prompt)
+kw_model = KeyLLM(llm)
+```
+
+### **LangChain**
+
+To use LangChain, we can simply load in any LLM and pass that as a QA-chain to KeyLLM.
+
+We install the package first:
+
+```bash
+pip install langchain
+```
+
+Then we run LangChain as follows:
+
+
+```python
+from langchain.chains.question_answering import load_qa_chain
+from langchain.llms import OpenAI
+chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff")
+```
+
+Finally, you can pass the chain to KeyBERT as follows:
+
+```python
+from keybert.llm import LangChain
+from keybert import KeyLLM
+
+# Create your LLM
+llm = LangChain(chain)
+
+# Load it in KeyLLM
+kw_model = KeyLLM(llm)
+```
\ No newline at end of file
diff --git a/docs/images/efficient.svg b/docs/images/efficient.svg
new file mode 100644
index 00000000..3778e14a
--- /dev/null
+++ b/docs/images/efficient.svg
@@ -0,0 +1,157 @@
+
diff --git a/docs/images/keybert_keyllm.svg b/docs/images/keybert_keyllm.svg
new file mode 100644
index 00000000..feb02a9a
--- /dev/null
+++ b/docs/images/keybert_keyllm.svg
@@ -0,0 +1,15 @@
+
diff --git a/docs/images/keyllm.svg b/docs/images/keyllm.svg
new file mode 100644
index 00000000..57eb9a4a
--- /dev/null
+++ b/docs/images/keyllm.svg
@@ -0,0 +1,14 @@
+
diff --git a/keybert/__init__.py b/keybert/__init__.py
index 048b901d..fece795c 100644
--- a/keybert/__init__.py
+++ b/keybert/__init__.py
@@ -1,3 +1,4 @@
+from keybert._llm import KeyLLM
from keybert._model import KeyBERT
-__version__ = "0.7.0"
+__version__ = "0.8.0"
diff --git a/keybert/_llm.py b/keybert/_llm.py
new file mode 100644
index 00000000..ca8ff22e
--- /dev/null
+++ b/keybert/_llm.py
@@ -0,0 +1,144 @@
+from typing import List, Union
+
+try:
+ from sentence_transformers import util
+ HAS_SBERT = True
+except ModuleNotFoundError:
+ HAS_SBERT = False
+
+
+class KeyLLM:
+ """
+ A minimal method for keyword extraction with Large Language Models (LLM)
+
+ The keyword extraction is done by simply asking the LLM to extract a
+ number of keywords from a single piece of text.
+ """
+
+ def __init__(self, llm):
+ """KeyBERT initialization
+
+ Arguments:
+ llm: The Large Language Model to use
+ """
+ self.llm = llm
+
+ def extract_keywords(
+ self,
+ docs: Union[str, List[str]],
+ check_vocab: bool = False,
+ candidate_keywords: List[List[str]] = None,
+ threshold: float = None,
+ embeddings=None
+ ) -> Union[List[str], List[List[str]]]:
+ """Extract keywords and/or keyphrases
+
+ To get the biggest speed-up, make sure to pass multiple documents
+ at once instead of iterating over a single document.
+
+ NOTE: The resulting keywords are expected to be separated by commas so
+ any changes to the prompt will have to make sure that the resulting
+ keywords are comma-separated.
+
+ Arguments:
+ docs: The document(s) for which to extract keywords/keyphrases
+ check_vocab: Only return keywords that appear exactly in the documents
+ candidate_keywords: Candidate keywords for each document
+
+ Returns:
+ keywords: The top n keywords for a document with their respective distances
+ to the input document.
+
+ Usage:
+
+ To extract keywords from a single document:
+
+ ```python
+ import openai
+ from keybert.llm import OpenAI
+ from keybert import KeyLLM
+
+ # Create your LLM
+ openai.api_key = "sk-..."
+ llm = OpenAI()
+
+ # Load it in KeyLLM
+ kw_model = KeyLLM(llm)
+
+ # Extract keywords
+ document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
+ keywords = kw_model.extract_keywords(document)
+ ```
+ """
+ # Check for a single, empty document
+ if isinstance(docs, str):
+ if docs:
+ docs = [docs]
+ else:
+ return []
+
+ if HAS_SBERT and threshold is not None and embeddings is not None:
+
+ # Find similar documents
+ clusters = util.community_detection(embeddings, min_community_size=2, threshold=threshold)
+ in_cluster = set([cluster for cluster_set in clusters for cluster in cluster_set])
+ out_cluster = set(list(range(len(docs)))).difference(in_cluster)
+
+ # Extract keywords for all documents not in a cluster
+ if out_cluster:
+ selected_docs = [docs[index] for index in out_cluster]
+ print(out_cluster, selected_docs)
+ if candidate_keywords is not None:
+ selected_keywords = [candidate_keywords[index] for index in out_cluster]
+ else:
+ selected_keywords = None
+ print(f"Call LLM with {len(selected_docs)} docs; out-cluster")
+ out_cluster_keywords = self.llm.extract_keywords(
+ selected_docs,
+ selected_keywords,
+ )
+ out_cluster_keywords = {index: words for words, index in zip(out_cluster_keywords, out_cluster)}
+
+ # Extract keywords for only the first document in a cluster
+ if in_cluster:
+ selected_docs = [docs[cluster[0]] for cluster in clusters]
+ print(in_cluster, selected_docs)
+ if candidate_keywords is not None:
+ selected_keywords = [candidate_keywords[cluster[0]] for cluster in in_cluster]
+ else:
+ selected_keywords = None
+ print(f"Call LLM with {len(selected_docs)} docs; in-cluster")
+ in_cluster_keywords = self.llm.extract_keywords(
+ selected_docs,
+ selected_keywords
+ )
+ in_cluster_keywords = {
+ doc_id: in_cluster_keywords[index]
+ for index, cluster in enumerate(clusters)
+ for doc_id in cluster
+ }
+
+ # Update out cluster keywords with in cluster keywords
+ if out_cluster:
+ if in_cluster:
+ out_cluster_keywords.update(in_cluster_keywords)
+ print(out_cluster_keywords)
+ keywords = [out_cluster_keywords[index] for index in range(len(docs))]
+ else:
+ keywords = [in_cluster_keywords[index] for index in range(len(docs))]
+ else:
+ # Extract keywords using a Large Language Model (LLM)
+ keywords = self.llm.extract_keywords(docs, candidate_keywords)
+
+ # Only extract keywords that appear in the input document
+ if check_vocab:
+ updated_keywords = []
+ for keyword_set, document in zip(keywords, docs):
+ updated_keyword_set = []
+ for keyword in keyword_set:
+ if keyword in document:
+ updated_keyword_set.append(keyword)
+ updated_keywords.append(updated_keyword_set)
+ return updated_keywords
+
+ return keywords
diff --git a/keybert/_model.py b/keybert/_model.py
index 475bdbef..564a994e 100644
--- a/keybert/_model.py
+++ b/keybert/_model.py
@@ -14,6 +14,8 @@
from keybert._maxsum import max_sum_distance
from keybert._highlight import highlight_document
from keybert.backend._utils import select_backend
+from keybert.llm._base import BaseLLM
+from keybert import KeyLLM
class KeyBERT:
@@ -36,7 +38,7 @@ class KeyBERT:
"""
- def __init__(self, model="all-MiniLM-L6-v2"):
+ def __init__(self, model="all-MiniLM-L6-v2", llm: BaseLLM = None):
"""KeyBERT initialization
Arguments:
@@ -54,6 +56,11 @@ def __init__(self, model="all-MiniLM-L6-v2"):
"""
self.model = select_backend(model)
+ if isinstance(llm, BaseLLM):
+ self.llm = KeyLLM(llm)
+ else:
+ self.llm = llm
+
def extract_keywords(
self,
docs: Union[str, List[str]],
@@ -71,6 +78,7 @@ def extract_keywords(
seed_keywords: Union[List[str], List[List[str]]] = None,
doc_embeddings: np.array = None,
word_embeddings: np.array = None,
+ threshold: float = None
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
"""Extract keywords and/or keyphrases
@@ -245,6 +253,19 @@ def extract_keywords(
highlight_document(docs[0], all_keywords[0], count)
all_keywords = all_keywords[0]
+ # Fine-tune keywords using an LLM
+ if self.llm is not None:
+ if isinstance(all_keywords[0], tuple):
+ candidate_keywords = [[keyword[0] for keyword in all_keywords]]
+ else:
+ candidate_keywords = [[keyword[0] for keyword in keywords] for keywords in all_keywords]
+ keywords = self.llm.extract_keywords(
+ docs,
+ embeddings=doc_embeddings,
+ candidate_keywords=candidate_keywords,
+ threshold=threshold
+ )
+ return keywords
return all_keywords
def extract_embeddings(
diff --git a/keybert/_utils.py b/keybert/_utils.py
new file mode 100644
index 00000000..4ba45741
--- /dev/null
+++ b/keybert/_utils.py
@@ -0,0 +1,22 @@
+class NotInstalled:
+ """
+ This object is used to notify the user that additional dependencies need to be
+ installed in order to use the string matching model.
+ """
+
+ def __init__(self, tool, dep, custom_msg=None):
+ self.tool = tool
+ self.dep = dep
+
+ msg = f"In order to use {self.tool} you will need to install via;\n\n"
+ if custom_msg is not None:
+ msg += custom_msg
+ else:
+ msg += f"pip install bertopic[{self.dep}]\n\n"
+ self.msg = msg
+
+ def __getattr__(self, *args, **kwargs):
+ raise ModuleNotFoundError(self.msg)
+
+ def __call__(self, *args, **kwargs):
+ raise ModuleNotFoundError(self.msg)
\ No newline at end of file
diff --git a/keybert/backend/_utils.py b/keybert/backend/_utils.py
index 69217c6a..ba1eabe6 100644
--- a/keybert/backend/_utils.py
+++ b/keybert/backend/_utils.py
@@ -1,7 +1,4 @@
from ._base import BaseEmbedder
-from ._sentencetransformers import SentenceTransformerBackend
-from ._hftransformers import HFTransformerBackend
-from transformers.pipelines import Pipeline
def select_backend(embedding_model) -> BaseEmbedder:
@@ -42,14 +39,22 @@ def select_backend(embedding_model) -> BaseEmbedder:
# Sentence Transformer embeddings
if "sentence_transformers" in str(type(embedding_model)):
+ from ._sentencetransformers import SentenceTransformerBackend
return SentenceTransformerBackend(embedding_model)
# Create a Sentence Transformer model based on a string
if isinstance(embedding_model, str):
+ from ._sentencetransformers import SentenceTransformerBackend
return SentenceTransformerBackend(embedding_model)
# Hugging Face embeddings
- if isinstance(embedding_model, Pipeline):
- return HFTransformerBackend(embedding_model)
-
+ try:
+ from transformers.pipelines import Pipeline
+ if isinstance(embedding_model, Pipeline):
+ from ._hftransformers import HFTransformerBackend
+ return HFTransformerBackend(embedding_model)
+ except ImportError:
+ pass
+
+ from ._sentencetransformers import SentenceTransformerBackend
return SentenceTransformerBackend("paraphrase-multilingual-MiniLM-L12-v2")
diff --git a/keybert/llm/__init__.py b/keybert/llm/__init__.py
new file mode 100644
index 00000000..e7ec4ee4
--- /dev/null
+++ b/keybert/llm/__init__.py
@@ -0,0 +1,48 @@
+from keybert._utils import NotInstalled
+from keybert.llm._base import BaseLLM
+
+
+# TextGeneration
+try:
+ from keybert.llm._textgeneration import TextGeneration
+except ModuleNotFoundError:
+ msg = "`pip install keybert` \n\n"
+ TextGeneration = NotInstalled("TextGeneration", "keybert", custom_msg=msg)
+
+# OpenAI Generator
+try:
+ from keybert.llm._openai import OpenAI
+except ModuleNotFoundError:
+ msg = "`pip install openai` \n\n"
+ OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg)
+
+# Cohere Generator
+try:
+ from keybert.llm._cohere import Cohere
+except ModuleNotFoundError:
+ msg = "`pip install cohere` \n\n"
+ Cohere = NotInstalled("Cohere", "cohere", custom_msg=msg)
+
+# LangChain Generator
+try:
+ from keybert.llm._langchain import LangChain
+except ModuleNotFoundError:
+ msg = "`pip install langchain` \n\n"
+ LangChain = NotInstalled("langchain", "langchain", custom_msg=msg)
+
+# LiteLLM
+try:
+ from keybert.llm._litellm import LiteLLM
+except ModuleNotFoundError:
+ msg = "`pip install litellm` \n\n"
+ LiteLLM = NotInstalled("LiteLLM", "litellm", custom_msg=msg)
+
+
+__all__ = [
+ "BaseLLM",
+ "Cohere",
+ "OpenAI",
+ "TextGeneration",
+ "LangChain",
+ "LiteLLM"
+]
diff --git a/keybert/llm/_base.py b/keybert/llm/_base.py
new file mode 100644
index 00000000..359d1df0
--- /dev/null
+++ b/keybert/llm/_base.py
@@ -0,0 +1,20 @@
+from sklearn.base import BaseEstimator
+from typing import List
+
+
+class BaseLLM(BaseEstimator):
+ """ The base representation model for fine-tuning topic representations """
+ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
+ """ Extract topics
+
+ Arguments:
+ documents: The documents to extract keywords from
+ candidate_keywords: A list of candidate keywords that the LLM will fine-tune
+ For example, it will create a nicer representation of
+ the candidate keywords, remove redundant keywords, or
+ shorten them depending on the input prompt.
+
+ Returns:
+ all_keywords: All keywords for each document
+ """
+ return [None for document in documents]
diff --git a/keybert/llm/_cohere.py b/keybert/llm/_cohere.py
new file mode 100644
index 00000000..fdf0d3d6
--- /dev/null
+++ b/keybert/llm/_cohere.py
@@ -0,0 +1,131 @@
+import time
+from tqdm import tqdm
+from typing import List
+from keybert.llm._base import BaseLLM
+from keybert.llm._utils import process_candidate_keywords
+
+
+DEFAULT_PROMPT = """
+The following is a list of documents. Please extract the top keywords, separated by a comma, that describe the topic of the texts.
+
+Document:
+- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
+
+Keywords: Traditional diets, Plant-based, Meat, Industrial style meat production, Factory farming, Staple food, Cultural dietary practices
+
+Document:
+- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
+
+Keywords: Website, Delivery, Mention, Timeframe, Not received, Waiting, Order fulfillment
+
+Document:
+- [DOCUMENT]
+
+Keywords:"""
+
+
+class Cohere(BaseLLM):
+ """ Use the Cohere API to generate topic labels based on their
+ generative model.
+
+ Find more about their models here:
+ https://docs.cohere.ai/docs
+
+ NOTE: The resulting keywords are expected to be separated by commas so
+ any changes to the prompt will have to make sure that the resulting
+ keywords are comma-separated.
+
+ Arguments:
+ client: A cohere.Client
+ model: Model to use within Cohere, defaults to `"xlarge"`.
+ prompt: The prompt to be used in the model. If no prompt is given,
+ `self.default_prompt_` is used instead.
+ NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
+ to decide where the keywords and documents need to be
+ inserted.
+ delay_in_seconds: The delay in seconds between consecutive prompts
+ in order to prevent RateLimitErrors.
+ verbose: Set this to True if you want to see a progress bar for the
+ keyword extraction.
+
+ Usage:
+
+ To use this, you will need to install cohere first:
+
+ `pip install cohere`
+
+ Then, get yourself an API key and use Cohere's API as follows:
+
+ ```python
+ import cohere
+ from keybert.llm import Cohere
+ from keybert import KeyLLM
+
+ # Create your LLM
+ co = cohere.Client(my_api_key)
+ llm = Cohere(co)
+
+ # Load it in KeyLLM
+ kw_model = KeyLLM(llm)
+
+ # Extract keywords
+ document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
+ keywords = kw_model.extract_keywords(document)
+ ```
+
+ You can also use a custom prompt:
+
+ ```python
+ prompt = "I have the following document: [DOCUMENT]. What keywords does it contain? Make sure to separate the keywords with commas."
+ llm = Cohere(co, prompt=prompt)
+ ```
+ """
+ def __init__(self,
+ client,
+ model: str = "xlarge",
+ prompt: str = None,
+ delay_in_seconds: float = None,
+ verbose: bool = False
+ ):
+ self.client = client
+ self.model = model
+ self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
+ self.default_prompt_ = DEFAULT_PROMPT
+ self.delay_in_seconds = delay_in_seconds
+ self.verbose = verbose
+
+ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
+ """ Extract topics
+
+ Arguments:
+ documents: The documents to extract keywords from
+ candidate_keywords: A list of candidate keywords that the LLM will fine-tune
+ For example, it will create a nicer representation of
+ the candidate keywords, remove redundant keywords, or
+ shorten them depending on the input prompt.
+
+ Returns:
+ all_keywords: All keywords for each document
+ """
+ all_keywords = []
+ candidate_keywords = process_candidate_keywords(documents, candidate_keywords)
+
+ for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
+ prompt = self.prompt.replace("[DOCUMENT]", document)
+ if candidates is not None:
+ prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))
+
+ # Delay
+ if self.delay_in_seconds:
+ time.sleep(self.delay_in_seconds)
+
+ request = self.client.generate(model=self.model,
+ prompt=prompt,
+ max_tokens=50,
+ num_generations=1,
+ stop_sequences=["\n"])
+ keywords = request.generations[0].text.strip()
+ keywords = [keyword.strip() for keyword in keywords.split(",")]
+ all_keywords.append(keywords)
+
+ return all_keywords
diff --git a/keybert/llm/_langchain.py b/keybert/llm/_langchain.py
new file mode 100644
index 00000000..4741c8ce
--- /dev/null
+++ b/keybert/llm/_langchain.py
@@ -0,0 +1,104 @@
+from tqdm import tqdm
+from typing import List
+from langchain.docstore.document import Document
+from keybert.llm._base import BaseLLM
+from keybert.llm._utils import process_candidate_keywords
+
+
+DEFAULT_PROMPT = "What is this document about? Please provide keywords separated by commas."
+
+
+class LangChain(BaseLLM):
+ """ Using chains in langchain to generate keywords.
+
+ Currently, only chains from question answering is implemented. See:
+ https://langchain.readthedocs.io/en/latest/modules/chains/combine_docs_examples/question_answering.html
+
+ NOTE: The resulting keywords are expected to be separated by commas so
+ any changes to the prompt will have to make sure that the resulting
+ keywords are comma-separated.
+
+ Arguments:
+ chain: A langchain chain that has two input parameters, `input_documents` and `query`.
+ prompt: The prompt to be used in the model. If no prompt is given,
+ `self.default_prompt_` is used instead.
+ verbose: Set this to True if you want to see a progress bar for the
+ keyword extraction.
+
+ Usage:
+
+ To use this, you will need to install the langchain package first.
+ Additionally, you will need an underlying LLM to support langchain,
+ like openai:
+
+ `pip install langchain`
+ `pip install openai`
+
+ Then, you can create your chain as follows:
+
+ ```python
+ from langchain.chains.question_answering import load_qa_chain
+ from langchain.llms import OpenAI
+ chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff")
+ ```
+
+ Finally, you can pass the chain to KeyBERT as follows:
+
+ ```python
+ from keybert.llm import LangChain
+ from keybert import KeyLLM
+
+ # Create your LLM
+ llm = LangChain(chain)
+
+ # Load it in KeyLLM
+ kw_model = KeyLLM(llm)
+
+ # Extract keywords
+ document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
+ keywords = kw_model.extract_keywords(document)
+ ```
+
+ You can also use a custom prompt:
+
+ ```python
+ prompt = "What are these documents about? Please give a single label."
+ llm = LangChain(chain, prompt=prompt)
+ ```
+ """
+ def __init__(self,
+ chain,
+ prompt: str = None,
+ verbose: bool = False,
+ ):
+ self.chain = chain
+ self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
+ self.default_prompt_ = DEFAULT_PROMPT
+ self.verbose = verbose
+
+ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
+ """ Extract topics
+
+ Arguments:
+ documents: The documents to extract keywords from
+ candidate_keywords: A list of candidate keywords that the LLM will fine-tune
+ For example, it will create a nicer representation of
+ the candidate keywords, remove redundant keywords, or
+ shorten them depending on the input prompt.
+
+ Returns:
+ all_keywords: All keywords for each document
+ """
+ all_keywords = []
+ candidate_keywords = process_candidate_keywords(documents, candidate_keywords)
+
+ for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
+ prompt = self.prompt.replace("[DOCUMENT]", document)
+ if candidates is not None:
+ prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))
+ input_document = Document(page_content=document)
+ keywords = self.chain.run(input_documents=input_document, question=self.prompt).strip()
+ keywords = [keyword.strip() for keyword in keywords.split(",")]
+ all_keywords.append(keywords)
+
+ return all_keywords
diff --git a/keybert/llm/_litellm.py b/keybert/llm/_litellm.py
new file mode 100644
index 00000000..f0e55469
--- /dev/null
+++ b/keybert/llm/_litellm.py
@@ -0,0 +1,129 @@
+import time
+from tqdm import tqdm
+from litellm import completion
+from typing import Mapping, Any, List
+from keybert.llm._base import BaseLLM
+from keybert.llm._utils import process_candidate_keywords
+
+
+DEFAULT_PROMPT = """
+I have the following document:
+[DOCUMENT]
+
+Based on the information above, extract the keywords that best describe the topic of the text.
+Use the following format separated by commas:
+
+"""
+
+
+class LiteLLM(BaseLLM):
+ """ Extract keywords using LiteLLM to call any LLM API using OpenAI format
+ such as Anthropic, Huggingface, Cohere, TogetherAI, Azure, OpenAI, etc.
+
+ NOTE: The resulting keywords are expected to be separated by commas so
+ any changes to the prompt will have to make sure that the resulting
+ keywords are comma-separated.
+
+ Arguments:
+ model: Model to use within LiteLLM, defaults to OpenAI's `"gpt-3.5-turbo"`.
+ generator_kwargs: Kwargs passed to `litellm.completion`
+ for fine-tuning the output.
+ prompt: The prompt to be used in the model. If no prompt is given,
+ `self.default_prompt_` is used instead.
+ NOTE: Use `"[DOCUMENT]"` in the prompt
+ to decide where the document needs to be inserted
+ delay_in_seconds: The delay in seconds between consecutive prompts
+ in order to prevent RateLimitErrors.
+ verbose: Set this to True if you want to see a progress bar for the
+ keyword extraction.
+
+ Usage:
+
+ Let's use OpenAI as an example:
+
+ ```python
+ import os
+ from keybert.llm import LiteLLM
+ from keybert import KeyLLM
+
+ # Select LLM
+ os.environ["OPENAI_API_KEY"] = "sk-..."
+ llm = LiteLLM("gpt-3.5-turbo")
+
+ # Load it in KeyLLM
+ kw_model = KeyLLM(llm)
+
+ # Extract keywords
+ document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
+ keywords = kw_model.extract_keywords(document)
+ ```
+
+ You can also use a custom prompt:
+
+ ```python
+ prompt = "I have the following document: [DOCUMENT] \nThis document contains the following keywords separated by commas: '"
+ llm = LiteLLM("gpt-3.5-turbo", prompt=prompt)
+ ```
+ """
+ def __init__(self,
+ model: str = "gpt-3.5-turbo",
+ prompt: str = None,
+ generator_kwargs: Mapping[str, Any] = {},
+ delay_in_seconds: float = None,
+ verbose: bool = False
+ ):
+ self.model = model
+
+ if prompt is None:
+ self.prompt = DEFAULT_PROMPT
+ else:
+ self.prompt = prompt
+
+ self.default_prompt_ = DEFAULT_PROMPT
+ self.delay_in_seconds = delay_in_seconds
+ self.verbose = verbose
+
+ self.generator_kwargs = generator_kwargs
+ if self.generator_kwargs.get("model"):
+ self.model = generator_kwargs.get("model")
+ if self.generator_kwargs.get("prompt"):
+ del self.generator_kwargs["prompt"]
+
+ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
+ """ Extract topics
+
+ Arguments:
+ documents: The documents to extract keywords from
+ candidate_keywords: A list of candidate keywords that the LLM will fine-tune
+ For example, it will create a nicer representation of
+ the candidate keywords, remove redundant keywords, or
+ shorten them depending on the input prompt.
+
+ Returns:
+ all_keywords: All keywords for each document
+ """
+ all_keywords = []
+ candidate_keywords = process_candidate_keywords(documents, candidate_keywords)
+
+ for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
+ prompt = self.prompt.replace("[DOCUMENT]", document)
+ if candidates is not None:
+ prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))
+
+ # Delay
+ if self.delay_in_seconds:
+ time.sleep(self.delay_in_seconds)
+
+ # Use a chat model
+ messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt}
+ ]
+ kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
+
+ response = completion(**kwargs)
+ keywords = response["choices"][0]["message"]["content"].strip()
+ keywords = [keyword.strip() for keyword in keywords.split(",")]
+ all_keywords.append(keywords)
+
+ return all_keywords
diff --git a/keybert/llm/_openai.py b/keybert/llm/_openai.py
new file mode 100644
index 00000000..178abee6
--- /dev/null
+++ b/keybert/llm/_openai.py
@@ -0,0 +1,210 @@
+import time
+import openai
+from tqdm import tqdm
+from typing import Mapping, Any, List
+from keybert.llm._base import BaseLLM
+from keybert.llm._utils import retry_with_exponential_backoff, process_candidate_keywords
+
+
+DEFAULT_PROMPT = """
+The following is a list of documents. Please extract the top keywords, separated by a comma, that describe the topic of the texts.
+
+Document:
+- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
+
+Keywords: Traditional diets, Plant-based, Meat, Industrial style meat production, Factory farming, Staple food, Cultural dietary practices
+
+Document:
+- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
+
+Keywords: Website, Delivery, Mention, Timeframe, Not received, Waiting, Order fulfillment
+
+Document:
+- [DOCUMENT]
+
+Keywords:"""
+
+DEFAULT_CHAT_PROMPT = """
+I have the following document:
+[DOCUMENT]
+
+Based on the information above, extract the keywords that best describe the topic of the text.
+Use the following format separated by commas:
+
+"""
+
+
+class OpenAI(BaseLLM):
+ """ Using the OpenAI API to extract keywords
+
+ The default method is `openai.Completion` if `chat=False`.
+ The prompts will also need to follow a completion task. If you
+ are looking for a more interactive chats, use `chat=True`
+ with `model=gpt-3.5-turbo`.
+
+ For an overview see:
+ https://platform.openai.com/docs/models
+
+ NOTE: The resulting keywords are expected to be separated by commas so
+ any changes to the prompt will have to make sure that the resulting
+ keywords are comma-separated.
+
+ Arguments:
+ model: Model to use within OpenAI, defaults to `"text-ada-001"`.
+ NOTE: If a `gpt-3.5-turbo` model is used, make sure to set
+ `chat` to True.
+ generator_kwargs: Kwargs passed to `openai.Completion.create`
+ for fine-tuning the output.
+ prompt: The prompt to be used in the model. If no prompt is given,
+ `self.default_prompt_` is used instead.
+ NOTE: Use `"[DOCUMENT]"` in the prompt
+ to decide where the document needs to be inserted
+ delay_in_seconds: The delay in seconds between consecutive prompts
+ in order to prevent RateLimitErrors.
+ exponential_backoff: Retry requests with a random exponential backoff.
+ A short sleep is used when a rate limit error is hit,
+ then the requests is retried. Increase the sleep length
+ if errors are hit until 10 unsuccesfull requests.
+ If True, overrides `delay_in_seconds`.
+ chat: Set this to True if a chat model is used. Generally, this GPT 3.5 or higher
+ See: https://platform.openai.com/docs/models/gpt-3-5
+ verbose: Set this to True if you want to see a progress bar for the
+ keyword extraction.
+
+ Usage:
+
+ To use this, you will need to install the openai package first:
+
+ `pip install openai`
+
+ Then, get yourself an API key and use OpenAI's API as follows:
+
+ ```python
+ import openai
+ from keybert.llm import OpenAI
+ from keybert import KeyLLM
+
+ # Create your LLM
+ openai.api_key = "sk-..."
+ llm = OpenAI()
+
+ # Load it in KeyLLM
+ kw_model = KeyLLM(llm)
+
+ # Extract keywords
+ document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
+ keywords = kw_model.extract_keywords(document)
+ ```
+
+ You can also use a custom prompt:
+
+ ```python
+ prompt = "I have the following document: [DOCUMENT] \nThis document contains the following keywords separated by commas: '"
+ llm = OpenAI(prompt=prompt, delay_in_seconds=5)
+ ```
+
+ If you want to use OpenAI's ChatGPT model:
+
+ ```python
+ llm = OpenAI(model="gpt-3.5-turbo", delay_in_seconds=10, chat=True)
+ ```
+ """
+ def __init__(self,
+ model: str = "gpt-3.5-turbo-instruct",
+ prompt: str = None,
+ generator_kwargs: Mapping[str, Any] = {},
+ delay_in_seconds: float = None,
+ exponential_backoff: bool = False,
+ chat: bool = False,
+ verbose: bool = False
+ ):
+ self.model = model
+
+ if prompt is None:
+ self.prompt = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT
+ else:
+ self.prompt = prompt
+
+ self.default_prompt_ = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT
+ self.delay_in_seconds = delay_in_seconds
+ self.exponential_backoff = exponential_backoff
+ self.chat = chat
+ self.verbose = verbose
+
+ self.generator_kwargs = generator_kwargs
+ if self.generator_kwargs.get("model"):
+ self.model = generator_kwargs.get("model")
+ if self.generator_kwargs.get("prompt"):
+ del self.generator_kwargs["prompt"]
+ if not self.generator_kwargs.get("stop") and not chat:
+ self.generator_kwargs["stop"] = "\n"
+
+ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
+ """ Extract topics
+
+ Arguments:
+ documents: The documents to extract keywords from
+ candidate_keywords: A list of candidate keywords that the LLM will fine-tune
+ For example, it will create a nicer representation of
+ the candidate keywords, remove redundant keywords, or
+ shorten them depending on the input prompt.
+
+ Returns:
+ all_keywords: All keywords for each document
+ """
+ all_keywords = []
+ candidate_keywords = process_candidate_keywords(documents, candidate_keywords)
+
+ for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
+ prompt = self.prompt.replace("[DOCUMENT]", document)
+ if candidates is not None:
+ prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))
+
+ # Delay
+ if self.delay_in_seconds:
+ time.sleep(self.delay_in_seconds)
+
+ # Use a chat model
+ if self.chat:
+ messages = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt}
+ ]
+ kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
+ if self.exponential_backoff:
+ response = chat_completions_with_backoff(**kwargs)
+ else:
+ response = openai.ChatCompletion.create(**kwargs)
+ keywords = response["choices"][0]["message"]["content"].strip()
+
+ # Use a non-chat model
+ else:
+ if self.exponential_backoff:
+ response = completions_with_backoff(model=self.model, prompt=prompt, **self.generator_kwargs)
+ else:
+ response = openai.Completion.create(model=self.model, prompt=prompt, **self.generator_kwargs)
+ keywords = response["choices"][0]["text"].strip()
+ keywords = [keyword.strip() for keyword in keywords.split(",")]
+ all_keywords.append(keywords)
+
+ return all_keywords
+
+
+def completions_with_backoff(**kwargs):
+ return retry_with_exponential_backoff(
+ openai.Completion.create,
+ errors=(
+ openai.error.RateLimitError,
+ openai.error.ServiceUnavailableError,
+ ),
+ )(**kwargs)
+
+
+def chat_completions_with_backoff(**kwargs):
+ return retry_with_exponential_backoff(
+ openai.ChatCompletion.create,
+ errors=(
+ openai.error.RateLimitError,
+ openai.error.ServiceUnavailableError,
+ ),
+ )(**kwargs)
diff --git a/keybert/llm/_textgeneration.py b/keybert/llm/_textgeneration.py
new file mode 100644
index 00000000..f505caf4
--- /dev/null
+++ b/keybert/llm/_textgeneration.py
@@ -0,0 +1,120 @@
+from tqdm import tqdm
+from transformers import pipeline, set_seed
+from transformers.pipelines.base import Pipeline
+from typing import Mapping, List, Any, Union
+from keybert.llm._base import BaseLLM
+from keybert.llm._utils import process_candidate_keywords
+
+
+DEFAULT_PROMPT = """
+I have the following document:
+* [DOCUMENT]
+
+Please give me the keywords that are present in this document and separate them with commas:
+"""
+
+
+class TextGeneration(BaseLLM):
+ """ Text2Text or text generation with transformers
+
+ NOTE: The resulting keywords are expected to be separated by commas so
+ any changes to the prompt will have to make sure that the resulting
+ keywords are comma-separated.
+
+ Arguments:
+ model: A transformers pipeline that should be initialized as "text-generation"
+ for gpt-like models or "text2text-generation" for T5-like models.
+ For example, `pipeline('text-generation', model='gpt2')`. If a string
+ is passed, "text-generation" will be selected by default.
+ prompt: The prompt to be used in the model. If no prompt is given,
+ `self.default_prompt_` is used instead.
+ NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
+ to decide where the keywords and documents need to be
+ inserted.
+ pipeline_kwargs: Kwargs that you can pass to the transformers.pipeline
+ when it is called.
+ random_state: A random state to be passed to `transformers.set_seed`
+ verbose: Set this to True if you want to see a progress bar for the
+ keyword extraction.
+
+ Usage:
+
+ To use a gpt-like model:
+
+ ```python
+ from keybert.llm import TextGeneration
+ from keybert import KeyLLM
+
+ # Create your LLM
+ generator = pipeline('text-generation', model='gpt2')
+ llm = TextGeneration(generator)
+
+ # Load it in KeyLLM
+ kw_model = KeyLLM(llm)
+
+ # Extract keywords
+ document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
+ keywords = kw_model.extract_keywords(document)
+ ```
+
+ You can use a custom prompt and decide where the document should
+ be inserted with the `[DOCUMENT]` tag:
+
+ ```python
+ from keybert.llm import TextGeneration
+
+ prompt = "I have the following documents '[DOCUMENT]'. Please give me the keywords that are present in this document and separate them with commas:"
+
+ # Create your representation model
+ generator = pipeline('text2text-generation', model='google/flan-t5-base')
+ llm = TextGeneration(generator)
+ ```
+ """
+ def __init__(self,
+ model: Union[str, pipeline],
+ prompt: str = None,
+ pipeline_kwargs: Mapping[str, Any] = {},
+ random_state: int = 42,
+ verbose: bool = False
+ ):
+ set_seed(random_state)
+ if isinstance(model, str):
+ self.model = pipeline("text-generation", model=model)
+ elif isinstance(model, Pipeline):
+ self.model = model
+ else:
+ raise ValueError("Make sure that the HF model that you"
+ "pass is either a string referring to a"
+ "HF model or a `transformers.pipeline` object.")
+ self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
+ self.default_prompt_ = DEFAULT_PROMPT
+ self.pipeline_kwargs = pipeline_kwargs
+ self.verbose = verbose
+
+ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
+ """ Extract topics
+
+ Arguments:
+ documents: The documents to extract keywords from
+ candidate_keywords: A list of candidate keywords that the LLM will fine-tune
+ For example, it will create a nicer representation of
+ the candidate keywords, remove redundant keywords, or
+ shorten them depending on the input prompt.
+
+ Returns:
+ all_keywords: All keywords for each document
+ """
+ all_keywords = []
+ candidate_keywords = process_candidate_keywords(documents, candidate_keywords)
+
+ for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
+ prompt = self.prompt.replace("[DOCUMENT]", document)
+ if candidates is not None:
+ prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))
+
+ # Extract result from generator and use that as label
+ keywords = self.model(prompt, **self.pipeline_kwargs)[0]["generated_text"].replace(prompt, "")
+ keywords = [keyword.strip() for keyword in keywords.split(",")]
+ all_keywords.append(keywords)
+
+ return all_keywords
diff --git a/keybert/llm/_utils.py b/keybert/llm/_utils.py
new file mode 100644
index 00000000..6ca8bdd3
--- /dev/null
+++ b/keybert/llm/_utils.py
@@ -0,0 +1,57 @@
+import random
+import time
+
+
+def process_candidate_keywords(documents, candidate_keywords):
+ """Create a common format for candidate keywords."""
+ if candidate_keywords is None:
+ candidate_keywords = [None for _ in documents]
+ elif isinstance(candidate_keywords[0][0], str) and not isinstance(candidate_keywords[0], list):
+ candidate_keywords = [[keyword for keyword, _ in candidate_keywords]]
+ elif isinstance(candidate_keywords[0][0], tuple):
+ candidate_keywords = [[keyword for keyword, _ in keywords] for keywords in candidate_keywords]
+ return candidate_keywords
+
+
+def retry_with_exponential_backoff(
+ func,
+ initial_delay: float = 1,
+ exponential_base: float = 2,
+ jitter: bool = True,
+ max_retries: int = 10,
+ errors: tuple = None,
+):
+ """Retry a function with exponential backoff."""
+
+ def wrapper(*args, **kwargs):
+ # Initialize variables
+ num_retries = 0
+ delay = initial_delay
+
+ # Loop until a successful response or max_retries is hit or an exception is raised
+ while True:
+ try:
+ return func(*args, **kwargs)
+
+ # Retry on specific errors
+ except errors as e:
+ # Increment retries
+ num_retries += 1
+
+ # Check if max retries has been reached
+ if num_retries > max_retries:
+ raise Exception(
+ f"Maximum number of retries ({max_retries}) exceeded."
+ )
+
+ # Increment the delay
+ delay *= exponential_base * (1 + jitter * random.random())
+
+ # Sleep for the delay
+ time.sleep(delay)
+
+ # Raise exceptions for any errors not specified
+ except Exception as e:
+ raise e
+
+ return wrapper
diff --git a/mkdocs.yml b/mkdocs.yml
index 9135d529..029418ec 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -14,10 +14,19 @@ nav:
- Quickstart: guides/quickstart.md
- Embedding Models: guides/embeddings.md
- CountVectorizer: guides/countvectorizer.md
+ - KeyLLM: guides/keyllm.md
+ - LLMs: guides/llms.md
- API:
- KeyBERT: api/keybert.md
- MMR: api/mmr.md
- MaxSum: api/maxsum.md
+ - KeyLLM: api/keyllm.md
+ - LLM:
+ - OpenAI: api/openai.md
+ - Cohere: api/cohere.md
+ - LangChain: api/langchain.md
+ - TextGeneration: api/textgeneration.md
+ - LiteLLM: api/litellm.md
- FAQ: faq.md
- Changelog: changelog.md
diff --git a/setup.py b/setup.py
index 3b6789e0..c1478967 100644
--- a/setup.py
+++ b/setup.py
@@ -37,7 +37,7 @@
setup(
name="keybert",
packages=find_packages(exclude=["notebooks", "docs"]),
- version="0.7.0",
+ version="0.8.0",
author="Maarten Grootendorst",
author_email="maartengrootendorst@gmail.com",
description="KeyBERT performs keyword extraction with state-of-the-art transformer models.",