Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e59641c
Align main with release (#150)
ago109 May 1, 2026
6feae7f
minor cleanup of filter variable name in ganglion-encoder
May 1, 2026
244ceab
update batch size setup for elestic net, lasso, and ridge
rxng8 May 2, 2026
a53e5c5
Merge branch 'main' of github.com:NACLab/ngc-learn
rxng8 May 2, 2026
a105cda
integrated gini-index into metric-utils
May 2, 2026
788d7a8
integrated gini-index into metric-utils
May 2, 2026
abe28f8
integrated gini-index into metric-utils
May 2, 2026
1f3883a
minor cleanup in utils
May 5, 2026
2f5f885
updates to utils/metrics/filters
May 15, 2026
1e3073a
added heatmap tool for utils.viz.classification_analysis
May 19, 2026
22ba7aa
added class-conformity metrics in metric_utils; integrated kmeans-pro…
May 19, 2026
abe7dfa
minor edits/updates
May 20, 2026
d527ab6
Refactor ganglion cell (#151)
Faezehabibi May 21, 2026
f47549a
mod to model_utils
May 21, 2026
9c3b1a0
Merge branch 'main' of github.com:NACLab/ngc-learn
May 21, 2026
63f3f79
mod to model_utils
May 21, 2026
23c7171
mod to model_utils
May 21, 2026
631d65e
clean-up to model_utils
May 21, 2026
91ff5f1
reverted back rate-cell/gauss-cell to v3.0.0 states
May 25, 2026
172a68d
Add effective_dim.py with participation_ratio function (#152)
Faezehabibi May 27, 2026
b8015fc
minor tweak to eff-dim measure
May 27, 2026
5647c68
Implement effective rank calculation in effective_dim.py (#153)
Faezehabibi Jun 1, 2026
df08087
Update JaxProcessesMixin.py
willgebhardt Jun 4, 2026
4c38490
Update to and adjustment to (with comments added) JaxProcessesMixin.p…
ago109 Jun 9, 2026
0ba7049
updates to utils, including integration of useful filters; included u…
Jun 9, 2026
ad0cc85
nudge for ngclearn docs to v3.1.1 minor update
Jun 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ $ python install -e .
</pre>

**Version:**<br>
3.1.0 <!--1.2.3-Beta--> <!-- -Alpha -->
3.1.1 <!--1.2.3-Beta--> <!-- -Alpha -->

Author:
Alexander G. Ororbia II<br>
Expand Down
2 changes: 2 additions & 0 deletions history.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,6 @@ History
* new suite of competitive learning synapses (generalized formats), including vector-quantization, self-organizing map, adaptive resonance theory (contus-version), and modern hopfield network
* revisions to metric and model utils
* some additional clean-up, including supported retinal ganglion encoder
* fixed pkg-resource header (in both ngcsimlib and ngclearn) to account for newer python(s)


2 changes: 1 addition & 1 deletion ngclearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"currently installed!")

##################################################################################
## Needed to preload is called before anything in ngclearn
## Following are needed to preload is called before anything in ngclearn
from pathlib import Path
from sys import argv
import numpy
Expand Down
3 changes: 3 additions & 0 deletions ngclearn/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from .input_encoders.ganglionCell import RetinalGanglionCell
from .input_encoders.latencyCell import LatencyCell
from .input_encoders.phasorCell import PhasorCell
#from .input_encoders.populationCoderCell import PopulationCoderCell
#from .input_encoders.gridCell import GridCell
#from .input_encoders.placeCell import PlaceCell

## point to synapse component types
from .synapses.denseSynapse import DenseSynapse
Expand Down
6 changes: 4 additions & 2 deletions ngclearn/components/input_encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .bernoulliCell import BernoulliCell
from .poissonCell import PoissonCell
from .latencyCell import LatencyCell
from .ganglionCell import RetinalGanglionCell
from .phasorCell import PhasorCell

from .ganglionCell import RetinalGanglionCell
#from .populationCoderCell import PopulationCoderCell
#from .gridCell import GridCell
#from .placeCell import PlaceCell

7 changes: 6 additions & 1 deletion ngclearn/components/input_encoders/bernoulliCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ class BernoulliCell(JaxComponent):
"""

def __init__(
self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs
self,
name: str,
n_units: int,
batch_size: int = 1,
key: Union[jax.Array, None] = None,
**kwargs
):
super().__init__(name=name, key=key)

Expand Down
152 changes: 90 additions & 62 deletions ngclearn/components/input_encoders/ganglionCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,28 @@
def _create_gaussian_filter(patch_shape, sigma):
## Create a 2D Gaussian kernel centered on patch_shape with given sigma.
px, py = patch_shape

x_ = jnp.linspace(0, px - 1, px)
y_ = jnp.linspace(0, py - 1, py)

x, y = jnp.meshgrid(x_, y_)

xc = px // 2
yc = py // 2

filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
return filter / jnp.sum(filter)
_filter = jnp.exp(-((x - xc) ** 2 + (y - yc) ** 2) / (2 * (sigma ** 2)))
return _filter / jnp.sum(_filter)

def _create_dog_filter(patch_shape, sigma, k=1.6, lmbda=1):
g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)

dog = g1 - lmbda * g2

return dog #- jnp.mean(dog)


def _create_ratio_of_gauss_filter(patch_shape, sigma, k=1.6):
g1 = _create_gaussian_filter(patch_shape, sigma=sigma)
g2 = _create_gaussian_filter(patch_shape, sigma=sigma * k)
rog = g1 / (g2 + 1e-8)
return rog


def _create_patches(obs, patch_shape, step_shape):
"""
Extract 2D patches from a batch of images using a sliding window.
Expand Down Expand Up @@ -67,10 +69,36 @@ def _create_patches(obs, patch_shape, step_shape):

return patches

def _reconstruct(patches, nx_ny, area_shape, patch_shape, step_shape):
# patches: (N, nx * ny, px, py)

B = len(patches)
nx, ny = nx_ny
ix, iy = area_shape
px, py = patch_shape
sx, sy = step_shape
x = jnp.zeros((B, ix, iy))
counts = jnp.zeros((ix, iy))

idx = 0
for i in range(ny):
for j in range(nx):
di = i * sx
dj = j * sy
x = x.at[:, di:di + px, dj:dj + py].add(patches[:, idx])
counts = counts.at[di:di + px, dj:dj + py].add(1.0)
idx += 1

return x / counts[None, :, :]



class RetinalGanglionCell(JaxComponent):
"""
A group of retinal ganglion cell that senses the input stimuli and sends out the filtered signal to the brain.
A group of retinal ganglion cell that sense input stimuli and send out filtered
signals (as output). Note that these simulated cells employ internal generalized
filters based on either Gaussian or difference-of-Gaussian kernels) to recover
historical receptive field processing effects.

| --- Cell Input Compartments: ---
| inputs - input (takes in external signals)
Expand All @@ -85,32 +113,34 @@ class RetinalGanglionCell(JaxComponent):
filter_type: string name of filter function (Default: identity)
:Note: supported filters include "gaussian", "difference_of_gaussian"

sigma: standard deviation of gaussian kernel
sigma: standard deviation of (gaussian) kernel

area_shape: receptive field area of ganglion cells in this module all together
area_shape: shape of receptive field area of ganglion cells in this module (all together)

n_cells: number of ganglion cells in this module

patch_shape: each ganglion cell receptive field area
patch_shape: shape of each ganglion cell's receptive field area

step_shape: the non-overlapping area between each two ganglion cells
step_shape: the non-overlapping area between each pair (two) of ganglion cells

batch_size: batch size dimension of this cell (Default: 1)
batch_size: batch size dimension of this cell/module (Default: 1)
"""

def __init__(self, name: str,
filter_type: str,
area_shape: Tuple[int, int],
n_cells: int,
patch_shape: Tuple[int, int],
step_shape: Tuple[int, int],
batch_size: int = 1,
sigma: float = 1.0,
key: Union[jax.Array, None] = None,
**kwargs):
def __init__(
self,
name: str,
filter_type: str,
area_shape: Tuple[int, int],
n_cells: int,
patch_shape: Tuple[int, int],
step_shape: Tuple[int, int],
batch_size: int = 1,
sigma: float = 1.0,
key: Union[jax.Array, None] = None,
**kwargs
):
super().__init__(name=name, key=key)


## Layer Size Setup
self.filter_type = filter_type
self.n_cells = n_cells
Expand All @@ -121,36 +151,43 @@ def __init__(self, name: str,
self.patch_shape = patch_shape
self.step_shape = step_shape

filter = jnp.ones(self.patch_shape)
_filter = jnp.ones(self.patch_shape)

if filter_type == 'gaussian':
filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)
elif filter_type == 'difference_of_gaussian':
filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)
if self.filter_type == 'gaussian':
print("filter type is ", self.filter_type)
_filter = _create_gaussian_filter(patch_shape=self.patch_shape, sigma=self.sigma)

elif self.filter_type in ["difference_of_gaussian", "DoG"]:
print("filter type is difference of gaussian: f(x) = p1 - p2")
_filter = _create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)

elif self.filter_type in ["ratio_of_gaussian", "RoG"]:
print("filter type is ratio of gaussian: f(x) = p1 / p2")
_filter = _create_ratio_of_gauss_filter(patch_shape=self.patch_shape, sigma=sigma)

# ═════════════════ compartments initial values ════════════════════
in_restVals = jnp.zeros((batch_size,
*self.area_shape)) ## input: (B | ix | iy)
in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)

