# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Rajdeep Singh
"""Encoders for transforming data into hypervectors.
This module provides various encoding strategies to transform different types
of data (discrete features, continuous values, images) into hypervectors.
"""
from dataclasses import dataclass, field
from typing import Optional, Union
import jax
import jax.numpy as jnp
from bayes_hdc import functional as F
from bayes_hdc._compat import register_dataclass
from bayes_hdc.constants import EPS
from bayes_hdc.vsa import VSAModel, create_vsa_model
[docs]
@register_dataclass
@dataclass
class RandomEncoder:
"""Encoder using random hypervectors for discrete features.
Each unique feature value is mapped to a random hypervector from a codebook.
Multiple features are bundled together to form the final representation.
"""
# Data fields (traced by JAX)
codebook: jax.Array # Shape: (num_features, num_values, dimensions)
# Metadata fields (static, not traced)
num_features: int = field(metadata=dict(static=True))
num_values: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs]
@staticmethod
def create(
num_features: int,
num_values: int,
dimensions: int = 10000,
vsa_model: Union[str, VSAModel] = "map",
key: Optional[jax.Array] = None,
) -> "RandomEncoder":
"""Create a random encoder.
Args:
num_features: Number of features to encode
num_values: Number of possible values per feature
dimensions: Dimensionality of hypervectors (default: 10000)
vsa_model: VSA model to use ('bsc', 'map', 'hrr', 'fhrr') or VSAModel instance
key: JAX random key (default: PRNGKey(0))
Returns:
Initialized RandomEncoder
"""
if key is None:
key = jax.random.PRNGKey(0)
# Handle both string and VSAModel
if isinstance(vsa_model, str):
vsa_model_name = vsa_model
vsa = create_vsa_model(vsa_model, dimensions)
else:
vsa_model_name = vsa_model.name
vsa = vsa_model
# Generate random codebook
codebook = vsa.random(key, shape=(num_features, num_values, dimensions))
return RandomEncoder(
codebook=codebook,
num_features=num_features,
num_values=num_values,
dimensions=dimensions,
vsa_model_name=vsa_model_name,
)
[docs]
@jax.jit
def encode(self, indices: jax.Array) -> jax.Array:
"""Encode discrete features as hypervectors.
Args:
indices: Feature indices of shape (num_features,) with values in [0, num_values).
Out-of-bounds indices are clamped to valid range.
Returns:
Encoded hypervector of shape (dimensions,)
"""
# Clamp indices to valid range to avoid out-of-bounds access
indices = jnp.clip(indices.astype(jnp.int32), 0, self.num_values - 1)
# Select hypervector for each feature
# codebook[i, indices[i]] selects the hypervector for feature i with value indices[i]
selected = jax.vmap(lambda i: self.codebook[i, indices[i]])(jnp.arange(self.num_features))
# Bundle all feature hypervectors
if self.vsa_model_name == "bsc":
return F.bundle_bsc(selected, axis=0)
else:
return F.bundle_map(selected, axis=0)
[docs]
@jax.jit
def encode_batch(self, indices: jax.Array) -> jax.Array:
"""Encode a batch of samples.
Args:
indices: Batch of feature indices of shape (batch_size, num_features)
Returns:
Encoded hypervectors of shape (batch_size, dimensions)
"""
return jax.vmap(self.encode)(indices)
[docs]
@register_dataclass
@dataclass
class LevelEncoder:
"""Encoder for continuous values using level hypervectors.
Continuous values are encoded by interpolating between level hypervectors,
creating a smooth representation where similar values map to similar hypervectors.
"""
# Data fields
level_hvs: jax.Array # Shape: (num_levels, dimensions)
# Metadata fields
num_levels: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
min_value: float = field(metadata=dict(static=True))
max_value: float = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
encoding_type: str = field(metadata=dict(static=True), default="linear")
[docs]
@staticmethod
def create(
num_levels: int = 100,
dimensions: int = 10000,
min_value: float = 0.0,
max_value: float = 1.0,
vsa_model: Union[str, VSAModel] = "map",
encoding_type: str = "linear",
key: Optional[jax.Array] = None,
) -> "LevelEncoder":
"""Create a level encoder.
Args:
num_levels: Number of levels for discretization (default: 100)
dimensions: Dimensionality of hypervectors (default: 10000)
min_value: Minimum value of the range (default: 0.0)
max_value: Maximum value of the range (default: 1.0)
vsa_model: VSA model to use ('bsc', 'map', 'hrr', 'fhrr')
encoding_type: 'linear' or 'circular' (default: 'linear')
key: JAX random key
Returns:
Initialized LevelEncoder
"""
if key is None:
key = jax.random.PRNGKey(0)
# Handle both string and VSAModel
if isinstance(vsa_model, str):
vsa_model_name = vsa_model
vsa = create_vsa_model(vsa_model, dimensions)
else:
vsa_model_name = vsa_model.name
vsa = vsa_model
if min_value >= max_value:
raise ValueError(
f"min_value ({min_value}) must be less than max_value ({max_value}). "
"Use a non-empty range for level encoding."
)
# Generate random level hypervectors
level_hvs = vsa.random(key, shape=(num_levels, dimensions))
return LevelEncoder(
level_hvs=level_hvs,
num_levels=num_levels,
dimensions=dimensions,
min_value=min_value,
max_value=max_value,
vsa_model_name=vsa_model_name,
encoding_type=encoding_type,
)
[docs]
@jax.jit
def encode(self, value: Union[float, jax.Array]) -> jax.Array:
"""Encode a continuous value as a hypervector.
Args:
value: Continuous value to encode (scalar or array)
Returns:
Encoded hypervector of shape (dimensions,) or batch shape + (dimensions,)
"""
# Normalize value to [0, num_levels - 1] (range validated at create time)
value_range = self.max_value - self.min_value
normalized = (value - self.min_value) / jnp.maximum(value_range, EPS)
normalized = jnp.clip(normalized, 0.0, 1.0)
level_pos = normalized * (self.num_levels - 1)
# Get lower and upper level indices
lower_idx = jnp.floor(level_pos).astype(jnp.int32)
upper_idx = jnp.ceil(level_pos).astype(jnp.int32)
# Interpolation weight
weight = level_pos - lower_idx
# Get level hypervectors
lower_hv = self.level_hvs[lower_idx]
upper_hv = self.level_hvs[upper_idx]
# Linear interpolation for real-valued models
if self.vsa_model_name in ["map", "hrr", "fhrr"]:
# Weighted combination
encoded = (1 - weight[..., None]) * lower_hv + weight[..., None] * upper_hv
# Normalize
norm = jnp.linalg.norm(encoded, axis=-1, keepdims=True)
return encoded / (norm + EPS)
else: # BSC
# For binary, use threshold-based selection
return jnp.where(weight[..., None] > 0.5, upper_hv, lower_hv)
[docs]
@jax.jit
def encode_batch(self, values: jax.Array) -> jax.Array:
"""Encode a batch of continuous values.
Args:
values: Batch of values of shape (batch_size,) or (batch_size, num_features)
Returns:
Encoded hypervectors
"""
return jax.vmap(self.encode)(values)
[docs]
@register_dataclass
@dataclass
class ProjectionEncoder:
"""Encoder using random projection for high-dimensional data.
Projects high-dimensional input data into hypervector space using a
random projection matrix. Useful for images, text embeddings, etc.
"""
# Data fields
projection_matrix: jax.Array # Shape: (input_dim, dimensions)
# Metadata fields
input_dim: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs]
@staticmethod
def create(
input_dim: int,
dimensions: int = 10000,
vsa_model: Union[str, VSAModel] = "map",
key: Optional[jax.Array] = None,
) -> "ProjectionEncoder":
"""Create a projection encoder.
Args:
input_dim: Dimensionality of input data
dimensions: Dimensionality of hypervectors (default: 10000)
vsa_model: VSA model to use ('bsc', 'map', 'hrr', 'fhrr')
key: JAX random key
Returns:
Initialized ProjectionEncoder
"""
if key is None:
key = jax.random.PRNGKey(0)
# Handle both string and VSAModel
if isinstance(vsa_model, str):
vsa_model_name = vsa_model
else:
vsa_model_name = vsa_model.name
# Create random projection matrix (normalized)
projection_matrix = jax.random.normal(key, shape=(input_dim, dimensions))
projection_matrix = projection_matrix / jnp.sqrt(input_dim)
return ProjectionEncoder(
projection_matrix=projection_matrix,
input_dim=input_dim,
dimensions=dimensions,
vsa_model_name=vsa_model_name,
)
[docs]
@jax.jit
def encode(self, x: jax.Array) -> jax.Array:
"""Encode input data as a hypervector.
Args:
x: Input data of shape (input_dim,)
Returns:
Encoded hypervector of shape (dimensions,)
"""
# Random projection
projected = jnp.dot(x, self.projection_matrix)
# Apply activation based on VSA model
if self.vsa_model_name == "bsc":
# Threshold for binary
return projected > 0
else:
# Normalize for real-valued
norm = jnp.linalg.norm(projected)
return projected / (norm + EPS)
[docs]
@jax.jit
def encode_batch(self, x: jax.Array) -> jax.Array:
"""Encode a batch of inputs.
Args:
x: Batch of inputs of shape (batch_size, input_dim)
Returns:
Encoded hypervectors of shape (batch_size, dimensions)
"""
return jax.vmap(self.encode)(x)
[docs]
@register_dataclass
@dataclass
class KernelEncoder:
"""Encoder using RBF kernel approximation (Random Fourier Features).
Approximates the RBF kernel k(x,y) = exp(-gamma ||x-y||^2) via random
Fourier features, mapping input to a hypervector space that preserves
kernel similarity.
"""
omega: jax.Array # Shape: (input_dim, n_features)
bias: jax.Array # Shape: (n_features,)
input_dim: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
gamma: float = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs]
@staticmethod
def create(
input_dim: int,
dimensions: int = 10000,
gamma: float = 1.0,
vsa_model: Union[str, VSAModel] = "map",
key: Optional[jax.Array] = None,
) -> "KernelEncoder":
"""Create a kernel encoder.
Args:
input_dim: Dimensionality of input data
dimensions: Dimensionality of output hypervectors
gamma: RBF kernel scale parameter (1 / 2*sigma^2)
vsa_model: VSA model ('map', 'hrr', 'fhrr' for real-valued)
key: JAX random key
Returns:
Initialized KernelEncoder
"""
if key is None:
key = jax.random.PRNGKey(0)
if isinstance(vsa_model, str):
vsa_model_name = vsa_model
else:
vsa_model_name = vsa_model.name
# Random Fourier features: omega ~ N(0, 2*gamma I), bias ~ U(0, 2*pi)
key_omega, key_bias = jax.random.split(key)
omega = jax.random.normal(key_omega, (input_dim, dimensions)) * jnp.sqrt(2.0 * gamma)
bias = jax.random.uniform(key_bias, (dimensions,), minval=0.0, maxval=2.0 * jnp.pi)
return KernelEncoder(
omega=omega,
bias=bias,
input_dim=input_dim,
dimensions=dimensions,
gamma=gamma,
vsa_model_name=vsa_model_name,
)
[docs]
@jax.jit
def encode(self, x: jax.Array) -> jax.Array:
"""Encode input using RBF kernel approximation."""
proj = jnp.dot(x, self.omega) + self.bias
features = jnp.cos(proj) * jnp.sqrt(2.0 / self.dimensions)
if self.vsa_model_name == "bsc":
return features > 0
norm = jnp.linalg.norm(features) + EPS
return features / norm
[docs]
@jax.jit
def encode_batch(self, x: jax.Array) -> jax.Array:
"""Encode a batch of inputs."""
return jax.vmap(self.encode)(x)
[docs]
@register_dataclass
@dataclass
class GraphEncoder:
"""Encoder for graph structures (nodes and edges).
Encodes a graph by assigning random hypervectors to nodes and
bundling bound node pairs for edges. Graph = bundle of edge HVs.
"""
node_embeddings: jax.Array # Shape: (num_nodes, dimensions)
num_nodes: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs]
@staticmethod
def create(
num_nodes: int,
dimensions: int = 10000,
vsa_model: Union[str, VSAModel] = "map",
key: Optional[jax.Array] = None,
) -> "GraphEncoder":
"""Create a graph encoder.
Args:
num_nodes: Maximum number of nodes
dimensions: Hypervector dimensionality
vsa_model: VSA model for real-valued graphs
key: JAX random key
Returns:
Initialized GraphEncoder
"""
if key is None:
key = jax.random.PRNGKey(0)
if isinstance(vsa_model, str):
vsa_model_name = vsa_model
vsa = create_vsa_model(vsa_model, dimensions)
else:
vsa_model_name = vsa_model.name
vsa = vsa_model
node_embeddings = vsa.random(key, (num_nodes, dimensions))
return GraphEncoder(
node_embeddings=node_embeddings,
num_nodes=num_nodes,
dimensions=dimensions,
vsa_model_name=vsa_model_name,
)
[docs]
def encode_edges(self, edges: jax.Array) -> jax.Array:
"""Encode graph as bundle of bound edge pairs.
Args:
edges: Array of shape (num_edges, 2) with node indices in [0, num_nodes).
Out-of-bounds indices are clamped to valid range.
Returns:
Graph hypervector of shape (dimensions,)
"""
edge_hvs = []
for i in range(edges.shape[0]):
# Clamp node indices to valid range to avoid out-of-bounds access
u = int(jnp.clip(edges[i, 0], 0, self.num_nodes - 1))
v = int(jnp.clip(edges[i, 1], 0, self.num_nodes - 1))
bound = F.bind_map(
self.node_embeddings[u],
F.permute(self.node_embeddings[v], 1),
)
edge_hvs.append(bound)
return F.bundle_map(jnp.stack(edge_hvs), axis=0)
@register_dataclass
@dataclass
class TokenEncoder:
r"""Encode token-ID sequences as positional hypervectors.
A vocabulary-sized codebook of random unit-norm hypervectors:
each token ID :math:`t \in [0, V)` is assigned a fixed random
vector :math:`c_t \in \mathbb{R}^d`. A sequence of token IDs is
encoded by looking up each token's vector and applying the
standard permute-bundle sequence encoding (Sahlgren et al. 2008;
Kanerva 2009) — flat for short sequences, hierarchical
(:class:`~bayes_hdc.structures.HierarchicalSequence`) for long
horizons where the per-position SNR of the flat construction
becomes the bottleneck.
Tokenizer-agnostic: takes integer token IDs from any source
(HuggingFace tokenizers, SentencePiece, tiktoken, BPE, character
indices). Tokenisation itself is the caller's responsibility::
# HuggingFace example (the library is a runtime dep of the user, not us):
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
ids = tok.encode("the quick brown fox", return_tensors="jax")[0]
encoder = TokenEncoder.create(vocab_size=tok.vocab_size, dimensions=10000)
seq_hv = encoder.encode(ids)
Attributes:
codebook: Codebook of shape ``(vocab_size, dimensions)``.
Stable across calls; persist or save with
:func:`jax.numpy.save` for reproducible encodings.
vocab_size: Token-vocabulary cardinality :math:`V`.
dimensions: Hypervector dimension :math:`d`.
"""
codebook: jax.Array # (vocab_size, d)
vocab_size: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
@staticmethod
def create(
vocab_size: int,
dimensions: int = 10_000,
vsa_model: Union[str, VSAModel] = "map",
key: Optional[jax.Array] = None,
) -> "TokenEncoder":
"""Build a fresh codebook of ``vocab_size`` random hypervectors.
Args:
vocab_size: Token-vocabulary cardinality.
dimensions: Hypervector dimension. Defaults to 10 000;
drop to ``4096`` for tighter memory at modest
capacity loss.
vsa_model: ``"map"`` (default — real-valued, unit-norm),
``"hrr"``, ``"fhrr"``, etc. For HDC text pipelines
MAP is the conventional choice; the codebook is
drawn from the chosen model and L2-normalised so
``permute_bundle_sequence`` retrieval works
downstream.
key: ``jax.random.PRNGKey``. Defaults to ``PRNGKey(0)``;
pass an explicit key for reproducible runs.
Returns:
A frozen :class:`TokenEncoder` ready for ``encode()``.
"""
if vocab_size < 1:
raise ValueError(f"vocab_size must be >= 1, got {vocab_size}")
if dimensions < 1:
raise ValueError(f"dimensions must be >= 1, got {dimensions}")
if key is None:
key = jax.random.PRNGKey(0)
if isinstance(vsa_model, str):
vsa = create_vsa_model(vsa_model, dimensions)
else:
vsa = vsa_model
codebook = vsa.random(key, shape=(vocab_size, dimensions))
# L2-normalise per-row so each token's hypervector lives on
# the unit sphere — required for the permute-bundle SNR
# analysis the sequence encoders rely on.
norms = jnp.linalg.norm(codebook, axis=-1, keepdims=True)
codebook = codebook / (norms + EPS)
return TokenEncoder(
codebook=codebook,
vocab_size=vocab_size,
dimensions=dimensions,
)
@jax.jit
def lookup(self, token_id: jax.Array) -> jax.Array:
"""Return the hypervector for a single token ID."""
return self.codebook[token_id]
@jax.jit
def lookup_batch(self, token_ids: jax.Array) -> jax.Array:
"""Return the per-token hypervectors for a sequence of IDs.
Args:
token_ids: Integer IDs of shape ``(T,)``.
Returns:
Hypervectors of shape ``(T, d)``.
"""
return self.codebook[token_ids]
@jax.jit
def encode(self, token_ids: jax.Array) -> jax.Array:
"""Flat permute-bundle encoding of a token sequence.
Returns the flat :class:`~bayes_hdc.structures.Sequence`
representation: ``Σ_i P^{T-1-i} c_{t_i}``. For long sequences
(``T ≳ 200`` at ``d = 4 096``) prefer
:meth:`encode_hierarchical`; see ``BENCHMARKS.md`` for the
capacity table.
"""
items = self.codebook[token_ids]
return F.bundle_sequence(items)
def encode_hierarchical(
self,
token_ids: jax.Array,
chunk_size: int = 16,
):
"""Two-level chunked encoding for long-horizon sequences.
Returns a :class:`~bayes_hdc.structures.HierarchicalSequence`
rather than a raw hypervector, because retrieval requires the
cached chunk codebook stored on that object. ``get(i)`` on
the result returns the (still-noisy) item hypervector at
position ``i``; clean against ``self.codebook`` for symbolic
recovery.
"""
# Local import to avoid the structures ↔ embeddings import cycle.
from bayes_hdc.structures import HierarchicalSequence
items = self.codebook[token_ids]
return HierarchicalSequence.from_vectors(items, chunk_size=chunk_size)
__all__ = [
"RandomEncoder",
"LevelEncoder",
"ProjectionEncoder",
"KernelEncoder",
"GraphEncoder",
"TokenEncoder",
]