"""Translation symmetry and momentum-basis construction utilities.
This module builds translation orbits and sparse momentum-basis projectors for
symmetry-sector configuration tables, and provides momentum-projected sparse
operator kernels for generic factorized n-body operators.
"""
import numpy as np
from numba import njit, prange
from edlgt.tools.config_encoding import config_to_index_binarysearch
OP_EPS = 1e-12 # small cutoff for local operator entries
__all__ = [
"check_normalization",
"check_orthogonality",
"get_translated_state_indices",
"get_reference_indices",
"get_momentum_basis",
"nbody_data_momentum",
"nbody_data_momentum_4sites",
"nbody_data_momentum_2sites",
"nbody_data_momentum_1site",
]
[docs]
@njit(cache=True)
def get_translated_state_indices(config, sector_configs, logical_unit_size=1):
"""Generate all 1D translations of a configuration in a sorted sector basis."""
n_sites = len(config)
if n_sites % logical_unit_size != 0:
raise ValueError("Number of sites is not a multiple of the logical unit size.")
if n_sites != sector_configs.shape[1]:
raise ValueError(
f"config.shape[0]={n_sites} must be equal to "
f"sector_configs.shape[1]={sector_configs.shape[1]}"
)
num_translations = n_sites // logical_unit_size
trans_indices = np.zeros(num_translations, dtype=np.int32)
for translation_idx in range(num_translations):
roll_steps = translation_idx * logical_unit_size
rolled_config = np.roll(config, -roll_steps)
trans_indices[translation_idx] = config_to_index_binarysearch(
rolled_config, sector_configs
)
return trans_indices
[docs]
@njit(cache=True)
def get_reference_indices(sector_configs):
"""Select translation-inequivalent reference configurations in 1D."""
sector_dim = sector_configs.shape[0]
normalization = np.zeros(sector_dim, dtype=np.int32)
independent_indices = np.zeros(sector_dim, dtype=np.bool_)
for cfg_idx in range(sector_dim):
config = sector_configs[cfg_idx]
trans_indices = get_translated_state_indices(config, sector_configs)
is_independent = True
for prev_cfg_idx in range(cfg_idx):
if independent_indices[prev_cfg_idx] and prev_cfg_idx in trans_indices:
is_independent = False
break
if is_independent:
independent_indices[cfg_idx] = True
normalization[cfg_idx] = len(np.unique(trans_indices))
ref_indices = np.flatnonzero(independent_indices)
norm = normalization[ref_indices]
return ref_indices, norm
@njit(cache=True)
def _prepare_momentum_local_transition_data(
op_list: np.ndarray, op_sites_list: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Precompute nonzero local transitions for the momentum-space kernels.
Parameters
----------
op_list : ndarray
Site-resolved operator matrices.
op_sites_list : ndarray
Site indices on which the operator acts.
Returns
-------
tuple
``(transition_counts, ket_local_states, transition_values)`` where
``transition_counts[op_idx, bra_loc]`` stores how many outgoing local
transitions are available for operator factor ``op_idx`` from local bra
state ``bra_loc``.
Notes
-----
``local_dim`` is the common padded local dimension stored by the rectangular
operator tensor ``op_list``. Sites with smaller physical local Hilbert spaces
are represented by matrices whose unused rows/columns are zero padded.
"""
n_ops = len(op_sites_list)
local_dim = np.int32(op_list.shape[2])
# For each operator factor and each local bra state, store the list of
# allowed local ket states together with the corresponding matrix elements.
transition_counts = np.zeros((n_ops, local_dim), dtype=np.int32)
ket_local_states = np.empty((n_ops, local_dim, local_dim), dtype=np.int32)
transition_values = np.empty((n_ops, local_dim, local_dim), dtype=np.complex128)
for op_idx in range(n_ops):
site_idx = op_sites_list[op_idx]
site_op = op_list[op_idx, site_idx]
for bra_loc in range(local_dim):
n_transitions = 0
for ket_loc in range(local_dim):
elem = site_op[bra_loc, ket_loc]
if np.abs(elem) > OP_EPS:
# Compactify the nonzero row entries so later kernels can
# iterate only over the allowed local transitions.
ket_local_states[op_idx, bra_loc, n_transitions] = ket_loc
transition_values[op_idx, bra_loc, n_transitions] = elem
n_transitions += 1
transition_counts[op_idx, bra_loc] = n_transitions
return transition_counts, ket_local_states, transition_values
@njit(cache=True)
def _prefix_sum_counts(nnz_per_col: np.ndarray) -> np.ndarray:
"""
Turn column nnz counts into CSC col_ptr with a prefix sum.
col_ptr has length (n_cols + 1) and col_ptr[-1] = total_nnz.
"""
n_cols = nnz_per_col.size
col_ptr = np.empty(n_cols + 1, np.int32)
running_total = 0
col_ptr[0] = 0
for col_idx in range(n_cols):
running_total += nnz_per_col[col_idx]
col_ptr[col_idx + 1] = running_total
return col_ptr
@njit(inline="always")
def _insertion_sort_by_row(rows: np.ndarray, vals: np.ndarray, length: int) -> None:
"""
Small, stable, numba-friendly in-place sort by 'rows' for the first 'length' items.
Keeps CSC column rows in ascending order (nice to have, often expected).
"""
for entry_idx in range(1, length):
row_key = rows[entry_idx]
value_key = vals[entry_idx]
insert_idx = entry_idx - 1
while insert_idx >= 0 and rows[insert_idx] > row_key:
rows[insert_idx + 1] = rows[insert_idx]
vals[insert_idx + 1] = vals[insert_idx]
insert_idx -= 1
rows[insert_idx + 1] = row_key
vals[insert_idx + 1] = value_key
@njit(cache=True, parallel=True)
def precompute_C_sign_per_config(
sector_configs: np.ndarray, # (N, L) int32
C_label_sign: np.ndarray, # (d_loc,) float64, entries ∈ {+1.0, -1.0}
) -> np.ndarray:
"""
Returns S[i] = ∏_j C_label_sign[ sector_configs[i, j] ] as float64 ±1.
Works for both k=0 and finite-k paths.
"""
n_configs, n_sites = sector_configs.shape
out = np.ones(n_configs, np.float64)
for cfg_idx in prange(n_configs):
parity_sign = 1.0
for site_idx in range(n_sites):
parity_sign *= C_label_sign[sector_configs[cfg_idx, site_idx]]
out[cfg_idx] = parity_sign
return out
# ---------- linear index <-> coords (ROW-MAJOR)----------
@njit(inline="always")
def linear_to_coords_rowmajor(
site_index: int, lvals: np.ndarray, out_coords: np.ndarray
) -> None:
"""
Convert a linear site index [0..prod(L)-1] into D coords in row-major order.
out_coords is preallocated (length D).
"""
lattice_dim = lvals.size
for ax in range(lattice_dim - 1, -1, -1):
axis_length = lvals[ax]
out_coords[ax] = site_index % axis_length
site_index //= axis_length
@njit(inline="always")
def coords_to_linear_rowmajor(coords: np.ndarray, lvals: np.ndarray) -> int:
"""
Convert D coords -> linear site index in row-major order.
"""
lattice_dim = lvals.size
idx = 0
for ax in range(lattice_dim):
idx = idx * lvals[ax] + coords[ax]
return idx
# ---------- mixed-radix encode/decode of per-axis block-shifts ----------
@njit(inline="always")
def encode_shift(axis_shifts: np.ndarray, shifts_per_dir: np.ndarray) -> int:
"""
Map a D-vector of per-axis block shifts t=(t0,...,t_{D-1})
with ranges [0..R_d-1] to a single flat index in [0..prod(R)-1], using row-major.
"""
flat = 0
lattice_dim = shifts_per_dir.size
for ax in range(lattice_dim):
flat = flat * shifts_per_dir[ax] + (axis_shifts[ax] % shifts_per_dir[ax])
return flat
@njit(inline="always")
def decode_shift(
flat_index: int, shifts_per_dir: np.ndarray, out_axis_shifts: np.ndarray
) -> None:
"""
The inverse of encode_shift: flat -> D-vector of per-axis shifts.
"""
lattice_dim = shifts_per_dir.size
for ax in range(lattice_dim - 1, -1, -1):
out_axis_shifts[ax] = flat_index % shifts_per_dir[ax]
flat_index //= shifts_per_dir[ax]
@njit(inline="always")
def decode_mixed_index(idx: int, bases: np.ndarray, out: np.ndarray) -> None:
"""
Mixed-radix decode matching encode_shift's row-major convention:
axis 0 is most significant, axis D-1 least significant.
"""
lattice_dim = bases.size
for ax in range(lattice_dim - 1, -1, -1):
out[ax] = idx % bases[ax]
idx //= bases[ax]
@njit(cache=True)
def _compute_rowmajor_strides(bases: np.ndarray) -> np.ndarray:
"""Return row-major strides matching :func:`encode_shift`."""
lattice_dim = bases.size
strides = np.empty(lattice_dim, np.int32)
stride = 1
for ax in range(lattice_dim - 1, -1, -1):
# In row-major mixed-radix order, axis ``ax`` advances by the product
# of all bases to its right.
strides[ax] = stride
stride *= bases[ax]
return strides
@njit(inline="always")
def _next_power_of_two_at_least(min_size: int) -> int:
"""Return the smallest power of two greater than or equal to ``min_size``."""
size = 1
while size < min_size:
size <<= 1
return size
@njit(inline="always")
def _advance_period_box_counter(
t_local: np.ndarray, pvec: np.ndarray, shift_strides: np.ndarray, full_shift: int
) -> tuple[int, bool]:
"""Advance the local mixed-radix counter and update the full shift index.
``t_local`` enumerates the period box with per-axis bases ``pvec``.
``full_shift`` stores the corresponding flat shift index in the full
translation table, using the row-major strides of ``shifts_per_dir``.
"""
lattice_dim = pvec.size
for ax in range(lattice_dim - 1, -1, -1):
# Try to advance the least-significant still-active axis first.
t_local[ax] += 1
full_shift += shift_strides[ax]
if t_local[ax] < pvec[ax]:
return full_shift, False
# This digit overflowed: reset it and remove the completed block.
full_shift -= pvec[ax] * shift_strides[ax]
t_local[ax] = 0
return full_shift, True
@njit(inline="always")
def _insert_or_accumulate_real(
cfg_row: int,
incr: float,
used_indices: np.ndarray,
used_values: np.ndarray,
hash_keys: np.ndarray,
hash_positions: np.ndarray,
hash_mask: int,
used_len: int,
) -> int:
"""Insert or accumulate one real-valued orbit contribution."""
slot = cfg_row & hash_mask
while True:
key = hash_keys[slot]
if key == -1:
# First time this translated row appears in the orbit column.
hash_keys[slot] = cfg_row
hash_positions[slot] = used_len
used_indices[used_len] = cfg_row
used_values[used_len] = incr
return used_len + 1
if key == cfg_row:
# Repeated image of the same basis row: accumulate its weight.
used_values[hash_positions[slot]] += incr
return used_len
slot = (slot + 1) & hash_mask
@njit(inline="always")
def _insert_or_accumulate_complex(
cfg_row: int,
incr: complex,
used_indices: np.ndarray,
used_values: np.ndarray,
hash_keys: np.ndarray,
hash_positions: np.ndarray,
hash_mask: int,
used_len: int,
) -> int:
"""Insert or accumulate one complex-valued orbit contribution."""
slot = cfg_row & hash_mask
while True:
key = hash_keys[slot]
if key == -1:
# First time this translated row appears in the orbit column.
hash_keys[slot] = cfg_row
hash_positions[slot] = used_len
used_indices[used_len] = cfg_row
used_values[used_len] = incr
return used_len + 1
if key == cfg_row:
# Repeated image of the same basis row: accumulate its Bloch phase.
used_values[hash_positions[slot]] += incr
return used_len
slot = (slot + 1) & hash_mask
@njit(cache=True)
def _accumulate_zero_k_orbit(
translations_row: np.ndarray, pvec: np.ndarray, shift_strides: np.ndarray
) -> tuple[np.ndarray, np.ndarray, int]:
"""Accumulate one zero-momentum orbit column with hashed deduplication."""
lattice_dim = pvec.size
orbit_size = 1
for ax in range(lattice_dim):
orbit_size *= pvec[ax]
# Over-allocate one slot per point in the period box. After deduplication,
# only the first ``used_len`` entries are meaningful.
used_indices = np.empty(orbit_size, np.int32)
used_values = np.zeros(orbit_size, np.float64)
# The hash table maps translated basis rows -> position inside the compact
# ``used_*`` arrays. A power-of-two size keeps linear probing simple.
hash_size = _next_power_of_two_at_least(2 * orbit_size)
hash_keys = np.full(hash_size, -1, np.int32)
hash_positions = np.empty(hash_size, np.int32)
hash_mask = hash_size - 1
# ``t_local`` walks the local period box, while ``full_shift`` tracks the
# corresponding flat index inside the full translation table row.
t_local = np.zeros(lattice_dim, np.int32)
full_shift = 0
used_len = 0
finished = False
while not finished:
cfg_row = translations_row[full_shift]
used_len = _insert_or_accumulate_real(
cfg_row,
1.0,
used_indices,
used_values,
hash_keys,
hash_positions,
hash_mask,
used_len,
)
full_shift, finished = _advance_period_box_counter(
t_local, pvec, shift_strides, full_shift
)
return used_indices, used_values, used_len
@njit(cache=True)
def _prepare_phase_table(k_vals: np.ndarray, shifts_per_dir: np.ndarray) -> np.ndarray:
"""Precompute per-axis Bloch phases ``exp(-2πi k_d t / R_d)``."""
lattice_dim = shifts_per_dir.size
max_shift = 0
for ax in range(lattice_dim):
if shifts_per_dir[ax] > max_shift:
max_shift = shifts_per_dir[ax]
phase_table = np.empty((lattice_dim, max_shift), np.complex128)
for ax in range(lattice_dim):
# Shift zero always carries unit phase.
phase_table[ax, 0] = 1.0 + 0.0j
if shifts_per_dir[ax] <= 1:
continue
kd = k_vals[ax] % shifts_per_dir[ax]
step = np.exp(-1j * 2.0 * np.pi * kd / float(shifts_per_dir[ax]))
for shift_idx in range(1, shifts_per_dir[ax]):
# Reuse the previous power instead of calling exp repeatedly.
phase_table[ax, shift_idx] = phase_table[ax, shift_idx - 1] * step
return phase_table
@njit(cache=True)
def _accumulate_finite_k_orbit(
translations_row: np.ndarray,
pvec: np.ndarray,
shift_strides: np.ndarray,
phase_table: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, int]:
"""Accumulate one finite-momentum orbit column with hashed deduplication."""
lattice_dim = pvec.size
orbit_size = 1
for ax in range(lattice_dim):
orbit_size *= pvec[ax]
# Over-allocate one slot per point in the period box. After deduplication,
# only the first ``used_len`` entries are meaningful.
used_indices = np.empty(orbit_size, np.int32)
used_values = np.zeros(orbit_size, np.complex128)
hash_size = _next_power_of_two_at_least(2 * orbit_size)
hash_keys = np.full(hash_size, -1, np.int32)
hash_positions = np.empty(hash_size, np.int32)
hash_mask = hash_size - 1
# ``t_local`` walks the local period box, while ``full_shift`` tracks the
# corresponding flat index inside the full translation table row.
t_local = np.zeros(lattice_dim, np.int32)
full_shift = 0
used_len = 0
finished = False
while not finished:
phase = 1.0 + 0.0j
for ax in range(lattice_dim):
# The total Bloch phase factorizes axis by axis.
phase *= phase_table[ax, t_local[ax]]
cfg_row = translations_row[full_shift]
used_len = _insert_or_accumulate_complex(
cfg_row,
phase,
used_indices,
used_values,
hash_keys,
hash_positions,
hash_mask,
used_len,
)
full_shift, finished = _advance_period_box_counter(
t_local, pvec, shift_strides, full_shift
)
return used_indices, used_values, used_len
[docs]
@njit(cache=True)
def check_normalization(basis: np.ndarray) -> bool:
"""Check whether all columns of a basis matrix are normalized.
Parameters
----------
basis : ndarray
Basis matrix with basis vectors stored as columns.
Returns
-------
bool
``True`` if every column has unit norm.
"""
for col_idx in range(basis.shape[1]):
if not np.isclose(np.linalg.norm(basis[:, col_idx]), 1):
return False
return True
[docs]
@njit(cache=True)
def check_orthogonality(basis: np.ndarray) -> bool:
"""Check whether the columns of a basis matrix are mutually orthogonal.
Parameters
----------
basis : ndarray
Basis matrix with basis vectors stored as columns.
Returns
-------
bool
``True`` if all distinct column pairs are orthogonal.
"""
for left_col_idx in range(basis.shape[1]):
for right_col_idx in range(left_col_idx + 1, basis.shape[1]):
if not np.isclose(
np.vdot(basis[:, left_col_idx], basis[:, right_col_idx]),
0,
atol=1e-10,
):
return False
return True
@njit(cache=True, parallel=True)
def build_TC_translations(
sector_configs: np.ndarray, # (N, L) dressed local-state ids
C_map: np.ndarray, # (d_loc,) local map for even sites
):
"""
Build the orbit table of the combined generator X = T ∘ C on a 1D ring.
IMPORTANT: At each step apply C (sitewise, even/odd possibly different),
then translate by +1. Repeat this t times to get X^t.
Returns:
Ttab : (N, L) int32, Ttab[i, t] = index of (X^t)|config_i>
shifts_per_dir : (1,) int32, with shifts_per_dir[0] = L
"""
n_configs, n_sites = sector_configs.shape
num_tc_steps = n_sites
translations_table = np.empty((n_configs, num_tc_steps), np.int32)
for cfg_idx in prange(n_configs):
base = sector_configs[cfg_idx]
# current config after t applications (start at t=0)
current_config = np.empty(n_sites, np.int32)
for site_idx in range(n_sites):
current_config[site_idx] = base[site_idx]
translations_table[cfg_idx, 0] = cfg_idx # X^0 = identity
mapped_config = np.empty(n_sites, np.int32) # scratch for C action
for step_idx in range(1, num_tc_steps):
# 1) apply C at this step in the LAB frame (before translating)
for site_idx in range(n_sites):
local_state = current_config[site_idx]
if (site_idx & 1) == 0: # even site index
mapped_config[site_idx] = C_map[local_state]
else: # odd site index
mapped_config[site_idx] = C_map[local_state]
# 2) translate by +1: dest[(j+1) % L] = work[j]
for site_idx in range(n_sites):
current_config[(site_idx + 1) % n_sites] = mapped_config[site_idx]
# 3) lookup
translations_table[cfg_idx, step_idx] = config_to_index_binarysearch(
current_config, sector_configs
)
shifts_per_dir = np.empty(1, np.int32)
shifts_per_dir[0] = num_tc_steps
return translations_table, shifts_per_dir
# ─────────────────────────────────────────────────────────────────────────────
@njit(cache=True)
def _prepare_translation_source_sites(
lvals: np.ndarray, unit_cell_size: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""Precompute how each allowed block-translation permutes lattice sites.
Parameters
----------
lvals : ndarray
Lattice lengths along each spatial direction.
unit_cell_size : ndarray
Translation step along each axis, expressed in lattice sites.
Returns
-------
tuple
``(source_sites, shifts_per_dir)`` where
``source_sites[flat_shift, dest_site_idx]`` gives the source lattice
site that lands on ``dest_site_idx`` after the decoded block-translation
``flat_shift`` is applied.
Notes
-----
This helper moves all coordinate algebra out of the hot ``cfg_idx`` loop of
:func:`build_all_translations`. Once the source table is known, translating a
configuration reduces to a simple gather:
``rolled_cfg[dest_site_idx] = cfg[source_sites[flat_shift, dest_site_idx]]``.
"""
lattice_dim = lvals.size
shifts_per_dir = lvals // unit_cell_size
total_shifts = 1
n_sites = 1
for ax in range(lattice_dim):
total_shifts *= shifts_per_dir[ax]
n_sites *= lvals[ax]
source_sites = np.empty((total_shifts, n_sites), np.int32)
axis_shifts = np.empty(lattice_dim, np.int32)
coords = np.empty(lattice_dim, np.int32)
shifted_coords = np.empty(lattice_dim, np.int32)
# Build the permutation table once for every allowed combined shift.
for flat_shift in range(total_shifts):
decode_shift(flat_shift, shifts_per_dir, axis_shifts)
for source_site_idx in range(n_sites):
# Convert the source site to lattice coordinates.
linear_to_coords_rowmajor(source_site_idx, lvals, coords)
# Apply the block translation in coordinate space.
for ax in range(lattice_dim):
shifted_coords[ax] = (
coords[ax] + axis_shifts[ax] * unit_cell_size[ax]
) % lvals[ax]
# Record which source site populates the translated destination site.
dest_site_idx = coords_to_linear_rowmajor(shifted_coords, lvals)
source_sites[flat_shift, dest_site_idx] = source_site_idx
return source_sites, shifts_per_dir
# ─────────────────────────────────────────────────────────────────────────────
@njit(cache=True, parallel=True)
def build_all_translations(
sector_configs: np.ndarray,
lvals: np.ndarray,
unit_cell_size: np.ndarray,
):
"""
Build the table of all allowed block-translations for every configuration in sector_config.
For each direction ax the number of possible translations or shifts are
shifts_per_dir[ax] = lvals[ax] // unit_cell_size[ax]
Given a fixed config, roll it along ax by multiples of unit_cell_size[ax]
and record the index within sector_configs of each translation.
Theory:
-------
Allowed translations are t_d * s_d along axis d, with t_d in [0..R_d-1],
where R_d = L_d / s_d. The full translation group is the product of those
cyclic groups. For each config c, its orbit points are T(t) c for all t.
Returns:
translations_array: int32 array of shape (N_configs, prod(R_d))
translations_array[i, s] gives the index (row in sector_configs)
of the configuration obtained by applying the block-shift decoded
from 's' to configuration i.
shifts_per_dir: int32 array (D,) with R_d = L_d / s_d.
"""
n_configs, n_sites = sector_configs.shape
# Precompute the geometry-dependent site permutations once, before touching
# the potentially huge number of sector configurations.
source_sites, shifts_per_dir = _prepare_translation_source_sites(
lvals, unit_cell_size
)
total_shifts = source_sites.shape[0]
# Allocate the memory for the array with all the translations
translations_array = np.empty((n_configs, total_shifts), np.int32)
# Run over all the configs
for cfg_idx in prange(n_configs):
cfg = sector_configs[cfg_idx]
# Reuse one scratch array for every shift of this configuration.
rolled_cfg = np.empty(n_sites, sector_configs.dtype)
# Consider all the possible combined block-shifts.
for flat_shift in range(total_shifts):
# Gather the translated configuration via the precomputed site map.
for dest_site_idx in range(n_sites):
rolled_cfg[dest_site_idx] = cfg[source_sites[flat_shift, dest_site_idx]]
# Find the index of the rolled config in sector_configs
translations_array[cfg_idx, flat_shift] = config_to_index_binarysearch(
rolled_cfg, sector_configs
)
return translations_array, shifts_per_dir
@njit(cache=True)
def select_references(
translations_array: np.ndarray, # shape (n_configs, total_shifts)
# Each row stores the orbit obtained from all block shifts.
shifts_per_dir: np.ndarray, # shape (D,)
# R_d = L_d // s_d gives the number of block positions per axis.
k_vals: np.ndarray, # shape (D,), momentum labels k_d in [0..R_d-1]
):
"""
Pick one reference per orbit, compute its axis-period vector p (minimal block
shifts along each axis returning to itself), and keep the orbit iff
(k_d * p_d) % R_d == 0 for ALL axes d.
Theory:
-------
- The orbit of i is the set { translations_array[i, s] for s in 0..total_shifts-1 }.
- Axis period p_d is the smallest p>0 such that shifting by p blocks on axis d
returns i to itself. In the table, that means:
translations_array[i, encode_shift(p * e_d)] == i.
- Momentum sector k is consistent with the orbit iff the character is trivial
on the stabilizer: exp(-2π i k_d p_d / R_d) == 1, i.e. (k_d * p_d) % R_d == 0.
"""
n_configs, total_shifts = translations_array.shape
lattice_dim = shifts_per_dir.size
# Output buffers (over-allocated, trimmed at the end)
references = np.empty(n_configs, np.int32)
period_vectors = np.empty((n_configs, lattice_dim), np.int32)
# Bookkeeping: which config indices already belong to an orbit we've processed?
assigned = np.zeros(n_configs, np.uint8)
# Convert axis-only shifts into flat indices without rebuilding temporary
# vectors inside the innermost loop.
shift_strides = _compute_rowmajor_strides(shifts_per_dir)
n_refs = 0
# loop in index order
for cfg_idx in range(n_configs):
if assigned[cfg_idx]:
# Already included in a previously discovered orbit
continue
# Row listing all T(t)|cfg_idx>, for all combined block shifts t
orbit_row = translations_array[cfg_idx]
# ---------- 1) Compute the axis period vector p (minimal periods) ----------
pvec = np.empty(lattice_dim, np.int32)
for ax in range(lattice_dim):
found = False
# Scan p = 1..R_d until the shift along axis 'ax' returns to the reference
for period_candidate in range(1, shifts_per_dir[ax] + 1):
# ``p = R_d`` wraps back to zero shift, which is why the modulo
# reproduces the full-cycle fallback without a separate branch.
flat_shift_index = (
period_candidate % shifts_per_dir[ax]
) * shift_strides[ax]
if orbit_row[flat_shift_index] == cfg_idx: # returned to itself
pvec[ax] = period_candidate
found = True
break
if not found:
# Should not occur for well-formed tables; fall back to the full cycle
pvec[ax] = shifts_per_dir[ax]
# ---------- 2) Mark the ENTIRE orbit as assigned ----------
# Important: do this before the momentum filter, so we never duplicate orbits
for shift_idx in range(total_shifts):
assigned[orbit_row[shift_idx]] = 1
# ---------- 3) Momentum-compatibility filter ----------
# Keep the orbit iff (k_d * p_d) % R_d == 0 for all axes d
keep = True
for ax in range(lattice_dim):
if (k_vals[ax] * pvec[ax]) % shifts_per_dir[ax] != 0:
keep = False
break
if not keep:
# Orbit incompatible with the requested momentum; skip it
continue
# ---------- 4) Accept this orbit representative ----------
references[n_refs] = cfg_idx
for ax in range(lattice_dim):
period_vectors[n_refs, ax] = pvec[ax]
n_refs += 1
# Trim the over-allocated outputs
return references[:n_refs], period_vectors[:n_refs, :]
@njit(cache=True, parallel=True)
def momentum_basis_zero_k(
sector_configs: np.ndarray, # (n_configs, n_sites) int32
lvals: np.ndarray, # (D,) int32 lattice lengths
unit_cell_size: np.ndarray, # (D,) int32 block sizes s_d (must divide L_d)
):
"""
Build the Γ (k=0) momentum projector B in **sparse form**:
- CSC arrays: (L_col_ptr, L_row_idx, L_data) — float64
- CSR arrays: (R_row_ptr, R_col_idx, R_data) — float64
Mathematical content (Γ sector):
--------------------------------
For each translation orbit (represented by a 'reference' config with axis
period-vector p), the Γ vector is:
|Γ; ref> = (1 / sqrt(∏_d p_d)) * sum_{0 <= t_d < p_d} T(t) |ref>.
All phases are 1, so the basis is real.
Implementation overview:
------------------------
1) Precompute the translation table and per-axis block counts R_d.
2) Choose one orbit representative per orbit (select_references with k=0).
3) Two-pass CSC build:
PASS 1: for each column (reference), deduplicate images over the “period box”
and count how many unique rows it will write (nnz_per_col).
PASS 2: repeat, but write the normalized values into CSC arrays.
(We also sort row indices within each column for canonical CSC.)
4) Build a CSR view from CSC (row-wise prefix sum + scatter).
Returns
-------
L_col_ptr : (n_cols+1,) int32
L_row_idx : (nnz,) int32
L_data : (nnz,) float64
R_row_ptr : (n_rows+1,) int32
R_col_idx : (nnz,) int32
R_data : (nnz,) float64
Note
----
You can get:
n_rows = sector_configs.shape[0]
n_cols = L_col_ptr.shape[0] - 1
"""
# ---------- 0) Sizes & translations ----------
n_rows = sector_configs.shape[0]
translations_array, shifts_per_dir = build_all_translations(
sector_configs, lvals, unit_cell_size
)
lattice_dim = shifts_per_dir.size
# These strides let the orbit helper update the full translation index
# incrementally while it walks the period box.
shift_strides = _compute_rowmajor_strides(shifts_per_dir)
# ---------- 1) Orbit representatives and their period-vectors ----------
k_zero = np.zeros(lattice_dim, np.int32)
references, period_vectors = select_references(
translations_array, shifts_per_dir, k_zero
)
n_cols = references.shape[0]
# ---------- 2) PASS 1: count per-column nonzeros (after dedup) ----------
nnz_per_col = np.zeros(n_cols, np.int32)
for col_idx in prange(n_cols):
ref_cfg_index = references[col_idx]
pvec = period_vectors[col_idx, :] # (D,)
# Build the deduplicated orbit content once, then read only its length.
used_indices, used_values, used_len = _accumulate_zero_k_orbit(
translations_array[ref_cfg_index], pvec, shift_strides
)
nnz_per_col[col_idx] = used_len
# CSC structure
L_col_ptr = _prefix_sum_counts(nnz_per_col)
total_nnz = L_col_ptr[-1]
L_row_idx = np.empty(total_nnz, np.int32)
L_data = np.empty(total_nnz, np.float64)
# ---------- 3) PASS 2: fill CSC (dedup + normalize + (optional) sort) ----------
for col_idx in prange(n_cols):
ref_cfg_index = references[col_idx]
pvec = period_vectors[col_idx, :]
# Rebuild the same orbit content, now to normalize and materialize it.
used_indices, used_values, used_len = _accumulate_zero_k_orbit(
translations_array[ref_cfg_index], pvec, shift_strides
)
# Normalize the column by its 2-norm (counts are real here)
norm_sq = 0.0
for entry_idx in range(used_len):
norm_sq += used_values[entry_idx] * used_values[entry_idx]
if np.isclose(norm_sq, 0.0):
continue
inv_norm = 1.0 / np.sqrt(norm_sq)
# (Optional but nice): sort by row index for canonical CSC
_insertion_sort_by_row(used_indices, used_values, used_len)
# Write this column’s entries into CSC arrays
write = L_col_ptr[col_idx]
for entry_idx in range(used_len):
L_row_idx[write] = used_indices[entry_idx]
L_data[write] = used_values[entry_idx] * inv_norm
write += 1
# ---------- 4) Build CSR view from CSC (no SciPy) ----------
R_row_ptr = np.zeros(n_rows + 1, np.int32)
# row counts
for nnz_idx in range(total_nnz):
row_idx = L_row_idx[nnz_idx]
R_row_ptr[row_idx + 1] += 1
# prefix sum
for row_idx in range(n_rows):
R_row_ptr[row_idx + 1] += R_row_ptr[row_idx]
# scatter
R_col_idx = np.empty(total_nnz, np.int32)
R_data = np.empty(total_nnz, np.float64)
# work heads (copy)
heads = np.empty(n_rows, np.int32)
for row_idx in range(n_rows):
heads[row_idx] = R_row_ptr[row_idx]
for col_idx in range(n_cols):
start = L_col_ptr[col_idx]
stop = L_col_ptr[col_idx + 1]
for nnz_idx in range(start, stop):
row_idx = L_row_idx[nnz_idx]
write_idx = heads[row_idx]
R_col_idx[write_idx] = col_idx
R_data[write_idx] = L_data[nnz_idx]
heads[row_idx] += 1
return (L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data)
@njit(cache=True, parallel=True)
def momentum_basis_finite_k(
sector_configs: np.ndarray, # (n_configs, n_sites) int32
lvals: np.ndarray, # (D,) int32
unit_cell_size: np.ndarray, # (D,) int32 (must divide L_d)
k_vals: np.ndarray, # (D,) int32 momenta mod R_d
):
"""
Build the finite-k momentum projector B in sparse form:
- CSC arrays: (L_col_ptr, L_row_idx, L_data) — complex128
- CSR arrays: (R_row_ptr, R_col_idx, R_data) — complex128
Math:
R_d = L_d / s_d, group G = Z_{R0} × ... × Z_{R_{D-1}}.
For an orbit representative `ref` with axis‐periods p_d, the finite-k Bloch sum is
|k; ref> = (1/√N) ∑_{0≤t_d<p_d} exp[-2πi Σ_d (k_d t_d / R_d)] T(t) |ref>,
where N = ∑_{distinct images j} |amplitude_j|^2 after deduplication.
If the orbit is incompatible with k, all amplitudes cancel → empty column.
Determinism:
Each column is built independently (per-iteration scratch), then we sort
its (row, value) pairs by row index before writing CSC.
Returns
-------
L_col_ptr : (n_cols+1,) int32
L_row_idx : (nnz,) int32
L_data : (nnz,) complex128
R_row_ptr : (n_rows+1,) int32
R_col_idx : (nnz,) int32
R_data : (nnz,) complex128
(As usual: n_rows = sector_configs.shape[0], n_cols = L_col_ptr.size - 1.)
"""
# Tolerances for pruning true zeros / near-incompatible columns
TOL_ZERO = 1e-14 # per-entry amplitude threshold
TOL_COLNORM = 1e-30 # column norm^2 threshold
# ---------- 0) Sizes & translations ----------
n_rows = sector_configs.shape[0]
translations_array, shifts_per_dir = build_all_translations(
sector_configs, lvals, unit_cell_size
)
# These strides let the orbit helper update the full translation index
# incrementally while it walks the period box.
shift_strides = _compute_rowmajor_strides(shifts_per_dir)
# Precompute all one-axis phase factors once and reuse them across columns.
phase_table = _prepare_phase_table(k_vals, shifts_per_dir)
# ---------- 1) Orbit representatives & periods (filtered by k_vals) ----------
references, period_vectors = select_references(
translations_array, shifts_per_dir, k_vals
)
n_cols = references.shape[0]
# ---------- 2) PASS 1: count nonzeros per column (after cancellations) ----------
nnz_per_col = np.zeros(n_cols, np.int32)
for col_idx in prange(n_cols):
ref_cfg_index = references[col_idx]
pvec = period_vectors[col_idx, :] # (D,)
# Build the deduplicated orbit content once, then only count the
# entries that survive Bloch-phase cancellations.
used_indices, used_values, used_len = _accumulate_finite_k_orbit(
translations_array[ref_cfg_index], pvec, shift_strides, phase_table
)
# Count only truly non-zero amplitudes after cancellation
surviving_entries = 0
for entry_idx in range(used_len):
if np.abs(used_values[entry_idx]) > TOL_ZERO:
surviving_entries += 1
nnz_per_col[col_idx] = surviving_entries
# Allocate CSC
L_col_ptr = _prefix_sum_counts(nnz_per_col)
total_nnz = L_col_ptr[-1]
L_row_idx = np.empty(total_nnz, np.int32)
L_data = np.empty(total_nnz, np.complex128)
# ---------- 3) PASS 2: fill CSC (dedup + prune zeros + normalize + sort) ----------
for col_idx in prange(n_cols):
ref_cfg_index = references[col_idx]
pvec = period_vectors[col_idx, :]
# Rebuild the same orbit content, now to prune zeros, normalize, and
# write the canonical CSC column.
used_indices, used_values, used_len = _accumulate_finite_k_orbit(
translations_array[ref_cfg_index], pvec, shift_strides, phase_table
)
# In-place prune entries that cancelled to ~0, and compute column norm
kept_len = 0
norm_sq = 0.0
for entry_idx in range(used_len):
amplitude = used_values[entry_idx]
if np.abs(amplitude) > TOL_ZERO:
used_indices[kept_len] = used_indices[entry_idx]
used_values[kept_len] = amplitude
norm_sq += (
amplitude.real * amplitude.real + amplitude.imag * amplitude.imag
)
kept_len += 1
# If the column fully cancels, write nothing
if norm_sq <= TOL_COLNORM or kept_len == 0:
continue
inv_norm = 1.0 / np.sqrt(norm_sq)
# Sort by row index for canonical CSC
_insertion_sort_by_row(used_indices, used_values, kept_len)
# Materialize
write = L_col_ptr[col_idx]
for entry_idx in range(kept_len):
L_row_idx[write] = used_indices[entry_idx]
L_data[write] = used_values[entry_idx] * inv_norm
write += 1
# ---------- 4) Build CSR from CSC ----------
R_row_ptr = np.zeros(n_rows + 1, np.int32)
for nnz_idx in range(total_nnz):
row_idx = L_row_idx[nnz_idx]
R_row_ptr[row_idx + 1] += 1
for row_idx in range(n_rows):
R_row_ptr[row_idx + 1] += R_row_ptr[row_idx]
R_col_idx = np.empty(total_nnz, np.int32)
R_data = np.empty(total_nnz, np.complex128)
heads = np.empty(n_rows, np.int32)
for row_idx in range(n_rows):
heads[row_idx] = R_row_ptr[row_idx]
for col_idx in range(n_cols):
start = L_col_ptr[col_idx]
stop = L_col_ptr[col_idx + 1]
for nnz_idx in range(start, stop):
row_idx = L_row_idx[nnz_idx]
write_idx = heads[row_idx]
R_col_idx[write_idx] = col_idx
R_data[write_idx] = L_data[nnz_idx]
heads[row_idx] += 1
return (L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data)
@njit(cache=True, parallel=True)
def momentum_basis_zero_k_TC(sector_configs: np.ndarray):
"""
Build the Γ (k=0) momentum projector B in **sparse form**:
- CSC arrays: (L_col_ptr, L_row_idx, L_data) — float64
- CSR arrays: (R_row_ptr, R_col_idx, R_data) — float64
Mathematical content (Γ sector):
--------------------------------
For each translation orbit (represented by a 'reference' config with axis
period-vector p), the Γ vector is:
|Γ; ref> = (1 / sqrt(∏_d p_d)) * sum_{0 <= t_d < p_d} T(t) |ref>.
All phases are 1, so the basis is real.
Implementation overview:
------------------------
1) Precompute the translation table and per-axis block counts R_d.
2) Choose one orbit representative per orbit (select_references with k=0).
3) Two-pass CSC build:
PASS 1: for each column (reference), deduplicate images over the “period box”
and count how many unique rows it will write (nnz_per_col).
PASS 2: repeat, but write the normalized values into CSC arrays.
(We also sort row indices within each column for canonical CSC.)
4) Build a CSR view from CSC (row-wise prefix sum + scatter).
Returns
-------
L_col_ptr : (n_cols+1,) int32
L_row_idx : (nnz,) int32
L_data : (nnz,) float64
R_row_ptr : (n_rows+1,) int32
R_col_idx : (nnz,) int32
R_data : (nnz,) float64
Note
----
You can get:
n_rows = sector_configs.shape[0]
n_cols = L_col_ptr.shape[0] - 1
"""
# ---------- 0) Sizes & translations ----------
n_rows = sector_configs.shape[0]
# Special 1D case with TC+inversion symmetry
C_map = np.array([4, 5, 2, 3, 0, 1], dtype=np.int32)
C_phase_label = np.array([+1, +1, +1, -1, +1, +1], dtype=np.float64)
C_phase_per_config = precompute_C_sign_per_config(sector_configs, C_phase_label)
translations_array, shifts_per_dir = build_TC_translations(sector_configs, C_map)
lattice_dim = shifts_per_dir.size
# ---------- 1) Orbit representatives and their period-vectors ----------
k_zero = np.zeros(lattice_dim, np.int32)
references, period_vectors = select_references(
translations_array, shifts_per_dir, k_zero
)
n_cols = references.shape[0]
# ---------- 2) PASS 1: count per-column nonzeros (after dedup) ----------
nnz_per_col = np.zeros(n_cols, np.int32)
for col_idx in prange(n_cols):
ref_cfg_index = references[col_idx]
pvec = period_vectors[col_idx, :] # (D,)
# Upper bound for distinct images when scanning the period box
orbit_size = np.prod(pvec)
# Per-column dedup accumulator (indices + counts)
used_indices = np.empty(orbit_size, np.int32)
used_values = np.zeros(orbit_size, np.float64) # counts (phase=1)
used_len = 0
# local mixed-radix index in the period box
t_local = np.zeros(lattice_dim, np.int32)
# Enumerate all t in 0..p_d-1 and map via the precomputed table
for local_shift_idx in range(orbit_size):
decode_mixed_index(local_shift_idx, pvec, t_local)
flat_full = encode_shift(t_local, shifts_per_dir) # base R_d
cfg_row = translations_array[ref_cfg_index, flat_full]
# -----------------------------------------------------
incr = 1.0
# Special 1D case with TC symmetry
ref_sign = C_phase_per_config[ref_cfg_index]
if (flat_full & 1) == 1: # t odd?
incr = ref_sign
# deduplicate: linear scan OK (orbit boxes are small)
pos = -1
for entry_idx in range(used_len):
if used_indices[entry_idx] == cfg_row:
pos = entry_idx
break
if pos == -1:
pos = used_len
used_indices[pos] = cfg_row
used_len += 1
# Odd powers of the TC generator contribute the parity sign of the
# reference configuration instead of a plain +1 factor.
used_values[pos] += incr
nnz_per_col[col_idx] = used_len
# CSC structure
L_col_ptr = _prefix_sum_counts(nnz_per_col)
total_nnz = L_col_ptr[-1]
L_row_idx = np.empty(total_nnz, np.int32)
L_data = np.empty(total_nnz, np.float64)
# ---------- 3) PASS 2: fill CSC (dedup + normalize + (optional) sort) ----------
for col_idx in prange(n_cols):
ref_cfg_index = references[col_idx]
pvec = period_vectors[col_idx, :]
# same bound
orbit_size = np.prod(pvec)
used_indices = np.empty(orbit_size, np.int32)
used_values = np.zeros(orbit_size, np.float64)
used_len = 0
t_local = np.zeros(lattice_dim, np.int32)
for local_shift_idx in range(orbit_size):
decode_mixed_index(local_shift_idx, pvec, t_local)
flat_full = encode_shift(t_local, shifts_per_dir)
cfg_row = translations_array[ref_cfg_index, flat_full]
# -----------------------------------------------------
incr = 1.0
# Special 1D case with TC symmetry
ref_sign = C_phase_per_config[ref_cfg_index]
if (flat_full & 1) == 1: # t odd?
incr = ref_sign
pos = -1
for entry_idx in range(used_len):
if used_indices[entry_idx] == cfg_row:
pos = entry_idx
break
if pos == -1:
pos = used_len
used_indices[pos] = cfg_row
used_len += 1
used_values[pos] += incr
# Normalize the column by its 2-norm (counts are real here)
norm_sq = 0.0
for entry_idx in range(used_len):
norm_sq += used_values[entry_idx] * used_values[entry_idx]
if np.isclose(norm_sq, 0.0):
continue
inv_norm = 1.0 / np.sqrt(norm_sq)
# (Optional but nice): sort by row index for canonical CSC
_insertion_sort_by_row(used_indices, used_values, used_len)
# Write this column’s entries into CSC arrays
write = L_col_ptr[col_idx]
for entry_idx in range(used_len):
L_row_idx[write] = used_indices[entry_idx]
L_data[write] = used_values[entry_idx] * inv_norm
write += 1
# ---------- 4) Build CSR view from CSC (no SciPy) ----------
R_row_ptr = np.zeros(n_rows + 1, np.int32)
# row counts
for nnz_idx in range(total_nnz):
row_idx = L_row_idx[nnz_idx]
R_row_ptr[row_idx + 1] += 1
# prefix sum
for row_idx in range(n_rows):
R_row_ptr[row_idx + 1] += R_row_ptr[row_idx]
# scatter
R_col_idx = np.empty(total_nnz, np.int32)
R_data = np.empty(total_nnz, np.float64)
# work heads (copy)
heads = np.empty(n_rows, np.int32)
for row_idx in range(n_rows):
heads[row_idx] = R_row_ptr[row_idx]
for col_idx in range(n_cols):
start = L_col_ptr[col_idx]
stop = L_col_ptr[col_idx + 1]
for nnz_idx in range(start, stop):
row_idx = L_row_idx[nnz_idx]
write_idx = heads[row_idx]
R_col_idx[write_idx] = col_idx
R_data[write_idx] = L_data[nnz_idx]
heads[row_idx] += 1
return (L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data)
@njit(cache=True, parallel=True)
def momentum_basis_finite_k_TC(
sector_configs: np.ndarray, # (n_configs, n_sites) int32
k_vals: np.ndarray, # (D,) int32 momenta mod R_d
):
"""
Build the finite-k momentum projector B in sparse form:
- CSC arrays: (L_col_ptr, L_row_idx, L_data) — complex128
- CSR arrays: (R_row_ptr, R_col_idx, R_data) — complex128
Math:
R_d = L_d / s_d, group G = Z_{R0} × ... × Z_{R_{D-1}}.
For an orbit representative `ref` with axis‐periods p_d, the finite-k Bloch sum is
|k; ref> = (1/√N) ∑_{0≤t_d<p_d} exp[-2πi Σ_d (k_d t_d / R_d)] T(t) |ref>,
where N = ∑_{distinct images j} |amplitude_j|^2 after deduplication.
If the orbit is incompatible with k, all amplitudes cancel → empty column.
Determinism:
Each column is built independently (per-iteration scratch), then we sort
its (row, value) pairs by row index before writing CSC.
Returns
-------
L_col_ptr : (n_cols+1,) int32
L_row_idx : (nnz,) int32
L_data : (nnz,) complex128
R_row_ptr : (n_rows+1,) int32
R_col_idx : (nnz,) int32
R_data : (nnz,) complex128
(As usual: n_rows = sector_configs.shape[0], n_cols = L_col_ptr.size - 1.)
"""
# Tolerances for pruning true zeros / near-incompatible columns
TOL_ZERO = 1e-14 # per-entry amplitude threshold
TOL_COLNORM = 1e-30 # column norm^2 threshold
# ---------- 0) Sizes & translations ----------
n_rows = sector_configs.shape[0]
# Special 1D case with TC+inversion symmetry
C_map = np.array([4, 5, 2, 3, 0, 1], dtype=np.int32)
C_phase_label = np.array([+1, +1, +1, -1, +1, +1], dtype=np.float64)
C_phase_per_config = precompute_C_sign_per_config(sector_configs, C_phase_label)
translations_array, shifts_per_dir = build_TC_translations(sector_configs, C_map)
lattice_dim = shifts_per_dir.size
# ---------- 1) Orbit representatives & periods (filtered by k_vals) ----------
references, period_vectors = select_references(
translations_array, shifts_per_dir, k_vals
)
n_cols = references.shape[0]
# ---------- 2) PASS 1: count nonzeros per column (after cancellations) ----------
nnz_per_col = np.zeros(n_cols, np.int32)
for col_idx in prange(n_cols):
ref_cfg_index = references[col_idx]
pvec = period_vectors[col_idx, :] # (D,)
# Orbit-box upper bound
orbit_size = 1
for ax in range(lattice_dim):
orbit_size *= pvec[ax]
# Per-column accumulators (dedup by image row)
used_indices = np.empty(orbit_size, np.int32)
used_values = np.zeros(orbit_size, np.complex128)
used_len = 0
t_local = np.zeros(lattice_dim, np.int32)
# Enumerate period box, accumulate complex phase per distinct image
for local_shift_idx in range(orbit_size):
decode_mixed_index(local_shift_idx, pvec, t_local)
# phase = exp(-2πi Σ_d (k_d * t_d / R_d))
phase_arg = 0.0
for ax in range(lattice_dim):
kd = k_vals[ax] % shifts_per_dir[ax]
phase_arg += (kd * t_local[ax]) / float(shifts_per_dir[ax])
phase = np.exp(-1j * 2.0 * np.pi * phase_arg)
flat_full = encode_shift(t_local, shifts_per_dir) # base R_d
cfg_row = translations_array[ref_cfg_index, flat_full]
# -----------------------------------------------------
incr = 1.0
# In 1D SU(2) we can implement the TC symmetry
ref_sign = C_phase_per_config[ref_cfg_index]
if (flat_full & 1) == 1: # t odd?
incr = ref_sign
# deduplicate by row
pos = -1
for entry_idx in range(used_len):
if used_indices[entry_idx] == cfg_row:
pos = entry_idx
break
if pos == -1:
pos = used_len
used_indices[pos] = cfg_row
used_len += 1
used_values[pos] += incr * phase
# Count only truly non-zero amplitudes after cancellation
surviving_entries = 0
for entry_idx in range(used_len):
if np.abs(used_values[entry_idx]) > TOL_ZERO:
surviving_entries += 1
nnz_per_col[col_idx] = surviving_entries
# Allocate CSC
L_col_ptr = _prefix_sum_counts(nnz_per_col)
total_nnz = L_col_ptr[-1]
L_row_idx = np.empty(total_nnz, np.int32)
L_data = np.empty(total_nnz, np.complex128)
# ---------- 3) PASS 2: fill CSC (dedup + prune zeros + normalize + sort) ----------
for col_idx in prange(n_cols):
ref_cfg_index = references[col_idx]
pvec = period_vectors[col_idx, :]
orbit_size = np.prod(pvec)
used_indices = np.empty(orbit_size, np.int32)
used_values = np.zeros(orbit_size, np.complex128)
used_len = 0
t_local = np.zeros(lattice_dim, np.int32)
# Deduplicate + accumulate complex phases
for local_shift_idx in range(orbit_size):
decode_mixed_index(local_shift_idx, pvec, t_local)
phase_arg = 0.0
for ax in range(lattice_dim):
kd = k_vals[ax] % shifts_per_dir[ax]
phase_arg += (kd * t_local[ax]) / float(shifts_per_dir[ax])
phase = np.exp(-1j * 2.0 * np.pi * phase_arg)
flat_full = encode_shift(t_local, shifts_per_dir)
cfg_row = translations_array[ref_cfg_index, flat_full]
# -----------------------------------------------------
incr = 1.0
# In 1D SU(2) we can implement the TC symmetry
ref_sign = C_phase_per_config[ref_cfg_index]
if (flat_full & 1) == 1: # t odd?
incr = ref_sign
# deduplicate by row
pos = -1
for entry_idx in range(used_len):
if used_indices[entry_idx] == cfg_row:
pos = entry_idx
break
if pos == -1:
pos = used_len
used_indices[pos] = cfg_row
used_len += 1
used_values[pos] += incr * phase
# In-place prune entries that cancelled to ~0, and compute column norm
kept_len = 0
norm_sq = 0.0
for entry_idx in range(used_len):
amplitude = used_values[entry_idx]
if np.abs(amplitude) > TOL_ZERO:
used_indices[kept_len] = used_indices[entry_idx]
used_values[kept_len] = amplitude
norm_sq += (
amplitude.real * amplitude.real + amplitude.imag * amplitude.imag
)
kept_len += 1
# If the column fully cancels, write nothing
if norm_sq <= TOL_COLNORM or kept_len == 0:
continue
inv_norm = 1.0 / np.sqrt(norm_sq)
# Sort by row index for canonical CSC
_insertion_sort_by_row(used_indices, used_values, kept_len)
# Materialize
write = L_col_ptr[col_idx]
for entry_idx in range(kept_len):
L_row_idx[write] = used_indices[entry_idx]
L_data[write] = used_values[entry_idx] * inv_norm
write += 1
# ---------- 4) Build CSR from CSC ----------
R_row_ptr = np.zeros(n_rows + 1, np.int32)
for nnz_idx in range(total_nnz):
row_idx = L_row_idx[nnz_idx]
R_row_ptr[row_idx + 1] += 1
for row_idx in range(n_rows):
R_row_ptr[row_idx + 1] += R_row_ptr[row_idx]
R_col_idx = np.empty(total_nnz, np.int32)
R_data = np.empty(total_nnz, np.complex128)
heads = np.empty(n_rows, np.int32)
for row_idx in range(n_rows):
heads[row_idx] = R_row_ptr[row_idx]
for col_idx in range(n_cols):
start = L_col_ptr[col_idx]
stop = L_col_ptr[col_idx + 1]
for nnz_idx in range(start, stop):
row_idx = L_row_idx[nnz_idx]
write_idx = heads[row_idx]
R_col_idx[write_idx] = col_idx
R_data[write_idx] = L_data[nnz_idx]
heads[row_idx] += 1
return (L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data)
# ─────────────────────────────────────────────────────────────────────────────
[docs]
def get_momentum_basis(
sector_configs: np.ndarray,
lvals: list[int],
unit_cell_size: np.ndarray,
k_vals: np.ndarray,
TC_symmetry: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Construct a sparse momentum-basis projector representation.
Parameters
----------
sector_configs : ndarray
Symmetry-sector configurations, one row per basis state.
lvals : list
Lattice lengths along each spatial direction.
unit_cell_size : ndarray
Translation step (logical unit-cell size) per spatial direction.
k_vals : ndarray
Momentum quantum numbers (one per spatial direction, or the effective
translation-combined symmetry momentum in ``TC_symmetry`` mode).
TC_symmetry : bool, optional
If ``True``, use the translation-combined (TC) symmetry construction.
Returns
-------
tuple
Sparse left/right representations of the momentum-basis projector
``B``:
``(L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data)``,
where the first three arrays encode the CSC representation of ``B`` and
the last three arrays encode the CSR representation of ``B``.
"""
lvals = np.ascontiguousarray(lvals, dtype=np.int32)
unit_cell_size = np.ascontiguousarray(unit_cell_size, dtype=np.int32)
k_vals = np.ascontiguousarray(k_vals, dtype=np.int32)
# Dispatch to the correct projector constructor: finite-k vs zero-k, and
# standard translations vs the special TC symmetry path.
if np.any(k_vals != 0):
if TC_symmetry:
return momentum_basis_finite_k_TC(sector_configs, k_vals)
return momentum_basis_finite_k(sector_configs, lvals, unit_cell_size, k_vals)
if TC_symmetry:
return momentum_basis_zero_k_TC(sector_configs)
return momentum_basis_zero_k(sector_configs, lvals, unit_cell_size)
# Optional: Uncomment for debugging
# if not check_normalization(basis) or not check_orthogonality(basis):
# raise ValueError("Basis normalization or orthogonality failed.")
# ─────────────────────────────────────────────────────────────────────────────
@njit(cache=True, parallel=True)
def precompute_nonzero_csr(matrix: np.ndarray):
"""
Build CSR pointers for nonzeros of each column of `mat`.
Returns (indptr, indices) so that
indices[indptr[p]:indptr[p+1]] are the rows c where mat[c,p]!=0.
"""
Nx, Ny = matrix.shape
col_counts = np.zeros(Ny, np.int32)
for jj in prange(Ny):
for ii in range(Nx):
if np.abs(matrix[ii, jj]) > 1e-10:
col_counts[jj] += 1
indptr = np.empty(Ny + 1, np.int32)
indptr[0] = 0
for jj in range(Ny):
indptr[jj + 1] = indptr[jj] + col_counts[jj]
total_nz = indptr[Ny]
indices = np.empty(total_nz, np.int32)
for jj in prange(Ny):
start = indptr[jj]
pos = start
for ii in range(Nx):
if np.abs(matrix[ii, jj]) > 1e-10:
indices[pos] = ii
pos += 1
return indptr, indices
[docs]
@njit(cache=True, parallel=True)
def nbody_data_momentum(
op_list: np.ndarray, # (n_ops, n_sites, d_loc, d_loc)
op_sites_list: np.ndarray, # (n_ops,), int32
sector_configs: np.ndarray, # (n_states, n_sites), int32
# --- sparse momentum basis B arrays ---
L_col_ptr: np.ndarray, # (proj_dim+1,) int32
L_row_idx: np.ndarray, # (nnz_B,) int32
L_data: np.ndarray, # (nnz_B,) float64/complex128
R_row_ptr: np.ndarray, # (n_states+1,) int32
R_col_idx: np.ndarray, # (nnz_B,) int32
R_data: np.ndarray, # (nnz_B,) float64/complex128
):
"""Build sparse triplets for a generic n-body operator in momentum basis.
Parameters
----------
op_list : ndarray
Site-resolved factorized operator data of shape
``(n_ops, n_sites, d_loc, d_loc)``.
op_sites_list : ndarray
Sites where each operator factor acts (shape ``(n_ops,)``).
sector_configs : ndarray
Symmetry-sector configurations, one row per basis state.
L_col_ptr, L_row_idx, L_data : ndarray
CSC representation of momentum projector ``B``.
R_row_ptr, R_col_idx, R_data : ndarray
CSR representation of momentum projector ``B``.
Returns
-------
tuple
``(row_list, col_list, value_list)`` triplets for the projected
operator ``B^H H B``.
"""
_, n_sites = sector_configs.shape
n_ops = len(op_sites_list)
proj_dim = L_col_ptr.shape[0] - 1
if n_ops < 1:
raise ValueError(f"nbody operator must act on at least one site, got {n_ops}")
# Precompute all one-site transitions once and reuse them in both passes.
transition_counts_all, ket_local_states_all, transition_values_all = (
_prepare_momentum_local_transition_data(op_list, op_sites_list)
)
# ----------------------------
# PASS 1: count nnz per projected row
# ----------------------------
nnz_per_row = np.zeros(proj_dim, np.int32)
for proj_row_idx in prange(proj_dim):
row_nnz = 0
# The left CSC projector tells which real-space bra rows contribute to
# this momentum row.
left_start = L_col_ptr[proj_row_idx]
left_stop = L_col_ptr[proj_row_idx + 1]
for left_ptr in range(left_start, left_stop):
bra_idx = L_row_idx[left_ptr]
bra_cfg = sector_configs[bra_idx]
bra_local_states = np.empty(n_ops, np.int32)
active_transition_counts = np.empty(n_ops, np.int32)
combo_count = 1
for op_idx in range(n_ops):
site_idx = op_sites_list[op_idx]
bra_loc = bra_cfg[site_idx]
bra_local_states[op_idx] = bra_loc
active_transition_counts[op_idx] = transition_counts_all[
op_idx, bra_loc
]
combo_count *= active_transition_counts[op_idx]
if combo_count == 0:
continue
ket_cfg = np.empty(n_sites, np.int32)
# Start every candidate ket from the bra configuration so that
# untouched sites remain unchanged throughout the mixed-radix walk.
for site_idx in range(n_sites):
ket_cfg[site_idx] = bra_cfg[site_idx]
transition_counters = np.zeros(n_ops, np.int32)
finished = False
while not finished:
# Overwrite only the acted sites; all untouched sites stay equal
# to the bra because ``ket_cfg`` started as a copy of it.
for op_idx in range(n_ops):
trans_idx = transition_counters[op_idx]
bra_loc = bra_local_states[op_idx]
ket_cfg[op_sites_list[op_idx]] = ket_local_states_all[
op_idx, bra_loc, trans_idx
]
ket_idx = config_to_index_binarysearch(ket_cfg, sector_configs)
if ket_idx >= 0:
row_nnz += R_row_ptr[ket_idx + 1] - R_row_ptr[ket_idx]
for op_idx in range(n_ops - 1, -1, -1):
transition_counters[op_idx] += 1
if transition_counters[op_idx] < active_transition_counts[op_idx]:
break
transition_counters[op_idx] = 0
if op_idx == 0:
finished = True
nnz_per_row[proj_row_idx] = row_nnz
# Prefix-sum row offsets
offset = 0
for proj_row_idx in range(proj_dim):
tmp_nnz = nnz_per_row[proj_row_idx]
nnz_per_row[proj_row_idx] = offset
offset += tmp_nnz
total_nnz = offset
# ----------------------------
# PASS 2: fill triplets
# ----------------------------
row_list = np.empty(total_nnz, np.int32)
col_list = np.empty(total_nnz, np.int32)
value_list = np.empty(total_nnz, np.complex128)
for proj_row_idx in prange(proj_dim):
write_ptr = nnz_per_row[proj_row_idx]
left_start = L_col_ptr[proj_row_idx]
left_stop = L_col_ptr[proj_row_idx + 1]
for left_ptr in range(left_start, left_stop):
bra_idx = L_row_idx[left_ptr]
# Left projector contributes the bra-side amplitude <prow|bra>.
amp_left = np.conj(L_data[left_ptr])
bra_cfg = sector_configs[bra_idx]
bra_local_states = np.empty(n_ops, np.int32)
active_transition_counts = np.empty(n_ops, np.int32)
combo_count = 1
for op_idx in range(n_ops):
site_idx = op_sites_list[op_idx]
bra_loc = bra_cfg[site_idx]
bra_local_states[op_idx] = bra_loc
active_transition_counts[op_idx] = transition_counts_all[
op_idx, bra_loc
]
combo_count *= active_transition_counts[op_idx]
if combo_count == 0:
continue
ket_cfg = np.empty(n_sites, np.int32)
# Start every candidate ket from the bra configuration so that
# untouched sites remain unchanged throughout the mixed-radix walk.
for site_idx in range(n_sites):
ket_cfg[site_idx] = bra_cfg[site_idx]
transition_counters = np.zeros(n_ops, np.int32)
finished = False
while not finished:
# Rebuild the full operator amplitude from the selected local
# transitions and update only the acted sites in the ket.
amp_mid = 1.0 + 0.0j
for op_idx in range(n_ops):
bra_loc = bra_local_states[op_idx]
trans_idx = transition_counters[op_idx]
amp_mid *= transition_values_all[op_idx, bra_loc, trans_idx]
ket_cfg[op_sites_list[op_idx]] = ket_local_states_all[
op_idx, bra_loc, trans_idx
]
ket_idx = config_to_index_binarysearch(ket_cfg, sector_configs)
if ket_idx >= 0:
# The right CSR projector tells which momentum columns the
# real-space ket contributes to, together with their amplitudes.
right_start = R_row_ptr[ket_idx]
right_stop = R_row_ptr[ket_idx + 1]
for right_ptr in range(right_start, right_stop):
proj_col_idx = R_col_idx[right_ptr]
amp_right = R_data[right_ptr]
row_list[write_ptr] = proj_row_idx
col_list[write_ptr] = proj_col_idx
value_list[write_ptr] = amp_left * amp_mid * amp_right
write_ptr += 1
for op_idx in range(n_ops - 1, -1, -1):
transition_counters[op_idx] += 1
if transition_counters[op_idx] < active_transition_counts[op_idx]:
break
transition_counters[op_idx] = 0
if op_idx == 0:
finished = True
return row_list, col_list, value_list
[docs]
@njit(cache=True, parallel=True)
def nbody_data_momentum_1site(
op_list: np.ndarray, # (1, n_sites, d_loc, d_loc)
op_sites_list: np.ndarray, # (1,), int32
sector_configs: np.ndarray, # (N, n_sites), int32
# --- sparse momentum basis B arrays ---
L_col_ptr: np.ndarray, # (Ldim+1,) int32 -- columns of B
L_row_idx: np.ndarray, # (nnz_B,) int32 -- rows for each CSC entry
L_data: np.ndarray, # (nnz_B,) float64 or complex128 -- B[row, col]
R_row_ptr: np.ndarray, # (N+1,) int32 -- rows of B
R_col_idx: np.ndarray, # (nnz_B,) int32 -- cols for each CSR entry
R_data: np.ndarray, # (nnz_B,) float64 or complex128 -- B[row, col]
):
"""Build sparse triplets for a one-site operator in the momentum basis.
Parameters
----------
op_list : ndarray
One-site factorized operator data (shape ``(1, n_sites, d_loc, d_loc)``).
op_sites_list : ndarray
Site index of the operator action (shape ``(1,)``).
sector_configs : ndarray
Symmetry-sector configurations, one row per basis state.
L_col_ptr, L_row_idx, L_data : ndarray
CSC representation of the momentum-basis projector ``B``.
R_row_ptr, R_col_idx, R_data : ndarray
CSR representation of the same projector ``B``.
Returns
-------
tuple
``(row_list, col_list, value_list)`` triplets for the projected
operator ``B^† H B``.
"""
n_sites = sector_configs.shape[1]
Ldim = L_col_ptr.size - 1 # number of momentum columns (dim of projected space)
# Selected site/operator
site = op_sites_list[0]
transition_counts_all, ket_local_states_all, transition_values_all = (
_prepare_momentum_local_transition_data(op_list, op_sites_list)
)
# ----------------------------
# PASS 1: count nnz per momentum-row (prow)
# ----------------------------
nnz_per_row = np.zeros(Ldim, np.int32)
for prow in prange(Ldim):
cnt = 0
# (CHANGED) iterate all real-space rows j1 with B[j1, prow] != 0 via CSC
start1 = L_col_ptr[prow]
stop1 = L_col_ptr[prow + 1]
for p1 in range(start1, stop1):
j1 = L_row_idx[p1] # row index in real space
# bra config
bra_cfg = sector_configs[j1]
bra_local_state = bra_cfg[site]
num_local_transitions = transition_counts_all[0, bra_local_state]
# scratch ket (copy bra → then edit one site)
ket_cfg = np.empty(n_sites, np.int32)
for jj in range(n_sites):
ket_cfg[jj] = bra_cfg[jj]
# for each allowed local change bra_local_state->ket_local_state, find j2
for local_transition_idx in range(num_local_transitions):
ket_local_state = ket_local_states_all[
0, bra_local_state, local_transition_idx
]
ket_cfg[site] = ket_local_state
j2 = config_to_index_binarysearch(ket_cfg, sector_configs)
if j2 < 0:
continue
# (CHANGED) number of nonzero momentum-cols for row j2 via CSR
cnt += R_row_ptr[j2 + 1] - R_row_ptr[j2]
nnz_per_row[prow] = cnt
# prefix-sum offsets
offset = 0
for prow in range(Ldim):
tmp = nnz_per_row[prow]
nnz_per_row[prow] = offset
offset += tmp
total_nnz = offset
# ----------------------------
# PASS 2: fill triplets
# ----------------------------
row_list = np.empty(total_nnz, np.int32)
col_list = np.empty(total_nnz, np.int32)
value_list = np.empty(total_nnz, np.complex128)
for prow in prange(Ldim):
ptr = nnz_per_row[prow]
# (CHANGED) iterate B[:, prow] via CSC
start1 = L_col_ptr[prow]
stop1 = L_col_ptr[prow + 1]
for p1 in range(start1, stop1):
j1 = L_row_idx[p1] # real-space row
B1 = L_data[p1] # value B[j1, prow] (float64 or complex128)
bra_cfg = sector_configs[j1]
bra_local_state = bra_cfg[site]
num_local_transitions = transition_counts_all[0, bra_local_state]
# scratch ket
ket_cfg = np.empty(n_sites, np.int32)
for jj in range(n_sites):
ket_cfg[jj] = bra_cfg[jj]
# explicit 1-nested loop + projection
for local_transition_idx in range(num_local_transitions):
ket_local_state = ket_local_states_all[
0, bra_local_state, local_transition_idx
]
transition_value = transition_values_all[
0, bra_local_state, local_transition_idx
]
ket_cfg[site] = ket_local_state
j2 = config_to_index_binarysearch(ket_cfg, sector_configs)
if j2 < 0:
continue
# (CHANGED) project into momentum columns of row j2 via CSR
start2 = R_row_ptr[j2]
stop2 = R_row_ptr[j2 + 1]
for p2 in range(start2, stop2):
pcol = R_col_idx[p2]
B2 = R_data[p2] # value B[j2, pcol]
# (CHANGED) use sparse B entries
val = np.conj(B1) * transition_value * B2
row_list[ptr] = prow
col_list[ptr] = pcol
value_list[ptr] = val
ptr += 1
return row_list, col_list, value_list
[docs]
@njit(cache=True, parallel=True)
def nbody_data_momentum_2sites(
op_list: np.ndarray, # shape (2, n_sites, d_loc, d_loc)
op_sites_list: np.ndarray, # shape (2,), int32
sector_configs: np.ndarray, # shape (N, n_sites), int32
# ---- momentum basis B in sparse form ----
L_col_ptr: np.ndarray, # (Ldim+1,), int32 -- columns of B
L_row_idx: np.ndarray, # (nnz_B,), int32 -- real-space rows j with B[j, prow] != 0
L_data: np.ndarray, # (nnz_B,), complex128/float64 -- B[j, prow]
R_row_ptr: np.ndarray, # (N+1,), int32 -- rows of B
R_col_idx: np.ndarray, # (nnz_B,), int32 -- projected cols pcol with B[j, pcol] != 0
R_data: np.ndarray, # (nnz_B,), complex128/float64 -- B[j, pcol]
):
"""Build sparse triplets for a two-site operator in the momentum basis.
Parameters
----------
op_list : ndarray
Two-site factorized operator data (shape ``(2, n_sites, d_loc, d_loc)``).
op_sites_list : ndarray
Two site indices where the operator acts.
sector_configs : ndarray
Symmetry-sector configurations, one row per basis state.
L_col_ptr, L_row_idx, L_data : ndarray
CSC representation of the momentum-basis projector ``B``.
R_row_ptr, R_col_idx, R_data : ndarray
CSR representation of the same projector ``B``.
Returns
-------
tuple
``(row_list, col_list, value_list)`` triplets for the projected
operator ``B^H H B``.
"""
_, n_sites = sector_configs.shape
Ldim = L_col_ptr.shape[0] - 1
transition_counts_all, ket_local_states_all, transition_values_all = (
_prepare_momentum_local_transition_data(op_list, op_sites_list)
)
# -------------------------------
# PASS 1: count nonzeros per projected row
# -------------------------------
nnz_per_row = np.zeros(Ldim, np.int32)
for prow in prange(Ldim):
cnt = 0
# all real-space rows j1 with B[j1, prow] != 0 (CSC of B)
start1 = L_col_ptr[prow]
stop1 = L_col_ptr[prow + 1]
for idx1 in range(start1, stop1):
j1 = L_row_idx[idx1]
bra_cfg = sector_configs[j1]
bra_loc0 = bra_cfg[op_sites_list[0]]
bra_loc1 = bra_cfg[op_sites_list[1]]
len0 = transition_counts_all[0, bra_loc0]
len1 = transition_counts_all[1, bra_loc1]
# scratch ket config (start from bra each time)
ket_cfg = np.empty(n_sites, np.int32)
for jj in range(n_sites):
ket_cfg[jj] = bra_cfg[jj]
site0, site1 = op_sites_list[0], op_sites_list[1]
# explicit 2-nested loops → real-space j2
for i0 in range(len0):
b0 = ket_local_states_all[0, bra_loc0, i0]
ket_cfg[site0] = b0
for i1 in range(len1):
b1 = ket_local_states_all[1, bra_loc1, i1]
ket_cfg[site1] = b1
# find j2 in the sector
j2 = config_to_index_binarysearch(ket_cfg, sector_configs)
if j2 < 0:
continue
# number of projected columns reachable from j2 (CSR of B)
cnt += R_row_ptr[j2 + 1] - R_row_ptr[j2]
nnz_per_row[prow] = cnt
# prefix-sum → row offsets
offset = 0
for prow in range(Ldim):
tmp = nnz_per_row[prow]
nnz_per_row[prow] = offset
offset += tmp
total_nnz = offset
# -------------------------------
# PASS 2: fill triplets
# -------------------------------
row_list = np.empty(total_nnz, np.int32)
col_list = np.empty(total_nnz, np.int32)
value_list = np.empty(total_nnz, np.complex128)
for prow in prange(Ldim):
ptr = nnz_per_row[prow]
start1 = L_col_ptr[prow]
stop1 = L_col_ptr[prow + 1]
for idx1 in range(start1, stop1):
j1 = L_row_idx[idx1]
bra_cfg = sector_configs[j1]
amp_L = np.conj(L_data[idx1]) # conj(B[j1, prow])
bra_loc0 = bra_cfg[op_sites_list[0]]
bra_loc1 = bra_cfg[op_sites_list[1]]
len0 = transition_counts_all[0, bra_loc0]
len1 = transition_counts_all[1, bra_loc1]
ket_cfg = np.empty(n_sites, np.int32)
for jj in range(n_sites):
ket_cfg[jj] = bra_cfg[jj]
site0, site1 = op_sites_list[0], op_sites_list[1]
for i0 in range(len0):
b0 = ket_local_states_all[0, bra_loc0, i0]
v0 = transition_values_all[0, bra_loc0, i0]
ket_cfg[site0] = b0
for i1 in range(len1):
b1 = ket_local_states_all[1, bra_loc1, i1]
v1 = transition_values_all[1, bra_loc1, i1]
ket_cfg[site1] = b1
j2 = config_to_index_binarysearch(ket_cfg, sector_configs)
if j2 < 0:
continue
amp_M = v0 * v1
# project into momentum columns (CSR of B)
start2 = R_row_ptr[j2]
stop2 = R_row_ptr[j2 + 1]
for tt in range(start2, stop2):
pcol = R_col_idx[tt]
amp_R = R_data[tt] # B[j2, pcol]
val = amp_L * amp_M * amp_R
row_list[ptr] = prow
col_list[ptr] = pcol
value_list[ptr] = val
ptr += 1
return row_list, col_list, value_list
[docs]
@njit(cache=True, parallel=True)
def nbody_data_momentum_4sites(
op_list: np.ndarray, # shape (4, n_sites, d_loc, d_loc)
op_sites_list: np.ndarray, # shape (4,), int32
sector_configs: np.ndarray, # shape (N, n_sites), int32
# ---- momentum basis B in sparse form ----
L_col_ptr: np.ndarray, # (Ldim+1,), int32
L_row_idx: np.ndarray, # (nnz_B,), int32
L_data: np.ndarray, # (nnz_B,), complex128/float64
R_row_ptr: np.ndarray, # (N+1,), int32
R_col_idx: np.ndarray, # (nnz_B,), int32
R_data: np.ndarray, # (nnz_B,), complex128/float64
):
"""Build sparse triplets for a four-site operator in the momentum basis.
Parameters
----------
op_list : ndarray
Four-site factorized operator data (shape ``(4, n_sites, d_loc, d_loc)``).
op_sites_list : ndarray
Four site indices where the operator acts.
sector_configs : ndarray
Symmetry-sector configurations, one row per basis state.
L_col_ptr, L_row_idx, L_data : ndarray
CSC representation of the momentum-basis projector ``B``.
R_row_ptr, R_col_idx, R_data : ndarray
CSR representation of the same projector ``B``.
Returns
-------
tuple
``(row_list, col_list, value_list)`` triplets for the projected
operator ``B^H H B``.
"""
_, n_sites = sector_configs.shape
Ldim = L_col_ptr.shape[0] - 1
transition_counts_all, ket_local_states_all, transition_values_all = (
_prepare_momentum_local_transition_data(op_list, op_sites_list)
)
# -------------------------------
# PASS 1: count nonzeros per projected row
# -------------------------------
nnz_per_row = np.zeros(Ldim, np.int32)
for prow in prange(Ldim):
cnt = 0
# all real-space rows j1 with B[j1, prow] != 0 (CSC of B)
start1 = L_col_ptr[prow]
stop1 = L_col_ptr[prow + 1]
for idx1 in range(start1, stop1):
j1 = L_row_idx[idx1]
bra_cfg = sector_configs[j1]
bra_loc0 = bra_cfg[op_sites_list[0]]
bra_loc1 = bra_cfg[op_sites_list[1]]
bra_loc2 = bra_cfg[op_sites_list[2]]
bra_loc3 = bra_cfg[op_sites_list[3]]
len0 = transition_counts_all[0, bra_loc0]
len1 = transition_counts_all[1, bra_loc1]
len2 = transition_counts_all[2, bra_loc2]
len3 = transition_counts_all[3, bra_loc3]
ket_cfg = np.empty(n_sites, np.int32)
for jj in range(n_sites):
ket_cfg[jj] = bra_cfg[jj]
site0 = op_sites_list[0]
site1 = op_sites_list[1]
site2 = op_sites_list[2]
site3 = op_sites_list[3]
# explicit 4-nested loops → real-space j2
for i0 in range(len0):
b0 = ket_local_states_all[0, bra_loc0, i0]
ket_cfg[site0] = b0
for i1 in range(len1):
b1 = ket_local_states_all[1, bra_loc1, i1]
ket_cfg[site1] = b1
for i2 in range(len2):
b2 = ket_local_states_all[2, bra_loc2, i2]
ket_cfg[site2] = b2
for i3 in range(len3):
b3 = ket_local_states_all[3, bra_loc3, i3]
ket_cfg[site3] = b3
j2 = config_to_index_binarysearch(ket_cfg, sector_configs)
if j2 < 0:
continue
# number of projected columns from j2 (CSR of B)
cnt += R_row_ptr[j2 + 1] - R_row_ptr[j2]
nnz_per_row[prow] = cnt
# prefix-sum → row offsets
offset = 0
for prow in range(Ldim):
tmp = nnz_per_row[prow]
nnz_per_row[prow] = offset
offset += tmp
total_nnz = offset
# -------------------------------
# PASS 2: fill triplets
# -------------------------------
row_list = np.empty(total_nnz, np.int32)
col_list = np.empty(total_nnz, np.int32)
value_list = np.empty(total_nnz, np.complex128)
for prow in prange(Ldim):
ptr = nnz_per_row[prow]
start1 = L_col_ptr[prow]
stop1 = L_col_ptr[prow + 1]
for idx1 in range(start1, stop1):
j1 = L_row_idx[idx1]
bra_cfg = sector_configs[j1]
amp_L = np.conj(L_data[idx1]) # conj(B[j1, prow])
bra_loc0 = bra_cfg[op_sites_list[0]]
bra_loc1 = bra_cfg[op_sites_list[1]]
bra_loc2 = bra_cfg[op_sites_list[2]]
bra_loc3 = bra_cfg[op_sites_list[3]]
len0 = transition_counts_all[0, bra_loc0]
len1 = transition_counts_all[1, bra_loc1]
len2 = transition_counts_all[2, bra_loc2]
len3 = transition_counts_all[3, bra_loc3]
ket_cfg = np.empty(n_sites, np.int32)
for jj in range(n_sites):
ket_cfg[jj] = bra_cfg[jj]
site0 = op_sites_list[0]
site1 = op_sites_list[1]
site2 = op_sites_list[2]
site3 = op_sites_list[3]
# explicit 4-loops + projection
for i0 in range(len0):
b0 = ket_local_states_all[0, bra_loc0, i0]
v0 = transition_values_all[0, bra_loc0, i0]
ket_cfg[site0] = b0
for i1 in range(len1):
b1 = ket_local_states_all[1, bra_loc1, i1]
v1 = transition_values_all[1, bra_loc1, i1]
ket_cfg[site1] = b1
for i2 in range(len2):
b2 = ket_local_states_all[2, bra_loc2, i2]
v2 = transition_values_all[2, bra_loc2, i2]
ket_cfg[site2] = b2
for i3 in range(len3):
b3 = ket_local_states_all[3, bra_loc3, i3]
v3 = transition_values_all[3, bra_loc3, i3]
ket_cfg[site3] = b3
j2 = config_to_index_binarysearch(ket_cfg, sector_configs)
if j2 < 0:
continue
amp_M = v0 * v1 * v2 * v3
# project into momentum columns (CSR of B)
start2 = R_row_ptr[j2]
stop2 = R_row_ptr[j2 + 1]
for tt in range(start2, stop2):
pcol = R_col_idx[tt]
amp_R = R_data[tt] # B[j2, pcol]
val = amp_L * amp_M * amp_R
row_list[ptr] = prow
col_list[ptr] = pcol
value_list[ptr] = val
ptr += 1
return row_list, col_list, value_list