"""Parity (inversion) symmetry helpers in symmetry-reduced bases.
The module builds parity permutations and parity operators directly in a
configuration-based sector basis, and provides a lightweight routine to apply
the resulting signed permutation to a state vector.
"""
import numpy as np
from numba import njit, prange
from edlgt.tools.config_encoding import config_to_index_binarysearch
__all__ = [
"apply_parity_to_state",
"build_parity_operator",
]
@njit(cache=True)
def parity_perm_site(n_sites: int, j0: int) -> np.ndarray:
"""
Site-centered inversion about site j0:
j -> (2*j0 - j) mod n_sites
Parameters
----------
n_sites : int
Number of lattice sites (L).
j0 : int
Site index about which we invert. Can be any integer; reduced mod L.
Returns
-------
ndarray
Permutation array such that ``perm[j]`` is the new site index of the
degree of freedom originally at site ``j``.
"""
num_sites = n_sites
center_site = j0 % num_sites
perm = np.empty(num_sites, dtype=np.int32)
center_twice = 2 * center_site
for site_idx in range(num_sites):
perm[site_idx] = (center_twice - site_idx) % num_sites
return perm
@njit(cache=True)
def parity_perm_bond(n_sites: int, j0: int) -> np.ndarray:
"""
Bond-centered inversion about the bond (j0, j0+1),
i.e. center at j0 + 0.5:
j -> (2*(j0 + 0.5) - j) mod n_sites
= (2*j0 + 1 - j) mod n_sites
Parameters
----------
n_sites : int
Number of lattice sites (L).
j0 : int
LEFT site index of the bond (j0, j0+1). Can be any integer; reduced mod L.
Returns
-------
ndarray
Permutation array such that ``perm[j]`` is the new site index of the
degree of freedom originally at site ``j``.
"""
num_sites = n_sites
center_site = j0 % num_sites
perm = np.empty(num_sites, dtype=np.int32)
center_twice = 2 * center_site + 1 # 2*(j0 + 0.5)
for site_idx in range(num_sites):
perm[site_idx] = (center_twice - site_idx) % num_sites
return perm
@njit(cache=True)
def parity_image_config(
config: np.ndarray,
site_perm: np.ndarray,
loc_perm: np.ndarray,
loc_phase: np.ndarray,
):
"""
Apply parity to a single configuration.
Parameters
----------
config : (n_sites,)
Local basis indices of original configuration.
site_perm : (n_sites,)
site_perm[j] = new site index of DOF originally at j.
loc_perm : ndarray
Local basis-label permutation under parity.
loc_phase : ndarray
Local parity phase factors associated with the basis labels.
Returns
-------
tuple
``(new_config, total_phase)`` with the transformed configuration and
the accumulated parity sign.
"""
n_sites = len(config)
new_config = np.empty(n_sites, dtype=np.int32)
total_phase = 1
for site_idx in range(n_sites):
local_state = config[site_idx]
transformed_local_state = loc_perm[local_state]
total_phase = total_phase * loc_phase[local_state]
transformed_site_idx = site_perm[site_idx]
new_config[transformed_site_idx] = transformed_local_state
return new_config, total_phase
[docs]
@njit(cache=True, parallel=True)
def build_parity_operator(
sector_configs: np.ndarray,
loc_perm: np.ndarray,
loc_phase: np.ndarray,
wrt_site: np.uint8 = 0,
):
"""Build the parity operator in triplet form for a sector basis.
Parameters
----------
sector_configs : ndarray
Symmetry-sector configurations (one row per basis state, lexicographically
sorted).
loc_perm : ndarray
Local basis-label permutation under parity.
loc_phase : ndarray
Local parity phase factors associated with the basis labels.
wrt_site : int, optional
If ``0``, build a site-centered inversion; otherwise use a bond-centered
inversion.
Returns
-------
tuple
``(row, col, data)`` triplet representation of the parity operator,
with exactly one nonzero ``±1`` entry per column.
"""
n_configs, n_sites = sector_configs.shape
if wrt_site == 0:
site_perm = parity_perm_site(n_sites, n_sites // 2 - 1)
else:
site_perm = parity_perm_bond(n_sites, n_sites // 2 - 1)
row = np.empty(n_configs, dtype=np.int32)
col = np.empty(n_configs, dtype=np.int32)
data = np.empty(n_configs, dtype=np.float64)
for cidx in prange(n_configs):
cfg = sector_configs[cidx]
new_cfg, phase = parity_image_config(cfg, site_perm, loc_perm, loc_phase)
new_cfg_idx = config_to_index_binarysearch(new_cfg, sector_configs)
# In a consistent sector construction, new_cfg_idx should never be -1.
row[cidx] = new_cfg_idx
col[cidx] = cidx
data[cidx] = np.float64(phase)
return row, col, data
[docs]
@njit(cache=True, parallel=True)
def apply_parity_to_state(
psi: np.ndarray, # (n_configs,), complex128 expected
row: np.ndarray, # (n_configs,), int32
col: np.ndarray, # (n_configs,), int32
data: np.ndarray, # (n_configs,), float64 in {+1, -1}
) -> np.ndarray:
"""Apply a parity operator stored as signed-permutation triplets.
Parameters
----------
psi : ndarray
State vector in the sector basis.
row, col, data : ndarray
Triplet representation returned by :func:`build_parity_operator`.
Returns
-------
ndarray
Parity-transformed state vector.
Notes
-----
The routine assumes exactly one nonzero entry per column, which holds for a
permutation-with-signs representation of parity.
"""
n_configs = psi.size
psi_out = np.zeros(n_configs, dtype=psi.dtype)
for cfg_idx in prange(n_configs):
row_idx = row[cfg_idx]
col_idx = col[cfg_idx]
psi_out[row_idx] += psi[col_idx] * np.array(data[cfg_idx], dtype=psi.dtype)
return psi_out