diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml
index 2f3ffbbc..40c1ee7c 100644
--- a/.github/workflows/testing.yml
+++ b/.github/workflows/testing.yml
@@ -11,6 +11,14 @@ on:
- dev
jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ # Ref: https://github.com/pre-commit/action
+ - uses: pre-commit/action@v3.0.1
+
build:
runs-on: ubuntu-latest
strategy:
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0b0391e4..759fa5a6 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -11,15 +11,9 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
-- repo: https://github.com/PyCQA/flake8
- rev: 7.1.0
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.5.1
hooks:
- - id: flake8
-- repo: https://github.com/psf/black
- rev: 24.4.2
- hooks:
- - id: black
- exclude: |
- (?x)^(
- README.md
- )$
+ - id: ruff
+ args: [--fix, --show-fixes, --exit-non-zero-on-fix]
+ - id: ruff-format
diff --git a/docs/changelog.md b/docs/changelog.md
index a2cd9909..d40fc2c2 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -59,7 +59,7 @@ kw_model = KeyLLM(llm)
* Use `KeyLLM` to leverage LLMs for extracting keywords
* Use it either with or without candidate keywords generated through `KeyBERT`
- * Multiple LLMs are integrated: OpenAI, Cohere, LangChain, HF, and LiteLLM
+ * Multiple LLMs are integrated: OpenAI, Cohere, LangChain, HF, and LiteLLM
```python
import openai
@@ -101,7 +101,7 @@ doc_embeddings, word_embeddings = kw_model.extract_embeddings(docs)
keywords = kw_model.extract_keywords(docs, doc_embeddings=doc_embeddings, word_embeddings=word_embeddings)
```
-Do note that the parameters passed to `.extract_embeddings` for creating the vectorizer should be exactly the same as those in `.extract_keywords`.
+Do note that the parameters passed to `.extract_embeddings` for creating the vectorizer should be exactly the same as those in `.extract_keywords`.
**Fixes**:
@@ -137,7 +137,7 @@ kw_model = KeyBERT(model=hf_model)
**NOTE**: Although highlighting for Chinese texts is improved, since I am not familiar with the Chinese language there is a good chance it is not yet as optimized as for other languages. Any feedback with respect to this is highly appreciated!
-**Fixes**:
+**Fixes**:
* Fix typo in ReadMe by [@priyanshul-govil](https://github.com/priyanshul-govil) in [#117](https://github.com/MaartenGr/KeyBERT/pull/117)
* Add missing optional dependencies (gensim, use, and spacy) by [@yusuke1997](https://github.com/yusuke1997)
diff --git a/docs/faq.md b/docs/faq.md
index 04814585..c6a26e70 100644
--- a/docs/faq.md
+++ b/docs/faq.md
@@ -21,11 +21,11 @@ topic modeling to HTML-code to extract topics of code, then it becomes important
## **How can I speed up the model?**
-Since KeyBERT uses large language models as its backend, a GPU is typically prefered when using this package.
+Since KeyBERT uses large language models as its backend, a GPU is typically prefered when using this package.
Although it is possible to use it without a dedicated GPU, the inference speed will be significantly slower.
-A second method for speeding up KeyBERT is by passing it multiple documents at once. By doing this, words
-need to only be embedded a single time, which can result in a major speed up.
+A second method for speeding up KeyBERT is by passing it multiple documents at once. By doing this, words
+need to only be embedded a single time, which can result in a major speed up.
This is **faster**:
diff --git a/docs/guides/embeddings.md b/docs/guides/embeddings.md
index fc1728de..b2a155d5 100644
--- a/docs/guides/embeddings.md
+++ b/docs/guides/embeddings.md
@@ -21,7 +21,7 @@ kw_model = KeyBERT(model=sentence_model)
```
### 🤗 **Hugging Face Transformers**
-To use a Hugging Face transformers model, load in a pipeline and point
+To use a Hugging Face transformers model, load in a pipeline and point
to any model found on their model hub (https://huggingface.co/models):
```python
@@ -32,8 +32,8 @@ kw_model = KeyBERT(model=hf_model)
```
!!! tip "Tip!"
- These transformers also work quite well using `sentence-transformers` which has a number of
- optimizations tricks that make using it a bit faster.
+ These transformers also work quite well using `sentence-transformers` which has a number of
+ optimizations tricks that make using it a bit faster.
### **Flair**
[Flair](https://github.com/flairNLP/flair) allows you to choose almost any embedding model that
diff --git a/docs/guides/keyllm.md b/docs/guides/keyllm.md
index f1913fc4..e0b2f33c 100644
--- a/docs/guides/keyllm.md
+++ b/docs/guides/keyllm.md
@@ -16,7 +16,7 @@ documents = [
This data was chosen to show the different use cases and techniques. As you might have noticed documents 1 and 2 are quite similar whereas document 3 is about an entirely different subject. This similarity will be taken into account when using `KeyBERT` together with `KeyLLM`
-Let's start with `KeyLLM` only.
+Let's start with `KeyLLM` only.
# Use Cases
@@ -180,7 +180,7 @@ If you have embeddings of your documents, you could use those to find documents
!!! Tip
- Before you get started, it might be worthwhile to uninstall sentence-transformers and re-install it from the main branch.
+ Before you get started, it might be worthwhile to uninstall sentence-transformers and re-install it from the main branch.
There is an issue with community detection (cluster) that might make the model run without finishing. It is as straightforward as:
`pip uninstall sentence-transformers`
`pip install --upgrade git+https://github.com/UKPLab/sentence-transformers`
@@ -231,7 +231,7 @@ This is the best of both worlds. We use `KeyBERT` to generate a first pass of ke
!!! Tip
- Before you get started, it might be worthwhile to uninstall sentence-transformers and re-install it from the main branch.
+ Before you get started, it might be worthwhile to uninstall sentence-transformers and re-install it from the main branch.
There is an issue with community detection (cluster) that might make the model run without finishing. It is as straightforward as:
`pip uninstall sentence-transformers`
`pip install --upgrade git+https://github.com/UKPLab/sentence-transformers`
diff --git a/docs/guides/llms.md b/docs/guides/llms.md
index 10e65cab..b7b07f0d 100644
--- a/docs/guides/llms.md
+++ b/docs/guides/llms.md
@@ -3,7 +3,7 @@ In this tutorial we will be going through the Large Language Models (LLM) that c
Having the option to choose the LLM allow you to leverage the model that suit your use-case.
### **OpenAI**
-To use OpenAI's external API, we need to define our key and use the `keybert.llm.OpenAI` model.
+To use OpenAI's external API, we need to define our key and use the `keybert.llm.OpenAI` model.
We install the package first:
@@ -98,7 +98,7 @@ kw_model = KeyLLM(llm)
```
### 🤗 **Hugging Face Transformers**
-To use a Hugging Face transformers model, load in a pipeline and point
+To use a Hugging Face transformers model, load in a pipeline and point
to any model found on their model hub (https://huggingface.co/models). Let's use Llama 2 as an example:
```python
@@ -109,8 +109,8 @@ model_id = 'meta-llama/Llama-2-7b-chat-hf'
# 4-bit Quantization to load Llama 2 with less GPU memory
bnb_config = transformers.BitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_quant_type='nf4',
+ load_in_4bit=True,
+ bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=bfloat16
)
@@ -152,7 +152,7 @@ I have the following document:
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
Please give me the keywords that are present in this document and separate them with commas.
-Make sure you to only return the keywords and say nothing else. For example, don't say:
+Make sure you to only return the keywords and say nothing else. For example, don't say:
"Here are the keywords present in the document"
[/INST] meat, beef, eat, eating, emissions, steak, food, health, processed, chicken [INST]
@@ -160,7 +160,7 @@ I have the following document:
- [DOCUMENT]
Please give me the keywords that are present in this document and separate them with commas.
-Make sure you to only return the keywords and say nothing else. For example, don't say:
+Make sure you to only return the keywords and say nothing else. For example, don't say:
"Here are the keywords present in the document"
[/INST]
"""
@@ -200,4 +200,4 @@ llm = LangChain(chain)
# Load it in KeyLLM
kw_model = KeyLLM(llm)
-```
\ No newline at end of file
+```
diff --git a/docs/guides/quickstart.md b/docs/guides/quickstart.md
index d6e2871e..615fdbf4 100644
--- a/docs/guides/quickstart.md
+++ b/docs/guides/quickstart.md
@@ -78,9 +78,9 @@ keywords = kw_model.extract_keywords(doc, highlight=True)
## **Fine-tuning**
-As a default, KeyBERT simply compares the documents and candidate keywords/keyphrases based on their cosine similarity. However, this might lead
-to very similar words ending up in the list of most accurate keywords/keyphrases. To make sure they are a bit more diversified, there are two
-approaches that we can take in order to fine-tune our output, **Max Sum Distance** and **Maximal Marginal Relevance**.
+As a default, KeyBERT simply compares the documents and candidate keywords/keyphrases based on their cosine similarity. However, this might lead
+to very similar words ending up in the list of most accurate keywords/keyphrases. To make sure they are a bit more diversified, there are two
+approaches that we can take in order to fine-tune our output, **Max Sum Distance** and **Maximal Marginal Relevance**.
### **Max Sum Distance**
@@ -165,8 +165,8 @@ keywords = kw_model.extract_keywords(doc, seed_keywords=seed_keywords)
## **Prepare embeddings**
-When you have a large dataset and you want to fine-tune parameters such as `diversity` it can take quite a while to re-calculate the document and
-word embeddings each time you change a parameter. Instead, we can pre-calculate these embeddings and pass them to `.extract_keywords` such that
+When you have a large dataset and you want to fine-tune parameters such as `diversity` it can take quite a while to re-calculate the document and
+word embeddings each time you change a parameter. Instead, we can pre-calculate these embeddings and pass them to `.extract_keywords` such that
we only have to calculate it once:
@@ -183,15 +183,15 @@ You can then use these embeddings and pass them to `.extract_keywords` to speed
keywords = kw_model.extract_keywords(docs, doc_embeddings=doc_embeddings, word_embeddings=word_embeddings)
```
-There are several parameters in `.extract_embeddings` that define how the list of candidate keywords/keyphrases is generated:
+There are several parameters in `.extract_embeddings` that define how the list of candidate keywords/keyphrases is generated:
* `candidates`
* `keyphrase_ngram_range`
-* `stop_words`
+* `stop_words`
* `min_df`
* `vectorizer`
-The values of these parameters need to be exactly the same in `.extract_embeddings` as they are in `. extract_keywords`.
+The values of these parameters need to be exactly the same in `.extract_embeddings` as they are in `. extract_keywords`.
In other words, the following will work as they use the same parameter subset:
@@ -200,8 +200,8 @@ from keybert import KeyBERT
kw_model = KeyBERT()
doc_embeddings, word_embeddings = kw_model.extract_embeddings(docs, min_df=1, stop_words="english")
-keywords = kw_model.extract_keywords(docs, min_df=1, stop_words="english",
- doc_embeddings=doc_embeddings,
+keywords = kw_model.extract_keywords(docs, min_df=1, stop_words="english",
+ doc_embeddings=doc_embeddings,
word_embeddings=word_embeddings)
```
@@ -212,7 +212,7 @@ from keybert import KeyBERT
kw_model = KeyBERT()
doc_embeddings, word_embeddings = kw_model.extract_embeddings(docs, min_df=3, stop_words="dutch")
-keywords = kw_model.extract_keywords(docs, min_df=1, stop_words="english",
- doc_embeddings=doc_embeddings,
+keywords = kw_model.extract_keywords(docs, min_df=1, stop_words="english",
+ doc_embeddings=doc_embeddings,
word_embeddings=word_embeddings)
```
diff --git a/docs/images/guided.svg b/docs/images/guided.svg
index ee6e65b0..64757a87 100644
--- a/docs/images/guided.svg
+++ b/docs/images/guided.svg
@@ -1,6 +1,6 @@
\ No newline at end of file
+ Input DocumentTokenize WordsEmbed TokensExtract EmbeddingsAverage seed keyword and document embeddingsCalculateCosine SimilarityMost microbats use echolocationto navigate and find food.Most microbats...sonarmostmicrobatsuse echolocationtonavigate andfindfood0.110.550.320.28................0.720.960.490.34mostfoodMost microbats...mostfood.......08.73We use the CountVectorizer from Scikit-Learn to tokenize our document into candidate kewords/keyphrases.We embed the seeded keywords (e.g., the word “sonar”) and calculate a weighted average with the document embedding (1:3). We calculate the cosine similarity between all candidate keywords and the input document. The keywords that have the largest similarity to the document are extracted.
diff --git a/docs/images/pipeline.svg b/docs/images/pipeline.svg
index b93e4241..12edc447 100644
--- a/docs/images/pipeline.svg
+++ b/docs/images/pipeline.svg
@@ -1,6 +1,6 @@
\ No newline at end of file
+ Input DocumentTokenize WordsEmbed TokensExtract EmbeddingsEmbed DocumentCalculateCosine SimilarityMost microbats use echolocationto navigate and find food.Most microbats use echolocationto navigate and find food.mostmicrobatsuse echolocationtonavigate andfindfood0.110.550.28............0.720.960.34mostfoodMost microbats...mostfood.......08.73We use the CountVectorizer from Scikit-Learn to tokenize our document into candidate kewords/keyphrases.We can use any language model that can embed both documents and keywords, like sentence-transformers.We calculate the cosine similarity between all candidate keywords and the input document. The keywords that have the largest similarity to the document are extracted.
diff --git a/docs/index.md b/docs/index.md
index 0af611bc..b81c3342 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -99,4 +99,4 @@ of words you would like in the resulting keyphrases:
```
!!! note "NOTE"
- You can also pass multiple documents at once if you are looking for a major speed-up!
\ No newline at end of file
+ You can also pass multiple documents at once if you are looking for a major speed-up!
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
index 54661441..d038f2ce 100644
--- a/docs/stylesheets/extra.css
+++ b/docs/stylesheets/extra.css
@@ -6,7 +6,7 @@
--md-typeset-a-color: #0277BD;
}
-body[data-md-color-primary="black"] .excalidraw svg {
+body[data-md-color-primary="black"] .excalidraw svg {
filter: invert(100%) hue-rotate(180deg);
}
diff --git a/keybert/__init__.py b/keybert/__init__.py
index 01fa8378..d3ce0bfe 100644
--- a/keybert/__init__.py
+++ b/keybert/__init__.py
@@ -4,3 +4,8 @@
from keybert._model import KeyBERT
__version__ = version("keybert")
+
+__all__ = [
+ "KeyBERT",
+ "KeyLLM",
+]
diff --git a/keybert/_highlight.py b/keybert/_highlight.py
index 7100b297..ae7bbbaf 100644
--- a/keybert/_highlight.py
+++ b/keybert/_highlight.py
@@ -11,10 +11,8 @@ class NullHighlighter(RegexHighlighter):
highlights = [r""]
-def highlight_document(
- doc: str, keywords: List[Tuple[str, float]], vectorizer: CountVectorizer
-):
- """Highlight keywords in a document
+def highlight_document(doc: str, keywords: List[Tuple[str, float]], vectorizer: CountVectorizer):
+ """Highlight keywords in a document.
Arguments:
doc: The document for which to extract keywords/keyphrases.
@@ -38,10 +36,8 @@ def highlight_document(
console.print(highlighted_text)
-def _highlight_one_gram(
- doc: str, keywords: List[str], vectorizer: CountVectorizer
-) -> str:
- """Highlight 1-gram keywords in a document
+def _highlight_one_gram(doc: str, keywords: List[str], vectorizer: CountVectorizer) -> str:
+ """Highlight 1-gram keywords in a document.
Arguments:
doc: The document for which to extract keywords/keyphrases.
@@ -57,18 +53,13 @@ def _highlight_one_gram(
separator = "" if "zh" in str(tokenizer) else " "
highlighted_text = separator.join(
- [
- f"[black on #FFFF00]{token}[/]" if token.lower() in keywords else f"{token}"
- for token in tokens
- ]
+ [f"[black on #FFFF00]{token}[/]" if token.lower() in keywords else f"{token}" for token in tokens]
).strip()
return highlighted_text
-def _highlight_n_gram(
- doc: str, keywords: List[str], vectorizer: CountVectorizer
-) -> str:
- """Highlight n-gram keywords in a document
+def _highlight_n_gram(doc: str, keywords: List[str], vectorizer: CountVectorizer) -> str:
+ """Highlight n-gram keywords in a document.
Arguments:
doc: The document for which to extract keywords/keyphrases.
@@ -85,8 +76,7 @@ def _highlight_n_gram(
separator = "" if "zh" in str(tokenizer) else " "
n_gram_tokens = [
- [separator.join(tokens[i : i + max_len][0 : j + 1]) for j in range(max_len)]
- for i, _ in enumerate(tokens)
+ [separator.join(tokens[i : i + max_len][0 : j + 1]) for j in range(max_len)] for i, _ in enumerate(tokens)
]
highlighted_text = []
skip = False
@@ -96,11 +86,8 @@ def _highlight_n_gram(
if not skip:
for index, n_gram in enumerate(n_grams):
-
if n_gram.lower() in keywords:
- candidate = (
- f"[black on #FFFF00]{n_gram}[/]" + n_grams[-1].split(n_gram)[-1]
- )
+ candidate = f"[black on #FFFF00]{n_gram}[/]" + n_grams[-1].split(n_gram)[-1]
skip = index + 1
if not candidate:
diff --git a/keybert/_llm.py b/keybert/_llm.py
index f3b04fab..8c65958f 100644
--- a/keybert/_llm.py
+++ b/keybert/_llm.py
@@ -2,21 +2,21 @@
try:
from sentence_transformers import util
+
HAS_SBERT = True
except ModuleNotFoundError:
HAS_SBERT = False
class KeyLLM:
- """
- A minimal method for keyword extraction with Large Language Models (LLM)
+ """A minimal method for keyword extraction with Large Language Models (LLM).
The keyword extraction is done by simply asking the LLM to extract a
number of keywords from a single piece of text.
"""
def __init__(self, llm):
- """KeyBERT initialization
+ """KeyBERT initialization.
Arguments:
llm: The Large Language Model to use
@@ -29,9 +29,9 @@ def extract_keywords(
check_vocab: bool = False,
candidate_keywords: List[List[str]] = None,
threshold: float = None,
- embeddings=None
+ embeddings=None,
) -> Union[List[str], List[List[str]]]:
- """Extract keywords and/or keyphrases
+ """Extract keywords and/or keyphrases.
To get the biggest speed-up, make sure to pass multiple documents
at once instead of iterating over a single document.
@@ -44,6 +44,8 @@ def extract_keywords(
docs: The document(s) for which to extract keywords/keyphrases
check_vocab: Only return keywords that appear exactly in the documents
candidate_keywords: Candidate keywords for each document
+ threshold: Minimum similarity value between 0 and 1 used to decide how similar documents need to receive the same keywords.
+ embeddings: The embeddings of each document.
Returns:
keywords: The top n keywords for a document with their respective distances
@@ -78,7 +80,6 @@ def extract_keywords(
return []
if HAS_SBERT and threshold is not None and embeddings is not None:
-
# Find similar documents
clusters = util.community_detection(embeddings, min_community_size=2, threshold=threshold)
in_cluster = set([cluster for cluster_set in clusters for cluster in cluster_set])
@@ -97,21 +98,16 @@ def extract_keywords(
)
out_cluster_keywords = {index: words for words, index in zip(out_cluster_keywords, out_cluster)}
- # Extract keywords for only the first document in a cluster
+ # Extract keywords for only the first document in a cluster
if in_cluster:
selected_docs = [docs[cluster[0]] for cluster in clusters]
if candidate_keywords is not None:
selected_keywords = [candidate_keywords[cluster[0]] for cluster in clusters]
else:
selected_keywords = None
- in_cluster_keywords = self.llm.extract_keywords(
- selected_docs,
- selected_keywords
- )
+ in_cluster_keywords = self.llm.extract_keywords(selected_docs, selected_keywords)
in_cluster_keywords = {
- doc_id: in_cluster_keywords[index]
- for index, cluster in enumerate(clusters)
- for doc_id in cluster
+ doc_id: in_cluster_keywords[index] for index, cluster in enumerate(clusters) for doc_id in cluster
}
# Update out cluster keywords with in cluster keywords
diff --git a/keybert/_maxsum.py b/keybert/_maxsum.py
index 85de980a..02c0d97f 100644
--- a/keybert/_maxsum.py
+++ b/keybert/_maxsum.py
@@ -11,7 +11,7 @@ def max_sum_distance(
top_n: int,
nr_candidates: int,
) -> List[Tuple[str, float]]:
- """Calculate Max Sum Distance for extraction of keywords
+ """Calculate Max Sum Distance for extraction of keywords.
We take the 2 x top_n most similar words/phrases to the document.
Then, we take all top_n combinations from the 2 x top_n words and
@@ -31,10 +31,7 @@ def max_sum_distance(
List[Tuple[str, float]]: The selected keywords/keyphrases with their distances
"""
if nr_candidates < top_n:
- raise Exception(
- "Make sure that the number of candidates exceeds the number "
- "of keywords to return."
- )
+ raise Exception("Make sure that the number of candidates exceeds the number " "of keywords to return.")
elif top_n > len(words):
return []
@@ -51,14 +48,9 @@ def max_sum_distance(
min_sim = 100_000
candidate = None
for combination in itertools.combinations(range(len(words_idx)), top_n):
- sim = sum(
- [candidates[i][j] for i in combination for j in combination if i != j]
- )
+ sim = sum([candidates[i][j] for i in combination for j in combination if i != j])
if sim < min_sim:
candidate = combination
min_sim = sim
- return [
- (words_vals[idx], round(float(distances[0][words_idx[idx]]), 4))
- for idx in candidate
- ]
+ return [(words_vals[idx], round(float(distances[0][words_idx[idx]]), 4)) for idx in candidate]
diff --git a/keybert/_mmr.py b/keybert/_mmr.py
index ae6b22eb..5d5791b0 100644
--- a/keybert/_mmr.py
+++ b/keybert/_mmr.py
@@ -33,7 +33,6 @@ def mmr(
List[Tuple[str, float]]: The selected keywords/keyphrases with their distances
"""
-
# Extract similarity within words, and between words and the document
word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding)
word_similarity = cosine_similarity(word_embeddings)
@@ -46,14 +45,10 @@ def mmr(
# Extract similarities within candidates and
# between candidates and selected keywords/phrases
candidate_similarities = word_doc_similarity[candidates_idx, :]
- target_similarities = np.max(
- word_similarity[candidates_idx][:, keywords_idx], axis=1
- )
+ target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1)
# Calculate MMR
- mmr = (
- 1 - diversity
- ) * candidate_similarities - diversity * target_similarities.reshape(-1, 1)
+ mmr = (1 - diversity) * candidate_similarities - diversity * target_similarities.reshape(-1, 1)
mmr_idx = candidates_idx[np.argmax(mmr)]
# Update keywords & candidates
@@ -61,9 +56,6 @@ def mmr(
candidates_idx.remove(mmr_idx)
# Extract and sort keywords in descending similarity
- keywords = [
- (words[idx], round(float(word_doc_similarity.reshape(1, -1)[0][idx]), 4))
- for idx in keywords_idx
- ]
+ keywords = [(words[idx], round(float(word_doc_similarity.reshape(1, -1)[0][idx]), 4)) for idx in keywords_idx]
keywords = sorted(keywords, key=itemgetter(1), reverse=True)
return keywords
diff --git a/keybert/_model.py b/keybert/_model.py
index 4e9990ad..a36e0158 100644
--- a/keybert/_model.py
+++ b/keybert/_model.py
@@ -1,3 +1,5 @@
+# ruff: noqa: E402
+
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
@@ -13,15 +15,13 @@
from keybert._mmr import mmr
from keybert._maxsum import max_sum_distance
from keybert._highlight import highlight_document
-from keybert.backend._base import BaseEmbedder
from keybert.backend._utils import select_backend
from keybert.llm._base import BaseLLM
from keybert import KeyLLM
class KeyBERT:
- """
- A minimal method for keyword extraction with BERT
+ """A minimal method for keyword extraction with BERT.
The keyword extraction is done by finding the sub-phrases in
a document that are the most similar to the document itself.
@@ -44,7 +44,7 @@ def __init__(
model="all-MiniLM-L6-v2",
llm: BaseLLM = None,
):
- """KeyBERT initialization
+ """KeyBERT initialization.
Arguments:
model: Use a custom embedding model or a specific KeyBERT Backend.
@@ -58,6 +58,7 @@ def __init__(
You can also pass in a string that points to one of the following
sentence-transformers models:
* https://www.sbert.net/docs/pretrained_models.html
+ llm: The Large Language Model used to extract keywords
"""
self.model = select_backend(model)
@@ -85,7 +86,7 @@ def extract_keywords(
word_embeddings: np.array = None,
threshold: float = None,
) -> Union[List[Tuple[str, float]], List[List[Tuple[str, float]]]]:
- """Extract keywords and/or keyphrases
+ """Extract keywords and/or keyphrases.
To get the biggest speed-up, make sure to pass multiple documents
at once instead of iterating over a single document.
@@ -126,6 +127,7 @@ def extract_keywords(
NOTE: The `word_embeddings` should be generated through
`.extract_embeddings` as the order of these embeddings depend
on the vectorizer that was used to generate its vocabulary.
+ threshold: Minimum similarity value between 0 and 1 used to decide how similar documents need to receive the same keywords.
Returns:
keywords: The top n keywords for a document with their respective distances
@@ -199,26 +201,18 @@ def extract_keywords(
# Guided KeyBERT either local (keywords shared among documents) or global (keywords per document)
if seed_keywords is not None:
if isinstance(seed_keywords[0], str):
- seed_embeddings = self.model.embed(seed_keywords).mean(
- axis=0, keepdims=True
- )
+ seed_embeddings = self.model.embed(seed_keywords).mean(axis=0, keepdims=True)
elif len(docs) != len(seed_keywords):
- raise ValueError(
- "The length of docs must match the length of seed_keywords"
- )
+ raise ValueError("The length of docs must match the length of seed_keywords")
else:
seed_embeddings = np.vstack(
- [
- self.model.embed(keywords).mean(axis=0, keepdims=True)
- for keywords in seed_keywords
- ]
+ [self.model.embed(keywords).mean(axis=0, keepdims=True) for keywords in seed_keywords]
)
doc_embeddings = (doc_embeddings * 3 + seed_embeddings) / 4
# Find keywords
all_keywords = []
for index, _ in enumerate(docs):
-
try:
# Select embeddings
candidate_indices = df[index].nonzero()[1]
@@ -276,9 +270,7 @@ def extract_keywords(
if isinstance(all_keywords[0], tuple):
candidate_keywords = [[keyword[0] for keyword in all_keywords]]
else:
- candidate_keywords = [
- [keyword[0] for keyword in keywords] for keywords in all_keywords
- ]
+ candidate_keywords = [[keyword[0] for keyword in keywords] for keywords in all_keywords]
keywords = self.llm.extract_keywords(
docs,
embeddings=doc_embeddings,
diff --git a/keybert/_utils.py b/keybert/_utils.py
index 4ba45741..e4152d9a 100644
--- a/keybert/_utils.py
+++ b/keybert/_utils.py
@@ -1,6 +1,5 @@
class NotInstalled:
- """
- This object is used to notify the user that additional dependencies need to be
+ """This object is used to notify the user that additional dependencies need to be
installed in order to use the string matching model.
"""
@@ -19,4 +18,4 @@ def __getattr__(self, *args, **kwargs):
raise ModuleNotFoundError(self.msg)
def __call__(self, *args, **kwargs):
- raise ModuleNotFoundError(self.msg)
\ No newline at end of file
+ raise ModuleNotFoundError(self.msg)
diff --git a/keybert/backend/_base.py b/keybert/backend/_base.py
index 6f542a6c..a0cbe75c 100644
--- a/keybert/backend/_base.py
+++ b/keybert/backend/_base.py
@@ -3,7 +3,7 @@
class BaseEmbedder:
- """The Base Embedder used for creating embedding models
+ """The Base Embedder used for creating embedding models.
Arguments:
embedding_model: The main embedding model to be used for extracting
@@ -19,8 +19,7 @@ def __init__(self, embedding_model=None, word_embedding_model=None):
self.word_embedding_model = word_embedding_model
def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
- """Embed a list of n documents/words into an n-dimensional
- matrix of embeddings
+ """Embed a list of n documents/words into an n-dimensional matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
diff --git a/keybert/backend/_flair.py b/keybert/backend/_flair.py
index bc4115ff..44f23344 100644
--- a/keybert/backend/_flair.py
+++ b/keybert/backend/_flair.py
@@ -8,9 +8,9 @@
class FlairBackend(BaseEmbedder):
- """Flair Embedding Model
- The Flair embedding model used for generating document and
- word embeddings.
+ """Flair Embedding Model.
+
+ The Flair embedding model used for generating document and word embeddings.
Arguments:
embedding_model: A Flair embedding model
@@ -60,14 +60,12 @@ def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
verbose: Controls the verbosity of the process
Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
- that each have an embeddings size of `m`
+ that each have an embeddings size of `m`.
"""
embeddings = []
for index, document in tqdm(enumerate(documents), disable=not verbose):
try:
- sentence = (
- Sentence(document) if document else Sentence("an empty document")
- )
+ sentence = Sentence(document) if document else Sentence("an empty document")
self.embedding_model.embed(sentence)
except RuntimeError:
sentence = Sentence("an empty document")
diff --git a/keybert/backend/_gensim.py b/keybert/backend/_gensim.py
index 13e81dc8..723c5a0a 100644
--- a/keybert/backend/_gensim.py
+++ b/keybert/backend/_gensim.py
@@ -8,7 +8,7 @@
class GensimBackend(BaseEmbedder):
- """Gensim Embedding Model
+ """Gensim Embedding Model.
The Gensim embedding model is typically used for word embeddings with
GloVe, Word2Vec or FastText.
@@ -40,8 +40,7 @@ def __init__(self, embedding_model: Word2VecKeyedVectors):
)
def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
- """Embed a list of n documents/words into an n-dimensional
- matrix of embeddings
+ """Embed a list of n documents/words into an n-dimensional matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
diff --git a/keybert/backend/_hftransformers.py b/keybert/backend/_hftransformers.py
index b285ad08..23d2cfa7 100644
--- a/keybert/backend/_hftransformers.py
+++ b/keybert/backend/_hftransformers.py
@@ -10,7 +10,7 @@
class HFTransformerBackend(BaseEmbedder):
- """Hugging Face transformers model
+ """Hugging Face transformers model.
This uses the `transformers.pipelines.pipeline` to define and create
a feature generation pipeline from which embeddings can be extracted.
@@ -42,8 +42,7 @@ def __init__(self, embedding_model: Pipeline):
)
def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
- """Embed a list of n documents/words into an n-dimensional
- matrix of embeddings
+ """Embed a list of n documents/words into an n-dimensional matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
@@ -57,9 +56,7 @@ def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
embeddings = []
for document, features in tqdm(
- zip(
- documents, self.embedding_model(dataset, truncation=True, padding=True)
- ),
+ zip(documents, self.embedding_model(dataset, truncation=True, padding=True)),
total=len(dataset),
disable=not verbose,
):
@@ -68,7 +65,7 @@ def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
return np.array(embeddings)
def _embed(self, document: str, features: np.ndarray) -> np.ndarray:
- """Mean pooling
+ """Mean pooling.
Arguments:
document: The document for which to extract the attention mask
@@ -78,12 +75,10 @@ def _embed(self, document: str, features: np.ndarray) -> np.ndarray:
https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2#usage-huggingface-transformers
"""
token_embeddings = np.array(features)
- attention_mask = self.embedding_model.tokenizer(
- document, truncation=True, padding=True, return_tensors="np"
- )["attention_mask"]
- input_mask_expanded = np.broadcast_to(
- np.expand_dims(attention_mask, -1), token_embeddings.shape
- )
+ attention_mask = self.embedding_model.tokenizer(document, truncation=True, padding=True, return_tensors="np")[
+ "attention_mask"
+ ]
+ input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), token_embeddings.shape)
sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = np.clip(
input_mask_expanded.sum(1),
@@ -95,7 +90,7 @@ def _embed(self, document: str, features: np.ndarray) -> np.ndarray:
class MyDataset(Dataset):
- """Dataset to pass to `transformers.pipelines.pipeline`"""
+ """Dataset to pass to `transformers.pipelines.pipeline`."""
def __init__(self, docs):
self.docs = docs
diff --git a/keybert/backend/_sentencetransformers.py b/keybert/backend/_sentencetransformers.py
index 47fd1e73..fb8acc46 100644
--- a/keybert/backend/_sentencetransformers.py
+++ b/keybert/backend/_sentencetransformers.py
@@ -6,9 +6,9 @@
class SentenceTransformerBackend(BaseEmbedder):
- """Sentence-transformers embedding model
- The sentence-transformers embedding model used for generating document and
- word embeddings.
+ """Sentence-transformers embedding model.
+
+ The sentence-transformers embedding model used for generating document and word embeddings.
Arguments:
embedding_model: A sentence-transformers embedding model
@@ -34,9 +34,7 @@ class SentenceTransformerBackend(BaseEmbedder):
```
"""
- def __init__(
- self, embedding_model: Union[str, SentenceTransformer], **encode_kwargs
- ):
+ def __init__(self, embedding_model: Union[str, SentenceTransformer], **encode_kwargs):
super().__init__()
if isinstance(embedding_model, SentenceTransformer):
@@ -53,7 +51,7 @@ def __init__(
def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
"""Embed a list of n documents/words into an n-dimensional
- matrix of embeddings
+ matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
diff --git a/keybert/backend/_spacy.py b/keybert/backend/_spacy.py
index 1a1a27a1..ad88d401 100644
--- a/keybert/backend/_spacy.py
+++ b/keybert/backend/_spacy.py
@@ -5,7 +5,7 @@
class SpacyBackend(BaseEmbedder):
- """Spacy embedding model
+ """Spacy embedding model.
The Spacy embedding model used for generating document and
word embeddings.
@@ -63,8 +63,7 @@ def __init__(self, embedding_model):
)
def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
- """Embed a list of n documents/words into an n-dimensional
- matrix of embeddings
+ """Embed a list of n documents/words into an n-dimensional matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
@@ -74,21 +73,14 @@ def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
-
# Extract embeddings from a transformer model
if "transformer" in self.embedding_model.component_names:
embeddings = []
for doc in tqdm(documents, position=0, leave=True, disable=not verbose):
try:
- embedding = (
- self.embedding_model(doc)._.trf_data.tensors[-1][0].tolist()
- )
- except:
- embedding = (
- self.embedding_model("An empty document")
- ._.trf_data.tensors[-1][0]
- .tolist()
- )
+ embedding = self.embedding_model(doc)._.trf_data.tensors[-1][0].tolist()
+ except: # noqa: E722
+ embedding = self.embedding_model("An empty document")._.trf_data.tensors[-1][0].tolist()
embeddings.append(embedding)
embeddings = np.array(embeddings)
diff --git a/keybert/backend/_use.py b/keybert/backend/_use.py
index 9cb3af6e..9a82507a 100644
--- a/keybert/backend/_use.py
+++ b/keybert/backend/_use.py
@@ -6,7 +6,7 @@
class USEBackend(BaseEmbedder):
- """Universal Sentence Encoder
+ """Universal Sentence Encoder.
USE encodes text into high-dimensional vectors that
are used for semantic similarity in KeyBERT.
@@ -39,8 +39,7 @@ def __init__(self, embedding_model):
)
def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
- """Embed a list of n documents/words into an n-dimensional
- matrix of embeddings
+ """Embed a list of n documents/words into an n-dimensional matrix of embeddings.
Arguments:
documents: A list of documents or words to be embedded
@@ -51,9 +50,6 @@ def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
that each have an embeddings size of `m`
"""
embeddings = np.array(
- [
- self.embedding_model([doc]).cpu().numpy()[0]
- for doc in tqdm(documents, disable=not verbose)
- ]
+ [self.embedding_model([doc]).cpu().numpy()[0] for doc in tqdm(documents, disable=not verbose)]
)
return embeddings
diff --git a/keybert/backend/_utils.py b/keybert/backend/_utils.py
index ba1eabe6..b9979382 100644
--- a/keybert/backend/_utils.py
+++ b/keybert/backend/_utils.py
@@ -40,21 +40,26 @@ def select_backend(embedding_model) -> BaseEmbedder:
# Sentence Transformer embeddings
if "sentence_transformers" in str(type(embedding_model)):
from ._sentencetransformers import SentenceTransformerBackend
+
return SentenceTransformerBackend(embedding_model)
# Create a Sentence Transformer model based on a string
if isinstance(embedding_model, str):
from ._sentencetransformers import SentenceTransformerBackend
+
return SentenceTransformerBackend(embedding_model)
# Hugging Face embeddings
try:
from transformers.pipelines import Pipeline
+
if isinstance(embedding_model, Pipeline):
from ._hftransformers import HFTransformerBackend
+
return HFTransformerBackend(embedding_model)
except ImportError:
pass
-
+
from ._sentencetransformers import SentenceTransformerBackend
+
return SentenceTransformerBackend("paraphrase-multilingual-MiniLM-L12-v2")
diff --git a/keybert/llm/__init__.py b/keybert/llm/__init__.py
index e24fe48a..abe7c7d6 100644
--- a/keybert/llm/__init__.py
+++ b/keybert/llm/__init__.py
@@ -44,12 +44,4 @@
LiteLLM = NotInstalled("LiteLLM", "litellm", custom_msg=msg)
-__all__ = [
- "BaseLLM",
- "Cohere",
- "OpenAI",
- "TextGeneration",
- "TextGenerationInference",
- "LangChain",
- "LiteLLM"
-]
+__all__ = ["BaseLLM", "Cohere", "OpenAI", "TextGeneration", "TextGenerationInference", "LangChain", "LiteLLM"]
diff --git a/keybert/llm/_base.py b/keybert/llm/_base.py
index 359d1df0..4e9964ed 100644
--- a/keybert/llm/_base.py
+++ b/keybert/llm/_base.py
@@ -3,15 +3,16 @@
class BaseLLM(BaseEstimator):
- """ The base representation model for fine-tuning topic representations """
+ """The base representation model for fine-tuning topic representations."""
+
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
- """ Extract topics
+ """Extract topics.
Arguments:
documents: The documents to extract keywords from
candidate_keywords: A list of candidate keywords that the LLM will fine-tune
- For example, it will create a nicer representation of
- the candidate keywords, remove redundant keywords, or
+ For example, it will create a nicer representation of
+ the candidate keywords, remove redundant keywords, or
shorten them depending on the input prompt.
Returns:
diff --git a/keybert/llm/_cohere.py b/keybert/llm/_cohere.py
index 6a5dd5bf..7b6d23b1 100644
--- a/keybert/llm/_cohere.py
+++ b/keybert/llm/_cohere.py
@@ -25,7 +25,7 @@
class Cohere(BaseLLM):
- """ Use the Cohere API to generate topic labels based on their
+ """Use the Cohere API to generate topic labels based on their
generative model.
Find more about their models here:
@@ -80,13 +80,10 @@ class Cohere(BaseLLM):
llm = Cohere(co, prompt=prompt)
```
"""
- def __init__(self,
- client,
- model: str = "command",
- prompt: str = None,
- delay_in_seconds: float = None,
- verbose: bool = False
- ):
+
+ def __init__(
+ self, client, model: str = "command", prompt: str = None, delay_in_seconds: float = None, verbose: bool = False
+ ):
self.client = client
self.model = model
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
@@ -95,7 +92,7 @@ def __init__(self,
self.verbose = verbose
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
- """ Extract topics
+ """Extract topics.
Arguments:
documents: The documents to extract keywords from
@@ -119,11 +116,9 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)
- request = self.client.generate(model=self.model,
- prompt=prompt,
- max_tokens=50,
- num_generations=1,
- stop_sequences=["\n"])
+ request = self.client.generate(
+ model=self.model, prompt=prompt, max_tokens=50, num_generations=1, stop_sequences=["\n"]
+ )
keywords = request.generations[0].text.strip()
keywords = [keyword.strip() for keyword in keywords.split(",")]
all_keywords.append(keywords)
diff --git a/keybert/llm/_langchain.py b/keybert/llm/_langchain.py
index db7a6654..f786109e 100644
--- a/keybert/llm/_langchain.py
+++ b/keybert/llm/_langchain.py
@@ -9,7 +9,7 @@
class LangChain(BaseLLM):
- """ Using chains in langchain to generate keywords.
+ """Using chains in langchain to generate keywords.
Currently, only chains from question answering is implemented. See:
https://langchain.readthedocs.io/en/latest/modules/chains/combine_docs_examples/question_answering.html
@@ -66,18 +66,20 @@ class LangChain(BaseLLM):
llm = LangChain(chain, prompt=prompt)
```
"""
- def __init__(self,
- chain,
- prompt: str = None,
- verbose: bool = False,
- ):
+
+ def __init__(
+ self,
+ chain,
+ prompt: str = None,
+ verbose: bool = False,
+ ):
self.chain = chain
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.verbose = verbose
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
- """ Extract topics
+ """Extract topics.
Arguments:
documents: The documents to extract keywords from
diff --git a/keybert/llm/_litellm.py b/keybert/llm/_litellm.py
index 5a1fe1c8..e9929376 100644
--- a/keybert/llm/_litellm.py
+++ b/keybert/llm/_litellm.py
@@ -17,7 +17,7 @@
class LiteLLM(BaseLLM):
- """ Extract keywords using LiteLLM to call any LLM API using OpenAI format
+ r"""Extract keywords using LiteLLM to call any LLM API using OpenAI format
such as Anthropic, Huggingface, Cohere, TogetherAI, Azure, OpenAI, etc.
NOTE: The resulting keywords are expected to be separated by commas so
@@ -32,8 +32,8 @@ class LiteLLM(BaseLLM):
`self.default_prompt_` is used instead.
NOTE: Use `"[DOCUMENT]"` in the prompt
to decide where the document needs to be inserted
- system_prompt: The message that sets the behavior of the assistant.
- It's typically used to provide high-level instructions
+ system_prompt: The message that sets the behavior of the assistant.
+ It's typically used to provide high-level instructions
for the conversation.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
@@ -68,14 +68,16 @@ class LiteLLM(BaseLLM):
llm = LiteLLM("gpt-3.5-turbo", prompt=prompt)
```
"""
- def __init__(self,
- model: str = "gpt-3.5-turbo",
- prompt: str = None,
- system_prompt: str = "You are a helpful assistant.",
- generator_kwargs: Mapping[str, Any] = {},
- delay_in_seconds: float = None,
- verbose: bool = False
- ):
+
+ def __init__(
+ self,
+ model: str = "gpt-3.5-turbo",
+ prompt: str = None,
+ system_prompt: str = "You are a helpful assistant.",
+ generator_kwargs: Mapping[str, Any] = {},
+ delay_in_seconds: float = None,
+ verbose: bool = False,
+ ):
self.model = model
if prompt is None:
@@ -95,7 +97,7 @@ def __init__(self,
del self.generator_kwargs["prompt"]
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
- """ Extract topics
+ """Extract topics.
Arguments:
documents: The documents to extract keywords from
@@ -120,10 +122,7 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s
time.sleep(self.delay_in_seconds)
# Use a chat model
- messages = [
- {"role": "system", "content": self.system_prompt},
- {"role": "user", "content": prompt}
- ]
+ messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}]
kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
response = completion(**kwargs)
diff --git a/keybert/llm/_openai.py b/keybert/llm/_openai.py
index 43b605eb..cbbb2d6d 100644
--- a/keybert/llm/_openai.py
+++ b/keybert/llm/_openai.py
@@ -35,7 +35,7 @@
class OpenAI(BaseLLM):
- """ Using the OpenAI API to extract keywords
+ r"""Using the OpenAI API to extract keywords.
The default method is `openai.Completion` if `chat=False`.
The prompts will also need to follow a completion task. If you
@@ -60,8 +60,8 @@ class OpenAI(BaseLLM):
`self.default_prompt_` is used instead.
NOTE: Use `"[DOCUMENT]"` in the prompt
to decide where the document needs to be inserted
- system_prompt: The message that sets the behavior of the assistant.
- It's typically used to provide high-level instructions
+ system_prompt: The message that sets the behavior of the assistant.
+ It's typically used to provide high-level instructions
for the conversation.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
@@ -113,17 +113,19 @@ class OpenAI(BaseLLM):
llm = OpenAI(client, model="gpt-3.5-turbo", delay_in_seconds=10, chat=True)
```
"""
- def __init__(self,
- client,
- model: str = "gpt-3.5-turbo-instruct",
- prompt: str = None,
- system_prompt: str = "You are a helpful assistant.",
- generator_kwargs: Mapping[str, Any] = {},
- delay_in_seconds: float = None,
- exponential_backoff: bool = False,
- chat: bool = False,
- verbose: bool = False
- ):
+
+ def __init__(
+ self,
+ client,
+ model: str = "gpt-3.5-turbo-instruct",
+ prompt: str = None,
+ system_prompt: str = "You are a helpful assistant.",
+ generator_kwargs: Mapping[str, Any] = {},
+ delay_in_seconds: float = None,
+ exponential_backoff: bool = False,
+ chat: bool = False,
+ verbose: bool = False,
+ ):
self.client = client
self.model = model
@@ -148,7 +150,7 @@ def __init__(self,
self.generator_kwargs["stop"] = "\n"
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
- """ Extract topics
+ """Extract topics.
Arguments:
documents: The documents to extract keywords from
@@ -174,10 +176,7 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s
# Use a chat model
if self.chat:
- messages = [
- {"role": "system", "content": self.system_prompt},
- {"role": "user", "content": prompt}
- ]
+ messages = [{"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt}]
kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
if self.exponential_backoff:
response = chat_completions_with_backoff(self.client, **kwargs)
@@ -188,7 +187,9 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s
# Use a non-chat model
else:
if self.exponential_backoff:
- response = completions_with_backoff(self.client, model=self.model, prompt=prompt, **self.generator_kwargs)
+ response = completions_with_backoff(
+ self.client, model=self.model, prompt=prompt, **self.generator_kwargs
+ )
else:
response = self.client.completions.create(model=self.model, prompt=prompt, **self.generator_kwargs)
keywords = response.choices[0].text.strip()
@@ -201,16 +202,12 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s
def completions_with_backoff(client, **kwargs):
return retry_with_exponential_backoff(
client.completions.create,
- errors=(
- openai.RateLimitError,
- ),
+ errors=(openai.RateLimitError,),
)(**kwargs)
def chat_completions_with_backoff(client, **kwargs):
return retry_with_exponential_backoff(
client.chat.completions.create,
- errors=(
- openai.RateLimitError,
- ),
+ errors=(openai.RateLimitError,),
)(**kwargs)
diff --git a/keybert/llm/_textgeneration.py b/keybert/llm/_textgeneration.py
index f505caf4..5a9a2f5a 100644
--- a/keybert/llm/_textgeneration.py
+++ b/keybert/llm/_textgeneration.py
@@ -15,7 +15,7 @@
class TextGeneration(BaseLLM):
- """ Text2Text or text generation with transformers
+ """Text2Text or text generation with transformers.
NOTE: The resulting keywords are expected to be separated by commas so
any changes to the prompt will have to make sure that the resulting
@@ -70,29 +70,33 @@ class TextGeneration(BaseLLM):
llm = TextGeneration(generator)
```
"""
- def __init__(self,
- model: Union[str, pipeline],
- prompt: str = None,
- pipeline_kwargs: Mapping[str, Any] = {},
- random_state: int = 42,
- verbose: bool = False
- ):
+
+ def __init__(
+ self,
+ model: Union[str, pipeline],
+ prompt: str = None,
+ pipeline_kwargs: Mapping[str, Any] = {},
+ random_state: int = 42,
+ verbose: bool = False,
+ ):
set_seed(random_state)
if isinstance(model, str):
self.model = pipeline("text-generation", model=model)
elif isinstance(model, Pipeline):
self.model = model
else:
- raise ValueError("Make sure that the HF model that you"
- "pass is either a string referring to a"
- "HF model or a `transformers.pipeline` object.")
+ raise ValueError(
+ "Make sure that the HF model that you"
+ "pass is either a string referring to a"
+ "HF model or a `transformers.pipeline` object."
+ )
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.pipeline_kwargs = pipeline_kwargs
self.verbose = verbose
def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
- """ Extract topics
+ """Extract topics.
Arguments:
documents: The documents to extract keywords from
diff --git a/keybert/llm/_textgenerationinference.py b/keybert/llm/_textgenerationinference.py
index f2b99a66..592456f1 100644
--- a/keybert/llm/_textgenerationinference.py
+++ b/keybert/llm/_textgenerationinference.py
@@ -24,7 +24,7 @@ class Keywords(BaseModel):
class TextGenerationInference(BaseLLM):
- """ Tex
+ """Tex.
Arguments:
client: InferenceClient from huggingface_hub.
@@ -79,22 +79,16 @@ class Keywords(BaseModel):
```
"""
- def __init__(self,
- client: InferenceClient,
- prompt: str = None,
- json_schema: BaseModel = Keywords
- ):
+ def __init__(self, client: InferenceClient, prompt: str = None, json_schema: BaseModel = Keywords):
self.client = client
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.json_schema = json_schema
def extract_keywords(
- self,
- documents: List[str], candidate_keywords: List[List[str]] = None,
- inference_kwargs: Mapping[str, Any] = {}
+ self, documents: List[str], candidate_keywords: List[List[str]] = None, inference_kwargs: Mapping[str, Any] = {}
):
- """ Extract topics
+ """Extract topics.
Arguments:
documents: The documents to extract keywords from
@@ -102,6 +96,7 @@ def extract_keywords(
For example, it will create a nicer representation of
the candidate keywords, remove redundant keywords, or
shorten them depending on the input prompt.
+ inference_kwargs: kwargs for `InferenceClient.text_generation`. See: https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
Returns:
all_keywords: All keywords for each document
@@ -116,9 +111,7 @@ def extract_keywords(
# Extract result from generator and use that as label
response = self.client.text_generation(
- prompt=prompt,
- grammar={"type": "json", "value": self.json_schema.schema()},
- **inference_kwargs
+ prompt=prompt, grammar={"type": "json", "value": self.json_schema.schema()}, **inference_kwargs
)
all_keywords = json.loads(response)["keywords"]
diff --git a/keybert/llm/_utils.py b/keybert/llm/_utils.py
index 6ca8bdd3..fa0862ec 100644
--- a/keybert/llm/_utils.py
+++ b/keybert/llm/_utils.py
@@ -34,15 +34,13 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)
# Retry on specific errors
- except errors as e:
+ except errors:
# Increment retries
num_retries += 1
# Check if max retries has been reached
if num_retries > max_retries:
- raise Exception(
- f"Maximum number of retries ({max_retries}) exceeded."
- )
+ raise Exception(f"Maximum number of retries ({max_retries}) exceeded.")
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
diff --git a/mkdocs.yml b/mkdocs.yml
index 07d03c20..7253fbd3 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -73,5 +73,5 @@ markdown_extensions:
- pymdownx.highlight
- pymdownx.superfences
- pymdownx.snippets
- - toc:
+ - toc:
permalink: true
diff --git a/pyproject.toml b/pyproject.toml
index 795b64f3..872d3327 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -89,3 +89,27 @@ Repository = "https://github.com/MaartenGr/KeyBERT.git"
[tool.setuptools.packages.find]
include = ["keybert*"]
exclude = ["tests"]
+
+[tool.ruff]
+line-length = 120
+
+[tool.ruff.lint]
+select = [
+ "E4", # Ruff Defaults
+ "E7",
+ "E9",
+ "F", # End Ruff Defaults,
+ "D"
+]
+ignore = [
+ "D100", # Missing docstring in public module
+ "D104", # Missing docstring in public package
+ "D205", # 1 blank line required between summary line and description
+ "E731", # Do not assign a lambda expression, use a def
+]
+
+[tool.ruff.lint.per-file-ignores]
+"**/tests/*" = ["D"] # Ignore all docstring errors in tests
+
+[tool.ruff.lint.pydocstyle]
+convention = "google"
diff --git a/tests/test_backend.py b/tests/test_backend.py
index 7ef86174..5ffe492c 100644
--- a/tests/test_backend.py
+++ b/tests/test_backend.py
@@ -11,9 +11,7 @@
@pytest.mark.parametrize("keyphrase_length", [(1, i + 1) for i in range(5)])
-@pytest.mark.parametrize(
- "vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]
-)
+@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
def test_single_doc_sentence_transformer_backend(keyphrase_length, vectorizer):
"""Test whether the keywords are correctly extracted"""
top_n = 5
diff --git a/tests/test_model.py b/tests/test_model.py
index 84fd35ba..e539eb8e 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -7,14 +7,12 @@
doc_one, doc_two = get_test_data()
-docs = fetch_20newsgroups(subset='test', remove=('headers', 'footers', 'quotes'))['data']
+docs = fetch_20newsgroups(subset="test", remove=("headers", "footers", "quotes"))["data"]
model = KeyBERT(model="all-MiniLM-L6-v2")
@pytest.mark.parametrize("keyphrase_length", [(1, i + 1) for i in range(5)])
-@pytest.mark.parametrize(
- "vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]
-)
+@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
def test_single_doc(keyphrase_length, vectorizer):
"""Test whether the keywords are correctly extracted"""
top_n = 5
@@ -40,15 +38,9 @@ def test_single_doc(keyphrase_length, vectorizer):
"keyphrase_length, mmr, maxsum",
[((1, i + 1), truth, not truth) for i in range(4) for truth in [True, False]],
)
-@pytest.mark.parametrize(
- "vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")]
-)
-@pytest.mark.parametrize(
- "candidates", [None, ["praise"]]
-)
-@pytest.mark.parametrize(
- "seed_keywords", [None, ["time", "night", "day", "moment"]]
-)
+@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
+@pytest.mark.parametrize("candidates", [None, ["praise"]])
+@pytest.mark.parametrize("seed_keywords", [None, ["time", "night", "day", "moment"]])
def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer, candidates, seed_keywords):
"""Test extraction of protected single document method"""
top_n = 5
@@ -76,17 +68,12 @@ def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer,
@pytest.mark.parametrize("keyphrase_length", [(1, i + 1) for i in range(5)])
-@pytest.mark.parametrize(
- "candidates", [None, ["praise"]]
-)
+@pytest.mark.parametrize("candidates", [None, ["praise"]])
def test_extract_keywords_multiple_docs(keyphrase_length, candidates):
"""Test extraction of protected multiple document method"""
top_n = 5
keywords_list = model.extract_keywords(
- [doc_one, doc_two],
- top_n=top_n,
- keyphrase_ngram_range=keyphrase_length,
- candidates=candidates
+ [doc_one, doc_two], top_n=top_n, keyphrase_ngram_range=keyphrase_length, candidates=candidates
)
assert isinstance(keywords_list, list)
assert isinstance(keywords_list[0], list)
@@ -103,15 +90,14 @@ def test_extract_keywords_multiple_docs(keyphrase_length, candidates):
assert keywords_list[0][0][0] == candidates[0]
assert len(keywords_list[1]) == 0
+
def test_guided():
"""Test whether the keywords are correctly extracted"""
# single doc + a keywords list
top_n = 5
seed_keywords = ["time", "night", "day", "moment"]
- keywords = model.extract_keywords(
- doc_one, min_df=1, top_n=top_n, seed_keywords=seed_keywords
- )
+ keywords = model.extract_keywords(doc_one, min_df=1, top_n=top_n, seed_keywords=seed_keywords)
assert isinstance(keywords, list)
assert isinstance(keywords[0], tuple)
assert isinstance(keywords[0][0], str)
@@ -122,12 +108,7 @@ def test_guided():
top_n = 5
list_of_docs = [doc_one, doc_two]
list_of_seed_keywords = ["time", "night", "day", "moment"]
- keywords = model.extract_keywords(
- list_of_docs,
- min_df=1,
- top_n=top_n,
- seed_keywords=list_of_seed_keywords
- )
+ keywords = model.extract_keywords(list_of_docs, min_df=1, top_n=top_n, seed_keywords=list_of_seed_keywords)
print(keywords)
assert isinstance(keywords, list)
@@ -140,16 +121,8 @@ def test_guided():
# a bacth of docs, each of which has its own seed keywords
top_n = 5
list_of_docs = [doc_one, doc_two]
- list_of_seed_keywords = [
- ["time", "night", "day", "moment"],
- ["hockey", "games", "afternoon", "tv"]
- ]
- keywords = model.extract_keywords(
- list_of_docs,
- min_df=1,
- top_n=top_n,
- seed_keywords=list_of_seed_keywords
- )
+ list_of_seed_keywords = [["time", "night", "day", "moment"], ["hockey", "games", "afternoon", "tv"]]
+ keywords = model.extract_keywords(list_of_docs, min_df=1, top_n=top_n, seed_keywords=list_of_seed_keywords)
print(keywords)
assert isinstance(keywords, list)
@@ -159,6 +132,7 @@ def test_guided():
assert isinstance(keywords[0][0][1], float)
assert len(keywords[0]) == top_n
+
def test_empty_doc():
"""Test empty document"""
doc = ""
@@ -172,9 +146,7 @@ def test_extract_embeddings():
n_docs = 50
doc_embeddings, word_embeddings = model.extract_embeddings(docs[:n_docs])
keywords_fast = model.extract_keywords(
- docs[:n_docs],
- doc_embeddings=doc_embeddings,
- word_embeddings=word_embeddings
+ docs[:n_docs], doc_embeddings=doc_embeddings, word_embeddings=word_embeddings
)
keywords_slow = model.extract_keywords(docs[:n_docs])
@@ -182,12 +154,9 @@ def test_extract_embeddings():
assert doc_embeddings.shape[0] == n_docs
assert keywords_fast == keywords_slow
- # When we use `min_df=3` to extract the keywords, this should give an error since
- # this value was not used when extracting the embeddings and should be the same.
+ # When we use `min_df=3` to extract the keywords, this should give an error since
+ # this value was not used when extracting the embeddings and should be the same.
with pytest.raises(ValueError):
_ = model.extract_keywords(
- docs[:n_docs],
- doc_embeddings=doc_embeddings,
- word_embeddings=word_embeddings,
- min_df=3
+ docs[:n_docs], doc_embeddings=doc_embeddings, word_embeddings=word_embeddings, min_df=3
)