"""Configuration-table and basis-expansion helpers for symmetry sectors.
This module keeps the symmetry-facing helpers used to enumerate product-basis
configurations and expand symmetry-reduced bases back into the full Hilbert
space.
"""
import numpy as np
from numba import njit, prange
from edlgt.tools.config_encoding import compute_strides, config_to_index
__all__ = [
"get_state_configs",
"build_sector_expansion_projector",
]
[docs]
@njit(parallel=True, cache=True)
def get_state_configs(loc_dims):
"""Enumerate all product-basis configurations for a set of local dimensions.
Parameters
----------
loc_dims : ndarray
One-dimensional array of local Hilbert-space dimensions.
Returns
-------
ndarray
Array of shape ``(prod(loc_dims), len(loc_dims))`` with dtype
``np.uint8``. Each row is one many-body configuration.
"""
# Total number of configs
total_configs = 1
for dim in loc_dims:
total_configs *= dim
# Len of each config
num_dims = len(loc_dims)
configs = np.empty((total_configs, num_dims), dtype=np.uint8)
# Precompute the mixed-radix strides once so that the inner loop can decode
# each site digit with simple integer arithmetic instead of recomputing the
# product of the trailing local dimensions for every configuration/site pair.
strides = compute_strides(loc_dims)
# Iterate over all the possible configs
for config_idx in prange(total_configs):
for dim_index in range(num_dims):
configs[config_idx, dim_index] = (
config_idx // strides[dim_index]
) % loc_dims[dim_index]
return configs
[docs]
@njit(cache=True, parallel=True)
def build_sector_expansion_projector(
sector_configs: np.ndarray, local_dims: np.ndarray
) -> np.ndarray:
"""Build a dense expansion projector from sector to full basis.
Parameters
----------
sector_configs : ndarray
Allowed configurations in the reduced sector, one row per basis state.
local_dims : ndarray
Local Hilbert-space dimensions in the same site order.
Returns
-------
ndarray
Dense binary projector of shape ``(prod(local_dims), sector_dim)``.
"""
sector_dim = sector_configs.shape[0]
full_dim = np.prod(local_dims)
projector = np.zeros((full_dim, sector_dim), dtype=np.uint8)
for sector_idx in prange(sector_dim):
row_idx = config_to_index(sector_configs[sector_idx], local_dims)
projector[row_idx, sector_idx] = 1
return projector