Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TF-IGM (Inverse Gravity Moment) weighting #45

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ nosetests.xml
coverage.xml
*,cover
.hypothesis/
*.swp

# Translations
*.mo
Expand Down
65 changes: 65 additions & 0 deletions examples/feature_weighting/plot_tfigm_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# License: BSD 3 clause
#
# Authors: Roman Yurchak <[email protected]>

import pandas as pd

from sklearn.svm import LinearSVC
from sklearn.preprocessing import Normalizer, FunctionTransformer
from sklearn.pipeline import make_pipeline
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import cross_validate
from sklearn.metrics import f1_score

from sklearn_extra.feature_weighting import TfigmTransformer

if "CI" in os.environ:
# make this example run faster in CI
categories = ["sci.crypt", "comp.graphics", "comp.sys.mac.hardware"]
else:
categories = None

docs, y = fetch_20newsgroups(return_X_y=True, categories=categories)


vect = CountVectorizer(min_df=5, stop_words="english", ngram_range=(1, 1))
X = vect.fit_transform(docs)

res = []

for scaler_label, scaler in [
("TF", FunctionTransformer(lambda x: x)),
("TF-IDF(sublinear_tf=False)", TfidfTransformer()),
("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)),
("TF-IGM(tf_scale=None)", TfigmTransformer()),
("TF-IGM(tf_scale='sqrt')", TfigmTransformer(tf_scale="sqrt"),),
("TF-IGM(tf_scale='log1p')", TfigmTransformer(tf_scale="log1p"),),
]:
pipe = make_pipeline(scaler, Normalizer())
X_tr = pipe.fit_transform(X, y)
est = LinearSVC()
scoring = {
"F1-macro": lambda est, X, y: f1_score(
y, est.predict(X), average="macro"
),
"balanced_accuracy": "balanced_accuracy",
}
scores = cross_validate(est, X_tr, y, scoring=scoring,)
for key, val in scores.items():
if not key.endswith("_time"):
res.append(
{
"metric": "_".join(key.split("_")[1:]),
"subset": key.split("_")[0],
"preprocessing": scaler_label,
"score": f"{val.mean():.3f}±{val.std():.3f}",
}
)
scores = (
pd.DataFrame(res)
.set_index(["preprocessing", "metric", "subset"])["score"]
.unstack(-1)
)
scores = scores["test"].unstack(-1).sort_values("F1-macro", ascending=False)
print(scores)
5 changes: 5 additions & 0 deletions sklearn_extra/feature_weighting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# License: BSD 3 clause

from ._text import TfigmTransformer

__all__ = ["TfigmTransformer"]
190 changes: 190 additions & 0 deletions sklearn_extra/feature_weighting/_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# License: BSD 3 clause
#
# Authors: Roman Yurchak <[email protected]>

import numpy as np
import scipy.sparse as sp

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_array, check_X_y
from sklearn.preprocessing import LabelEncoder


class TfigmTransformer(BaseEstimator, TransformerMixin):
"""TF-IGM feature weighting.

TF-IGM (Inverse Gravity Momentum) is a supervised
feature weighting scheme for classification tasks that measures
class distinguishing power.

See User Guide for mode details.

Parameters
----------
alpha : float, default=0.15
regularization parameter. Known good default values are 0.14 - 0.20.
tf_scale : {"sqrt", "log1p"}, default=None
if not None, scaling applied to term frequency. Possible scaling values are,
- "sqrt": square root scaling
- "log1p": ``log(1 + tf)`` scaling. This option corresponds to
``sublinear_tf=True`` parameter in
:class:`~sklearn.feature_extraction.text.TfidfTransformer`.

Attributes
----------
igm_ : array of shape (n_features,)
The Inverse Gravity Moment (IGM) weight.
coef_ : array of shape (n_features,)
Regularized IGM weight corresponding to ``alpha + igm_``

Examples
--------
>>> from sklearn.feature_extraction.text import CountVectorizer
>>> from sklearn.pipeline import Pipeline
>>> from sklearn_extra.feature_weighting import TfigmTransformer
>>> corpus = ['this is the first document',
... 'this document is the second document',
... 'and this is the third one',
... 'is this the first document']
>>> y = ['1', '2', '1', '2']
>>> pipe = Pipeline([('count', CountVectorizer()),
... ('tfigm', TfigmTransformer())]).fit(corpus, y)
>>> pipe['count'].transform(corpus).toarray()
array([[0, 1, 1, 1, 0, 0, 1, 0, 1],
[0, 2, 0, 1, 0, 1, 1, 0, 1],
[1, 0, 0, 1, 1, 0, 1, 1, 1],
[0, 1, 1, 1, 0, 0, 1, 0, 1]])
>>> pipe['tfigm'].igm_
array([1. , 0.25, 0. , 0. , 1. , 1. , 0. , 1. , 0. ])
>>> pipe['tfigm'].coef_
array([1.15, 0.4 , 0.15, 0.15, 1.15, 1.15, 0.15, 1.15, 0.15])
>>> pipe.transform(corpus).shape
(4, 9)

References
----------
Chen, Kewen, et al. "Turning from TF-IDF to TF-IGM for term weighting
in text classification." Expert Systems with Applications 66 (2016):
245-260.
"""

def __init__(self, alpha=0.15, tf_scale=None):
self.alpha = alpha
self.tf_scale = tf_scale

def _fit(self, X, y):
"""Learn the igm vector (global term weights)

Parameters
----------
X : {array-like, sparse matrix} of (n_samples, n_features)
a matrix of term/token counts
y : array-like of shape (n_samples,)
target classes
"""
self._le = LabelEncoder().fit(y)
n_class = len(self._le.classes_)
class_freq = np.zeros((n_class, X.shape[1]))

X_nz = X != 0
if sp.issparse(X_nz):
X_nz = X_nz.asformat("csr", copy=False)

for idx, class_label in enumerate(self._le.classes_):
y_mask = y == class_label
n_samples = y_mask.sum()
class_freq[idx, :] = X_nz[y_mask].sum(axis=0) / n_samples

self._class_freq = class_freq
class_freq_sort = np.sort(self._class_freq, axis=0)
f1 = class_freq_sort[-1, :]

fk = (class_freq_sort * np.arange(n_class, 0, -1)[:, None]).sum(axis=0)
# avoid division by zero
igm = np.divide(f1, fk, out=np.zeros_like(f1), where=(fk != 0))
if n_class > 1:
# scale weights to [0, 1]
self.igm_ = ((1 + n_class) * n_class * igm - 2) / (
(1 + n_class) * n_class - 2
)
else:
self.igm_ = igm
self.coef_ = self.alpha + self.igm_
return self

def fit(self, X, y):
"""Learn the igm vector (global term weights)

Parameters
----------
X : {array-like, sparse matrix} of (n_samples, n_features)
a matrix of term/token counts
y : array-like of shape (n_samples,)
target classes
"""
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
self._fit(X, y)
return self

def _transform(self, X):
"""Transform a count matrix to a TF-IGM representation

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
a matrix of term/token counts

Returns
-------
vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
transformed matrix
"""
if self.tf_scale is None:
pass
elif self.tf_scale == "sqrt":
X = np.sqrt(X)
elif self.tf_scale == "log1p":
X = np.log1p(X)
else:
raise ValueError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A friendlier message here?


if sp.issparse(X):
X_tr = X @ sp.diags(self.coef_)
else:
X_tr = X * self.coef_[None, :]
return X_tr

def transform(self, X):
"""Transform a count matrix to a TF-IGM representation

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
a matrix of term/token counts

Returns
-------
vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
transformed matrix
"""
X = check_array(X, accept_sparse=["csr", "csc"])
X_tr = self._transform(X)
return X_tr

def fit_transform(self, X, y):
"""Transform a count matrix to a TF-IGM representation

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
a matrix of term/token counts

Returns
-------
vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
transformed matrix
"""
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
self._fit(X, y)
X_tr = self._transform(X)
return X_tr
83 changes: 83 additions & 0 deletions sklearn_extra/feature_weighting/tests/test_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# License: BSD 3 clause

import numpy as np
from numpy.testing import assert_allclose, assert_array_less
import scipy.sparse as sp

import pytest

from sklearn_extra.feature_weighting import TfigmTransformer
from sklearn.datasets import make_classification


@pytest.mark.parametrize("array_format", ["dense", "csr", "csc", "coo"])
def test_tfigm_transform(array_format):
X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]])
if array_format != "dense":
X = sp.csr_matrix(X).asformat(array_format)
y = np.array(["a", "b", "a", "c"])

