Source code for edlgt.tools.plotting

"""Plotting and time-series post-processing helpers for analysis scripts.

This module combines small visualization utilities (figure sizing, tick
formatters) with helpers for running time averages and windowed smoothing used
in dynamics post-processing.

Only the functions listed in ``__all__`` are considered the public API here.
"""

import colorsys
import logging
import math
import os
from pathlib import Path

import numpy as np

_REPO_ROOT = Path(__file__).resolve().parents[2]
CACHE_DIR = _REPO_ROOT / ".cache"
MPLCONFIG_DIR = CACHE_DIR / "matplotlib"
FIGURES_DIR = _REPO_ROOT / "figures"

CACHE_DIR.mkdir(exist_ok=True)
MPLCONFIG_DIR.mkdir(exist_ok=True)
os.environ.setdefault("XDG_CACHE_HOME", str(CACHE_DIR))
os.environ.setdefault("MPLCONFIGDIR", str(MPLCONFIG_DIR))

import matplotlib
from matplotlib.ticker import LogLocator


def _in_ipykernel() -> bool:
    try:
        from IPython import get_ipython
    except ImportError:
        return False
    shell = get_ipython()
    return shell is not None and hasattr(shell, "kernel")


if not _in_ipykernel() and "MPLBACKEND" not in os.environ:
    matplotlib.use("Agg")

import matplotlib.colors as mc
import matplotlib.pyplot as plt

logger = logging.getLogger(__name__)

TEXTWIDTH_PT = 510.0
COLUMNWIDTH_PT = 246.0
TEXTWIDTH_IN = TEXTWIDTH_PT / 72.27
COLUMNWIDTH_IN = COLUMNWIDTH_PT / 72.27
X_MINOR = LogLocator(base=10.0, subs=np.arange(1.0, 10.0) * 0.1, numticks=10)
SAVE_PLOT_KWARGS = {"bbox_inches": "tight", "transparent": False}

__all__ = [
    "CACHE_DIR",
    "COLUMNWIDTH_IN",
    "COLUMNWIDTH_PT",
    "FIGURES_DIR",
    "MPLCONFIG_DIR",
    "SAVE_PLOT_KWARGS",
    "TEXTWIDTH_IN",
    "TEXTWIDTH_PT",
    "X_MINOR",
    "configure_matplotlib",
    "fake_log",
    "fit_log_growth",
    "fit_saturation_power_law",
    "get_tline",
    "time_integral",
    "custom_average",
    "moving_time_integral",
    "gaussian_time_integral",
    "moving_time_integral_centered",
    "lighten_color",
    "bz_axis",
    "save_figure",
    "set_size",
]


def _configure_plotting_loggers() -> None:
    """Silence verbose third-party loggers used during figure export."""
    logging.getLogger("fontTools").setLevel(logging.WARNING)
    logging.getLogger("fontTools.subset").setLevel(logging.WARNING)


def configure_matplotlib(fontsize: int = 10) -> None:
    """Apply shared Matplotlib defaults used across analysis plots."""
    FIGURES_DIR.mkdir(exist_ok=True)
    _configure_plotting_loggers()
    plt.rcParams.update(
        {
            "font.size": fontsize,
            "font.family": "STIXGeneral",
            "mathtext.fontset": "stix",
            "text.usetex": False,
            "xtick.labelsize": fontsize,
            "ytick.labelsize": fontsize,
            "legend.fontsize": fontsize,
            "legend.title_fontsize": fontsize,
            "legend.borderpad": 0.3,
            "legend.handletextpad": 0.5,
            "legend.borderaxespad": 1.0,
            "legend.columnspacing": 1.0,
            "axes.labelsize": fontsize,
            "axes.titlesize": fontsize,
            "lines.linewidth": 2.0,
            "lines.color": "k",
            "figure.titlesize": fontsize,
            "savefig.format": "pdf",
            "savefig.dpi": 72.27,
            "pdf.fonttype": 42,
            "ps.fonttype": 42,
        }
    )


