diff --git a/ot/plot.py b/ot/plot.py index 4b1bfb128..7a6ccb8ca 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -60,13 +60,12 @@ def plot1D_mat(a, b, M, title=''): pl.subplots_adjust(wspace=0., hspace=0.2) -def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): +def plot2D_samples_mat(xs, xt, G, draw_arrows: bool = True, thr=1e-8, **kwargs): r""" Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values Plot lines between source and target 2D samples with a color proportional to the value of the matrix :math:`\mathbf{G}` between samples. - Parameters ---------- xs : ndarray, shape (ns,2) @@ -75,6 +74,8 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): Target samples positions G : ndarray, shape (na,nb) OT matrix + draw_arrows : bool, optional + If True, draw directional arrows in the middle of the lines thr : float, optional threshold above which the line is drawn **kwargs : dict @@ -93,5 +94,17 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): for i in range(xs.shape[0]): for j in range(xt.shape[0]): if G[i, j] / mx > thr: - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], - alpha=G[i, j] / mx * scale, **kwargs) + plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], + alpha=G[i, j] / mx * scale, **kwargs) + + if draw_arrows: + # Calculate the midpoint + mid_x = (xs[i, 0] + xt[j, 0]) / 2 + mid_y = (xs[i, 1] + xt[j, 1]) / 2 + + # Annotate with an arrowhead at the midpoint + plt.annotate('', + xy=(mid_x, mid_y), + xytext=(mid_x - 0.5 * (xt[j, 0] - xs[i, 0]), mid_y - 0.5 * (xt[j, 1] - xs[i, 1])), + arrowprops=dict(arrowstyle='-|>', color=kwargs['color'], alpha=G[i, j] / mx * scale) + )