From 12cc01cd42abfb55e0a8e2c5098a8af0816db950 Mon Sep 17 00:00:00 2001 From: Peyman Mohseni Kiasari Date: Sat, 26 Aug 2023 00:13:23 +0200 Subject: [PATCH] Update plot.py ## Added Arrow Annotation Feature to `plot2D_samples_mat` Function ### Changes: 1. **New Parameter `draw_arrows`**: Introduced an optional boolean parameter `draw_arrows` to the `plot2D_samples_mat` function. When set to `True`, this parameter allows users to plot arrows in the middle of the lines connecting source and target samples. This helps to visually identify the direction from source to target. 2. **Arrow Placement**: The arrows are strategically placed at the midpoint of the lines for clear visualization. 3. **Arrow Properties**: The color and alpha (transparency) of the arrows match the lines they are associated with, ensuring visual consistency. ### Motivation: The addition of the arrow annotation feature enhances the visual representation of the connections between source and target samples. Especially in cases where directionality matters, these arrows provide a clearer understanding of the flow from source to target. ### Code: The main code changes involve: - Calculating the midpoint of the line segments. - Using the `plt.annotate` method to draw an arrow at the calculated midpoint. ### Example Usage: ```python plot2D_samples_mat(xs, xt, G, draw_arrows=True) ``` --- ot/plot.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/ot/plot.py b/ot/plot.py index 4b1bfb128..7a6ccb8ca 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -60,13 +60,12 @@ def plot1D_mat(a, b, M, title=''): pl.subplots_adjust(wspace=0., hspace=0.2) -def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): +def plot2D_samples_mat(xs, xt, G, draw_arrows: bool = True, thr=1e-8, **kwargs): r""" Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values Plot lines between source and target 2D samples with a color proportional to the value of the matrix :math:`\mathbf{G}` between samples. - Parameters ---------- xs : ndarray, shape (ns,2) @@ -75,6 +74,8 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): Target samples positions G : ndarray, shape (na,nb) OT matrix + draw_arrows : bool, optional + If True, draw directional arrows in the middle of the lines thr : float, optional threshold above which the line is drawn **kwargs : dict @@ -93,5 +94,17 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): for i in range(xs.shape[0]): for j in range(xt.shape[0]): if G[i, j] / mx > thr: - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], - alpha=G[i, j] / mx * scale, **kwargs) + plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], + alpha=G[i, j] / mx * scale, **kwargs) + + if draw_arrows: + # Calculate the midpoint + mid_x = (xs[i, 0] + xt[j, 0]) / 2 + mid_y = (xs[i, 1] + xt[j, 1]) / 2 + + # Annotate with an arrowhead at the midpoint + plt.annotate('', + xy=(mid_x, mid_y), + xytext=(mid_x - 0.5 * (xt[j, 0] - xs[i, 0]), mid_y - 0.5 * (xt[j, 1] - xs[i, 1])), + arrowprops=dict(arrowstyle='-|>', color=kwargs['color'], alpha=G[i, j] / mx * scale) + )