# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Rajdeep Singh
"""Vector Symbolic Architecture (VSA) model implementations.
This module provides different VSA models, each with their own binding,
bundling, and similarity operations. All models follow a consistent API.
The term *Vector Symbolic Architecture* was coined by Gayler (2003) as
an umbrella for the family of fixed-dimensional algebraic models (HRR,
BSC, MAP, ...) that descend from Smolensky's tensor-product binding
(Smolensky 1990) but compress it to a fixed dimension.
References:
Gayler, R. W. (2003). Vector Symbolic Architectures answer Jackendoff's
challenges for cognitive neuroscience. In Proc. ICCS/ASCS-2003,
pp. 133-138. arXiv:cs/0412059.
Kanerva, P. (2009). Hyperdimensional Computing: An Introduction to
Computing in Distributed Representation with High-Dimensional Random
Vectors. Cognitive Computation 1(2): 139-159.
Kleyko, D., Rachkovskij, D. A., Osipov, E., Rahimi, A. (2023).
A Survey on Hyperdimensional Computing aka Vector Symbolic
Architectures, Part I. ACM Computing Surveys 55(6): Article 130.
"""
from dataclasses import dataclass, field
import jax
import jax.numpy as jnp
from bayes_hdc import functional as F
from bayes_hdc._compat import register_dataclass
from bayes_hdc.constants import EPS
[docs]
@register_dataclass
@dataclass
class VSAModel:
"""Base class for VSA models defining the interface."""
name: str = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
[docs]
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind two hypervectors."""
raise NotImplementedError
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle multiple hypervectors."""
raise NotImplementedError
[docs]
def inverse(self, x: jax.Array) -> jax.Array:
"""Compute the inverse of a hypervector."""
raise NotImplementedError
[docs]
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute similarity between hypervectors."""
raise NotImplementedError
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random hypervectors."""
raise NotImplementedError
[docs]
@register_dataclass
@dataclass
class BSC(VSAModel):
"""Binary Spatter Codes (BSC).
Binary hypervectors with XOR binding, majority bundling, Hamming similarity.
Originally introduced by Kanerva (1997) as the *Spatter Code*; the BSC
acronym was retro-fitted by the modern HDC literature (Kleyko et al. 2023
Part I §2.3.6). The canonical introduction to the operations is
Kanerva (2009).
References:
Kanerva, P. (1997). Fully Distributed Representation. In Proc. RWC '97,
pp. 358-365.
Kanerva, P. (2009). Hyperdimensional Computing: An Introduction.
Cognitive Computation 1(2): 139-159.
"""
[docs]
@staticmethod
def create(dimensions: int = 10000) -> "BSC":
"""Create a BSC model.
Args:
dimensions: Dimensionality of hypervectors (default: 10000)
Returns:
Initialized BSC model
"""
return BSC(name="bsc", dimensions=dimensions)
[docs]
@jax.jit
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind using XOR."""
return F.bind_bsc(x, y)
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle using majority rule."""
return F.bundle_bsc(vectors, axis=axis)
[docs]
@jax.jit
def inverse(self, x: jax.Array) -> jax.Array:
"""Inverse is identity for XOR."""
return F.inverse_bsc(x)
[docs]
@jax.jit
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute Hamming similarity."""
return F.hamming_similarity(x, y)
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random binary hypervectors.
Args:
key: JAX random key
shape: Shape of output array
Returns:
Random binary hypervectors with ~50% ones
"""
return jax.random.bernoulli(key, 0.5, shape=shape)
[docs]
@register_dataclass
@dataclass
class MAP(VSAModel):
"""Multiply-Add-Permute (MAP) coding.
Real-valued vectors with element-wise multiply binding, normalized sum
bundling, cosine similarity. The MAP scheme was introduced by
Gayler (1998) and is the running example in Gayler (2003) where the
term "Vector Symbolic Architecture" itself is coined.
References:
Gayler, R. W. (1998). Multiplicative binding, representation operators
and analogy. In K. Holyoak, D. Gentner, B. Kokinov (eds.), Advances in
Analogy Research, pp. 1-4. New Bulgarian University Press.
Gayler, R. W. (2003). Vector Symbolic Architectures answer Jackendoff's
challenges for cognitive neuroscience. arXiv:cs/0412059.
"""
[docs]
@staticmethod
def create(dimensions: int = 10000) -> "MAP":
"""Create a MAP model.
Args:
dimensions: Dimensionality of hypervectors (default: 10000)
Returns:
Initialized MAP model
"""
return MAP(name="map", dimensions=dimensions)
[docs]
@jax.jit
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind using element-wise multiplication."""
return F.bind_map(x, y)
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle using normalized sum."""
return F.bundle_map(vectors, axis=axis)
[docs]
@jax.jit
def inverse(self, x: jax.Array) -> jax.Array:
"""Inverse via element-wise reciprocal."""
return F.inverse_map(x)
[docs]
@jax.jit
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute cosine similarity."""
return F.cosine_similarity(x, y)
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random real-valued hypervectors.
Args:
key: JAX random key
shape: Shape of output array
Returns:
Random normalized hypervectors sampled from normal distribution
"""
vectors = jax.random.normal(key, shape=shape)
# Normalize to unit length
norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True)
return vectors / (norm + EPS)
[docs]
@register_dataclass
@dataclass
class HRR(VSAModel):
"""Holographic Reduced Representations (HRR).
Real-valued vectors with circular convolution binding, normalized sum
bundling, cosine similarity. HRR was introduced by Plate (1995, IEEE
TNN) as a fixed-dimensional alternative to Smolensky's tensor-product
binding; the book-length treatment is Plate (2003). Circular
convolution is the canonical single-argument shift-equivariant
bilinear operator on R^d (see ``bayes_hdc.equivariance``).
References:
Plate, T. A. (1995). Holographic Reduced Representations. IEEE
Transactions on Neural Networks 6(3): 623-641.
Plate, T. A. (2003). Holographic Reduced Representation: Distributed
Representation for Cognitive Structures. CSLI Publications.
"""
[docs]
@staticmethod
def create(dimensions: int = 10000) -> "HRR":
"""Create an HRR model.
Args:
dimensions: Dimensionality of hypervectors (default: 10000)
Returns:
Initialized HRR model
"""
return HRR(name="hrr", dimensions=dimensions)
[docs]
@jax.jit
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind using circular convolution."""
return F.bind_hrr(x, y)
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle using normalized sum."""
return F.bundle_hrr(vectors, axis=axis)
[docs]
@jax.jit
def inverse(self, x: jax.Array) -> jax.Array:
"""Inverse via element reversal."""
return F.inverse_hrr(x)
[docs]
@jax.jit
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute cosine similarity."""
return F.cosine_similarity(x, y)
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random real-valued hypervectors.
Args:
key: JAX random key
shape: Shape of output array
Returns:
Random normalized hypervectors sampled from normal distribution
"""
vectors = jax.random.normal(key, shape=shape)
# Normalize to unit length
norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True)
return vectors / (norm + EPS)
[docs]
@register_dataclass
@dataclass
class FHRR(VSAModel):
"""Fourier Holographic Reduced Representations (FHRR).
Complex-valued unit-phasor vectors with element-wise multiply binding,
normalized sum bundling. FHRR was introduced by Plate (1994/2003) as
the frequency-domain dual of HRR: circular convolution in the spatial
domain becomes element-wise complex multiplication in the Fourier
domain.
References:
Plate, T. A. (2003). Holographic Reduced Representation: Distributed
Representation for Cognitive Structures. CSLI Publications. (FHRR is
treated alongside HRR; see chapters on the Fourier-domain formulation.)
"""
[docs]
@staticmethod
def create(dimensions: int = 10000) -> "FHRR":
"""Create an FHRR model.
Args:
dimensions: Dimensionality of hypervectors (default: 10000)
Returns:
Initialized FHRR model
"""
return FHRR(name="fhrr", dimensions=dimensions)
[docs]
@jax.jit
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind using element-wise multiplication."""
return x * y
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle using normalized sum."""
summed = jnp.sum(vectors, axis=axis)
norm = jnp.linalg.norm(summed, axis=-1, keepdims=True)
return summed / (norm + EPS)
[docs]
@jax.jit
def inverse(self, x: jax.Array) -> jax.Array:
"""Inverse via complex conjugate."""
return jnp.conj(x)
[docs]
@jax.jit
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute cosine similarity of complex vectors."""
x_norm = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + EPS)
y_norm = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + EPS)
# Use real part of inner product, clip to handle floating point precision
return jnp.clip(jnp.real(jnp.sum(x_norm * jnp.conj(y_norm), axis=-1)), -1.0, 1.0)
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random complex hypervectors on unit circle.
Args:
key: JAX random key
shape: Shape of output array
Returns:
Random unit complex hypervectors
"""
# Random phases on unit circle
phases = jax.random.uniform(key, shape=shape, minval=0, maxval=2 * jnp.pi)
return jnp.exp(1j * phases)
[docs]
@register_dataclass
@dataclass
class BSBC(VSAModel):
"""Binary Sparse Block Codes (B-SBC).
Block-sparse binary vectors with k_active ones per block, XOR binding,
majority bundling. The BSC operations carry over directly (Kanerva
1997); the sparse-block construction follows the line traced in
Kleyko et al. (2023) Part I §2.3.7 (sparse binary HDC family).
References:
Kanerva, P. (1997). Fully Distributed Representation. In Proc. RWC '97,
pp. 358-365.
Kleyko, D., Rachkovskij, D. A., Osipov, E., Rahimi, A. (2023).
A Survey on HDC aka VSA, Part I. ACM Computing Surveys 55(6).
"""
block_size: int = field(metadata=dict(static=True), default=100)
k_active: int = field(metadata=dict(static=True), default=5)
[docs]
@staticmethod
def create(
dimensions: int = 10000,
block_size: int = 100,
k_active: int = 5,
) -> "BSBC":
"""Create a B-SBC model.
Args:
dimensions: Total dimensionality (must be divisible by block_size)
block_size: Size of each block
k_active: Number of ones per block (sparsity)
Returns:
Initialized BSBC model
"""
if dimensions % block_size != 0:
raise ValueError(
f"dimensions ({dimensions}) must be divisible by block_size ({block_size})"
)
if k_active > block_size or k_active < 1:
raise ValueError(f"k_active must be in [1, block_size], got {k_active}")
return BSBC(
name="bsbc",
dimensions=dimensions,
block_size=block_size,
k_active=k_active,
)
[docs]
@jax.jit
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind using XOR (same as BSC)."""
return F.bind_bsc(x, y)
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle using majority rule."""
return F.bundle_bsc(vectors, axis=axis)
[docs]
@jax.jit
def inverse(self, x: jax.Array) -> jax.Array:
"""Inverse is identity for XOR."""
return F.inverse_bsc(x)
[docs]
@jax.jit
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute Hamming similarity."""
return F.hamming_similarity(x, y)
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random block-sparse binary hypervectors."""
num_blocks = self.dimensions // self.block_size
def gen_block(key_b: jax.Array) -> jax.Array:
perm = jax.random.permutation(key_b, self.block_size)
block = jnp.zeros(self.block_size, dtype=jnp.bool_)
return block.at[perm[: self.k_active]].set(True)
batch_size = max(1, int(jnp.prod(jnp.array(shape))) // self.dimensions)
keys = jax.random.split(key, batch_size * num_blocks + 1)[1:]
keys_per_hv = jnp.reshape(
jnp.stack(keys[: batch_size * num_blocks]), (batch_size, num_blocks, 2)
)
def make_hv(block_keys: jax.Array) -> jax.Array:
blocks = jax.vmap(gen_block)(block_keys)
return jnp.reshape(blocks, (self.dimensions,))
hvs = jax.vmap(make_hv)(keys_per_hv)
if batch_size == 1 and shape == (self.dimensions,):
return hvs[0]
if batch_size == 1 and len(shape) == 1:
return hvs[0]
return jnp.reshape(hvs, shape)
[docs]
@register_dataclass
@dataclass
class CGR(VSAModel):
"""Cyclic Group Representation (CGR).
Integer hypervectors in Z_q with modular addition binding,
component-wise mode bundling.
"""
q: int = field(metadata=dict(static=True), default=8)
[docs]
@staticmethod
def create(dimensions: int = 10000, q: int = 8) -> "CGR":
if q < 2:
raise ValueError(f"q must be >= 2, got {q}")
return CGR(name="cgr", dimensions=dimensions, q=q)
[docs]
@jax.jit
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind using modular addition."""
return F.bind_cgr(x, y, self.q)
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle using component-wise mode."""
return F.bundle_cgr(vectors, self.q, axis=axis)
[docs]
@jax.jit
def inverse(self, x: jax.Array) -> jax.Array:
"""Inverse via modular negation."""
return F.inverse_cgr(x, self.q)
[docs]
@jax.jit
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute fraction of matching elements."""
return F.matching_similarity(x, y)
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random integer hypervectors in {0, ..., q-1}."""
return jax.random.randint(key, shape=shape, minval=0, maxval=self.q)
[docs]
@register_dataclass
@dataclass
class MCR(VSAModel):
"""Modular Composite Representation (MCR).
Integer phase vectors with modular addition binding, phasor sum bundling.
"""
q: int = field(metadata=dict(static=True), default=64)
[docs]
@staticmethod
def create(dimensions: int = 10000, q: int = 64) -> "MCR":
if q < 2:
raise ValueError(f"q must be >= 2, got {q}")
return MCR(name="mcr", dimensions=dimensions, q=q)
[docs]
@jax.jit
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind using modular addition (phase addition)."""
return F.bind_mcr(x, y, self.q)
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle using phasor sum with snap-to-grid."""
return F.bundle_mcr(vectors, self.q, axis=axis)
[docs]
@jax.jit
def inverse(self, x: jax.Array) -> jax.Array:
"""Inverse via modular negation (phase conjugate)."""
return F.inverse_mcr(x, self.q)
[docs]
@jax.jit
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute phasor similarity."""
return F.phasor_similarity(x, y, self.q)
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random integer hypervectors in {0, ..., q-1}."""
return jax.random.randint(key, shape=shape, minval=0, maxval=self.q)
[docs]
@register_dataclass
@dataclass
class VTB(VSAModel):
"""Vector-Derived Transformation Binding (VTB).
Real-valued vectors with matrix multiplication binding, normalized sum bundling.
"""
[docs]
@staticmethod
def create(dimensions: int = 10000) -> "VTB":
n = round(dimensions**0.5)
if n * n != dimensions:
raise ValueError(f"VTB requires dimensions to be a perfect square, got {dimensions}")
return VTB(name="vtb", dimensions=dimensions)
[docs]
@jax.jit
def bind(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Bind using matrix multiplication."""
return F.bind_vtb(x, y)
[docs]
def bundle(self, vectors: jax.Array, axis: int = 0) -> jax.Array:
"""Bundle using normalized sum."""
return F.bundle_vtb(vectors, axis=axis)
[docs]
@jax.jit
def inverse(self, x: jax.Array) -> jax.Array:
"""Inverse via matrix pseudoinverse."""
return F.inverse_vtb(x)
[docs]
@jax.jit
def similarity(self, x: jax.Array, y: jax.Array) -> jax.Array:
"""Compute cosine similarity."""
return F.cosine_similarity(x, y)
[docs]
def random(self, key: jax.Array, shape: tuple) -> jax.Array:
"""Generate random normalized real-valued hypervectors."""
vectors = jax.random.normal(key, shape=shape)
norm = jnp.linalg.norm(vectors, axis=-1, keepdims=True)
return vectors / (norm + EPS)
[docs]
def create_vsa_model(model_type: str = "map", dimensions: int = 10000) -> VSAModel:
"""Factory function to create VSA models.
Args:
model_type: Type of VSA model ('bsc', 'map', 'hrr', 'fhrr', 'bsbc', 'cgr', 'mcr', 'vtb')
dimensions: Dimensionality of hypervectors (default: 10000)
Returns:
Initialized VSA model
"""
models = {
"bsc": BSC,
"map": MAP,
"hrr": HRR,
"fhrr": FHRR,
"bsbc": BSBC,
"cgr": CGR,
"mcr": MCR,
"vtb": VTB,
}
if model_type not in models:
raise ValueError(
f"Unknown VSA model: {model_type}. Available models: {list(models.keys())}"
)
return models[model_type].create(dimensions=dimensions) # type: ignore[attr-defined]
__all__ = [
"VSAModel",
"BSC",
"MAP",
"HRR",
"FHRR",
"BSBC",
"CGR",
"MCR",
"VTB",
"create_vsa_model",
]