import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import PathCollection
[docs]
def comparison_scatter(
    ax: plt.Axes, x: np.ndarray, y: np.ndarray, **scatter_kwargs
) -> PathCollection:
    """Scatter plot with diagonal line behind it. Shuffles the data to avoid overplotting.
    Args:
        ax (plt.Axes): Axes object to plot on
        x (np.ndarray): x values
        y (np.ndarray): y values
        **scatter_kwargs: Keyword arguments for the scatter plot. Defaults to marker='.', s=1.
    Returns:
        PathCollection: The PathCollection of the scatter plot.
    """
    scatter_kwargs = scatter_kwargs.copy()
    scatter_kwargs.setdefault("marker", ".")
    scatter_kwargs.setdefault("s", 1)
    # shuffle the data to avoid overplotting
    shuffle = np.random.permutation(len(x))
    x = x[shuffle]
    y = y[shuffle]
    for key in ["c", "marker", "alpha", "s"]:
        if key in scatter_kwargs and isinstance(scatter_kwargs[key], np.ndarray):
            scatter_kwargs[key] = scatter_kwargs[key][shuffle]
    scatter = ax.scatter(x, y, **scatter_kwargs)
    # make the limits of x and y axis equal
    x_min, x_max = ax.get_xlim()
    y_min, y_max = ax.get_ylim()
    lim = [min(x_min, y_min), max(x_max, y_max)]
    ax.set_xlim(lim)
    ax.set_ylim(lim)
    # draw diagonal line behind scatter plot
    ax.plot(lim, lim, color="k", alpha=1, zorder=-1, linewidth=0.5)
    ax.set_aspect("equal")
    return scatter