"""Validation, diagnostics, and matrix consistency checks.
This module collects small utility functions used across the library for:
- validating common input parameters,
- pausing or logging debug messages during script execution,
- checking commutation relations, matrix equality, and Hermiticity,
- timing function calls with a lightweight decorator.
Most matrix checks operate on SciPy sparse matrices.
"""
import logging
from functools import wraps
from math import prod
from time import perf_counter
import numpy as np
from scipy.sparse import csr_matrix, isspmatrix
from scipy.sparse.linalg import norm
logger = logging.getLogger(__name__)
__all__ = [
"validate_parameters",
"pause",
"alert",
"commutator",
"anti_commutator",
"check_commutator",
"check_matrix",
"check_hermitian",
"get_time",
]
[docs]
def get_time(func):
"""Decorate a function to log its execution time at debug level.
Parameters
----------
func : callable
Function to wrap.
Returns
-------
callable
Wrapped function with the same signature and return value.
"""
@wraps(func)
def wrapper(*args, **kwargs):
start_time = perf_counter()
result = func(*args, **kwargs)
end_time = perf_counter()
tot_time = end_time - start_time
logger.debug("TIME %s %.5f", func.__name__, tot_time)
return result
return wrapper
[docs]
def validate_parameters(
lvals=None,
loc_dims=None,
lattice_dim=None,
has_obc=None,
axes=None,
site_label=None,
coords=None,
ops_dict=None,
op_list=None,
op_names_list=None,
op_sites_list=None,
add_dagger=None,
get_real=None,
get_imag=None,
staggered_basis=None,
stag_label=None,
all_sites_equal=None,
gauge_basis=None,
dictionary=None,
filename=None,
phrase=None,
debug=None,
psi=None,
spmatrix=None,
index=None,
threshold=None,
print_plaq=None,
spin_list=None,
int_list=None,
sz_list=None,
pure_theory=None,
matter=None,
psi_vacuum=None,
get_singlet=None,
array=None,
): # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements
"""Validate commonly used library arguments by type and basic value rules.
Only parameters passed with a value different from ``None`` are checked.
The function is intentionally broad and centralizes validation logic shared by
multiple modules.
Parameters
----------
... : object, optional
Any named argument in the function signature. Each argument provided
with a value different from ``None`` is validated against the expected
type (and, where implemented, basic value constraints).
Returns
-------
None
Raises
------
TypeError
If an argument has an invalid type.
ValueError
If an argument has an invalid value (for example ``stag_label``).
"""
# -----------------------------------------------------------------------------
if lvals is not None and (
not isinstance(lvals, list)
or not all(isinstance(length, int) for length in lvals)
):
raise TypeError(f"lvals should be a LIST of INTs, not {type(lvals)}")
if lattice_dim is not None and not isinstance(lattice_dim, int):
raise TypeError(f"lattice_dim should be INT, not {type(lattice_dim)}")
if loc_dims is not None:
if isinstance(loc_dims, int):
loc_dims = np.full(prod(lvals), loc_dims)
elif isinstance(loc_dims, list):
loc_dims = np.asarray(loc_dims)
elif not isinstance(loc_dims, np.ndarray):
raise TypeError(
f"loc_dims must be INT, LIST, or np.ndarray, not {type(loc_dims)}"
)
if has_obc is not None and (
not isinstance(has_obc, list)
or not all(isinstance(boundary_flag, bool) for boundary_flag in has_obc)
):
raise TypeError(f"has_obc should be a LIST of BOOLs, not {type(has_obc)}")
if axes is not None and (
not isinstance(axes, list) or not all(isinstance(ax, str) for ax in axes)
):
raise TypeError(f"axes should be a LIST of STRs, not {type(axes)}")
if site_label is not None and not isinstance(site_label, str):
raise TypeError(f"site_label should be a STRING, not {type(site_label)}")
if coords is not None and not (
(
isinstance(coords, (tuple, list))
and all(isinstance(coord_value, int) for coord_value in coords)
)
):
raise TypeError(f"coords must be a TUPLE or LIST of INTs, not {type(coords)}")
# -----------------------------------------------------------------------------
if ops_dict is not None and not isinstance(ops_dict, dict):
raise TypeError(f"ops_dict must be a DICT, not {type(ops_dict)}")
if op_list is not None and not any(
[
isinstance(axes, list),
any(
[
all(isspmatrix(op) for op in op_list),
all(isinstance(op, np.ndarray) for op in op_list),
]
),
]
):
raise TypeError(
f"op_list must be a LIST of SPARSE/Numpy matrices, not {type(op_list)}"
)
if op_sites_list is not None and (
not isinstance(op_sites_list, list)
or not all(isinstance(site_index, int) for site_index in op_sites_list)
):
raise TypeError(
f"op_sites_list must be a LIST of INTs, not {type(op_sites_list)}"
)
if op_names_list is not None and (
not isinstance(op_names_list, list)
or not all(isinstance(op_name, str) for op_name in op_names_list)
):
raise TypeError(
f"op_names_list must be a LIST of STRs, not {type(op_names_list)}"
)
# -----------------------------------------------------------------------------
if add_dagger is not None and not isinstance(add_dagger, bool):
raise TypeError(f"add_dagger should be a BOOL, not {type(add_dagger)}")
if get_real is not None and not isinstance(get_real, bool):
raise TypeError(f"get_real should be a BOOL, not {type(get_real)}")
if get_imag is not None and not isinstance(get_imag, bool):
raise TypeError(f"get_imag should be a BOOL, not {type(get_imag)}")
# -----------------------------------------------------------------------------
if staggered_basis is not None and not isinstance(staggered_basis, bool):
raise TypeError(f"staggered_basis must be a BOOL, not {type(staggered_basis)}")
if stag_label is not None and not any([stag_label == "even", stag_label == "odd"]):
raise ValueError(f"stag_label must be 'even' or 'odd', not {stag_label}")
if all_sites_equal is not None and not isinstance(all_sites_equal, bool):
raise TypeError(
f"all_sites_equal should be a BOOL, not {type(all_sites_equal)}"
)
if gauge_basis is not None and not isinstance(gauge_basis, dict):
raise TypeError(f"gauge_basis must be a DICT, not {type(gauge_basis)}")
# -----------------------------------------------------------------------------
if dictionary is not None and not isinstance(dictionary, dict):
raise TypeError(f"dictionary should be a DICT, not {type(dictionary)}")
if filename is not None and not isinstance(filename, str):
raise TypeError(f"filename should be a STRING, not {type(filename)}")
# -----------------------------------------------------------------------------
if phrase is not None and not isinstance(phrase, str):
raise TypeError(f"phrase should be a STRING, not {type(phrase)}")
if debug is not None and not isinstance(debug, bool):
raise TypeError(f"debug should be a BOOL, not {type(debug)}")
# -----------------------------------------------------------------------------
if psi is not None and not isinstance(psi, np.ndarray):
raise TypeError(f"psi should be an ndarray, not a {type(psi)}")
if array is not None and not isinstance(array, np.ndarray):
raise TypeError(f"array must be np.array, not {type(array)}")
# -----------------------------------------------------------------------------
if spmatrix is not None and not isspmatrix(spmatrix):
raise TypeError(f"spmatrix should be sparse, not {type(spmatrix)}")
if index is not None and not isinstance(index, int):
raise TypeError(f"index should be a SCALAR INT, not {type(index)}")
if threshold is not None and not isinstance(threshold, float):
raise TypeError(f"threshold should be a SCALAR FLOAT, not {type(threshold)}")
# -----------------------------------------------------------------------------
if print_plaq is not None and not isinstance(print_plaq, bool):
raise TypeError(f"print_plaq must be a BOOL, not a {type(print_plaq)}")
# -----------------------------------------------------------------------------
# List of spin irreps
if spin_list is not None:
if not isinstance(spin_list, list):
raise TypeError(f"spin_list must be a list, not {type(spin_list)}")
for ii, spin in enumerate(spin_list):
if not float(2 * spin).is_integer() or spin < 0:
raise TypeError(
f"The {ii} spin must be positive (half-)integer, not {spin}"
)
# n values for the Zn group
if int_list is not None and (
not isinstance(int_list, list)
or not all(isinstance(group_order, int) for group_order in int_list)
):
raise TypeError(f"int_list must be a list of integers, not {int_list}")
# 3rd components of spins
if sz_list is not None:
if not isinstance(sz_list, list):
raise TypeError(f"sz_list must be a list, not {type(sz_list)}")
for ii, sz in enumerate(sz_list):
if not float(2 * sz).is_integer():
raise TypeError(
f"The {ii} z-component must be (half-)integer, not {sz}"
)
# -----------------------------------------------------------------------------
if pure_theory is not None and not isinstance(pure_theory, bool):
raise TypeError(f"pure_theory must be BOOL, not {type(pure_theory)}")
if matter is not None and not isinstance(matter, bool):
raise TypeError(f"matter must be BOOL, not {type(matter)}")
if psi_vacuum is not None and not isinstance(psi_vacuum, bool):
raise TypeError(f"psi_vacuum must be bool, not {type(psi_vacuum)}")
if get_singlet is not None and not isinstance(get_singlet, bool):
raise TypeError(f"get_singlet must be bool, not {type(get_singlet)}")
# -----------------------------------------------------------------------------
[docs]
def pause(phrase, debug):
"""Pause execution and wait for user input when debugging is enabled.
Parameters
----------
phrase : str
Prompt displayed to the user.
debug : bool
If ``True``, call :func:`input`; otherwise do nothing.
Returns
-------
None
Raises
------
TypeError
If ``phrase`` or ``debug`` has an invalid type.
"""
# Validate type of parameters
validate_parameters(phrase=phrase, debug=debug)
if debug is True:
# IT PROVIDES A PAUSE in a given point of the PYTHON CODE
logger.debug("----------------------------------------------------")
# Press the <ENTER> key to continue
_ = input(phrase)
logger.debug("----------------------------------------------------")
logger.debug("")
[docs]
def alert(phrase, debug):
"""Log a debug message when debugging is enabled.
Parameters
----------
phrase : str
Message to log.
debug : bool
If ``True``, emit the message at debug level; otherwise do nothing.
Returns
-------
None
Raises
------
TypeError
If ``phrase`` or ``debug`` has an invalid type.
"""
# Validate type of parameters
validate_parameters(phrase=phrase, debug=debug)
if debug is True:
# IT PRINTS A PHRASE IN A GIVEN POINT OF A PYTHON CODE
logger.debug("")
logger.debug(phrase)
[docs]
def commutator(matrix_a, matrix_b):
"""Compute the commutator ``[A, B] = AB - BA``.
Parameters
----------
matrix_a, matrix_b : scipy.sparse.spmatrix
Sparse matrices with compatible shapes.
Returns
-------
scipy.sparse.spmatrix
Sparse matrix representing ``AB - BA``.
Raises
------
TypeError
If ``matrix_a`` or ``matrix_b`` is not a SciPy sparse matrix.
"""
validate_parameters(spmatrix=matrix_a)
validate_parameters(spmatrix=matrix_b)
matrix_a = matrix_a.tocsr() if hasattr(matrix_a, "tocsr") else matrix_a
matrix_b = matrix_b.tocsr() if hasattr(matrix_b, "tocsr") else matrix_b
return matrix_a @ matrix_b - matrix_b @ matrix_a
[docs]
def anti_commutator(matrix_a, matrix_b):
"""Compute the anti-commutator ``{A, B} = AB + BA``.
Parameters
----------
matrix_a, matrix_b : scipy.sparse.spmatrix
Sparse matrices with compatible shapes.
Returns
-------
scipy.sparse.spmatrix
Sparse matrix representing ``AB + BA``.
Raises
------
TypeError
If ``matrix_a`` or ``matrix_b`` is not a SciPy sparse matrix.
"""
validate_parameters(spmatrix=matrix_a)
validate_parameters(spmatrix=matrix_b)
matrix_a = matrix_a.tocsr() if hasattr(matrix_a, "tocsr") else matrix_a
matrix_b = matrix_b.tocsr() if hasattr(matrix_b, "tocsr") else matrix_b
return matrix_a @ matrix_b + matrix_b @ matrix_a
[docs]
def check_commutator(matrix_a, matrix_b):
"""Check whether two sparse operators commute within a fixed tolerance.
The function computes a normalized commutator norm and raises if the ratio
exceeds ``1e-15``.
Parameters
----------
matrix_a, matrix_b : scipy.sparse.spmatrix
Sparse matrices with compatible shapes.
Returns
-------
None
Raises
------
TypeError
If ``matrix_a`` or ``matrix_b`` is not a SciPy sparse matrix.
ValueError
If the normalized commutator norm is larger than the tolerance.
"""
# CHECKS THE COMMUTATION RELATIONS BETWEEN THE OPERATORS A AND B
validate_parameters(spmatrix=matrix_a)
validate_parameters(spmatrix=matrix_b)
norma = norm(matrix_a * matrix_b - matrix_b * matrix_a)
norma_max = max(
norm(matrix_a * matrix_b + matrix_b * matrix_a),
norm(matrix_a),
norm(matrix_b),
)
ratio = norma / norma_max
# check=(AB!=BA).nnz
if ratio > 10 ** (-15):
logger.info("ERROR: A and B do NOT COMMUTE")
logger.info("NORM %s", norma)
logger.info("RATIO %s", ratio)
raise ValueError(f"A & B do not commute: NORM[A,B]={norma}, RATIO {ratio}")
logger.info("")
[docs]
def check_matrix(matrix_a: csr_matrix, matrix_b: csr_matrix):
"""Compare two sparse matrices using a normalized Frobenius-norm criterion.
Parameters
----------
matrix_a, matrix_b : scipy.sparse.csr_matrix
Sparse matrices to compare.
Returns
-------
None
Raises
------
TypeError
If ``matrix_a`` or ``matrix_b`` is not a SciPy sparse matrix.
ValueError
If shapes differ or the normalized difference is larger than ``1e-14``.
"""
# CHECKS THE DIFFERENCE BETWEEN TWO SPARSE MATRICES
validate_parameters(spmatrix=matrix_a)
validate_parameters(spmatrix=matrix_b)
if matrix_a.shape != matrix_b.shape:
raise ValueError(
f"Shape mismatch between : A {matrix_a.shape} & B: {matrix_b.shape}"
)
norma = norm(matrix_a - matrix_b)
norma_max = max(norm(matrix_a + matrix_b), norm(matrix_a), norm(matrix_b))
ratio = norma / norma_max
logger.debug("NORM %s, norma max %s, RATIO %s", norma, norma_max, ratio)
if ratio > 1e-14:
logger.info("ERROR: A and B are DIFFERENT MATRICES")
raise ValueError(f"NORM {norma}, RATIO {ratio}")
[docs]
def check_hermitian(matrix_a):
"""Validate that a sparse matrix is Hermitian.
Parameters
----------
matrix_a : scipy.sparse.spmatrix
Sparse matrix to test.
Returns
-------
None
Raises
------
TypeError
If ``matrix_a`` is not a SciPy sparse matrix.
ValueError
If ``matrix_a`` differs from its Hermitian conjugate beyond the tolerance used
by :func:`check_matrix`.
"""
validate_parameters(spmatrix=matrix_a)
matrix_a_dag = matrix_a.getH()
check_matrix(matrix_a, matrix_a_dag)
# Get the Hermitian
logger.info("HERMITICITY VALIDATED")