"""Low-level configuration encoding and lookup helpers.
This module groups together the mixed-radix helpers used to encode/decode
many-body configurations and the small lexicographic search kernels used on
sorted configuration tables.
"""
import numpy as np
from numba import njit, prange
__all__ = [
"compute_strides",
"encode_config",
"encode_all_configs",
"decode_key_to_config",
"binary_search_sorted",
"index_to_config",
"config_to_index",
"compare_configs",
"config_to_index_linsearch",
"config_to_index_binarysearch",
]
[docs]
@njit(cache=True)
def compute_strides(loc_dims: np.ndarray) -> np.ndarray:
"""Compute mixed-radix strides for encoding site configurations into int64 keys."""
n_sites = loc_dims.shape[0]
strides = np.empty(n_sites, dtype=np.int64)
running_stride = np.int64(1)
for kk in range(n_sites - 1, -1, -1):
strides[kk] = running_stride
running_stride *= np.int64(loc_dims[kk])
return strides
[docs]
@njit(cache=True)
def encode_config(config: np.ndarray, strides: np.ndarray) -> np.int64:
"""Encode a single configuration into an int64 key using precomputed strides."""
key = np.int64(0)
for kidx in range(config.shape[0]):
key += np.int64(config[kidx]) * strides[kidx]
return key
[docs]
@njit(parallel=True, cache=True)
def encode_all_configs(configs: np.ndarray, strides: np.ndarray) -> np.ndarray:
"""Encode many configurations into int64 keys using precomputed strides."""
n_configs, n_sites = configs.shape
keys = np.empty(n_configs, dtype=np.int64)
for ii in prange(n_configs):
encoded_key = np.int64(0)
for kk in range(n_sites):
encoded_key += np.int64(configs[ii, kk]) * strides[kk]
keys[ii] = encoded_key
return keys
[docs]
@njit(cache=True)
def decode_key_to_config(key: np.int64, loc_dims: np.ndarray) -> np.ndarray:
"""Decode an int64 key back into a configuration vector."""
n_sites = loc_dims.shape[0]
config = np.empty(n_sites, dtype=np.uint16)
remainder = np.int64(key)
for kk in range(n_sites - 1, -1, -1):
dim_site = np.int64(loc_dims[kk])
config[kk] = np.uint16(remainder % dim_site)
remainder //= dim_site
return config
[docs]
@njit(cache=True)
def binary_search_sorted(keys_sorted: np.ndarray, target: np.int64) -> int:
"""Binary search on a sorted int64 array."""
lo = 0
hi = keys_sorted.shape[0] - 1
while lo <= hi:
mid = (lo + hi) >> 1
mid_value = keys_sorted[mid]
if mid_value < target:
lo = mid + 1
elif mid_value > target:
hi = mid - 1
else:
return mid
return -1
[docs]
@njit(cache=True)
def index_to_config(qmb_index, loc_dims):
"""Convert a linear many-body basis index to a site configuration."""
num_sites = len(loc_dims)
config = np.zeros(num_sites, dtype=np.uint8)
for site_index in range(num_sites - 1, -1, -1):
dim = loc_dims[site_index]
config[site_index] = qmb_index % dim
qmb_index //= dim
return config
[docs]
@njit(cache=True)
def config_to_index(config, loc_dims):
"""Convert a site configuration into a linear many-body basis index."""
qmb_index = 0
multiplier = 1
n_sites = len(config)
for site_index in range(n_sites - 1, -1, -1):
qmb_index += config[site_index] * multiplier
multiplier *= loc_dims[site_index]
return qmb_index
[docs]
@njit(cache=True)
def compare_configs(config1, config2):
"""Lexicographically compare two configurations."""
for site_index, config_value in enumerate(config1):
if config_value < config2[site_index]:
return -1
if config_value > config2[site_index]:
return 1
return 0
[docs]
@njit(cache=True)
def config_to_index_linsearch(config, unique_configs):
"""Find a configuration index by linear search in a config table."""
for idx in range(unique_configs.shape[0]):
match = True
for site_index, config_value in enumerate(config):
if config_value != unique_configs[idx, site_index]:
match = False
break
if match:
return idx
return -1
[docs]
@njit(cache=True)
def config_to_index_binarysearch(config, unique_configs):
"""Find a configuration index by binary search in a sorted table."""
low = 0
high = len(unique_configs) - 1
while low <= high:
idx = (low + high) // 2
comp_result = compare_configs(unique_configs[idx], config)
if comp_result == 0:
return idx
if comp_result < 0:
low = idx + 1
else:
high = idx - 1
return -1