"""Partition-building helpers for subsystem/environment factorizations."""
import numpy as np
from numba import njit, prange
from edlgt.tools.config_encoding import (
compare_configs,
compute_strides,
config_to_index_binarysearch,
config_to_index_linsearch,
encode_all_configs,
)
__all__ = [
"exclude_columns",
"subenv_map_to_unique_indices",
"unique_configs_with_inverse",
"can_encode_partition_configs",
"build_partition_metadata",
]
[docs]
@njit(parallel=True, cache=True)
def exclude_columns(data_matrix, exclude_indices):
"""Return a copy of a matrix with selected columns removed."""
num_rows = data_matrix.shape[0]
num_cols = data_matrix.shape[1]
# Build the exclusion mask once so the parallel row loop only performs
# direct indexed copies of the kept source columns.
exclude_mask = np.zeros(num_cols, dtype=np.bool_)
for exclude_index in exclude_indices:
exclude_mask[exclude_index] = True
num_cols_remaining = num_cols - len(exclude_indices)
kept_indices = np.empty(num_cols_remaining, dtype=np.int64)
kept_count = 0
for col in range(num_cols):
if not exclude_mask[col]:
kept_indices[kept_count] = col
kept_count += 1
reduced_matrix = np.empty((num_rows, num_cols_remaining), dtype=data_matrix.dtype)
for row in prange(num_rows):
for new_col_idx in range(num_cols_remaining):
reduced_matrix[row, new_col_idx] = data_matrix[
row, kept_indices[new_col_idx]
]
return reduced_matrix
@njit(cache=True)
def _is_sorted_config_table(config_table: np.ndarray) -> bool:
"""Check whether a configuration table is lexicographically sorted."""
for idx in range(1, config_table.shape[0]):
if compare_configs(config_table[idx - 1], config_table[idx]) > 0:
return False
return True
[docs]
@njit(cache=True, parallel=True)
def subenv_map_to_unique_indices(
subsystem_configs: np.ndarray,
environment_configs: np.ndarray,
unique_subsys_configs: np.ndarray,
unique_env_configs: np.ndarray,
):
"""Map subsystem/environment rows to indices in their unique lookup tables."""
sector_dim = subsystem_configs.shape[0]
subsys_map = np.empty(sector_dim, dtype=np.int64)
env_map = np.empty(sector_dim, dtype=np.int64)
subsys_is_sorted = _is_sorted_config_table(unique_subsys_configs)
env_is_sorted = _is_sorted_config_table(unique_env_configs)
for idx in prange(sector_dim):
if env_is_sorted:
env_map[idx] = config_to_index_binarysearch(
environment_configs[idx], unique_env_configs
)
else:
env_map[idx] = config_to_index_linsearch(
environment_configs[idx], unique_env_configs
)
if subsys_is_sorted:
subsys_map[idx] = config_to_index_binarysearch(
subsystem_configs[idx], unique_subsys_configs
)
else:
subsys_map[idx] = config_to_index_linsearch(
subsystem_configs[idx], unique_subsys_configs
)
return subsys_map, env_map
[docs]
def can_encode_partition_configs(loc_dims: np.ndarray) -> bool:
"""Check whether a partition mixed-radix encoding fits in signed int64."""
max_key_plus_one = 1
int64_max = np.iinfo(np.int64).max
for dim in np.asarray(loc_dims, dtype=np.int64):
if dim <= 0:
raise ValueError("local dimensions must be positive")
if max_key_plus_one > int64_max // int(dim):
return False
max_key_plus_one *= int(dim)
return True
[docs]
def unique_configs_with_inverse(
configs: np.ndarray, loc_dims: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""Get unique rows and inverse map, preferring packed-key uniqueness."""
if configs.shape[1] == 0:
unique_configs = np.empty((1, 0), dtype=configs.dtype)
inverse_map = np.zeros(configs.shape[0], dtype=np.int64)
return unique_configs, inverse_map
if can_encode_partition_configs(loc_dims):
strides = compute_strides(np.asarray(loc_dims, dtype=np.int64))
encoded_configs = encode_all_configs(configs, strides)
_, unique_indices, inverse_map = np.unique(
encoded_configs, return_index=True, return_inverse=True
)
unique_configs = np.ascontiguousarray(configs[unique_indices])
return unique_configs, inverse_map
return np.unique(configs, axis=0, return_inverse=True)