Source code for edlgt.tools.stabilizers

"""Stabilizer-support helpers built on mixed-radix encoded configurations."""

import numpy as np
from numba import njit, prange

from .config_encoding import (
    binary_search_sorted,
    compute_strides,
    decode_key_to_config,
    encode_all_configs,
)

__all__ = [
    "decode_Xstrings",
    "extract_support",
    "unique_sorted_int64",
    "all_pairwise_pkeys_support",
    "stabilizer_renyi_sum",
]


@njit(parallel=True, cache=True)
def decode_xstrings(xp_keys: np.ndarray, loc_dims: np.ndarray) -> np.ndarray:
    """Decode encoded X-string keys into per-site shift vectors."""
    n_strings = xp_keys.shape[0]
    n_sites = loc_dims.shape[0]
    x_strings = np.empty((n_strings, n_sites), dtype=np.uint16)
    for ii in prange(n_strings):
        x_strings[ii, :] = decode_key_to_config(xp_keys[ii], loc_dims)
    return x_strings


decode_Xstrings = decode_xstrings  # pylint: disable=invalid-name


[docs] def extract_support( psi: np.ndarray, loc_dims: np.ndarray, sector_configs: np.ndarray, prob_threshold: float = 1e-2, sort_for_encoding: bool = True, ): """Extract the dominant sector-basis support used by stabilizer routines. Parameters ---------- psi : numpy.ndarray State coefficients in the symmetry-sector basis. loc_dims : numpy.ndarray Local Hilbert-space dimensions, one per site. sector_configs : numpy.ndarray Basis configurations corresponding to ``psi``. prob_threshold : float, optional Allowed discarded probability mass. sort_for_encoding : bool, optional If ``True``, sort the support according to the mixed-radix encoding used by the stabilizer kernels. Returns ------- tuple ``(support_indices, support_coeffs, support_configs, support_keys, discarded_weight)`` describing the retained support. """ delta = float(prob_threshold) if delta <= 0.0 or delta >= 1.0: raise ValueError("prob_threshold must be 1> delta >0.") prob = np.abs(psi) ** 2 order_desc = np.argsort(-prob, kind="mergesort") prob_sorted = prob[order_desc] cum = np.cumsum(prob_sorted, dtype=np.float64) target = 1.0 - delta k_last = int(np.searchsorted(cum, target, side="left")) support_indices = order_desc[: k_last + 1].astype(np.int64) support_coeffs = psi[support_indices] support_configs = sector_configs[support_indices, :].astype(np.uint16, copy=False) kept_weight = float(np.sum(np.abs(support_coeffs) ** 2)) discarded_weight = max(0.0, 1.0 - kept_weight) strides = compute_strides(loc_dims) support_keys = encode_all_configs(support_configs, strides) if sort_for_encoding: sort_order = np.argsort(support_keys, kind="mergesort") support_indices = support_indices[sort_order] support_coeffs = support_coeffs[sort_order] support_configs = support_configs[sort_order, :] support_keys = support_keys[sort_order] return ( support_indices, support_coeffs, support_configs, support_keys, discarded_weight, )
[docs] @njit(cache=True) def unique_sorted_int64(arr_sorted: np.ndarray) -> np.ndarray: """Return unique values from a sorted int64 array.""" num_values = arr_sorted.shape[0] if num_values == 0: return arr_sorted num_unique = 1 prev = arr_sorted[0] for value_idx in range(1, num_values): value = arr_sorted[value_idx] if value != prev: num_unique += 1 prev = value out = np.empty(num_unique, dtype=np.int64) out[0] = arr_sorted[0] write_idx = 1 prev = arr_sorted[0] for value_idx in range(1, num_values): value = arr_sorted[value_idx] if value != prev: out[write_idx] = value write_idx += 1 prev = value return out
[docs] @njit(parallel=True, cache=True) def all_pairwise_pkeys_support( support_configs: np.ndarray, loc_dims: np.ndarray, strides: np.ndarray, ) -> np.ndarray: """Generate encoded X-string keys induced by all ordered support pairs.""" n_cfgs, n_sites = support_configs.shape out = np.empty(n_cfgs * n_cfgs, dtype=np.int64) for cfg_idx_row in prange(n_cfgs): base = cfg_idx_row * n_cfgs for cfg_idx_col in range(n_cfgs): key = np.int64(0) for site_idx in range(n_sites): dim_site = np.int64(loc_dims[site_idx]) config_row_value = np.int64(support_configs[cfg_idx_row, site_idx]) config_col_value = np.int64(support_configs[cfg_idx_col, site_idx]) shift_value = (config_col_value - config_row_value) % dim_site key += shift_value * strides[site_idx] out[base + cfg_idx_col] = key return out
@njit(cache=True, inline="always") def encode_shifted_key( config_row: np.ndarray, pvec: np.ndarray, loc_dims: np.ndarray, strides: np.ndarray ): """Encode the configuration obtained by adding per-site shifts.""" n_sites = loc_dims.shape[0] shifted_key = np.int64(0) for kk in range(n_sites): d_k = np.int64(loc_dims[kk]) alpha_k = np.int64(config_row[kk]) p_k = np.int64(pvec[kk]) beta_k = (alpha_k + p_k) % d_k shifted_key += beta_k * strides[kk] return shifted_key @njit(cache=True, inline="always") def encode_configs_pair_key( cfg1: np.ndarray, cfg2: np.ndarray, loc_dims: np.ndarray, strides: np.ndarray ): """Encode cfg1 + cfg2 componentwise modulo the local dimensions.""" n_sites = loc_dims.shape[0] shifted_key = np.int64(0) for kk in range(n_sites): dk = np.int64(loc_dims[kk]) a1_k = np.int64(cfg1[kk]) a2_k = np.int64(cfg2[kk]) b_k = (a1_k + a2_k) % dk shifted_key += b_k * strides[kk] return shifted_key @njit(cache=True) def exact_xstring_from_support( # pylint: disable=too-many-locals pkey, support_configs: np.ndarray, support_coeffs: np.ndarray, support_keys: np.ndarray, loc_dims: np.ndarray, strides: np.ndarray, ): """Compute the exact contribution of one X-string on a truncated support.""" n_support = support_configs.shape[0] pvec = decode_key_to_config(np.int64(pkey), loc_dims) Avals = np.empty(n_support, dtype=np.complex128) Acfg_idx = np.empty(n_support, dtype=np.int64) nnzA = 0 for cfg_idx in range(n_support): shifted_key = encode_shifted_key( support_configs[cfg_idx], pvec, loc_dims, strides ) shifted_cfg_idx = binary_search_sorted(support_keys, shifted_key) if shifted_cfg_idx >= 0: Avals[nnzA] = support_coeffs[cfg_idx] * np.conj( support_coeffs[shifted_cfg_idx] ) Acfg_idx[nnzA] = cfg_idx nnzA += 1 if nnzA == 0: return np.float64(0.0) n_pairs = nnzA * nnzA pair_keys = np.empty(n_pairs, dtype=np.int64) pair_vals = np.empty(n_pairs, dtype=np.complex128) tmp = 0 for ia in range(nnzA): cfg_i = support_configs[Acfg_idx[ia]] ai = Avals[ia] for ja in range(nnzA): cfg_j = support_configs[Acfg_idx[ja]] shifted_pair_key = encode_configs_pair_key(cfg_i, cfg_j, loc_dims, strides) pair_keys[tmp] = shifted_pair_key pair_vals[tmp] = ai * np.conj(Avals[ja]) tmp += 1 order = np.argsort(pair_keys) pair_keys = pair_keys[order] pair_vals = pair_vals[order] tp_value = np.float64(0.0) current_key = pair_keys[0] acc = np.complex128(0.0 + 0.0j) for pair_idx in range(n_pairs): kk = pair_keys[pair_idx] if kk != current_key: tp_value += acc.real * acc.real + acc.imag * acc.imag current_key = kk acc = pair_vals[pair_idx] else: acc += pair_vals[pair_idx] tp_value += acc.real * acc.real + acc.imag * acc.imag return tp_value exact_Xstring_from_support = exact_xstring_from_support # pylint: disable=invalid-name
[docs] @njit(parallel=True, cache=True) def stabilizer_renyi_sum( pkeys_uniq: np.ndarray, support_configs: np.ndarray, support_coeffs: np.ndarray, support_keys: np.ndarray, loc_dims: np.ndarray, strides: np.ndarray, ) -> np.float64: """Compute the stabilizer Renyi-2 sum on a truncated support.""" n_strings = pkeys_uniq.shape[0] Tp_array = np.zeros(n_strings, dtype=np.float64) for idx in prange(n_strings): Tp_array[idx] = exact_xstring_from_support( pkeys_uniq[idx], support_configs, support_coeffs, support_keys, loc_dims, strides, ) M2 = np.float64(0.0) for str_idx in range(n_strings): M2 += Tp_array[str_idx] return M2