Skip to content

Commit

Permalink
[Feature-selection] Replace matplotlib with plotly (#815)
Browse files Browse the repository at this point in the history
  • Loading branch information
yonishelach authored Jun 18, 2024
1 parent 55a6023 commit 696da33
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 74 deletions.
52 changes: 12 additions & 40 deletions feature_selection/feature_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,15 @@
# limitations under the License.
#
import json
import os

import matplotlib.pyplot as plt
import mlrun
import mlrun.datastore
import mlrun.utils
import mlrun.feature_store as fs
import mlrun.utils
import numpy as np
import pandas as pd
import seaborn as sns
from mlrun.artifacts import PlotArtifact
import plotly.express as px
from mlrun.artifacts import PlotlyArtifact
from mlrun.datastore.targets import ParquetTarget
# MLRun utils
from mlrun.utils.helpers import create_class
Expand All @@ -42,15 +40,6 @@
}


def _clear_current_figure():
"""
Clear matplotlib current figure.
"""
plt.cla()
plt.clf()
plt.close()


def show_values_on_bars(axs, h_v="v", space=0.4):
def _show_on_single_plot(ax_):
if h_v == "v":
Expand All @@ -74,33 +63,18 @@ def _show_on_single_plot(ax_):


def plot_stat(context, stat_name, stat_df):
_clear_current_figure()

# Add chart
ax = plt.axes()
stat_chart = sns.barplot(
sorted_df = stat_df.sort_values(stat_name)
fig = px.bar(
data_frame=sorted_df,
x=stat_name,
y="index",
data=stat_df.sort_values(stat_name, ascending=False).reset_index(),
ax=ax,
y=sorted_df.index,
title=f"{stat_name} feature scores",
color=stat_name,
)
plt.tight_layout()

for p in stat_chart.patches:
width = p.get_width()
plt.text(
5 + p.get_width(),
p.get_y() + 0.55 * p.get_height(),
"{:1.2f}".format(width),
ha="center",
va="center",
)

context.log_artifact(
PlotArtifact(f"{stat_name}", body=plt.gcf()),
local_path=os.path.join("plots", "feature_selection", f"{stat_name}.html"),
item=PlotlyArtifact(key=stat_name, figure=fig),
local_path=f"{stat_name}.html",
)
_clear_current_figure()


def feature_selection(
Expand All @@ -115,7 +89,6 @@ def feature_selection(
sample_ratio: float = None,
output_vector_name: float = None,
ignore_type_errors: bool = False,
is_feature_vector: bool = False,
):
"""
Applies selected feature selection statistical functions or models on our 'df_artifact'.
Expand All @@ -138,10 +111,9 @@ def feature_selection(
model name (ex. LinearSVC), formalized json (contains 'CLASS',
'FIT', 'META') or a path to such json file.
:param max_scaled_scores: produce feature scores table scaled with max_scaler.
:param sample_ratio: percentage of the dataset the user whishes to compute the feature selection process on.
:param sample_ratio: percentage of the dataset the user wishes to compute the feature selection process on.
:param output_vector_name: creates a new feature vector containing only the identifies features.
:param ignore_type_errors: skips datatypes that are neither float nor int within the feature vector.
:param is_feature_vector: bool stating if the data is passed as a feature vector.
"""
stat_filters = stat_filters or DEFAULT_STAT_FILTERS
model_filters = model_filters or DEFAULT_MODEL_FILTERS
Expand Down
Loading

0 comments on commit 696da33

Please sign in to comment.