Source code for edlgt.symmetries.translational_sym

"""Translation symmetry and momentum-basis construction utilities.

This module builds translation orbits and sparse momentum-basis projectors for
symmetry-sector configuration tables, and provides momentum-projected sparse
operator kernels for generic factorized n-body operators.
"""

import numpy as np
from numba import njit, prange

from edlgt.tools.config_encoding import config_to_index_binarysearch

OP_EPS = 1e-12  # small cutoff for local operator entries

__all__ = [
    "check_normalization",
    "check_orthogonality",
    "get_translated_state_indices",
    "get_reference_indices",
    "get_momentum_basis",
    "nbody_data_momentum",
    "nbody_data_momentum_4sites",
    "nbody_data_momentum_2sites",
    "nbody_data_momentum_1site",
]


[docs] @njit(cache=True) def get_translated_state_indices(config, sector_configs, logical_unit_size=1): """Generate all 1D translations of a configuration in a sorted sector basis.""" n_sites = len(config) if n_sites % logical_unit_size != 0: raise ValueError("Number of sites is not a multiple of the logical unit size.") if n_sites != sector_configs.shape[1]: raise ValueError( f"config.shape[0]={n_sites} must be equal to " f"sector_configs.shape[1]={sector_configs.shape[1]}" ) num_translations = n_sites // logical_unit_size trans_indices = np.zeros(num_translations, dtype=np.int32) for translation_idx in range(num_translations): roll_steps = translation_idx * logical_unit_size rolled_config = np.roll(config, -roll_steps) trans_indices[translation_idx] = config_to_index_binarysearch( rolled_config, sector_configs ) return trans_indices
[docs] @njit(cache=True) def get_reference_indices(sector_configs): """Select translation-inequivalent reference configurations in 1D.""" sector_dim = sector_configs.shape[0] normalization = np.zeros(sector_dim, dtype=np.int32) independent_indices = np.zeros(sector_dim, dtype=np.bool_) for cfg_idx in range(sector_dim): config = sector_configs[cfg_idx] trans_indices = get_translated_state_indices(config, sector_configs) is_independent = True for prev_cfg_idx in range(cfg_idx): if independent_indices[prev_cfg_idx] and prev_cfg_idx in trans_indices: is_independent = False break if is_independent: independent_indices[cfg_idx] = True normalization[cfg_idx] = len(np.unique(trans_indices)) ref_indices = np.flatnonzero(independent_indices) norm = normalization[ref_indices] return ref_indices, norm
@njit(cache=True) def _prepare_momentum_local_transition_data( op_list: np.ndarray, op_sites_list: np.ndarray ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Precompute nonzero local transitions for the momentum-space kernels. Parameters ---------- op_list : ndarray Site-resolved operator matrices. op_sites_list : ndarray Site indices on which the operator acts. Returns ------- tuple ``(transition_counts, ket_local_states, transition_values)`` where ``transition_counts[op_idx, bra_loc]`` stores how many outgoing local transitions are available for operator factor ``op_idx`` from local bra state ``bra_loc``. Notes ----- ``local_dim`` is the common padded local dimension stored by the rectangular operator tensor ``op_list``. Sites with smaller physical local Hilbert spaces are represented by matrices whose unused rows/columns are zero padded. """ n_ops = len(op_sites_list) local_dim = np.int32(op_list.shape[2]) # For each operator factor and each local bra state, store the list of # allowed local ket states together with the corresponding matrix elements. transition_counts = np.zeros((n_ops, local_dim), dtype=np.int32) ket_local_states = np.empty((n_ops, local_dim, local_dim), dtype=np.int32) transition_values = np.empty((n_ops, local_dim, local_dim), dtype=np.complex128) for op_idx in range(n_ops): site_idx = op_sites_list[op_idx] site_op = op_list[op_idx, site_idx] for bra_loc in range(local_dim): n_transitions = 0 for ket_loc in range(local_dim): elem = site_op[bra_loc, ket_loc] if np.abs(elem) > OP_EPS: # Compactify the nonzero row entries so later kernels can # iterate only over the allowed local transitions. ket_local_states[op_idx, bra_loc, n_transitions] = ket_loc transition_values[op_idx, bra_loc, n_transitions] = elem n_transitions += 1 transition_counts[op_idx, bra_loc] = n_transitions return transition_counts, ket_local_states, transition_values @njit(cache=True) def _prefix_sum_counts(nnz_per_col: np.ndarray) -> np.ndarray: """ Turn column nnz counts into CSC col_ptr with a prefix sum. col_ptr has length (n_cols + 1) and col_ptr[-1] = total_nnz. """ n_cols = nnz_per_col.size col_ptr = np.empty(n_cols + 1, np.int32) running_total = 0 col_ptr[0] = 0 for col_idx in range(n_cols): running_total += nnz_per_col[col_idx] col_ptr[col_idx + 1] = running_total return col_ptr @njit(inline="always") def _insertion_sort_by_row(rows: np.ndarray, vals: np.ndarray, length: int) -> None: """ Small, stable, numba-friendly in-place sort by 'rows' for the first 'length' items. Keeps CSC column rows in ascending order (nice to have, often expected). """ for entry_idx in range(1, length): row_key = rows[entry_idx] value_key = vals[entry_idx] insert_idx = entry_idx - 1 while insert_idx >= 0 and rows[insert_idx] > row_key: rows[insert_idx + 1] = rows[insert_idx] vals[insert_idx + 1] = vals[insert_idx] insert_idx -= 1 rows[insert_idx + 1] = row_key vals[insert_idx + 1] = value_key @njit(cache=True, parallel=True) def precompute_C_sign_per_config( sector_configs: np.ndarray, # (N, L) int32 C_label_sign: np.ndarray, # (d_loc,) float64, entries ∈ {+1.0, -1.0} ) -> np.ndarray: """ Returns S[i] = ∏_j C_label_sign[ sector_configs[i, j] ] as float64 ±1. Works for both k=0 and finite-k paths. """ n_configs, n_sites = sector_configs.shape out = np.ones(n_configs, np.float64) for cfg_idx in prange(n_configs): parity_sign = 1.0 for site_idx in range(n_sites): parity_sign *= C_label_sign[sector_configs[cfg_idx, site_idx]] out[cfg_idx] = parity_sign return out # ---------- linear index <-> coords (ROW-MAJOR)---------- @njit(inline="always") def linear_to_coords_rowmajor( site_index: int, lvals: np.ndarray, out_coords: np.ndarray ) -> None: """ Convert a linear site index [0..prod(L)-1] into D coords in row-major order. out_coords is preallocated (length D). """ lattice_dim = lvals.size for ax in range(lattice_dim - 1, -1, -1): axis_length = lvals[ax] out_coords[ax] = site_index % axis_length site_index //= axis_length @njit(inline="always") def coords_to_linear_rowmajor(coords: np.ndarray, lvals: np.ndarray) -> int: """ Convert D coords -> linear site index in row-major order. """ lattice_dim = lvals.size idx = 0 for ax in range(lattice_dim): idx = idx * lvals[ax] + coords[ax] return idx # ---------- mixed-radix encode/decode of per-axis block-shifts ---------- @njit(inline="always") def encode_shift(axis_shifts: np.ndarray, shifts_per_dir: np.ndarray) -> int: """ Map a D-vector of per-axis block shifts t=(t0,...,t_{D-1}) with ranges [0..R_d-1] to a single flat index in [0..prod(R)-1], using row-major. """ flat = 0 lattice_dim = shifts_per_dir.size for ax in range(lattice_dim): flat = flat * shifts_per_dir[ax] + (axis_shifts[ax] % shifts_per_dir[ax]) return flat @njit(inline="always") def decode_shift( flat_index: int, shifts_per_dir: np.ndarray, out_axis_shifts: np.ndarray ) -> None: """ The inverse of encode_shift: flat -> D-vector of per-axis shifts. """ lattice_dim = shifts_per_dir.size for ax in range(lattice_dim - 1, -1, -1): out_axis_shifts[ax] = flat_index % shifts_per_dir[ax] flat_index //= shifts_per_dir[ax] @njit(inline="always") def decode_mixed_index(idx: int, bases: np.ndarray, out: np.ndarray) -> None: """ Mixed-radix decode matching encode_shift's row-major convention: axis 0 is most significant, axis D-1 least significant. """ lattice_dim = bases.size for ax in range(lattice_dim - 1, -1, -1): out[ax] = idx % bases[ax] idx //= bases[ax] @njit(cache=True) def _compute_rowmajor_strides(bases: np.ndarray) -> np.ndarray: """Return row-major strides matching :func:`encode_shift`.""" lattice_dim = bases.size strides = np.empty(lattice_dim, np.int32) stride = 1 for ax in range(lattice_dim - 1, -1, -1): # In row-major mixed-radix order, axis ``ax`` advances by the product # of all bases to its right. strides[ax] = stride stride *= bases[ax] return strides @njit(inline="always") def _next_power_of_two_at_least(min_size: int) -> int: """Return the smallest power of two greater than or equal to ``min_size``.""" size = 1 while size < min_size: size <<= 1 return size @njit(inline="always") def _advance_period_box_counter( t_local: np.ndarray, pvec: np.ndarray, shift_strides: np.ndarray, full_shift: int ) -> tuple[int, bool]: """Advance the local mixed-radix counter and update the full shift index. ``t_local`` enumerates the period box with per-axis bases ``pvec``. ``full_shift`` stores the corresponding flat shift index in the full translation table, using the row-major strides of ``shifts_per_dir``. """ lattice_dim = pvec.size for ax in range(lattice_dim - 1, -1, -1): # Try to advance the least-significant still-active axis first. t_local[ax] += 1 full_shift += shift_strides[ax] if t_local[ax] < pvec[ax]: return full_shift, False # This digit overflowed: reset it and remove the completed block. full_shift -= pvec[ax] * shift_strides[ax] t_local[ax] = 0 return full_shift, True @njit(inline="always") def _insert_or_accumulate_real( cfg_row: int, incr: float, used_indices: np.ndarray, used_values: np.ndarray, hash_keys: np.ndarray, hash_positions: np.ndarray, hash_mask: int, used_len: int, ) -> int: """Insert or accumulate one real-valued orbit contribution.""" slot = cfg_row & hash_mask while True: key = hash_keys[slot] if key == -1: # First time this translated row appears in the orbit column. hash_keys[slot] = cfg_row hash_positions[slot] = used_len used_indices[used_len] = cfg_row used_values[used_len] = incr return used_len + 1 if key == cfg_row: # Repeated image of the same basis row: accumulate its weight. used_values[hash_positions[slot]] += incr return used_len slot = (slot + 1) & hash_mask @njit(inline="always") def _insert_or_accumulate_complex( cfg_row: int, incr: complex, used_indices: np.ndarray, used_values: np.ndarray, hash_keys: np.ndarray, hash_positions: np.ndarray, hash_mask: int, used_len: int, ) -> int: """Insert or accumulate one complex-valued orbit contribution.""" slot = cfg_row & hash_mask while True: key = hash_keys[slot] if key == -1: # First time this translated row appears in the orbit column. hash_keys[slot] = cfg_row hash_positions[slot] = used_len used_indices[used_len] = cfg_row used_values[used_len] = incr return used_len + 1 if key == cfg_row: # Repeated image of the same basis row: accumulate its Bloch phase. used_values[hash_positions[slot]] += incr return used_len slot = (slot + 1) & hash_mask @njit(cache=True) def _accumulate_zero_k_orbit( translations_row: np.ndarray, pvec: np.ndarray, shift_strides: np.ndarray ) -> tuple[np.ndarray, np.ndarray, int]: """Accumulate one zero-momentum orbit column with hashed deduplication.""" lattice_dim = pvec.size orbit_size = 1 for ax in range(lattice_dim): orbit_size *= pvec[ax] # Over-allocate one slot per point in the period box. After deduplication, # only the first ``used_len`` entries are meaningful. used_indices = np.empty(orbit_size, np.int32) used_values = np.zeros(orbit_size, np.float64) # The hash table maps translated basis rows -> position inside the compact # ``used_*`` arrays. A power-of-two size keeps linear probing simple. hash_size = _next_power_of_two_at_least(2 * orbit_size) hash_keys = np.full(hash_size, -1, np.int32) hash_positions = np.empty(hash_size, np.int32) hash_mask = hash_size - 1 # ``t_local`` walks the local period box, while ``full_shift`` tracks the # corresponding flat index inside the full translation table row. t_local = np.zeros(lattice_dim, np.int32) full_shift = 0 used_len = 0 finished = False while not finished: cfg_row = translations_row[full_shift] used_len = _insert_or_accumulate_real( cfg_row, 1.0, used_indices, used_values, hash_keys, hash_positions, hash_mask, used_len, ) full_shift, finished = _advance_period_box_counter( t_local, pvec, shift_strides, full_shift ) return used_indices, used_values, used_len @njit(cache=True) def _prepare_phase_table(k_vals: np.ndarray, shifts_per_dir: np.ndarray) -> np.ndarray: """Precompute per-axis Bloch phases ``exp(-2πi k_d t / R_d)``.""" lattice_dim = shifts_per_dir.size max_shift = 0 for ax in range(lattice_dim): if shifts_per_dir[ax] > max_shift: max_shift = shifts_per_dir[ax] phase_table = np.empty((lattice_dim, max_shift), np.complex128) for ax in range(lattice_dim): # Shift zero always carries unit phase. phase_table[ax, 0] = 1.0 + 0.0j if shifts_per_dir[ax] <= 1: continue kd = k_vals[ax] % shifts_per_dir[ax] step = np.exp(-1j * 2.0 * np.pi * kd / float(shifts_per_dir[ax])) for shift_idx in range(1, shifts_per_dir[ax]): # Reuse the previous power instead of calling exp repeatedly. phase_table[ax, shift_idx] = phase_table[ax, shift_idx - 1] * step return phase_table @njit(cache=True) def _accumulate_finite_k_orbit( translations_row: np.ndarray, pvec: np.ndarray, shift_strides: np.ndarray, phase_table: np.ndarray, ) -> tuple[np.ndarray, np.ndarray, int]: """Accumulate one finite-momentum orbit column with hashed deduplication.""" lattice_dim = pvec.size orbit_size = 1 for ax in range(lattice_dim): orbit_size *= pvec[ax] # Over-allocate one slot per point in the period box. After deduplication, # only the first ``used_len`` entries are meaningful. used_indices = np.empty(orbit_size, np.int32) used_values = np.zeros(orbit_size, np.complex128) hash_size = _next_power_of_two_at_least(2 * orbit_size) hash_keys = np.full(hash_size, -1, np.int32) hash_positions = np.empty(hash_size, np.int32) hash_mask = hash_size - 1 # ``t_local`` walks the local period box, while ``full_shift`` tracks the # corresponding flat index inside the full translation table row. t_local = np.zeros(lattice_dim, np.int32) full_shift = 0 used_len = 0 finished = False while not finished: phase = 1.0 + 0.0j for ax in range(lattice_dim): # The total Bloch phase factorizes axis by axis. phase *= phase_table[ax, t_local[ax]] cfg_row = translations_row[full_shift] used_len = _insert_or_accumulate_complex( cfg_row, phase, used_indices, used_values, hash_keys, hash_positions, hash_mask, used_len, ) full_shift, finished = _advance_period_box_counter( t_local, pvec, shift_strides, full_shift ) return used_indices, used_values, used_len
[docs] @njit(cache=True) def check_normalization(basis: np.ndarray) -> bool: """Check whether all columns of a basis matrix are normalized. Parameters ---------- basis : ndarray Basis matrix with basis vectors stored as columns. Returns ------- bool ``True`` if every column has unit norm. """ for col_idx in range(basis.shape[1]): if not np.isclose(np.linalg.norm(basis[:, col_idx]), 1): return False return True
[docs] @njit(cache=True) def check_orthogonality(basis: np.ndarray) -> bool: """Check whether the columns of a basis matrix are mutually orthogonal. Parameters ---------- basis : ndarray Basis matrix with basis vectors stored as columns. Returns ------- bool ``True`` if all distinct column pairs are orthogonal. """ for left_col_idx in range(basis.shape[1]): for right_col_idx in range(left_col_idx + 1, basis.shape[1]): if not np.isclose( np.vdot(basis[:, left_col_idx], basis[:, right_col_idx]), 0, atol=1e-10, ): return False return True
@njit(cache=True, parallel=True) def build_TC_translations( sector_configs: np.ndarray, # (N, L) dressed local-state ids C_map: np.ndarray, # (d_loc,) local map for even sites ): """ Build the orbit table of the combined generator X = T ∘ C on a 1D ring. IMPORTANT: At each step apply C (sitewise, even/odd possibly different), then translate by +1. Repeat this t times to get X^t. Returns: Ttab : (N, L) int32, Ttab[i, t] = index of (X^t)|config_i> shifts_per_dir : (1,) int32, with shifts_per_dir[0] = L """ n_configs, n_sites = sector_configs.shape num_tc_steps = n_sites translations_table = np.empty((n_configs, num_tc_steps), np.int32) for cfg_idx in prange(n_configs): base = sector_configs[cfg_idx] # current config after t applications (start at t=0) current_config = np.empty(n_sites, np.int32) for site_idx in range(n_sites): current_config[site_idx] = base[site_idx] translations_table[cfg_idx, 0] = cfg_idx # X^0 = identity mapped_config = np.empty(n_sites, np.int32) # scratch for C action for step_idx in range(1, num_tc_steps): # 1) apply C at this step in the LAB frame (before translating) for site_idx in range(n_sites): local_state = current_config[site_idx] if (site_idx & 1) == 0: # even site index mapped_config[site_idx] = C_map[local_state] else: # odd site index mapped_config[site_idx] = C_map[local_state] # 2) translate by +1: dest[(j+1) % L] = work[j] for site_idx in range(n_sites): current_config[(site_idx + 1) % n_sites] = mapped_config[site_idx] # 3) lookup translations_table[cfg_idx, step_idx] = config_to_index_binarysearch( current_config, sector_configs ) shifts_per_dir = np.empty(1, np.int32) shifts_per_dir[0] = num_tc_steps return translations_table, shifts_per_dir # ───────────────────────────────────────────────────────────────────────────── @njit(cache=True) def _prepare_translation_source_sites( lvals: np.ndarray, unit_cell_size: np.ndarray ) -> tuple[np.ndarray, np.ndarray]: """Precompute how each allowed block-translation permutes lattice sites. Parameters ---------- lvals : ndarray Lattice lengths along each spatial direction. unit_cell_size : ndarray Translation step along each axis, expressed in lattice sites. Returns ------- tuple ``(source_sites, shifts_per_dir)`` where ``source_sites[flat_shift, dest_site_idx]`` gives the source lattice site that lands on ``dest_site_idx`` after the decoded block-translation ``flat_shift`` is applied. Notes ----- This helper moves all coordinate algebra out of the hot ``cfg_idx`` loop of :func:`build_all_translations`. Once the source table is known, translating a configuration reduces to a simple gather: ``rolled_cfg[dest_site_idx] = cfg[source_sites[flat_shift, dest_site_idx]]``. """ lattice_dim = lvals.size shifts_per_dir = lvals // unit_cell_size total_shifts = 1 n_sites = 1 for ax in range(lattice_dim): total_shifts *= shifts_per_dir[ax] n_sites *= lvals[ax] source_sites = np.empty((total_shifts, n_sites), np.int32) axis_shifts = np.empty(lattice_dim, np.int32) coords = np.empty(lattice_dim, np.int32) shifted_coords = np.empty(lattice_dim, np.int32) # Build the permutation table once for every allowed combined shift. for flat_shift in range(total_shifts): decode_shift(flat_shift, shifts_per_dir, axis_shifts) for source_site_idx in range(n_sites): # Convert the source site to lattice coordinates. linear_to_coords_rowmajor(source_site_idx, lvals, coords) # Apply the block translation in coordinate space. for ax in range(lattice_dim): shifted_coords[ax] = ( coords[ax] + axis_shifts[ax] * unit_cell_size[ax] ) % lvals[ax] # Record which source site populates the translated destination site. dest_site_idx = coords_to_linear_rowmajor(shifted_coords, lvals) source_sites[flat_shift, dest_site_idx] = source_site_idx return source_sites, shifts_per_dir # ───────────────────────────────────────────────────────────────────────────── @njit(cache=True, parallel=True) def build_all_translations( sector_configs: np.ndarray, lvals: np.ndarray, unit_cell_size: np.ndarray, ): """ Build the table of all allowed block-translations for every configuration in sector_config. For each direction ax the number of possible translations or shifts are shifts_per_dir[ax] = lvals[ax] // unit_cell_size[ax] Given a fixed config, roll it along ax by multiples of unit_cell_size[ax] and record the index within sector_configs of each translation. Theory: ------- Allowed translations are t_d * s_d along axis d, with t_d in [0..R_d-1], where R_d = L_d / s_d. The full translation group is the product of those cyclic groups. For each config c, its orbit points are T(t) c for all t. Returns: translations_array: int32 array of shape (N_configs, prod(R_d)) translations_array[i, s] gives the index (row in sector_configs) of the configuration obtained by applying the block-shift decoded from 's' to configuration i. shifts_per_dir: int32 array (D,) with R_d = L_d / s_d. """ n_configs, n_sites = sector_configs.shape # Precompute the geometry-dependent site permutations once, before touching # the potentially huge number of sector configurations. source_sites, shifts_per_dir = _prepare_translation_source_sites( lvals, unit_cell_size ) total_shifts = source_sites.shape[0] # Allocate the memory for the array with all the translations translations_array = np.empty((n_configs, total_shifts), np.int32) # Run over all the configs for cfg_idx in prange(n_configs): cfg = sector_configs[cfg_idx] # Reuse one scratch array for every shift of this configuration. rolled_cfg = np.empty(n_sites, sector_configs.dtype) # Consider all the possible combined block-shifts. for flat_shift in range(total_shifts): # Gather the translated configuration via the precomputed site map. for dest_site_idx in range(n_sites): rolled_cfg[dest_site_idx] = cfg[source_sites[flat_shift, dest_site_idx]] # Find the index of the rolled config in sector_configs translations_array[cfg_idx, flat_shift] = config_to_index_binarysearch( rolled_cfg, sector_configs ) return translations_array, shifts_per_dir @njit(cache=True) def select_references( translations_array: np.ndarray, # shape (n_configs, total_shifts) # Each row stores the orbit obtained from all block shifts. shifts_per_dir: np.ndarray, # shape (D,) # R_d = L_d // s_d gives the number of block positions per axis. k_vals: np.ndarray, # shape (D,), momentum labels k_d in [0..R_d-1] ): """ Pick one reference per orbit, compute its axis-period vector p (minimal block shifts along each axis returning to itself), and keep the orbit iff (k_d * p_d) % R_d == 0 for ALL axes d. Theory: ------- - The orbit of i is the set { translations_array[i, s] for s in 0..total_shifts-1 }. - Axis period p_d is the smallest p>0 such that shifting by p blocks on axis d returns i to itself. In the table, that means: translations_array[i, encode_shift(p * e_d)] == i. - Momentum sector k is consistent with the orbit iff the character is trivial on the stabilizer: exp(-2π i k_d p_d / R_d) == 1, i.e. (k_d * p_d) % R_d == 0. """ n_configs, total_shifts = translations_array.shape lattice_dim = shifts_per_dir.size # Output buffers (over-allocated, trimmed at the end) references = np.empty(n_configs, np.int32) period_vectors = np.empty((n_configs, lattice_dim), np.int32) # Bookkeeping: which config indices already belong to an orbit we've processed? assigned = np.zeros(n_configs, np.uint8) # Convert axis-only shifts into flat indices without rebuilding temporary # vectors inside the innermost loop. shift_strides = _compute_rowmajor_strides(shifts_per_dir) n_refs = 0 # loop in index order for cfg_idx in range(n_configs): if assigned[cfg_idx]: # Already included in a previously discovered orbit continue # Row listing all T(t)|cfg_idx>, for all combined block shifts t orbit_row = translations_array[cfg_idx] # ---------- 1) Compute the axis period vector p (minimal periods) ---------- pvec = np.empty(lattice_dim, np.int32) for ax in range(lattice_dim): found = False # Scan p = 1..R_d until the shift along axis 'ax' returns to the reference for period_candidate in range(1, shifts_per_dir[ax] + 1): # ``p = R_d`` wraps back to zero shift, which is why the modulo # reproduces the full-cycle fallback without a separate branch. flat_shift_index = ( period_candidate % shifts_per_dir[ax] ) * shift_strides[ax] if orbit_row[flat_shift_index] == cfg_idx: # returned to itself pvec[ax] = period_candidate found = True break if not found: # Should not occur for well-formed tables; fall back to the full cycle pvec[ax] = shifts_per_dir[ax] # ---------- 2) Mark the ENTIRE orbit as assigned ---------- # Important: do this before the momentum filter, so we never duplicate orbits for shift_idx in range(total_shifts): assigned[orbit_row[shift_idx]] = 1 # ---------- 3) Momentum-compatibility filter ---------- # Keep the orbit iff (k_d * p_d) % R_d == 0 for all axes d keep = True for ax in range(lattice_dim): if (k_vals[ax] * pvec[ax]) % shifts_per_dir[ax] != 0: keep = False break if not keep: # Orbit incompatible with the requested momentum; skip it continue # ---------- 4) Accept this orbit representative ---------- references[n_refs] = cfg_idx for ax in range(lattice_dim): period_vectors[n_refs, ax] = pvec[ax] n_refs += 1 # Trim the over-allocated outputs return references[:n_refs], period_vectors[:n_refs, :] @njit(cache=True, parallel=True) def momentum_basis_zero_k( sector_configs: np.ndarray, # (n_configs, n_sites) int32 lvals: np.ndarray, # (D,) int32 lattice lengths unit_cell_size: np.ndarray, # (D,) int32 block sizes s_d (must divide L_d) ): """ Build the Γ (k=0) momentum projector B in **sparse form**: - CSC arrays: (L_col_ptr, L_row_idx, L_data) — float64 - CSR arrays: (R_row_ptr, R_col_idx, R_data) — float64 Mathematical content (Γ sector): -------------------------------- For each translation orbit (represented by a 'reference' config with axis period-vector p), the Γ vector is: |Γ; ref> = (1 / sqrt(∏_d p_d)) * sum_{0 <= t_d < p_d} T(t) |ref>. All phases are 1, so the basis is real. Implementation overview: ------------------------ 1) Precompute the translation table and per-axis block counts R_d. 2) Choose one orbit representative per orbit (select_references with k=0). 3) Two-pass CSC build: PASS 1: for each column (reference), deduplicate images over the “period box” and count how many unique rows it will write (nnz_per_col). PASS 2: repeat, but write the normalized values into CSC arrays. (We also sort row indices within each column for canonical CSC.) 4) Build a CSR view from CSC (row-wise prefix sum + scatter). Returns ------- L_col_ptr : (n_cols+1,) int32 L_row_idx : (nnz,) int32 L_data : (nnz,) float64 R_row_ptr : (n_rows+1,) int32 R_col_idx : (nnz,) int32 R_data : (nnz,) float64 Note ---- You can get: n_rows = sector_configs.shape[0] n_cols = L_col_ptr.shape[0] - 1 """ # ---------- 0) Sizes & translations ---------- n_rows = sector_configs.shape[0] translations_array, shifts_per_dir = build_all_translations( sector_configs, lvals, unit_cell_size ) lattice_dim = shifts_per_dir.size # These strides let the orbit helper update the full translation index # incrementally while it walks the period box. shift_strides = _compute_rowmajor_strides(shifts_per_dir) # ---------- 1) Orbit representatives and their period-vectors ---------- k_zero = np.zeros(lattice_dim, np.int32) references, period_vectors = select_references( translations_array, shifts_per_dir, k_zero ) n_cols = references.shape[0] # ---------- 2) PASS 1: count per-column nonzeros (after dedup) ---------- nnz_per_col = np.zeros(n_cols, np.int32) for col_idx in prange(n_cols): ref_cfg_index = references[col_idx] pvec = period_vectors[col_idx, :] # (D,) # Build the deduplicated orbit content once, then read only its length. used_indices, used_values, used_len = _accumulate_zero_k_orbit( translations_array[ref_cfg_index], pvec, shift_strides ) nnz_per_col[col_idx] = used_len # CSC structure L_col_ptr = _prefix_sum_counts(nnz_per_col) total_nnz = L_col_ptr[-1] L_row_idx = np.empty(total_nnz, np.int32) L_data = np.empty(total_nnz, np.float64) # ---------- 3) PASS 2: fill CSC (dedup + normalize + (optional) sort) ---------- for col_idx in prange(n_cols): ref_cfg_index = references[col_idx] pvec = period_vectors[col_idx, :] # Rebuild the same orbit content, now to normalize and materialize it. used_indices, used_values, used_len = _accumulate_zero_k_orbit( translations_array[ref_cfg_index], pvec, shift_strides ) # Normalize the column by its 2-norm (counts are real here) norm_sq = 0.0 for entry_idx in range(used_len): norm_sq += used_values[entry_idx] * used_values[entry_idx] if np.isclose(norm_sq, 0.0): continue inv_norm = 1.0 / np.sqrt(norm_sq) # (Optional but nice): sort by row index for canonical CSC _insertion_sort_by_row(used_indices, used_values, used_len) # Write this column’s entries into CSC arrays write = L_col_ptr[col_idx] for entry_idx in range(used_len): L_row_idx[write] = used_indices[entry_idx] L_data[write] = used_values[entry_idx] * inv_norm write += 1 # ---------- 4) Build CSR view from CSC (no SciPy) ---------- R_row_ptr = np.zeros(n_rows + 1, np.int32) # row counts for nnz_idx in range(total_nnz): row_idx = L_row_idx[nnz_idx] R_row_ptr[row_idx + 1] += 1 # prefix sum for row_idx in range(n_rows): R_row_ptr[row_idx + 1] += R_row_ptr[row_idx] # scatter R_col_idx = np.empty(total_nnz, np.int32) R_data = np.empty(total_nnz, np.float64) # work heads (copy) heads = np.empty(n_rows, np.int32) for row_idx in range(n_rows): heads[row_idx] = R_row_ptr[row_idx] for col_idx in range(n_cols): start = L_col_ptr[col_idx] stop = L_col_ptr[col_idx + 1] for nnz_idx in range(start, stop): row_idx = L_row_idx[nnz_idx] write_idx = heads[row_idx] R_col_idx[write_idx] = col_idx R_data[write_idx] = L_data[nnz_idx] heads[row_idx] += 1 return (L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data) @njit(cache=True, parallel=True) def momentum_basis_finite_k( sector_configs: np.ndarray, # (n_configs, n_sites) int32 lvals: np.ndarray, # (D,) int32 unit_cell_size: np.ndarray, # (D,) int32 (must divide L_d) k_vals: np.ndarray, # (D,) int32 momenta mod R_d ): """ Build the finite-k momentum projector B in sparse form: - CSC arrays: (L_col_ptr, L_row_idx, L_data) — complex128 - CSR arrays: (R_row_ptr, R_col_idx, R_data) — complex128 Math: R_d = L_d / s_d, group G = Z_{R0} × ... × Z_{R_{D-1}}. For an orbit representative `ref` with axis‐periods p_d, the finite-k Bloch sum is |k; ref> = (1/√N) ∑_{0≤t_d<p_d} exp[-2πi Σ_d (k_d t_d / R_d)] T(t) |ref>, where N = ∑_{distinct images j} |amplitude_j|^2 after deduplication. If the orbit is incompatible with k, all amplitudes cancel → empty column. Determinism: Each column is built independently (per-iteration scratch), then we sort its (row, value) pairs by row index before writing CSC. Returns ------- L_col_ptr : (n_cols+1,) int32 L_row_idx : (nnz,) int32 L_data : (nnz,) complex128 R_row_ptr : (n_rows+1,) int32 R_col_idx : (nnz,) int32 R_data : (nnz,) complex128 (As usual: n_rows = sector_configs.shape[0], n_cols = L_col_ptr.size - 1.) """ # Tolerances for pruning true zeros / near-incompatible columns TOL_ZERO = 1e-14 # per-entry amplitude threshold TOL_COLNORM = 1e-30 # column norm^2 threshold # ---------- 0) Sizes & translations ---------- n_rows = sector_configs.shape[0] translations_array, shifts_per_dir = build_all_translations( sector_configs, lvals, unit_cell_size ) # These strides let the orbit helper update the full translation index # incrementally while it walks the period box. shift_strides = _compute_rowmajor_strides(shifts_per_dir) # Precompute all one-axis phase factors once and reuse them across columns. phase_table = _prepare_phase_table(k_vals, shifts_per_dir) # ---------- 1) Orbit representatives & periods (filtered by k_vals) ---------- references, period_vectors = select_references( translations_array, shifts_per_dir, k_vals ) n_cols = references.shape[0] # ---------- 2) PASS 1: count nonzeros per column (after cancellations) ---------- nnz_per_col = np.zeros(n_cols, np.int32) for col_idx in prange(n_cols): ref_cfg_index = references[col_idx] pvec = period_vectors[col_idx, :] # (D,) # Build the deduplicated orbit content once, then only count the # entries that survive Bloch-phase cancellations. used_indices, used_values, used_len = _accumulate_finite_k_orbit( translations_array[ref_cfg_index], pvec, shift_strides, phase_table ) # Count only truly non-zero amplitudes after cancellation surviving_entries = 0 for entry_idx in range(used_len): if np.abs(used_values[entry_idx]) > TOL_ZERO: surviving_entries += 1 nnz_per_col[col_idx] = surviving_entries # Allocate CSC L_col_ptr = _prefix_sum_counts(nnz_per_col) total_nnz = L_col_ptr[-1] L_row_idx = np.empty(total_nnz, np.int32) L_data = np.empty(total_nnz, np.complex128) # ---------- 3) PASS 2: fill CSC (dedup + prune zeros + normalize + sort) ---------- for col_idx in prange(n_cols): ref_cfg_index = references[col_idx] pvec = period_vectors[col_idx, :] # Rebuild the same orbit content, now to prune zeros, normalize, and # write the canonical CSC column. used_indices, used_values, used_len = _accumulate_finite_k_orbit( translations_array[ref_cfg_index], pvec, shift_strides, phase_table ) # In-place prune entries that cancelled to ~0, and compute column norm kept_len = 0 norm_sq = 0.0 for entry_idx in range(used_len): amplitude = used_values[entry_idx] if np.abs(amplitude) > TOL_ZERO: used_indices[kept_len] = used_indices[entry_idx] used_values[kept_len] = amplitude norm_sq += ( amplitude.real * amplitude.real + amplitude.imag * amplitude.imag ) kept_len += 1 # If the column fully cancels, write nothing if norm_sq <= TOL_COLNORM or kept_len == 0: continue inv_norm = 1.0 / np.sqrt(norm_sq) # Sort by row index for canonical CSC _insertion_sort_by_row(used_indices, used_values, kept_len) # Materialize write = L_col_ptr[col_idx] for entry_idx in range(kept_len): L_row_idx[write] = used_indices[entry_idx] L_data[write] = used_values[entry_idx] * inv_norm write += 1 # ---------- 4) Build CSR from CSC ---------- R_row_ptr = np.zeros(n_rows + 1, np.int32) for nnz_idx in range(total_nnz): row_idx = L_row_idx[nnz_idx] R_row_ptr[row_idx + 1] += 1 for row_idx in range(n_rows): R_row_ptr[row_idx + 1] += R_row_ptr[row_idx] R_col_idx = np.empty(total_nnz, np.int32) R_data = np.empty(total_nnz, np.complex128) heads = np.empty(n_rows, np.int32) for row_idx in range(n_rows): heads[row_idx] = R_row_ptr[row_idx] for col_idx in range(n_cols): start = L_col_ptr[col_idx] stop = L_col_ptr[col_idx + 1] for nnz_idx in range(start, stop): row_idx = L_row_idx[nnz_idx] write_idx = heads[row_idx] R_col_idx[write_idx] = col_idx R_data[write_idx] = L_data[nnz_idx] heads[row_idx] += 1 return (L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data) @njit(cache=True, parallel=True) def momentum_basis_zero_k_TC(sector_configs: np.ndarray): """ Build the Γ (k=0) momentum projector B in **sparse form**: - CSC arrays: (L_col_ptr, L_row_idx, L_data) — float64 - CSR arrays: (R_row_ptr, R_col_idx, R_data) — float64 Mathematical content (Γ sector): -------------------------------- For each translation orbit (represented by a 'reference' config with axis period-vector p), the Γ vector is: |Γ; ref> = (1 / sqrt(∏_d p_d)) * sum_{0 <= t_d < p_d} T(t) |ref>. All phases are 1, so the basis is real. Implementation overview: ------------------------ 1) Precompute the translation table and per-axis block counts R_d. 2) Choose one orbit representative per orbit (select_references with k=0). 3) Two-pass CSC build: PASS 1: for each column (reference), deduplicate images over the “period box” and count how many unique rows it will write (nnz_per_col). PASS 2: repeat, but write the normalized values into CSC arrays. (We also sort row indices within each column for canonical CSC.) 4) Build a CSR view from CSC (row-wise prefix sum + scatter). Returns ------- L_col_ptr : (n_cols+1,) int32 L_row_idx : (nnz,) int32 L_data : (nnz,) float64 R_row_ptr : (n_rows+1,) int32 R_col_idx : (nnz,) int32 R_data : (nnz,) float64 Note ---- You can get: n_rows = sector_configs.shape[0] n_cols = L_col_ptr.shape[0] - 1 """ # ---------- 0) Sizes & translations ---------- n_rows = sector_configs.shape[0] # Special 1D case with TC+inversion symmetry C_map = np.array([4, 5, 2, 3, 0, 1], dtype=np.int32) C_phase_label = np.array([+1, +1, +1, -1, +1, +1], dtype=np.float64) C_phase_per_config = precompute_C_sign_per_config(sector_configs, C_phase_label) translations_array, shifts_per_dir = build_TC_translations(sector_configs, C_map) lattice_dim = shifts_per_dir.size # ---------- 1) Orbit representatives and their period-vectors ---------- k_zero = np.zeros(lattice_dim, np.int32) references, period_vectors = select_references( translations_array, shifts_per_dir, k_zero ) n_cols = references.shape[0] # ---------- 2) PASS 1: count per-column nonzeros (after dedup) ---------- nnz_per_col = np.zeros(n_cols, np.int32) for col_idx in prange(n_cols): ref_cfg_index = references[col_idx] pvec = period_vectors[col_idx, :] # (D,) # Upper bound for distinct images when scanning the period box orbit_size = np.prod(pvec) # Per-column dedup accumulator (indices + counts) used_indices = np.empty(orbit_size, np.int32) used_values = np.zeros(orbit_size, np.float64) # counts (phase=1) used_len = 0 # local mixed-radix index in the period box t_local = np.zeros(lattice_dim, np.int32) # Enumerate all t in 0..p_d-1 and map via the precomputed table for local_shift_idx in range(orbit_size): decode_mixed_index(local_shift_idx, pvec, t_local) flat_full = encode_shift(t_local, shifts_per_dir) # base R_d cfg_row = translations_array[ref_cfg_index, flat_full] # ----------------------------------------------------- incr = 1.0 # Special 1D case with TC symmetry ref_sign = C_phase_per_config[ref_cfg_index] if (flat_full & 1) == 1: # t odd? incr = ref_sign # deduplicate: linear scan OK (orbit boxes are small) pos = -1 for entry_idx in range(used_len): if used_indices[entry_idx] == cfg_row: pos = entry_idx break if pos == -1: pos = used_len used_indices[pos] = cfg_row used_len += 1 # Odd powers of the TC generator contribute the parity sign of the # reference configuration instead of a plain +1 factor. used_values[pos] += incr nnz_per_col[col_idx] = used_len # CSC structure L_col_ptr = _prefix_sum_counts(nnz_per_col) total_nnz = L_col_ptr[-1] L_row_idx = np.empty(total_nnz, np.int32) L_data = np.empty(total_nnz, np.float64) # ---------- 3) PASS 2: fill CSC (dedup + normalize + (optional) sort) ---------- for col_idx in prange(n_cols): ref_cfg_index = references[col_idx] pvec = period_vectors[col_idx, :] # same bound orbit_size = np.prod(pvec) used_indices = np.empty(orbit_size, np.int32) used_values = np.zeros(orbit_size, np.float64) used_len = 0 t_local = np.zeros(lattice_dim, np.int32) for local_shift_idx in range(orbit_size): decode_mixed_index(local_shift_idx, pvec, t_local) flat_full = encode_shift(t_local, shifts_per_dir) cfg_row = translations_array[ref_cfg_index, flat_full] # ----------------------------------------------------- incr = 1.0 # Special 1D case with TC symmetry ref_sign = C_phase_per_config[ref_cfg_index] if (flat_full & 1) == 1: # t odd? incr = ref_sign pos = -1 for entry_idx in range(used_len): if used_indices[entry_idx] == cfg_row: pos = entry_idx break if pos == -1: pos = used_len used_indices[pos] = cfg_row used_len += 1 used_values[pos] += incr # Normalize the column by its 2-norm (counts are real here) norm_sq = 0.0 for entry_idx in range(used_len): norm_sq += used_values[entry_idx] * used_values[entry_idx] if np.isclose(norm_sq, 0.0): continue inv_norm = 1.0 / np.sqrt(norm_sq) # (Optional but nice): sort by row index for canonical CSC _insertion_sort_by_row(used_indices, used_values, used_len) # Write this column’s entries into CSC arrays write = L_col_ptr[col_idx] for entry_idx in range(used_len): L_row_idx[write] = used_indices[entry_idx] L_data[write] = used_values[entry_idx] * inv_norm write += 1 # ---------- 4) Build CSR view from CSC (no SciPy) ---------- R_row_ptr = np.zeros(n_rows + 1, np.int32) # row counts for nnz_idx in range(total_nnz): row_idx = L_row_idx[nnz_idx] R_row_ptr[row_idx + 1] += 1 # prefix sum for row_idx in range(n_rows): R_row_ptr[row_idx + 1] += R_row_ptr[row_idx] # scatter R_col_idx = np.empty(total_nnz, np.int32) R_data = np.empty(total_nnz, np.float64) # work heads (copy) heads = np.empty(n_rows, np.int32) for row_idx in range(n_rows): heads[row_idx] = R_row_ptr[row_idx] for col_idx in range(n_cols): start = L_col_ptr[col_idx] stop = L_col_ptr[col_idx + 1] for nnz_idx in range(start, stop): row_idx = L_row_idx[nnz_idx] write_idx = heads[row_idx] R_col_idx[write_idx] = col_idx R_data[write_idx] = L_data[nnz_idx] heads[row_idx] += 1 return (L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data) @njit(cache=True, parallel=True) def momentum_basis_finite_k_TC( sector_configs: np.ndarray, # (n_configs, n_sites) int32 k_vals: np.ndarray, # (D,) int32 momenta mod R_d ): """ Build the finite-k momentum projector B in sparse form: - CSC arrays: (L_col_ptr, L_row_idx, L_data) — complex128 - CSR arrays: (R_row_ptr, R_col_idx, R_data) — complex128 Math: R_d = L_d / s_d, group G = Z_{R0} × ... × Z_{R_{D-1}}. For an orbit representative `ref` with axis‐periods p_d, the finite-k Bloch sum is |k; ref> = (1/√N) ∑_{0≤t_d<p_d} exp[-2πi Σ_d (k_d t_d / R_d)] T(t) |ref>, where N = ∑_{distinct images j} |amplitude_j|^2 after deduplication. If the orbit is incompatible with k, all amplitudes cancel → empty column. Determinism: Each column is built independently (per-iteration scratch), then we sort its (row, value) pairs by row index before writing CSC. Returns ------- L_col_ptr : (n_cols+1,) int32 L_row_idx : (nnz,) int32 L_data : (nnz,) complex128 R_row_ptr : (n_rows+1,) int32 R_col_idx : (nnz,) int32 R_data : (nnz,) complex128 (As usual: n_rows = sector_configs.shape[0], n_cols = L_col_ptr.size - 1.) """ # Tolerances for pruning true zeros / near-incompatible columns TOL_ZERO = 1e-14 # per-entry amplitude threshold TOL_COLNORM = 1e-30 # column norm^2 threshold # ---------- 0) Sizes & translations ---------- n_rows = sector_configs.shape[0] # Special 1D case with TC+inversion symmetry C_map = np.array([4, 5, 2, 3, 0, 1], dtype=np.int32) C_phase_label = np.array([+1, +1, +1, -1, +1, +1], dtype=np.float64) C_phase_per_config = precompute_C_sign_per_config(sector_configs, C_phase_label) translations_array, shifts_per_dir = build_TC_translations(sector_configs, C_map) lattice_dim = shifts_per_dir.size # ---------- 1) Orbit representatives & periods (filtered by k_vals) ---------- references, period_vectors = select_references( translations_array, shifts_per_dir, k_vals ) n_cols = references.shape[0] # ---------- 2) PASS 1: count nonzeros per column (after cancellations) ---------- nnz_per_col = np.zeros(n_cols, np.int32) for col_idx in prange(n_cols): ref_cfg_index = references[col_idx] pvec = period_vectors[col_idx, :] # (D,) # Orbit-box upper bound orbit_size = 1 for ax in range(lattice_dim): orbit_size *= pvec[ax] # Per-column accumulators (dedup by image row) used_indices = np.empty(orbit_size, np.int32) used_values = np.zeros(orbit_size, np.complex128) used_len = 0 t_local = np.zeros(lattice_dim, np.int32) # Enumerate period box, accumulate complex phase per distinct image for local_shift_idx in range(orbit_size): decode_mixed_index(local_shift_idx, pvec, t_local) # phase = exp(-2πi Σ_d (k_d * t_d / R_d)) phase_arg = 0.0 for ax in range(lattice_dim): kd = k_vals[ax] % shifts_per_dir[ax] phase_arg += (kd * t_local[ax]) / float(shifts_per_dir[ax]) phase = np.exp(-1j * 2.0 * np.pi * phase_arg) flat_full = encode_shift(t_local, shifts_per_dir) # base R_d cfg_row = translations_array[ref_cfg_index, flat_full] # ----------------------------------------------------- incr = 1.0 # In 1D SU(2) we can implement the TC symmetry ref_sign = C_phase_per_config[ref_cfg_index] if (flat_full & 1) == 1: # t odd? incr = ref_sign # deduplicate by row pos = -1 for entry_idx in range(used_len): if used_indices[entry_idx] == cfg_row: pos = entry_idx break if pos == -1: pos = used_len used_indices[pos] = cfg_row used_len += 1 used_values[pos] += incr * phase # Count only truly non-zero amplitudes after cancellation surviving_entries = 0 for entry_idx in range(used_len): if np.abs(used_values[entry_idx]) > TOL_ZERO: surviving_entries += 1 nnz_per_col[col_idx] = surviving_entries # Allocate CSC L_col_ptr = _prefix_sum_counts(nnz_per_col) total_nnz = L_col_ptr[-1] L_row_idx = np.empty(total_nnz, np.int32) L_data = np.empty(total_nnz, np.complex128) # ---------- 3) PASS 2: fill CSC (dedup + prune zeros + normalize + sort) ---------- for col_idx in prange(n_cols): ref_cfg_index = references[col_idx] pvec = period_vectors[col_idx, :] orbit_size = np.prod(pvec) used_indices = np.empty(orbit_size, np.int32) used_values = np.zeros(orbit_size, np.complex128) used_len = 0 t_local = np.zeros(lattice_dim, np.int32) # Deduplicate + accumulate complex phases for local_shift_idx in range(orbit_size): decode_mixed_index(local_shift_idx, pvec, t_local) phase_arg = 0.0 for ax in range(lattice_dim): kd = k_vals[ax] % shifts_per_dir[ax] phase_arg += (kd * t_local[ax]) / float(shifts_per_dir[ax]) phase = np.exp(-1j * 2.0 * np.pi * phase_arg) flat_full = encode_shift(t_local, shifts_per_dir) cfg_row = translations_array[ref_cfg_index, flat_full] # ----------------------------------------------------- incr = 1.0 # In 1D SU(2) we can implement the TC symmetry ref_sign = C_phase_per_config[ref_cfg_index] if (flat_full & 1) == 1: # t odd? incr = ref_sign # deduplicate by row pos = -1 for entry_idx in range(used_len): if used_indices[entry_idx] == cfg_row: pos = entry_idx break if pos == -1: pos = used_len used_indices[pos] = cfg_row used_len += 1 used_values[pos] += incr * phase # In-place prune entries that cancelled to ~0, and compute column norm kept_len = 0 norm_sq = 0.0 for entry_idx in range(used_len): amplitude = used_values[entry_idx] if np.abs(amplitude) > TOL_ZERO: used_indices[kept_len] = used_indices[entry_idx] used_values[kept_len] = amplitude norm_sq += ( amplitude.real * amplitude.real + amplitude.imag * amplitude.imag ) kept_len += 1 # If the column fully cancels, write nothing if norm_sq <= TOL_COLNORM or kept_len == 0: continue inv_norm = 1.0 / np.sqrt(norm_sq) # Sort by row index for canonical CSC _insertion_sort_by_row(used_indices, used_values, kept_len) # Materialize write = L_col_ptr[col_idx] for entry_idx in range(kept_len): L_row_idx[write] = used_indices[entry_idx] L_data[write] = used_values[entry_idx] * inv_norm write += 1 # ---------- 4) Build CSR from CSC ---------- R_row_ptr = np.zeros(n_rows + 1, np.int32) for nnz_idx in range(total_nnz): row_idx = L_row_idx[nnz_idx] R_row_ptr[row_idx + 1] += 1 for row_idx in range(n_rows): R_row_ptr[row_idx + 1] += R_row_ptr[row_idx] R_col_idx = np.empty(total_nnz, np.int32) R_data = np.empty(total_nnz, np.complex128) heads = np.empty(n_rows, np.int32) for row_idx in range(n_rows): heads[row_idx] = R_row_ptr[row_idx] for col_idx in range(n_cols): start = L_col_ptr[col_idx] stop = L_col_ptr[col_idx + 1] for nnz_idx in range(start, stop): row_idx = L_row_idx[nnz_idx] write_idx = heads[row_idx] R_col_idx[write_idx] = col_idx R_data[write_idx] = L_data[nnz_idx] heads[row_idx] += 1 return (L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data) # ─────────────────────────────────────────────────────────────────────────────
[docs] def get_momentum_basis( sector_configs: np.ndarray, lvals: list[int], unit_cell_size: np.ndarray, k_vals: np.ndarray, TC_symmetry: bool = False, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Construct a sparse momentum-basis projector representation. Parameters ---------- sector_configs : ndarray Symmetry-sector configurations, one row per basis state. lvals : list Lattice lengths along each spatial direction. unit_cell_size : ndarray Translation step (logical unit-cell size) per spatial direction. k_vals : ndarray Momentum quantum numbers (one per spatial direction, or the effective translation-combined symmetry momentum in ``TC_symmetry`` mode). TC_symmetry : bool, optional If ``True``, use the translation-combined (TC) symmetry construction. Returns ------- tuple Sparse left/right representations of the momentum-basis projector ``B``: ``(L_col_ptr, L_row_idx, L_data, R_row_ptr, R_col_idx, R_data)``, where the first three arrays encode the CSC representation of ``B`` and the last three arrays encode the CSR representation of ``B``. """ lvals = np.ascontiguousarray(lvals, dtype=np.int32) unit_cell_size = np.ascontiguousarray(unit_cell_size, dtype=np.int32) k_vals = np.ascontiguousarray(k_vals, dtype=np.int32) # Dispatch to the correct projector constructor: finite-k vs zero-k, and # standard translations vs the special TC symmetry path. if np.any(k_vals != 0): if TC_symmetry: return momentum_basis_finite_k_TC(sector_configs, k_vals) return momentum_basis_finite_k(sector_configs, lvals, unit_cell_size, k_vals) if TC_symmetry: return momentum_basis_zero_k_TC(sector_configs) return momentum_basis_zero_k(sector_configs, lvals, unit_cell_size)
# Optional: Uncomment for debugging # if not check_normalization(basis) or not check_orthogonality(basis): # raise ValueError("Basis normalization or orthogonality failed.") # ───────────────────────────────────────────────────────────────────────────── @njit(cache=True, parallel=True) def precompute_nonzero_csr(matrix: np.ndarray): """ Build CSR pointers for nonzeros of each column of `mat`. Returns (indptr, indices) so that indices[indptr[p]:indptr[p+1]] are the rows c where mat[c,p]!=0. """ Nx, Ny = matrix.shape col_counts = np.zeros(Ny, np.int32) for jj in prange(Ny): for ii in range(Nx): if np.abs(matrix[ii, jj]) > 1e-10: col_counts[jj] += 1 indptr = np.empty(Ny + 1, np.int32) indptr[0] = 0 for jj in range(Ny): indptr[jj + 1] = indptr[jj] + col_counts[jj] total_nz = indptr[Ny] indices = np.empty(total_nz, np.int32) for jj in prange(Ny): start = indptr[jj] pos = start for ii in range(Nx): if np.abs(matrix[ii, jj]) > 1e-10: indices[pos] = ii pos += 1 return indptr, indices
[docs] @njit(cache=True, parallel=True) def nbody_data_momentum( op_list: np.ndarray, # (n_ops, n_sites, d_loc, d_loc) op_sites_list: np.ndarray, # (n_ops,), int32 sector_configs: np.ndarray, # (n_states, n_sites), int32 # --- sparse momentum basis B arrays --- L_col_ptr: np.ndarray, # (proj_dim+1,) int32 L_row_idx: np.ndarray, # (nnz_B,) int32 L_data: np.ndarray, # (nnz_B,) float64/complex128 R_row_ptr: np.ndarray, # (n_states+1,) int32 R_col_idx: np.ndarray, # (nnz_B,) int32 R_data: np.ndarray, # (nnz_B,) float64/complex128 ): """Build sparse triplets for a generic n-body operator in momentum basis. Parameters ---------- op_list : ndarray Site-resolved factorized operator data of shape ``(n_ops, n_sites, d_loc, d_loc)``. op_sites_list : ndarray Sites where each operator factor acts (shape ``(n_ops,)``). sector_configs : ndarray Symmetry-sector configurations, one row per basis state. L_col_ptr, L_row_idx, L_data : ndarray CSC representation of momentum projector ``B``. R_row_ptr, R_col_idx, R_data : ndarray CSR representation of momentum projector ``B``. Returns ------- tuple ``(row_list, col_list, value_list)`` triplets for the projected operator ``B^H H B``. """ _, n_sites = sector_configs.shape n_ops = len(op_sites_list) proj_dim = L_col_ptr.shape[0] - 1 if n_ops < 1: raise ValueError(f"nbody operator must act on at least one site, got {n_ops}") # Precompute all one-site transitions once and reuse them in both passes. transition_counts_all, ket_local_states_all, transition_values_all = ( _prepare_momentum_local_transition_data(op_list, op_sites_list) ) # ---------------------------- # PASS 1: count nnz per projected row # ---------------------------- nnz_per_row = np.zeros(proj_dim, np.int32) for proj_row_idx in prange(proj_dim): row_nnz = 0 # The left CSC projector tells which real-space bra rows contribute to # this momentum row. left_start = L_col_ptr[proj_row_idx] left_stop = L_col_ptr[proj_row_idx + 1] for left_ptr in range(left_start, left_stop): bra_idx = L_row_idx[left_ptr] bra_cfg = sector_configs[bra_idx] bra_local_states = np.empty(n_ops, np.int32) active_transition_counts = np.empty(n_ops, np.int32) combo_count = 1 for op_idx in range(n_ops): site_idx = op_sites_list[op_idx] bra_loc = bra_cfg[site_idx] bra_local_states[op_idx] = bra_loc active_transition_counts[op_idx] = transition_counts_all[ op_idx, bra_loc ] combo_count *= active_transition_counts[op_idx] if combo_count == 0: continue ket_cfg = np.empty(n_sites, np.int32) # Start every candidate ket from the bra configuration so that # untouched sites remain unchanged throughout the mixed-radix walk. for site_idx in range(n_sites): ket_cfg[site_idx] = bra_cfg[site_idx] transition_counters = np.zeros(n_ops, np.int32) finished = False while not finished: # Overwrite only the acted sites; all untouched sites stay equal # to the bra because ``ket_cfg`` started as a copy of it. for op_idx in range(n_ops): trans_idx = transition_counters[op_idx] bra_loc = bra_local_states[op_idx] ket_cfg[op_sites_list[op_idx]] = ket_local_states_all[ op_idx, bra_loc, trans_idx ] ket_idx = config_to_index_binarysearch(ket_cfg, sector_configs) if ket_idx >= 0: row_nnz += R_row_ptr[ket_idx + 1] - R_row_ptr[ket_idx] for op_idx in range(n_ops - 1, -1, -1): transition_counters[op_idx] += 1 if transition_counters[op_idx] < active_transition_counts[op_idx]: break transition_counters[op_idx] = 0 if op_idx == 0: finished = True nnz_per_row[proj_row_idx] = row_nnz # Prefix-sum row offsets offset = 0 for proj_row_idx in range(proj_dim): tmp_nnz = nnz_per_row[proj_row_idx] nnz_per_row[proj_row_idx] = offset offset += tmp_nnz total_nnz = offset # ---------------------------- # PASS 2: fill triplets # ---------------------------- row_list = np.empty(total_nnz, np.int32) col_list = np.empty(total_nnz, np.int32) value_list = np.empty(total_nnz, np.complex128) for proj_row_idx in prange(proj_dim): write_ptr = nnz_per_row[proj_row_idx] left_start = L_col_ptr[proj_row_idx] left_stop = L_col_ptr[proj_row_idx + 1] for left_ptr in range(left_start, left_stop): bra_idx = L_row_idx[left_ptr] # Left projector contributes the bra-side amplitude <prow|bra>. amp_left = np.conj(L_data[left_ptr]) bra_cfg = sector_configs[bra_idx] bra_local_states = np.empty(n_ops, np.int32) active_transition_counts = np.empty(n_ops, np.int32) combo_count = 1 for op_idx in range(n_ops): site_idx = op_sites_list[op_idx] bra_loc = bra_cfg[site_idx] bra_local_states[op_idx] = bra_loc active_transition_counts[op_idx] = transition_counts_all[ op_idx, bra_loc ] combo_count *= active_transition_counts[op_idx] if combo_count == 0: continue ket_cfg = np.empty(n_sites, np.int32) # Start every candidate ket from the bra configuration so that # untouched sites remain unchanged throughout the mixed-radix walk. for site_idx in range(n_sites): ket_cfg[site_idx] = bra_cfg[site_idx] transition_counters = np.zeros(n_ops, np.int32) finished = False while not finished: # Rebuild the full operator amplitude from the selected local # transitions and update only the acted sites in the ket. amp_mid = 1.0 + 0.0j for op_idx in range(n_ops): bra_loc = bra_local_states[op_idx] trans_idx = transition_counters[op_idx] amp_mid *= transition_values_all[op_idx, bra_loc, trans_idx] ket_cfg[op_sites_list[op_idx]] = ket_local_states_all[ op_idx, bra_loc, trans_idx ] ket_idx = config_to_index_binarysearch(ket_cfg, sector_configs) if ket_idx >= 0: # The right CSR projector tells which momentum columns the # real-space ket contributes to, together with their amplitudes. right_start = R_row_ptr[ket_idx] right_stop = R_row_ptr[ket_idx + 1] for right_ptr in range(right_start, right_stop): proj_col_idx = R_col_idx[right_ptr] amp_right = R_data[right_ptr] row_list[write_ptr] = proj_row_idx col_list[write_ptr] = proj_col_idx value_list[write_ptr] = amp_left * amp_mid * amp_right write_ptr += 1 for op_idx in range(n_ops - 1, -1, -1): transition_counters[op_idx] += 1 if transition_counters[op_idx] < active_transition_counts[op_idx]: break transition_counters[op_idx] = 0 if op_idx == 0: finished = True return row_list, col_list, value_list
[docs] @njit(cache=True, parallel=True) def nbody_data_momentum_1site( op_list: np.ndarray, # (1, n_sites, d_loc, d_loc) op_sites_list: np.ndarray, # (1,), int32 sector_configs: np.ndarray, # (N, n_sites), int32 # --- sparse momentum basis B arrays --- L_col_ptr: np.ndarray, # (Ldim+1,) int32 -- columns of B L_row_idx: np.ndarray, # (nnz_B,) int32 -- rows for each CSC entry L_data: np.ndarray, # (nnz_B,) float64 or complex128 -- B[row, col] R_row_ptr: np.ndarray, # (N+1,) int32 -- rows of B R_col_idx: np.ndarray, # (nnz_B,) int32 -- cols for each CSR entry R_data: np.ndarray, # (nnz_B,) float64 or complex128 -- B[row, col] ): """Build sparse triplets for a one-site operator in the momentum basis. Parameters ---------- op_list : ndarray One-site factorized operator data (shape ``(1, n_sites, d_loc, d_loc)``). op_sites_list : ndarray Site index of the operator action (shape ``(1,)``). sector_configs : ndarray Symmetry-sector configurations, one row per basis state. L_col_ptr, L_row_idx, L_data : ndarray CSC representation of the momentum-basis projector ``B``. R_row_ptr, R_col_idx, R_data : ndarray CSR representation of the same projector ``B``. Returns ------- tuple ``(row_list, col_list, value_list)`` triplets for the projected operator ``B^† H B``. """ n_sites = sector_configs.shape[1] Ldim = L_col_ptr.size - 1 # number of momentum columns (dim of projected space) # Selected site/operator site = op_sites_list[0] transition_counts_all, ket_local_states_all, transition_values_all = ( _prepare_momentum_local_transition_data(op_list, op_sites_list) ) # ---------------------------- # PASS 1: count nnz per momentum-row (prow) # ---------------------------- nnz_per_row = np.zeros(Ldim, np.int32) for prow in prange(Ldim): cnt = 0 # (CHANGED) iterate all real-space rows j1 with B[j1, prow] != 0 via CSC start1 = L_col_ptr[prow] stop1 = L_col_ptr[prow + 1] for p1 in range(start1, stop1): j1 = L_row_idx[p1] # row index in real space # bra config bra_cfg = sector_configs[j1] bra_local_state = bra_cfg[site] num_local_transitions = transition_counts_all[0, bra_local_state] # scratch ket (copy bra → then edit one site) ket_cfg = np.empty(n_sites, np.int32) for jj in range(n_sites): ket_cfg[jj] = bra_cfg[jj] # for each allowed local change bra_local_state->ket_local_state, find j2 for local_transition_idx in range(num_local_transitions): ket_local_state = ket_local_states_all[ 0, bra_local_state, local_transition_idx ] ket_cfg[site] = ket_local_state j2 = config_to_index_binarysearch(ket_cfg, sector_configs) if j2 < 0: continue # (CHANGED) number of nonzero momentum-cols for row j2 via CSR cnt += R_row_ptr[j2 + 1] - R_row_ptr[j2] nnz_per_row[prow] = cnt # prefix-sum offsets offset = 0 for prow in range(Ldim): tmp = nnz_per_row[prow] nnz_per_row[prow] = offset offset += tmp total_nnz = offset # ---------------------------- # PASS 2: fill triplets # ---------------------------- row_list = np.empty(total_nnz, np.int32) col_list = np.empty(total_nnz, np.int32) value_list = np.empty(total_nnz, np.complex128) for prow in prange(Ldim): ptr = nnz_per_row[prow] # (CHANGED) iterate B[:, prow] via CSC start1 = L_col_ptr[prow] stop1 = L_col_ptr[prow + 1] for p1 in range(start1, stop1): j1 = L_row_idx[p1] # real-space row B1 = L_data[p1] # value B[j1, prow] (float64 or complex128) bra_cfg = sector_configs[j1] bra_local_state = bra_cfg[site] num_local_transitions = transition_counts_all[0, bra_local_state] # scratch ket ket_cfg = np.empty(n_sites, np.int32) for jj in range(n_sites): ket_cfg[jj] = bra_cfg[jj] # explicit 1-nested loop + projection for local_transition_idx in range(num_local_transitions): ket_local_state = ket_local_states_all[ 0, bra_local_state, local_transition_idx ] transition_value = transition_values_all[ 0, bra_local_state, local_transition_idx ] ket_cfg[site] = ket_local_state j2 = config_to_index_binarysearch(ket_cfg, sector_configs) if j2 < 0: continue # (CHANGED) project into momentum columns of row j2 via CSR start2 = R_row_ptr[j2] stop2 = R_row_ptr[j2 + 1] for p2 in range(start2, stop2): pcol = R_col_idx[p2] B2 = R_data[p2] # value B[j2, pcol] # (CHANGED) use sparse B entries val = np.conj(B1) * transition_value * B2 row_list[ptr] = prow col_list[ptr] = pcol value_list[ptr] = val ptr += 1 return row_list, col_list, value_list
[docs] @njit(cache=True, parallel=True) def nbody_data_momentum_2sites( op_list: np.ndarray, # shape (2, n_sites, d_loc, d_loc) op_sites_list: np.ndarray, # shape (2,), int32 sector_configs: np.ndarray, # shape (N, n_sites), int32 # ---- momentum basis B in sparse form ---- L_col_ptr: np.ndarray, # (Ldim+1,), int32 -- columns of B L_row_idx: np.ndarray, # (nnz_B,), int32 -- real-space rows j with B[j, prow] != 0 L_data: np.ndarray, # (nnz_B,), complex128/float64 -- B[j, prow] R_row_ptr: np.ndarray, # (N+1,), int32 -- rows of B R_col_idx: np.ndarray, # (nnz_B,), int32 -- projected cols pcol with B[j, pcol] != 0 R_data: np.ndarray, # (nnz_B,), complex128/float64 -- B[j, pcol] ): """Build sparse triplets for a two-site operator in the momentum basis. Parameters ---------- op_list : ndarray Two-site factorized operator data (shape ``(2, n_sites, d_loc, d_loc)``). op_sites_list : ndarray Two site indices where the operator acts. sector_configs : ndarray Symmetry-sector configurations, one row per basis state. L_col_ptr, L_row_idx, L_data : ndarray CSC representation of the momentum-basis projector ``B``. R_row_ptr, R_col_idx, R_data : ndarray CSR representation of the same projector ``B``. Returns ------- tuple ``(row_list, col_list, value_list)`` triplets for the projected operator ``B^H H B``. """ _, n_sites = sector_configs.shape Ldim = L_col_ptr.shape[0] - 1 transition_counts_all, ket_local_states_all, transition_values_all = ( _prepare_momentum_local_transition_data(op_list, op_sites_list) ) # ------------------------------- # PASS 1: count nonzeros per projected row # ------------------------------- nnz_per_row = np.zeros(Ldim, np.int32) for prow in prange(Ldim): cnt = 0 # all real-space rows j1 with B[j1, prow] != 0 (CSC of B) start1 = L_col_ptr[prow] stop1 = L_col_ptr[prow + 1] for idx1 in range(start1, stop1): j1 = L_row_idx[idx1] bra_cfg = sector_configs[j1] bra_loc0 = bra_cfg[op_sites_list[0]] bra_loc1 = bra_cfg[op_sites_list[1]] len0 = transition_counts_all[0, bra_loc0] len1 = transition_counts_all[1, bra_loc1] # scratch ket config (start from bra each time) ket_cfg = np.empty(n_sites, np.int32) for jj in range(n_sites): ket_cfg[jj] = bra_cfg[jj] site0, site1 = op_sites_list[0], op_sites_list[1] # explicit 2-nested loops → real-space j2 for i0 in range(len0): b0 = ket_local_states_all[0, bra_loc0, i0] ket_cfg[site0] = b0 for i1 in range(len1): b1 = ket_local_states_all[1, bra_loc1, i1] ket_cfg[site1] = b1 # find j2 in the sector j2 = config_to_index_binarysearch(ket_cfg, sector_configs) if j2 < 0: continue # number of projected columns reachable from j2 (CSR of B) cnt += R_row_ptr[j2 + 1] - R_row_ptr[j2] nnz_per_row[prow] = cnt # prefix-sum → row offsets offset = 0 for prow in range(Ldim): tmp = nnz_per_row[prow] nnz_per_row[prow] = offset offset += tmp total_nnz = offset # ------------------------------- # PASS 2: fill triplets # ------------------------------- row_list = np.empty(total_nnz, np.int32) col_list = np.empty(total_nnz, np.int32) value_list = np.empty(total_nnz, np.complex128) for prow in prange(Ldim): ptr = nnz_per_row[prow] start1 = L_col_ptr[prow] stop1 = L_col_ptr[prow + 1] for idx1 in range(start1, stop1): j1 = L_row_idx[idx1] bra_cfg = sector_configs[j1] amp_L = np.conj(L_data[idx1]) # conj(B[j1, prow]) bra_loc0 = bra_cfg[op_sites_list[0]] bra_loc1 = bra_cfg[op_sites_list[1]] len0 = transition_counts_all[0, bra_loc0] len1 = transition_counts_all[1, bra_loc1] ket_cfg = np.empty(n_sites, np.int32) for jj in range(n_sites): ket_cfg[jj] = bra_cfg[jj] site0, site1 = op_sites_list[0], op_sites_list[1] for i0 in range(len0): b0 = ket_local_states_all[0, bra_loc0, i0] v0 = transition_values_all[0, bra_loc0, i0] ket_cfg[site0] = b0 for i1 in range(len1): b1 = ket_local_states_all[1, bra_loc1, i1] v1 = transition_values_all[1, bra_loc1, i1] ket_cfg[site1] = b1 j2 = config_to_index_binarysearch(ket_cfg, sector_configs) if j2 < 0: continue amp_M = v0 * v1 # project into momentum columns (CSR of B) start2 = R_row_ptr[j2] stop2 = R_row_ptr[j2 + 1] for tt in range(start2, stop2): pcol = R_col_idx[tt] amp_R = R_data[tt] # B[j2, pcol] val = amp_L * amp_M * amp_R row_list[ptr] = prow col_list[ptr] = pcol value_list[ptr] = val ptr += 1 return row_list, col_list, value_list
[docs] @njit(cache=True, parallel=True) def nbody_data_momentum_4sites( op_list: np.ndarray, # shape (4, n_sites, d_loc, d_loc) op_sites_list: np.ndarray, # shape (4,), int32 sector_configs: np.ndarray, # shape (N, n_sites), int32 # ---- momentum basis B in sparse form ---- L_col_ptr: np.ndarray, # (Ldim+1,), int32 L_row_idx: np.ndarray, # (nnz_B,), int32 L_data: np.ndarray, # (nnz_B,), complex128/float64 R_row_ptr: np.ndarray, # (N+1,), int32 R_col_idx: np.ndarray, # (nnz_B,), int32 R_data: np.ndarray, # (nnz_B,), complex128/float64 ): """Build sparse triplets for a four-site operator in the momentum basis. Parameters ---------- op_list : ndarray Four-site factorized operator data (shape ``(4, n_sites, d_loc, d_loc)``). op_sites_list : ndarray Four site indices where the operator acts. sector_configs : ndarray Symmetry-sector configurations, one row per basis state. L_col_ptr, L_row_idx, L_data : ndarray CSC representation of the momentum-basis projector ``B``. R_row_ptr, R_col_idx, R_data : ndarray CSR representation of the same projector ``B``. Returns ------- tuple ``(row_list, col_list, value_list)`` triplets for the projected operator ``B^H H B``. """ _, n_sites = sector_configs.shape Ldim = L_col_ptr.shape[0] - 1 transition_counts_all, ket_local_states_all, transition_values_all = ( _prepare_momentum_local_transition_data(op_list, op_sites_list) ) # ------------------------------- # PASS 1: count nonzeros per projected row # ------------------------------- nnz_per_row = np.zeros(Ldim, np.int32) for prow in prange(Ldim): cnt = 0 # all real-space rows j1 with B[j1, prow] != 0 (CSC of B) start1 = L_col_ptr[prow] stop1 = L_col_ptr[prow + 1] for idx1 in range(start1, stop1): j1 = L_row_idx[idx1] bra_cfg = sector_configs[j1] bra_loc0 = bra_cfg[op_sites_list[0]] bra_loc1 = bra_cfg[op_sites_list[1]] bra_loc2 = bra_cfg[op_sites_list[2]] bra_loc3 = bra_cfg[op_sites_list[3]] len0 = transition_counts_all[0, bra_loc0] len1 = transition_counts_all[1, bra_loc1] len2 = transition_counts_all[2, bra_loc2] len3 = transition_counts_all[3, bra_loc3] ket_cfg = np.empty(n_sites, np.int32) for jj in range(n_sites): ket_cfg[jj] = bra_cfg[jj] site0 = op_sites_list[0] site1 = op_sites_list[1] site2 = op_sites_list[2] site3 = op_sites_list[3] # explicit 4-nested loops → real-space j2 for i0 in range(len0): b0 = ket_local_states_all[0, bra_loc0, i0] ket_cfg[site0] = b0 for i1 in range(len1): b1 = ket_local_states_all[1, bra_loc1, i1] ket_cfg[site1] = b1 for i2 in range(len2): b2 = ket_local_states_all[2, bra_loc2, i2] ket_cfg[site2] = b2 for i3 in range(len3): b3 = ket_local_states_all[3, bra_loc3, i3] ket_cfg[site3] = b3 j2 = config_to_index_binarysearch(ket_cfg, sector_configs) if j2 < 0: continue # number of projected columns from j2 (CSR of B) cnt += R_row_ptr[j2 + 1] - R_row_ptr[j2] nnz_per_row[prow] = cnt # prefix-sum → row offsets offset = 0 for prow in range(Ldim): tmp = nnz_per_row[prow] nnz_per_row[prow] = offset offset += tmp total_nnz = offset # ------------------------------- # PASS 2: fill triplets # ------------------------------- row_list = np.empty(total_nnz, np.int32) col_list = np.empty(total_nnz, np.int32) value_list = np.empty(total_nnz, np.complex128) for prow in prange(Ldim): ptr = nnz_per_row[prow] start1 = L_col_ptr[prow] stop1 = L_col_ptr[prow + 1] for idx1 in range(start1, stop1): j1 = L_row_idx[idx1] bra_cfg = sector_configs[j1] amp_L = np.conj(L_data[idx1]) # conj(B[j1, prow]) bra_loc0 = bra_cfg[op_sites_list[0]] bra_loc1 = bra_cfg[op_sites_list[1]] bra_loc2 = bra_cfg[op_sites_list[2]] bra_loc3 = bra_cfg[op_sites_list[3]] len0 = transition_counts_all[0, bra_loc0] len1 = transition_counts_all[1, bra_loc1] len2 = transition_counts_all[2, bra_loc2] len3 = transition_counts_all[3, bra_loc3] ket_cfg = np.empty(n_sites, np.int32) for jj in range(n_sites): ket_cfg[jj] = bra_cfg[jj] site0 = op_sites_list[0] site1 = op_sites_list[1] site2 = op_sites_list[2] site3 = op_sites_list[3] # explicit 4-loops + projection for i0 in range(len0): b0 = ket_local_states_all[0, bra_loc0, i0] v0 = transition_values_all[0, bra_loc0, i0] ket_cfg[site0] = b0 for i1 in range(len1): b1 = ket_local_states_all[1, bra_loc1, i1] v1 = transition_values_all[1, bra_loc1, i1] ket_cfg[site1] = b1 for i2 in range(len2): b2 = ket_local_states_all[2, bra_loc2, i2] v2 = transition_values_all[2, bra_loc2, i2] ket_cfg[site2] = b2 for i3 in range(len3): b3 = ket_local_states_all[3, bra_loc3, i3] v3 = transition_values_all[3, bra_loc3, i3] ket_cfg[site3] = b3 j2 = config_to_index_binarysearch(ket_cfg, sector_configs) if j2 < 0: continue amp_M = v0 * v1 * v2 * v3 # project into momentum columns (CSR of B) start2 = R_row_ptr[j2] stop2 = R_row_ptr[j2 + 1] for tt in range(start2, stop2): pcol = R_col_idx[tt] amp_R = R_data[tt] # B[j2, pcol] val = amp_L * amp_M * amp_R row_list[ptr] = prow col_list[ptr] = pcol value_list[ptr] = val ptr += 1 return row_list, col_list, value_list