# SPDX-License-Identifier: MIT
# Copyright (c) 2026 Rajdeep Singh
"""Memory modules for Hyperdimensional Computing.
Sparse Distributed Memory (SDM), Hopfield networks, and attention-based retrieval.
"""
from dataclasses import dataclass, field
from dataclasses import replace as dc_replace
from typing import Optional
import jax
import jax.numpy as jnp
from bayes_hdc import functional as F
from bayes_hdc._compat import register_dataclass
[docs]
@register_dataclass
@dataclass
class SparseDistributedMemory:
"""Sparse Distributed Memory (SDM) for content-addressable storage."""
locations: jax.Array # (num_locations, dimensions)
contents: jax.Array # (num_locations, dimensions)
dimensions: int = field(metadata=dict(static=True))
radius: float = field(metadata=dict(static=True))
[docs]
@staticmethod
def create(
num_locations: int,
dimensions: int,
radius: float = 0.0,
key: Optional[jax.Array] = None,
) -> "SparseDistributedMemory":
if key is None:
key = jax.random.PRNGKey(0)
locs = jax.random.normal(key, (num_locations, dimensions))
locs = locs / (jnp.linalg.norm(locs, axis=-1, keepdims=True) + 1e-8)
contents = jnp.zeros((num_locations, dimensions))
return SparseDistributedMemory(
locations=locs,
contents=contents,
dimensions=dimensions,
radius=radius,
)
[docs]
def write(self, address: jax.Array, value: jax.Array) -> "SparseDistributedMemory":
sims = jax.vmap(lambda loc: F.cosine_similarity(address, loc))(self.locations)
mask = sims >= (1.0 - self.radius)
delta = mask[:, None].astype(jnp.float32) * value
return SparseDistributedMemory(
locations=self.locations,
contents=self.contents + delta,
dimensions=self.dimensions,
radius=self.radius,
)
[docs]
@jax.jit
def read(self, address: jax.Array) -> jax.Array:
sims = jax.vmap(lambda loc: F.cosine_similarity(address, loc))(self.locations)
mask = sims >= (1.0 - self.radius)
summed = jnp.sum(self.contents * mask[:, None], axis=0)
norm = jnp.linalg.norm(summed) + 1e-8
return summed / norm
[docs]
@register_dataclass
@dataclass
class HopfieldMemory:
"""Modern continuous Hopfield network (Ramsauer et al. 2020).
One-step softmax-attention retrieval over stored patterns. Distinct
from the classical sign-thresholded recurrent Hopfield network
(Hopfield 1982) — which settles via repeated application of a sign
update — and from the spiking-neuron cleanup memories of Stewart,
Tang & Eliasmith (2010), which run on populations of leaky-
integrate-and-fire neurons via the Neural Engineering Framework.
Retrieval here is a single feed-forward softmax over cosine
similarities to the stored patterns; no recurrent settling.
References:
Ramsauer, H. et al. (2020). Hopfield Networks is All You Need.
arXiv:2008.02217.
Hopfield, J. J. (1982). Neural networks and physical systems with
emergent collective computational abilities. PNAS 79(8): 2554-2558.
Stewart, T. C., Tang, Y., Eliasmith, C. (2010). A Biologically
Realistic Cleanup Memory: Autoassociation in Spiking Neurons.
Cognitive Systems Research 12: 84-92.
"""
patterns: jax.Array # (num_patterns, dimensions)
dimensions: int = field(metadata=dict(static=True))
beta: float = field(metadata=dict(static=True), default=1.0)
[docs]
@staticmethod
def create(
dimensions: int,
beta: float = 1.0,
) -> "HopfieldMemory":
return HopfieldMemory(
patterns=jnp.zeros((0, dimensions)),
dimensions=dimensions,
beta=beta,
)
[docs]
def add(self, pattern: jax.Array) -> "HopfieldMemory":
p = pattern.reshape(-1) / (jnp.linalg.norm(pattern) + 1e-8)
new_patterns = jnp.concatenate([self.patterns, p[None, :]], axis=0)
return dc_replace(self, patterns=new_patterns)
[docs]
@jax.jit
def retrieve(self, query: jax.Array) -> jax.Array:
q = query.reshape(-1)
if self.patterns.shape[0] == 0:
return jnp.zeros_like(q)
sims = jax.vmap(lambda p: F.cosine_similarity(q, p))(self.patterns)
weights = jax.nn.softmax(self.beta * sims)
return jnp.sum(self.patterns * weights[:, None], axis=0)
[docs]
@register_dataclass
@dataclass
class AttentionMemory:
"""Attention-based retrieval with key-value storage and multi-head support."""
keys: jax.Array
values: jax.Array
dimensions: int = field(metadata=dict(static=True))
temperature: float = field(metadata=dict(static=True), default=1.0)
num_heads: int = field(metadata=dict(static=True), default=1)
[docs]
@staticmethod
def create(
dimensions: int,
temperature: float = 1.0,
num_heads: int = 1,
) -> "AttentionMemory":
if num_heads > 1 and dimensions % num_heads != 0:
raise ValueError(
f"dimensions ({dimensions}) must be divisible by num_heads ({num_heads})"
)
return AttentionMemory(
keys=jnp.zeros((0, dimensions)),
values=jnp.zeros((0, dimensions)),
dimensions=dimensions,
temperature=temperature,
num_heads=num_heads,
)
[docs]
def write(self, key: jax.Array, value: jax.Array) -> "AttentionMemory":
k = key.reshape(1, -1)
v = value.reshape(1, -1)
return dc_replace(
self,
keys=jnp.concatenate([self.keys, k], axis=0),
values=jnp.concatenate([self.values, v], axis=0),
)
[docs]
def write_batch(self, keys: jax.Array, values: jax.Array) -> "AttentionMemory":
return dc_replace(
self,
keys=jnp.concatenate([self.keys, keys], axis=0),
values=jnp.concatenate([self.values, values], axis=0),
)
[docs]
@jax.jit
def retrieve(self, query: jax.Array) -> jax.Array:
q = query.reshape(-1)
if self.keys.shape[0] == 0:
return jnp.zeros(self.dimensions)
if self.num_heads == 1:
scale = (self.dimensions**0.5) * self.temperature
scores = self.keys @ q / scale
weights = jax.nn.softmax(scores)
return jnp.sum(self.values * weights[:, None], axis=0)
else:
head_dim = self.dimensions // self.num_heads
q_heads = q.reshape(self.num_heads, head_dim)
k_heads = self.keys.reshape(-1, self.num_heads, head_dim)
v_heads = self.values.reshape(-1, self.num_heads, head_dim)
scale = (head_dim**0.5) * self.temperature
scores = jnp.einsum("hd,nhd->hn", q_heads, k_heads) / scale
weights = jax.nn.softmax(scores, axis=-1)
result = jnp.einsum("hn,nhd->hd", weights, v_heads)
return result.reshape(-1)
[docs]
@jax.jit
def retrieve_with_weights(self, query: jax.Array) -> tuple:
q = query.reshape(-1)
if self.keys.shape[0] == 0:
return jnp.zeros(self.dimensions), jnp.array([])
scale = (self.dimensions**0.5) * self.temperature
scores = self.keys @ q / scale
weights = jax.nn.softmax(scores)
result = jnp.sum(self.values * weights[:, None], axis=0)
return result, weights
__all__ = ["SparseDistributedMemory", "HopfieldMemory", "AttentionMemory"]