Source code for bayes_hdc.functional

# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Rajdeep Singh

"""Core functional operations for Hyperdimensional Computing.

This module provides the fundamental operations for manipulating hypervectors:
- Binding: Combines two hypervectors into a dissimilar result
- Bundling: Aggregates multiple hypervectors into a similar result
- Permutation: Reorders elements to encode sequences
- Similarity: Measures relatedness between hypervectors
- Composite: Multi-vector bind/bundle, n-grams, sequences, hash tables
"""

import functools
from typing import Callable, Optional, Union

import jax
import jax.numpy as jnp

from bayes_hdc.constants import EPS


[docs] @jax.jit def bind_bsc(x: jax.Array, y: jax.Array) -> jax.Array: """Bind two hypervectors using XOR for Binary Spatter Codes. Binding creates a new hypervector that is dissimilar to both inputs. XOR is its own inverse, so unbinding is identical to binding. Args: x: Binary hypervector of shape (..., d) y: Binary hypervector of shape (..., d) Returns: Bound hypervector of shape (..., d), dissimilar to both x and y References: Kanerva, P. (1997). Fully Distributed Representation. In Proc. RWC '97, pp. 358-365. """ return jnp.logical_xor(x, y)
[docs] def bundle_bsc( vectors: jax.Array, axis: int = 0, key: Optional[jax.Array] = None, ) -> jax.Array: """Bundle hypervectors using majority rule for Binary Spatter Codes. Bundling creates a new hypervector similar to all inputs by taking the majority vote at each dimension. For an even number of input vectors, ties at exactly half the count are broken according to ``key``: * ``key is None`` (default) — deterministic: ties map to ``False``. This is fast, JIT-friendly, and matches the historical behaviour of this library. * ``key`` is a ``jax.random.PRNGKey`` — stochastic: ties are broken by independent fair coin flips per component, matching Kanerva (1997)'s prescription for an unbiased majority rule under even input counts. For an odd number of input vectors there is no tie and ``key`` has no effect. Args: vectors: Binary hypervectors of shape with axis containing vectors to bundle axis: Axis along which to bundle (default: 0) key: Optional ``jax.random.PRNGKey``. When provided, ties are broken by stochastic coin flip; when ``None``, ties map deterministically to ``False``. Returns: Bundled hypervector, similar to all inputs References: Kanerva, P. (1997). Fully Distributed Representation. In Proc. RWC '97, pp. 358-365. """ counts = jnp.sum(vectors, axis=axis) shape_size = vectors.shape[axis] threshold = shape_size / 2.0 if key is None: return counts > threshold # Stochastic tie-break: where counts == threshold, return a random bit. is_tie = counts == threshold coin = jax.random.bernoulli(key, p=0.5, shape=counts.shape) return jnp.where(is_tie, coin, counts > threshold)
[docs] @jax.jit def inverse_bsc(x: jax.Array) -> jax.Array: """Compute inverse for BSC (identity since XOR is self-inverse).""" return x # XOR is self-inverse
[docs] @jax.jit def hamming_similarity(x: jax.Array, y: jax.Array) -> jax.Array: """Compute normalized Hamming similarity between binary hypervectors. Returns the fraction of matching bits between two binary vectors. Random vectors have similarity ≈ 0.5. Args: x: Binary hypervector of shape (..., d) y: Binary hypervector of shape (..., d) Returns: Similarity score in [0, 1], where 1 is identical and 0.5 is random """ matches = jnp.logical_not(jnp.logical_xor(x, y)) return jnp.mean(matches.astype(jnp.float32), axis=-1)
[docs] @jax.jit def bind_map(x: jax.Array, y: jax.Array) -> jax.Array: """Bind two hypervectors using element-wise multiplication for MAP. For real-valued vectors (MAP model), binding is element-wise multiplication. The result is dissimilar to both inputs. Args: x: Real-valued hypervector of shape (..., d) y: Real-valued hypervector of shape (..., d) Returns: Bound hypervector of shape (..., d) """ return x * y
[docs] def bundle_map(vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle hypervectors using normalized sum for MAP. For real-valued vectors, bundling is the normalized sum. The result is similar to all inputs (high cosine similarity). Args: vectors: Real-valued hypervectors with axis containing vectors to bundle axis: Axis along which to bundle (default: 0) Returns: Bundled and normalized hypervector """ summed = jnp.sum(vectors, axis=axis) norm = jnp.linalg.norm(summed, axis=-1, keepdims=True) return summed / (norm + EPS)
[docs] @jax.jit def inverse_map(x: jax.Array, eps: float = EPS) -> jax.Array: """Compute inverse for MAP using element-wise reciprocal. For MAP binding (element-wise multiplication), the inverse is element-wise reciprocal: bind(bind(x, y), inverse(y)) = x. Near-zero elements return 0 (no inverse; bind with 0 destroys information). Args: x: Real-valued hypervector of shape (..., d) eps: Small constant for numerical stability (default: EPS) Returns: Inverse hypervector """ safe_inv = jnp.where(jnp.abs(x) > eps, 1.0 / x, 0.0) return safe_inv
def vector_intersect( x: jax.Array, y: jax.Array, atoms: jax.Array, ) -> jax.Array: r"""Holistic vector intersection (Gayler & Levy 2009). Given two bundle hypervectors ``x`` and ``y`` and a known atom set ``atoms`` of shape ``(N, d)``, return a bundle hypervector that contains *only* the atoms present in both ``x`` and ``y``, weighted by their joint projection. For each atom ``a_i`` we compute ``s_i = max(cos(x, a_i), 0) * max(cos(y, a_i), 0)`` — a non-negative joint-membership weight that is large when ``a_i`` is similar to both inputs and zero when it is absent from either. The output is :math:`\sum_i s_i \cdot a_i`. This is the explicit-atom-set realisation of the cleanup-memory construction in Gayler & Levy (2009, §"Distributed Implementation" Figure 2). It is a soft logical AND on bundles, the primitive that makes VSA-based graph isomorphism / analogical mapping possible. Args: x: First bundle hypervector of shape ``(d,)``. y: Second bundle hypervector of shape ``(d,)``. atoms: Known atom set of shape ``(N, d)``. The intersection is constrained to lie in the span of these atoms. Returns: Hypervector of shape ``(d,)`` representing ``x ∧ y`` — the atoms shared by both bundles, weighted by joint similarity. References: Gayler, R. W., Levy, S. D. (2009). A Distributed Basis for Analogical Mapping. In Proc. 2nd Int. Conf. on Analogy (ANALOGY-2009), pp. 165-174. """ # Per-atom cosine to each input. sim_x = jax.vmap(lambda a: cosine_similarity(x, a))(atoms) # (N,) sim_y = jax.vmap(lambda a: cosine_similarity(y, a))(atoms) # (N,) # Joint membership: capped at 0, multiplied. Atoms absent from either # input contribute nothing to the result. joint = jnp.maximum(sim_x, 0.0) * jnp.maximum(sim_y, 0.0) # (N,) # Bundle of atoms weighted by joint membership. return jnp.sum(atoms * joint[:, None], axis=0) @jax.jit def transformation_vector(a: jax.Array, b: jax.Array) -> jax.Array: r"""Construct a transformation hypervector :math:`T = a^{-1} \star b`. The transformation vector encodes "the rule that maps :math:`a` to :math:`b`": applying it via :func:`bind_map` recovers :math:`b` from :math:`a`, and bundles of transformation vectors over many example pairs serve as a learned rule under the Rasmussen-Eliasmith (2011) inductive-reasoning recipe and Kanerva's (2010) "Dollar of Mexico" analogical-mapping construction. For Binary Spatter Codes (where XOR is self-inverse and ``inverse_bsc`` is the identity) this collapses to ``bind_bsc(a, b)``; for MAP it is ``bind_map(inverse_map(a), b)``. Args: a: Source hypervector of shape ``(..., d)``. b: Target hypervector of shape ``(..., d)``. Returns: The transformation hypervector ``inverse(a) * b`` (real-valued MAP convention; element-wise reciprocal of ``a`` followed by element-wise product with ``b``). References: Kanerva, P. (2010). What We Mean When We Say "What's the Dollar of Mexico?": Prototypes and Mapping in Concept Space. Rasmussen, D., Eliasmith, C. (2011). A Neural Model of Rule Generation in Inductive Reasoning. Topics in Cognitive Science 3(1): 140-153. """ return bind_map(inverse_map(a), b)
[docs] @jax.jit def cosine_similarity(x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity between real-valued hypervectors. Returns the cosine of the angle between two vectors. Random unit vectors have similarity ≈ 0. Args: x: Real-valued hypervector of shape (..., d) y: Real-valued hypervector of shape (..., d) Returns: Similarity score in [-1, 1], where 1 is identical, -1 is opposite, and 0 is orthogonal """ x_norm = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + EPS) y_norm = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + EPS) return jnp.clip(jnp.sum(x_norm * y_norm, axis=-1), -1.0, 1.0)
[docs] @jax.jit def permute(x: jax.Array, shifts: int = 1) -> jax.Array: """Cyclically permute a hypervector to encode sequence information. Permutation reorders elements to represent positional or sequential information. Cyclic shifts preserve the distribution of values. Args: x: Hypervector of shape (..., d) shifts: Number of positions to shift (default: 1) Returns: Permuted hypervector of shape (..., d) """ return jnp.roll(x, shifts, axis=-1)
[docs] @functools.partial(jax.jit, static_argnames=("return_similarity",)) def cleanup( query: jax.Array, memory: jax.Array, similarity_fn: Callable[[jax.Array, jax.Array], jax.Array] = cosine_similarity, return_similarity: bool = False, ) -> Union[jax.Array, tuple[jax.Array, jax.Array]]: """Find the most similar vector in memory to the query. Cleanup is used to retrieve the closest known hypervector from memory, useful for error correction and symbol retrieval after a bind / unbind sequence has introduced approximation noise. This is the abstract-vector cleanup operation of Kanerva (2009); for the spiking-neuron implementation see Stewart, Tang & Eliasmith (2010, *Cognitive Systems Research* 12: 84-92), which is out of scope here. Resonator networks (Frady et al. 2020) are a related but distinct factorisation algorithm built on top of cleanup. Args: query: Query hypervector of shape (..., d) memory: Memory hypervectors of shape (n, d) similarity_fn: Function to compute similarity (default: cosine_similarity) return_similarity: Whether to return similarity scores (default: False) Returns: Most similar vector from memory, or (vector, similarity) if return_similarity=True References: Kanerva, P. (2009). Hyperdimensional Computing: An Introduction. Cognitive Computation 1(2): 139-159. """ similarities = jax.vmap(lambda m: similarity_fn(query, m))(memory) best_idx = jnp.argmax(similarities) best_vector = memory[best_idx] if return_similarity: return best_vector, similarities[best_idx] return best_vector
# Batch versions for common operations batch_bind_bsc = jax.vmap(bind_bsc, in_axes=(0, 0)) batch_bind_map = jax.vmap(bind_map, in_axes=(0, 0)) batch_hamming_similarity = jax.vmap(hamming_similarity, in_axes=(0, None)) batch_cosine_similarity = jax.vmap(cosine_similarity, in_axes=(0, None))
[docs] @jax.jit def bind_hrr(x: jax.Array, y: jax.Array) -> jax.Array: """Bind two hypervectors using circular convolution for HRR. Circular convolution in the spatial domain is equivalent to element-wise multiplication in the Fourier domain, making it efficient to compute via the FFT. This is the canonical HRR binding of Plate (1995, 1994/2003); Jones & Mewhort (2007) is the canonical cognitive-science application (the BEAGLE composite holographic lexicon). Args: x: Real-valued hypervector of shape (..., d) y: Real-valued hypervector of shape (..., d) Returns: Bound hypervector via circular convolution References: Plate, T. A. (1995). Holographic Reduced Representations. IEEE Transactions on Neural Networks 6(3): 623-641. Plate, T. A. (2003). Holographic Reduced Representation: Distributed Representation for Cognitive Structures. CSLI Publications. Jones, M. N., Mewhort, D. J. K. (2007). Representing Word Meaning and Order Information in a Composite Holographic Lexicon. Psychological Review 114(1): 1-37. """ x_fft = jnp.fft.fft(x, axis=-1) y_fft = jnp.fft.fft(y, axis=-1) result_fft = x_fft * y_fft return jnp.fft.ifft(result_fft, axis=-1).real
[docs] @jax.jit def inverse_hrr(x: jax.Array) -> jax.Array: """Approximate inverse for HRR — Plate's *involution* ``x*``. The involution is the vector ``x*`` with ``(x*)_i = x_{(-i) mod d}`` — i.e. the first element is preserved and the remaining ``d - 1`` elements are reversed. For a length-4 example ``[c_0, c_1, c_2, c_3]`` this returns ``[c_0, c_3, c_2, c_1]``, matching Plate (1995, §II.F) verbatim. The involution is an *approximate* inverse: ``bind_hrr(x, inverse_hrr(x))`` is approximately the unit impulse, with the approximation tightening as the dimension grows. It differs from the *exact* inverse ``F^{-1}(1 / F(x))``, which exists only when no Fourier coefficient of ``x`` vanishes; the library uses the involution because it is cheap, always defined, and the standard choice in HRR libraries. Args: x: Real-valued hypervector of shape (..., d) Returns: The involution ``x*``, shape (..., d). References: Plate, T. A. (1995). Holographic Reduced Representations. IEEE Transactions on Neural Networks 6(3): 623-641. (See §II.F for the involution definition.) """ return jnp.concatenate([x[..., :1], jnp.flip(x[..., 1:], axis=-1)], axis=-1)
# HRR bundle reuses MAP bundle bundle_hrr = bundle_map
[docs] @jax.jit def bind_cgr(x: jax.Array, y: jax.Array, q: int) -> jax.Array: """Bind using modular addition for Cyclic Group Representation. Args: x: Integer hypervector with values in {0, ..., q-1}, shape (..., d) y: Integer hypervector with values in {0, ..., q-1}, shape (..., d) q: Size of the cyclic group Returns: Bound hypervector: (x + y) mod q """ return (x + y) % q
[docs] def bundle_cgr(vectors: jax.Array, q: int, axis: int = 0) -> jax.Array: """Bundle using component-wise mode for CGR. Selects the most frequent value at each dimension. Args: vectors: Integer hypervectors with values in {0, ..., q-1} q: Size of the cyclic group axis: Axis along which to bundle (default: 0) Returns: Bundled hypervector with mode value at each dimension """ one_hot = jax.nn.one_hot(vectors, q) counts = jnp.sum(one_hot, axis=axis) return jnp.argmax(counts, axis=-1).astype(jnp.int32)
[docs] @jax.jit def inverse_cgr(x: jax.Array, q: int) -> jax.Array: """Inverse via modular negation for CGR. Args: x: Integer hypervector with values in {0, ..., q-1} q: Size of the cyclic group Returns: Inverse: (q - x) mod q """ return (q - x) % q
[docs] @jax.jit def matching_similarity(x: jax.Array, y: jax.Array) -> jax.Array: """Fraction of matching elements between integer hypervectors. Random vectors with q levels have expected similarity 1/q. Args: x: Integer hypervector of shape (..., d) y: Integer hypervector of shape (..., d) Returns: Similarity in [0, 1] """ return jnp.mean((x == y).astype(jnp.float32), axis=-1)
# MCR bind reuses CGR bind bind_mcr = bind_cgr # MCR inverse reuses CGR inverse inverse_mcr = inverse_cgr
[docs] def bundle_mcr(vectors: jax.Array, q: int, axis: int = 0) -> jax.Array: """Bundle using phasor sum and snap-to-grid for MCR. Converts indices to complex phasors (q-th roots of unity), sums them, and snaps back to the nearest discrete phase level. Args: vectors: Integer hypervectors with values in {0, ..., q-1} q: Number of phase levels axis: Axis along which to bundle (default: 0) Returns: Bundled hypervector with values snapped to nearest phase index """ phases = 2 * jnp.pi * vectors.astype(jnp.float32) / q phasors = jnp.exp(1j * phases) summed = jnp.sum(phasors, axis=axis) result_angles = jnp.angle(summed) % (2 * jnp.pi) result_indices = jnp.round(result_angles * q / (2 * jnp.pi)) % q return result_indices.astype(jnp.int32)
[docs] @jax.jit def phasor_similarity(x: jax.Array, y: jax.Array, q: int) -> jax.Array: """Similarity via phasor inner product for MCR. Converts to phasors and computes the real part of the normalized inner product. Random vectors have expected similarity ~0. Args: x: Integer hypervector with values in {0, ..., q-1} y: Integer hypervector with values in {0, ..., q-1} q: Number of phase levels Returns: Similarity in [-1, 1] """ phases_x = 2 * jnp.pi * x.astype(jnp.float32) / q phases_y = 2 * jnp.pi * y.astype(jnp.float32) / q phasors_x = jnp.exp(1j * phases_x) phasors_y = jnp.exp(1j * phases_y) return jnp.real(jnp.mean(phasors_x * jnp.conj(phasors_y), axis=-1))
[docs] @jax.jit def bind_vtb(x: jax.Array, y: jax.Array) -> jax.Array: """Bind using matrix multiplication for VTB. Reshapes d-dimensional vectors into sqrt(d) x sqrt(d) matrices and multiplies them. Requires d to be a perfect square. Args: x: Real-valued hypervector of shape (..., d) y: Real-valued hypervector of shape (..., d) Returns: Bound hypervector of shape (..., d) """ d = x.shape[-1] n = round(d**0.5) X = x.reshape(*x.shape[:-1], n, n) Y = y.reshape(*y.shape[:-1], n, n) return (X @ Y).reshape(*x.shape[:-1], d)
[docs] @jax.jit def inverse_vtb(x: jax.Array) -> jax.Array: """Inverse via matrix pseudoinverse for VTB.""" d = x.shape[-1] n = round(d**0.5) X = x.reshape(*x.shape[:-1], n, n) return jnp.linalg.pinv(X).reshape(*x.shape[:-1], d)
# VTB bundle reuses MAP bundle bundle_vtb = bundle_map # --------------------------------------------------------------------------- # Similarity: dot product # --------------------------------------------------------------------------- @jax.jit def dot_similarity(x: jax.Array, y: jax.Array) -> jax.Array: """Compute dot product similarity between hypervectors. Args: x: Hypervector of shape (..., d) y: Hypervector of shape (..., d) Returns: Dot product (scalar or batch) """ return jnp.sum(x * y, axis=-1) # --------------------------------------------------------------------------- # Negation (bundling inverse) # --------------------------------------------------------------------------- @jax.jit def negative_bsc(x: jax.Array) -> jax.Array: """Bundling inverse for BSC (bit flip).""" return jnp.logical_not(x) @jax.jit def negative_map(x: jax.Array) -> jax.Array: """Bundling inverse for MAP (element-wise negation).""" return -x # --------------------------------------------------------------------------- # Multi-vector operations # --------------------------------------------------------------------------- def multibind_map(vectors: jax.Array, axis: int = 0) -> jax.Array: """Bind all vectors along an axis via element-wise product (MAP/HRR). Generalises :func:`bind_map` to *n* vectors. Args: vectors: Array with shape (n, ..., d) axis: Axis along which to reduce (default: 0) Returns: Single hypervector equal to vectors[0] * vectors[1] * ... """ return jnp.prod(vectors, axis=axis) def multibind_bsc(vectors: jax.Array, axis: int = 0) -> jax.Array: """Bind all vectors along an axis via cumulative XOR (BSC). XOR of n binary vectors: bit is 1 iff an odd number of inputs are 1. Args: vectors: Boolean array with shape (n, ..., d) axis: Axis along which to reduce (default: 0) Returns: XOR-reduction of all vectors along axis """ counts = jnp.sum(vectors.astype(jnp.int32), axis=axis) return (counts % 2) == 1 def cross_product(set_a: jax.Array, set_b: jax.Array, bind_fn: Callable = bind_map) -> jax.Array: """Compute the cross product (all pairwise bindings) of two sets. Returns an array of shape (n, m, d) where element [i, j] is bind(set_a[i], set_b[j]). Args: set_a: First set of shape (n, d) set_b: Second set of shape (m, d) bind_fn: Binding function (default: bind_map) """ return jax.vmap(lambda a: jax.vmap(lambda b: bind_fn(a, b))(set_b))(set_a) # --------------------------------------------------------------------------- # Composite encodings # --------------------------------------------------------------------------- def hash_table( keys: jax.Array, values: jax.Array, bind_fn: Callable = bind_map, ) -> jax.Array: """Create a hash-table hypervector by bundling bound (key, value) pairs. hash_table = Σ bind(k_i, v_i) Args: keys: Key hypervectors of shape (n, d) values: Value hypervectors of shape (n, d) bind_fn: Binding function (default: bind_map) Returns: Hash-table hypervector of shape (d,) """ pairs = jax.vmap(bind_fn)(keys, values) return jnp.sum(pairs, axis=0) def ngrams( vectors: jax.Array, n: int = 3, bind_fn: Callable = bind_map, ) -> jax.Array: """Compute the n-gram representation of a sequence of hypervectors. Each n-gram is the binding of n positionally-permuted consecutive vectors, then all n-grams are bundled (summed). Args: vectors: Sequence of shape (m, d) n: Size of each n-gram (default: 3) bind_fn: Binding function (default: bind_map) Returns: N-gram hypervector of shape (d,) """ m = vectors.shape[0] if m < n: raise ValueError(f"Need at least {n} vectors for {n}-grams, got {m}") result = jnp.zeros(vectors.shape[-1]) for start in range(m - n + 1): gram = permute(vectors[start], shifts=n - 1) for offset in range(1, n): gram = bind_fn(gram, permute(vectors[start + offset], shifts=n - 1 - offset)) result = result + gram return result def bundle_sequence(vectors: jax.Array) -> jax.Array: """Encode a sequence by bundling position-permuted vectors. sequence = Σ permute(v_i, shifts=m-1-i) Preserves order information through positional permutation. Args: vectors: Sequence of shape (m, d) Returns: Sequence hypervector of shape (d,) """ m = vectors.shape[0] result = jnp.zeros(vectors.shape[-1]) for i in range(m): result = result + permute(vectors[i], shifts=m - 1 - i) return result def bind_sequence( vectors: jax.Array, bind_fn: Callable = bind_map, ) -> jax.Array: """Encode a sequence by binding position-permuted vectors. sequence = Π permute(v_i, shifts=m-1-i) Binding-based sequences are more noise-resistant than bundle-based for short sequences, at the cost of lossy retrieval. Args: vectors: Sequence of shape (m, d) bind_fn: Binding function (default: bind_map) Returns: Sequence hypervector of shape (d,) """ m = vectors.shape[0] result = permute(vectors[0], shifts=m - 1) for i in range(1, m): result = bind_fn(result, permute(vectors[i], shifts=m - 1 - i)) return result def graph_encode( edges: jax.Array, node_hvs: jax.Array, *, directed: bool = False, bind_fn: Callable = bind_map, ) -> jax.Array: """Encode a graph as a single hypervector. Each edge (u, v) is encoded as bind(node_hvs[u], permute(node_hvs[v])) for directed graphs, or bind(node_hvs[u], node_hvs[v]) for undirected. All edge encodings are bundled. Args: edges: Edge list of shape (num_edges, 2) with node indices node_hvs: Node hypervectors of shape (num_nodes, d) directed: Whether edges are directed (default: False) bind_fn: Binding function (default: bind_map) Returns: Graph hypervector of shape (d,) """ def encode_edge(edge: jax.Array) -> jax.Array: u_hv = node_hvs[edge[0]] v_hv = node_hvs[edge[1]] if directed: return bind_fn(u_hv, permute(v_hv)) return bind_fn(u_hv, v_hv) edge_hvs = jax.vmap(encode_edge)(edges) return jnp.sum(edge_hvs, axis=0) def resonator( codebooks: list[jax.Array], target: jax.Array, *, max_iters: int = 100, bind_fn: Callable = bind_map, ) -> list[jax.Array]: """Resonator network for factorising a composite hypervector. Given codebooks C_1 ... C_k and a target = bind(f_1, ..., f_k), iteratively estimates each factor f_i. Uses early stopping when estimates converge. Args: codebooks: List of k codebooks, each of shape (n_i, d) target: Target hypervector of shape (d,) max_iters: Maximum iterations (default: 100) bind_fn: Binding function (default: bind_map) Returns: List of k estimated factor hypervectors """ k = len(codebooks) estimates = [codebooks[i][0] for i in range(k)] for _ in range(max_iters): new_estimates = [] for i in range(k): other = target for j in range(k): if j != i: inv = inverse_map(estimates[j]) other = bind_fn(other, inv) sims = jax.vmap(lambda c: cosine_similarity(other, c))(codebooks[i]) best = codebooks[i][jnp.argmax(sims)] new_estimates.append(best) converged = all(bool(jnp.allclose(new_estimates[i], estimates[i])) for i in range(k)) estimates = new_estimates if converged: break return estimates # --------------------------------------------------------------------------- # Additional similarity metrics (inspired by PyBHV) # --------------------------------------------------------------------------- @jax.jit def jaccard_similarity(x: jax.Array, y: jax.Array) -> jax.Array: """Jaccard similarity between binary hypervectors. |x AND y| / |x OR y|. Returns 1 for identical vectors, ~0.33 for random. Args: x: Binary hypervector of shape (..., d) y: Binary hypervector of shape (..., d) Returns: Jaccard index in [0, 1] """ intersection = jnp.sum(jnp.logical_and(x, y).astype(jnp.float32), axis=-1) union = jnp.sum(jnp.logical_or(x, y).astype(jnp.float32), axis=-1) return intersection / (union + EPS) @jax.jit def tversky_similarity( x: jax.Array, y: jax.Array, alpha: float = 1.0, beta: float = 1.0, ) -> jax.Array: """Tversky similarity index between binary hypervectors. Generalises Jaccard (alpha=beta=1) and Dice (alpha=beta=0.5). Args: x: Binary prototype hypervector of shape (..., d) y: Binary variant hypervector of shape (..., d) alpha: Weight for x-only elements (default: 1.0) beta: Weight for y-only elements (default: 1.0) Returns: Tversky index in [0, 1] """ x_f = x.astype(jnp.float32) y_f = y.astype(jnp.float32) intersection = jnp.sum(x_f * y_f, axis=-1) x_only = jnp.sum(x_f * (1 - y_f), axis=-1) y_only = jnp.sum((1 - x_f) * y_f, axis=-1) return intersection / (intersection + alpha * x_only + beta * y_only + EPS) # --------------------------------------------------------------------------- # Selection and threshold operations (inspired by PyBHV) # --------------------------------------------------------------------------- @jax.jit def select_bsc(cond: jax.Array, when_true: jax.Array, when_false: jax.Array) -> jax.Array: """Element-wise MUX for binary hypervectors. Returns when_true where cond is True, when_false otherwise. Args: cond: Binary mask of shape (..., d) when_true: Binary hypervector returned where cond is True when_false: Binary hypervector returned where cond is False """ return jnp.where(cond, when_true, when_false) @jax.jit def select_map(cond: jax.Array, when_pos: jax.Array, when_neg: jax.Array) -> jax.Array: """Element-wise MUX for real-valued hypervectors. Selects when_pos where cond > 0, when_neg otherwise. Args: cond: Real-valued mask of shape (..., d) when_pos: Hypervector returned where cond > 0 when_neg: Hypervector returned where cond <= 0 """ return jnp.where(cond > 0, when_pos, when_neg) def threshold(vectors: jax.Array, t: int) -> jax.Array: """Generalised majority: bit is 1 when at least *t* of the input vectors have a 1. Equivalent to standard majority when ``t = n // 2 + 1`` (for odd n). Args: vectors: Binary hypervectors of shape (n, d) t: Minimum count for a bit to be set Returns: Binary hypervector of shape (d,) """ counts = jnp.sum(vectors.astype(jnp.int32), axis=0) return counts >= t def window(vectors: jax.Array, lo: int, hi: int) -> jax.Array: """Window vote: bit is 1 when the count of 1s is in [lo, hi]. Useful for agreement / disagreement filters. Args: vectors: Binary hypervectors of shape (n, d) lo: Minimum count (inclusive) hi: Maximum count (inclusive) Returns: Binary hypervector of shape (d,) """ counts = jnp.sum(vectors.astype(jnp.int32), axis=0) return (counts >= lo) & (counts <= hi) # --------------------------------------------------------------------------- # Noise injection # --------------------------------------------------------------------------- def flip_fraction(key: jax.Array, x: jax.Array, fraction: float = 0.1) -> jax.Array: """Randomly flip a fraction of bits in a binary hypervector. Useful for generating noisy variants at a controlled Hamming distance. Args: key: JAX PRNG key x: Binary hypervector of shape (..., d) fraction: Fraction of bits to flip, in [0, 1] Returns: Noisy binary hypervector """ mask = jax.random.bernoulli(key, fraction, shape=x.shape) return jnp.logical_xor(x, mask) def add_noise_map(key: jax.Array, x: jax.Array, noise_level: float = 0.1) -> jax.Array: """Add Gaussian noise to a real-valued hypervector and re-normalise. Args: key: JAX PRNG key x: Real-valued hypervector of shape (..., d) noise_level: Standard deviation of the noise Returns: Noisy normalised hypervector """ noisy = x + noise_level * jax.random.normal(key, shape=x.shape) norm = jnp.linalg.norm(noisy, axis=-1, keepdims=True) return noisy / (norm + EPS) # --------------------------------------------------------------------------- # Quantisation # --------------------------------------------------------------------------- @jax.jit def fractional_power(x: jax.Array, p: float) -> jax.Array: """Raise a MAP hypervector to a fractional power. Computes sign(x) * |x|^p element-wise. This smoothly interpolates between the zero vector (p -> 0) and x itself (p = 1), and can extrapolate beyond (p > 1). Widely used for encoding continuous attributes: bind(role, fractional_power(filler, value)) produces representations that vary smoothly with *value*. Args: x: Real-valued hypervector of shape (..., d) p: Exponent (typically in [0, 2]) Returns: Hypervector of shape (..., d) """ return jnp.sign(x) * jnp.abs(x) ** p @jax.jit def soft_quantize(x: jax.Array) -> jax.Array: """Apply tanh for soft bipolar quantisation.""" return jnp.tanh(x) @jax.jit def hard_quantize(x: jax.Array) -> jax.Array: """Snap each element to +1 or -1 (sign function, 0 maps to -1).""" return jnp.where(x > 0, 1.0, -1.0) __all__ = [ # BSC operations "bind_bsc", "bundle_bsc", "inverse_bsc", "negative_bsc", "hamming_similarity", # MAP operations "bind_map", "bundle_map", "inverse_map", "negative_map", "cosine_similarity", "dot_similarity", # HRR operations "bind_hrr", "bundle_hrr", "inverse_hrr", # CGR operations "bind_cgr", "bundle_cgr", "inverse_cgr", "matching_similarity", # MCR operations "bind_mcr", "bundle_mcr", "inverse_mcr", "phasor_similarity", # VTB operations "bind_vtb", "bundle_vtb", "inverse_vtb", # Universal operations "permute", "cleanup", # Multi-vector operations "multibind_map", "multibind_bsc", "cross_product", # Composite encodings "hash_table", "ngrams", "bundle_sequence", "bind_sequence", "graph_encode", "resonator", # Additional similarity metrics "jaccard_similarity", "tversky_similarity", # Selection and threshold "select_bsc", "select_map", "threshold", "window", # Noise injection "flip_fraction", "add_noise_map", # Power and quantisation "fractional_power", "soft_quantize", "hard_quantize", # Batch operations "batch_bind_bsc", "batch_bind_map", "batch_hamming_similarity", "batch_cosine_similarity", ]