alpha = 0.2
est = TfigmTransformer(alpha=alpha)
X_tr = est.fit_transform(X, y)

assert_allclose(est.igm_, [0.20, 0.40, 0.0])
assert_allclose(est.igm_ + alpha, est.coef_)

assert X_tr.shape == X.shape
assert sp.issparse(X_tr) is (array_format != "dense")

if array_format == "dense":
assert_allclose(X * est.coef_[None, :], X_tr)
else:
assert_allclose(X.A * est.coef_[None, :], X_tr.A)


def test_tfigm_synthetic():
X, y = make_classification(
n_samples=100,
n_features=10,
n_informative=5,
n_redundant=0,
random_state=0,
n_classes=5,
shuffle=False,
)
X = (X > 0).astype(np.float)

est = TfigmTransformer()
est.fit(X, y)
# informative features have higher IGM weights than noisy ones.
# (athough here we lose a lot of information due to thresholding of X).
assert est.igm_[:5].mean() / est.igm_[5:].mean() > 3


@pytest.mark.parametrize("n_class", [2, 5])
def test_tfigm_random_distribution(n_class):
rng = np.random.RandomState(0)
n_samples, n_features = 500, 4
X = rng.randint(2, size=(n_samples, n_features))
y = rng.randint(n_class, size=(n_samples,))

est = TfigmTransformer()
X_tr = est.fit_transform(X, y)

# all weighs are strictly positive
assert_array_less(0, est.igm_)
# and close to zero, since none of the features are discriminant
assert_array_less(est.igm_, 0.05)


def test_tfigm_valid_target():
X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]])
y = None

est = TfigmTransformer()
with pytest.raises(ValueError, match="y cannot be None"):
est.fit(X, y)

# check asymptotic behaviour for 1 class
y = [1, 1, 1, 1]
est = TfigmTransformer()
est.fit(X, y)
assert_allclose(est.igm_[0], np.ones(3))
9 changes: 8 additions & 1 deletion sklearn_extra/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
from sklearn_extra.kernel_approximation import Fastfood
from sklearn_extra.kernel_methods import EigenProClassifier, EigenProRegressor
from sklearn_extra.cluster import KMedoids
from sklearn_extra.feature_weighting import TfigmTransformer

ALL_ESTIMATORS = [Fastfood, KMedoids, EigenProClassifier, EigenProRegressor]
ALL_ESTIMATORS = [
Fastfood,
KMedoids,
EigenProClassifier,
EigenProRegressor,
TfigmTransformer,
]

if hasattr(estimator_checks, "parametrize_with_checks"):
# Common tests are only run on scikit-learn 0.22+
Expand Down