Source code for bayes_hdc.vsa

# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Rajdeep Singh

"""Vector Symbolic Architecture (VSA) model implementations.

This module provides different VSA models, each with their own binding,
bundling, and similarity operations. All models follow a consistent API.

The term *Vector Symbolic Architecture* was coined by Gayler (2003) as
an umbrella for the family of fixed-dimensional algebraic models (HRR,
BSC, MAP, ...) that descend from Smolensky's tensor-product binding
(Smolensky 1990) but compress it to a fixed dimension.

References:
Gayler, R. W. (2003). Vector Symbolic Architectures answer Jackendoff's
challenges for cognitive neuroscience. In Proc. ICCS/ASCS-2003,
pp. 133-138. arXiv:cs/0412059.
Kanerva, P. (2009). Hyperdimensional Computing: An Introduction to
Computing in Distributed Representation with High-Dimensional Random
Vectors. Cognitive Computation 1(2): 139-159.
Kleyko, D., Rachkovskij, D. A., Osipov, E., Rahimi, A. (2023).
A Survey on Hyperdimensional Computing aka Vector Symbolic
Architectures, Part I. ACM Computing Surveys 55(6): Article 130.
"""

from dataclasses import dataclass, field

import jax
import jax.numpy as jnp

from bayes_hdc import functional as F
from bayes_hdc._compat import register_dataclass
from bayes_hdc.constants import EPS


