Skip to content

Commit

Permalink
Add custom countvectorizer (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Dec 2, 2020
1 parent 43fed75 commit 60bf9e5
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 32 deletions.
1 change: 1 addition & 0 deletions keybert/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from keybert.model import KeyBERT
__version__ = "0.1.3"
50 changes: 31 additions & 19 deletions keybert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
from tqdm import tqdm
from typing import List, Union
from typing import List, Union, Tuple
import warnings
from .mmr import mmr
from .maxsum import max_sum_similarity
Expand Down Expand Up @@ -35,14 +35,15 @@ def __init__(self, model: str = 'distilbert-base-nli-mean-tokens'):

def extract_keywords(self,
docs: Union[str, List[str]],
keyphrase_length: int = 1,
keyphrase_ngram_range: Tuple[int, int] = (1, 1),
stop_words: Union[str, List[str]] = 'english',
top_n: int = 5,
min_df: int = 1,
use_maxsum: bool = False,
use_mmr: bool = False,
diversity: float = 0.5,
nr_candidates: int = 20) -> Union[List[str], List[List[str]]]:
nr_candidates: int = 20,
vectorizer: CountVectorizer = None) -> Union[List[str], List[List[str]]]:
""" Extract keywords/keyphrases
NOTE:
Expand All @@ -62,7 +63,7 @@ def extract_keywords(self,
Arguments:
docs: The document(s) for which to extract keywords/keyphrases
keyphrase_length: Length, in words, of the extracted keywords/keyphrases
keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases
stop_words: Stopwords to remove from the document
top_n: Return the top n keywords/keyphrases
min_df: Minimum document frequency of a word across all documents
Expand All @@ -75,6 +76,7 @@ def extract_keywords(self,
is set to True
nr_candidates: The number of candidates to consider if use_maxsum is
set to True
vectorizer: Pass in your own CountVectorizer from scikit-learn
Returns:
keywords: the top n keywords for a document
Expand All @@ -83,52 +85,58 @@ def extract_keywords(self,

if isinstance(docs, str):
return self._extract_keywords_single_doc(docs,
keyphrase_length,
keyphrase_ngram_range,
stop_words,
top_n,
use_maxsum,
use_mmr,
diversity,
nr_candidates)
nr_candidates,
vectorizer)
elif isinstance(docs, list):
warnings.warn("Although extracting keywords for multiple documents is faster "
"than iterating over single documents, it requires significant memory "
"than iterating over single documents, it requires significantly more memory "
"to hold all word embeddings. Use this at your own discretion!")
return self._extract_keywords_multiple_docs(docs,
keyphrase_length,
keyphrase_ngram_range,
stop_words,
top_n,
min_df=min_df)
min_df,
vectorizer)

def _extract_keywords_single_doc(self,
doc: str,
keyphrase_length: int = 1,
keyphrase_ngram_range: Tuple[int, int] = (1, 1),
stop_words: Union[str, List[str]] = 'english',
top_n: int = 5,
use_maxsum: bool = False,
use_mmr: bool = False,
diversity: float = 0.5,
nr_candidates: int = 20) -> List[str]:
nr_candidates: int = 20,
vectorizer: CountVectorizer = None) -> List[str]:
""" Extract keywords/keyphrases for a single document
Arguments:
doc: The document for which to extract keywords/keyphrases
keyphrase_length: Length, in words, of the extracted keywords/keyphrases
keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases
stop_words: Stopwords to remove from the document
top_n: Return the top n keywords/keyphrases
use_mmr: Whether to use Max Sum Similarity
use_mmr: Whether to use MMR
diversity: The diversity of results between 0 and 1 if use_mmr is True
nr_candidates: The number of candidates to consider if use_maxsum is set to True
vectorizer: Pass in your own CountVectorizer from scikit-learn
Returns:
keywords: The top n keywords for a document
"""
try:
# Extract Words
n_gram_range = (keyphrase_length, keyphrase_length)
count = CountVectorizer(ngram_range=n_gram_range, stop_words=stop_words).fit([doc])
if vectorizer:
count = vectorizer.fit([doc])
else:
count = CountVectorizer(ngram_range=keyphrase_ngram_range, stop_words=stop_words).fit([doc])
words = count.get_feature_names()

# Extract Embeddings
Expand All @@ -150,28 +158,32 @@ def _extract_keywords_single_doc(self,

def _extract_keywords_multiple_docs(self,
docs: List[str],
keyphrase_length: int = 1,
keyphrase_ngram_range: Tuple[int, int] = (1, 1),
stop_words: str = 'english',
top_n: int = 5,
min_df: int = 1):
min_df: int = 1,
vectorizer: CountVectorizer = None):
""" Extract keywords/keyphrases for a multiple documents
This currently does not use MMR as
Arguments:
docs: The document for which to extract keywords/keyphrases
keyphrase_length: Length, in words, of the extracted keywords/keyphrases
keyphrase_ngram_range: Length, in words, of the extracted keywords/keyphrases
stop_words: Stopwords to remove from the document
top_n: Return the top n keywords/keyphrases
min_df: The minimum frequency of words
vectorizer: Pass in your own CountVectorizer from scikit-learn
Returns:
keywords: The top n keywords for a document
"""
# Extract words
n_gram_range = (keyphrase_length, keyphrase_length)
count = CountVectorizer(ngram_range=n_gram_range, stop_words=stop_words, min_df=min_df).fit(docs)
if vectorizer:
count = vectorizer.fit(docs)
else:
count = CountVectorizer(ngram_range=keyphrase_ngram_range, stop_words=stop_words, min_df=min_df).fit(docs)
words = count.get_feature_names()
df = count.transform(docs)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
setuptools.setup(
name="keybert",
packages=["keybert"],
version="0.1.2",
version="0.1.3",
author="Maarten Grootendorst",
author_email="[email protected]",
description="KeyBERT performs keyword extraction with state-of-the-art transformer models.",
Expand Down
33 changes: 21 additions & 12 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,56 @@
import pytest
from .utils import get_test_data
from sklearn.feature_extraction.text import CountVectorizer

doc_one, doc_two = get_test_data()


@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)])
def test_single_doc(keyphrase_length, base_keybert):
@pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)])
@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
def test_single_doc(keyphrase_length, vectorizer, base_keybert):
""" Test whether the keywords are correctly extracted """
top_n = 5
keywords = base_keybert.extract_keywords(doc_one, keyphrase_length=keyphrase_length, min_df=1, top_n=top_n)

keywords = base_keybert.extract_keywords(doc_one,
keyphrase_ngram_range=keyphrase_length,
min_df=1,
top_n=top_n,
vectorizer=vectorizer)
assert isinstance(keywords, list)
assert isinstance(keywords[0], str)
assert len(keywords) == top_n
for keyword in keywords:
assert len(keyword.split(" ")) == keyphrase_length
assert len(keyword.split(" ")) <= keyphrase_length[1]


@pytest.mark.parametrize("keyphrase_length, mmr, maxsum", [(i+1, truth, not truth)
@pytest.mark.parametrize("keyphrase_length, mmr, maxsum", [((1, i+1), truth, not truth)
for i in range(4)
for truth in [True, False]])
def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, base_keybert):
@pytest.mark.parametrize("vectorizer", [None, CountVectorizer(ngram_range=(1, 1), stop_words="english")])
def test_extract_keywords_single_doc(keyphrase_length, mmr, maxsum, vectorizer, base_keybert):
""" Test extraction of protected single document method """
top_n = 5
keywords = base_keybert._extract_keywords_single_doc(doc_one,
top_n=top_n,
keyphrase_length=keyphrase_length,
keyphrase_ngram_range=keyphrase_length,
use_mmr=mmr,
use_maxsum=maxsum,
diversity=0.5)
diversity=0.5,
vectorizer=vectorizer)
assert isinstance(keywords, list)
assert isinstance(keywords[0], str)
assert len(keywords) == top_n
for keyword in keywords:
assert len(keyword.split(" ")) == keyphrase_length
assert len(keyword.split(" ")) <= keyphrase_length[1]


@pytest.mark.parametrize("keyphrase_length", [i+1 for i in range(5)])
@pytest.mark.parametrize("keyphrase_length", [(1, i+1) for i in range(5)])
def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert):
""" Test extractino of protected multiple document method"""
top_n = 5
keywords_list = base_keybert._extract_keywords_multiple_docs([doc_one, doc_two],
top_n=top_n,
keyphrase_length=keyphrase_length)
keyphrase_ngram_range=keyphrase_length)
assert isinstance(keywords_list, list)
assert isinstance(keywords_list[0], list)
assert len(keywords_list) == 2
Expand All @@ -50,7 +59,7 @@ def test_extract_keywords_multiple_docs(keyphrase_length, base_keybert):
assert len(keywords) == top_n

for keyword in keywords:
assert len(keyword.split(" ")) == keyphrase_length
assert len(keyword.split(" ")) <= keyphrase_length[1]


def test_error(base_keybert):
Expand Down

0 comments on commit 60bf9e5

Please sign in to comment.