Source code for edlgt.symmetries.generate_configs

"""Configuration-table and basis-expansion helpers for symmetry sectors.

This module keeps the symmetry-facing helpers used to enumerate product-basis
configurations and expand symmetry-reduced bases back into the full Hilbert
space.
"""

import numpy as np
from numba import njit, prange

from edlgt.tools.config_encoding import compute_strides, config_to_index

__all__ = [
    "get_state_configs",
    "build_sector_expansion_projector",
]


[docs] @njit(parallel=True, cache=True) def get_state_configs(loc_dims): """Enumerate all product-basis configurations for a set of local dimensions. Parameters ---------- loc_dims : ndarray One-dimensional array of local Hilbert-space dimensions. Returns ------- ndarray Array of shape ``(prod(loc_dims), len(loc_dims))`` with dtype ``np.uint8``. Each row is one many-body configuration. """ # Total number of configs total_configs = 1 for dim in loc_dims: total_configs *= dim # Len of each config num_dims = len(loc_dims) configs = np.empty((total_configs, num_dims), dtype=np.uint8) # Precompute the mixed-radix strides once so that the inner loop can decode # each site digit with simple integer arithmetic instead of recomputing the # product of the trailing local dimensions for every configuration/site pair. strides = compute_strides(loc_dims) # Iterate over all the possible configs for config_idx in prange(total_configs): for dim_index in range(num_dims): configs[config_idx, dim_index] = ( config_idx // strides[dim_index] ) % loc_dims[dim_index] return configs
[docs] @njit(cache=True, parallel=True) def build_sector_expansion_projector( sector_configs: np.ndarray, local_dims: np.ndarray ) -> np.ndarray: """Build a dense expansion projector from sector to full basis. Parameters ---------- sector_configs : ndarray Allowed configurations in the reduced sector, one row per basis state. local_dims : ndarray Local Hilbert-space dimensions in the same site order. Returns ------- ndarray Dense binary projector of shape ``(prod(local_dims), sector_dim)``. """ sector_dim = sector_configs.shape[0] full_dim = np.prod(local_dims) projector = np.zeros((full_dim, sector_dim), dtype=np.uint8) for sector_idx in prange(sector_dim): row_idx = config_to_index(sector_configs[sector_idx], local_dims) projector[row_idx, sector_idx] = 1 return projector