[docs] @register_dataclass @dataclass class VSAModel: """Base class for VSA models defining the interface.""" name: str = field(metadata=dict(static=True)) dimensions: int = field(metadata=dict(static=True))
[docs] def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind two hypervectors.""" raise NotImplementedError
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle multiple hypervectors.""" raise NotImplementedError
[docs] def inverse(self, x: jax.Array) -> jax.Array: """Compute the inverse of a hypervector.""" raise NotImplementedError
[docs] def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute similarity between hypervectors.""" raise NotImplementedError
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random hypervectors.""" raise NotImplementedError
[docs] @register_dataclass @dataclass class BSC(VSAModel): """Binary Spatter Codes (BSC). Binary hypervectors with XOR binding, majority bundling, Hamming similarity. Originally introduced by Kanerva (1997) as the *Spatter Code*; the BSC acronym was retro-fitted by the modern HDC literature (Kleyko et al. 2023 Part I §2.3.6). The canonical introduction to the operations is Kanerva (2009). References: Kanerva, P. (1997). Fully Distributed Representation. In Proc. RWC '97, pp. 358-365. Kanerva, P. (2009). Hyperdimensional Computing: An Introduction. Cognitive Computation 1(2): 139-159. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "BSC": """Create a BSC model. Args: dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized BSC model """ return BSC(name="bsc", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using XOR.""" return F.bind_bsc(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using majority rule.""" return F.bundle_bsc(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse is identity for XOR.""" return F.inverse_bsc(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute Hamming similarity.""" return F.hamming_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random binary hypervectors. Args: key: JAX random key shape: Shape of output array Returns: Random binary hypervectors with ~50% ones """ return jax.random.bernoulli(key, 0.5, shape=shape)
[docs] @register_dataclass @dataclass class MAP(VSAModel): """Multiply-Add-Permute (MAP) coding. Real-valued vectors with element-wise multiply binding, normalized sum bundling, cosine similarity. The MAP scheme was introduced by Gayler (1998) and is the running example in Gayler (2003) where the term "Vector Symbolic Architecture" itself is coined. References: Gayler, R. W. (1998). Multiplicative binding, representation operators and analogy. In K. Holyoak, D. Gentner, B. Kokinov (eds.), Advances in Analogy Research, pp. 1-4. New Bulgarian University Press. Gayler, R. W. (2003). Vector Symbolic Architectures answer Jackendoff's challenges for cognitive neuroscience. arXiv:cs/0412059. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "MAP": """Create a MAP model. Args: dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized MAP model """ return MAP(name="map", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using element-wise multiplication.""" return F.bind_map(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using normalized sum.""" return F.bundle_map(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via element-wise reciprocal.""" return F.inverse_map(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity.""" return F.cosine_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random real-valued hypervectors. Args: key: JAX random key shape: Shape of output array Returns: Random normalized hypervectors sampled from normal distribution """ vectors = jax.random.normal(key, shape=shape) # Normalize to unit length norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True) return vectors / (norm + EPS)
[docs] @register_dataclass @dataclass class HRR(VSAModel): """Holographic Reduced Representations (HRR). Real-valued vectors with circular convolution binding, normalized sum bundling, cosine similarity. HRR was introduced by Plate (1995, IEEE TNN) as a fixed-dimensional alternative to Smolensky's tensor-product binding; the book-length treatment is Plate (2003). Circular convolution is the canonical single-argument shift-equivariant bilinear operator on R^d (see ``bayes_hdc.equivariance``). References: Plate, T. A. (1995). Holographic Reduced Representations. IEEE Transactions on Neural Networks 6(3): 623-641. Plate, T. A. (2003). Holographic Reduced Representation: Distributed Representation for Cognitive Structures. CSLI Publications. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "HRR": """Create an HRR model. Args: dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized HRR model """ return HRR(name="hrr", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using circular convolution.""" return F.bind_hrr(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using normalized sum.""" return F.bundle_hrr(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via element reversal.""" return F.inverse_hrr(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity.""" return F.cosine_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random real-valued hypervectors. Args: key: JAX random key shape: Shape of output array Returns: Random normalized hypervectors sampled from normal distribution """ vectors = jax.random.normal(key, shape=shape) # Normalize to unit length norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True) return vectors / (norm + EPS)
[docs] @register_dataclass @dataclass class FHRR(VSAModel): """Fourier Holographic Reduced Representations (FHRR). Complex-valued unit-phasor vectors with element-wise multiply binding, normalized sum bundling. FHRR was introduced by Plate (1994/2003) as the frequency-domain dual of HRR: circular convolution in the spatial domain becomes element-wise complex multiplication in the Fourier domain. References: Plate, T. A. (2003). Holographic Reduced Representation: Distributed Representation for Cognitive Structures. CSLI Publications. (FHRR is treated alongside HRR; see chapters on the Fourier-domain formulation.) """
[docs] @staticmethod def create(dimensions: int = 10000) -> "FHRR": """Create an FHRR model. Args: dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized FHRR model """ return FHRR(name="fhrr", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using element-wise multiplication.""" return x * y
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using normalized sum.""" summed = jnp.sum(vectors, axis=axis) norm = jnp.linalg.norm(summed, axis=-1, keepdims=True) return summed / (norm + EPS)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via complex conjugate.""" return jnp.conj(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity of complex vectors.""" x_norm = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + EPS) y_norm = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + EPS) # Use real part of inner product, clip to handle floating point precision return jnp.clip(jnp.real(jnp.sum(x_norm * jnp.conj(y_norm), axis=-1)), -1.0, 1.0)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random complex hypervectors on unit circle. Args: key: JAX random key shape: Shape of output array Returns: Random unit complex hypervectors """ # Random phases on unit circle phases = jax.random.uniform(key, shape=shape, minval=0, maxval=2 * jnp.pi) return jnp.exp(1j * phases)
[docs] @register_dataclass @dataclass class BSBC(VSAModel): """Binary Sparse Block Codes (B-SBC). Block-sparse binary vectors with k_active ones per block, XOR binding, majority bundling. The BSC operations carry over directly (Kanerva 1997); the sparse-block construction follows the line traced in Kleyko et al. (2023) Part I §2.3.7 (sparse binary HDC family). References: Kanerva, P. (1997). Fully Distributed Representation. In Proc. RWC '97, pp. 358-365. Kleyko, D., Rachkovskij, D. A., Osipov, E., Rahimi, A. (2023). A Survey on HDC aka VSA, Part I. ACM Computing Surveys 55(6). """ block_size: int = field(metadata=dict(static=True), default=100) k_active: int = field(metadata=dict(static=True), default=5)
[docs] @staticmethod def create( dimensions: int = 10000, block_size: int = 100, k_active: int = 5, ) -> "BSBC": """Create a B-SBC model. Args: dimensions: Total dimensionality (must be divisible by block_size) block_size: Size of each block k_active: Number of ones per block (sparsity) Returns: Initialized BSBC model """ if dimensions % block_size != 0: raise ValueError( f"dimensions ({dimensions}) must be divisible by block_size ({block_size})" ) if k_active > block_size or k_active < 1: raise ValueError(f"k_active must be in [1, block_size], got {k_active}") return BSBC( name="bsbc", dimensions=dimensions, block_size=block_size, k_active=k_active, )
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using XOR (same as BSC).""" return F.bind_bsc(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using majority rule.""" return F.bundle_bsc(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse is identity for XOR.""" return F.inverse_bsc(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute Hamming similarity.""" return F.hamming_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random block-sparse binary hypervectors.""" num_blocks = self.dimensions // self.block_size def gen_block(key_b: jax.Array) -> jax.Array: perm = jax.random.permutation(key_b, self.block_size) block = jnp.zeros(self.block_size, dtype=jnp.bool_) return block.at[perm[: self.k_active]].set(True) batch_size = max(1, int(jnp.prod(jnp.array(shape))) // self.dimensions) keys = jax.random.split(key, batch_size * num_blocks + 1)[1:] keys_per_hv = jnp.reshape( jnp.stack(keys[: batch_size * num_blocks]), (batch_size, num_blocks, 2) ) def make_hv(block_keys: jax.Array) -> jax.Array: blocks = jax.vmap(gen_block)(block_keys) return jnp.reshape(blocks, (self.dimensions,)) hvs = jax.vmap(make_hv)(keys_per_hv) if batch_size == 1 and shape == (self.dimensions,): return hvs[0] if batch_size == 1 and len(shape) == 1: return hvs[0] return jnp.reshape(hvs, shape)
[docs] @register_dataclass @dataclass class CGR(VSAModel): """Cyclic Group Representation (CGR). Integer hypervectors in Z_q with modular addition binding, component-wise mode bundling. """ q: int = field(metadata=dict(static=True), default=8)
[docs] @staticmethod def create(dimensions: int = 10000, q: int = 8) -> "CGR": if q < 2: raise ValueError(f"q must be >= 2, got {q}") return CGR(name="cgr", dimensions=dimensions, q=q)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using modular addition.""" return F.bind_cgr(x, y, self.q)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using component-wise mode.""" return F.bundle_cgr(vectors, self.q, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via modular negation.""" return F.inverse_cgr(x, self.q)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute fraction of matching elements.""" return F.matching_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random integer hypervectors in {0, ..., q-1}.""" return jax.random.randint(key, shape=shape, minval=0, maxval=self.q)
[docs] @register_dataclass @dataclass class MCR(VSAModel): """Modular Composite Representation (MCR). Integer phase vectors with modular addition binding, phasor sum bundling. """ q: int = field(metadata=dict(static=True), default=64)
[docs] @staticmethod def create(dimensions: int = 10000, q: int = 64) -> "MCR": if q < 2: raise ValueError(f"q must be >= 2, got {q}") return MCR(name="mcr", dimensions=dimensions, q=q)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using modular addition (phase addition).""" return F.bind_mcr(x, y, self.q)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using phasor sum with snap-to-grid.""" return F.bundle_mcr(vectors, self.q, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via modular negation (phase conjugate).""" return F.inverse_mcr(x, self.q)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute phasor similarity.""" return F.phasor_similarity(x, y, self.q)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random integer hypervectors in {0, ..., q-1}.""" return jax.random.randint(key, shape=shape, minval=0, maxval=self.q)
[docs] @register_dataclass @dataclass class VTB(VSAModel): """Vector-Derived Transformation Binding (VTB). Real-valued vectors with matrix multiplication binding, normalized sum bundling. """
[docs] @staticmethod def create(dimensions: int = 10000) -> "VTB": n = round(dimensions**0.5) if n * n != dimensions: raise ValueError(f"VTB requires dimensions to be a perfect square, got {dimensions}") return VTB(name="vtb", dimensions=dimensions)
[docs] @jax.jit def bind(self, x: jax.Array, y: jax.Array) -> jax.Array: """Bind using matrix multiplication.""" return F.bind_vtb(x, y)
[docs] def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array: """Bundle using normalized sum.""" return F.bundle_vtb(vectors, axis=axis)
[docs] @jax.jit def inverse(self, x: jax.Array) -> jax.Array: """Inverse via matrix pseudoinverse.""" return F.inverse_vtb(x)
[docs] @jax.jit def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array: """Compute cosine similarity.""" return F.cosine_similarity(x, y)
[docs] def random(self, key: jax.Array, shape: tuple) -> jax.Array: """Generate random normalized real-valued hypervectors.""" vectors = jax.random.normal(key, shape=shape) norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True) return vectors / (norm + EPS)
[docs] def create_vsa_model(model_type: str = "map", dimensions: int = 10000) -> VSAModel: """Factory function to create VSA models. Args: model_type: Type of VSA model ('bsc', 'map', 'hrr', 'fhrr', 'bsbc', 'cgr', 'mcr', 'vtb') dimensions: Dimensionality of hypervectors (default: 10000) Returns: Initialized VSA model """ models = { "bsc": BSC, "map": MAP, "hrr": HRR, "fhrr": FHRR, "bsbc": BSBC, "cgr": CGR, "mcr": MCR, "vtb": VTB, } if model_type not in models: raise ValueError( f"Unknown VSA model: {model_type}. Available models: {list(models.keys())}" ) return models[model_type].create(dimensions=dimensions) # type: ignore[attr-defined]
__all__ = [ "VSAModel", "BSC", "MAP", "HRR", "FHRR", "BSBC", "CGR", "MCR", "VTB", "create_vsa_model", ]