diff --git a/README.md b/README.md index 757b9862..760a4933 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ $ python install -e . **Version:**
-3.1.0 +3.1.1 Author: Alexander G. Ororbia II
diff --git a/history.txt b/history.txt index 4345a232..a8c1e724 100644 --- a/history.txt +++ b/history.txt @@ -99,4 +99,6 @@ History * new suite of competitive learning synapses (generalized formats), including vector-quantization, self-organizing map, adaptive resonance theory (contus-version), and modern hopfield network * revisions to metric and model utils * some additional clean-up, including supported retinal ganglion encoder + * fixed pkg-resource header (in both ngcsimlib and ngclearn) to account for newer python(s) + diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py index 6121286c..5670cfca 100644 --- a/ngclearn/__init__.py +++ b/ngclearn/__init__.py @@ -31,7 +31,7 @@ "currently installed!") ################################################################################## -## Needed to preload is called before anything in ngclearn +## Following are needed to preload is called before anything in ngclearn from pathlib import Path from sys import argv import numpy diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index 0b81afdc..94a2c116 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -30,6 +30,9 @@ from .input_encoders.ganglionCell import RetinalGanglionCell from .input_encoders.latencyCell import LatencyCell from .input_encoders.phasorCell import PhasorCell +#from .input_encoders.populationCoderCell import PopulationCoderCell +#from .input_encoders.gridCell import GridCell +#from .input_encoders.placeCell import PlaceCell ## point to synapse component types from .synapses.denseSynapse import DenseSynapse diff --git a/ngclearn/components/input_encoders/__init__.py b/ngclearn/components/input_encoders/__init__.py index 9af3241b..bbee5280 100644 --- a/ngclearn/components/input_encoders/__init__.py +++ b/ngclearn/components/input_encoders/__init__.py @@ -1,7 +1,9 @@ from .bernoulliCell import BernoulliCell from .poissonCell import PoissonCell from .latencyCell import LatencyCell -from .ganglionCell import RetinalGanglionCell from .phasorCell import PhasorCell - +from .ganglionCell import RetinalGanglionCell +#from .populationCoderCell import PopulationCoderCell +#from .gridCell import GridCell +#from .placeCell import PlaceCell diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py index 132f7d92..68e87539 100755 --- a/ngclearn/components/input_encoders/bernoulliCell.py +++ b/ngclearn/components/input_encoders/bernoulliCell.py @@ -27,7 +27,12 @@ class BernoulliCell(JaxComponent): """ def __init__( - self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs + self, + name: str, + n_units: int, + batch_size: int = 1, + key: Union[jax.Array, None] = None, + **kwargs ): super().__init__(name=name, key=key) diff --git a/ngclearn/components/input_encoders/ganglionCell.py b/ngclearn/components/input_encoders/ganglionCell.py index a35c7b0a..b7e58c05 100644 --- a/ngclearn/components/input_encoders/ganglionCell.py +++ b/ngclearn/components/input_encoders/ganglionCell.py @@ -8,26 +8,28 @@ def _create_gaussian_filter(patch_shape, sigma): ## Create a 2D Gaussian kernel centered on patch_shape with given sigma. px, py = patch_shape - x_ = jnp.linspace(0, px - 1, px) y_ = jnp.linspace(0, py - 1, py) - x, y = jnp.meshgrid(x_, y_) - xc = px // 2 yc = py // 2 - - filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2))) - return filter / jnp.sum(filter) + _filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2))) + return _filter / jnp.sum(_filter) def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1): g1 = _create_gaussian_filter(patch_shape, sigma=sigma) g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k) - dog = g1 - lmbda * g2 - return dog #- jnp.mean(dog) + +def _create_ratio_of_gauss_filter(patch_shape, sigma, k=1.6): + g1 = _create_gaussian_filter(patch_shape, sigma=sigma) + g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k) + rog = g1 / (g2 + 1e-8) + return rog + + def _create_patches(obs, patch_shape, step_shape): """ Extract 2D patches from a batch of images using a sliding window. @@ -67,10 +69,36 @@ def _create_patches(obs, patch_shape, step_shape): return patches +def _reconstruct(patches, nx_ny, area_shape, patch_shape, step_shape): + # patches: (N, nx * ny, px, py) + + B = len(patches) + nx, ny = nx_ny + ix, iy = area_shape + px, py = patch_shape + sx, sy = step_shape + x = jnp.zeros((B, ix, iy)) + counts = jnp.zeros((ix, iy)) + + idx = 0 + for i in range(ny): + for j in range(nx): + di = i * sx + dj = j * sy + x = x.at[:, di:di + px, dj:dj + py].add(patches[:, idx]) + counts = counts.at[di:di + px, dj:dj + py].add(1.0) + idx += 1 + + return x / counts[None, :, :] + + class RetinalGanglionCell(JaxComponent): """ - A group of retinal ganglion cell that senses the input stimuli and sends out the filtered signal to the brain. + A group of retinal ganglion cell that sense input stimuli and send out filtered + signals (as output). Note that these simulated cells employ internal generalized + filters based on either Gaussian or difference-of-Gaussian kernels) to recover + historical receptive field processing effects. | --- Cell Input Compartments: --- | inputs - input (takes in external signals) @@ -85,32 +113,34 @@ class RetinalGanglionCell(JaxComponent): filter_type: string name of filter function (Default: identity) :Note: supported filters include "gaussian", "difference_of_gaussian" - sigma: standard deviation of gaussian kernel + sigma: standard deviation of (gaussian) kernel - area_shape: receptive field area of ganglion cells in this module all together + area_shape: shape of receptive field area of ganglion cells in this module (all together) n_cells: number of ganglion cells in this module - patch_shape: each ganglion cell receptive field area + patch_shape: shape of each ganglion cell's receptive field area - step_shape: the non-overlapping area between each two ganglion cells + step_shape: the non-overlapping area between each pair (two) of ganglion cells - batch_size: batch size dimension of this cell (Default: 1) + batch_size: batch size dimension of this cell/module (Default: 1) """ - def __init__(self, name: str, - filter_type: str, - area_shape: Tuple[int, int], - n_cells: int, - patch_shape: Tuple[int, int], - step_shape: Tuple[int, int], - batch_size: int = 1, - sigma: float = 1.0, - key: Union[jax.Array, None] = None, - **kwargs): + def __init__( + self, + name: str, + filter_type: str, + area_shape: Tuple[int, int], + n_cells: int, + patch_shape: Tuple[int, int], + step_shape: Tuple[int, int], + batch_size: int = 1, + sigma: float = 1.0, + key: Union[jax.Array, None] = None, + **kwargs + ): super().__init__(name=name, key=key) - ## Layer Size Setup self.filter_type = filter_type self.n_cells = n_cells @@ -121,36 +151,43 @@ def __init__(self, name: str, self.patch_shape = patch_shape self.step_shape = step_shape - filter = jnp.ones(self.patch_shape) + _filter = jnp.ones(self.patch_shape) - if filter_type == 'gaussian': - filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma) - elif filter_type == 'difference_of_gaussian': - filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) + if self.filter_type == 'gaussian': + print("filter type is ", self.filter_type) + _filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma) + + elif self.filter_type in ["difference_of_gaussian", "DoG"]: + print("filter type is difference of gaussian: f(x) = p1 - p2") + _filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) + + elif self.filter_type in ["ratio_of_gaussian", "RoG"]: + print("filter type is ratio of gaussian: f(x) = p1 / p2") + _filter = _create_ratio_of_gauss_filter(patch_shape=self.patch_shape, sigma=sigma) # ═════════════════ compartments initial values ════════════════════ - in_restVals = jnp.zeros((batch_size, - *self.area_shape)) ## input: (B | ix | iy) + in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) - out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) - self.n_cells * self.patch_shape[0] * self.patch_shape[1])) + out_restVals = jnp.zeros( + (batch_size, self.n_cells * self.patch_shape[0] * self.patch_shape[1]) + ) ## output.shape: (B | n_cells * px * py) # ═══════════════════ set compartments ══════════════════════ self.inputs = Compartment(in_restVals, display_name="Input Stimulus") # input compartment - self.filter = Compartment(filter, display_name="Filter") # Filter compartment + self.filter = Compartment(_filter, display_name="Filter") # Filter compartment self.outputs = Compartment(out_restVals, display_name="Output Signal") # output compartment @compilable def advance_state(self, t): inputs = self.inputs.get() - filter = self.filter.get() + _filter = self.filter.get() px, py = self.patch_shape # ═══════════════════ extract pathches for filters ══════════════════ input_patches = _create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape) # ═══════════════════ apply filter to all pathches ══════════════════ - filtered_input = input_patches * filter ## shape: (B | n_cells | px | py) + filtered_input = input_patches * _filter ## shape: (B | n_cells | px | py) # ════════════ reshape all cells responses to a single input to brain ════════════ filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py) @@ -160,31 +197,20 @@ def advance_state(self, t): self.outputs.set(outputs) + + @compilable def reset(self): ## reset core components/statistics - # self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member in_restVals = jnp.zeros((self.batch_size, *self.area_shape)) ## input: (B | ix | iy) out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) self.n_cells * self.patch_shape[0] * self.patch_shape[1])) self.inputs.set(in_restVals) self.outputs.set(out_restVals) - # Viet: NOTE: we should not need this function since the reset function - # one could set the batch size then do reset - # @compilable - # def batched_reset(self, batch_size): - # in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) - - # out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) - # self.n_cells * self.patch_shape[0] * self.patch_shape[1])) - - # self.inputs.set(in_restVals) - # self.outputs.set(out_restVals) - @classmethod def help(cls): ## component help function properties = { - "cell_type": "RetinalGanglionCell - filters the input stimuli, " + "cell_type": "RetinalGanglionCell - filters the input stimuli according retinal ganglion dynamics" } compartment_props = { "inputs": @@ -196,11 +222,11 @@ def help(cls): ## component help function } hyperparams = { "filter_type": "Type of the filter for preprocessing the input", - "sigma": "Standard deviation of gaussian kernel", + "sigma": "Standard deviation of gaussian kernel/filter", "area_shape": "Effective receptive field area shape of ganglion cells in this module", - "n_cells": "Number of Retinal Ganglion (center-surround) cells to model in this layer", - "patch_shape": "Classical Receptive field area shape of individual ganglion cells in this module", - "step_shape": "Extra-Classical Receptive field area shape each ganglion cell in this module", + "n_cells": "Number of retinal ganglion (center-surround) cells to model in this layer", + "patch_shape": "Classical receptive field area shape of individual ganglion cells in this module", + "step_shape": "Extra-classical receptive field area shape each ganglion cell in this module", "batch_size": "Batch size dimension of this component" } info = {cls.__name__: properties, @@ -212,13 +238,15 @@ def help(cls): ## component help function if __name__ == '__main__': from ngcsimlib.context import Context with Context("Bar") as bar: - X = RetinalGanglionCell("RGC", filter_type="gaussian", - sigma=2.3, - area_shape=(16, 26), - n_cells = 3, - patch_shape=(16, 16), - step_shape=(0, 5) - ) + X = RetinalGanglionCell( + "RGC", + filter_type="gaussian", + sigma=2.3, + area_shape=(16, 26), + n_cells = 3, + patch_shape=(16, 16), + step_shape=(0, 5) + ) print(X) diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py index 810776ab..333479ae 100644 --- a/ngclearn/components/input_encoders/poissonCell.py +++ b/ngclearn/components/input_encoders/poissonCell.py @@ -32,8 +32,13 @@ class PoissonCell(JaxComponent): @deprecate_args(max_freq="target_freq") def __init__( - self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1, - key: Union[jax.Array, None] = None, **kwargs + self, + name: str, + n_units: int, + target_freq: float = 63.75, + batch_size: int = 1, + key: Union[jax.Array, None] = None, + **kwargs ): super().__init__(name=name, key=key) diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index d12bd33f..9e072de5 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -1,175 +1,174 @@ -# %% - -from ngclearn.components.jaxComponent import JaxComponent -from jax import numpy as jnp, jit -from ngclearn import compilable #from ngcsimlib.parser import compilable -from ngclearn import Compartment #from ngcsimlib.compartment import Compartment - -class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell - """ - A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically, - this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood). - - | --- Cell Input Compartments: --- - | mu - predicted value (takes in external signals) - | Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2 - | target - desired/goal value (takes in external signals) - | modulator - modulation signal (takes in optional external signals) - | mask - binary/gating mask to apply to error neuron calculations - | --- Cell Output Compartments: --- - | L - local loss function embodied by this cell - | dmu - derivative of L w.r.t. mu - | dSigma - derivative of L w.r.t. Sigma - | dtarget - derivative of L w.r.t. target - - Args: - name: the string name of this cell - - n_units: number of cellular entities (neural population size) - - batch_size: batch size dimension of this cell (Default: 1) - - sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution; - Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses - to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument - (Default: 1) - """ - def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): - super().__init__(name, **kwargs) - - ## Layer Size Setup - _shape = (batch_size, n_units) ## default shape is 2D/matrix - if shape is None: - shape = (n_units,) ## we set shape to be equal to n_units if nothing provided - else: - _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor - sigma_shape = (1,1) - if not isinstance(sigma, float) and not isinstance(sigma, int): - sigma_shape = jnp.array(sigma).shape - self.sigma_shape = sigma_shape - self.shape = shape - self.batch_size = batch_size - self.n_units = n_units - - ## Convolution shape setup - self.width = self.height = n_units - - ## Compartment setup - restVals = jnp.zeros(_shape) - self.L = Compartment(0., display_name="Gaussian Log likelihood", units="nats") # loss compartment - self.mu = Compartment(restVals, display_name="Gaussian mean") # mean/mean name. input wire - self.dmu = Compartment(restVals) # derivative mean - _Sigma = jnp.zeros(sigma_shape) - self.Sigma = Compartment(_Sigma + sigma, display_name="Gaussian variance/covariance") - self.dSigma = Compartment(_Sigma) - self.target = Compartment(restVals, display_name="Gaussian data/target variable") # target. input wire - self.dtarget = Compartment(restVals) # derivative target - self.modulator = Compartment(restVals + 1.0) # to be set/consumed - self.mask = Compartment(restVals + 1.0) - - @staticmethod - def _eval_log_density(target, mu, Sigma): ## Gaussian log likelihood - ## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix - _dmu = (target - mu) - #_numerator = 1. # 0.5 - log_density = -jnp.sum(jnp.square(_dmu)) * (1./((Sigma ** 2) * 2)) #* (_numerator / Sigma) - return log_density, _dmu ## return density and raw delta - - @compilable - def advance_state(self, dt): ## compute Gaussian error cell output (fixed-point) - # Get the variables - mu = self.mu.get() - target = self.target.get() - Sigma = self.Sigma.get() - modulator = self.modulator.get() - mask = self.mask.get() - - # Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit - # behavior of the local cost functional: - # FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2 - # but should support full log likelihood of the multivariate Gaussian with covariance of different types - # TODO: could introduce a variant of GaussianErrorCell that moves according to an ODE - # (using integration time constant dt) - - L, _dmu = GaussianErrorCell._eval_log_density(target, mu, Sigma) # L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) - ## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu) - dmu = _dmu * (1./ Sigma) ## obtain precision-scaled e: (target - mu)/Sigma - dtarget = -dmu # reverse of e ## -(target - mu)/Sigma - dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma - - dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance... - dtarget = dtarget * modulator * mask - mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t - - # Update compartments - self.dmu.set(dmu) - self.dtarget.set(dtarget) - self.dSigma.set(dSigma) - self.L.set(jnp.squeeze(L)) - self.mask.set(mask) - - @compilable - def reset(self): ## reset core components/statistics - self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member - - @compilable - def batched_reset(self, batch_size): - _shape = (batch_size, self.shape[0]) - if len(self.shape) > 1: - _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) - restVals = jnp.zeros(_shape) - dmu = restVals - dtarget = restVals - dSigma = jnp.zeros(self.sigma_shape) - target = restVals - mu = restVals - modulator = mu + 1. - L = 0. #jnp.zeros((1, 1)) - mask = jnp.ones(_shape) - - self.dmu.set(dmu) - self.dtarget.set(dtarget) - self.dSigma.set(dSigma) - if not self.target.targeted: - self.target.set(target) - if not self.mu.targeted: - self.mu.set(mu) - self.modulator.set(modulator) - self.L.set(L) - self.mask.set(mask) - - @classmethod - def help(cls): ## component help function - properties = { - "cell_type": "GaussianErrorCell - computes mismatch/error signals at " - "each time step t (between a `target` and a prediction `mu`)" - } - compartment_props = { - "inputs": - {"mu": "External input prediction value(s)", - "Sigma": "External variance/covariance prediction value(s)", - "target": "External input target signal value(s)", - "modulator": "External input modulatory/scaling signal(s)", - "mask": "External binary/gating mask to apply to signals"}, - "outputs": - {"L": "Local loss / free-energy value embodied by this error-cell", - "dmu": "first derivative of loss w.r.t. prediction value(s)", - "dSigma": "first derivative of loss w.r.t. variance/covariance value(s)", - "dtarget": "first derivative of loss w.r.t. target value(s)"}, - } - hyperparams = { - "n_units": "Number of neuronal cells to model in this layer", - "batch_size": "Batch size dimension of this component", - "sigma": "External input variance value (currently fixed and not learnable)" - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "Gaussian(x=target; mu, sigma)", - "hyperparameters": hyperparams} - return info - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - X = GaussianErrorCell("X", 9) - print(X) +# %% + +from ngclearn.components.jaxComponent import JaxComponent +from jax import numpy as jnp, jit +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment + +class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell + """ + A simple (non-spiking) Gaussian error cell - this is a fixed-point calculation of a mismatch signal. Specifically, + this error cell offers a configurable variance and calculates its local free energy (Gaussian log likelihood). + + | --- Cell Input Compartments: --- + | mu - predicted value (takes in external signals) + | Sigma - predicted covariance (takes in external signals), or, if just a scalar, then it's sigma^2 + | target - desired/goal value (takes in external signals) + | modulator - modulation signal (takes in optional external signals) + | mask - binary/gating mask to apply to error neuron calculations + | --- Cell Output Compartments: --- + | L - local loss function embodied by this cell + | dmu - derivative of L w.r.t. mu + | dSigma - derivative of L w.r.t. Sigma + | dtarget - derivative of L w.r.t. target + + Args: + name: the string name of this cell + + n_units: number of cellular entities (neural population size) + + batch_size: batch size dimension of this cell (Default: 1) + + sigma: initial/fixed value for prediction covariance matrix (𝚺) in multivariate gaussian distribution; + Note that if the compartment `Sigma` is never used, then this cell assumes that the covariance collapses + to a constant/fixed `sigma^2`, i.e., Sigma = sigma^2, where `sigma` is a scalar standard deviation argument + (Default: 1) + """ + def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): + super().__init__(name, **kwargs) + + ## Layer Size Setup + _shape = (batch_size, n_units) ## default shape is 2D/matrix + if shape is None: + shape = (n_units,) ## we set shape to be equal to n_units if nothing provided + else: + _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor + sigma_shape = (1,1) + if not isinstance(sigma, float) and not isinstance(sigma, int): + sigma_shape = jnp.array(sigma).shape + self.sigma_shape = sigma_shape + self.shape = shape + self.n_units = n_units + self.batch_size = batch_size + + ## Convolution shape setup + self.width = self.height = n_units + + ## Compartment setup + restVals = jnp.zeros(_shape) + self.L = Compartment(0., display_name="Gaussian Log likelihood", units="nats") # loss compartment + self.mu = Compartment(restVals, display_name="Gaussian mean") # mean/mean name. input wire + self.dmu = Compartment(restVals) # derivative mean + _Sigma = jnp.zeros(sigma_shape) + self.Sigma = Compartment(_Sigma + sigma, display_name="Gaussian variance/covariance") + self.dSigma = Compartment(_Sigma) + self.target = Compartment(restVals, display_name="Gaussian data/target variable") # target. input wire + self.dtarget = Compartment(restVals) # derivative target + self.modulator = Compartment(restVals + 1.0) # to be set/consumed + self.mask = Compartment(restVals + 1.0) + + @staticmethod + def eval_log_density(target, mu, Sigma): + ## NOTE: ln(p) = -(x - mu)^2 * 1/(2 Sigma), where Sigma might be sigma^2 or covariance matrix + _dmu = (target - mu) + #_numerator = 1. # 0.5 + log_density = -jnp.sum(jnp.square(_dmu)) * (1./((Sigma ** 2) * 2)) #* (_numerator / Sigma) + return log_density, _dmu ## return density and raw delta + + @compilable + def advance_state(self, dt): ## compute Gaussian error cell output + # Get the variables + mu = self.mu.get() + target = self.target.get() + Sigma = self.Sigma.get() + modulator = self.modulator.get() + mask = self.mask.get() + + # Move Gaussian cell dynamics one step forward. Specifically, this + # routine emulates the error unit + ''' + ## This commented-out block of code should be adapted to replace the + ## five lines below it in future iterations (more accurate/flexible) + L, _dmu = GaussianErrorCell._eval_log_density(target, mu, Sigma) # L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) + ## _dmu => "raw" e (error unit/mis-match) # _dmu = (target - mu) + dmu = _dmu * (1./ Sigma) ## obtain precision-scaled e: (target - mu)/Sigma + dtarget = -dmu # reverse of e ## -(target - mu)/Sigma + dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for Sigma + ''' + _dmu = (target - mu) # e (error unit) + dmu = _dmu / Sigma + dtarget = -dmu # reverse of e + dSigma = Sigma * 0 + 1. # no derivative is calculated at this time for sigma + L = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma) + #L = GaussianErrorCell.eval_log_density(target, mu, Sigma) + + dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance... + dtarget = dtarget * modulator * mask + mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t + + # Update compartments + self.dmu.set(dmu) + self.dtarget.set(dtarget) + self.dSigma.set(dSigma) + self.L.set(jnp.squeeze(L)) + self.mask.set(mask) + + @compilable + def reset(self): ## reset core components/statistics + _shape = (self.batch_size, self.shape[0]) + if len(self.shape) > 1: + _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + restVals = jnp.zeros(_shape) + dmu = restVals + dtarget = restVals + dSigma = jnp.zeros(self.sigma_shape) + target = restVals + mu = restVals + modulator = mu + 1. + L = 0. #jnp.zeros((1, 1)) + mask = jnp.ones(_shape) + + self.dmu.set(dmu) + self.dtarget.set(dtarget) + self.dSigma.set(dSigma) + self.target.set(target) + self.mu.set(mu) + self.modulator.set(modulator) + self.L.set(L) + self.mask.set(mask) + + @classmethod + def help(cls): ## component help function + properties = { + "cell_type": "GaussianErrorCell - computes mismatch/error signals at " + "each time step t (between a `target` and a prediction `mu`)" + } + compartment_props = { + "inputs": + {"mu": "External input prediction value(s)", + "Sigma": "External variance/covariance prediction value(s)", + "target": "External input target signal value(s)", + "modulator": "External input modulatory/scaling signal(s)", + "mask": "External binary/gating mask to apply to signals"}, + "outputs": + {"L": "Local loss / free-energy value embodied by this error-cell", + "dmu": "first derivative of loss w.r.t. prediction value(s)", + "dSigma": "first derivative of loss w.r.t. variance/covariance value(s)", + "dtarget": "first derivative of loss w.r.t. target value(s)"}, + } + hyperparams = { + "n_units": "Number of neuronal cells to model in this layer", + "batch_size": "Batch size dimension of this component", + "sigma": "External input variance value (currently fixed and not learnable)" + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "Gaussian(x=target; mu, sigma)", + "hyperparameters": hyperparams} + return info + +if __name__ == '__main__': + from ngcsimlib.context import Context + with Context("Bar") as bar: + X = GaussianErrorCell("X", 9) + print(X) diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index 98f60325..3dcbdcc8 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -1,328 +1,324 @@ -# %% - -from jax import numpy as jnp, random, jit - -from ngclearn import compilable #from ngcsimlib.parser import compilable -from ngclearn import Compartment #from ngcsimlib.compartment import Compartment -from ngclearn.components.jaxComponent import JaxComponent -from ngclearn.utils.model_utils import create_function, threshold_soft, \ - threshold_cauchy -from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ - step_euler, step_rk2, step_rk4 -from ngcsimlib.logger import info - - -def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics - z_leak = jnp.sign(z) ## d/dx of Laplace is signum - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): ## raw dynamics - z_leak = (z * 2)/(1. + jnp.square(z)) - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): ## raw dynamics - z_leak = jnp.exp(-jnp.square(z)) * z * 2 - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): ## raw dynamics - z_leak = z # * 2 ## Default: assume Gaussian - dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) - return dz_dt - -# @jit -def _modulate(j, dfx_val): - """ - Apply a signal modulator to j (typically of the form of a derivative/dampening function) - - Args: - j: current/stimulus value to modulate - - dfx_val: modulator signal - - Returns: - modulated j value - """ - return j * dfx_val - -# @partial(jit, static_argnames=["integType", "priorType"]) -def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0): - """ - Runs leaky rate-coded state dynamics one step in time. - - Args: - dt: integration time constant - - j: input (bottom-up) electrical/stimulus current - - j_td: modulatory (top-down) electrical/stimulus pressure - - z: current value of membrane/state - - tau_m: membrane/state time constant - - leak_gamma: strength of leak to apply to membrane/state - - integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4) - - priorType: scale-shift prior distribution to impose over neural dynamics - - Returns: - New value of membrane/state for next time step - """ - _dfz_fns = { - 0: lambda t, z, params: _dfz_internal_gaussian(z, *params), - 1: lambda t, z, params: _dfz_internal_laplace(z, *params), - 2: lambda t, z, params: _dfz_internal_cauchy(z, *params), - 3: lambda t, z, params: _dfz_internal_exp(z, *params), - } - _dfz_fn = _dfz_fns.get(priorType, _dfz_internal_gaussian) - _step_fns = { - 0: step_euler, - 1: step_rk2, - 2: step_rk4, - } - _step_fn = _step_fns.get(integType, step_euler) - params = (j, j_td, tau_m, leak_gamma) - _, _z = _step_fn(0., z, _dfz_fn, dt, params) - return _z - -# @jit -def _run_cell_stateless(j): - """ - A simplification of running a stateless set of dynamics over j (an identity - functional form of dynamics). - - Args: - j: stimulus to do nothing to - - Returns: - the stimulus - """ - return j + 0 - -class RateCell(JaxComponent): ## Rate-coded/real-valued cell - """ - A non-spiking cell driven by the gradient dynamics of neural generative - coding-driven predictive processing. - - The specific differential equation that characterizes this cell - is (for adjusting v, given current j, over time) is: - - | tau_m * dz/dt = lambda * prior(z) + (j + j_td) - | where j is the set of general incoming input signals (e.g., message-passed signals) - | and j_td is taken to be the set of top-down pressure signals - - | --- Cell Input Compartments: --- - | j - input pressure (takes in external signals) - | j_td - input/top-down pressure input (takes in external signals) - | --- Cell State Compartments --- - | z - rate activity - | --- Cell Output Compartments: --- - | zF - post-activation function activity, i.e., fx(z) - - Args: - name: the string name of this cell - - n_units: number of cellular entities (neural population size) - - tau_m: membrane/state time constant (milliseconds) - - prior: a kernel for specifying the type of centered scale-shift distribution - to impose over neuronal dynamics, applied to each neuron or - dimension within this component (Default: ("gaussian", 0)); this is - a tuple with 1st element containing a string name of the distribution - one wants to use while the second value is a `leak rate` scalar - that controls the influence/weighting that this distribution - has on the dynamics; for example, ("laplacian, 0.001") means that a - centered laplacian distribution scaled by `0.001` will be injected - into this cell's dynamics ODE each step of simulated time - - :Note: supported scale-shift distributions include "laplacian", - "cauchy", "exp", and "gaussian" - - act_fx: string name of activation function/nonlinearity to use - - output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.) - - integration_type: type of integration to use for this cell's dynamics; - current supported forms include "euler" (Euler/RK-1 integration) - and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") - - :Note: setting the integration type to the midpoint method will - increase the accuray of the estimate of the cell's evolution - at an increase in computational cost (and simulation time) - - resist_scale: a scaling factor applied to incoming pressure `j` (default: 1) - """ - - def __init__( - self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.), - integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): - jax_comp_kwargs = {k: v for k, v in kwargs.items() if k not in ('omega_0',)} - this_class_kwargs = {k: v for k, v in kwargs.items() if k in ('omega_0',)} - super().__init__(name, **jax_comp_kwargs) - - ## membrane parameter setup (affects ODE integration) - self.output_scale = output_scale - self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode - self.is_stateful = is_stateful - if isinstance(tau_m, float): - if tau_m <= 0: ## trigger stateless mode - self.is_stateful = False - priorType, leakRate = prior - priorTypeDict = { - "gaussian": 0, - "laplacian": 1, - "cauchy": 2, - "exp": 3 - } - self.priorType = priorTypeDict.get(priorType, 0) - self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior) - thresholdType, thr_lmbda = threshold - self.thresholdType = thresholdType ## type of thresholding function to use - self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics - self.resist_scale = resist_scale ## a "resistance" scaling factor - - ## integration properties - self.integrationType = integration_type - self.intgFlag = get_integrator_code(self.integrationType) - - ## Layer size setup - _shape = (batch_size, n_units) ## default shape is 2D/matrix - if shape is None: - shape = (n_units,) ## we set shape to be equal to n_units if nothing provided - else: - _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor - self.shape = shape - self.n_units = n_units - self.batch_size = batch_size - - omega_0 = None - if act_fx == "sine": - omega_0 = this_class_kwargs["omega_0"] - self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0) - - # compartments (state of the cell & parameters will be updated through stateless calls) - restVals = jnp.zeros(_shape) - self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current - self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity - self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure - self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity - - @compilable - def advance_state(self, dt): - # Get the compartment values - j = self.j.get() - j_td = self.j_td.get() - z = self.z.get() - - #if tau_m > 0.: - if self.is_stateful: - ### run a step of integration over neuronal dynamics - ## Notes: - ## self.pressure <-- "top-down" expectation / contextual pressure - ## self.current <-- "bottom-up" data-dependent signal - dfx_val = self.dfx(z) - j = _modulate(j, dfx_val) ## TODO: make this optional (for NGC circuit dynamics) - j = j * self.resist_scale - tmp_z = _run_cell( - dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag, - priorType=self.priorType - ) - ## apply optional thresholding sub-dynamics - if self.thresholdType == "soft_threshold": - tmp_z = threshold_soft(tmp_z, self.thr_lmbda) - elif self.thresholdType == "cauchy_threshold": - tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda) - z = tmp_z ## pre-activation function value(s) - zF = self.fx(z) * self.output_scale ## post-activation function value(s) - else: - ## run in "stateless" mode (when no membrane time constant provided) - j_total = j + j_td - z = _run_cell_stateless(j_total) - zF = self.fx(z) * self.output_scale - - # Update compartments - self.j.set(j) - self.j_td.set(j_td) - self.z.set(z) - self.zF.set(zF) - - @compilable - def reset(self): ## reset core components/statistics - self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member - - @compilable - def batched_reset(self, batch_size): - _shape = (batch_size, self.shape[0]) - if len(self.shape) > 1: - _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) - restVals = jnp.zeros(_shape) - self.j.set(restVals) - self.j_td.set(restVals) - self.z.set(restVals) - self.zF.set(restVals) - - # def save(self, directory, **kwargs): - # ## do a protected save of constants, depending on whether they are floats or arrays - # tau_m = (self.tau_m if isinstance(self.tau_m, float) - # else jnp.ones([[self.tau_m]])) - # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) - # else jnp.ones([[self.priorLeakRate]])) - # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) - # else jnp.ones([[self.resist_scale]])) - # - # file_name = directory + "/" + self.name + ".npz" - # jnp.savez(file_name, - # tau_m=tau_m, priorLeakRate=priorLeakRate, - # resist_scale=resist_scale) #, key=self.key.value) - # - # def load(self, directory, seeded=False, **kwargs): - # file_name = directory + "/" + self.name + ".npz" - # data = jnp.load(file_name) - # ## constants loaded in - # self.tau_m = data['tau_m'] - # self.priorLeakRate = data['priorLeakRate'] - # self.resist_scale = data['resist_scale'] - # #if seeded: - # # self.key.set(data['key']) - - @classmethod - def help(cls): ## component help function - properties = { - "cell_type": "RateCell - evolves neurons according to rate-coded/" - "continuous dynamics " - } - compartment_props = { - "inputs": - {"j": "External input stimulus value(s)", - "j_td": "External top-down input stimulus value(s); these get " - "multiplied by the derivative of f(x), i.e., df(x)"}, - "states": - {"z": "Update to rate-coded continuous dynamics; value at time t"}, - "outputs": - {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"}, - } - hyperparams = { - "n_units": "Number of neuronal cells to model in this layer", - "batch_size": "Batch size dimension of this component", - "tau_m": "Cell state/membrane time constant", - "prior": "What kind of kurtotic prior to place over neuronal dynamics?", - "act_fx": "Elementwise activation function to apply over cell state `z`", - "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?", - "integration_type": "Type of numerical integration to use for the cell dynamics", - } - info = {cls.__name__: properties, - "compartments": compartment_props, - "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)", - "hyperparameters": hyperparams} - return info - -if __name__ == '__main__': - from ngcsimlib.context import Context - with Context("Bar") as bar: - X = RateCell("X", 9, 0.03) - print(X) +# %% + +from jax import numpy as jnp, random, jit + +from ngclearn import compilable #from ngcsimlib.parser import compilable +from ngclearn import Compartment #from ngcsimlib.compartment import Compartment +from ngclearn.components.jaxComponent import JaxComponent +from ngclearn.utils.model_utils import create_function, threshold_soft, \ + threshold_cauchy +from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \ + step_euler, step_rk2, step_rk4 +from ngcsimlib.logger import info + + +def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics + z_leak = jnp.sign(z) ## d/dx of Laplace is signum + dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) + return dz_dt + +def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): ## raw dynamics + z_leak = (z * 2)/(1. + jnp.square(z)) + dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) + return dz_dt + +def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): ## raw dynamics + z_leak = jnp.exp(-jnp.square(z)) * z * 2 + dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) + return dz_dt + +def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): ## raw dynamics + z_leak = z # * 2 ## Default: assume Gaussian + dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m) + return dz_dt + +# @jit +def _modulate(j, dfx_val): + """ + Apply a signal modulator to j (typically of the form of a derivative/dampening function) + + Args: + j: current/stimulus value to modulate + + dfx_val: modulator signal + + Returns: + modulated j value + """ + return j * dfx_val + +# @partial(jit, static_argnames=["integType", "priorType"]) +def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0): + """ + Runs leaky rate-coded state dynamics one step in time. + + Args: + dt: integration time constant + + j: input (bottom-up) electrical/stimulus current + + j_td: modulatory (top-down) electrical/stimulus pressure + + z: current value of membrane/state + + tau_m: membrane/state time constant + + leak_gamma: strength of leak to apply to membrane/state + + integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4) + + priorType: scale-shift prior distribution to impose over neural dynamics + + Returns: + New value of membrane/state for next time step + """ + _dfz_fns = { + 0: lambda t, z, params: _dfz_internal_gaussian(z, *params), + 1: lambda t, z, params: _dfz_internal_laplace(z, *params), + 2: lambda t, z, params: _dfz_internal_cauchy(z, *params), + 3: lambda t, z, params: _dfz_internal_exp(z, *params), + } + _dfz_fn = _dfz_fns.get(priorType, _dfz_internal_gaussian) + _step_fns = { + 0: step_euler, + 1: step_rk2, + 2: step_rk4, + } + _step_fn = _step_fns.get(integType, step_euler) + params = (j, j_td, tau_m, leak_gamma) + _, _z = _step_fn(0., z, _dfz_fn, dt, params) + return _z + +# @jit +def _run_cell_stateless(j): + """ + A simplification of running a stateless set of dynamics over j (an identity + functional form of dynamics). + + Args: + j: stimulus to do nothing to + + Returns: + the stimulus + """ + return j + 0 + +class RateCell(JaxComponent): ## Rate-coded/real-valued cell + """ + A non-spiking cell driven by the gradient dynamics of neural generative + coding-driven predictive processing. + + The specific differential equation that characterizes this cell + is (for adjusting v, given current j, over time) is: + + | tau_m * dz/dt = lambda * prior(z) + (j + j_td) + | where j is the set of general incoming input signals (e.g., message-passed signals) + | and j_td is taken to be the set of top-down pressure signals + + | --- Cell Input Compartments: --- + | j - input pressure (takes in external signals) + | j_td - input/top-down pressure input (takes in external signals) + | --- Cell State Compartments --- + | z - rate activity + | --- Cell Output Compartments: --- + | zF - post-activation function activity, i.e., fx(z) + + Args: + name: the string name of this cell + + n_units: number of cellular entities (neural population size) + + tau_m: membrane/state time constant (milliseconds) + + prior: a kernel for specifying the type of centered scale-shift distribution + to impose over neuronal dynamics, applied to each neuron or + dimension within this component (Default: ("gaussian", 0)); this is + a tuple with 1st element containing a string name of the distribution + one wants to use while the second value is a `leak rate` scalar + that controls the influence/weighting that this distribution + has on the dynamics; for example, ("laplacian, 0.001") means that a + centered laplacian distribution scaled by `0.001` will be injected + into this cell's dynamics ODE each step of simulated time + + :Note: supported scale-shift distributions include "laplacian", + "cauchy", "exp", and "gaussian" + + act_fx: string name of activation function/nonlinearity to use + + output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.) + + integration_type: type of integration to use for this cell's dynamics; + current supported forms include "euler" (Euler/RK-1 integration) + and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler") + + :Note: setting the integration type to the midpoint method will + increase the accuray of the estimate of the cell's evolution + at an increase in computational cost (and simulation time) + + resist_scale: a scaling factor applied to incoming pressure `j` (default: 1) + """ + + def __init__( + self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.), + integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs): + jax_comp_kwargs = {k: v for k, v in kwargs.items() if k not in ('omega_0',)} + this_class_kwargs = {k: v for k, v in kwargs.items() if k in ('omega_0',)} + super().__init__(name, **jax_comp_kwargs) + + ## membrane parameter setup (affects ODE integration) + self.output_scale = output_scale + self.tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode + self.is_stateful = is_stateful + if isinstance(tau_m, float): + if tau_m <= 0: ## trigger stateless mode + self.is_stateful = False + priorType, leakRate = prior + priorTypeDict = { + "gaussian": 0, + "laplacian": 1, + "cauchy": 2, + "exp": 3 + } + self.priorType = priorTypeDict.get(priorType, 0) + self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior) + thresholdType, thr_lmbda = threshold + self.thresholdType = thresholdType ## type of thresholding function to use + self.thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics + self.resist_scale = resist_scale ## a "resistance" scaling factor + + ## integration properties + self.integrationType = integration_type + self.intgFlag = get_integrator_code(self.integrationType) + + ## Layer size setup + _shape = (batch_size, n_units) ## default shape is 2D/matrix + if shape is None: + shape = (n_units,) ## we set shape to be equal to n_units if nothing provided + else: + _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor + self.shape = shape + self.n_units = n_units + self.batch_size = batch_size + + omega_0 = None + if act_fx == "sine": + omega_0 = this_class_kwargs["omega_0"] + self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0) + + # compartments (state of the cell & parameters will be updated through stateless calls) + restVals = jnp.zeros(_shape) + self.j = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current + self.zF = Compartment(restVals, display_name="Transformed Rate Activity") # rate-coded output - activity + self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure + self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity + + @compilable + def advance_state(self, dt): + # Get the compartment values + j = self.j.get() + j_td = self.j_td.get() + z = self.z.get() + + #if tau_m > 0.: + if self.is_stateful: + ### run a step of integration over neuronal dynamics + ## Notes: + ## self.pressure <-- "top-down" expectation / contextual pressure + ## self.current <-- "bottom-up" data-dependent signal + dfx_val = self.dfx(z) + j = _modulate(j, dfx_val) + j = j * self.resist_scale + tmp_z = _run_cell( + dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag, + priorType=self.priorType + ) + ## apply optional thresholding sub-dynamics + if self.thresholdType == "soft_threshold": + tmp_z = threshold_soft(tmp_z, self.thr_lmbda) + elif self.thresholdType == "cauchy_threshold": + tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda) + z = tmp_z ## pre-activation function value(s) + zF = self.fx(z) * self.output_scale ## post-activation function value(s) + else: + ## run in "stateless" mode (when no membrane time constant provided) + j_total = j + j_td + z = _run_cell_stateless(j_total) + zF = self.fx(z) * self.output_scale + + # Update compartments + self.j.set(j) + self.j_td.set(j_td) + self.z.set(z) + self.zF.set(zF) + + @compilable + def reset(self): #, batch_size, shape): #n_units + _shape = (self.batch_size, self.shape[0]) + if len(self.shape) > 1: + _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + restVals = jnp.zeros(_shape) + self.j.set(restVals) + self.j_td.set(restVals) + self.z.set(restVals) + self.zF.set(restVals) + + # def save(self, directory, **kwargs): + # ## do a protected save of constants, depending on whether they are floats or arrays + # tau_m = (self.tau_m if isinstance(self.tau_m, float) + # else jnp.ones([[self.tau_m]])) + # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float) + # else jnp.ones([[self.priorLeakRate]])) + # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float) + # else jnp.ones([[self.resist_scale]])) + # + # file_name = directory + "/" + self.name + ".npz" + # jnp.savez(file_name, + # tau_m=tau_m, priorLeakRate=priorLeakRate, + # resist_scale=resist_scale) #, key=self.key.value) + # + # def load(self, directory, seeded=False, **kwargs): + # file_name = directory + "/" + self.name + ".npz" + # data = jnp.load(file_name) + # ## constants loaded in + # self.tau_m = data['tau_m'] + # self.priorLeakRate = data['priorLeakRate'] + # self.resist_scale = data['resist_scale'] + # #if seeded: + # # self.key.set(data['key']) + + @classmethod + def help(cls): ## component help function + properties = { + "cell_type": "RateCell - evolves neurons according to rate-coded/" + "continuous dynamics " + } + compartment_props = { + "inputs": + {"j": "External input stimulus value(s)", + "j_td": "External top-down input stimulus value(s); these get " + "multiplied by the derivative of f(x), i.e., df(x)"}, + "states": + {"z": "Update to rate-coded continuous dynamics; value at time t"}, + "outputs": + {"zF": "Nonlinearity/function applied to rate-coded dynamics; f(z)"}, + } + hyperparams = { + "n_units": "Number of neuronal cells to model in this layer", + "batch_size": "Batch size dimension of this component", + "tau_m": "Cell state/membrane time constant", + "prior": "What kind of kurtotic prior to place over neuronal dynamics?", + "act_fx": "Elementwise activation function to apply over cell state `z`", + "threshold": "What kind of iterative thresholding function to place over neuronal dynamics?", + "integration_type": "Type of numerical integration to use for the cell dynamics", + } + info = {cls.__name__: properties, + "compartments": compartment_props, + "dynamics": "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)", + "hyperparameters": hyperparams} + return info + +if __name__ == '__main__': + from ngcsimlib.context import Context + with Context("Bar") as bar: + X = RateCell("X", 9, 0.03) + print(X) diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py index ee1ecb02..92c9c5e9 100755 --- a/ngclearn/components/synapses/denseSynapse.py +++ b/ngclearn/components/synapses/denseSynapse.py @@ -83,7 +83,7 @@ def __init__( self.inputs = Compartment(preVals) self.outputs = Compartment(postVals) self.weights = Compartment(weights) - _mask = 1. + _mask = jnp.ones((1, 1)) if mask is not None: _mask = mask self.mask = Compartment(_mask) diff --git a/ngclearn/components/synapses/hebbian/__init__.py b/ngclearn/components/synapses/hebbian/__init__.py index 05bfd207..0a1630c3 100644 --- a/ngclearn/components/synapses/hebbian/__init__.py +++ b/ngclearn/components/synapses/hebbian/__init__.py @@ -4,4 +4,5 @@ from .expSTDPSynapse import ExpSTDPSynapse from .eventSTDPSynapse import EventSTDPSynapse from .BCMSynapse import BCMSynapse +from .gerstnerHebbianSynapse import GerstnerHebbianSynapse ## Taylor-expansion Hebbian model diff --git a/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py b/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py new file mode 100644 index 00000000..d1efca5a --- /dev/null +++ b/ngclearn/components/synapses/hebbian/gerstnerHebbianSynapse.py @@ -0,0 +1,113 @@ +import jax.numpy as jnp +from jax import random, jit + +from ngclearn import compilable +from ngclearn import Compartment +from ngclearn.components.synapses import DenseSynapse +from ngclearn.utils import tensorstats +from ngcsimlib import deprecate_args +#from ngclearn.utils.io_utils import save_pkl, load_pkl + +class GerstnerHebbianSynapse(DenseSynapse): + """ + A synapse component that implements Gerstner's general Hebbian + learning (Taylor) expansion (Equation 3 from Gerstner & Kistler, 2002). + + Note that this synpatic update model can recover several classical forms + of Hebbian-like update rules, including the covariance rule. + + There are other higher-order terms possible, i.e., \Theta(xy), such as + x * y2 and y x^2, etc. + + | c2_corr > 0 and c0 = c1_pre = c1_post = 0 => Hebbian update + | c2_corr < 0 and c0 = c1_pre = c1_post = 0 => anti-Hebbian update + | c2_corr = 1 and c1_pre = -x_theta < 0 + + """ + def __init__( + self, + name, + shape, ## (post_dim, pre_dim) + eta=0.01, ## global step-size + coeffs=None, ## these configure which kind of Hebb learning is done + weight_init=None, + p_conn=1., + resist_scale=1., + sign_value=1., + batch_size=1, + **kwargs + ): + bias_init = None ## no biases are included in Gerster's formulation + super().__init__( + name, + shape=shape, + weight_init=weight_init, + bias_init=bias_init, + resist_scale=resist_scale, + p_conn=p_conn, + batch_size=batch_size, + **kwargs + ) + ## General Hebbian meta-parameters + self.eta = eta + self.sign_value = sign_value + + ## Expansion coefficients (c0, c1_pre, c1_post, c2_corr) + if coeffs is None: ## Default to standard bilinear Hebb + self.coeffs = { + 'c0': 0., 'c1_pre': 0., 'c1_post': 0., 'c2_corr': 1.0 + } + else: + self.coeffs = coeffs + self.c0 = self.coeffs['c0'] + self.c1_pre = self.coeffs['c1_pre'] + self.c1_post = self.coeffs['c1_post'] + self.c2_corr = self.coeffs['c2_corr'] + + # Initialize Weights (using JAX PRNG) + #init_key, _ = random.split(self.key) + #w_init = random.normal(init_key, shape) * 0.05 + + # Compartments (ngc-learn state management) + #self.weights = Compartment(w_init) + self.pre = Compartment(jnp.zeros((1, shape[1]))) + self.post = Compartment(jnp.zeros((1, shape[0]))) + + @compilable + def evolve(self, **kwargs): + """ + Updates weights using the Gerstner general expansion. + Assumes pre_act and post_act compartments have been populated. + """ + # Retrieve current states + W = self.weights.get() + x = self.pre.get() # pre-synaptic activity (batch, pre_dim) + y = self.post.get() # post-synaptic activity (batch, post_dim) + batch_size = self.batch_size + + ## Bilinear Term (c2): correlation matrix + ### (post_dim, batch) @ (batch, pre_dim) -> (post_dim, pre_dim) + dW_corr = jnp.matmul(x.T, y) * (1./batch_size) + ## Linear pre-synaptic term (c1_pre) + ### Average over batch then broadcast to match weight matrix + dW_pre = jnp.sum(x, axis=0, keepdims=True).T * (1./batch_size) + ## Linear post-synaptic term (c1_post) + dW_post = jnp.sum(y, axis=0, keepdims=True) * (1./batch_size) + + ## Apply Equation 3 Taylor expansion + dW = (self.c0 * W + ## synaptic decay + self.c1_pre * dW_pre + ## bilinear term + self.c1_post * dW_post + ## pre-synaptic gating term + self.c2_corr * dW_corr ## post-synpatic gating term + ) + ## perform a step of Hebbian ascent + W = W + self.eta * dW + ## Update weights + self.weights.set(W) + + @compilable + def reset(self, **kwargs): + """Clears activity compartments""" + self.pre.set( jnp.zeros((self.batch_size, self.shape[1])) ) + self.post.set( jnp.zeros((self.batch_size, self.shape[0])) ) + diff --git a/ngclearn/modules/regression/elastic_net.py b/ngclearn/modules/regression/elastic_net.py index 5860d2bc..06cd09bd 100644 --- a/ngclearn/modules/regression/elastic_net.py +++ b/ngclearn/modules/regression/elastic_net.py @@ -77,13 +77,15 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l self.W = HebbianSynapse( "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1, weight_init=dist.constant(value=weight_fill), prior=('elastic_net', (lmbda, l1_ratio)), w_bound=0., - optim_type=optim_type, key=subkeys[0] + optim_type=optim_type, key=subkeys[0], batch_size=batch_size ) - self.err = GaussianErrorCell("err", n_units=sys_dim) + self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size) # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.batch_size = batch_size self.err.batch_size = batch_size + self.W.reset() + self.err.reset() # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.outputs >> self.err.mu self.err.dmu >> self.W.post diff --git a/ngclearn/modules/regression/lasso.py b/ngclearn/modules/regression/lasso.py index 15a014bb..8d5f91d5 100644 --- a/ngclearn/modules/regression/lasso.py +++ b/ngclearn/modules/regression/lasso.py @@ -76,12 +76,14 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l self.W = HebbianSynapse( "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1, weight_init=dist.constant(value=weight_fill), prior=('lasso', lasso_lmbda), w_bound=0., - optim_type=optim_type, key=subkeys[0] + optim_type=optim_type, key=subkeys[0], batch_size=batch_size ) - self.err = GaussianErrorCell("err", n_units=sys_dim) + self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size) # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.batch_size = batch_size self.err.batch_size = batch_size + self.W.reset() + self.err.reset() # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.outputs >> self.err.mu self.err.dmu >> self.W.post diff --git a/ngclearn/modules/regression/ridge.py b/ngclearn/modules/regression/ridge.py index dfbacb03..84591670 100644 --- a/ngclearn/modules/regression/ridge.py +++ b/ngclearn/modules/regression/ridge.py @@ -76,13 +76,15 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l self.W = HebbianSynapse( "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1, weight_init=dist.constant(value=weight_fill), prior=('ridge', ridge_lmbda), w_bound=0., - optim_type=optim_type, key=subkeys[0] + optim_type=optim_type, key=subkeys[0], batch_size=batch_size ) - self.err = GaussianErrorCell("err", n_units=sys_dim) + self.err = GaussianErrorCell("err", n_units=sys_dim, batch_size=batch_size) # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.batch_size = batch_size self.err.batch_size = batch_size + self.W.reset() + self.err.reset() # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ self.W.outputs >> self.err.mu self.err.dmu >> self.W.post diff --git a/ngclearn/utils/JaxProcessesMixin.py b/ngclearn/utils/JaxProcessesMixin.py index 3b849f09..3958dfb3 100644 --- a/ngclearn/utils/JaxProcessesMixin.py +++ b/ngclearn/utils/JaxProcessesMixin.py @@ -9,6 +9,10 @@ from ngcsimlib._src.process.baseProcess import BaseProcess class JaxCompiledMethod(CompiledMethod): + """ + A wrapper for a compiled method that includes jax's jit wrapped. Used + exclusively by the mixin and shouldn't be used elsewhere. + """ def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals): super().__init__(fn, fn_ast, auxiliary_ast, namespace, extra_globals) self._fn = jax.jit(fn) @@ -16,10 +20,19 @@ def __init__(self, fn, fn_ast, auxiliary_ast, namespace, extra_globals): @property def source_fn(self): + """ + The source method not wrapped in jit + """ return self._fn_source @classmethod def wrap(cls, compiledMethod: CompiledMethod): + """ + Helper method to expand on a base compiled method + Args: + compiledMethod: The method to be expanded upon + Returns: the JaxCompiledMethod based on the input + """ return cls(compiledMethod._fn, compiledMethod.ast, compiledMethod.auxiliary_ast, @@ -28,7 +41,16 @@ def wrap(cls, compiledMethod: CompiledMethod): class JaxProcessesMixin: + """ + A mixin for the base Process that adds JAX functionality such as scan and + implicit jit wrapping + """ def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs): + """ + Look at the BaseProcess class for information about other arguments + Args: + use_jit: a flag for if the process should implicitly jit wrap + """ super().__init__(name, *args, **kwargs) self._previous_result = None self._previous_state = None @@ -36,27 +58,51 @@ def __init__(self: "BaseProcess", name, *args, use_jit=True, **kwargs): @property def previous_result(self): + """ + Stores and returns the last result of scan (the second returned value) + """ return self._previous_result @property def previous_state(self): + """ + Stores and returns the last returned state of scan (the first returned + value) + """ return self._previous_state def clear(self): + """ + Clears out the previous result and state from scan + """ self._previous_result = None self._previous_state = None - def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True): + def scan(self: "BaseProcess", inputs, current_state=None, store_state: bool = True, store_results: bool = True): + """ + Runs the process through jax's scan method + Args: + inputs: The inputs for scan (use pack rows to generate), must be a jax array + current_state: Optional, the current state of the model, if none uses current global state + store_state: Optional flag, should the final state be stored in the process + store_results: Optional flag, should the final result be stored in the process + + Returns: the final state, the final result + + """ state = current_state or stateManager.state final_state, result = jax.lax.scan(self.run.compiled, state, inputs) - if save_state: + if store_state: self._previous_state = final_state if store_results: self._previous_result = result return final_state, result def compile(self: "baseProcess"): + """ + For use by the compiler + """ super().compile() if self._use_jit: self.run.compiled = JaxCompiledMethod.wrap(self.run.compiled) diff --git a/ngclearn/utils/analysis/__init__.py b/ngclearn/utils/analysis/__init__.py index 082c68cd..31506c3e 100644 --- a/ngclearn/utils/analysis/__init__.py +++ b/ngclearn/utils/analysis/__init__.py @@ -2,4 +2,4 @@ from .linear_probe import LinearProbe from .attentive_probe import AttentiveProbe from .knn_probe import KNNProbe - +from .kmeans_probe import KMeansProbe diff --git a/ngclearn/utils/analysis/effective_dim.py b/ngclearn/utils/analysis/effective_dim.py new file mode 100644 index 00000000..7b46f318 --- /dev/null +++ b/ngclearn/utils/analysis/effective_dim.py @@ -0,0 +1,46 @@ +from jax import numpy as jnp + +def participation_ratio(latent_codes): + """ + Calculates the participation ratio coefficient for a set of latent codes + + Args: + latent_codes: a set of (N x D) latent code vectors (one row per vector code) + + Returns: + scalar measurement of the effective dimension + """ + Z = latent_codes + Zc = Z - Z.mean(axis=0, keepdims=True) + cov = (Zc.T @ Zc) / (Zc.shape[0] - 1) + + tr = jnp.trace(cov) + tr2_cov = tr * tr + cov2_tr = jnp.trace(cov @ cov) + + return tr2_cov / cov2_tr if cov2_tr > 0 else float("nan") + + + + +def rankme(Z, eps=1e-7): + """ + Calculates the effective rank of for a code matrix Z + effective rank = exp(Shannon entropy), from Garrido, Balestriero, + Najman & LeCun, "RankMe: Assessing the Downstream Performance of Pretrained + Self-Supervised Representations by Their Rank" (ICML 2023, arXiv:2210.02885). + + Args: + latent_codes: a set of (N x D) latent code vectors (one row per vector code) + + Returns: + scalar measurement of the effective dimension + """ + + singular_values = jnp.linalg.svd(Z, compute_uv=False) ## singular values of Z + sum_singular_vals = jnp.sum(singular_values) ## L1 + if sum_singular_vals <= 0: + return float("nan") + p = singular_values / sum_singular_vals + eps ## L1-normalized singular value + shannon_entropy = -jnp.sum(p * jnp.log(p)) ## Shannon entropy + return jnp.exp(shannon_entropy) ## exp(Shannon entropy) = effective rank diff --git a/ngclearn/utils/analysis/kmeans_probe.py b/ngclearn/utils/analysis/kmeans_probe.py new file mode 100644 index 00000000..676a00cd --- /dev/null +++ b/ngclearn/utils/analysis/kmeans_probe.py @@ -0,0 +1,109 @@ +import jax +from ngcsimlib import deprecate_args +from ngclearn.utils.analysis.probe import Probe +from jax import jit, random, numpy as jnp, lax, nn +from functools import partial as bind +from ngclearn.utils.metric_utils import measure_ARI + +@bind(jax.jit, static_argnums=[2]) +def _run_kmeans_probe(_embeddings, centroids, n_clusters): + ## Broadcast distances: (n_samples, 1, n_features) - (1, n_clusters, n_features) + distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1) + labels_pred = jnp.argmin(distances, axis=1) + ## Re-estimate centroids/means + one_hot_preds = labels_pred[:, None] == jnp.arange(n_clusters) + counts = jnp.maximum(one_hot_preds.sum(axis=0, keepdims=True).T, 1.0) + centroids = jnp.dot(one_hot_preds.T.astype(jnp.float32), _embeddings) / counts + return centroids + +@bind(jax.jit, static_argnums=[2]) +def _predict_with_probe(_embeddings, centroids, n_clusters): + ## Final pass to compute stable predictions + distances = jnp.sum((_embeddings[:, None, :] - centroids[None, :, :]) ** 2, axis=-1) + labels_pred = jnp.argmin(distances, axis=1) + Y_pred = nn.one_hot(labels_pred, n_clusters) + return labels_pred, Y_pred + +class KMeansProbe(Probe): + """ + This implements a K-means clustering probe, which is useful for evaluating the quality of + encodings/embeddings in light of the ability to cluster downstream data. Currently, this + probe only supports L2/Euclidean distance-based clustering. + + Args: + dkey: init seed key + + source_seq_length: length of input sequence (e.g., height x width of the image feature) + + input_dim: input dimensionality of probe + + out_dim: output dimensionality of probe - number of clusters for this probe to create + + batch_size: + + """ + + def __init__( + self, + dkey, + source_seq_length, + input_dim, + out_dim=2, ## number of clusters/centroids to uncover + batch_size=1, + **kwargs + ): + super().__init__(dkey, batch_size, **kwargs) + self.dkey, *subkeys = random.split(self.dkey, 3) + self.source_seq_length = source_seq_length + self.input_dim = input_dim + self.n_clusters = self.out_dim = out_dim + ## centroids that will be uncovered by this probe + self.centroids : jax.Array = None + + def _init(self, embeddings): + _embeddings = embeddings + if len(_embeddings.shape) > 2: + flat_dim = embeddings.shape[1] * embeddings.shape[2] + _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) + ## choose random data-points to serve as centroids at iteration 0 + self.dkey, *subkeys = random.split(self.dkey, 15) + n_samples, n_features = _embeddings.shape + random_indices = random.choice( + subkeys[0], n_samples, shape=(self.n_clusters,), replace=False + ) + self.centroids = _embeddings[random_indices] + + def process(self, embeddings, dkey=None): + _embeddings = embeddings + if len(_embeddings.shape) > 2: + flat_dim = embeddings.shape[1] * embeddings.shape[2] + _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) + ## Compute final geometric vs semantic conformity via ARI + _, Y_pred = _predict_with_probe(_embeddings, self.centroids, self.n_clusters) + return Y_pred ## (B, C) + + def update(self, embeddings, labels, dkey=None): + _embeddings = embeddings + if len(_embeddings.shape) > 2: + flat_dim = embeddings.shape[1] * embeddings.shape[2] + _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) + self.centroids = _run_kmeans_probe(_embeddings, self.centroids, self.n_clusters) + L = 0. ## FIXME: should be clustering loss + predictions = self.process(_embeddings) + return L, predictions + + def fit(self, dataset, dev_dataset=None, n_iter=20, patience=20): + data, labels = dataset + _labels = jnp.argmax(labels, axis=-1) + + self._init(data) ## init K-means centroids + ari = 0. + for i in range(n_iter): ## Run vectorized K-Means optimization loop + _L, py = self.update(data, labels) + labels_pred = jnp.argmax(py, axis=1) + ari_i = measure_ARI(_labels, labels_pred) + print(f"\r{i}: ARI = {ari_i}", end="") + if ari_i > ari: + ari = ari_i + print() + return ari diff --git a/ngclearn/utils/analysis/knn_probe.py b/ngclearn/utils/analysis/knn_probe.py index 6935a3a0..b9656371 100644 --- a/ngclearn/utils/analysis/knn_probe.py +++ b/ngclearn/utils/analysis/knn_probe.py @@ -2,28 +2,34 @@ import numpy as np from ngcsimlib import deprecate_args from ngclearn.utils.analysis.probe import Probe -from ngclearn.utils.model_utils import kwta from jax import jit, random, numpy as jnp, lax, nn from functools import partial as bind -from ngclearn.utils.distribution_generator import DistributionGenerator - -@bind(jax.jit, static_argnums=[2, 3]) -def _run_knn_probe(_embeddings, Wx, K, dist_order=2): - ## Notes: - ### We do some 3D tensor math to handle a batch of predictions that need to be made - ### B = batch-size, D = embedding/input dim, C = number classes, N = number of memories - _Wx = jnp.expand_dims(Wx, axis=0) ## 3D tensor format of KNN params (1 x N x D) - embed_tensor = jnp.expand_dims(_embeddings, axis=1) ## 3D projection of input signals (B x 1 x D) - D = embed_tensor - _Wx ## compute 3D batched delta tensor (B x N x D) - ## get batched (negative) distance measurements - dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True) ## (B x N x 1) - ## else, default -> euclidean - ### Note: negative distance allows us to find minimal points w/ maximal functions - dist = -jnp.squeeze(dist, axis=2) ## (B x N) - ## now get K winners per sample in batch + + +@bind(jax.jit, static_argnums=[2, 3, 4]) +def _run_knn_probe(_embeddings, Wx, K, dist_order=2, dist_metric="minkowski"): + if dist_metric == "cosine": + ## normalize the incoming batch embeddings along the feature axis (axis 1) + ### add a tiny epsilon to prevent division-by-zero errors + eps = 1e-12 + embed_norm = _embeddings / (jnp.linalg.norm(_embeddings, axis=1, keepdims=True) + eps) + ## normalize the internal memory database array (axis 1) + Wx_norm = Wx / (jnp.linalg.norm(Wx, axis=1, keepdims=True) + eps) + ## compute batched cosine similarity using standard 2D matrix multiplication + ### (B x D) @ (D x N) -> yields a (B x N) similarity matrix directly! + dist = jnp.matmul(embed_norm, Wx_norm.T) + else: # Default back to your Minkowski setup + _Wx = jnp.expand_dims(Wx, axis=0) ## (1 x N x D) + embed_tensor = jnp.expand_dims(_embeddings, axis=1) ## (B x 1 x D) + D = embed_tensor - _Wx ## (B x N x D) + dist = jnp.linalg.norm(D, ord=dist_order, axis=2, keepdims=True) ## (B x N x 1) + dist = -jnp.squeeze(dist, axis=2) ## (B x N) + + # lax.top_k naturally grabs the maximums (highest similarity or smallest negative distance) values, indices = lax.top_k(dist, K) return values, indices + class KNNProbe(Probe): """ This implements a K-nearest neighbors (KNN) probe, which is useful for evaluating the quality of @@ -82,16 +88,21 @@ def __init__( self.vote_fx = 0 ## 0 -> mode prediction; 1 -> mean prediction if vote_style == "mean": self.vote_fx = 1 + self.distance_function = distance_function - dist_fun, dist_order = distance_function ## Default: ("minkowski", 2) -> Euclidean - if "euclidean" in dist_fun.lower(): + dist_fun, dist_order = distance_function + self.dist_metric = "minkowski" # default tracker + if "cosine" in dist_fun.lower(): + self.dist_metric = "cosine" + dist_order = 2 ## fallback assignment + elif "euclidean" in dist_fun.lower(): dist_order = 2 elif "manhattan" in dist_fun.lower(): dist_order = 1 elif "chebyshev" in dist_fun.lower(): dist_order = jnp.inf - ## TODO: add in cosine-distance (and maybe Mahalanobis distance) self.dist_order = dist_order ## set distance order p + self.predictor_type = predictor_type self.pred_fx = 0 if "regressor" == predictor_type: @@ -102,14 +113,18 @@ def __init__( Wx = Wy = jnp.ones((1, 1)) ## Wy will be assumed to be one-hot encoded self.probe_params = (Wx, Wy) - def process(self, embeddings, dkey=None): ## TODO: JIT-i-fy this + def process(self, embeddings, dkey=None): _embeddings = embeddings if len(_embeddings.shape) > 2: flat_dim = embeddings.shape[1] * embeddings.shape[2] _embeddings = jnp.reshape(_embeddings, (embeddings.shape[0], flat_dim)) - Wx, Wy = self.probe_params ## pull out KNN parameters - values, indices = _run_knn_probe(_embeddings, Wx, self.K, self.dist_order) + Wx, Wy = self.probe_params + + # Pass the explicit metric string directly to the JIT-compiled loop + values, indices = _run_knn_probe( + _embeddings, Wx, self.K, self.dist_order, self.dist_metric + ) ## do K-neighbor voting scheme (find mode/frequency prediction) Y_counts = jnp.zeros((_embeddings.shape[0], Wy.shape[1])) @@ -139,25 +154,25 @@ def update(self, embeddings, labels, dkey=None): Wy = labels self.probe_params = (Wx, Wy) -if __name__ == '__main__': - seed = 42 - D = 7 - C = 5 - dkey = random.PRNGKey(seed) - dkey, *subkeys = random.split(dkey, 3) - knn = KNNProbe( - subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean" - ) - X = random.uniform(subkeys[1], shape=(10, D)) - Y = jnp.concat( - [ - jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]), - jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]), - jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]), - jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]), - jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]]) - ], - axis=0 - ) - knn.update(X, Y) ## fit KNN to data - print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y +# if __name__ == '__main__': +# seed = 42 +# D = 7 +# C = 5 +# dkey = random.PRNGKey(seed) +# dkey, *subkeys = random.split(dkey, 3) +# knn = KNNProbe( +# subkeys[0], 1, input_dim=D, out_dim=C, K=1, dist_function="euclidean" +# ) +# X = random.uniform(subkeys[1], shape=(10, D)) +# Y = jnp.concat( +# [ +# jnp.ones((2, C)) * jnp.array([[1., 0., 0., 0., 0.]]), +# jnp.ones((2, C)) * jnp.array([[0., 1., 0., 0., 0.]]), +# jnp.ones((2, C)) * jnp.array([[0., 0., 1., 0., 0.]]), +# jnp.ones((2, C)) * jnp.array([[0., 0., 0., 1., 0.]]), +# jnp.ones((2, C)) * jnp.array([[0., 0., 0., 0., 1.]]) +# ], +# axis=0 +# ) +# knn.update(X, Y) ## fit KNN to data +# print(knn.process(X)) ## should construct the (smeared) identity matrix, exactly same as Y diff --git a/ngclearn/utils/filters/__init__.py b/ngclearn/utils/filters/__init__.py new file mode 100644 index 00000000..7f778d24 --- /dev/null +++ b/ngclearn/utils/filters/__init__.py @@ -0,0 +1,3 @@ +from .gauss_filter import gaussian_filter +from .cortical_gauss_filter import cortical_gaussian_filter + diff --git a/ngclearn/utils/filters/cortical_gauss_filter.py b/ngclearn/utils/filters/cortical_gauss_filter.py new file mode 100644 index 00000000..5a57eaa9 --- /dev/null +++ b/ngclearn/utils/filters/cortical_gauss_filter.py @@ -0,0 +1,94 @@ +""" +Support for a function/pure "cortical" Gaussian filter - this can be configured to facilitate a filter that +engages in a difference-of-Gaussians-like or ratio-of-Gaussians-like process +""" + +import jax.numpy as jnp +from jax import lax, jit +from functools import partial + +def _calc_gaussian_kernel_2D( + sigma: float, ## standard deviation of kernel + radius: int ## controls shape of kernel +) -> jnp.ndarray: ## internal co-routine for Gaussian kernel + ## create a normalized 2D Gaussian kernel with shape (1, 1, 2*radius+1, 2*radius+1) + x = jnp.arange(-radius, radius + 1) + xx, yy = jnp.meshgrid(x, x) + kernel = jnp.exp(-0.5 * (xx ** 2 + yy ** 2) / sigma ** 2) + kernel = kernel / jnp.sum(kernel) + return kernel[jnp.newaxis, jnp.newaxis, :, :] + +@partial(jit, static_argnums=[3, 4, 5]) +def cortical_gaussian_filter( + images: jnp.ndarray, ## expected input shape : (B, C, H, W) + sigma_center: float, ## center excitation + sigma_surround: float, ## surround inhibition + kernel_size: int, ## kernel radius + use_ratio:bool=False, ## triggers either RoG vs DoG mode + semi_sat_constant:float=0.1, ## sigma_h (controls contrast-gain) + excitation_exp:float=2.0, ## p-exponent (typically in range of 1.5 - 2.0) + inhibition_exp:float=2.0, ## q-exponent + edge_pad_mode:str='edge' +) -> jnp.ndarray: + """ + Applies a configurable rectified Gaussian filter (either difference-of-Gaussians, i.e., DoG, or ratio-of-Gaussians, + i.e., RoG) to a tensor batch of 2D images (each of CxHxW tensor shape/format). Note that this variant filter + means that DoG mode acts more as a (half-wave) rectified subtraction of two Gaussian kernels and RoG mode acts more + as simple form of divisive normalization. + + Args: + images: input image tensor of shape (B, C, H, W) + + sigma_center: standard deviation for narrow / center blur + + sigma_surround: standard deviation for wide / surround blur + + kernel_size: kernel radius (window size will be `2*radius + 1`) + + use_ratio: if True, this filter applies a ratio-of-Gaussians (RoG) filter (Default: False) + + semi_sat_constant: suppresses amplified micro-variations in sensory/image space + + excitation_exp: p-exponent (typically in range of 1.5 - 2.0) (Default: 2.0) + + inhibition_exp: q-exponent (typically in range of 1.5 - 2.0) (Default: 2.0) + + edge_pad_mode: type of image edge-clamping/padding to use, either "edge" or "reflect" (Default: "edge") + + Returns: + An output tensor of shape (B, C, H, W) + """ + ## set up edge-artifact padding correction + padding_config = ((0, 0), (0, 0), (kernel_size, kernel_size), (kernel_size, kernel_size)) + padded_x = jnp.pad(images, padding_config, mode=edge_pad_mode) + + ## set up Gaussian kernels + k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size) + k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size) + + dn = lax.ConvDimensionNumbers( + lhs_spec=(0, 1, 2, 3), rhs_spec=(0, 1, 2, 3), out_spec=(0, 1, 2, 3) + ) + num_channels = images.shape[1] + + ## apply standard spatial convolutions + blur_center = lax.conv_general_dilated( + padded_x, k1, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels + ) + blur_surround = lax.conv_general_dilated( + padded_x, k2, window_strides=(1, 1), padding='VALID', dimension_numbers=dn, feature_group_count=num_channels + ) + if use_ratio: + ## cortically-plausible divisive normalization (or a cortical ratio-of-Gaussians; RoG) + ### first, apply half-wave rectification + rectified_center = jnp.maximum(0.0, blur_center) + rectified_surround = jnp.maximum(0.0, blur_surround) + ## next, apply nonlinear response exponents + numerator = jnp.power(rectified_center, excitation_exp) + denominator = jnp.power(rectified_surround, inhibition_exp) + (semi_sat_constant ** 2) + output = numerator / denominator ## calculate ratio + else: ## cortically-plausible difference-of-Gaussians (cortical DoG) + ### this is modeled via rectified linear subtraction under an output threshold (0 in the case below) + output = jnp.maximum(0.0, blur_center - blur_surround) + return output + diff --git a/ngclearn/utils/filters/gauss_filter.py b/ngclearn/utils/filters/gauss_filter.py new file mode 100644 index 00000000..ef592b52 --- /dev/null +++ b/ngclearn/utils/filters/gauss_filter.py @@ -0,0 +1,93 @@ +""" +Support for a function/pure Gaussian filter - this can be configured to facilicate +difference-of-Gaussians (DoG) or ratio-of-Gaussians (RoG). +""" + +import jax.numpy as jnp +from jax import lax, jit +from functools import partial + +def _calc_gaussian_kernel_2D( ## internal co-routine for Gaussian kernel + sigma: float, ## standard deviation of kernel + radius: int ## controls shape of kernel +) -> jnp.ndarray: + ## create a normalized 2D Gaussian kernel with shape (1, 1, 2*radius+1, 2*radius+1) + x = jnp.arange(-radius, radius + 1) + xx, yy = jnp.meshgrid(x, x) + kernel = jnp.exp(-0.5 * (xx ** 2 + yy ** 2) / sigma ** 2) + kernel = kernel / jnp.sum(kernel) + ## reshape output to: (out_channels, in_channels, height, width) to support lax.conv_general_dilated w/in filter + return kernel[jnp.newaxis, jnp.newaxis, :, :] + +@partial(jit, static_argnums=[3, 4]) +def gaussian_filter( ## core filter routine + images: jnp.ndarray, ## input image batch + sigma_center: float, ## sigma1 + sigma_surround: float, ## sigma2 + kernel_size : int, ## radius + use_ratio:bool=False, ## if True, this becomes a ratio-of-Gaussians + edge_pad_mode:str="edge" ## "reflect" +) -> jnp.ndarray: + """ + Applies a configurable Gaussian filter (either difference-of-Gaussians or ratio-of-Gaussians) to a tensor batch of + 2D images (of CxHxW tensor shape). + + Args: + images: input image tensor of shape (B, C, H, W) + + sigma_center: standard deviation for narrow / center blur + + sigma_surround: standard deviation for wide / surround blur + + kernel_size: kernel radius (window size will be `2*radius + 1`) + + use_ratio: if True, this filter applies a ratio-of-Gaussians (RoG) filter (Default: False) + + edge_pad_mode: type of image edge-clamping/padding to use, either "edge" or "reflect" (Default: "edge") + + Returns: + An output tensor of shape (B, C, H, W) + """ + ## pad spatial dimensions (H, W) using edge/reflect-clamping in order to remove artifacts + ### format for 4D tensor (B, C, H, W) => + ### result: ((Before_B, After_B), (Before_C, After_C), (Before_H, After_H), (Before_W, After_W)) + padding_config = ((0, 0), (0, 0), (kernel_size, kernel_size), (kernel_size, kernel_size)) + padded_x = jnp.pad(images, padding_config, mode=edge_pad_mode) + + ## construct two 2D Gaussian kernels + k1 = _calc_gaussian_kernel_2D(sigma_center, kernel_size) ## center kernel + k2 = _calc_gaussian_kernel_2D(sigma_surround, kernel_size) ## surround kernel + + ## define dimension ordering for lax.conv ('NCHW' standard layout) + dn = lax.ConvDimensionNumbers( + lhs_spec=(0, 1, 2, 3), ## (batch, channel, height, width) + rhs_spec=(0, 1, 2, 3), ## (out_channel, in_channel, height, width) + out_spec=(0, 1, 2, 3) ## (batch, channel, height, width) + ) + num_channels = images.shape[1] ## get channel count dynamically for independent channel-wise filtering + + ## below performs spatial convolutions w/ "VALID" padding on the edge-padded input + blur_center = lax.conv_general_dilated( ## center Gaussian + padded_x, + k1, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=dn, + feature_group_count=num_channels + ) + blur_surround = lax.conv_general_dilated( ## surround Gaussian + padded_x, + k2, + window_strides=(1, 1), + padding='VALID', + dimension_numbers=dn, + feature_group_count=num_channels + ) + ## Perform final filter calculation + if use_ratio: ## apply ratio-of-Gaussians (RoG) + eps = 1e-5 + output = blur_center / (blur_surround + eps) ## calculate kernel ratio + else: ## apply difference-of-Gaussians (DoG) + output = blur_center - blur_surround ## calculate kernel difference + return output ## final shape: (B, C, H, W) + diff --git a/ngclearn/utils/io_utils.py b/ngclearn/utils/io_utils.py index 8553af44..5365b34e 100755 --- a/ngclearn/utils/io_utils.py +++ b/ngclearn/utils/io_utils.py @@ -78,3 +78,4 @@ def load_pkl(directory: str, name: str) -> Any: with open(file_name, 'rb') as f: data = pickle.load(f) return data + diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py index 457043cb..e30f3d9c 100755 --- a/ngclearn/utils/metric_utils.py +++ b/ngclearn/utils/metric_utils.py @@ -67,12 +67,12 @@ def measure_breadth_TC(spikes, preserve_batch=False): spikes: full spike train matrix; shape is (T x D) where D is number of neurons in a group/cluster - preserve_batch: if True, will return one score per sample in batch + preserve_batch: if True, will return one score per neuron in train/window (Default: False), otherwise, returns scalar average score Returns: - a 1 x D Fano factor vector (one factor per neuron) OR a single - average Fano factor across the neuronal group + a 1 x D BTC vector (one factor per neuron) OR a single + average BTC across the neuronal group """ mu = jnp.mean(spikes, axis=0, keepdims=True) sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True)) @@ -82,6 +82,42 @@ def measure_breadth_TC(spikes, preserve_batch=False): BTC = jnp.mean(BTC) return BTC +@partial(jit, static_argnums=[1]) +def measure_gini_index(codes, preserve_batch=True): + """ + Calculates the gini index a group of neurons represented as vector code samples. + Gini index measures the sparseness of the values within each vector code, where + a higher index value indicates higher sparsity and a lower index value indicates a + lower sparsity (higher density). + + Args: + codes: a batch of neural codes; shape is (N x D) where D is number of + neurons in a group/cluster and N is number of samples + + preserve_batch: if True, will return one score per sample in batch + (Default: False), otherwise, returns scalar average score + + Returns: + a N x 1 Gini index vector (one score per neuron) OR a single + average Gini score for the whole sample/set of codes + """ + ## Gini index + ### values closer to 1 indicate high sparsity (sparser codes) + ### values closer to 0 indicate lower sparsity (denser codes) + _codes = codes + (jnp.sum(codes, axis=1, keepdims=True) <= 0.) + 1e-8 + ### note that the calculation below is faster than the mean-absolute-value + ### form of gini-index; below calculation requires sorting but yields a + ### lower-complexity calculation + D = codes.shape[1] ## length of vector + codes_sorted = jnp.sort(jnp.abs(_codes), axis=1) ## sort all codes w/in batch matrix + index = jnp.arange(1, D + 1) + term1 = jnp.sum((2 * index - D - 1) * codes_sorted, axis=1, keepdims=True) + term2 = D * jnp.sum(codes_sorted, axis=1, keepdims=True) + gini = term1 / term2 ## calc final ratio + if not preserve_batch: + gini = jnp.mean(gini) ## this is the mean gini-index + return gini + @partial(jit, static_argnums=[2, 3]) def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=False): """ @@ -89,9 +125,13 @@ def measure_sparsity(codes, tolerance=0., preserve_batch=True, flip_measure=Fals this matrix is a non-negative vector. Formally, this means we compute, per i-th row: + | rho(x_i) = num_zeros(x_i) / dim(x_i) + and for a global score for matrix X with N codes/rows, we measure: + | rho_mean(X) = 1/N Sum^N_{i=1} rho(x_i) + where lower/closer to 0 means codes more sparse and closer to 1 means codes are more dense. @@ -136,9 +176,9 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st y: target / ground-truth (design) matrix; shape is (N x C) OR an array of class integers of length N (with "extract_label_indx = True") - extract_label_indx: run an argmax to pull class integer indices from + extract_label_indx: wehn True, run an argmax to pull class integer indices from "y", assuming y is a one-hot binary encoding matrix (Default: True), - otherwise, this assumes "y" is an array of class integer indices + otherwise, if False, this treats "y" is an array of class integer indices of length N Returns: @@ -423,3 +463,289 @@ def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10 if not preserve_batch: bce = jnp.mean(bce) return bce + +@partial(jit, static_argnums=[1]) +def measure_hoyer_sparsity(codes: jnp.ndarray, preserve_batch: bool=False) -> float: + """ + Measures the Hoyer sparsity for a set of latent codes. + Hoyer sparsity lies in [0, 1], where a value of 0.0 indicates if something is dense and + a value of 1 indicates something is extremely sparse. + + Args: + codes: matrix (shape: N x D) of non-negative codes to measure + sparsity of (per row); D is flattened latent code size + + preserve_batch: if True, will return one score per sample in batch + (Default: False), otherwise, returns scalar mean score + + Returns: + an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise + """ + # Flatten everything past the batch dimension + x = jnp.reshape(codes, (codes.shape[0], -1)) + N = x.shape[1] + + l1 = jnp.sum(jnp.abs(x), axis=1) + l2 = jnp.sqrt(jnp.sum(jnp.square(x), axis=1) + 1e-8) # epsilon to avoid division by zero + + hoyer = (jnp.sqrt(N) - (l1 / l2)) / (jnp.sqrt(N) - 1.0) + if not preserve_batch: + hoyer = jnp.mean(hoyer) # calc average sparsity across set/batch + return hoyer + +@partial(jit, static_argnums=[1]) +def measure_excess_kurtosis(codes: jnp.ndarray, preserve_batch: bool=False) -> float: + """ + Measures the peak and heavy-tailedness of a set of neural activation codes. Note that + higher values (> 0) indicate sparse, localized 'high-burst' activations. + + Args: + codes: matrix (shape: N x D) of non-negative codes to measure + sparsity of (per row) + + preserve_batch: if True, will return one score per sample in batch + (Default: False), otherwise, returns scalar mean score + + Returns: + an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise + """ + x = jnp.reshape(codes, (codes.shape[0], -1)) + mean = jnp.mean(x, axis=1, keepdims=True) ## 1st moment + variance = jnp.var(x, axis=1, keepdims=True) ## 2nd moment + + ## 4th central moment divided by variance squared + fourth_moment = jnp.mean(jnp.power(x - mean, 4), axis=1, keepdims=True) + kurtosis = fourth_moment / (jnp.square(variance) + 1e-8) ## kurtosis of distribution + excess_kurtosis = kurtosis - 3.0 ## calc "excess kurtosis" by subtracting 3 + if not preserve_batch: + excess_kurtosis = jnp.mean(excess_kurtosis) ## calc avg excess-kurtosis over set/batch + return excess_kurtosis + + +### class conformity metrics ### + +@partial(jit, static_argnums=[2, 3]) +def _compute_contingency_table( ## vectorized construction of contingency matrix + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + n_classes: int, + n_clusters: int +) -> jnp.ndarray: + ## Computes a contingency matrix table + ## This routine expects true integer labels and predicted integer labels (1D arrays of size N) + + # Create indicator masks across all unique classes/clusters + # find unique IDs safely up to a static maximum size (or provide num_classes) + # n_classes = n_true = jnp.max(labels_true) + 1 + # n_clusters = n_pred = jnp.max(labels_pred) + 1 + + # Broadcast to form a full one-hot lookup map + true_mask = labels_true[:, None] == jnp.arange(n_classes) + pred_mask = labels_pred[:, None] == jnp.arange(n_clusters) + + # Contingency matrix is the matrix product of boolean indicators + contingency = jnp.dot(true_mask.T.astype(jnp.float32), pred_mask.astype(jnp.float32)) + return contingency + + +def measure_ARI( + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray +) -> jnp.ndarray: + """ + Computes the adjusted random index (ARI), which measures similarity between two + sets of indices (ground truth against a clustering's produced indices) via counting the + pairs of data points assigned to same or different clusters (adjusted for chance). This + measurement lies in `[0, 1]`, where `0` indicates a random labeling/assignment and `1` indicates + perfect agreement. + + Args: + labels_true: 1D array of shape (n_samples,) with true integer class labels. + + labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels. + + Returns: + scalar ARI of these two sets of indices + """ + ## Dynamically find dimensions up to a statically bounded maximum + n_classes = int(jnp.max(labels_true) + 1) + n_clusters = int(jnp.max(labels_pred) + 1) + return _calc_adjusted_rand_index(labels_true, labels_pred, n_classes, n_clusters) + + +@partial(jit, static_argnums=[2, 3]) +def _calc_adjusted_rand_index( ## ARI + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + n_classes: int, + n_clusters: int +) -> jnp.ndarray: + n_samples = labels_true.shape[0] + if n_samples <= 1: + return jnp.array(1.0) + + ## Get contingency matrix (n_classes x n_clusters) + contingency = _compute_contingency_table( + labels_true, + labels_pred, + n_classes, + n_clusters + ) + + ## Calculate combination sums n_ijC2 = (n_ij * (n_ij - 1)) / 2 + sum_nij_c2 = jnp.sum((contingency * (contingency - 1.0)) / 2.0) + + ## Sums across margins (rows and columns) + sum_a = jnp.sum(contingency, axis=1) + sum_b = jnp.sum(contingency, axis=0) + + ## Margin pair combinations + sum_a_c2 = jnp.sum((sum_a * (sum_a - 1.0)) / 2.0) + sum_b_c2 = jnp.sum((sum_b * (sum_b - 1.0)) / 2.0) + + ## Expected index and Max index math formulas + total_c2 = (n_samples * (n_samples - 1.0)) / 2.0 + expected_index = (sum_a_c2 * sum_b_c2) / total_c2 + max_index = (sum_a_c2 + sum_b_c2) / 2.0 + + ## Prevent division by zero if everything is perfectly clustered or uniform + denominator = max_index - expected_index + ari = jnp.where(denominator == 0.0, 1.0, (sum_nij_c2 - expected_index) / denominator) + return ari + + +def measure_FMI( + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray +) -> jnp.ndarray: + """ + Calculates the Fowlkes-Mallows Index (FMI), which measures similarity between two sets of + indices - this score is the geometric mean of pair-wise recall and precision. + This measurement lies in `[0, 1]`, where higher is better (indicating greater similarity between + two clustering sets of identifiers). + + Args: + labels_true: 1D array of shape (n_samples,) with true integer class labels. + + labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels. + + Returns: + scalar FMI of these two sets of indices + """ + ## Dynamically find dimensions up to a statically bounded maximum + n_classes = int(jnp.max(labels_true) + 1) + n_clusters = int(jnp.max(labels_pred) + 1) + return _measure_fowlkes_mallows_index(labels_true, labels_pred, n_classes, n_clusters) + + +@partial(jit, static_argnums=[2, 3]) +def _measure_fowlkes_mallows_index( ## FMI + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + n_classes: int, + n_clusters: int +) -> jnp.ndarray: + n_samples = labels_true.shape[0] + # Handle edge case for single or empty samples safely + if n_samples <= 1: + return jnp.array(0.0, dtype=jnp.float32) + + contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters) + + ## Compute marginal sums (sums along rows and columns) + sum_true = jnp.sum(contingency, axis=1) + sum_pred = jnp.sum(contingency, axis=0) + + ## Calculate pairwise combinations using the matrix shortcut: nC2 = 0.5 * (sum(x^2) - N) + # True Positives pair combinations (tk) + tk = 0.5 * (jnp.sum(contingency ** 2) - n_samples) + ## Total pairs clustered together in ground truth (tr) + tr = 0.5 * (jnp.sum(sum_true ** 2) - n_samples) + ## Total pairs clustered together in predictions (tc) + tc = 0.5 * (jnp.sum(sum_pred ** 2) - n_samples) + + ## Compute FMI = tk / sqrt(tr * tc) + # Prevent division by zero if there are no pair splits/matches + denominator = jnp.sqrt(tr * tc) + fmi = jnp.where(denominator == 0.0, 0.0, tk / denominator) + return fmi + + +def measure_Vmeasure( ## V-Measure + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + beta: float = 1.0 +) -> jnp.ndarray: + """ + Calculates the V-Measure scoring metric for class conformity. This measurement compares + predicted cluster indices ("labels_pred") against ground truth indices ("labels_true") and + represents the harmonic mean of homogeneity (where each cluster contains only members of a single class) + as well as completeness (where all members of a given class are assigned to the same cluster). + This measurement (higher is better) lies in `[0,1]` where `1` indicates perfect, correct clustering. + + Args: + labels_true: 1D array of shape (n_samples,) with true integer class labels + + labels_pred: 1D array of shape (n_samples,) with predicted integer cluster labels + + beta: Weight factor. Ratios > 1.0 favor completeness, < 1.0 favor homogeneity. + + Returns: + scalar V-measure of these two sets of indices + """ + ## Dynamically find dimensions up to a statically bounded maximum + n_classes = int(jnp.max(labels_true) + 1) + n_clusters = int(jnp.max(labels_pred) + 1) + return _measure_v_measure_score(labels_true, labels_pred, n_classes, n_clusters, beta) + + +@partial(jit, static_argnums=[2, 3, 4]) +def _measure_v_measure_score( ## V-Measure + labels_true: jnp.ndarray, + labels_pred: jnp.ndarray, + n_classes: int, + n_clusters: int, + beta: float = 1.0 +) -> jnp.ndarray: + n_samples = labels_true.shape[0] + + ## Handle edge case for single or empty samples safely + if n_samples <= 1: + return jnp.array(0.0, dtype=jnp.float32) + + contingency = _compute_contingency_table(labels_true, labels_pred, n_classes, n_clusters) + + ## Calculate Marginal Sums (Row and Column totals) + sum_true = jnp.sum(contingency, axis=1) + sum_pred = jnp.sum(contingency, axis=0) + + ## Compute Base Entropies H(True) and H(Pred) + p_true = sum_true / n_samples + h_true = -jnp.sum(jnp.where(p_true > 0.0, p_true * jnp.log(p_true), 0.0)) + + p_pred = sum_pred / n_samples + h_pred = -jnp.sum(jnp.where(p_pred > 0.0, p_pred * jnp.log(p_pred), 0.0)) + + ## Compute Joint Entropy H(True, Pred) + p_joint = contingency / n_samples + h_joint = -jnp.sum(jnp.where(p_joint > 0.0, p_joint * jnp.log(p_joint), 0.0)) + + ## Derive Conditional Entropies: H(True|Pred) and H(Pred|True) using identity rule + h_true_given_pred = h_joint - h_pred + h_pred_given_true = h_joint - h_true + + ## Compute Homogeneity (H) and Completeness (C) + ## If base entropy is 0, the metric is perfectly satisfied (1.0) + homogeneity = jnp.where(h_true == 0.0, 1.0, 1.0 - (h_true_given_pred / h_true)) + completeness = jnp.where(h_pred == 0.0, 1.0, 1.0 - (h_pred_given_true / h_pred)) + + ## Compute Weighted Harmonic Mean (V-Measure) + denominator = beta * homogeneity + completeness + + ## Prevent division by zero if both metrics are zero + v_measure = jnp.where( + denominator == 0.0, + 0.0, + (1.0 + beta) * homogeneity * completeness / denominator + ) + return v_measure diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py index bf20079b..aa8ef9cb 100755 --- a/ngclearn/utils/model_utils.py +++ b/ngclearn/utils/model_utils.py @@ -73,8 +73,7 @@ def create_function(fun_name, args=None): Args: fun_name: string name of activation function to produce; Currently supports: "tanh", "bkwta" (binary K-winners-take-all), "sigmoid", "relu", "lrelu", "relu6", - "elu", "silu", "gelu", "softplus", "softmax" (derivative not supported), "unit_threshold", "heaviside", - "identity" + "elu", "silu", "gelu", "softplus", "softmax", "unit_threshold", "heaviside", "identity" Returns: function fx, first derivative of function (w.r.t. input) dfx @@ -108,7 +107,7 @@ def create_function(fun_name, args=None): elif fun_name == "elu": fx = elu dfx = d_elu - elif fun_name == "silu": + elif fun_name == "silu": # NOTE: this is also the swish function fx = silu dfx = d_silu elif fun_name == "gelu": @@ -122,26 +121,33 @@ def create_function(fun_name, args=None): dfx = d_softplus elif fun_name == "softmax": fx = softmax - ## NOTE: below is an improper derivative proxy - ## correct dfx is a Jacobian of softmax (not currently supported!) - dfx = d_identity + dfx = d_softmax ## NOTE: this yields a Jacobian tensor Jx elif fun_name == "unit_threshold": fx = threshold ## default threshold is 1 (thus unit) - dfx = d_threshold ## STE approximation + dfx = d_threshold ## NOTE: STE approximation elif "heaviside" in fun_name: fx = heaviside - dfx = d_heaviside ## STE approximation + dfx = d_heaviside ## NOTE: STE approximation elif fun_name == "identity": fx = identity dfx = d_identity - else: + else: ## throw exception for un-supported activation raise RuntimeError( "Activation function (" + fun_name + ") is not recognized/supported!" - ) + ) return fx, dfx @partial(jit, static_argnums=[1]) -def bkwta(x, nWTA=5): #5 10 15 #K=50): +def bkwta(x, nWTA=5): ## binarized k-winner-take-all function + """ + The binarized K winner-take-all (K-WTA) function: + + Args: + x: input (tensor) value (real-valued) + + Returns: + output (tensor) value (binary values) + """ values, indices = lax.top_k(x, nWTA) # Note: we do not care to sort the indices kth = jnp.expand_dims(jnp.min(values,axis=1),axis=1) # must do comparison per sample in potential mini-batch topK = jnp.greater_equal(x, kth).astype(jnp.float32) # cast booleans to floats @@ -214,7 +220,6 @@ def clamp_max(x, max_val): _x = x * mask + (1. - mask) * max_val return _x - @jit def one_hot(P): """ @@ -254,7 +259,7 @@ def chebyshev_norm(d, axis=-1, keepdims=False): @jit def binarize(data, threshold=0.5): """ - Converts the vector *data* to its binary equivalent + Converts the vector *data* to its binary equivalent. Args: data: the data to binarize (real-valued) @@ -357,10 +362,13 @@ def d_telu(x): ex = jnp.exp(x) tanh_ex = jnp.tanh(ex) return tanh_ex + x * ex * (1.0 - tanh_ex ** 2) + @jit def sine(x, omega_0=30): """ - f(x) = sin(x * omega_0). + The sine function, parameterized by frequency `omega`: + + | f(x) = sin(x * omega_0). Args: x: input (tensor) value @@ -373,14 +381,15 @@ def sine(x, omega_0=30): @jit def d_sine(x, omega_0=30): """ - frequency = omega_0 - frequency * cos(x * frequency). + The derivative of the sine function: + + | f'(x) = frequency * cos(x * frequency); where frequency = omega_0 Args: x: input (tensor) value Returns: - output (tensor) value + output (tensor) derivative value (with respect to input) """ return omega_0 * jnp.cos(omega_0 * x) @@ -492,7 +501,9 @@ def d_relu6(x): @jit def softplus(x): """ - The softplus elementwise function. + The softplus elementwise function: + + | f(x) = ln(1 + exp(-x)) Args: x: input (tensor) value @@ -514,31 +525,94 @@ def d_softplus(x): output (tensor) derivative value (with respect to input argument) """ ## d/dx of softplus = logistic sigmoid - return nn.sigmoid(x) + return sigmoid(x) #nn.sigmoid(x) @jit def threshold(x, thr=1.): + """ + The threshold function (or Heaviside but with a non-zero boundary): + + | f(x) = 1 if x >= thr, otherwise 0 (for x < thr) + + Args: + x: input (tensor) value + + Returns: + output (tensor) value + """ return (x >= thr).astype(jnp.float32) @jit def d_threshold(x, thr=1.): - return x * 0. + 1. ## straight-thru estimator + """ + Derivative of the threshold function; specifically, this employs the + straight-through estimator (STE) as a proxy/surrogate derivative instead. + + Args: + x: input (tensor) value + + Returns: + output (tensor) derivative value (with respect to input argument) + """ + return x * 0. + 1. ## NOTE: straight-thru estimator (STE) @jit def heaviside(x): + """ + The Heaviside function: + + | f(x) = 1 if x >= 0, otherwise 0 (for x < 0) + + Args: + x: input (tensor) value + + Returns: + output (tensor) value + """ return (x >= 0.).astype(jnp.float32) @jit def d_heaviside(x): - return x * 0. + 1. ## straight-thru estimator + """ + Derivative of the Heaviside function; specifically, this employs the + straight-through estimator (STE) as a proxy/surrogate derivative instead. + + Args: + x: input (tensor) value + + Returns: + output (tensor) derivative value (with respect to input argument) + """ + return x * 0. + 1. ## NOTE: straight-thru estimator (STE) @jit def sigmoid(x): - return nn.sigmoid(x) + """ + The sigmoid / logistic-link function: + + | f(x) = 1/(1 + exp(-x) + + Args: + x: input (tensor) value + + Returns: + output (tensor) value + """ + sigm_x = 1./ (1. + jnp.exp(-x)) + return sigm_x #nn.sigmoid(x) @jit def d_sigmoid(x): - sigm_x = nn.sigmoid(x) ## pre-compute once + """ + Derivative of the sigmoid / logistic-link function. + + Args: + x: input (tensor) value + + Returns: + output (tensor) derivative value (with respect to input argument) + """ + sigm_x = sigmoid(x) #nn.sigmoid(x) ## pre-compute once return sigm_x * (1. - sigm_x) def inverse_sigmoid(x, clip_bound=0.03): ## wrapper call for naming convention ease @@ -590,7 +664,9 @@ def d_swish(x, beta): @jit def silu(x): """ - Applies the sigmoid-weighted linear unit (SiLU or SiL) activation. + Applies the sigmoid-weighted linear unit (SiLU or SiL) activation. + Note that this is primarily a convenience wrapper function for + the `swish` activation. Args: x: data to transform via inverse logistic function @@ -607,7 +683,8 @@ def d_silu(x): @jit def gelu(x): """ - Applies the Gaussian Error Linear Unit (GeLU) activation (specifically, a fast approximation is used). + Applies the Gaussian Error Linear Unit (GeLU) activation + (specifically, a fast approximation is used via a weighted `swish`). Args: x: data to transform via inverse logistic function @@ -635,7 +712,7 @@ def elu(x, alpha=1.): Returns: output of the GeLU activation """ - mask = x >= 0. + mask = x >= 0. ## pre-compute mask return x * mask + ((jnp.exp(x) - 1) * alpha) * (1. - mask) @jit @@ -653,21 +730,68 @@ def softmax(x, tau=0.0): Args: x: a (N x D) input argument (pre-activity) to the softmax operator - tau: probability sharpening/softening factor + tau: probability sharpening/softening factor, if > 0.; else, <= 0 disables + this (Default: 0.) Returns: a (N x D) probability distribution output block """ - if tau > 0.0: - x = x / tau + #if tau > 0.0: + # x = x / tau + _m = tau > 0. + _tau = tau * _m + (1. - _m) ## sets _tau=1 if tau <= 0 + x = x * (1./ _tau) max_x = jnp.max(x, axis=1, keepdims=True) exp_x = jnp.exp(x - max_x) return exp_x / jnp.sum(exp_x, axis=1, keepdims=True) +@partial(jit, static_argnums=[2]) +def d_softmax(x, tau=0., vmap_form=False): ## temperature-controlled softmax derivative co-routine + """ + Derivative of the softmax function. + Note that this returns specifically the Jacobian tensor `Jx` of softmax(x) w.r.t. + potential batch set of vectors (one per row). + + Args: + x: input (tensor) value (B x D) + + vmap_form: optional algorithm switch flag; if True, `Jx` is computed using + Jax vmap (Default: False) + + Returns: + output (tensor) derivative values (Jacobian with respect to input argument; B x D x D) + """ + _m = tau > 0. + _tau = tau * _m + (1. - _m) ## sets _tau=1 if tau <= 0 + Jx = 0. ## d_softmax(x)/d_x is a Jacobian matrix per sample + ## caclulate softmax along feature dimension (axis=-1) + s = softmax(x, tau=_tau) # nn.softmax(x, axis=-1) ## (BxD) + if not vmap_form: ### use pure tensorized batch-identity trick algorithm + diag_s = jnp.expand_dims(s, axis=-1) * jnp.eye(s.shape[-1]) ## Shape: (BxDx1) * (1xDxD) => (BxDxD) + ## batched outer(s, s) ~> outer product for each batch vector + outer_s = jnp.expand_dims(s, axis=-1) * jnp.expand_dims(s, axis=-2) ## (BxDx1) * (Bx1xD) => (BxDxD) + Jx = (diag_s - outer_s) * (1. / _tau) + else: ### switch to vmap algorithm + ## calc outer product using einsum (clean and readable) + outer_s = jnp.einsum('bi,bj->bij', s, s) ## (BxDxD) + ## fast batched diagonal insertion via a diagonal mask + d = s.shape[-1] + diag_indices = jnp.arange(d) + ## jax.at subtracts outer product from diagonal + ## (s - s^2) for diagonal, (-s_i s_j) for off-diagonal + ## avoids constructing a giant identity matrix + jacobian = -outer_s + ## vmap over index updates across batch + def add_diag(J_matrix, s_vector): + return J_matrix.at[diag_indices, diag_indices].add(s_vector) + Jx = ( jax.vmap(add_diag)(jacobian, s) ) * (1. / _tau) + return Jx ## return full, final Jacobian + @jit def threshold_soft(x, lmbda): """ - A soft threshold routine applied to each dimension of input + A soft threshold routine applied to each dimension of input. + (Note that this function does not contain a complementary derivative.) Args: x: data to apply threshold function over @@ -684,7 +808,8 @@ def threshold_soft(x, lmbda): @jit def threshold_cauchy(x, lmbda): """ - A Cauchy distributional threshold routine applied to each dimension of input + A Cauchy distributional threshold routine applied to each dimension of input. + (Note that this function does not contain a complementary derivative.) Args: x: data to apply threshold function over @@ -770,6 +895,23 @@ def create_block_matrix(map_matrix, group_shape, alpha_inh=-1., alpha_exc=1.): gmat = jnp.concatenate(gmat, axis=0) return gmat +@partial(jit, static_argnums=[0]) +def eye_wrapped(N, k, values): + """ + Creates an N x N matrix with a wrapped off-diagonal. + + Args: + N: Size of the square matrix (N x N) + + k: Diagonal offset (positive=above, negative=below) + + values: Array of values to place (length should match n) + """ + matrix = jnp.zeros((N, N)) ## Create empty matrix + row_indices = jnp.arange(N) ## Generate indices for the diagonal + col_indices = (row_indices + k) % N ## Wrap column indices using modulo + return matrix.at[row_indices, col_indices].set(values) ## Fill diagonal using efficient indexing + @partial(jit, static_argnums=[1, 2, 3, 4]) def normalize_block_matrix(matrix, block_size, order=2, axis=0, norm_targ=1.): """ diff --git a/ngclearn/utils/viz/__init__.py b/ngclearn/utils/viz/__init__.py index e37c46fc..605d9bef 100644 --- a/ngclearn/utils/viz/__init__.py +++ b/ngclearn/utils/viz/__init__.py @@ -2,3 +2,5 @@ from . import raster from . import spike_plot from . import synapse_plot +from . import classification_analysis + diff --git a/ngclearn/utils/viz/classification_analysis.py b/ngclearn/utils/viz/classification_analysis.py new file mode 100644 index 00000000..cce819f5 --- /dev/null +++ b/ngclearn/utils/viz/classification_analysis.py @@ -0,0 +1,57 @@ +import matplotlib.pyplot as plt +import numpy as np + +def visualize_confusion_heatmap( + confuse_matrix, + classes, + out_fname, + figure_title="Confusion Matrix", + color_map="Blues", # "Greens" "Reds" + fontsize=10, + norm_by="none" +): + _conf = confuse_matrix + if "recall" in norm_by: ## normalize by row (recall) + _conf = (_conf / np.sum(_conf, axis=1, keepdims=True)) * 100 + elif "precision" in norm_by: ## normalize by col (precision) + _conf = (_conf / np.sum(_conf, axis=0, keepdims=True)) * 100 + ## Initialize plot + fig, ax = plt.subplots(figsize=(6, 6)) + vmin = np.floor(np.min(_conf)) + vmax = np.ceil(np.max(_conf)) + im = ax.imshow( + _conf, interpolation="nearest", cmap=color_map, vmin=vmin, vmax=vmax + ) + ## Add color bar + fig.colorbar(im, ax=ax, shrink=0.75) + ## Configure axes labels + ax.set_xticks(np.arange(len(classes))) + ax.set_yticks(np.arange(len(classes))) + ax.set_xticklabels(classes) + ax.set_yticklabels(classes) + + ax.set_xlabel("Predicted Label", fontsize=12, labelpad=10) + ax.set_ylabel("True Label", fontsize=12, labelpad=10) + ax.set_title(figure_title, fontsize=14, pad=15) + + ## Loop to print numbers on top of each colored cell + threshold = _conf.max() / 2.0 # Find midpoint to flip text color for readability + for i in range(_conf.shape[0]): + for j in range(_conf.shape[1]): + ## Use white text on dark cells, black text on light cells + color = "white" if _conf[i, j] > threshold else "black" + ax.text( + j, + i, + f"{_conf[i, j]:.1f}", #format(confuse_matrix[i, j], "d"), + ha="center", + va="center", + color=color, + fontsize=fontsize, + weight="bold", + ) + + ## Ensure layout fits tightly + plt.tight_layout() + plt.savefig(out_fname) + diff --git a/ngclearn/utils/viz/dim_reduce.py b/ngclearn/utils/viz/dim_reduce.py index 3f32057d..3300feef 100755 --- a/ngclearn/utils/viz/dim_reduce.py +++ b/ngclearn/utils/viz/dim_reduce.py @@ -28,7 +28,12 @@ def extract_pca_latents(vectors): ## PCA mapping routine z_2D = vectors return z_2D -def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): ## tSNE mapping routine +def extract_tsne_latents( + vectors, + perplexity=30, + n_pca_comp=32, + batch_size=500 +): ## tSNE mapping routine """ Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D) visualization space via the t-distributed stochastic neighbor embedding algorithm (t-SNE). This algorithm also uses PCA to produce an @@ -41,9 +46,9 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): perplexity: the perplexity control factor for t-SNE (Default: 30) n_pca_comp: number of PCA top components (sorted by eigen-values) to retain/extract before continuing - with t-SNE dimensionality reduction + with t-SNE dimensionality reduction (Default: 32) - batch_size: number of sampled embedding vectors to use per iteration of online internal PCA + batch_size: number of sampled embedding vectors to use per iteration of online internal PCA (Default: 500) Returns: a matrix (K x 2) of projected vectors (to 2D space) @@ -67,7 +72,15 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): z_2D = vectors return z_2D -def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., cmap=None): +def plot_latents( + code_vectors, + labels, + plot_fname="2Dcode_plot.jpg", + alpha=1., + cmap=None, + xaxis_title=None, + yaxis_title=None +): """ Produces a label-overlaid (label map to distinct colors) scatterplot for visualizing two-dimensional latent codes (produced by either PCA or t-SNE). @@ -84,6 +97,11 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., c alpha: alpha intensity level to present colors in scatterplot cmap: custom color-map to provide + + xaxis_title: string denoting title to place for X-axis (Default: None) + + yaxis_title: string denoting title to place for Y-axis (Default: None) + """ curr_backend = plt.rcParams["backend"] matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting @@ -98,10 +116,17 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., c if _cmap is None: _cmap = default_cmap #print("> USING DEFAULT CMAP!") - plt.scatter(code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=_cmap, alpha=alpha) + plt.scatter( + code_vectors[:, 0], code_vectors[:, 1], c=lab, cmap=_cmap, alpha=alpha + ) colorbar = plt.colorbar() #colorbar.set_alpha(1) #plt.draw_all() + if xaxis_title is not None: + plt.xlabel(xaxis_title, fontsize=16, fontweight="bold") + if yaxis_title is not None: + plt.ylabel(yaxis_title, fontsize=16, fontweight="bold") + plt.grid() plt.savefig("{0}".format(plot_fname), dpi=300) plt.clf() diff --git a/ngclearn/utils/viz/raster.py b/ngclearn/utils/viz/raster.py index dff8745b..875a52e8 100755 --- a/ngclearn/utils/viz/raster.py +++ b/ngclearn/utils/viz/raster.py @@ -140,3 +140,4 @@ def create_overlay_raster_plot(spike_train, targ_train, Y, idxs, s=1.5, c="black plt.savefig(plot_fname + '_' + str(idx) + suffix) plt.clf() plt.close() + diff --git a/ngclearn/utils/viz/synapse_plot.py b/ngclearn/utils/viz/synapse_plot.py index 0f268491..f62a47da 100644 --- a/ngclearn/utils/viz/synapse_plot.py +++ b/ngclearn/utils/viz/synapse_plot.py @@ -9,8 +9,13 @@ import jax.numpy as jnp - -def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'): +def visualize( + thetas, + sizes, + prefix, + order=None, + suffix='.jpg' +): """ Args: @@ -22,8 +27,6 @@ def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'): suffix: """ - - if order is None: order = ['C' for _ in range(len(thetas))] @@ -52,8 +55,11 @@ def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'): point = start + 1 + i + (r * extra) plt.subplot(n_rows_total, n_cols_total, point) - filter = T[i, :] - plt.imshow(np.reshape(filter, (sizes[idx][0], sizes[idx][1]), order=order[idx]), cmap=plt.cm.bone, interpolation='nearest') + _filter = T[i, :] + plt.imshow( + np.reshape(_filter, (sizes[idx][0], sizes[idx][1]), order=order[idx]), + cmap=plt.cm.bone, interpolation='nearest' + ) plt.axis("off") plt.subplots_adjust(top=0.9) @@ -62,7 +68,14 @@ def visualize(thetas, sizes, prefix, order=None, suffix='.jpg'): plt.close() -def visualize_labels(thetas, sizes, prefix, space_width=None, widths=None, suffix='.jpg'): +def visualize_labels( + thetas, + sizes, + prefix, + space_width=None, + widths=None, + suffix='.jpg' +): """ Args: @@ -139,14 +152,34 @@ def visualize_labels(thetas, sizes, prefix, space_width=None, widths=None, suffi fig.savefig(prefix+suffix, bbox_inches='tight') plt.close(fig) -def visualize_frame(frame, path='.', name='tmp', suffix='.jpg', **kwargs): +def visualize_frame( + frame, + path='.', + name='tmp', + suffix='.jpg', + **kwargs +): iio.imwrite(path + '/' + name + suffix, frame.astype(jnp.uint8), **kwargs) -def visualize_gif(frames, path='.', name='tmp', suffix='.jpg', **kwargs): +def visualize_gif( + frames, + path='.', + name='tmp', + suffix='.jpg', + **kwargs +): _frames = [f.astype(jnp.uint8) for f in frames] iio.imwrite(path + '/' + name + '.gif', _frames, **kwargs) -def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs): +def make_video( + f_start, + f_end, + path, + prefix, + suffix='.jpg', + skip=1, + **kwargs +): images = [] for i in range(f_start, f_end+1, skip): print("Reading frame " + str(i)) @@ -154,10 +187,13 @@ def make_video(f_start, f_end, path, prefix, suffix='.jpg', skip=1, **kwargs): print("writing gif") iio.imwrite(path + '/training.gif', images, **kwargs) - -# def visualize_norm(thetas, sizes, prefix, suffix='.jpg'): - -def viz_block(thetas, sizes, prefix, suffix=".jpg", padding=1, low_rez=True): +def viz_block( + thetas, + sizes, prefix, + suffix=".jpg", + padding=1, + low_rez=True +): num_filters = [T.shape[1] for T in thetas] n_cols = [math.ceil(math.sqrt(nf)) for nf in num_filters] n_rows = [math.ceil(nf / c) for nf, c in zip(num_filters, n_cols)] diff --git a/pyproject.toml b/pyproject.toml index bd1c7194..3458b230 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" # using setuptool building engine [project] name = "ngclearn" -version = "3.1.0" +version = "3.1.1" description = "Simulation software for building and analyzing computational neuroscience models, brain-inspired computing systems, and NeuroAI agents." authors = [ {name = "Alexander Ororbia", email = "ago@cs.rit.edu"}, diff --git a/requirements.txt b/requirements.txt index 09ab6214..d68c12aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,10 @@ numpy>=1.26.4 scikit-learn>=1.6.1 scipy>=1.14.1 matplotlib>=3.9.4 -# patchify # patchify has issues with pip installation +# patchify ## note: patchify has issues with pip installation jax>=0.4.28 jaxlib>=0.4.28 -ngcsimlib>=3.0.0 +ngcsimlib>=3.1.0 imageio>=2.37.0 pandas>=2.2.3 +typing_extensions>=4.15.0