"""Stabilizer-support helpers built on mixed-radix encoded configurations."""
import numpy as np
from numba import njit, prange
from .config_encoding import (
binary_search_sorted,
compute_strides,
decode_key_to_config,
encode_all_configs,
)
__all__ = [
"decode_Xstrings",
"extract_support",
"unique_sorted_int64",
"all_pairwise_pkeys_support",
"stabilizer_renyi_sum",
]
@njit(parallel=True, cache=True)
def decode_xstrings(xp_keys: np.ndarray, loc_dims: np.ndarray) -> np.ndarray:
"""Decode encoded X-string keys into per-site shift vectors."""
n_strings = xp_keys.shape[0]
n_sites = loc_dims.shape[0]
x_strings = np.empty((n_strings, n_sites), dtype=np.uint16)
for ii in prange(n_strings):
x_strings[ii, :] = decode_key_to_config(xp_keys[ii], loc_dims)
return x_strings
decode_Xstrings = decode_xstrings # pylint: disable=invalid-name
[docs]
@njit(cache=True)
def unique_sorted_int64(arr_sorted: np.ndarray) -> np.ndarray:
"""Return unique values from a sorted int64 array."""
num_values = arr_sorted.shape[0]
if num_values == 0:
return arr_sorted
num_unique = 1
prev = arr_sorted[0]
for value_idx in range(1, num_values):
value = arr_sorted[value_idx]
if value != prev:
num_unique += 1
prev = value
out = np.empty(num_unique, dtype=np.int64)
out[0] = arr_sorted[0]
write_idx = 1
prev = arr_sorted[0]
for value_idx in range(1, num_values):
value = arr_sorted[value_idx]
if value != prev:
out[write_idx] = value
write_idx += 1
prev = value
return out
[docs]
@njit(parallel=True, cache=True)
def all_pairwise_pkeys_support(
support_configs: np.ndarray,
loc_dims: np.ndarray,
strides: np.ndarray,
) -> np.ndarray:
"""Generate encoded X-string keys induced by all ordered support pairs."""
n_cfgs, n_sites = support_configs.shape
out = np.empty(n_cfgs * n_cfgs, dtype=np.int64)
for cfg_idx_row in prange(n_cfgs):
base = cfg_idx_row * n_cfgs
for cfg_idx_col in range(n_cfgs):
key = np.int64(0)
for site_idx in range(n_sites):
dim_site = np.int64(loc_dims[site_idx])
config_row_value = np.int64(support_configs[cfg_idx_row, site_idx])
config_col_value = np.int64(support_configs[cfg_idx_col, site_idx])
shift_value = (config_col_value - config_row_value) % dim_site
key += shift_value * strides[site_idx]
out[base + cfg_idx_col] = key
return out
@njit(cache=True, inline="always")
def encode_shifted_key(
config_row: np.ndarray, pvec: np.ndarray, loc_dims: np.ndarray, strides: np.ndarray
):
"""Encode the configuration obtained by adding per-site shifts."""
n_sites = loc_dims.shape[0]
shifted_key = np.int64(0)
for kk in range(n_sites):
d_k = np.int64(loc_dims[kk])
alpha_k = np.int64(config_row[kk])
p_k = np.int64(pvec[kk])
beta_k = (alpha_k + p_k) % d_k
shifted_key += beta_k * strides[kk]
return shifted_key
@njit(cache=True, inline="always")
def encode_configs_pair_key(
cfg1: np.ndarray, cfg2: np.ndarray, loc_dims: np.ndarray, strides: np.ndarray
):
"""Encode cfg1 + cfg2 componentwise modulo the local dimensions."""
n_sites = loc_dims.shape[0]
shifted_key = np.int64(0)
for kk in range(n_sites):
dk = np.int64(loc_dims[kk])
a1_k = np.int64(cfg1[kk])
a2_k = np.int64(cfg2[kk])
b_k = (a1_k + a2_k) % dk
shifted_key += b_k * strides[kk]
return shifted_key
@njit(cache=True)
def exact_xstring_from_support( # pylint: disable=too-many-locals
pkey,
support_configs: np.ndarray,
support_coeffs: np.ndarray,
support_keys: np.ndarray,
loc_dims: np.ndarray,
strides: np.ndarray,
):
"""Compute the exact contribution of one X-string on a truncated support."""
n_support = support_configs.shape[0]
pvec = decode_key_to_config(np.int64(pkey), loc_dims)
Avals = np.empty(n_support, dtype=np.complex128)
Acfg_idx = np.empty(n_support, dtype=np.int64)
nnzA = 0
for cfg_idx in range(n_support):
shifted_key = encode_shifted_key(
support_configs[cfg_idx], pvec, loc_dims, strides
)
shifted_cfg_idx = binary_search_sorted(support_keys, shifted_key)
if shifted_cfg_idx >= 0:
Avals[nnzA] = support_coeffs[cfg_idx] * np.conj(
support_coeffs[shifted_cfg_idx]
)
Acfg_idx[nnzA] = cfg_idx
nnzA += 1
if nnzA == 0:
return np.float64(0.0)
n_pairs = nnzA * nnzA
pair_keys = np.empty(n_pairs, dtype=np.int64)
pair_vals = np.empty(n_pairs, dtype=np.complex128)
tmp = 0
for ia in range(nnzA):
cfg_i = support_configs[Acfg_idx[ia]]
ai = Avals[ia]
for ja in range(nnzA):
cfg_j = support_configs[Acfg_idx[ja]]
shifted_pair_key = encode_configs_pair_key(cfg_i, cfg_j, loc_dims, strides)
pair_keys[tmp] = shifted_pair_key
pair_vals[tmp] = ai * np.conj(Avals[ja])
tmp += 1
order = np.argsort(pair_keys)
pair_keys = pair_keys[order]
pair_vals = pair_vals[order]
tp_value = np.float64(0.0)
current_key = pair_keys[0]
acc = np.complex128(0.0 + 0.0j)
for pair_idx in range(n_pairs):
kk = pair_keys[pair_idx]
if kk != current_key:
tp_value += acc.real * acc.real + acc.imag * acc.imag
current_key = kk
acc = pair_vals[pair_idx]
else:
acc += pair_vals[pair_idx]
tp_value += acc.real * acc.real + acc.imag * acc.imag
return tp_value
exact_Xstring_from_support = exact_xstring_from_support # pylint: disable=invalid-name
[docs]
@njit(parallel=True, cache=True)
def stabilizer_renyi_sum(
pkeys_uniq: np.ndarray,
support_configs: np.ndarray,
support_coeffs: np.ndarray,
support_keys: np.ndarray,
loc_dims: np.ndarray,
strides: np.ndarray,
) -> np.float64:
"""Compute the stabilizer Renyi-2 sum on a truncated support."""
n_strings = pkeys_uniq.shape[0]
Tp_array = np.zeros(n_strings, dtype=np.float64)
for idx in prange(n_strings):
Tp_array[idx] = exact_xstring_from_support(
pkeys_uniq[idx],
support_configs,
support_coeffs,
support_keys,
loc_dims,
strides,
)
M2 = np.float64(0.0)
for str_idx in range(n_strings):
M2 += Tp_array[str_idx]
return M2