# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Rajdeep Singh
"""Classification and learning models for Hyperdimensional Computing."""
from dataclasses import dataclass, field
from dataclasses import replace as dataclass_replace
from typing import Any, Optional, Union
import jax
import jax.numpy as jnp
from bayes_hdc import functional as F
from bayes_hdc._compat import register_dataclass
from bayes_hdc.constants import EPS
from bayes_hdc.vsa import VSAModel, create_vsa_model
[docs]
@register_dataclass
@dataclass
class CentroidClassifier:
"""Centroid-based classifier for HDC.
Stores one prototype hypervector per class. Classification finds the
most similar prototype to the query.
"""
prototypes: jax.Array # (num_classes, dimensions)
num_classes: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs]
@staticmethod
def create(
num_classes: int,
dimensions: int = 10000,
vsa_model: Union[str, VSAModel] = "map",
initial_prototypes: Optional[jax.Array] = None,
key: Optional[jax.Array] = None,
) -> "CentroidClassifier":
"""Create a centroid classifier.
Args:
num_classes: Number of classes
dimensions: Dimensionality of hypervectors
vsa_model: VSA model name or instance
initial_prototypes: Optional initial prototypes of shape (num_classes, dimensions)
key: JAX random key for initialization
"""
if isinstance(vsa_model, str):
vsa_model_name = vsa_model
vsa = create_vsa_model(vsa_model, dimensions)
else:
vsa_model_name = vsa_model.name
vsa = vsa_model
if initial_prototypes is not None:
prototypes = initial_prototypes
else:
if key is None:
key = jax.random.PRNGKey(0)
prototypes = vsa.random(key, shape=(num_classes, dimensions))
return CentroidClassifier(
prototypes=prototypes,
num_classes=num_classes,
dimensions=dimensions,
vsa_model_name=vsa_model_name,
)
[docs]
@jax.jit
def similarity(self, query: jax.Array) -> jax.Array:
"""Compute similarity between query and all class prototypes."""
if self.vsa_model_name == "bsc":
return jax.vmap(lambda p: F.hamming_similarity(query, p))(self.prototypes)
else:
return jax.vmap(lambda p: F.cosine_similarity(query, p))(self.prototypes)
[docs]
@jax.jit
def predict(self, queries: jax.Array) -> jax.Array:
"""Predict class labels for queries.
Args:
queries: Shape (batch_size, dimensions) or (dimensions,)
Returns:
Predicted class indices
"""
is_single = queries.ndim == 1
if is_single:
queries = queries[None, :]
similarities = jax.vmap(self.similarity)(queries)
predictions = jnp.argmax(similarities, axis=-1)
if is_single:
return predictions[0]
return predictions
[docs]
@jax.jit
def predict_proba(self, queries: jax.Array) -> jax.Array:
"""Predict class probabilities using softmax of similarities."""
is_single = queries.ndim == 1
if is_single:
queries = queries[None, :]
similarities = jax.vmap(self.similarity)(queries)
probs = jax.nn.softmax(similarities, axis=-1)
if is_single:
return probs[0]
return probs
[docs]
def fit(self, train_hvs: jax.Array, train_labels: jax.Array) -> "CentroidClassifier":
"""Train classifier by computing class centroids.
Args:
train_hvs: Training hypervectors of shape (n_samples, dimensions)
train_labels: Training labels of shape (n_samples,)
Returns:
Trained CentroidClassifier (new instance)
"""
if train_hvs.shape[0] == 0:
raise ValueError("Cannot fit CentroidClassifier: training data is empty")
new_prototypes_list = []
for class_idx in range(self.num_classes):
class_mask = train_labels == class_idx
num_samples = jnp.sum(class_mask)
if num_samples > 0:
weights = jnp.where(class_mask[:, None], 1.0, 0.0)
if self.vsa_model_name == "bsc":
weighted_hvs = train_hvs.astype(jnp.float32) * weights
summed = jnp.sum(weighted_hvs, axis=0)
centroid = summed > (num_samples / 2.0)
else:
weighted_hvs = train_hvs * weights
summed = jnp.sum(weighted_hvs, axis=0)
centroid = summed / (jnp.linalg.norm(summed) + EPS)
new_prototypes_list.append(centroid)
else:
new_prototypes_list.append(self.prototypes[class_idx])
return self.replace(prototypes=jnp.stack(new_prototypes_list))
[docs]
def update_online(
self, sample_hv: jax.Array, label: int, learning_rate: float = 0.1
) -> "CentroidClassifier":
"""Update classifier online with a single sample."""
old_prototype = self.prototypes[label]
if self.vsa_model_name == "bsc":
combined = jnp.stack([old_prototype, sample_hv])
new_prototype = F.bundle_bsc(combined, axis=0)
else:
new_prototype = (1 - learning_rate) * old_prototype + learning_rate * sample_hv
new_prototype = new_prototype / (jnp.linalg.norm(new_prototype) + EPS)
return self.replace(prototypes=self.prototypes.at[label].set(new_prototype))
[docs]
@jax.jit
def score(self, test_hvs: jax.Array, test_labels: jax.Array) -> jax.Array:
"""Compute accuracy on test data."""
predictions = self.predict(test_hvs)
return jnp.mean(predictions == test_labels)
[docs]
def replace(self, **updates: Any) -> "CentroidClassifier":
return dataclass_replace(self, **updates)
[docs]
@register_dataclass
@dataclass
class AdaptiveHDC:
"""Adaptive HDC classifier with iterative prototype refinement."""
prototypes: jax.Array
num_updates: jax.Array
num_classes: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs]
@staticmethod
def create(
num_classes: int,
dimensions: int = 10000,
vsa_model: Union[str, VSAModel] = "map",
key: Optional[jax.Array] = None,
) -> "AdaptiveHDC":
if isinstance(vsa_model, str):
vsa_model_name = vsa_model
vsa = create_vsa_model(vsa_model, dimensions)
else:
vsa_model_name = vsa_model.name
vsa = vsa_model
if key is None:
key = jax.random.PRNGKey(0)
return AdaptiveHDC(
prototypes=vsa.random(key, shape=(num_classes, dimensions)),
num_updates=jnp.zeros(num_classes, dtype=jnp.int32),
num_classes=num_classes,
dimensions=dimensions,
vsa_model_name=vsa_model_name,
)
[docs]
@jax.jit
def predict(self, queries: jax.Array) -> jax.Array:
"""Predict class labels."""
is_single = queries.ndim == 1
if is_single:
queries = queries[None, :]
if self.vsa_model_name == "bsc":
similarities = jax.vmap(
lambda q: jax.vmap(lambda p: F.hamming_similarity(q, p))(self.prototypes)
)(queries)
else:
similarities = jax.vmap(
lambda q: jax.vmap(lambda p: F.cosine_similarity(q, p))(self.prototypes)
)(queries)
predictions = jnp.argmax(similarities, axis=-1)
if is_single:
return predictions[0]
return predictions
[docs]
def fit(
self,
train_hvs: jax.Array,
train_labels: jax.Array,
epochs: int = 1,
learning_rate: float = 0.1,
) -> "AdaptiveHDC":
"""Train with iterative prototype refinement.
Initialises each class prototype with the unit-normalised
class mean, then walks the training set for `epochs` passes;
on each misclassification the true-class prototype is moved
toward the sample by `learning_rate` and re-normalised.
Accuracy-preserving, single-sided LVQ update.
Args:
train_hvs: Training hypervectors of shape ``(n, d)``.
train_labels: Training labels of shape ``(n,)``.
epochs: Number of refinement epochs after the centroid init.
learning_rate: Refinement step size.
"""
if train_hvs.shape[0] == 0:
raise ValueError("Cannot fit AdaptiveHDC: training data is empty")
classifier = self
for class_idx in range(self.num_classes):
class_mask = train_labels == class_idx
num_samples = jnp.sum(class_mask)
if num_samples > 0:
weights = jnp.where(class_mask[:, None], 1.0, 0.0)
if self.vsa_model_name == "bsc":
weighted_hvs = train_hvs.astype(jnp.float32) * weights
summed = jnp.sum(weighted_hvs, axis=0)
centroid = summed > (num_samples / 2.0)
else:
weighted_hvs = train_hvs * weights
summed = jnp.sum(weighted_hvs, axis=0)
centroid = summed / (jnp.linalg.norm(summed) + EPS)
classifier = classifier.replace(
prototypes=classifier.prototypes.at[class_idx].set(centroid)
)
for _epoch in range(epochs):
for i in range(len(train_hvs)):
pred = classifier.predict(train_hvs[i])
true_label = train_labels[i]
if pred != true_label:
classifier = classifier._update_prototypes(
train_hvs[i], true_label, pred, learning_rate
)
return classifier
def _update_prototypes(
self,
sample_hv: jax.Array,
true_label: Union[int, jax.Array],
pred_label: Union[int, jax.Array],
learning_rate: float,
) -> "AdaptiveHDC":
"""One-sided LVQ update: attract the true prototype toward the sample."""
true_proto = self.prototypes[true_label]
if self.vsa_model_name != "bsc":
new_true_proto = true_proto + learning_rate * sample_hv
new_true_proto = new_true_proto / (jnp.linalg.norm(new_true_proto) + EPS)
else:
new_true_proto = F.bundle_bsc(jnp.stack([true_proto, sample_hv]), axis=0)
return self.replace(prototypes=self.prototypes.at[true_label].set(new_true_proto))
[docs]
@jax.jit
def score(self, test_hvs: jax.Array, test_labels: jax.Array) -> jax.Array:
"""Compute accuracy."""
predictions = self.predict(test_hvs)
return jnp.mean(predictions == test_labels)
[docs]
def replace(self, **updates: Any) -> "AdaptiveHDC":
return dataclass_replace(self, **updates)
[docs]
@register_dataclass
@dataclass
class LVQClassifier:
"""Learning Vector Quantization classifier.
Prototypes are updated: move winner toward sample if correct,
away if wrong.
"""
prototypes: jax.Array
num_classes: int = field(metadata=dict(static=True))
dimensions: int = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
[docs]
@staticmethod
def create(
num_classes: int,
dimensions: int = 10000,
vsa_model: Union[str, VSAModel] = "map",
key: Optional[jax.Array] = None,
) -> "LVQClassifier":
if isinstance(vsa_model, str):
vsa = create_vsa_model(vsa_model, dimensions)
else:
vsa = vsa_model
if key is None:
key = jax.random.PRNGKey(0)
return LVQClassifier(
prototypes=vsa.random(key, (num_classes, dimensions)),
num_classes=num_classes,
dimensions=dimensions,
vsa_model_name=vsa.name,
)
[docs]
@jax.jit
def predict(self, queries: jax.Array) -> jax.Array:
"""Predict class labels by nearest prototype."""
is_single = queries.ndim == 1
if is_single:
queries = queries[None, :]
if self.vsa_model_name == "bsc":
sims = jax.vmap(
lambda q: jax.vmap(lambda p: F.hamming_similarity(q, p))(self.prototypes)
)(queries)
else:
sims = jax.vmap(
lambda q: jax.vmap(lambda p: F.cosine_similarity(q, p))(self.prototypes)
)(queries)
preds = jnp.argmax(sims, axis=-1)
return preds[0] if is_single else preds
[docs]
def fit(
self,
train_hvs: jax.Array,
train_labels: jax.Array,
epochs: int = 10,
lr: float = 0.1,
) -> "LVQClassifier":
"""Train with LVQ updates (winner-take-all, move toward/away)."""
if train_hvs.shape[0] == 0:
raise ValueError("Cannot fit LVQClassifier: training data is empty")
clf = self
for _ in range(epochs):
for i in range(len(train_hvs)):
x, y_true = train_hvs[i], int(train_labels[i])
pred = int(clf.predict(x))
if pred == y_true:
delta = lr * (x - clf.prototypes[pred])
else:
delta = -lr * (x - clf.prototypes[pred])
if self.vsa_model_name != "bsc":
new_p = clf.prototypes[pred] + delta
new_p = new_p / (jnp.linalg.norm(new_p) + EPS)
else:
new_p = F.bundle_bsc(
jnp.stack([clf.prototypes[pred], (clf.prototypes[pred] + delta) > 0.5]),
axis=0,
)
clf = clf.replace(prototypes=clf.prototypes.at[pred].set(new_p))
return clf
[docs]
@jax.jit
def score(self, test_hvs: jax.Array, test_labels: jax.Array) -> jax.Array:
preds = self.predict(test_hvs)
return jnp.mean(preds == test_labels)
[docs]
def replace(self, **updates: Any) -> "LVQClassifier":
return dataclass_replace(self, **updates)
[docs]
@register_dataclass
@dataclass
class RegularizedLSClassifier:
r"""Regularized Least Squares classifier in hypervector space.
Solves the ridge-regression objective
:math:`\min_W \|XW - Y\|_F^2 + \lambda \|W\|_F^2`
in closed form. Automatically selects primal or dual form based on
:math:`n` vs. :math:`d`:
- **Primal** (when :math:`n \geq d`):
:math:`W = (X^\top X + \lambda I_d)^{-1} X^\top Y`.
Conditioning on the :math:`d \times d` feature-covariance matrix.
- **Dual** (when :math:`n < d`, i.e. the typical HDC regime with
high-dim hypervectors and modest training sets):
:math:`W = X^\top (X X^\top + \lambda I_n)^{-1} Y`.
Conditioning on the :math:`n \times n` Gram matrix — numerically
well-behaved when :math:`d \gg n` and avoids the rank deficiency
that kills the primal form on small datasets.
The two forms are mathematically equivalent when both are well
posed; only the conditioning differs.
"""
weights: jax.Array # (dimensions, num_classes)
dimensions: int = field(metadata=dict(static=True))
num_classes: int = field(metadata=dict(static=True))
reg: float = field(metadata=dict(static=True))
[docs]
@staticmethod
def create(
dimensions: int,
num_classes: int,
reg: float = 1.0,
) -> "RegularizedLSClassifier":
return RegularizedLSClassifier(
weights=jnp.zeros((dimensions, num_classes)),
dimensions=dimensions,
num_classes=num_classes,
reg=reg,
)
[docs]
def fit(self, train_hvs: jax.Array, train_labels: jax.Array) -> "RegularizedLSClassifier":
"""Fit by solving regularised least squares.
Uses whichever of the primal (d×d) or dual (n×n) formulation
conditions better given the training-set size vs dimensionality.
"""
n = train_hvs.shape[0]
if n == 0:
raise ValueError("Cannot fit RegularizedLSClassifier: training data is empty")
Y = jax.nn.one_hot(train_labels, self.num_classes)
if n >= self.dimensions:
# Primal form: (d × d) system.
XtX = train_hvs.T @ train_hvs + self.reg * jnp.eye(self.dimensions)
XtY = train_hvs.T @ Y
weights = jnp.linalg.solve(XtX, XtY)
else:
# Dual form: (n × n) system, far better conditioned when d >> n.
K = train_hvs @ train_hvs.T + self.reg * jnp.eye(n)
alpha = jnp.linalg.solve(K, Y) # (n, num_classes)
weights = train_hvs.T @ alpha # (d, num_classes)
return self.replace(weights=weights)
[docs]
@jax.jit
def predict(self, queries: jax.Array) -> jax.Array:
logits = queries @ self.weights
return jnp.argmax(logits, axis=-1)
[docs]
@jax.jit
def score(self, test_hvs: jax.Array, test_labels: jax.Array) -> jax.Array:
preds = self.predict(test_hvs)
return jnp.mean(preds == test_labels)
[docs]
def replace(self, **updates: Any) -> "RegularizedLSClassifier":
return dataclass_replace(self, **updates)
@register_dataclass
@dataclass
class HDRegressor:
r"""Continuous-output ridge regression on hypervector features.
Where :class:`RegularizedLSClassifier` solves the ridge problem
against one-hot class targets, :class:`HDRegressor` solves it
against an arbitrary continuous target matrix
:math:`Y \in \mathbb{R}^{n \times k}`. This is the natural HDC
primitive for predicting continuous quantities — joint angles in
a robot policy, position deltas in a tracking head, regression
targets in a calibration table.
The closed-form solution is the same as the classifier's, applied
to a real-valued :math:`Y`:
- **Primal** (:math:`n \geq d`):
:math:`W = (X^\top X + \lambda I_d)^{-1} X^\top Y`.
- **Dual** (:math:`n < d`):
:math:`W = X^\top (X X^\top + \lambda I_n)^{-1} Y`.
For a downstream split-conformal regression pipeline, pair this
with :class:`~bayes_hdc.uncertainty.ConformalRegressor` to wrap
the point predictions in finite-sample-coverage intervals.
Attributes:
weights: Coefficient matrix of shape ``(dimensions, output_dim)``.
dimensions: Hypervector dimensionality :math:`d`.
output_dim: Output dimensionality :math:`k`.
reg: Tikhonov regularisation strength :math:`\lambda`.
"""
weights: jax.Array # (dimensions, output_dim)
dimensions: int = field(metadata=dict(static=True))
output_dim: int = field(metadata=dict(static=True))
reg: float = field(metadata=dict(static=True))
@staticmethod
def create(
dimensions: int,
output_dim: int,
reg: float = 1.0,
) -> "HDRegressor":
return HDRegressor(
weights=jnp.zeros((dimensions, output_dim)),
dimensions=dimensions,
output_dim=output_dim,
reg=reg,
)
def fit(self, train_hvs: jax.Array, train_targets: jax.Array) -> "HDRegressor":
"""Fit by solving regularised least squares on continuous targets.
Args:
train_hvs: Training hypervectors of shape ``(n, d)``.
train_targets: Training targets of shape ``(n, k)`` or
``(n,)`` (rank-1 reshaped automatically to ``(n, 1)``).
Returns:
A fitted ``HDRegressor`` (immutable update; the original is
unchanged).
"""
n = train_hvs.shape[0]
if n == 0:
raise ValueError("Cannot fit HDRegressor: training data is empty")
Y = train_targets
if Y.ndim == 1:
Y = Y[:, None]
if Y.shape[0] != n:
raise ValueError(
f"train_hvs and train_targets disagree on n: {train_hvs.shape[0]} vs {Y.shape[0]}"
)
if Y.shape[1] != self.output_dim:
raise ValueError(
f"train_targets has output_dim={Y.shape[1]} but the regressor "
f"was created with output_dim={self.output_dim}"
)
if n >= self.dimensions:
# Primal form: (d × d) system.
XtX = train_hvs.T @ train_hvs + self.reg * jnp.eye(self.dimensions)
XtY = train_hvs.T @ Y
weights = jnp.linalg.solve(XtX, XtY)
else:
# Dual form: (n × n) system, far better conditioned when d >> n.
K = train_hvs @ train_hvs.T + self.reg * jnp.eye(n)
alpha = jnp.linalg.solve(K, Y) # (n, k)
weights = train_hvs.T @ alpha # (d, k)
return self.replace(weights=weights)
@jax.jit
def predict(self, queries: jax.Array) -> jax.Array:
"""Predict continuous outputs.
Args:
queries: Hypervectors of shape ``(d,)`` or ``(n, d)``.
Returns:
Predictions of shape ``(k,)`` or ``(n, k)`` matching the
input rank. If the regressor was created with
``output_dim=1``, the trailing singleton is squeezed only
when the input was a single vector.
"""
if queries.ndim == 1:
return queries @ self.weights # (d,) @ (d, k) -> (k,)
return queries @ self.weights # (n, d) @ (d, k) -> (n, k)
@jax.jit
def score(self, test_hvs: jax.Array, test_targets: jax.Array) -> jax.Array:
"""Coefficient of determination R² on a test set.
Returns the multi-output R² computed as
:math:`1 - \\sum (y - \\hat y)^2 / \\sum (y - \\bar y)^2`,
where the variances are summed across all output dimensions.
Negative values indicate a fit worse than predicting the mean.
"""
Y = test_targets
if Y.ndim == 1:
Y = Y[:, None]
preds = self.predict(test_hvs)
ss_res = jnp.sum((Y - preds) ** 2)
ss_tot = jnp.sum((Y - jnp.mean(Y, axis=0, keepdims=True)) ** 2)
return 1.0 - ss_res / (ss_tot + EPS)
def replace(self, **updates: Any) -> "HDRegressor":
return dataclass_replace(self, **updates)
@register_dataclass
@dataclass
class ClusteringModel:
"""HDC-style k-means clustering.
Encodes data into hypervectors, then iteratively assigns clusters
by cosine similarity and updates centroids by bundling.
Inspired by the ClusteringModel in hdlib (Cumbo et al., 2023).
"""
centroids: jax.Array
dimensions: int = field(metadata=dict(static=True))
num_clusters: int = field(metadata=dict(static=True))
vsa_model_name: str = field(metadata=dict(static=True), default="map")
@staticmethod
def create(
num_clusters: int,
dimensions: int = 10000,
vsa_model: Union[str, "VSAModel"] = "map",
key: Optional[jax.Array] = None,
) -> "ClusteringModel":
if key is None:
key = jax.random.PRNGKey(0)
if isinstance(vsa_model, str):
vsa_model_name = vsa_model
else:
vsa_model_name = vsa_model.name
centroids = jax.random.normal(key, (num_clusters, dimensions))
norms = jnp.linalg.norm(centroids, axis=-1, keepdims=True)
centroids = centroids / (norms + EPS)
return ClusteringModel(
centroids=centroids,
dimensions=dimensions,
num_clusters=num_clusters,
vsa_model_name=vsa_model_name,
)
def fit(
self,
hvs: jax.Array,
max_iters: int = 50,
) -> "ClusteringModel":
"""Fit clusters by iterating assignment and centroid update.
Args:
hvs: Hypervectors of shape (n, d)
max_iters: Maximum iterations (default: 50)
Returns:
Updated ClusteringModel with refined centroids
"""
centroids = self.centroids
for _ in range(max_iters):
sims = hvs @ centroids.T
assignments = jnp.argmax(sims, axis=-1)
new_centroids = []
for k in range(self.num_clusters):
mask = assignments == k
count = jnp.sum(mask)
cluster_sum = jnp.sum(hvs * mask[:, None], axis=0)
fallback = centroids[k]
centroid = jnp.where(count > 0, cluster_sum / (count + EPS), fallback)
norm = jnp.linalg.norm(centroid) + EPS
new_centroids.append(centroid / norm)
stacked_centroids: jax.Array = jnp.stack(new_centroids)
if jnp.allclose(stacked_centroids, centroids, atol=1e-6):
break
centroids = stacked_centroids
return dataclass_replace(self, centroids=centroids)
@jax.jit
def predict(self, hvs: jax.Array) -> jax.Array:
"""Assign each hypervector to the closest centroid.
Args:
hvs: Hypervectors of shape (n, d) or (d,)
Returns:
Cluster assignments (scalar for single query, array for batch)
"""
single = hvs.ndim == 1
if single:
hvs = hvs[None, :]
sims = hvs @ self.centroids.T
result = jnp.argmax(sims, axis=-1)
return result[0] if single else result
def replace(self, **updates: Any) -> "ClusteringModel":
return dataclass_replace(self, **updates)
__all__ = [
"CentroidClassifier",
"AdaptiveHDC",
"LVQClassifier",
"RegularizedLSClassifier",
"ClusteringModel",
]