Source code for edlgt.tools.config_encoding

"""Low-level configuration encoding and lookup helpers.

This module groups together the mixed-radix helpers used to encode/decode
many-body configurations and the small lexicographic search kernels used on
sorted configuration tables.
"""

import numpy as np
from numba import njit, prange

__all__ = [
    "compute_strides",
    "encode_config",
    "encode_all_configs",
    "decode_key_to_config",
    "binary_search_sorted",
    "index_to_config",
    "config_to_index",
    "compare_configs",
    "config_to_index_linsearch",
    "config_to_index_binarysearch",
]


[docs] @njit(cache=True) def compute_strides(loc_dims: np.ndarray) -> np.ndarray: """Compute mixed-radix strides for encoding site configurations into int64 keys.""" n_sites = loc_dims.shape[0] strides = np.empty(n_sites, dtype=np.int64) running_stride = np.int64(1) for kk in range(n_sites - 1, -1, -1): strides[kk] = running_stride running_stride *= np.int64(loc_dims[kk]) return strides
[docs] @njit(cache=True) def encode_config(config: np.ndarray, strides: np.ndarray) -> np.int64: """Encode a single configuration into an int64 key using precomputed strides.""" key = np.int64(0) for kidx in range(config.shape[0]): key += np.int64(config[kidx]) * strides[kidx] return key
[docs] @njit(parallel=True, cache=True) def encode_all_configs(configs: np.ndarray, strides: np.ndarray) -> np.ndarray: """Encode many configurations into int64 keys using precomputed strides.""" n_configs, n_sites = configs.shape keys = np.empty(n_configs, dtype=np.int64) for ii in prange(n_configs): encoded_key = np.int64(0) for kk in range(n_sites): encoded_key += np.int64(configs[ii, kk]) * strides[kk] keys[ii] = encoded_key return keys
[docs] @njit(cache=True) def decode_key_to_config(key: np.int64, loc_dims: np.ndarray) -> np.ndarray: """Decode an int64 key back into a configuration vector.""" n_sites = loc_dims.shape[0] config = np.empty(n_sites, dtype=np.uint16) remainder = np.int64(key) for kk in range(n_sites - 1, -1, -1): dim_site = np.int64(loc_dims[kk]) config[kk] = np.uint16(remainder % dim_site) remainder //= dim_site return config
[docs] @njit(cache=True) def binary_search_sorted(keys_sorted: np.ndarray, target: np.int64) -> int: """Binary search on a sorted int64 array.""" lo = 0 hi = keys_sorted.shape[0] - 1 while lo <= hi: mid = (lo + hi) >> 1 mid_value = keys_sorted[mid] if mid_value < target: lo = mid + 1 elif mid_value > target: hi = mid - 1 else: return mid return -1
[docs] @njit(cache=True) def index_to_config(qmb_index, loc_dims): """Convert a linear many-body basis index to a site configuration.""" num_sites = len(loc_dims) config = np.zeros(num_sites, dtype=np.uint8) for site_index in range(num_sites - 1, -1, -1): dim = loc_dims[site_index] config[site_index] = qmb_index % dim qmb_index //= dim return config
[docs] @njit(cache=True) def config_to_index(config, loc_dims): """Convert a site configuration into a linear many-body basis index.""" qmb_index = 0 multiplier = 1 n_sites = len(config) for site_index in range(n_sites - 1, -1, -1): qmb_index += config[site_index] * multiplier multiplier *= loc_dims[site_index] return qmb_index
[docs] @njit(cache=True) def compare_configs(config1, config2): """Lexicographically compare two configurations.""" for site_index, config_value in enumerate(config1): if config_value < config2[site_index]: return -1 if config_value > config2[site_index]: return 1 return 0
[docs] @njit(cache=True) def config_to_index_linsearch(config, unique_configs): """Find a configuration index by linear search in a config table.""" for idx in range(unique_configs.shape[0]): match = True for site_index, config_value in enumerate(config): if config_value != unique_configs[idx, site_index]: match = False break if match: return idx return -1
[docs] @njit(cache=True) def config_to_index_binarysearch(config, unique_configs): """Find a configuration index by binary search in a sorted table.""" low = 0 high = len(unique_configs) - 1 while low <= high: idx = (low + high) // 2 comp_result = compare_configs(unique_configs[idx], config) if comp_result == 0: return idx if comp_result < 0: low = idx + 1 else: high = idx - 1 return -1