def save_figure(
    stem: str | Path,
    *,
    fig=None,
    show: bool | None = None,
    close: bool = True,
    directory: str | Path | None = None,
    save_kwargs: dict | None = None,
    return_path: bool = False,
) -> Path | None:
    """Save a figure to the standard figures directory and optionally show it."""
    _configure_plotting_loggers()
    if fig is None:
        fig = plt.gcf()
    if show is None:
        show = _in_ipykernel() or bool(plt.isinteractive())

    save_path = Path(stem)
    if not save_path.suffix:
        save_path = save_path.with_suffix(".pdf")
    if not save_path.is_absolute():
        base_dir = FIGURES_DIR if directory is None else Path(directory)
        save_path = base_dir / save_path
    save_path.parent.mkdir(parents=True, exist_ok=True)

    figure_save_kwargs = SAVE_PLOT_KWARGS.copy()
    if save_kwargs:
        figure_save_kwargs.update(save_kwargs)
    fig.savefig(save_path, **figure_save_kwargs)

    if show:
        plt.show()
    if close:
        plt.close(fig)
    if return_path:
        return save_path
    return None


[docs] def set_size(width_pt, fraction=1, subplots=(1, 1), height_factor=1.0): """Compute figure dimensions in inches from a document width. Parameters ---------- width_pt : float Reference width in points. fraction : float, optional Fraction of the width to occupy. Default is ``1``. subplots : tuple, optional Number of subplot rows and columns. Default is ``(1, 1)``. height_factor : float, optional Additional multiplier applied to the computed height. Returns ------- tuple Figure dimensions ``(width_in, height_in)`` in inches. Notes ----- The height is based on a golden-ratio scaling, adjusted by the subplot layout and ``height_factor``. """ # Width of figure (in pts) fig_width_pt = width_pt * fraction inches_per_pt = 1 / 72.27 golden_ratio = (5**0.5 - 1) / 2 fig_width_in = fig_width_pt * inches_per_pt fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1]) return (fig_width_in, fig_height_in * height_factor)
# To extract simulations use: # energy[ii][jj] = extract_dict(ugrid[ii][jj], key="res", glob="energy") # or: # energy[ii][jj] = get_sim(ugrid[ii][jj]).res["energy"] # To acquire the psi file: # sim = get_sim(ugrid[ii][jj]) # sim.link("psi") # psi = sim.load("psi", cache=True) @plt.FuncFormatter def fake_log(tick_value, _pos): """Format axis ticks as powers of ten for Matplotlib. Parameters ---------- tick_value : float Tick value. pos : int Tick position (unused; required by Matplotlib formatter API). Returns ------- str Tick label formatted as a power of ten. """ return rf"$10^{{{int(tick_value)}}}$" def lighten_color(color, amount=0.5): """Return a lighter version of a Matplotlib-compatible color. Parameters ---------- color : str or tuple Matplotlib color string, hex string, or RGB tuple. amount : float, optional Lightening factor applied in HLS space. Returns ------- tuple Lightened RGB tuple. """ try: base_color = mc.cnames[color] except (KeyError, TypeError): base_color = color hls_color = colorsys.rgb_to_hls(*mc.to_rgb(base_color)) return colorsys.hls_to_rgb( hls_color[0], 1 - amount * (1 - hls_color[1]), hls_color[2] )
[docs] def gaussian_time_integral(time, observable_values, sigma=None): """Smooth a time series with a Gaussian-weighted local average. Parameters ---------- time : numpy.ndarray One-dimensional time grid (can be non-uniform). observable_values : numpy.ndarray Observable values sampled on ``time``. sigma : float, optional Width of the Gaussian window. If ``None``, a default value equal to one-tenth of the total time range is used. Returns ------- numpy.ndarray Smoothed observable values with the same shape as ``observable_values``. """ # Choose a default sigma if none is provided. if sigma is None: sigma = (time[-1] - time[0]) / 10.0 smoothed_values = np.zeros_like(observable_values) # For each time point, compute the Gaussian-weighted average. for time_idx, time_value in enumerate(time): # Compute Gaussian weights centered at t. weights = np.exp(-0.5 * ((time - time_value) / sigma) ** 2) # Use numerical integration (trapezoidal rule) to perform the weighted average. weighted_sum = np.trapz(weights * observable_values, time) weight_norm = np.trapz(weights, time) smoothed_values[time_idx] = weighted_sum / weight_norm return smoothed_values
[docs] def moving_time_integral(time, observable_values, max_points=100): """Compute a moving-window time average using trapezoidal integration. Parameters ---------- time : numpy.ndarray One-dimensional time grid (can be non-uniform). observable_values : numpy.ndarray Observable values sampled on ``time``. max_points : int, optional Maximum number of samples used in the averaging window. Returns ------- numpy.ndarray Running averaged observable with the same shape as ``observable_values``. Notes ----- At early times, when fewer than ``max_points`` samples are available, the window includes all samples from the start. """ averaged_values = np.zeros_like(observable_values) for time_idx in range(len(time)): # Determine the starting index of the moving window. start = max(0, time_idx - max_points + 1) t_segment = time[start : time_idx + 1] observable_segment = observable_values[start : time_idx + 1] # Compute the integral over the selected time window using the trapezoidal rule. # Then normalize by the width of the time window to get an average. dt = t_segment[-1] - t_segment[0] if dt != 0: integrated_value = np.trapz(observable_segment, t_segment) averaged_values[time_idx] = integrated_value / dt else: averaged_values[time_idx] = observable_segment[0] return averaged_values
[docs] def time_integral(time, observable_values): """Compute a cumulative time average of an observable. Parameters ---------- time : numpy.ndarray One-dimensional time grid. observable_values : numpy.ndarray Observable values sampled on ``time``. Returns ------- numpy.ndarray Array where entry ``i`` is the time-averaged value accumulated up to ``time[i]``. """ averaged_values = np.zeros_like(observable_values) averaged_values[0] = observable_values[0] for stop_idx in range(1, len(time)): for segment_idx in range(1, stop_idx + 1): averaged_values[stop_idx] += ( 0.5 * (observable_values[segment_idx] + observable_values[segment_idx - 1]) * (time[segment_idx] - time[segment_idx - 1]) / time[stop_idx] ) return averaged_values
def moving_time_integral_centered(time, observable_values, max_points=101): """Compute a centered moving time average with the largest symmetric window. - For each index ``i``, choose a radius ``r_i = min(half, i, N-1-i)``. - The window is ``[i - r_i, ..., i + r_i]``. - Near the edges the window shrinks symmetrically. - In the bulk it has full length ``max_points``. Parameters ---------- time : 1D np.ndarray Time points (can be non-uniformly spaced). observable_values : 1D np.ndarray Observable values ``M(t_i)``. max_points : int Target maximum number of points in the window (will be made odd). Returns ------- numpy.ndarray Time-averaged observable at each time point. """ num_points = len(time) if num_points != len(observable_values): raise ValueError("time and observable_values must have the same length") # Enforce odd max_points for a perfectly centered window if max_points % 2 == 0: max_points += 1 half = max_points // 2 averaged_values = np.zeros_like(observable_values, dtype=float) for time_idx in range(num_points): # Largest symmetric radius that fits both sides and the max_points constraint radius = min(half, time_idx, num_points - 1 - time_idx) start = time_idx - radius end = time_idx + radius + 1 # slice is [start, end) t_segment = time[start:end] observable_segment = observable_values[start:end] dt = t_segment[-1] - t_segment[0] if dt != 0: integrated_value = np.trapezoid(observable_segment, t_segment) averaged_values[time_idx] = integrated_value / dt else: averaged_values[time_idx] = observable_segment[0] return averaged_values
[docs] def get_tline(par: dict): """Build a uniform time grid from a parameter dictionary. Parameters ---------- par : dict Dictionary containing at least ``"start"``, ``"stop"``, and ``"delta_n"``. Returns ------- numpy.ndarray Uniform time grid starting at ``par["start"]`` with step ``par["delta_n"]`` and stopping before ``par["stop"]``. """ start = par["start"] stop = par["stop"] delta_n = par["delta_n"] n_steps = int((stop - start) / delta_n) return start + np.arange(n_steps) * delta_n
[docs] def custom_average(arr, staggered=None, norm=None): """Average rows of a 2D array with optional site selection or weighting. Parameters ---------- arr : numpy.ndarray Two-dimensional array where each row is averaged over columns. staggered : {"even", "odd"}, optional If provided, average only even or odd column indices. norm : numpy.ndarray, optional Weight vector used for a dot-product average. If provided, this branch is used instead of the ``staggered`` selection. Returns ------- numpy.ndarray One-dimensional array containing one averaged value per row. Raises ------ ValueError If ``norm`` is provided and its length does not match ``arr.shape[1]``. """ # Determine indices to consider based on the staggered parameter indices = np.arange(arr.shape[1]) if staggered == "even": indices_to_consider = indices[indices % 2 == 0] # Select even indices elif staggered == "odd": indices_to_consider = indices[indices % 2 != 0] # Select odd indices else: indices_to_consider = indices if norm is not None: # Ensure norm is a 1D array with the same length as the number of columns in arr if norm.shape[0] != arr.shape[1]: raise ValueError( "norm vector length " f"{norm.shape[0]} must match the number of columns in arr " f"{arr.shape[1]}" ) # Calculate the scalar product of each row and the norm vector # then divide by the number of columns mean_values = np.dot(arr, norm) / arr.shape[1] else: # Calculate the mean across the selected indices mean_values = np.mean(arr[:, indices_to_consider], axis=1) return mean_values
def bz_axis(Nk: int, *, numeric_labels: bool = False): """Build a symmetric Brillouin-zone axis for ``Nk`` momentum points. Returns ------- tuple ``(k_path, order, tick_labels)`` where ``k_path`` spans ``[-pi, pi]``, ``order`` wraps the data with zero momentum in the middle, and ``tick_labels`` contains either integer or ``pi``-fraction labels. Notes ----- Assumes ``Nk`` is even. """ if Nk % 2 != 0: raise ValueError("Nk must be even for a symmetric Brillouin zone within ±π.") # 1) x-positions: include both endpoints so you can 'close' the band plot. # This is equivalent to m * (2π/Nk) for m in [-Nk/2, ..., +Nk/2] k_path = np.linspace(-np.pi, np.pi, Nk + 1) # 2) Reordering indices so momentum 0 sits in the middle and the curve is closed # Example (Nk=16): [8,9,10,11,12,13,14,15,0,1,2,3,4,5,6,7,8] order = list(range(Nk // 2, Nk)) + list(range(0, Nk // 2)) + [Nk // 2] # 3) Tick labels at every point by default (same length as k_path); # you can slice these arrays if you want fewer ticks. tick_vals = np.arange(-Nk // 2, Nk // 2 + 1) if numeric_labels: # Simple integer labels from -Nk//2 .. 0 .. +Nk//2 tick_labels = [ rf"${tick_value}$" if tick_value != 0 else r"$0$" for tick_value in tick_vals ] else: # Pretty π-fraction labels: k = (m / (Nk/2)) * π den_base = Nk // 2 tick_labels = [] for tick_value in tick_vals: if tick_value == 0: tick_labels.append(r"$0$") continue if abs(tick_value) == den_base: tick_labels.append(r"$-\pi$" if tick_value < 0 else r"$+\pi$") continue # reduce fraction |m| / (Nk/2) num, den = abs(tick_value), den_base gcd_value = math.gcd(num, den) num //= gcd_value den //= gcd_value sign = "-" if tick_value < 0 else "+" if den == 1: label = rf"${sign}\pi$" if num == 1 else rf"${sign}{num}\pi$" tick_labels.append(label) else: label = ( rf"${sign}\frac{{\pi}}{{{den}}}$" if num == 1 else rf"${sign}\frac{{{num}\pi}}{{{den}}}$" ) tick_labels.append(label) return k_path, order, tick_labels def _fit_quality_metrics(y_values, y_fit): """Return compact quality metrics for a one-dimensional fit.""" y_values = np.asarray(y_values, dtype=float) y_fit = np.asarray(y_fit, dtype=float) residuals = y_values - y_fit rms_error = float(np.sqrt(np.mean(residuals**2))) scale = np.maximum(np.abs(y_values), np.finfo(float).eps) relative_rms_error = float(np.sqrt(np.mean((residuals / scale) ** 2))) ss_res = float(np.sum(residuals**2)) ss_tot = float(np.sum((y_values - np.mean(y_values)) ** 2)) r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else 1.0 return { "rms_error": rms_error, "relative_rms_error": relative_rms_error, "r2": r2, "n_points": int(len(y_values)), } def fit_saturation_power_law(time, values, fit_window, saturation_window): """Fit ``F(t) = F_sat - C t^{-B}`` on user-selected windows.""" time = np.asarray(time, dtype=float) values = np.asarray(values, dtype=float) fit_mask = ( np.isfinite(time) & np.isfinite(values) & (time >= fit_window[0]) & (time <= fit_window[1]) & (time > 0) ) saturation_mask = ( np.isfinite(time) & np.isfinite(values) & (time >= saturation_window[0]) & (time <= saturation_window[1]) ) fit_time = time[fit_mask] fit_values = values[fit_mask] if len(fit_time) < 2: raise ValueError("not enough valid points in the selected fit window") if np.count_nonzero(saturation_mask) < 1: raise ValueError("not enough valid points in the selected saturation window") f_saturation = float(np.mean(values[saturation_mask])) valid_mask = (f_saturation - fit_values) > 0 if np.count_nonzero(valid_mask) < 2: raise ValueError("need at least two points with F_sat - F(t) > 0") reg_time = fit_time[valid_mask] reg_values = fit_values[valid_mask] log_time = np.log(reg_time) log_delta = np.log(f_saturation - reg_values) slope, intercept = np.polyfit(log_time, log_delta, 1) b_value = float(-slope) c_value = float(np.exp(intercept)) regression_curve = f_saturation - c_value * reg_time ** (-b_value) fit_curve = f_saturation - c_value * fit_time ** (-b_value) metrics = _fit_quality_metrics(reg_values, regression_curve) return { "F_saturation": f_saturation, "C": c_value, "B": b_value, "fit_time": fit_time, "fit_curve": fit_curve, **metrics, } def fit_log_growth(time, values, fit_window): """Fit ``F(t) = C log(t) + B`` on a user-selected window.""" time = np.asarray(time, dtype=float) values = np.asarray(values, dtype=float) fit_mask = ( np.isfinite(time) & np.isfinite(values) & (time >= fit_window[0]) & (time <= fit_window[1]) & (time > 0) ) fit_time = time[fit_mask] fit_values = values[fit_mask] if len(fit_time) < 2: raise ValueError("not enough valid points in the selected fit window") log_time = np.log(fit_time) c_value, b_value = np.polyfit(log_time, fit_values, 1) fit_curve = c_value * log_time + b_value metrics = _fit_quality_metrics(fit_values, fit_curve) return { "C": float(c_value), "B": float(b_value), "fit_time": fit_time, "fit_curve": fit_curve, **metrics, }