"""Subsystem/environment matrix builders and cache-format helpers."""
import numpy as np
from numba import njit, prange
from scipy.sparse import csr_matrix
__all__ = [
"build_psi_matrix",
"build_sparse_psi_matrix",
"should_cache_psi_matrix_as_sparse",
]
[docs]
@njit(cache=True, parallel=True)
def build_psi_matrix(
psi: np.ndarray,
subsys_config_index: np.ndarray,
env_config_index: np.ndarray,
subsys_dim: int,
env_dim: int,
):
"""Build the dense subsystem/environment matrix of a state vector."""
psi_matrix = np.zeros((subsys_dim, env_dim), dtype=np.complex128)
for ii in prange(psi.shape[0]):
psi_matrix[subsys_config_index[ii], env_config_index[ii]] = psi[ii]
return psi_matrix
[docs]
def build_sparse_psi_matrix(
psi: np.ndarray,
subsys_config_index: np.ndarray,
env_config_index: np.ndarray,
subsys_dim: int,
env_dim: int,
) -> csr_matrix:
"""Build a CSR subsystem/environment matrix directly from index maps."""
return csr_matrix(
(
np.asarray(psi, dtype=np.complex128),
(
np.asarray(subsys_config_index, dtype=np.int64),
np.asarray(env_config_index, dtype=np.int64),
),
),
shape=(int(subsys_dim), int(env_dim)),
)
[docs]
def should_cache_psi_matrix_as_sparse(
subsys_dim: int,
env_dim: int,
nnz: int,
size_thresh: int,
density_thresh: float,
) -> bool:
"""Decide whether a subsystem/environment matrix should be cached as CSR."""
total_entries = int(subsys_dim) * int(env_dim)
if total_entries == 0:
return False
density = float(nnz) / float(total_entries)
return total_entries > size_thresh and density < density_thresh