Source code for bayes_hdc.utils
# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Rajdeep Singh
"""Utility functions for Bayes-HDC."""
import time
from typing import Any, Callable, Union
import jax
import jax.numpy as jnp
[docs]
def normalize(x: jax.Array, axis: int = -1, eps: float = 1e-8) -> jax.Array:
"""Normalize vectors to unit length.
Args:
x: Input array
axis: Axis along which to normalize (default: -1)
eps: Small constant to avoid division by zero
"""
norm = jnp.linalg.norm(x, axis=axis, keepdims=True)
return x / (norm + eps)
[docs]
def benchmark_function(
fn: Callable[..., Any], *args: Any, num_trials: int = 100, warmup: int = 10, **kwargs: Any
) -> dict[str, Union[float, int]]:
"""Benchmark a JAX function with proper warmup and async handling.
Args:
fn: Function to benchmark
*args: Positional arguments to fn
num_trials: Number of trials to run
warmup: Number of warmup trials
**kwargs: Keyword arguments to fn
Returns:
Dictionary with timing statistics (mean, std, min, max, median in ms)
"""
for _ in range(warmup):
result = fn(*args, **kwargs)
jax.block_until_ready(result)
times_list: list[float] = []
for _ in range(num_trials):
start = time.time()
result = fn(*args, **kwargs)
jax.block_until_ready(result)
end = time.time()
times_list.append((end - start) * 1000)
times = jnp.array(times_list)
return {
"mean_ms": float(jnp.mean(times)),
"std_ms": float(jnp.std(times)),
"min_ms": float(jnp.min(times)),
"max_ms": float(jnp.max(times)),
"median_ms": float(jnp.median(times)),
"num_trials": num_trials,
}
__all__ = [
"normalize",
"benchmark_function",
]