out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
out_restVals = jnp.zeros(
(batch_size, self.n_cells * self.patch_shape[0] * self.patch_shape[1])
) ## output.shape: (B | n_cells * px * py)

# ═══════════════════ set compartments ══════════════════════
self.inputs = Compartment(in_restVals, display_name="Input Stimulus") # input compartment
self.filter = Compartment(filter, display_name="Filter") # Filter compartment
self.filter = Compartment(_filter, display_name="Filter") # Filter compartment
self.outputs = Compartment(out_restVals, display_name="Output Signal") # output compartment

@compilable
def advance_state(self, t):
inputs = self.inputs.get()
filter = self.filter.get()
_filter = self.filter.get()
px, py = self.patch_shape

# ═══════════════════ extract pathches for filters ══════════════════
input_patches = _create_patches(inputs, patch_shape=self.patch_shape, step_shape=self.step_shape)

# ═══════════════════ apply filter to all pathches ══════════════════
filtered_input = input_patches * filter ## shape: (B | n_cells | px | py)
filtered_input = input_patches * _filter ## shape: (B | n_cells | px | py)

# ════════════ reshape all cells responses to a single input to brain ════════════
filtered_input = filtered_input.reshape(-1, self.n_cells * (px * py)) ## shape: (B | n_cells * px * py)
Expand All @@ -160,31 +197,20 @@ def advance_state(self, t):

