Source code for edlgt.tools.checks

"""Validation, diagnostics, and matrix consistency checks.

This module collects small utility functions used across the library for:

- validating common input parameters,
- pausing or logging debug messages during script execution,
- checking commutation relations, matrix equality, and Hermiticity,
- timing function calls with a lightweight decorator.

Most matrix checks operate on SciPy sparse matrices.
"""

import logging
from functools import wraps
from math import prod
from time import perf_counter

import numpy as np
from scipy.sparse import csr_matrix, isspmatrix
from scipy.sparse.linalg import norm

logger = logging.getLogger(__name__)

__all__ = [
    "validate_parameters",
    "pause",
    "alert",
    "commutator",
    "anti_commutator",
    "check_commutator",
    "check_matrix",
    "check_hermitian",
    "get_time",
]


[docs] def get_time(func): """Decorate a function to log its execution time at debug level. Parameters ---------- func : callable Function to wrap. Returns ------- callable Wrapped function with the same signature and return value. """ @wraps(func) def wrapper(*args, **kwargs): start_time = perf_counter() result = func(*args, **kwargs) end_time = perf_counter() tot_time = end_time - start_time logger.debug("TIME %s %.5f", func.__name__, tot_time) return result return wrapper
[docs] def validate_parameters( lvals=None, loc_dims=None, lattice_dim=None, has_obc=None, axes=None, site_label=None, coords=None, ops_dict=None, op_list=None, op_names_list=None, op_sites_list=None, add_dagger=None, get_real=None, get_imag=None, staggered_basis=None, stag_label=None, all_sites_equal=None, gauge_basis=None, dictionary=None, filename=None, phrase=None, debug=None, psi=None, spmatrix=None, index=None, threshold=None, print_plaq=None, spin_list=None, int_list=None, sz_list=None, pure_theory=None, matter=None, psi_vacuum=None, get_singlet=None, array=None, ): # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements """Validate commonly used library arguments by type and basic value rules. Only parameters passed with a value different from ``None`` are checked. The function is intentionally broad and centralizes validation logic shared by multiple modules. Parameters ---------- ... : object, optional Any named argument in the function signature. Each argument provided with a value different from ``None`` is validated against the expected type (and, where implemented, basic value constraints). Returns ------- None Raises ------ TypeError If an argument has an invalid type. ValueError If an argument has an invalid value (for example ``stag_label``). """ # ----------------------------------------------------------------------------- if lvals is not None and ( not isinstance(lvals, list) or not all(isinstance(length, int) for length in lvals) ): raise TypeError(f"lvals should be a LIST of INTs, not {type(lvals)}") if lattice_dim is not None and not isinstance(lattice_dim, int): raise TypeError(f"lattice_dim should be INT, not {type(lattice_dim)}") if loc_dims is not None: if isinstance(loc_dims, int): loc_dims = np.full(prod(lvals), loc_dims) elif isinstance(loc_dims, list): loc_dims = np.asarray(loc_dims) elif not isinstance(loc_dims, np.ndarray): raise TypeError( f"loc_dims must be INT, LIST, or np.ndarray, not {type(loc_dims)}" ) if has_obc is not None and ( not isinstance(has_obc, list) or not all(isinstance(boundary_flag, bool) for boundary_flag in has_obc) ): raise TypeError(f"has_obc should be a LIST of BOOLs, not {type(has_obc)}") if axes is not None and ( not isinstance(axes, list) or not all(isinstance(ax, str) for ax in axes) ): raise TypeError(f"axes should be a LIST of STRs, not {type(axes)}") if site_label is not None and not isinstance(site_label, str): raise TypeError(f"site_label should be a STRING, not {type(site_label)}") if coords is not None and not ( ( isinstance(coords, (tuple, list)) and all(isinstance(coord_value, int) for coord_value in coords) ) ): raise TypeError(f"coords must be a TUPLE or LIST of INTs, not {type(coords)}") # ----------------------------------------------------------------------------- if ops_dict is not None and not isinstance(ops_dict, dict): raise TypeError(f"ops_dict must be a DICT, not {type(ops_dict)}") if op_list is not None and not any( [ isinstance(axes, list), any( [ all(isspmatrix(op) for op in op_list), all(isinstance(op, np.ndarray) for op in op_list), ] ), ] ): raise TypeError( f"op_list must be a LIST of SPARSE/Numpy matrices, not {type(op_list)}" ) if op_sites_list is not None and ( not isinstance(op_sites_list, list) or not all(isinstance(site_index, int) for site_index in op_sites_list) ): raise TypeError( f"op_sites_list must be a LIST of INTs, not {type(op_sites_list)}" ) if op_names_list is not None and ( not isinstance(op_names_list, list) or not all(isinstance(op_name, str) for op_name in op_names_list) ): raise TypeError( f"op_names_list must be a LIST of STRs, not {type(op_names_list)}" ) # ----------------------------------------------------------------------------- if add_dagger is not None and not isinstance(add_dagger, bool): raise TypeError(f"add_dagger should be a BOOL, not {type(add_dagger)}") if get_real is not None and not isinstance(get_real, bool): raise TypeError(f"get_real should be a BOOL, not {type(get_real)}") if get_imag is not None and not isinstance(get_imag, bool): raise TypeError(f"get_imag should be a BOOL, not {type(get_imag)}") # ----------------------------------------------------------------------------- if staggered_basis is not None and not isinstance(staggered_basis, bool): raise TypeError(f"staggered_basis must be a BOOL, not {type(staggered_basis)}") if stag_label is not None and not any([stag_label == "even", stag_label == "odd"]): raise ValueError(f"stag_label must be 'even' or 'odd', not {stag_label}") if all_sites_equal is not None and not isinstance(all_sites_equal, bool): raise TypeError( f"all_sites_equal should be a BOOL, not {type(all_sites_equal)}" ) if gauge_basis is not None and not isinstance(gauge_basis, dict): raise TypeError(f"gauge_basis must be a DICT, not {type(gauge_basis)}") # ----------------------------------------------------------------------------- if dictionary is not None and not isinstance(dictionary, dict): raise TypeError(f"dictionary should be a DICT, not {type(dictionary)}") if filename is not None and not isinstance(filename, str): raise TypeError(f"filename should be a STRING, not {type(filename)}") # ----------------------------------------------------------------------------- if phrase is not None and not isinstance(phrase, str): raise TypeError(f"phrase should be a STRING, not {type(phrase)}") if debug is not None and not isinstance(debug, bool): raise TypeError(f"debug should be a BOOL, not {type(debug)}") # ----------------------------------------------------------------------------- if psi is not None and not isinstance(psi, np.ndarray): raise TypeError(f"psi should be an ndarray, not a {type(psi)}") if array is not None and not isinstance(array, np.ndarray): raise TypeError(f"array must be np.array, not {type(array)}") # ----------------------------------------------------------------------------- if spmatrix is not None and not isspmatrix(spmatrix): raise TypeError(f"spmatrix should be sparse, not {type(spmatrix)}") if index is not None and not isinstance(index, int): raise TypeError(f"index should be a SCALAR INT, not {type(index)}") if threshold is not None and not isinstance(threshold, float): raise TypeError(f"threshold should be a SCALAR FLOAT, not {type(threshold)}") # ----------------------------------------------------------------------------- if print_plaq is not None and not isinstance(print_plaq, bool): raise TypeError(f"print_plaq must be a BOOL, not a {type(print_plaq)}") # ----------------------------------------------------------------------------- # List of spin irreps if spin_list is not None: if not isinstance(spin_list, list): raise TypeError(f"spin_list must be a list, not {type(spin_list)}") for ii, spin in enumerate(spin_list): if not float(2 * spin).is_integer() or spin < 0: raise TypeError( f"The {ii} spin must be positive (half-)integer, not {spin}" ) # n values for the Zn group if int_list is not None and ( not isinstance(int_list, list) or not all(isinstance(group_order, int) for group_order in int_list) ): raise TypeError(f"int_list must be a list of integers, not {int_list}") # 3rd components of spins if sz_list is not None: if not isinstance(sz_list, list): raise TypeError(f"sz_list must be a list, not {type(sz_list)}") for ii, sz in enumerate(sz_list): if not float(2 * sz).is_integer(): raise TypeError( f"The {ii} z-component must be (half-)integer, not {sz}" ) # ----------------------------------------------------------------------------- if pure_theory is not None and not isinstance(pure_theory, bool): raise TypeError(f"pure_theory must be BOOL, not {type(pure_theory)}") if matter is not None and not isinstance(matter, bool): raise TypeError(f"matter must be BOOL, not {type(matter)}") if psi_vacuum is not None and not isinstance(psi_vacuum, bool): raise TypeError(f"psi_vacuum must be bool, not {type(psi_vacuum)}") if get_singlet is not None and not isinstance(get_singlet, bool): raise TypeError(f"get_singlet must be bool, not {type(get_singlet)}")
# -----------------------------------------------------------------------------
[docs] def pause(phrase, debug): """Pause execution and wait for user input when debugging is enabled. Parameters ---------- phrase : str Prompt displayed to the user. debug : bool If ``True``, call :func:`input`; otherwise do nothing. Returns ------- None Raises ------ TypeError If ``phrase`` or ``debug`` has an invalid type. """ # Validate type of parameters validate_parameters(phrase=phrase, debug=debug) if debug is True: # IT PROVIDES A PAUSE in a given point of the PYTHON CODE logger.debug("----------------------------------------------------") # Press the <ENTER> key to continue _ = input(phrase) logger.debug("----------------------------------------------------") logger.debug("")
[docs] def alert(phrase, debug): """Log a debug message when debugging is enabled. Parameters ---------- phrase : str Message to log. debug : bool If ``True``, emit the message at debug level; otherwise do nothing. Returns ------- None Raises ------ TypeError If ``phrase`` or ``debug`` has an invalid type. """ # Validate type of parameters validate_parameters(phrase=phrase, debug=debug) if debug is True: # IT PRINTS A PHRASE IN A GIVEN POINT OF A PYTHON CODE logger.debug("") logger.debug(phrase)
[docs] def commutator(matrix_a, matrix_b): """Compute the commutator ``[A, B] = AB - BA``. Parameters ---------- matrix_a, matrix_b : scipy.sparse.spmatrix Sparse matrices with compatible shapes. Returns ------- scipy.sparse.spmatrix Sparse matrix representing ``AB - BA``. Raises ------ TypeError If ``matrix_a`` or ``matrix_b`` is not a SciPy sparse matrix. """ validate_parameters(spmatrix=matrix_a) validate_parameters(spmatrix=matrix_b) matrix_a = matrix_a.tocsr() if hasattr(matrix_a, "tocsr") else matrix_a matrix_b = matrix_b.tocsr() if hasattr(matrix_b, "tocsr") else matrix_b return matrix_a @ matrix_b - matrix_b @ matrix_a
[docs] def anti_commutator(matrix_a, matrix_b): """Compute the anti-commutator ``{A, B} = AB + BA``. Parameters ---------- matrix_a, matrix_b : scipy.sparse.spmatrix Sparse matrices with compatible shapes. Returns ------- scipy.sparse.spmatrix Sparse matrix representing ``AB + BA``. Raises ------ TypeError If ``matrix_a`` or ``matrix_b`` is not a SciPy sparse matrix. """ validate_parameters(spmatrix=matrix_a) validate_parameters(spmatrix=matrix_b) matrix_a = matrix_a.tocsr() if hasattr(matrix_a, "tocsr") else matrix_a matrix_b = matrix_b.tocsr() if hasattr(matrix_b, "tocsr") else matrix_b return matrix_a @ matrix_b + matrix_b @ matrix_a
[docs] def check_commutator(matrix_a, matrix_b): """Check whether two sparse operators commute within a fixed tolerance. The function computes a normalized commutator norm and raises if the ratio exceeds ``1e-15``. Parameters ---------- matrix_a, matrix_b : scipy.sparse.spmatrix Sparse matrices with compatible shapes. Returns ------- None Raises ------ TypeError If ``matrix_a`` or ``matrix_b`` is not a SciPy sparse matrix. ValueError If the normalized commutator norm is larger than the tolerance. """ # CHECKS THE COMMUTATION RELATIONS BETWEEN THE OPERATORS A AND B validate_parameters(spmatrix=matrix_a) validate_parameters(spmatrix=matrix_b) norma = norm(matrix_a * matrix_b - matrix_b * matrix_a) norma_max = max( norm(matrix_a * matrix_b + matrix_b * matrix_a), norm(matrix_a), norm(matrix_b), ) ratio = norma / norma_max # check=(AB!=BA).nnz if ratio > 10 ** (-15): logger.info("ERROR: A and B do NOT COMMUTE") logger.info("NORM %s", norma) logger.info("RATIO %s", ratio) raise ValueError(f"A & B do not commute: NORM[A,B]={norma}, RATIO {ratio}") logger.info("")
[docs] def check_matrix(matrix_a: csr_matrix, matrix_b: csr_matrix): """Compare two sparse matrices using a normalized Frobenius-norm criterion. Parameters ---------- matrix_a, matrix_b : scipy.sparse.csr_matrix Sparse matrices to compare. Returns ------- None Raises ------ TypeError If ``matrix_a`` or ``matrix_b`` is not a SciPy sparse matrix. ValueError If shapes differ or the normalized difference is larger than ``1e-14``. """ # CHECKS THE DIFFERENCE BETWEEN TWO SPARSE MATRICES validate_parameters(spmatrix=matrix_a) validate_parameters(spmatrix=matrix_b) if matrix_a.shape != matrix_b.shape: raise ValueError( f"Shape mismatch between : A {matrix_a.shape} & B: {matrix_b.shape}" ) norma = norm(matrix_a - matrix_b) norma_max = max(norm(matrix_a + matrix_b), norm(matrix_a), norm(matrix_b)) ratio = norma / norma_max logger.debug("NORM %s, norma max %s, RATIO %s", norma, norma_max, ratio) if ratio > 1e-14: logger.info("ERROR: A and B are DIFFERENT MATRICES") raise ValueError(f"NORM {norma}, RATIO {ratio}")
[docs] def check_hermitian(matrix_a): """Validate that a sparse matrix is Hermitian. Parameters ---------- matrix_a : scipy.sparse.spmatrix Sparse matrix to test. Returns ------- None Raises ------ TypeError If ``matrix_a`` is not a SciPy sparse matrix. ValueError If ``matrix_a`` differs from its Hermitian conjugate beyond the tolerance used by :func:`check_matrix`. """ validate_parameters(spmatrix=matrix_a) matrix_a_dag = matrix_a.getH() check_matrix(matrix_a, matrix_a_dag) # Get the Hermitian logger.info("HERMITICITY VALIDATED")