Skip to content

Commit

Permalink
Add evaluate_model and plot_actual_vs_predicted functions
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Nov 27, 2024
1 parent ef0234c commit d01ed18
Showing 1 changed file with 166 additions and 0 deletions.
166 changes: 166 additions & 0 deletions leafmap/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15155,3 +15155,169 @@ def line_to_points(data: str) -> "GeoDataFrame":
points_gdf = gpd.GeoDataFrame(geometry=points, crs=line_gdf.crs)

return points_gdf


def evaluate_model(
df: pd.DataFrame,
y_col: str = "y",
y_pred_col: str = "y_pred",
metrics: list = None,
drop_na: bool = True,
filter_nonzero: bool = True,
) -> dict:
"""
Evaluates the model performance on the given dataframe with customizable options.

Args:
df: A pandas DataFrame with columns for actual and predicted values.
y_col: Column name for the actual values.
y_pred_col: Column name for the predicted values.
metrics: A list of metrics to calculate. Available options:
- 'r2': R-squared
- 'r': Pearson correlation coefficient
- 'rmse': Root Mean Squared Error
- 'mae': Mean Absolute Error
- 'mape': Mean Absolute Percentage Error
Defaults to all metrics if None.
drop_na: Whether to drop rows with NaN in the actual values column.
filter_nonzero: Whether to filter out rows where actual values are zero.

Returns:
A dictionary of the selected performance metrics.
"""

import math

try:
from sklearn import metrics as skmetrics
except ImportError:
raise ImportError(
"The scikit-learn package is required for this function. Install it using 'pip install scikit-learn'."
)

# Default metrics if none are provided
if metrics is None:
metrics = ["r2", "r", "rmse", "mae", "mape"]

# Data preprocessing
if drop_na:
df = df.dropna(subset=[y_col])
if filter_nonzero:
df = df[df[y_col] != 0]

# Metric calculations
results = {}
if "r2" in metrics:
results["r2"] = skmetrics.r2_score(df[y_col], df[y_pred_col])
if "r" in metrics:
results["r"] = df[y_col].corr(df[y_pred_col])
if "rmse" in metrics:
results["rmse"] = math.sqrt(
skmetrics.mean_squared_error(df[y_col], df[y_pred_col])
)
if "mae" in metrics:
results["mae"] = skmetrics.mean_absolute_error(df[y_col], df[y_pred_col])
if "mape" in metrics:
results["mape"] = skmetrics.mean_absolute_percentage_error(
df[y_col], df[y_pred_col]
)

return results


def plot_actual_vs_predicted(
df: pd.DataFrame,
x_col: str = "y",
y_col: str = "y_pred",
xlim: tuple = None,
ylim: tuple = None,
title: str = "Actual vs. Predicted Values",
x_label: str = "Actual Values",
y_label: str = "Predicted Values",
marker_size: int = 6,
marker_opacity: float = 0.7,
marker_color: str = "blue",
line_color: str = "red",
line_dash: str = "dash",
width: int = 800,
height: int = 800,
showlegend: bool = True,
template: str = "plotly_white",
square_aspect: bool = True,
return_figure: bool = False,
**kwargs,
):
"""
Plots a customizable scatter plot with a reference line for actual vs. predicted values.

Args:
df: A pandas DataFrame with columns for actual and predicted values.
x_col: Column name for the x-axis (actual values).
y_col: Column name for the y-axis (predicted values).
xlim: A tuple (min, max) for x-axis limits.
ylim: A tuple (min, max) for y-axis limits.
title: Title of the plot.
x_label: Label for the x-axis.
y_label: Label for the y-axis.
marker_size: Size of the scatter plot markers.
marker_opacity: Opacity of the scatter plot markers.
marker_color: Color of the scatter plot markers.
line_color: Color of the reference line.
line_dash: Dash style of the reference line ('dash', 'dot', etc.).
width: Width of the plot in pixels.
height: Height of the plot in pixels.
showlegend: Whether to show the legend.
template: Plotly template for styling.
square_aspect: Whether to enforce a square aspect ratio.
return_figure: Whether to return the Plotly figure object.
**kwargs: Additional keyword arguments for Plotly figure.

Returns:
A Plotly figure object.

"""
import plotly.graph_objects as go

# Default x and y limits if not provided
x_min, x_max = xlim if xlim else (df[x_col].min(), df[x_col].max())
y_min, y_max = ylim if ylim else (df[y_col].min(), df[y_col].max())

# Scatter plot for actual vs predicted
scatter = go.Scatter(
x=df[x_col],
y=df[y_col],
mode="markers",
marker=dict(size=marker_size, opacity=marker_opacity, color=marker_color),
name="Predicted vs Actual",
)

# Reference line y = x
ref_line = go.Scatter(
x=[x_min, x_max],
y=[x_min, x_max],
mode="lines",
line=dict(color=line_color, dash=line_dash),
name="Reference Line",
)

# Layout settings
layout = go.Layout(
title=title,
xaxis=dict(title=x_label, range=[x_min, x_max]),
yaxis=dict(title=y_label, range=[y_min, y_max]),
template=template,
showlegend=showlegend,
height=height,
width=width,
)

# Ensure square aspect ratio if specified
if square_aspect:
layout.update(xaxis_scaleanchor="y")

# Create the figure
fig = go.Figure(data=[scatter, ref_line], layout=layout, **kwargs)
if return_figure:
return fig
else:
fig.show()

0 comments on commit d01ed18

Please sign in to comment.