Source code for mldft.utils.plotting.scatter

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import PathCollection


[docs] def comparison_scatter( ax: plt.Axes, x: np.ndarray, y: np.ndarray, **scatter_kwargs ) -> PathCollection: """Scatter plot with diagonal line behind it. Shuffles the data to avoid overplotting. Args: ax (plt.Axes): Axes object to plot on x (np.ndarray): x values y (np.ndarray): y values **scatter_kwargs: Keyword arguments for the scatter plot. Defaults to marker='.', s=1. Returns: PathCollection: The PathCollection of the scatter plot. """ scatter_kwargs = scatter_kwargs.copy() scatter_kwargs.setdefault("marker", ".") scatter_kwargs.setdefault("s", 1) # shuffle the data to avoid overplotting shuffle = np.random.permutation(len(x)) x = x[shuffle] y = y[shuffle] for key in ["c", "marker", "alpha", "s"]: if key in scatter_kwargs and isinstance(scatter_kwargs[key], np.ndarray): scatter_kwargs[key] = scatter_kwargs[key][shuffle] scatter = ax.scatter(x, y, **scatter_kwargs) # make the limits of x and y axis equal x_min, x_max = ax.get_xlim() y_min, y_max = ax.get_ylim() lim = [min(x_min, y_min), max(x_max, y_max)] ax.set_xlim(lim) ax.set_ylim(lim) # draw diagonal line behind scatter plot ax.plot(lim, lim, color="k", alpha=1, zorder=-1, linewidth=0.5) ax.set_aspect("equal") return scatter