diff --git a/README.md b/README.md
index 757b9862..760a4933 100644
--- a/README.md
+++ b/README.md
@@ -119,7 +119,7 @@ $ python install -e .
**Version:**
-3.1.0
+3.1.1
Author:
Alexander G. Ororbia II
diff --git a/history.txt b/history.txt
index 4345a232..a8c1e724 100644
--- a/history.txt
+++ b/history.txt
@@ -99,4 +99,6 @@ History
* new suite of competitive learning synapses (generalized formats), including vector-quantization, self-organizing map, adaptive resonance theory (contus-version), and modern hopfield network
* revisions to metric and model utils
* some additional clean-up, including supported retinal ganglion encoder
+ * fixed pkg-resource header (in both ngcsimlib and ngclearn) to account for newer python(s)
+
diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py
index 6121286c..5670cfca 100644
--- a/ngclearn/__init__.py
+++ b/ngclearn/__init__.py
@@ -31,7 +31,7 @@
"currently installed!")
##################################################################################
-## Needed to preload is called before anything in ngclearn
+## Following are needed to preload is called before anything in ngclearn
from pathlib import Path
from sys import argv
import numpy
diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py
index 0b81afdc..94a2c116 100644
--- a/ngclearn/components/__init__.py
+++ b/ngclearn/components/__init__.py
@@ -30,6 +30,9 @@
from .input_encoders.ganglionCell import RetinalGanglionCell
from .input_encoders.latencyCell import LatencyCell
from .input_encoders.phasorCell import PhasorCell
+#from .input_encoders.populationCoderCell import PopulationCoderCell
+#from .input_encoders.gridCell import GridCell
+#from .input_encoders.placeCell import PlaceCell
## point to synapse component types
from .synapses.denseSynapse import DenseSynapse
diff --git a/ngclearn/components/input_encoders/__init__.py b/ngclearn/components/input_encoders/__init__.py
index 9af3241b..bbee5280 100644
--- a/ngclearn/components/input_encoders/__init__.py
+++ b/ngclearn/components/input_encoders/__init__.py
@@ -1,7 +1,9 @@
from .bernoulliCell import BernoulliCell
from .poissonCell import PoissonCell
from .latencyCell import LatencyCell
-from .ganglionCell import RetinalGanglionCell
from .phasorCell import PhasorCell
-
+from .ganglionCell import RetinalGanglionCell
+#from .populationCoderCell import PopulationCoderCell
+#from .gridCell import GridCell
+#from .placeCell import PlaceCell
diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py
index 132f7d92..68e87539 100755
--- a/ngclearn/components/input_encoders/bernoulliCell.py
+++ b/ngclearn/components/input_encoders/bernoulliCell.py
@@ -27,7 +27,12 @@ class BernoulliCell(JaxComponent):
"""
def __init__(
- self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs
+ self,
+ name: str,
+ n_units: int,
+ batch_size: int = 1,
+ key: Union[jax.Array, None] = None,
+ **kwargs
):
super().__init__(name=name, key=key)
diff --git a/ngclearn/components/input_encoders/ganglionCell.py b/ngclearn/components/input_encoders/ganglionCell.py
index a35c7b0a..b7e58c05 100644
--- a/ngclearn/components/input_encoders/ganglionCell.py
+++ b/ngclearn/components/input_encoders/ganglionCell.py
@@ -8,26 +8,28 @@
def _create_gaussian_filter(patch_shape, sigma):
## Create a 2D Gaussian kernel centered on patch_shape with given sigma.
px, py = patch_shape
-
x_ = jnp.linspace(0, px - 1, px)
y_ = jnp.linspace(0, py - 1, py)
-
x, y = jnp.meshgrid(x_, y_)
-
xc = px // 2
yc = py // 2
-
- filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
- return filter / jnp.sum(filter)
+ _filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
+ return _filter / jnp.sum(_filter)
def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)
-
dog = g1 - lmbda * g2
-
return dog #- jnp.mean(dog)
+
+def _create_ratio_of_gauss_filter(patch_shape, sigma, k=1.6):
+ g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
+ g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)
+ rog = g1 / (g2 + 1e-8)
+ return rog
+
+
def _create_patches(obs, patch_shape, step_shape):
"""
Extract 2D patches from a batch of images using a sliding window.
@@ -67,10 +69,36 @@ def _create_patches(obs, patch_shape, step_shape):
return patches
+def _reconstruct(patches, nx_ny, area_shape, patch_shape, step_shape):
+ # patches: (N, nx * ny, px, py)
+
+ B = len(patches)
+ nx, ny = nx_ny
+ ix, iy = area_shape
+ px, py = patch_shape
+ sx, sy = step_shape
+ x = jnp.zeros((B, ix, iy))
+ counts = jnp.zeros((ix, iy))
+
+ idx = 0
+ for i in range(ny):
+ for j in range(nx):
+ di = i * sx
+ dj = j * sy
+ x = x.at[:, di:di + px, dj:dj + py].add(patches[:, idx])
+ counts = counts.at[di:di + px, dj:dj + py].add(1.0)
+ idx += 1
+
+ return x / counts[None, :, :]
+
+
class RetinalGanglionCell(JaxComponent):
"""
- A group of retinal ganglion cell that senses the input stimuli and sends out the filtered signal to the brain.
+ A group of retinal ganglion cell that sense input stimuli and send out filtered
+ signals (as output). Note that these simulated cells employ internal generalized
+ filters based on either Gaussian or difference-of-Gaussian kernels) to recover
+ historical receptive field processing effects.
| --- Cell Input Compartments: ---
| inputs - input (takes in external signals)
@@ -85,32 +113,34 @@ class RetinalGanglionCell(JaxComponent):
filter_type: string name of filter function (Default: identity)
:Note: supported filters include "gaussian", "difference_of_gaussian"
- sigma: standard deviation of gaussian kernel
+ sigma: standard deviation of (gaussian) kernel
- area_shape: receptive field area of ganglion cells in this module all together
+ area_shape: shape of receptive field area of ganglion cells in this module (all together)
n_cells: number of ganglion cells in this module
- patch_shape: each ganglion cell receptive field area
+ patch_shape: shape of each ganglion cell's receptive field area
- step_shape: the non-overlapping area between each two ganglion cells
+ step_shape: the non-overlapping area between each pair (two) of ganglion cells
- batch_size: batch size dimension of this cell (Default: 1)
+ batch_size: batch size dimension of this cell/module (Default: 1)
"""
- def __init__(self, name: str,
- filter_type: str,
- area_shape: Tuple[int, int],
- n_cells: int,
- patch_shape: Tuple[int, int],
- step_shape: Tuple[int, int],
- batch_size: int = 1,
- sigma: float = 1.0,
- key: Union[jax.Array, None] = None,
- **kwargs):
+ def __init__(
+ self,
+ name: str,
+ filter_type: str,
+ area_shape: Tuple[int, int],
+ n_cells: int,
+ patch_shape: Tuple[int, int],
+ step_shape: Tuple[int, int],
+ batch_size: int = 1,
+ sigma: float = 1.0,
+ key: Union[jax.Array, None] = None,
+ **kwargs
+ ):
super().__init__(name=name, key=key)
-
## Layer Size Setup
self.filter_type = filter_type
self.n_cells = n_cells
@@ -121,36 +151,43 @@ def __init__(self, name: str,
self.patch_shape = patch_shape
self.step_shape = step_shape
- filter = jnp.ones(self.patch_shape)
+ _filter = jnp.ones(self.patch_shape)
- if filter_type == 'gaussian':
- filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
- elif filter_type == 'difference_of_gaussian':
- filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
+ if self.filter_type == 'gaussian':
+ print("filter type is ", self.filter_type)
+ _filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
+
+ elif self.filter_type in ["difference_of_gaussian", "DoG"]:
+ print("filter type is difference of gaussian: f(x) = p1 - p2")
+ _filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
+
+ elif self.filter_type in ["ratio_of_gaussian", "RoG"]:
+ print("filter type is ratio of gaussian: f(x) = p1 / p2")
+ _filter = _create_ratio_of_gauss_filter(patch_shape=self.patch_shape, sigma=sigma)
# ═════════════════ compartments initial values ════════════════════
- in_restVals = jnp.zeros((batch_size,
- *self.area_shape)) ## input: (B | ix | iy)
+ in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
- out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
- self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
+ out_restVals = jnp.zeros(
+ (batch_size, self.n_cells * self.patch_shape[0] * self.patch_shape[1])
+ ) ## output.shape: (B | n_cells * px * py)
# ═══════════════════ set compartments ══════════════════════
self.inputs = Compartment(in_restVals, display_name="Input Stimulus") # input compartment
- self.filter = Compartment(filter, display_name="Filter") # Filter compartment
+ self.filter = Compartment(_filter, display_name="Filter") # Filter compartment
self.outputs = Compartment(out_restVals, display_name="Output Signal") # output compartment
@compilable
def advance_state(self, t):
inputs = self.inputs.get()
- filter = self.filter.get()
+ _filter = self.filter.get()
px, py = self.patch_shape
# ═══════════════════ extract pathches for filters ══════════════════
input_patches = _create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape)
# ═══════════════════ apply filter to all pathches ══════════════════
- filtered_input = input_patches * filter ## shape: (B | n_cells | px | py)
+ filtered_input = input_patches * _filter ## shape: (B | n_cells | px | py)
# ════════════ reshape all cells responses to a single input to brain ════════════
filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py)
@@ -160,31 +197,20 @@ def advance_state(self, t):
self.outputs.set(outputs)
+
+
@compilable
def reset(self): ## reset core components/statistics
- # self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
in_restVals = jnp.zeros((self.batch_size, *self.area_shape)) ## input: (B | ix | iy)
out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
self.inputs.set(in_restVals)
self.outputs.set(out_restVals)
- # Viet: NOTE: we should not need this function since the reset function
- # one could set the batch size then do reset
- # @compilable
- # def batched_reset(self, batch_size):
- # in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)
-
- # out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
- # self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
-
- # self.inputs.set(in_restVals)
- # self.outputs.set(out_restVals)
-
@classmethod
def help(cls): ## component help function
properties = {
- "cell_type": "RetinalGanglionCell - filters the input stimuli, "
+ "cell_type": "RetinalGanglionCell - filters the input stimuli according retinal ganglion dynamics"
}
compartment_props = {
"inputs":
@@ -196,11 +222,11 @@ def help(cls): ## component help function
}
hyperparams = {
"filter_type": "Type of the filter for preprocessing the input",
- "sigma": "Standard deviation of gaussian kernel",
+ "sigma": "Standard deviation of gaussian kernel/filter",
"area_shape": "Effective receptive field area shape of ganglion cells in this module",
- "n_cells": "Number of Retinal Ganglion (center-surround) cells to model in this layer",
- "patch_shape": "Classical Receptive field area shape of individual ganglion cells in this module",
- "step_shape": "Extra-Classical Receptive field area shape each ganglion cell in this module",
+ "n_cells": "Number of retinal ganglion (center-surround) cells to model in this layer",
+ "patch_shape": "Classical receptive field area shape of individual ganglion cells in this module",
+ "step_shape": "Extra-classical receptive field area shape each ganglion cell in this module",
"batch_size": "Batch size dimension of this component"
}
info = {cls.__name__: properties,
@@ -212,13 +238,15 @@ def help(cls): ## component help function
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
- X = RetinalGanglionCell("RGC", filter_type="gaussian",
- sigma=2.3,
- area_shape=(16, 26),
- n_cells = 3,
- patch_shape=(16, 16),
- step_shape=(0, 5)
- )
+ X = RetinalGanglionCell(
+ "RGC",
+ filter_type="gaussian",
+ sigma=2.3,
+ area_shape=(16, 26),
+ n_cells = 3,
+ patch_shape=(16, 16),
+ step_shape=(0, 5)
+ )
print(X)
diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py
index 810776ab..333479ae 100644
--- a/ngclearn/components/input_encoders/poissonCell.py
+++ b/ngclearn/components/input_encoders/poissonCell.py
@@ -32,8 +32,13 @@ class PoissonCell(JaxComponent):
@deprecate_args(max_freq="target_freq")
def __init__(
- self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1,
- key: Union[jax.Array, None] = None, **kwargs
+ self,
+ name: str,
+ n_units: int,
+ target_freq: float = 63.75,
+ batch_size: int = 1,
+ key: Union[jax.Array, None] = None,
+ **kwargs
):
super().__init__(name=name, key=key)
diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py
index d12bd33f..9e072de5 100755
--- a/ngclearn/components/neurons/graded/gaussianErrorCell.py
+++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py
@@ -1,175 +1,174 @@
-# %%
-
-from ngclearn.components.jaxComponent import JaxComponent
-from jax import numpy as jnp, jit
-from ngclearn import compilable #from ngcsimlib.parser import compilable
-from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
-
-class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
- """
- A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically,
- this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood).
-
- | --- Cell Input Compartments: ---
- | mu - predicted value (takes in external signals)
- | Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2
- | target - desired/goal value (takes in external signals)
- | modulator - modulation signal (takes in optional external signals)
- | mask - binary/gating mask to apply to error neuron calculations
- | --- Cell Output Compartments: ---
- | L - local loss function embodied by this cell
- | dmu - derivative of L w.r.t. mu
- | dSigma - derivative of L w.r.t. Sigma
- | dtarget - derivative of L w.r.t. target
-
- Args:
- name: the string name of this cell
-
- n_units: number of cellular entities (neural population size)
-
- batch_size: batch size dimension of this cell (Default: 1)
-
- sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution;
- Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses
- to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument
- (Default: 1)
- """
- def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
- super().__init__(name, **kwargs)
-
- ## Layer Size Setup
- _shape = (batch_size, n_units) ## default shape is 2D/matrix
- if shape is None:
- shape = (n_units,) ## we set shape to be equal to n_units if nothing provided
- else:
- _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
- sigma_shape = (1,1)
- if not isinstance(sigma, float) and not isinstance(sigma, int):
- sigma_shape = jnp.array(sigma).shape
- self.sigma_shape = sigma_shape
- self.shape = shape
- self.batch_size = batch_size
- self.n_units = n_units
-
- ## Convolution shape setup
- self.width = self.height = n_units
-
- ## Compartment setup
- restVals = jnp.zeros(_shape)
- self.L = Compartment(0., display_name="Gaussian Log likelihood", units="nats") # loss compartment
- self.mu = Compartment(restVals, display_name="Gaussian mean") # mean/mean name. input wire
- self.dmu = Compartment(restVals) # derivative mean
- _Sigma = jnp.zeros(sigma_shape)
- self.Sigma = Compartment(_Sigma + sigma, display_name="Gaussian variance/covariance")
- self.dSigma = Compartment(_Sigma)
- self.target = Compartment(restVals, display_name="Gaussian data/target variable") # target. input wire
- self.dtarget = Compartment(restVals) # derivative target
- self.modulator = Compartment(restVals + 1.0) # to be set/consumed
- self.mask = Compartment(restVals + 1.0)
-
- @staticmethod
- def _eval_log_density(target, mu, Sigma): ## Gaussian log likelihood
- ## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix
- _dmu = (target - mu)
- #_numerator = 1. # 0.5
- log_density = -jnp.sum(jnp.square(_dmu)) * (1./((Sigma ** 2) * 2)) #* (_numerator / Sigma)
- return log_density, _dmu ## return density and raw delta
-
- @compilable
- def advance_state(self, dt): ## compute Gaussian error cell output (fixed-point)
- # Get the variables
- mu = self.mu.get()
- target = self.target.get()
- Sigma = self.Sigma.get()
- modulator = self.modulator.get()
- mask = self.mask.get()
-
- # Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit
- # behavior of the local cost functional:
- # FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2
- # but should support full log likelihood of the multivariate Gaussian with covariance of different types
- # TODO: could introduce a variant of GaussianErrorCell that moves according to an ODE
- # (using integration time constant dt)
-
- L, _dmu = GaussianErrorCell._eval_log_density(target, mu, Sigma) # L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
- ## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu)
- dmu = _dmu * (1./ Sigma) ## obtain precision-scaled e: (target - mu)/Sigma
- dtarget = -dmu # reverse of e ## -(target - mu)/Sigma
- dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma
-
- dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
- dtarget = dtarget * modulator * mask
- mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
-
- # Update compartments
- self.dmu.set(dmu)
- self.dtarget.set(dtarget)
- self.dSigma.set(dSigma)
- self.L.set(jnp.squeeze(L))
- self.mask.set(mask)
-
- @compilable
- def reset(self): ## reset core components/statistics
- self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
-
- @compilable
- def batched_reset(self, batch_size):
- _shape = (batch_size, self.shape[0])
- if len(self.shape) > 1:
- _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
- restVals = jnp.zeros(_shape)
- dmu = restVals
- dtarget = restVals
- dSigma = jnp.zeros(self.sigma_shape)
- target = restVals
- mu = restVals
- modulator = mu + 1.
- L = 0. #jnp.zeros((1, 1))
- mask = jnp.ones(_shape)
-
- self.dmu.set(dmu)
- self.dtarget.set(dtarget)
- self.dSigma.set(dSigma)
- if not self.target.targeted:
- self.target.set(target)
- if not self.mu.targeted:
- self.mu.set(mu)
- self.modulator.set(modulator)
- self.L.set(L)
- self.mask.set(mask)
-
- @classmethod
- def help(cls): ## component help function
- properties = {
- "cell_type": "GaussianErrorCell - computes mismatch/error signals at "
- "each time step t (between a `target` and a prediction `mu`)"
- }
- compartment_props = {
- "inputs":
- {"mu": "External input prediction value(s)",
- "Sigma": "External variance/covariance prediction value(s)",
- "target": "External input target signal value(s)",
- "modulator": "External input modulatory/scaling signal(s)",
- "mask": "External binary/gating mask to apply to signals"},
- "outputs":
- {"L": "Local loss / free-energy value embodied by this error-cell",
- "dmu": "first derivative of loss w.r.t. prediction value(s)",
- "dSigma": "first derivative of loss w.r.t. variance/covariance value(s)",
- "dtarget": "first derivative of loss w.r.t. target value(s)"},
- }
- hyperparams = {
- "n_units": "Number of neuronal cells to model in this layer",
- "batch_size": "Batch size dimension of this component",
- "sigma": "External input variance value (currently fixed and not learnable)"
- }
- info = {cls.__name__: properties,
- "compartments": compartment_props,
- "dynamics": "Gaussian(x=target; mu, sigma)",
- "hyperparameters": hyperparams}
- return info
-
-if __name__ == '__main__':
- from ngcsimlib.context import Context
- with Context("Bar") as bar:
- X = GaussianErrorCell("X", 9)
- print(X)
+# %%
+
+from ngclearn.components.jaxComponent import JaxComponent
+from jax import numpy as jnp, jit
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+
+class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
+ """
+ A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically,
+ this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood).
+
+ | --- Cell Input Compartments: ---
+ | mu - predicted value (takes in external signals)
+ | Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2
+ | target - desired/goal value (takes in external signals)
+ | modulator - modulation signal (takes in optional external signals)
+ | mask - binary/gating mask to apply to error neuron calculations
+ | --- Cell Output Compartments: ---
+ | L - local loss function embodied by this cell
+ | dmu - derivative of L w.r.t. mu
+ | dSigma - derivative of L w.r.t. Sigma
+ | dtarget - derivative of L w.r.t. target
+
+ Args:
+ name: the string name of this cell
+
+ n_units: number of cellular entities (neural population size)
+
+ batch_size: batch size dimension of this cell (Default: 1)
+
+ sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution;
+ Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses
+ to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument
+ (Default: 1)
+ """
+ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
+ super().__init__(name, **kwargs)
+
+ ## Layer Size Setup
+ _shape = (batch_size, n_units) ## default shape is 2D/matrix
+ if shape is None:
+ shape = (n_units,) ## we set shape to be equal to n_units if nothing provided
+ else:
+ _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
+ sigma_shape = (1,1)
+ if not isinstance(sigma, float) and not isinstance(sigma, int):
+ sigma_shape = jnp.array(sigma).shape
+ self.sigma_shape = sigma_shape
+ self.shape = shape
+ self.n_units = n_units
+ self.batch_size = batch_size
+
+ ## Convolution shape setup
+ self.width = self.height = n_units
+
+ ## Compartment setup
+ restVals = jnp.zeros(_shape)
+ self.L = Compartment(0., display_name="Gaussian Log likelihood", units="nats") # loss compartment
+ self.mu = Compartment(restVals, display_name="Gaussian mean") # mean/mean name. input wire
+ self.dmu = Compartment(restVals) # derivative mean
+ _Sigma = jnp.zeros(sigma_shape)
+ self.Sigma = Compartment(_Sigma + sigma, display_name="Gaussian variance/covariance")
+ self.dSigma = Compartment(_Sigma)
+ self.target = Compartment(restVals, display_name="Gaussian data/target variable") # target. input wire
+ self.dtarget = Compartment(restVals) # derivative target
+ self.modulator = Compartment(restVals + 1.0) # to be set/consumed
+ self.mask = Compartment(restVals + 1.0)
+
+ @staticmethod
+ def eval_log_density(target, mu, Sigma):
+ ## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix
+ _dmu = (target - mu)
+ #_numerator = 1. # 0.5
+ log_density = -jnp.sum(jnp.square(_dmu)) * (1./((Sigma ** 2) * 2)) #* (_numerator / Sigma)
+ return log_density, _dmu ## return density and raw delta
+
+ @compilable
+ def advance_state(self, dt): ## compute Gaussian error cell output
+ # Get the variables
+ mu = self.mu.get()
+ target = self.target.get()
+ Sigma = self.Sigma.get()
+ modulator = self.modulator.get()
+ mask = self.mask.get()
+
+ # Move Gaussian cell dynamics one step forward. Specifically, this
+ # routine emulates the error unit
+ '''
+ ## This commented-out block of code should be adapted to replace the
+ ## five lines below it in future iterations (more accurate/flexible)
+ L, _dmu = GaussianErrorCell._eval_log_density(target, mu, Sigma) # L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
+ ## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu)
+ dmu = _dmu * (1./ Sigma) ## obtain precision-scaled e: (target - mu)/Sigma
+ dtarget = -dmu # reverse of e ## -(target - mu)/Sigma
+ dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma
+ '''
+ _dmu = (target - mu) # e (error unit)
+ dmu = _dmu / Sigma
+ dtarget = -dmu # reverse of e
+ dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma
+ L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
+ #L = GaussianErrorCell.eval_log_density(target, mu, Sigma)
+
+ dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
+ dtarget = dtarget * modulator * mask
+ mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
+
+ # Update compartments
+ self.dmu.set(dmu)
+ self.dtarget.set(dtarget)
+ self.dSigma.set(dSigma)
+ self.L.set(jnp.squeeze(L))
+ self.mask.set(mask)
+
+ @compilable
+ def reset(self): ## reset core components/statistics
+ _shape = (self.batch_size, self.shape[0])
+ if len(self.shape) > 1:
+ _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
+ restVals = jnp.zeros(_shape)
+ dmu = restVals
+ dtarget = restVals
+ dSigma = jnp.zeros(self.sigma_shape)
+ target = restVals
+ mu = restVals
+ modulator = mu + 1.
+ L = 0. #jnp.zeros((1, 1))
+ mask = jnp.ones(_shape)
+
+ self.dmu.set(dmu)
+ self.dtarget.set(dtarget)
+ self.dSigma.set(dSigma)
+ self.target.set(target)
+ self.mu.set(mu)
+ self.modulator.set(modulator)
+ self.L.set(L)
+ self.mask.set(mask)
+
+ @classmethod
+ def help(cls): ## component help function
+ properties = {
+ "cell_type": "GaussianErrorCell - computes mismatch/error signals at "
+ "each time step t (between a `target` and a prediction `mu`)"
+ }
+ compartment_props = {
+ "inputs":
+ {"mu": "External input prediction value(s)",
+ "Sigma": "External variance/covariance prediction value(s)",
+ "target": "External input target signal value(s)",
+ "modulator": "External input modulatory/scaling signal(s)",
+ "mask": "External binary/gating mask to apply to signals"},
+ "outputs":
+ {"L": "Local loss / free-energy value embodied by this error-cell",
+ "dmu": "first derivative of loss w.r.t. prediction value(s)",
+ "dSigma": "first derivative of loss w.r.t. variance/covariance value(s)",
+ "dtarget": "first derivative of loss w.r.t. target value(s)"},
+ }
+ hyperparams = {
+ "n_units": "Number of neuronal cells to model in this layer",
+ "batch_size": "Batch size dimension of this component",
+ "sigma": "External input variance value (currently fixed and not learnable)"
+ }
+ info = {cls.__name__: properties,
+ "compartments": compartment_props,
+ "dynamics": "Gaussian(x=target; mu, sigma)",
+ "hyperparameters": hyperparams}
+ return info
+
+if __name__ == '__main__':
+ from ngcsimlib.context import Context
+ with Context("Bar") as bar:
+ X = GaussianErrorCell("X", 9)
+ print(X)
diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py
index 98f60325..3dcbdcc8 100755
--- a/ngclearn/components/neurons/graded/rateCell.py
+++ b/ngclearn/components/neurons/graded/rateCell.py
@@ -1,328 +1,324 @@
-# %%
-
-from jax import numpy as jnp, random, jit
-
-from ngclearn import compilable #from ngcsimlib.parser import compilable
-from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
-from ngclearn.components.jaxComponent import JaxComponent
-from ngclearn.utils.model_utils import create_function, threshold_soft, \
- threshold_cauchy
-from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
- step_euler, step_rk2, step_rk4
-from ngcsimlib.logger import info
-
-
-def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
- z_leak = jnp.sign(z) ## d/dx of Laplace is signum
- dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
- return dz_dt
-
-def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
- z_leak = (z * 2)/(1. + jnp.square(z))
- dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
- return dz_dt
-
-def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
- z_leak = jnp.exp(-jnp.square(z)) * z * 2
- dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
- return dz_dt
-
-def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
- z_leak = z # * 2 ## Default: assume Gaussian
- dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
- return dz_dt
-
-# @jit
-def _modulate(j, dfx_val):
- """
- Apply a signal modulator to j (typically of the form of a derivative/dampening function)
-
- Args:
- j: current/stimulus value to modulate
-
- dfx_val: modulator signal
-
- Returns:
- modulated j value
- """
- return j * dfx_val
-
-# @partial(jit, static_argnames=["integType", "priorType"])
-def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0):
- """
- Runs leaky rate-coded state dynamics one step in time.
-
- Args:
- dt: integration time constant
-
- j: input (bottom-up) electrical/stimulus current
-
- j_td: modulatory (top-down) electrical/stimulus pressure
-
- z: current value of membrane/state
-
- tau_m: membrane/state time constant
-
- leak_gamma: strength of leak to apply to membrane/state
-
- integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4)
-
- priorType: scale-shift prior distribution to impose over neural dynamics
-
- Returns:
- New value of membrane/state for next time step
- """
- _dfz_fns = {
- 0: lambda t, z, params: _dfz_internal_gaussian(z, *params),
- 1: lambda t, z, params: _dfz_internal_laplace(z, *params),
- 2: lambda t, z, params: _dfz_internal_cauchy(z, *params),
- 3: lambda t, z, params: _dfz_internal_exp(z, *params),
- }
- _dfz_fn = _dfz_fns.get(priorType, _dfz_internal_gaussian)
- _step_fns = {
- 0: step_euler,
- 1: step_rk2,
- 2: step_rk4,
- }
- _step_fn = _step_fns.get(integType, step_euler)
- params = (j, j_td, tau_m, leak_gamma)
- _, _z = _step_fn(0., z, _dfz_fn, dt, params)
- return _z
-
-# @jit
-def _run_cell_stateless(j):
- """
- A simplification of running a stateless set of dynamics over j (an identity
- functional form of dynamics).
-
- Args:
- j: stimulus to do nothing to
-
- Returns:
- the stimulus
- """
- return j + 0
-
-class RateCell(JaxComponent): ## Rate-coded/real-valued cell
- """
- A non-spiking cell driven by the gradient dynamics of neural generative
- coding-driven predictive processing.
-
- The specific differential equation that characterizes this cell
- is (for adjusting v, given current j, over time) is:
-
- | tau_m * dz/dt = lambda * prior(z) + (j + j_td)
- | where j is the set of general incoming input signals (e.g., message-passed signals)
- | and j_td is taken to be the set of top-down pressure signals
-
- | --- Cell Input Compartments: ---
- | j - input pressure (takes in external signals)
- | j_td - input/top-down pressure input (takes in external signals)
- | --- Cell State Compartments ---
- | z - rate activity
- | --- Cell Output Compartments: ---
- | zF - post-activation function activity, i.e., fx(z)
-
- Args:
- name: the string name of this cell
-
- n_units: number of cellular entities (neural population size)
-
- tau_m: membrane/state time constant (milliseconds)
-
- prior: a kernel for specifying the type of centered scale-shift distribution
- to impose over neuronal dynamics, applied to each neuron or
- dimension within this component (Default: ("gaussian", 0)); this is
- a tuple with 1st element containing a string name of the distribution
- one wants to use while the second value is a `leak rate` scalar
- that controls the influence/weighting that this distribution
- has on the dynamics; for example, ("laplacian, 0.001") means that a
- centered laplacian distribution scaled by `0.001` will be injected
- into this cell's dynamics ODE each step of simulated time
-
- :Note: supported scale-shift distributions include "laplacian",
- "cauchy", "exp", and "gaussian"
-
- act_fx: string name of activation function/nonlinearity to use
-
- output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.)
-
- integration_type: type of integration to use for this cell's dynamics;
- current supported forms include "euler" (Euler/RK-1 integration)
- and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
-
- :Note: setting the integration type to the midpoint method will
- increase the accuray of the estimate of the cell's evolution
- at an increase in computational cost (and simulation time)
-
- resist_scale: a scaling factor applied to incoming pressure `j` (default: 1)
- """
-
- def __init__(
- self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.),
- integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs):
- jax_comp_kwargs = {k: v for k, v in kwargs.items() if k not in ('omega_0',)}
- this_class_kwargs = {k: v for k, v in kwargs.items() if k in ('omega_0',)}
- super().__init__(name, **jax_comp_kwargs)
-
- ## membrane parameter setup (affects ODE integration)
- self.output_scale = output_scale
- self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode
- self.is_stateful = is_stateful
- if isinstance(tau_m, float):
- if tau_m <= 0: ## trigger stateless mode
- self.is_stateful = False
- priorType, leakRate = prior
- priorTypeDict = {
- "gaussian": 0,
- "laplacian": 1,
- "cauchy": 2,
- "exp": 3
- }
- self.priorType = priorTypeDict.get(priorType, 0)
- self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior)
- thresholdType, thr_lmbda = threshold
- self.thresholdType = thresholdType ## type of thresholding function to use
- self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics
- self.resist_scale = resist_scale ## a "resistance" scaling factor
-
- ## integration properties
- self.integrationType = integration_type
- self.intgFlag = get_integrator_code(self.integrationType)
-
- ## Layer size setup
- _shape = (batch_size, n_units) ## default shape is 2D/matrix
- if shape is None:
- shape = (n_units,) ## we set shape to be equal to n_units if nothing provided
- else:
- _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
- self.shape = shape
- self.n_units = n_units
- self.batch_size = batch_size
-
- omega_0 = None
- if act_fx == "sine":
- omega_0 = this_class_kwargs["omega_0"]
- self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0)
-
- # compartments (state of the cell & parameters will be updated through stateless calls)
- restVals = jnp.zeros(_shape)
- self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current
- self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity
- self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure
- self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
-
- @compilable
- def advance_state(self, dt):
- # Get the compartment values
- j = self.j.get()
- j_td = self.j_td.get()
- z = self.z.get()
-
- #if tau_m > 0.:
- if self.is_stateful:
- ### run a step of integration over neuronal dynamics
- ## Notes:
- ## self.pressure <-- "top-down" expectation / contextual pressure
- ## self.current <-- "bottom-up" data-dependent signal
- dfx_val = self.dfx(z)
- j = _modulate(j, dfx_val) ## TODO: make this optional (for NGC circuit dynamics)
- j = j * self.resist_scale
- tmp_z = _run_cell(
- dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag,
- priorType=self.priorType
- )
- ## apply optional thresholding sub-dynamics
- if self.thresholdType == "soft_threshold":
- tmp_z = threshold_soft(tmp_z, self.thr_lmbda)
- elif self.thresholdType == "cauchy_threshold":
- tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda)
- z = tmp_z ## pre-activation function value(s)
- zF = self.fx(z) * self.output_scale ## post-activation function value(s)
- else:
- ## run in "stateless" mode (when no membrane time constant provided)
- j_total = j + j_td
- z = _run_cell_stateless(j_total)
- zF = self.fx(z) * self.output_scale
-
- # Update compartments
- self.j.set(j)
- self.j_td.set(j_td)
- self.z.set(z)
- self.zF.set(zF)
-
- @compilable
- def reset(self): ## reset core components/statistics
- self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
-
- @compilable
- def batched_reset(self, batch_size):
- _shape = (batch_size, self.shape[0])
- if len(self.shape) > 1:
- _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
- restVals = jnp.zeros(_shape)
- self.j.set(restVals)
- self.j_td.set(restVals)
- self.z.set(restVals)
- self.zF.set(restVals)
-
- # def save(self, directory, **kwargs):
- # ## do a protected save of constants, depending on whether they are floats or arrays
- # tau_m = (self.tau_m if isinstance(self.tau_m, float)
- # else jnp.ones([[self.tau_m]]))
- # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
- # else jnp.ones([[self.priorLeakRate]]))
- # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
- # else jnp.ones([[self.resist_scale]]))
- #
- # file_name = directory + "/" + self.name + ".npz"
- # jnp.savez(file_name,
- # tau_m=tau_m, priorLeakRate=priorLeakRate,
- # resist_scale=resist_scale) #, key=self.key.value)
- #
- # def load(self, directory, seeded=False, **kwargs):
- # file_name = directory + "/" + self.name + ".npz"
- # data = jnp.load(file_name)
- # ## constants loaded in
- # self.tau_m = data['tau_m']
- # self.priorLeakRate = data['priorLeakRate']
- # self.resist_scale = data['resist_scale']
- # #if seeded:
- # # self.key.set(data['key'])
-
- @classmethod
- def help(cls): ## component help function
- properties = {
- "cell_type": "RateCell - evolves neurons according to rate-coded/"
- "continuous dynamics "
- }
- compartment_props = {
- "inputs":
- {"j": "External input stimulus value(s)",
- "j_td": "External top-down input stimulus value(s); these get "
- "multiplied by the derivative of f(x), i.e., df(x)"},
- "states":
- {"z": "Update to rate-coded continuous dynamics; value at time t"},
- "outputs":
- {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"},
- }
- hyperparams = {
- "n_units": "Number of neuronal cells to model in this layer",
- "batch_size": "Batch size dimension of this component",
- "tau_m": "Cell state/membrane time constant",
- "prior": "What kind of kurtotic prior to place over neuronal dynamics?",
- "act_fx": "Elementwise activation function to apply over cell state `z`",
- "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?",
- "integration_type": "Type of numerical integration to use for the cell dynamics",
- }
- info = {cls.__name__: properties,
- "compartments": compartment_props,
- "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)",
- "hyperparameters": hyperparams}
- return info
-
-if __name__ == '__main__':
- from ngcsimlib.context import Context
- with Context("Bar") as bar:
- X = RateCell("X", 9, 0.03)
- print(X)
+# %%
+
+from jax import numpy as jnp, random, jit
+
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.jaxComponent import JaxComponent
+from ngclearn.utils.model_utils import create_function, threshold_soft, \
+ threshold_cauchy
+from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
+ step_euler, step_rk2, step_rk4
+from ngcsimlib.logger import info
+
+
+def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
+ z_leak = jnp.sign(z) ## d/dx of Laplace is signum
+ dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
+ return dz_dt
+
+def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
+ z_leak = (z * 2)/(1. + jnp.square(z))
+ dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
+ return dz_dt
+
+def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
+ z_leak = jnp.exp(-jnp.square(z)) * z * 2
+ dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
+ return dz_dt
+
+def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
+ z_leak = z # * 2 ## Default: assume Gaussian
+ dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
+ return dz_dt
+
+# @jit
+def _modulate(j, dfx_val):
+ """
+ Apply a signal modulator to j (typically of the form of a derivative/dampening function)
+
+ Args:
+ j: current/stimulus value to modulate
+
+ dfx_val: modulator signal
+
+ Returns:
+ modulated j value
+ """
+ return j * dfx_val
+
+# @partial(jit, static_argnames=["integType", "priorType"])
+def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0):
+ """
+ Runs leaky rate-coded state dynamics one step in time.
+
+ Args:
+ dt: integration time constant
+
+ j: input (bottom-up) electrical/stimulus current
+
+ j_td: modulatory (top-down) electrical/stimulus pressure
+
+ z: current value of membrane/state
+
+ tau_m: membrane/state time constant
+
+ leak_gamma: strength of leak to apply to membrane/state
+
+ integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4)
+
+ priorType: scale-shift prior distribution to impose over neural dynamics
+
+ Returns:
+ New value of membrane/state for next time step
+ """
+ _dfz_fns = {
+ 0: lambda t, z, params: _dfz_internal_gaussian(z, *params),
+ 1: lambda t, z, params: _dfz_internal_laplace(z, *params),
+ 2: lambda t, z, params: _dfz_internal_cauchy(z, *params),
+ 3: lambda t, z, params: _dfz_internal_exp(z, *params),
+ }
+ _dfz_fn = _dfz_fns.get(priorType, _dfz_internal_gaussian)
+ _step_fns = {
+ 0: step_euler,
+ 1: step_rk2,
+ 2: step_rk4,
+ }
+ _step_fn = _step_fns.get(integType, step_euler)
+ params = (j, j_td, tau_m, leak_gamma)
+ _, _z = _step_fn(0., z, _dfz_fn, dt, params)
+ return _z
+
+# @jit
+def _run_cell_stateless(j):
+ """
+ A simplification of running a stateless set of dynamics over j (an identity
+ functional form of dynamics).
+
+ Args:
+ j: stimulus to do nothing to
+
+ Returns:
+ the stimulus
+ """
+ return j + 0
+
+class RateCell(JaxComponent): ## Rate-coded/real-valued cell
+ """
+ A non-spiking cell driven by the gradient dynamics of neural generative
+ coding-driven predictive processing.
+
+ The specific differential equation that characterizes this cell
+ is (for adjusting v, given current j, over time) is:
+
+ | tau_m * dz/dt = lambda * prior(z) + (j + j_td)
+ | where j is the set of general incoming input signals (e.g., message-passed signals)
+ | and j_td is taken to be the set of top-down pressure signals
+
+ | --- Cell Input Compartments: ---
+ | j - input pressure (takes in external signals)
+ | j_td - input/top-down pressure input (takes in external signals)
+ | --- Cell State Compartments ---
+ | z - rate activity
+ | --- Cell Output Compartments: ---
+ | zF - post-activation function activity, i.e., fx(z)
+
+ Args:
+ name: the string name of this cell
+
+ n_units: number of cellular entities (neural population size)
+
+ tau_m: membrane/state time constant (milliseconds)
+
+ prior: a kernel for specifying the type of centered scale-shift distribution
+ to impose over neuronal dynamics, applied to each neuron or
+ dimension within this component (Default: ("gaussian", 0)); this is
+ a tuple with 1st element containing a string name of the distribution
+ one wants to use while the second value is a `leak rate` scalar
+ that controls the influence/weighting that this distribution
+ has on the dynamics; for example, ("laplacian, 0.001") means that a
+ centered laplacian distribution scaled by `0.001` will be injected
+ into this cell's dynamics ODE each step of simulated time
+
+ :Note: supported scale-shift distributions include "laplacian",
+ "cauchy", "exp", and "gaussian"
+
+ act_fx: string name of activation function/nonlinearity to use
+
+ output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.)
+
+ integration_type: type of integration to use for this cell's dynamics;
+ current supported forms include "euler" (Euler/RK-1 integration)
+ and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
+
+ :Note: setting the integration type to the midpoint method will
+ increase the accuray of the estimate of the cell's evolution
+ at an increase in computational cost (and simulation time)
+
+ resist_scale: a scaling factor applied to incoming pressure `j` (default: 1)
+ """
+
+ def __init__(
+ self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.),
+ integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs):
+ jax_comp_kwargs = {k: v for k, v in kwargs.items() if k not in ('omega_0',)}
+ this_class_kwargs = {k: v for k, v in kwargs.items() if k in ('omega_0',)}
+ super().__init__(name, **jax_comp_kwargs)
+
+ ## membrane parameter setup (affects ODE integration)
+ self.output_scale = output_scale
+ self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode
+ self.is_stateful = is_stateful
+ if isinstance(tau_m, float):
+ if tau_m <= 0: ## trigger stateless mode
+ self.is_stateful = False
+ priorType, leakRate = prior
+ priorTypeDict = {
+ "gaussian": 0,
+ "laplacian": 1,
+ "cauchy": 2,
+ "exp": 3
+ }
+ self.priorType = priorTypeDict.get(priorType, 0)
+ self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior)
+ thresholdType, thr_lmbda = threshold
+ self.thresholdType = thresholdType ## type of thresholding function to use
+ self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics
+ self.resist_scale = resist_scale ## a "resistance" scaling factor
+
+ ## integration properties
+ self.integrationType = integration_type
+ self.intgFlag = get_integrator_code(self.integrationType)
+
+ ## Layer size setup
+ _shape = (batch_size, n_units) ## default shape is 2D/matrix
+ if shape is None:
+ shape = (n_units,) ## we set shape to be equal to n_units if nothing provided
+ else:
+ _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
+ self.shape = shape
+ self.n_units = n_units
+ self.batch_size = batch_size
+
+ omega_0 = None
+ if act_fx == "sine":
+ omega_0 = this_class_kwargs["omega_0"]
+ self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0)
+
+ # compartments (state of the cell & parameters will be updated through stateless calls)
+ restVals = jnp.zeros(_shape)
+ self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current
+ self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity
+ self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure
+ self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
+
+ @compilable
+ def advance_state(self, dt):
+ # Get the compartment values
+ j = self.j.get()
+ j_td = self.j_td.get()
+ z = self.z.get()
+
+ #if tau_m > 0.:
+ if self.is_stateful:
+ ### run a step of integration over neuronal dynamics
+ ## Notes:
+ ## self.pressure <-- "top-down" expectation / contextual pressure
+ ## self.current <-- "bottom-up" data-dependent signal
+ dfx_val = self.dfx(z)
+ j = _modulate(j, dfx_val)
+ j = j * self.resist_scale
+ tmp_z = _run_cell(
+ dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag,
+ priorType=self.priorType
+ )
+ ## apply optional thresholding sub-dynamics
+ if self.thresholdType == "soft_threshold":
+ tmp_z = threshold_soft(tmp_z, self.thr_lmbda)
+ elif self.thresholdType == "cauchy_threshold":
+ tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda)
+ z = tmp_z ## pre-activation function value(s)
+ zF = self.fx(z) * self.output_scale ## post-activation function value(s)
+ else:
+ ## run in "stateless" mode (when no membrane time constant provided)
+ j_total = j + j_td
+ z = _run_cell_stateless(j_total)
+ zF = self.fx(z) * self.output_scale
+
+ # Update compartments
+ self.j.set(j)
+ self.j_td.set(j_td)
+ self.z.set(z)
+ self.zF.set(zF)
+
+ @compilable
+ def reset(self): #, batch_size, shape): #n_units
+ _shape = (self.batch_size, self.shape[0])
+ if len(self.shape) > 1:
+ _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
+ restVals = jnp.zeros(_shape)
+ self.j.set(restVals)
+ self.j_td.set(restVals)
+ self.z.set(restVals)
+ self.zF.set(restVals)
+
+ # def save(self, directory, **kwargs):
+ # ## do a protected save of constants, depending on whether they are floats or arrays
+ # tau_m = (self.tau_m if isinstance(self.tau_m, float)
+ # else jnp.ones([[self.tau_m]]))
+ # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
+ # else jnp.ones([[self.priorLeakRate]]))
+ # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
+ # else jnp.ones([[self.resist_scale]]))
+ #
+ # file_name = directory + "/" + self.name + ".npz"
+ # jnp.savez(file_name,
+ # tau_m=tau_m, priorLeakRate=priorLeakRate,
+ # resist_scale=resist_scale) #, key=self.key.value)
+ #
+ # def load(self, directory, seeded=False, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # data = jnp.load(file_name)
+ # ## constants loaded in
+ # self.tau_m = data['tau_m']
+ # self.priorLeakRate = data['priorLeakRate']
+ # self.resist_scale = data['resist_scale']
+ # #if seeded:
+ # # self.key.set(data['key'])
+
+ @classmethod
+ def help(cls): ## component help function
+ properties = {
+ "cell_type": "RateCell - evolves neurons according to rate-coded/"
+ "continuous dynamics "
+ }
+ compartment_props = {
+ "inputs":
+ {"j": "External input stimulus value(s)",
+ "j_td": "External top-down input stimulus value(s); these get "
+ "multiplied by the derivative of f(x), i.e., df(x)"},
+ "states":
+ {"z": "Update to rate-coded continuous dynamics; value at time t"},
+ "outputs":
+ {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"},
+ }
+ hyperparams = {
+ "n_units": "Number of neuronal cells to model in this layer",
+ "batch_size": "Batch size dimension of this component",
+ "tau_m": "Cell state/membrane time constant",
+ "prior": "What kind of kurtotic prior to place over neuronal dynamics?",
+ "act_fx": "Elementwise activation function to apply over cell state `z`",
+ "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?",
+ "integration_type": "Type of numerical integration to use for the cell dynamics",
+ }
+ info = {cls.__name__: properties,
+ "compartments": compartment_props,
+ "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)",
+ "hyperparameters": hyperparams}
+ return info
+
+if __name__ == '__main__':
+ from ngcsimlib.context import Context
+ with Context("Bar") as bar:
+ X = RateCell("X", 9, 0.03)
+ print(X)
diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py
index ee1ecb02..92c9c5e9 100755
--- a/ngclearn/components/synapses/denseSynapse.py
+++ b/ngclearn/components/synapses/denseSynapse.py
@@ -83,7 +83,7 @@ def __init__(
self.inputs = Compartment(preVals)
self.outputs = Compartment(postVals)
self.weights = Compartment(weights)
- _mask = 1.
+ _mask = jnp.ones((1, 1))
if mask is not None:
_mask = mask
self.mask = Compartment(_mask)
diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py
index 05bfd207..0a1630c3 100644
--- a/ngclearn/components/synapses/hebbian/__init__.py
+++ b/ngclearn/components/synapses/hebbian/__init__.py
@@ -4,4 +4,5 @@
from .expSTDPSynapse import ExpSTDPSynapse
from .eventSTDPSynapse import EventSTDPSynapse
from .BCMSynapse import BCMSynapse
+from .gerstnerHebbianSynapse import GerstnerHebbianSynapse ## Taylor-expansion Hebbian model
diff --git a/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py b/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py
new file mode 100644
index 00000000..d1efca5a
--- /dev/null
+++ b/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py
@@ -0,0 +1,113 @@
+import jax.numpy as jnp
+from jax import random, jit
+
+from ngclearn import compilable
+from ngclearn import Compartment
+from ngclearn.components.synapses import DenseSynapse
+from ngclearn.utils import tensorstats
+from ngcsimlib import deprecate_args
+#from ngclearn.utils.io_utils import save_pkl, load_pkl
+
+class GerstnerHebbianSynapse(DenseSynapse):
+ """
+ A synapse component that implements Gerstner's general Hebbian
+ learning (Taylor) expansion (Equation 3 from Gerstner & Kistler, 2002).
+
+ Note that this synpatic update model can recover several classical forms
+ of Hebbian-like update rules, including the covariance rule.
+
+ There are other higher-order terms possible, i.e., \Theta(xy), such as
+ x * y2 and y x^2, etc.
+
+ | c2_corr > 0 and c0 = c1_pre = c1_post = 0 => Hebbian update
+ | c2_corr < 0 and c0 = c1_pre = c1_post = 0 => anti-Hebbian update
+ | c2_corr = 1 and c1_pre = -x_theta < 0
+
+ """
+ def __init__(
+ self,
+ name,
+ shape, ## (post_dim, pre_dim)
+ eta=0.01, ## global step-size
+ coeffs=None, ## these configure which kind of Hebb learning is done
+ weight_init=None,
+ p_conn=1.,
+ resist_scale=1.,
+ sign_value=1.,
+ batch_size=1,
+ **kwargs
+ ):
+ bias_init = None ## no biases are included in Gerster's formulation
+ super().__init__(
+ name,
+ shape=shape,
+ weight_init=weight_init,
+ bias_init=bias_init,
+ resist_scale=resist_scale,
+ p_conn=p_conn,
+ batch_size=batch_size,
+ **kwargs
+ )
+ ## General Hebbian meta-parameters
+ self.eta = eta
+ self.sign_value = sign_value
+
+ ## Expansion coefficients (c0, c1_pre, c1_post, c2_corr)
+ if coeffs is None: ## Default to standard bilinear Hebb
+ self.coeffs = {
+ 'c0': 0., 'c1_pre': 0., 'c1_post': 0., 'c2_corr': 1.0
+ }
+ else:
+ self.coeffs = coeffs
+ self.c0 = self.coeffs['c0']
+ self.c1_pre = self.coeffs['c1_pre']
+ self.c1_post = self.coeffs['c1_post']
+ self.c2_corr = self.coeffs['c2_corr']
+
+ # Initialize Weights (using JAX PRNG)
+ #init_key, _ = random.split(self.key)
+ #w_init = random.normal(init_key, shape) * 0.05
+
+ # Compartments (ngc-learn state management)
+ #self.weights = Compartment(w_init)
+ self.pre = Compartment(jnp.zeros((1, shape[1])))
+ self.post = Compartment(jnp.zeros((1, shape[0])))
+
+ @compilable
+ def evolve(self, **kwargs):
+ """
+ Updates weights using the Gerstner general expansion.
+ Assumes pre_act and post_act compartments have been populated.
+ """
+ # Retrieve current states
+ W = self.weights.get()
+ x = self.pre.get() # pre-synaptic activity (batch, pre_dim)
+ y = self.post.get() # post-synaptic activity (batch, post_dim)
+ batch_size = self.batch_size
+
+ ## Bilinear Term (c2): correlation matrix
+ ### (post_dim, batch) @ (batch, pre_dim) -> (post_dim, pre_dim)
+ dW_corr = jnp.matmul(x.T, y) * (1./batch_size)
+ ## Linear pre-synaptic term (c1_pre)
+ ### Average over batch then broadcast to match weight matrix
+ dW_pre = jnp.sum(x, axis=0, keepdims=True).T * (1./batch_size)
+ ## Linear post-synaptic term (c1_post)
+ dW_post = jnp.sum(y, axis=0, keepdims=True) * (1./batch_size)
+
+ ## Apply Equation 3 Taylor expansion
+ dW = (self.c0 * W + ## synaptic decay
+ self.c1_pre * dW_pre + ## bilinear term
+ self.c1_post * dW_post + ## pre-synaptic gating term
+ self.c2_corr * dW_corr ## post-synpatic gating term
+ )
+ ## perform a step of Hebbian ascent
+ W = W + self.eta * dW
+ ## Update weights
+ self.weights.set(W)
+
+ @compilable
+ def reset(self, **kwargs):
+ """Clears activity compartments"""
+ self.pre.set( jnp.zeros((self.batch_size, self.shape[1])) )
+ self.post.set( jnp.zeros((self.batch_size, self.shape[0])) )
+
diff --git a/ngclearn/modules/regression/elastic_net.py b/ngclearn/modules/regression/elastic_net.py
index 5860d2bc..06cd09bd 100644
--- a/ngclearn/modules/regression/elastic_net.py
+++ b/ngclearn/modules/regression/elastic_net.py
@@ -77,13 +77,15 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
self.W = HebbianSynapse(
"W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1,
weight_init=dist.constant(value=weight_fill), prior=('elastic_net', (lmbda, l1_ratio)), w_bound=0.,
- optim_type=optim_type, key=subkeys[0]
+ optim_type=optim_type, key=subkeys[0], batch_size=batch_size
)
- self.err = GaussianErrorCell("err", n_units=sys_dim)
+ self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size)
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.batch_size = batch_size
self.err.batch_size = batch_size
+ self.W.reset()
+ self.err.reset()
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.outputs >> self.err.mu
self.err.dmu >> self.W.post
diff --git a/ngclearn/modules/regression/lasso.py b/ngclearn/modules/regression/lasso.py
index 15a014bb..8d5f91d5 100644
--- a/ngclearn/modules/regression/lasso.py
+++ b/ngclearn/modules/regression/lasso.py
@@ -76,12 +76,14 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
self.W = HebbianSynapse(
"W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1,
weight_init=dist.constant(value=weight_fill), prior=('lasso', lasso_lmbda), w_bound=0.,
- optim_type=optim_type, key=subkeys[0]
+ optim_type=optim_type, key=subkeys[0], batch_size=batch_size
)
- self.err = GaussianErrorCell("err", n_units=sys_dim)
+ self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size)
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.batch_size = batch_size
self.err.batch_size = batch_size
+ self.W.reset()
+ self.err.reset()
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.outputs >> self.err.mu
self.err.dmu >> self.W.post
diff --git a/ngclearn/modules/regression/ridge.py b/ngclearn/modules/regression/ridge.py
index dfbacb03..84591670 100644
--- a/ngclearn/modules/regression/ridge.py
+++ b/ngclearn/modules/regression/ridge.py
@@ -76,13 +76,15 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
self.W = HebbianSynapse(
"W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1,
weight_init=dist.constant(value=weight_fill), prior=('ridge', ridge_lmbda), w_bound=0.,
- optim_type=optim_type, key=subkeys[0]
+ optim_type=optim_type, key=subkeys[0], batch_size=batch_size
)
- self.err = GaussianErrorCell("err", n_units=sys_dim)
+ self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size)
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.batch_size = batch_size
self.err.batch_size = batch_size
+ self.W.reset()
+ self.err.reset()
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.outputs >> self.err.mu
self.err.dmu >> self.W.post
diff --git a/ngclearn/utils/JaxProcessesMixin.py b/ngclearn/utils/JaxProcessesMixin.py
index 3b849f09..3958dfb3 100644
--- a/ngclearn/utils/JaxProcessesMixin.py
+++ b/ngclearn/utils/JaxProcessesMixin.py
@@ -9,6 +9,10 @@
from ngcsimlib._src.process.baseProcess import BaseProcess
class JaxCompiledMethod(CompiledMethod):
+ """
+ A wrapper for a compiled method that includes jax's jit wrapped. Used
+ exclusively by the mixin and shouldn't be used elsewhere.
+ """
def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals):
super().__init__(fn, fn_ast, auxiliary_ast, namespace, extra_globals)
self._fn = jax.jit(fn)
@@ -16,10 +20,19 @@ def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals):
@property
def source_fn(self):
+ """
+ The source method not wrapped in jit
+ """
return self._fn_source
@classmethod
def wrap(cls, compiledMethod: CompiledMethod):
+ """
+ Helper method to expand on a base compiled method
+ Args:
+ compiledMethod: The method to be expanded upon
+ Returns: the JaxCompiledMethod based on the input
+ """
return cls(compiledMethod._fn,
compiledMethod.ast,
compiledMethod.auxiliary_ast,
@@ -28,7 +41,16 @@ def wrap(cls, compiledMethod: CompiledMethod):
class JaxProcessesMixin:
+ """
+ A mixin for the base Process that adds JAX functionality such as scan and
+ implicit jit wrapping
+ """
def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs):
+ """
+ Look at the BaseProcess class for information about other arguments
+ Args:
+ use_jit: a flag for if the process should implicitly jit wrap
+ """
super().__init__(name, *args, **kwargs)
self._previous_result = None
self._previous_state = None
@@ -36,27 +58,51 @@ def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs):
@property
def previous_result(self):
+ """
+ Stores and returns the last result of scan (the second returned value)
+ """
return self._previous_result
@property
def previous_state(self):
+ """
+ Stores and returns the last returned state of scan (the first returned
+ value)
+ """
return self._previous_state
def clear(self):
+ """
+ Clears out the previous result and state from scan
+ """
self._previous_result = None
self._previous_state = None
- def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True):
+ def scan(self: "BaseProcess", inputs, current_state=None, store_state: bool = True, store_results: bool = True):
+ """
+ Runs the process through jax's scan method
+ Args:
+ inputs: The inputs for scan (use pack rows to generate), must be a jax array
+ current_state: Optional, the current state of the model, if none uses current global state
+ store_state: Optional flag, should the final state be stored in the process
+ store_results: Optional flag, should the final result be stored in the process
+
+ Returns: the final state, the final result
+
+ """
state = current_state or stateManager.state
final_state, result = jax.lax.scan(self.run.compiled, state, inputs)
- if save_state:
+ if store_state:
self._previous_state = final_state
if store_results:
self._previous_result = result
return final_state, result
def compile(self: "baseProcess"):
+ """
+ For use by the compiler
+ """
super().compile()
if self._use_jit:
self.run.compiled = JaxCompiledMethod.wrap(self.run.compiled)
diff --git a/ngclearn/utils/analysis/__init__.py b/ngclearn/utils/analysis/__init__.py
index 082c68cd..31506c3e 100644
--- a/ngclearn/utils/analysis/__init__.py
+++ b/ngclearn/utils/analysis/__init__.py
@@ -2,4 +2,4 @@
from .linear_probe import LinearProbe
from .attentive_probe import AttentiveProbe
from .knn_probe import KNNProbe
-
+from .kmeans_probe import KMeansProbe
diff --git a/ngclearn/utils/analysis/effective_dim.py b/ngclearn/utils/analysis/effective_dim.py
new file mode 100644
index 00000000..7b46f318
--- /dev/null
+++ b/ngclearn/utils/analysis/effective_dim.py
@@ -0,0 +1,46 @@
+from jax import numpy as jnp
+
+def participation_ratio(latent_codes):
+ """
+ Calculates the participation ratio coefficient for a set of latent codes
+
+ Args:
+ latent_codes: a set of (N x D) latent code vectors (one row per vector code)
+
+ Returns:
+ scalar measurement of the effective dimension
+ """
+ Z = latent_codes
+ Zc = Z - Z.mean(axis=0, keepdims=True)
+ cov = (Zc.T @ Zc) / (Zc.shape[0] - 1)
+
+ tr = jnp.trace(cov)
+ tr2_cov = tr * tr
+ cov2_tr = jnp.trace(cov @ cov)
+
+ return tr2_cov / cov2_tr if cov2_tr > 0 else float("nan")
+
+
+
+
+def rankme(Z, eps=1e-7):
+ """
+ Calculates the effective rank of for a code matrix Z
+ effective rank = exp(Shannon entropy), from Garrido, Balestriero,
+ Najman & LeCun, "RankMe: Assessing the Downstream Performance of Pretrained
+ Self-Supervised Representations by Their Rank" (ICML 2023, arXiv:2210.02885).
+
+ Args:
+ latent_codes: a set of (N x D) latent code vectors (one row per vector code)
+
+ Returns:
+ scalar measurement of the effective dimension
+ """
+
+ singular_values = jnp.linalg.svd(Z, compute_uv=False) ## singular values of Z
+ sum_singular_vals = jnp.sum(singular_values) ## L1
+ if sum_singular_vals <= 0:
+ return float("nan")
+ p = singular_values / sum_singular_vals + eps ## L1-normalized singular value
+ shannon_entropy = -jnp.sum(p * jnp.log(p)) ## Shannon entropy
+ return jnp.exp(shannon_entropy) ## exp(Shannon entropy) = effective rank
diff --git a/ngclearn/utils/analysis/kmeans_probe.py b/ngclearn/utils/analysis/kmeans_probe.py
new file mode 100644
index 00000000..676a00cd
--- /dev/null
+++ b/ngclearn/utils/analysis/kmeans_probe.py
@@ -0,0 +1,109 @@
+import jax
+from ngcsimlib import deprecate_args
+from ngclearn.utils.analysis.probe import Probe
+from jax import jit, random, numpy as jnp, lax, nn
+from functools import partial as bind
+from ngclearn.utils.metric_utils import measure_ARI
+
+@bind(jax.jit, static_argnums=[2])
+def _run_kmeans_probe(_embeddings, centroids, n_clusters):
+ ## Broadcast distances: (n_samples, 1, n_features) - (1, n_clusters, n_features)
+ distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1)
+ labels_pred = jnp.argmin(distances, axis=1)
+ ## Re-estimate centroids/means
+ one_hot_preds = labels_pred[:, None] == jnp.arange(n_clusters)
+ counts = jnp.maximum(one_hot_preds.sum(axis=0, keepdims=True).T, 1.0)
+ centroids = jnp.dot(one_hot_preds.T.astype(jnp.float32), _embeddings) / counts
+ return centroids
+
+@bind(jax.jit, static_argnums=[2])
+def _predict_with_probe(_embeddings, centroids, n_clusters):
+ ## Final pass to compute stable predictions
+ distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1)
+ labels_pred = jnp.argmin(distances, axis=1)
+ Y_pred = nn.one_hot(labels_pred, n_clusters)
+ return labels_pred, Y_pred
+
+class KMeansProbe(Probe):
+ """
+ This implements a K-means clustering probe, which is useful for evaluating the quality of
+ encodings/embeddings in light of the ability to cluster downstream data. Currently, this
+ probe only supports L2/Euclidean distance-based clustering.
+
+ Args:
+ dkey: init seed key
+
+ source_seq_length: length of input sequence (e.g., height x width of the image feature)
+
+ input_dim: input dimensionality of probe
+
+ out_dim: output dimensionality of probe - number of clusters for this probe to create
+
+ batch_size:
+
+ """
+
+ def __init__(
+ self,
+ dkey,
+ source_seq_length,
+ input_dim,
+ out_dim=2, ## number of clusters/centroids to uncover
+ batch_size=1,
+ **kwargs
+ ):
+ super().__init__(dkey, batch_size, **kwargs)
+ self.dkey, *subkeys = random.split(self.dkey, 3)
+ self.source_seq_length = source_seq_length
+ self.input_dim = input_dim
+ self.n_clusters = self.out_dim = out_dim
+ ## centroids that will be uncovered by this probe
+ self.centroids : jax.Array = None
+
+ def _init(self, embeddings):
+ _embeddings = embeddings
+ if len(_embeddings.shape) > 2:
+ flat_dim = embeddings.shape[1] * embeddings.shape[2]
+ _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
+ ## choose random data-points to serve as centroids at iteration 0
+ self.dkey, *subkeys = random.split(self.dkey, 15)
+ n_samples, n_features = _embeddings.shape
+ random_indices = random.choice(
+ subkeys[0], n_samples, shape=(self.n_clusters,), replace=False
+ )
+ self.centroids = _embeddings[random_indices]
+
+ def process(self, embeddings, dkey=None):
+ _embeddings = embeddings
+ if len(_embeddings.shape) > 2:
+ flat_dim = embeddings.shape[1] * embeddings.shape[2]
+ _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
+ ## Compute final geometric vs semantic conformity via ARI
+ _, Y_pred = _predict_with_probe(_embeddings, self.centroids, self.n_clusters)
+ return Y_pred ## (B, C)
+
+ def update(self, embeddings, labels, dkey=None):
+ _embeddings = embeddings
+ if len(_embeddings.shape) > 2:
+ flat_dim = embeddings.shape[1] * embeddings.shape[2]
+ _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
+ self.centroids = _run_kmeans_probe(_embeddings, self.centroids, self.n_clusters)
+ L = 0. ## FIXME: should be clustering loss
+ predictions = self.process(_embeddings)
+ return L, predictions
+
+ def fit(self, dataset, dev_dataset=None, n_iter=20, patience=20):
+ data, labels = dataset
+ _labels = jnp.argmax(labels, axis=-1)
+
+ self._init(data) ## init K-means centroids
+ ari = 0.
+ for i in range(n_iter): ## Run vectorized K-Means optimization loop
+ _L, py = self.update(data, labels)
+ labels_pred = jnp.argmax(py, axis=1)
+ ari_i = measure_ARI(_labels, labels_pred)
+ print(f"\r{i}: ARI = {ari_i}", end="")
+ if ari_i > ari:
+ ari = ari_i
+ print()
+ return ari
diff --git a/ngclearn/utils/analysis/knn_probe.py b/ngclearn/utils/analysis/knn_probe.py
index 6935a3a0..b9656371 100644
--- a/ngclearn/utils/analysis/knn_probe.py
+++ b/ngclearn/utils/analysis/knn_probe.py
@@ -2,28 +2,34 @@
import numpy as np
from ngcsimlib import deprecate_args
from ngclearn.utils.analysis.probe import Probe
-from ngclearn.utils.model_utils import kwta
from jax import jit, random, numpy as jnp, lax, nn
from functools import partial as bind
-from ngclearn.utils.distribution_generator import DistributionGenerator
-
-@bind(jax.jit, static_argnums=[2, 3])
-def _run_knn_probe(_embeddings, Wx, K, dist_order=2):
- ## Notes:
- ### We do some 3D tensor math to handle a batch of predictions that need to be made
- ### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories
- _Wx = jnp.expand_dims(Wx, axis=0) ## 3D tensor format of KNN params (1 x N x D)
- embed_tensor = jnp.expand_dims(_embeddings, axis=1) ## 3D projection of input signals (B x 1 x D)
- D = embed_tensor - _Wx ## compute 3D batched delta tensor (B x N x D)
- ## get batched (negative) distance measurements
- dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True) ## (B x N x 1)
- ## else, default -> euclidean
- ### Note: negative distance allows us to find minimal points w/ maximal functions
- dist = -jnp.squeeze(dist, axis=2) ## (B x N)
- ## now get K winners per sample in batch
+
+
+@bind(jax.jit, static_argnums=[2, 3, 4])
+def _run_knn_probe(_embeddings, Wx, K, dist_order=2, dist_metric="minkowski"):
+ if dist_metric == "cosine":
+ ## normalize the incoming batch embeddings along the feature axis (axis 1)
+ ### add a tiny epsilon to prevent division-by-zero errors
+ eps = 1e-12
+ embed_norm = _embeddings / (jnp.linalg.norm(_embeddings, axis=1, keepdims=True) + eps)
+ ## normalize the internal memory database array (axis 1)
+ Wx_norm = Wx / (jnp.linalg.norm(Wx, axis=1, keepdims=True) + eps)
+ ## compute batched cosine similarity using standard 2D matrix multiplication
+ ### (B x D) @ (D x N) -> yields a (B x N) similarity matrix directly!
+ dist = jnp.matmul(embed_norm, Wx_norm.T)
+ else: # Default back to your Minkowski setup
+ _Wx = jnp.expand_dims(Wx, axis=0) ## (1 x N x D)
+ embed_tensor = jnp.expand_dims(_embeddings, axis=1) ## (B x 1 x D)
+ D = embed_tensor - _Wx ## (B x N x D)
+ dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True) ## (B x N x 1)
+ dist = -jnp.squeeze(dist, axis=2) ## (B x N)
+
+ # lax.top_k naturally grabs the maximums (highest similarity or smallest negative distance)
values, indices = lax.top_k(dist, K)
return values, indices
+
class KNNProbe(Probe):
"""
This implements a K-nearest neighbors (KNN) probe, which is useful for evaluating the quality of
@@ -82,16 +88,21 @@ def __init__(
self.vote_fx = 0 ## 0 -> mode prediction; 1 -> mean prediction
if vote_style == "mean":
self.vote_fx = 1
+
self.distance_function = distance_function
- dist_fun, dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean
- if "euclidean" in dist_fun.lower():
+ dist_fun, dist_order = distance_function
+ self.dist_metric = "minkowski" # default tracker
+ if "cosine" in dist_fun.lower():
+ self.dist_metric = "cosine"
+ dist_order = 2 ## fallback assignment
+ elif "euclidean" in dist_fun.lower():
dist_order = 2
elif "manhattan" in dist_fun.lower():
dist_order = 1
elif "chebyshev" in dist_fun.lower():
dist_order = jnp.inf
- ## TODO: add in cosine-distance (and maybe Mahalanobis distance)
self.dist_order = dist_order ## set distance order p
+
self.predictor_type = predictor_type
self.pred_fx = 0
if "regressor" == predictor_type:
@@ -102,14 +113,18 @@ def __init__(
Wx = Wy = jnp.ones((1, 1)) ## Wy will be assumed to be one-hot encoded
self.probe_params = (Wx, Wy)
- def process(self, embeddings, dkey=None): ## TODO: JIT-i-fy this
+ def process(self, embeddings, dkey=None):
_embeddings = embeddings
if len(_embeddings.shape) > 2:
flat_dim = embeddings.shape[1] * embeddings.shape[2]
_embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim))
- Wx, Wy = self.probe_params ## pull out KNN parameters
- values, indices = _run_knn_probe(_embeddings, Wx, self.K, self.dist_order)
+ Wx, Wy = self.probe_params
+
+ # Pass the explicit metric string directly to the JIT-compiled loop
+ values, indices = _run_knn_probe(
+ _embeddings, Wx, self.K, self.dist_order, self.dist_metric
+ )
## do K-neighbor voting scheme (find mode/frequency prediction)
Y_counts = jnp.zeros((_embeddings.shape[0], Wy.shape[1]))
@@ -139,25 +154,25 @@ def update(self, embeddings, labels, dkey=None):
Wy = labels
self.probe_params = (Wx, Wy)
-if __name__ == '__main__':
- seed = 42
- D = 7
- C = 5
- dkey = random.PRNGKey(seed)
- dkey, *subkeys = random.split(dkey, 3)
- knn = KNNProbe(
- subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean"
- )
- X = random.uniform(subkeys[1], shape=(10, D))
- Y = jnp.concat(
- [
- jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]),
- jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]),
- jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]),
- jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]),
- jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]])
- ],
- axis=0
- )
- knn.update(X, Y) ## fit KNN to data
- print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y
+# if __name__ == '__main__':
+# seed = 42
+# D = 7
+# C = 5
+# dkey = random.PRNGKey(seed)
+# dkey, *subkeys = random.split(dkey, 3)
+# knn = KNNProbe(
+# subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean"
+# )
+# X = random.uniform(subkeys[1], shape=(10, D))
+# Y = jnp.concat(
+# [
+# jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]),
+# jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]),
+# jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]),
+# jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]),
+# jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]])
+# ],
+# axis=0
+# )
+# knn.update(X, Y) ## fit KNN to data
+# print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y
diff --git a/ngclearn/utils/filters/__init__.py b/ngclearn/utils/filters/__init__.py
new file mode 100644
index 00000000..7f778d24
--- /dev/null
+++ b/ngclearn/utils/filters/__init__.py
@@ -0,0 +1,3 @@
+from .gauss_filter import gaussian_filter
+from .cortical_gauss_filter import cortical_gaussian_filter
+
diff --git a/ngclearn/utils/filters/cortical_gauss_filter.py b/ngclearn/utils/filters/cortical_gauss_filter.py
new file mode 100644
index 00000000..5a57eaa9
--- /dev/null
+++ b/ngclearn/utils/filters/cortical_gauss_filter.py
@@ -0,0 +1,94 @@
+"""
+Support for a function/pure "cortical" Gaussian filter - this can be configured to facilitate a filter that
+engages in a difference-of-Gaussians-like or ratio-of-Gaussians-like process
+"""
+
+import jax.numpy as jnp
+from jax import lax, jit
+from functools import partial
+
+def _calc_gaussian_kernel_2D(
+ sigma: float, ## standard deviation of kernel
+ radius: int ## controls shape of kernel
+) -> jnp.ndarray: ## internal co-routine for Gaussian kernel
+ ## create a normalized 2D Gaussian kernel with shape (1, 1, 2*radius+1, 2*radius+1)
+ x = jnp.arange(-radius, radius + 1)
+ xx, yy = jnp.meshgrid(x, x)
+ kernel = jnp.exp(-0.5 * (xx ** 2 + yy ** 2) / sigma ** 2)
+ kernel = kernel / jnp.sum(kernel)
+ return kernel[jnp.newaxis, jnp.newaxis, :, :]
+
+@partial(jit, static_argnums=[3, 4, 5])
+def cortical_gaussian_filter(
+ images: jnp.ndarray, ## expected input shape : (B, C, H, W)
+ sigma_center: float, ## center excitation
+ sigma_surround: float, ## surround inhibition
+ kernel_size: int, ## kernel radius
+ use_ratio:bool=False, ## triggers either RoG vs DoG mode
+ semi_sat_constant:float=0.1, ## sigma_h (controls contrast-gain)
+ excitation_exp:float=2.0, ## p-exponent (typically in range of 1.5 - 2.0)
+ inhibition_exp:float=2.0, ## q-exponent
+ edge_pad_mode:str='edge'
+) -> jnp.ndarray:
+ """
+ Applies a configurable rectified Gaussian filter (either difference-of-Gaussians, i.e., DoG, or ratio-of-Gaussians,
+ i.e., RoG) to a tensor batch of 2D images (each of CxHxW tensor shape/format). Note that this variant filter
+ means that DoG mode acts more as a (half-wave) rectified subtraction of two Gaussian kernels and RoG mode acts more
+ as simple form of divisive normalization.
+
+ Args:
+ images: input image tensor of shape (B, C, H, W)
+
+ sigma_center: standard deviation for narrow / center blur
+
+ sigma_surround: standard deviation for wide / surround blur
+
+ kernel_size: kernel radius (window size will be `2*radius + 1`)
+
+ use_ratio: if True, this filter applies a ratio-of-Gaussians (RoG) filter (Default: False)
+
+ semi_sat_constant: suppresses amplified micro-variations in sensory/image space
+
+ excitation_exp: p-exponent (typically in range of 1.5 - 2.0) (Default: 2.0)
+
+ inhibition_exp: q-exponent (typically in range of 1.5 - 2.0) (Default: 2.0)
+
+ edge_pad_mode: type of image edge-clamping/padding to use, either "edge" or "reflect" (Default: "edge")
+
+ Returns:
+ An output tensor of shape (B, C, H, W)
+ """
+ ## set up edge-artifact padding correction
+ padding_config = ((0, 0), (0, 0), (kernel_size, kernel_size), (kernel_size, kernel_size))
+ padded_x = jnp.pad(images, padding_config, mode=edge_pad_mode)
+
+ ## set up Gaussian kernels
+ k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size)
+ k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size)
+
+ dn = lax.ConvDimensionNumbers(
+ lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3)
+ )
+ num_channels = images.shape[1]
+
+ ## apply standard spatial convolutions
+ blur_center = lax.conv_general_dilated(
+ padded_x, k1, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels
+ )
+ blur_surround = lax.conv_general_dilated(
+ padded_x, k2, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels
+ )
+ if use_ratio:
+ ## cortically-plausible divisive normalization (or a cortical ratio-of-Gaussians; RoG)
+ ### first, apply half-wave rectification
+ rectified_center = jnp.maximum(0.0, blur_center)
+ rectified_surround = jnp.maximum(0.0, blur_surround)
+ ## next, apply nonlinear response exponents
+ numerator = jnp.power(rectified_center, excitation_exp)
+ denominator = jnp.power(rectified_surround, inhibition_exp) + (semi_sat_constant ** 2)
+ output = numerator / denominator ## calculate ratio
+ else: ## cortically-plausible difference-of-Gaussians (cortical DoG)
+ ### this is modeled via rectified linear subtraction under an output threshold (0 in the case below)
+ output = jnp.maximum(0.0, blur_center - blur_surround)
+ return output
+
diff --git a/ngclearn/utils/filters/gauss_filter.py b/ngclearn/utils/filters/gauss_filter.py
new file mode 100644
index 00000000..ef592b52
--- /dev/null
+++ b/ngclearn/utils/filters/gauss_filter.py
@@ -0,0 +1,93 @@
+"""
+Support for a function/pure Gaussian filter - this can be configured to facilicate
+difference-of-Gaussians (DoG) or ratio-of-Gaussians (RoG).
+"""
+
+import jax.numpy as jnp
+from jax import lax, jit
+from functools import partial
+
+def _calc_gaussian_kernel_2D( ## internal co-routine for Gaussian kernel
+ sigma: float, ## standard deviation of kernel
+ radius: int ## controls shape of kernel
+) -> jnp.ndarray:
+ ## create a normalized 2D Gaussian kernel with shape (1, 1, 2*radius+1, 2*radius+1)
+ x = jnp.arange(-radius, radius + 1)
+ xx, yy = jnp.meshgrid(x, x)
+ kernel = jnp.exp(-0.5 * (xx ** 2 + yy ** 2) / sigma ** 2)
+ kernel = kernel / jnp.sum(kernel)
+ ## reshape output to: (out_channels, in_channels, height, width) to support lax.conv_general_dilated w/in filter
+ return kernel[jnp.newaxis, jnp.newaxis, :, :]
+
+@partial(jit, static_argnums=[3, 4])
+def gaussian_filter( ## core filter routine
+ images: jnp.ndarray, ## input image batch
+ sigma_center: float, ## sigma1
+ sigma_surround: float, ## sigma2
+ kernel_size : int, ## radius
+ use_ratio:bool=False, ## if True, this becomes a ratio-of-Gaussians
+ edge_pad_mode:str="edge" ## "reflect"
+) -> jnp.ndarray:
+ """
+ Applies a configurable Gaussian filter (either difference-of-Gaussians or ratio-of-Gaussians) to a tensor batch of
+ 2D images (of CxHxW tensor shape).
+
+ Args:
+ images: input image tensor of shape (B, C, H, W)
+
+ sigma_center: standard deviation for narrow / center blur
+
+ sigma_surround: standard deviation for wide / surround blur
+
+ kernel_size: kernel radius (window size will be `2*radius + 1`)
+
+ use_ratio: if True, this filter applies a ratio-of-Gaussians (RoG) filter (Default: False)
+
+ edge_pad_mode: type of image edge-clamping/padding to use, either "edge" or "reflect" (Default: "edge")
+
+ Returns:
+ An output tensor of shape (B, C, H, W)
+ """
+ ## pad spatial dimensions (H, W) using edge/reflect-clamping in order to remove artifacts
+ ### format for 4D tensor (B, C, H, W) =>
+ ### result: ((Before_B, After_B), (Before_C, After_C), (Before_H, After_H), (Before_W, After_W))
+ padding_config = ((0, 0), (0, 0), (kernel_size, kernel_size), (kernel_size, kernel_size))
+ padded_x = jnp.pad(images, padding_config, mode=edge_pad_mode)
+
+ ## construct two 2D Gaussian kernels
+ k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size) ## center kernel
+ k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size) ## surround kernel
+
+ ## define dimension ordering for lax.conv ('NCHW' standard layout)
+ dn = lax.ConvDimensionNumbers(
+ lhs_spec=(0, 1, 2, 3), ## (batch, channel, height, width)
+ rhs_spec=(0, 1, 2, 3), ## (out_channel, in_channel, height, width)
+ out_spec=(0, 1, 2, 3) ## (batch, channel, height, width)
+ )
+ num_channels = images.shape[1] ## get channel count dynamically for independent channel-wise filtering
+
+ ## below performs spatial convolutions w/ "VALID" padding on the edge-padded input
+ blur_center = lax.conv_general_dilated( ## center Gaussian
+ padded_x,
+ k1,
+ window_strides=(1, 1),
+ padding='VALID',
+ dimension_numbers=dn,
+ feature_group_count=num_channels
+ )
+ blur_surround = lax.conv_general_dilated( ## surround Gaussian
+ padded_x,
+ k2,
+ window_strides=(1, 1),
+ padding='VALID',
+ dimension_numbers=dn,
+ feature_group_count=num_channels
+ )
+ ## Perform final filter calculation
+ if use_ratio: ## apply ratio-of-Gaussians (RoG)
+ eps = 1e-5
+ output = blur_center / (blur_surround + eps) ## calculate kernel ratio
+ else: ## apply difference-of-Gaussians (DoG)
+ output = blur_center - blur_surround ## calculate kernel difference
+ return output ## final shape: (B, C, H, W)
+
diff --git a/ngclearn/utils/io_utils.py b/ngclearn/utils/io_utils.py
index 8553af44..5365b34e 100755
--- a/ngclearn/utils/io_utils.py
+++ b/ngclearn/utils/io_utils.py
@@ -78,3 +78,4 @@ def load_pkl(directory: str, name: str) -> Any:
with open(file_name, 'rb') as f:
data = pickle.load(f)
return data
+
diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py
index 457043cb..e30f3d9c 100755
--- a/ngclearn/utils/metric_utils.py
+++ b/ngclearn/utils/metric_utils.py
@@ -67,12 +67,12 @@ def measure_breadth_TC(spikes, preserve_batch=False):
spikes: full spike train matrix; shape is (T x D) where D is number of
neurons in a group/cluster
- preserve_batch: if True, will return one score per sample in batch
+ preserve_batch: if True, will return one score per neuron in train/window
(Default: False), otherwise, returns scalar average score
Returns:
- a 1 x D Fano factor vector (one factor per neuron) OR a single
- average Fano factor across the neuronal group
+ a 1 x D BTC vector (one factor per neuron) OR a single
+ average BTC across the neuronal group
"""
mu = jnp.mean(spikes, axis=0, keepdims=True)
sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True))
@@ -82,6 +82,42 @@ def measure_breadth_TC(spikes, preserve_batch=False):
BTC = jnp.mean(BTC)
return BTC
+@partial(jit, static_argnums=[1])
+def measure_gini_index(codes, preserve_batch=True):
+ """
+ Calculates the gini index a group of neurons represented as vector code samples.
+ Gini index measures the sparseness of the values within each vector code, where
+ a higher index value indicates higher sparsity and a lower index value indicates a
+ lower sparsity (higher density).
+
+ Args:
+ codes: a batch of neural codes; shape is (N x D) where D is number of
+ neurons in a group/cluster and N is number of samples
+
+ preserve_batch: if True, will return one score per sample in batch
+ (Default: False), otherwise, returns scalar average score
+
+ Returns:
+ a N x 1 Gini index vector (one score per neuron) OR a single
+ average Gini score for the whole sample/set of codes
+ """
+ ## Gini index
+ ### values closer to 1 indicate high sparsity (sparser codes)
+ ### values closer to 0 indicate lower sparsity (denser codes)
+ _codes = codes + (jnp.sum(codes, axis=1, keepdims=True) <= 0.) + 1e-8
+ ### note that the calculation below is faster than the mean-absolute-value
+ ### form of gini-index; below calculation requires sorting but yields a
+ ### lower-complexity calculation
+ D = codes.shape[1] ## length of vector
+ codes_sorted = jnp.sort(jnp.abs(_codes), axis=1) ## sort all codes w/in batch matrix
+ index = jnp.arange(1, D + 1)
+ term1 = jnp.sum((2 * index - D - 1) * codes_sorted, axis=1, keepdims=True)
+ term2 = D * jnp.sum(codes_sorted, axis=1, keepdims=True)
+ gini = term1 / term2 ## calc final ratio
+ if not preserve_batch:
+ gini = jnp.mean(gini) ## this is the mean gini-index
+ return gini
+
@partial(jit, static_argnums=[2, 3])
def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=False):
"""
@@ -89,9 +125,13 @@ def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=Fals
this matrix is a non-negative vector.
Formally, this means we compute, per i-th row:
+
| rho(x_i) = num_zeros(x_i) / dim(x_i)
+
and for a global score for matrix X with N codes/rows, we measure:
+
| rho_mean(X) = 1/N Sum^N_{i=1} rho(x_i)
+
where lower/closer to 0 means codes more sparse and closer to 1 means
codes are more dense.
@@ -136,9 +176,9 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st
y: target / ground-truth (design) matrix; shape is (N x C) OR an array
of class integers of length N (with "extract_label_indx = True")
- extract_label_indx: run an argmax to pull class integer indices from
+ extract_label_indx: wehn True, run an argmax to pull class integer indices from
"y", assuming y is a one-hot binary encoding matrix (Default: True),
- otherwise, this assumes "y" is an array of class integer indices
+ otherwise, if False, this treats "y" is an array of class integer indices
of length N
Returns:
@@ -423,3 +463,289 @@ def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10
if not preserve_batch:
bce = jnp.mean(bce)
return bce
+
+@partial(jit, static_argnums=[1])
+def measure_hoyer_sparsity(codes: jnp.ndarray, preserve_batch: bool=False) -> float:
+ """
+ Measures the Hoyer sparsity for a set of latent codes.
+ Hoyer sparsity lies in [0, 1], where a value of 0.0 indicates if something is dense and
+ a value of 1 indicates something is extremely sparse.
+
+ Args:
+ codes: matrix (shape: N x D) of non-negative codes to measure
+ sparsity of (per row); D is flattened latent code size
+
+ preserve_batch: if True, will return one score per sample in batch
+ (Default: False), otherwise, returns scalar mean score
+
+ Returns:
+ an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
+ """
+ # Flatten everything past the batch dimension
+ x = jnp.reshape(codes, (codes.shape[0], -1))
+ N = x.shape[1]
+
+ l1 = jnp.sum(jnp.abs(x), axis=1)
+ l2 = jnp.sqrt(jnp.sum(jnp.square(x), axis=1) + 1e-8) # epsilon to avoid division by zero
+
+ hoyer = (jnp.sqrt(N) - (l1 / l2)) / (jnp.sqrt(N) - 1.0)
+ if not preserve_batch:
+ hoyer = jnp.mean(hoyer) # calc average sparsity across set/batch
+ return hoyer
+
+@partial(jit, static_argnums=[1])
+def measure_excess_kurtosis(codes: jnp.ndarray, preserve_batch: bool=False) -> float:
+ """
+ Measures the peak and heavy-tailedness of a set of neural activation codes. Note that
+ higher values (> 0) indicate sparse, localized 'high-burst' activations.
+
+ Args:
+ codes: matrix (shape: N x D) of non-negative codes to measure
+ sparsity of (per row)
+
+ preserve_batch: if True, will return one score per sample in batch
+ (Default: False), otherwise, returns scalar mean score
+
+ Returns:
+ an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
+ """
+ x = jnp.reshape(codes, (codes.shape[0], -1))
+ mean = jnp.mean(x, axis=1, keepdims=True) ## 1st moment
+ variance = jnp.var(x, axis=1, keepdims=True) ## 2nd moment
+
+ ## 4th central moment divided by variance squared
+ fourth_moment = jnp.mean(jnp.power(x - mean, 4), axis=1, keepdims=True)
+ kurtosis = fourth_moment / (jnp.square(variance) + 1e-8) ## kurtosis of distribution
+ excess_kurtosis = kurtosis - 3.0 ## calc "excess kurtosis" by subtracting 3
+ if not preserve_batch:
+ excess_kurtosis = jnp.mean(excess_kurtosis) ## calc avg excess-kurtosis over set/batch
+ return excess_kurtosis
+
+
+### class conformity metrics ###
+
+@partial(jit, static_argnums=[2, 3])
+def _compute_contingency_table( ## vectorized construction of contingency matrix
+ labels_true: jnp.ndarray,
+ labels_pred: jnp.ndarray,
+ n_classes: int,
+ n_clusters: int
+) -> jnp.ndarray:
+ ## Computes a contingency matrix table
+ ## This routine expects true integer labels and predicted integer labels (1D arrays of size N)
+
+ # Create indicator masks across all unique classes/clusters
+ # find unique IDs safely up to a static maximum size (or provide num_classes)
+ # n_classes = n_true = jnp.max(labels_true) + 1
+ # n_clusters = n_pred = jnp.max(labels_pred) + 1
+
+ # Broadcast to form a full one-hot lookup map
+ true_mask = labels_true[:, None] == jnp.arange(n_classes)
+ pred_mask = labels_pred[:, None] == jnp.arange(n_clusters)
+
+ # Contingency matrix is the matrix product of boolean indicators
+ contingency = jnp.dot(true_mask.T.astype(jnp.float32), pred_mask.astype(jnp.float32))
+ return contingency
+
+
+def measure_ARI(
+ labels_true: jnp.ndarray,
+ labels_pred: jnp.ndarray
+) -> jnp.ndarray:
+ """
+ Computes the adjusted random index (ARI), which measures similarity between two
+ sets of indices (ground truth against a clustering's produced indices) via counting the
+ pairs of data points assigned to same or different clusters (adjusted for chance). This
+ measurement lies in `[0, 1]`, where `0` indicates a random labeling/assignment and `1` indicates
+ perfect agreement.
+
+ Args:
+ labels_true: 1D array of shape (n_samples,) with true integer class labels.
+
+ labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels.
+
+ Returns:
+ scalar ARI of these two sets of indices
+ """
+ ## Dynamically find dimensions up to a statically bounded maximum
+ n_classes = int(jnp.max(labels_true) + 1)
+ n_clusters = int(jnp.max(labels_pred) + 1)
+ return _calc_adjusted_rand_index(labels_true, labels_pred, n_classes, n_clusters)
+
+
+@partial(jit, static_argnums=[2, 3])
+def _calc_adjusted_rand_index( ## ARI
+ labels_true: jnp.ndarray,
+ labels_pred: jnp.ndarray,
+ n_classes: int,
+ n_clusters: int
+) -> jnp.ndarray:
+ n_samples = labels_true.shape[0]
+ if n_samples <= 1:
+ return jnp.array(1.0)
+
+ ## Get contingency matrix (n_classes x n_clusters)
+ contingency = _compute_contingency_table(
+ labels_true,
+ labels_pred,
+ n_classes,
+ n_clusters
+ )
+
+ ## Calculate combination sums n_ijC2 = (n_ij * (n_ij - 1)) / 2
+ sum_nij_c2 = jnp.sum((contingency * (contingency - 1.0)) / 2.0)
+
+ ## Sums across margins (rows and columns)
+ sum_a = jnp.sum(contingency, axis=1)
+ sum_b = jnp.sum(contingency, axis=0)
+
+ ## Margin pair combinations
+ sum_a_c2 = jnp.sum((sum_a * (sum_a - 1.0)) / 2.0)
+ sum_b_c2 = jnp.sum((sum_b * (sum_b - 1.0)) / 2.0)
+
+ ## Expected index and Max index math formulas
+ total_c2 = (n_samples * (n_samples - 1.0)) / 2.0
+ expected_index = (sum_a_c2 * sum_b_c2) / total_c2
+ max_index = (sum_a_c2 + sum_b_c2) / 2.0
+
+ ## Prevent division by zero if everything is perfectly clustered or uniform
+ denominator = max_index - expected_index
+ ari = jnp.where(denominator == 0.0, 1.0, (sum_nij_c2 - expected_index) / denominator)
+ return ari
+
+
+def measure_FMI(
+ labels_true: jnp.ndarray,
+ labels_pred: jnp.ndarray
+) -> jnp.ndarray:
+ """
+ Calculates the Fowlkes-Mallows Index (FMI), which measures similarity between two sets of
+ indices - this score is the geometric mean of pair-wise recall and precision.
+ This measurement lies in `[0, 1]`, where higher is better (indicating greater similarity between
+ two clustering sets of identifiers).
+
+ Args:
+ labels_true: 1D array of shape (n_samples,) with true integer class labels.
+
+ labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels.
+
+ Returns:
+ scalar FMI of these two sets of indices
+ """
+ ## Dynamically find dimensions up to a statically bounded maximum
+ n_classes = int(jnp.max(labels_true) + 1)
+ n_clusters = int(jnp.max(labels_pred) + 1)
+ return _measure_fowlkes_mallows_index(labels_true, labels_pred, n_classes, n_clusters)
+
+
+@partial(jit, static_argnums=[2, 3])
+def _measure_fowlkes_mallows_index( ## FMI
+ labels_true: jnp.ndarray,
+ labels_pred: jnp.ndarray,
+ n_classes: int,
+ n_clusters: int
+) -> jnp.ndarray:
+ n_samples = labels_true.shape[0]
+ # Handle edge case for single or empty samples safely
+ if n_samples <= 1:
+ return jnp.array(0.0, dtype=jnp.float32)
+
+ contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters)
+
+ ## Compute marginal sums (sums along rows and columns)
+ sum_true = jnp.sum(contingency, axis=1)
+ sum_pred = jnp.sum(contingency, axis=0)
+
+ ## Calculate pairwise combinations using the matrix shortcut: nC2 = 0.5 * (sum(x^2) - N)
+ # True Positives pair combinations (tk)
+ tk = 0.5 * (jnp.sum(contingency ** 2) - n_samples)
+ ## Total pairs clustered together in ground truth (tr)
+ tr = 0.5 * (jnp.sum(sum_true ** 2) - n_samples)
+ ## Total pairs clustered together in predictions (tc)
+ tc = 0.5 * (jnp.sum(sum_pred ** 2) - n_samples)
+
+ ## Compute FMI = tk / sqrt(tr * tc)
+ # Prevent division by zero if there are no pair splits/matches
+ denominator = jnp.sqrt(tr * tc)
+ fmi = jnp.where(denominator == 0.0, 0.0, tk / denominator)
+ return fmi
+
+
+def measure_Vmeasure( ## V-Measure
+ labels_true: jnp.ndarray,
+ labels_pred: jnp.ndarray,
+ beta: float = 1.0
+) -> jnp.ndarray:
+ """
+ Calculates the V-Measure scoring metric for class conformity. This measurement compares
+ predicted cluster indices ("labels_pred") against ground truth indices ("labels_true") and
+ represents the harmonic mean of homogeneity (where each cluster contains only members of a single class)
+ as well as completeness (where all members of a given class are assigned to the same cluster).
+ This measurement (higher is better) lies in `[0,1]` where `1` indicates perfect, correct clustering.
+
+ Args:
+ labels_true: 1D array of shape (n_samples,) with true integer class labels
+
+ labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels
+
+ beta: Weight factor. Ratios > 1.0 favor completeness, < 1.0 favor homogeneity.
+
+ Returns:
+ scalar V-measure of these two sets of indices
+ """
+ ## Dynamically find dimensions up to a statically bounded maximum
+ n_classes = int(jnp.max(labels_true) + 1)
+ n_clusters = int(jnp.max(labels_pred) + 1)
+ return _measure_v_measure_score(labels_true, labels_pred, n_classes, n_clusters, beta)
+
+
+@partial(jit, static_argnums=[2, 3, 4])
+def _measure_v_measure_score( ## V-Measure
+ labels_true: jnp.ndarray,
+ labels_pred: jnp.ndarray,
+ n_classes: int,
+ n_clusters: int,
+ beta: float = 1.0
+) -> jnp.ndarray:
+ n_samples = labels_true.shape[0]
+
+ ## Handle edge case for single or empty samples safely
+ if n_samples <= 1:
+ return jnp.array(0.0, dtype=jnp.float32)
+
+ contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters)
+
+ ## Calculate Marginal Sums (Row and Column totals)
+ sum_true = jnp.sum(contingency, axis=1)
+ sum_pred = jnp.sum(contingency, axis=0)
+
+ ## Compute Base Entropies H(True) and H(Pred)
+ p_true = sum_true / n_samples
+ h_true = -jnp.sum(jnp.where(p_true > 0.0, p_true * jnp.log(p_true), 0.0))
+
+ p_pred = sum_pred / n_samples
+ h_pred = -jnp.sum(jnp.where(p_pred > 0.0, p_pred * jnp.log(p_pred), 0.0))
+
+ ## Compute Joint Entropy H(True, Pred)
+ p_joint = contingency / n_samples
+ h_joint = -jnp.sum(jnp.where(p_joint > 0.0, p_joint * jnp.log(p_joint), 0.0))
+
+ ## Derive Conditional Entropies: H(True|Pred) and H(Pred|True) using identity rule
+ h_true_given_pred = h_joint - h_pred
+ h_pred_given_true = h_joint - h_true
+
+ ## Compute Homogeneity (H) and Completeness (C)
+ ## If base entropy is 0, the metric is perfectly satisfied (1.0)
+ homogeneity = jnp.where(h_true == 0.0, 1.0, 1.0 - (h_true_given_pred / h_true))
+ completeness = jnp.where(h_pred == 0.0, 1.0, 1.0 - (h_pred_given_true / h_pred))
+
+ ## Compute Weighted Harmonic Mean (V-Measure)
+ denominator = beta * homogeneity + completeness
+
+ ## Prevent division by zero if both metrics are zero
+ v_measure = jnp.where(
+ denominator == 0.0,
+ 0.0,
+ (1.0 + beta) * homogeneity * completeness / denominator
+ )
+ return v_measure
diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py
index bf20079b..aa8ef9cb 100755
--- a/ngclearn/utils/model_utils.py
+++ b/ngclearn/utils/model_utils.py
@@ -73,8 +73,7 @@ def create_function(fun_name, args=None):
Args:
fun_name: string name of activation function to produce;
Currently supports: "tanh", "bkwta" (binary K-winners-take-all), "sigmoid", "relu", "lrelu", "relu6",
- "elu", "silu", "gelu", "softplus", "softmax" (derivative not supported), "unit_threshold", "heaviside",
- "identity"
+ "elu", "silu", "gelu", "softplus", "softmax", "unit_threshold", "heaviside", "identity"
Returns:
function fx, first derivative of function (w.r.t. input) dfx
@@ -108,7 +107,7 @@ def create_function(fun_name, args=None):
elif fun_name == "elu":
fx = elu
dfx = d_elu
- elif fun_name == "silu":
+ elif fun_name == "silu": # NOTE: this is also the swish function
fx = silu
dfx = d_silu
elif fun_name == "gelu":
@@ -122,26 +121,33 @@ def create_function(fun_name, args=None):
dfx = d_softplus
elif fun_name == "softmax":
fx = softmax
- ## NOTE: below is an improper derivative proxy
- ## correct dfx is a Jacobian of softmax (not currently supported!)
- dfx = d_identity
+ dfx = d_softmax ## NOTE: this yields a Jacobian tensor Jx
elif fun_name == "unit_threshold":
fx = threshold ## default threshold is 1 (thus unit)
- dfx = d_threshold ## STE approximation
+ dfx = d_threshold ## NOTE: STE approximation
elif "heaviside" in fun_name:
fx = heaviside
- dfx = d_heaviside ## STE approximation
+ dfx = d_heaviside ## NOTE: STE approximation
elif fun_name == "identity":
fx = identity
dfx = d_identity
- else:
+ else: ## throw exception for un-supported activation
raise RuntimeError(
"Activation function (" + fun_name + ") is not recognized/supported!"
- )
+ )
return fx, dfx
@partial(jit, static_argnums=[1])
-def bkwta(x, nWTA=5): #5 10 15 #K=50):
+def bkwta(x, nWTA=5): ## binarized k-winner-take-all function
+ """
+ The binarized K winner-take-all (K-WTA) function:
+
+ Args:
+ x: input (tensor) value (real-valued)
+
+ Returns:
+ output (tensor) value (binary values)
+ """
values, indices = lax.top_k(x, nWTA) # Note: we do not care to sort the indices
kth = jnp.expand_dims(jnp.min(values,axis=1),axis=1) # must do comparison per sample in potential mini-batch
topK = jnp.greater_equal(x, kth).astype(jnp.float32) # cast booleans to floats
@@ -214,7 +220,6 @@ def clamp_max(x, max_val):
_x = x * mask + (1. - mask) * max_val
return _x
-
@jit
def one_hot(P):
"""
@@ -254,7 +259,7 @@ def chebyshev_norm(d, axis=-1, keepdims=False):
@jit
def binarize(data, threshold=0.5):
"""
- Converts the vector *data* to its binary equivalent
+ Converts the vector *data* to its binary equivalent.
Args:
data: the data to binarize (real-valued)
@@ -357,10 +362,13 @@ def d_telu(x):
ex = jnp.exp(x)
tanh_ex = jnp.tanh(ex)
return tanh_ex + x * ex * (1.0 - tanh_ex ** 2)
+
@jit
def sine(x, omega_0=30):
"""
- f(x) = sin(x * omega_0).
+ The sine function, parameterized by frequency `omega`:
+
+ | f(x) = sin(x * omega_0).
Args:
x: input (tensor) value
@@ -373,14 +381,15 @@ def sine(x, omega_0=30):
@jit
def d_sine(x, omega_0=30):
"""
- frequency = omega_0
- frequency * cos(x * frequency).
+ The derivative of the sine function:
+
+ | f'(x) = frequency * cos(x * frequency); where frequency = omega_0
Args:
x: input (tensor) value
Returns:
- output (tensor) value
+ output (tensor) derivative value (with respect to input)
"""
return omega_0 * jnp.cos(omega_0 * x)
@@ -492,7 +501,9 @@ def d_relu6(x):
@jit
def softplus(x):
"""
- The softplus elementwise function.
+ The softplus elementwise function:
+
+ | f(x) = ln(1 + exp(-x))
Args:
x: input (tensor) value
@@ -514,31 +525,94 @@ def d_softplus(x):
output (tensor) derivative value (with respect to input argument)
"""
## d/dx of softplus = logistic sigmoid
- return nn.sigmoid(x)
+ return sigmoid(x) #nn.sigmoid(x)
@jit
def threshold(x, thr=1.):
+ """
+ The threshold function (or Heaviside but with a non-zero boundary):
+
+ | f(x) = 1 if x >= thr, otherwise 0 (for x < thr)
+
+ Args:
+ x: input (tensor) value
+
+ Returns:
+ output (tensor) value
+ """
return (x >= thr).astype(jnp.float32)
@jit
def d_threshold(x, thr=1.):
- return x * 0. + 1. ## straight-thru estimator
+ """
+ Derivative of the threshold function; specifically, this employs the
+ straight-through estimator (STE) as a proxy/surrogate derivative instead.
+
+ Args:
+ x: input (tensor) value
+
+ Returns:
+ output (tensor) derivative value (with respect to input argument)
+ """
+ return x * 0. + 1. ## NOTE: straight-thru estimator (STE)
@jit
def heaviside(x):
+ """
+ The Heaviside function:
+
+ | f(x) = 1 if x >= 0, otherwise 0 (for x < 0)
+
+ Args:
+ x: input (tensor) value
+
+ Returns:
+ output (tensor) value
+ """
return (x >= 0.).astype(jnp.float32)
@jit
def d_heaviside(x):
- return x * 0. + 1. ## straight-thru estimator
+ """
+ Derivative of the Heaviside function; specifically, this employs the
+ straight-through estimator (STE) as a proxy/surrogate derivative instead.
+
+ Args:
+ x: input (tensor) value
+
+ Returns:
+ output (tensor) derivative value (with respect to input argument)
+ """
+ return x * 0. + 1. ## NOTE: straight-thru estimator (STE)
@jit
def sigmoid(x):
- return nn.sigmoid(x)
+ """
+ The sigmoid / logistic-link function:
+
+ | f(x) = 1/(1 + exp(-x)
+
+ Args:
+ x: input (tensor) value
+
+ Returns:
+ output (tensor) value
+ """
+ sigm_x = 1./ (1. + jnp.exp(-x))
+ return sigm_x #nn.sigmoid(x)
@jit
def d_sigmoid(x):
- sigm_x = nn.sigmoid(x) ## pre-compute once
+ """
+ Derivative of the sigmoid / logistic-link function.
+
+ Args:
+ x: input (tensor) value
+
+ Returns:
+ output (tensor) derivative value (with respect to input argument)
+ """
+ sigm_x = sigmoid(x) #nn.sigmoid(x) ## pre-compute once
return sigm_x * (1. - sigm_x)
def inverse_sigmoid(x, clip_bound=0.03): ## wrapper call for naming convention ease
@@ -590,7 +664,9 @@ def d_swish(x, beta):
@jit
def silu(x):
"""
- Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
+ Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
+ Note that this is primarily a convenience wrapper function for
+ the `swish` activation.
Args:
x: data to transform via inverse logistic function
@@ -607,7 +683,8 @@ def d_silu(x):
@jit
def gelu(x):
"""
- Applies the Gaussian Error Linear Unit (GeLU) activation (specifically, a fast approximation is used).
+ Applies the Gaussian Error Linear Unit (GeLU) activation
+ (specifically, a fast approximation is used via a weighted `swish`).
Args:
x: data to transform via inverse logistic function
@@ -635,7 +712,7 @@ def elu(x, alpha=1.):
Returns:
output of the GeLU activation
"""
- mask = x >= 0.
+ mask = x >= 0. ## pre-compute mask
return x * mask + ((jnp.exp(x) - 1) * alpha) * (1. - mask)
@jit
@@ -653,21 +730,68 @@ def softmax(x, tau=0.0):
Args:
x: a (N x D) input argument (pre-activity) to the softmax operator
- tau: probability sharpening/softening factor
+ tau: probability sharpening/softening factor, if > 0.; else, <= 0 disables
+ this (Default: 0.)
Returns:
a (N x D) probability distribution output block
"""
- if tau > 0.0:
- x = x / tau
+ #if tau > 0.0:
+ # x = x / tau
+ _m = tau > 0.
+ _tau = tau * _m + (1. - _m) ## sets _tau=1 if tau <= 0
+ x = x * (1./ _tau)
max_x = jnp.max(x, axis=1, keepdims=True)
exp_x = jnp.exp(x - max_x)
return exp_x / jnp.sum(exp_x, axis=1, keepdims=True)
+@partial(jit, static_argnums=[2])
+def d_softmax(x, tau=0., vmap_form=False): ## temperature-controlled softmax derivative co-routine
+ """
+ Derivative of the softmax function.
+ Note that this returns specifically the Jacobian tensor `Jx` of softmax(x) w.r.t.
+ potential batch set of vectors (one per row).
+
+ Args:
+ x: input (tensor) value (B x D)
+
+ vmap_form: optional algorithm switch flag; if True, `Jx` is computed using
+ Jax vmap (Default: False)
+
+ Returns:
+ output (tensor) derivative values (Jacobian with respect to input argument; B x D x D)
+ """
+ _m = tau > 0.
+ _tau = tau * _m + (1. - _m) ## sets _tau=1 if tau <= 0
+ Jx = 0. ## d_softmax(x)/d_x is a Jacobian matrix per sample
+ ## caclulate softmax along feature dimension (axis=-1)
+ s = softmax(x, tau=_tau) # nn.softmax(x, axis=-1) ## (BxD)
+ if not vmap_form: ### use pure tensorized batch-identity trick algorithm
+ diag_s = jnp.expand_dims(s, axis=-1) * jnp.eye(s.shape[-1]) ## Shape: (BxDx1) * (1xDxD) => (BxDxD)
+ ## batched outer(s, s) ~> outer product for each batch vector
+ outer_s = jnp.expand_dims(s, axis=-1) * jnp.expand_dims(s, axis=-2) ## (BxDx1) * (Bx1xD) => (BxDxD)
+ Jx = (diag_s - outer_s) * (1. / _tau)
+ else: ### switch to vmap algorithm
+ ## calc outer product using einsum (clean and readable)
+ outer_s = jnp.einsum('bi,bj->bij', s, s) ## (BxDxD)
+ ## fast batched diagonal insertion via a diagonal mask
+ d = s.shape[-1]
+ diag_indices = jnp.arange(d)
+ ## jax.at subtracts outer product from diagonal
+ ## (s - s^2) for diagonal, (-s_i s_j) for off-diagonal
+ ## avoids constructing a giant identity matrix
+ jacobian = -outer_s
+ ## vmap over index updates across batch
+ def add_diag(J_matrix, s_vector):
+ return J_matrix.at[diag_indices, diag_indices].add(s_vector)
+ Jx = ( jax.vmap(add_diag)(jacobian, s) ) * (1. / _tau)
+ return Jx ## return full, final Jacobian
+
@jit
def threshold_soft(x, lmbda):
"""
- A soft threshold routine applied to each dimension of input
+ A soft threshold routine applied to each dimension of input.
+ (Note that this function does not contain a complementary derivative.)
Args:
x: data to apply threshold function over
@@ -684,7 +808,8 @@ def threshold_soft(x, lmbda):
@jit
def threshold_cauchy(x, lmbda):
"""
- A Cauchy distributional threshold routine applied to each dimension of input
+ A Cauchy distributional threshold routine applied to each dimension of input.
+ (Note that this function does not contain a complementary derivative.)
Args:
x: data to apply threshold function over
@@ -770,6 +895,23 @@ def create_block_matrix(map_matrix, group_shape, alpha_inh=-1., alpha_exc=1.):
gmat = jnp.concatenate(gmat, axis=0)
return gmat
+@partial(jit, static_argnums=[0])
+def eye_wrapped(N, k, values):
+ """
+ Creates an N x N matrix with a wrapped off-diagonal.
+
+ Args:
+ N: Size of the square matrix (N x N)
+
+ k: Diagonal offset (positive=above, negative=below)
+
+ values: Array of values to place (length should match n)
+ """
+ matrix = jnp.zeros((N, N)) ## Create empty matrix
+ row_indices = jnp.arange(N) ## Generate indices for the diagonal
+ col_indices = (row_indices + k) % N ## Wrap column indices using modulo
+ return matrix.at[row_indices, col_indices].set(values) ## Fill diagonal using efficient indexing
+
@partial(jit, static_argnums=[1, 2, 3, 4])
def normalize_block_matrix(matrix, block_size, order=2, axis=0, norm_targ=1.):
"""
diff --git a/ngclearn/utils/viz/__init__.py b/ngclearn/utils/viz/__init__.py
index e37c46fc..605d9bef 100644
--- a/ngclearn/utils/viz/__init__.py
+++ b/ngclearn/utils/viz/__init__.py
@@ -2,3 +2,5 @@
from . import raster
from . import spike_plot
from . import synapse_plot
+from . import classification_analysis
+
diff --git a/ngclearn/utils/viz/classification_analysis.py b/ngclearn/utils/viz/classification_analysis.py
new file mode 100644
index 00000000..cce819f5
--- /dev/null
+++ b/ngclearn/utils/viz/classification_analysis.py
@@ -0,0 +1,57 @@
+import matplotlib.pyplot as plt
+import numpy as np
+
+def visualize_confusion_heatmap(
+ confuse_matrix,
+ classes,
+ out_fname,
+ figure_title="Confusion Matrix",
+ color_map="Blues", # "Greens" "Reds"
+ fontsize=10,
+ norm_by="none"
+):
+ _conf = confuse_matrix
+ if "recall" in norm_by: ## normalize by row (recall)
+ _conf = (_conf / np.sum(_conf, axis=1, keepdims=True)) * 100
+ elif "precision" in norm_by: ## normalize by col (precision)
+ _conf = (_conf / np.sum(_conf, axis=0, keepdims=True)) * 100
+ ## Initialize plot
+ fig, ax = plt.subplots(figsize=(6, 6))
+ vmin = np.floor(np.min(_conf))
+ vmax = np.ceil(np.max(_conf))
+ im = ax.imshow(
+ _conf, interpolation="nearest", cmap=color_map, vmin=vmin, vmax=vmax
+ )
+ ## Add color bar
+ fig.colorbar(im, ax=ax, shrink=0.75)
+ ## Configure axes labels
+ ax.set_xticks(np.arange(len(classes)))
+ ax.set_yticks(np.arange(len(classes)))
+ ax.set_xticklabels(classes)
+ ax.set_yticklabels(classes)
+
+ ax.set_xlabel("Predicted Label", fontsize=12, labelpad=10)
+ ax.set_ylabel("True Label", fontsize=12, labelpad=10)
+ ax.set_title(figure_title, fontsize=14, pad=15)
+
+ ## Loop to print numbers on top of each colored cell
+ threshold = _conf.max() / 2.0 # Find midpoint to flip text color for readability
+ for i in range(_conf.shape[0]):
+ for j in range(_conf.shape[1]):
+ ## Use white text on dark cells, black text on light cells
+ color = "white" if _conf[i, j] > threshold else "black"
+ ax.text(
+ j,
+ i,
+ f"{_conf[i, j]:.1f}", #format(confuse_matrix[i, j], "d"),
+ ha="center",
+ va="center",
+ color=color,
+ fontsize=fontsize,
+ weight="bold",
+ )
+
+ ## Ensure layout fits tightly
+ plt.tight_layout()
+ plt.savefig(out_fname)
+
diff --git a/ngclearn/utils/viz/dim_reduce.py b/ngclearn/utils/viz/dim_reduce.py
index 3f32057d..3300feef 100755
--- a/ngclearn/utils/viz/dim_reduce.py
+++ b/ngclearn/utils/viz/dim_reduce.py
@@ -28,7 +28,12 @@ def extract_pca_latents(vectors): ## PCA mapping routine
z_2D = vectors
return z_2D
-def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): ## tSNE mapping routine
+def extract_tsne_latents(
+ vectors,
+ perplexity=30,
+ n_pca_comp=32,
+ batch_size=500
+): ## tSNE mapping routine
"""
Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D) visualization space via the
t-distributed stochastic neighbor embedding algorithm (t-SNE). This algorithm also uses PCA to produce an
@@ -41,9 +46,9 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500):
perplexity: the perplexity control factor for t-SNE (Default: 30)
n_pca_comp: number of PCA top components (sorted by eigen-values) to retain/extract before continuing
- with t-SNE dimensionality reduction
+ with t-SNE dimensionality reduction (Default: 32)
- batch_size: number of sampled embedding vectors to use per iteration of online internal PCA
+ batch_size: number of sampled embedding vectors to use per iteration of online internal PCA (Default: 500)
Returns:
a matrix (K x 2) of projected vectors (to 2D space)
@@ -67,7 +72,15 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500):
z_2D = vectors
return z_2D
-def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., cmap=None):
+def plot_latents(
+ code_vectors,
+ labels,
+ plot_fname="2Dcode_plot.jpg",
+ alpha=1.,
+ cmap=None,
+ xaxis_title=None,
+ yaxis_title=None
+):
"""
Produces a label-overlaid (label map to distinct colors) scatterplot for visualizing two-dimensional latent codes
(produced by either PCA or t-SNE).
@@ -84,6 +97,11 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., c
alpha: alpha intensity level to present colors in scatterplot
cmap: custom color-map to provide
+
+ xaxis_title: string denoting title to place for X-axis (Default: None)
+
+ yaxis_title: string denoting title to place for Y-axis (Default: None)
+
"""
curr_backend = plt.rcParams["backend"]
matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting
@@ -98,10 +116,17 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., c
if _cmap is None:
_cmap = default_cmap
#print("> USING DEFAULT CMAP!")
- plt.scatter(code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=_cmap, alpha=alpha)
+ plt.scatter(
+ code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=_cmap, alpha=alpha
+ )
colorbar = plt.colorbar()
#colorbar.set_alpha(1)
#plt.draw_all()
+ if xaxis_title is not None:
+ plt.xlabel(xaxis_title, fontsize=16, fontweight="bold")
+ if yaxis_title is not None:
+ plt.ylabel(yaxis_title, fontsize=16, fontweight="bold")
+
plt.grid()
plt.savefig("{0}".format(plot_fname), dpi=300)
plt.clf()
diff --git a/ngclearn/utils/viz/raster.py b/ngclearn/utils/viz/raster.py
index dff8745b..875a52e8 100755
--- a/ngclearn/utils/viz/raster.py
+++ b/ngclearn/utils/viz/raster.py
@@ -140,3 +140,4 @@ def create_overlay_raster_plot(spike_train, targ_train, Y, idxs, s=1.5, c="black
plt.savefig(plot_fname + '_' + str(idx) + suffix)
plt.clf()
plt.close()
+
diff --git a/ngclearn/utils/viz/synapse_plot.py b/ngclearn/utils/viz/synapse_plot.py
index 0f268491..f62a47da 100644
--- a/ngclearn/utils/viz/synapse_plot.py
+++ b/ngclearn/utils/viz/synapse_plot.py
@@ -9,8 +9,13 @@
import jax.numpy as jnp
-
-def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'):
+def visualize(
+ thetas,
+ sizes,
+ prefix,
+ order=None,
+ suffix='.jpg'
+):
"""
Args:
@@ -22,8 +27,6 @@ def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'):
suffix:
"""
-
-
if order is None:
order = ['C' for _ in range(len(thetas))]
@@ -52,8 +55,11 @@ def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'):
point = start + 1 + i + (r * extra)
plt.subplot(n_rows_total, n_cols_total, point)
- filter = T[i, :]
- plt.imshow(np.reshape(filter, (sizes[idx][0], sizes[idx][1]), order=order[idx]), cmap=plt.cm.bone, interpolation='nearest')
+ _filter = T[i, :]
+ plt.imshow(
+ np.reshape(_filter, (sizes[idx][0], sizes[idx][1]), order=order[idx]),
+ cmap=plt.cm.bone, interpolation='nearest'
+ )
plt.axis("off")
plt.subplots_adjust(top=0.9)
@@ -62,7 +68,14 @@ def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'):
plt.close()
-def visualize_labels(thetas, sizes, prefix, space_width=None, widths=None, suffix='.jpg'):
+def visualize_labels(
+ thetas,
+ sizes,
+ prefix,
+ space_width=None,
+ widths=None,
+ suffix='.jpg'
+):
"""
Args:
@@ -139,14 +152,34 @@ def visualize_labels(thetas, sizes, prefix, space_width=None, widths=None, suffi
fig.savefig(prefix+suffix, bbox_inches='tight')
plt.close(fig)
-def visualize_frame(frame, path='.', name='tmp', suffix='.jpg', **kwargs):
+def visualize_frame(
+ frame,
+ path='.',
+ name='tmp',
+ suffix='.jpg',
+ **kwargs
+):
iio.imwrite(path + '/' + name + suffix, frame.astype(jnp.uint8), **kwargs)
-def visualize_gif(frames, path='.', name='tmp', suffix='.jpg', **kwargs):
+def visualize_gif(
+ frames,
+ path='.',
+ name='tmp',
+ suffix='.jpg',
+ **kwargs
+):
_frames = [f.astype(jnp.uint8) for f in frames]
iio.imwrite(path + '/' + name + '.gif', _frames, **kwargs)
-def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs):
+def make_video(
+ f_start,
+ f_end,
+ path,
+ prefix,
+ suffix='.jpg',
+ skip=1,
+ **kwargs
+):
images = []
for i in range(f_start, f_end+1, skip):
print("Reading frame " + str(i))
@@ -154,10 +187,13 @@ def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs):
print("writing gif")
iio.imwrite(path + '/training.gif', images, **kwargs)
-
-# def visualize_norm(thetas, sizes, prefix, suffix='.jpg'):
-
-def viz_block(thetas, sizes, prefix, suffix=".jpg", padding=1, low_rez=True):
+def viz_block(
+ thetas,
+ sizes, prefix,
+ suffix=".jpg",
+ padding=1,
+ low_rez=True
+):
num_filters = [T.shape[1] for T in thetas]
n_cols = [math.ceil(math.sqrt(nf)) for nf in num_filters]
n_rows = [math.ceil(nf / c) for nf, c in zip(num_filters, n_cols)]
diff --git a/pyproject.toml b/pyproject.toml
index bd1c7194..3458b230 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" # using setuptool building engine
[project]
name = "ngclearn"
-version = "3.1.0"
+version = "3.1.1"
description = "Simulation software for building and analyzing computational neuroscience models, brain-inspired computing systems, and NeuroAI agents."
authors = [
{name = "Alexander Ororbia", email = "ago@cs.rit.edu"},
diff --git a/requirements.txt b/requirements.txt
index 09ab6214..d68c12aa 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,9 +2,10 @@ numpy>=1.26.4
scikit-learn>=1.6.1
scipy>=1.14.1
matplotlib>=3.9.4
-# patchify # patchify has issues with pip installation
+# patchify ## note: patchify has issues with pip installation
jax>=0.4.28
jaxlib>=0.4.28
-ngcsimlib>=3.0.0
+ngcsimlib>=3.1.0
imageio>=2.37.0
pandas>=2.2.3
+typing_extensions>=4.15.0