self.outputs.set(outputs)



@compilable
def reset(self): ## reset core components/statistics
# self.batched_reset(batch_size=self.batch_size) ## arg = batch_size data-member
in_restVals = jnp.zeros((self.batch_size, *self.area_shape)) ## input: (B | ix | iy)
out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))
self.inputs.set(in_restVals)
self.outputs.set(out_restVals)

# Viet: NOTE: we should not need this function since the reset function
# one could set the batch size then do reset
# @compilable
# def batched_reset(self, batch_size):
# in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy)

# out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
# self.n_cells * self.patch_shape[0] * self.patch_shape[1]))

# self.inputs.set(in_restVals)
# self.outputs.set(out_restVals)

@classmethod
def help(cls): ## component help function
properties = {
"cell_type": "RetinalGanglionCell - filters the input stimuli, "
"cell_type": "RetinalGanglionCell - filters the input stimuli according retinal ganglion dynamics"
}
compartment_props = {
"inputs":
Expand All @@ -196,11 +222,11 @@ def help(cls): ## component help function
}
hyperparams = {
"filter_type": "Type of the filter for preprocessing the input",
"sigma": "Standard deviation of gaussian kernel",
"sigma": "Standard deviation of gaussian kernel/filter",
"area_shape": "Effective receptive field area shape of ganglion cells in this module",
"n_cells": "Number of Retinal Ganglion (center-surround) cells to model in this layer",
"patch_shape": "Classical Receptive field area shape of individual ganglion cells in this module",
"step_shape": "Extra-Classical Receptive field area shape each ganglion cell in this module",
"n_cells": "Number of retinal ganglion (center-surround) cells to model in this layer",
"patch_shape": "Classical receptive field area shape of individual ganglion cells in this module",
"step_shape": "Extra-classical receptive field area shape each ganglion cell in this module",
"batch_size": "Batch size dimension of this component"
}
info = {cls.__name__: properties,
Expand All @@ -212,13 +238,15 @@ def help(cls): ## component help function
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
X = RetinalGanglionCell("RGC", filter_type="gaussian",
sigma=2.3,
area_shape=(16, 26),
n_cells = 3,
patch_shape=(16, 16),
step_shape=(0, 5)
)
X = RetinalGanglionCell(
"RGC",
filter_type="gaussian",
sigma=2.3,
area_shape=(16, 26),
n_cells = 3,
patch_shape=(16, 16),
step_shape=(0, 5)
)
print(X)


Expand Down
9 changes: 7 additions & 2 deletions ngclearn/components/input_encoders/poissonCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ class PoissonCell(JaxComponent):

@deprecate_args(max_freq="target_freq")
def __init__(
self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1,
key: Union[jax.Array, None] = None, **kwargs
self,
name: str,
n_units: int,
target_freq: float = 63.75,
batch_size: int = 1,
key: Union[jax.Array, None] = None,
**kwargs
):
super().__init__(name=name, key=key)

Expand Down
Loading
Loading